In [1]:
import jax
import jax.tree_util as jtree
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import numpy as np
from collections import namedtuple

MCMCConfig = namedtuple(
    "MCMCConfig", ["num_posterior_samples", "num_warmup", "num_chains", "thinning"]
)

def run_mcmc(model, X, Y, rngkey, mcmc_config, step_size=None, init_params=None, itemp=1.0, progress_bar=True):
    if step_size is None:
        kernel = NUTS(model)
    else:
        kernel = NUTS(model, step_size=step_size)
    mcmc = MCMC(
        kernel, 
        num_warmup=mcmc_config.num_warmup, 
        num_samples=mcmc_config.num_posterior_samples, 
        thinning=mcmc_config.thinning, 
        num_chains=mcmc_config.num_chains, 
        progress_bar=progress_bar
    )
    print("Running MCMC")
    mcmc.run(rngkey, X, Y, itemp=itemp, init_params=init_params)
    return mcmc



def build_model(forward_fn, prior_center_tree, prior_std=1.0, sigma_obs=1.0):
    leaves, treedef = jtree.tree_flatten(prior_center_tree)
    def model(X, Y=None, itemp=1.0):
        # parameter from gaussian prior
        prior_samples = [
            numpyro.sample(i, dist.Normal(x, scale=prior_std)) for i, x in enumerate(leaves)
        ]
        param = jtree.tree_unflatten(treedef, prior_samples)
        # mean
        mu = forward_fn(param, X)
        numpyro.sample('obs', dist.Normal(mu, sigma_obs / jnp.sqrt(itemp)), obs=Y)
        return 
    return model



In [45]:
from expt_dln import initialise_expt
from dln import mse_loss, true_dln_learning_coefficient
from utils import param_lp_dist
import plotly.graph_objects as go

rngkey = jax.random.PRNGKey(0)

layer_widths = [2, 2]
input_dim = 3
input_dist = "unit_ball"
num_training_data = 10000
itemp = 1 / np.log(num_training_data)
true_param_config = {
    "method": "random", 
    "prop_rank_reduce": 0.2,
    "mean": 0.0, 
    "std": 5.0, 
}
mcmc_config = MCMCConfig(
    num_posterior_samples=10000, 
    num_warmup=0, 
    num_chains=1, 
    thinning=2
)

rngkey, subkey = jax.random.split(rngkey)
model, true_param, x_train, y_train = initialise_expt(
    subkey, 
    layer_widths, 
    input_dim, 
    input_dist,
    num_training_data, 
    true_param_config
)
loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
param_init = true_param
mcmc_model = build_model(model.apply, prior_center_tree=param_init, prior_std=0.1, sigma_obs=1.0)
mcmc_model = jax.jit(mcmc_model)

rngkey, subkey = jax.random.split(rngkey)

mcmc = run_mcmc(
    mcmc_model, 
    x_train, 
    y_train, 
    subkey, 
    mcmc_config, 
    step_size=1e-5, 
    init_params=dict(enumerate(jtree.tree_leaves(param_init))), 
    itemp=itemp, 
    progress_bar=True
)

Running MCMC


sample: 100%|██████████| 10000/10000 [00:05<00:00, 1675.13it/s, 1023 steps of size 1.00e-05. acc. prob=1.00]


In [46]:
_, treedef = jtree.tree_flatten(true_param)
samples = mcmc.get_samples()
nsamples = samples[0].shape[0]
loss_trace = []
distances = []
for i in range(nsamples):
    param = jtree.tree_unflatten(treedef, [samples[j][i] for j in sorted(samples.keys())])
    loss_trace.append(float(loss_fn(param, x_train, y_train)))
    distances.append(param_lp_dist(param_init, param))

init_loss = loss_fn(true_param, x_train, y_train)
lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp
true_matrix = jnp.linalg.multi_dot(
    [true_param[f'deep_linear_network/linear{loc}']['w'] for loc in [''] + [f'_{i}' for i in range(1, len(layer_widths))]]
)

true_rank = jnp.linalg.matrix_rank(true_matrix)
true_lambda, true_multiplicity = true_dln_learning_coefficient(true_rank, layer_widths, input_dim)

print(lambdahat, true_lambda)

299.90012 3.0


In [47]:
fig = go.Figure()

# Add trace for loss_trace
fig.add_trace(
    go.Scatter(
        x=list(range(nsamples)),
        y=loss_trace,
        name="Loss Trace",
        yaxis="y1"
    )
)

# Add trace for distances
fig.add_trace(
    go.Scatter(
        x=list(range(nsamples)),
        y=distances,
        name="Distances",
        yaxis="y2"
    )
)

# Set layout for the graph
fig.update_layout(
    title="Loss Trace and Distances",
    xaxis_title="Sample Index",
    yaxis=dict(
        title="Loss Trace",
        side="left",
        showgrid=False,
        zeroline=False
    ),
    yaxis2=dict(
        title="Distances",
        side="right",
        overlaying="y",
        showgrid=False,
        zeroline=False
    )
)

# Show the graph
fig.show()
