In [None]:
!pip install dm-haiku numpyro



In [None]:
from typing import Sequence, NamedTuple
import copy, time

import numpy as np
import plotly.express as px

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, HMC, NUTS

import jax
import jax.numpy as jnp
import jax.tree_util as jtree
import haiku as hk
import optax

In [None]:
class SGLDConfig(NamedTuple):
  epsilon: float
  gamma: float
  num_steps: int

def generate_rngkey_tree(key_or_seed, tree_or_treedef):
    rngseq = hk.PRNGSequence(key_or_seed)
    return jtree.tree_map(lambda _: next(rngseq), tree_or_treedef)

def optim_sgld(epsilon, rngkey_or_seed, preconditioning=None):
    @jax.jit
    def sgld_delta(g, rngkey, p=1):
        eta = jax.random.normal(rngkey, shape=g.shape) * jnp.sqrt(epsilon * p)
        return -epsilon * p * g / 2 + eta

    def init_fn(_):
        return rngkey_or_seed

    if preconditioning is None:
        update_map = lambda grads, rngkey_tree: jax.tree_map(sgld_delta, grads, rngkey_tree)
    else:
        update_map = lambda grads, rngkey_tree: jax.tree_map(sgld_delta, grads, rngkey_tree, preconditioning)

    @jax.jit
    def update_fn(grads, state):
        rngkey, new_rngkey = jax.random.split(state)
        rngkey_tree = generate_rngkey_tree(rngkey, grads)
        updates = update_map(grads, rngkey_tree)
        return updates, new_rngkey
    return optax.GradientTransformation(init_fn, update_fn)


def create_local_logposterior(avgnegloglikelihood_fn, num_training_data, w_init, gamma, itemp):
    def helper(x, y):
        return jnp.sum((x - y)**2)

    def _logprior_fn(w):
        sqnorm = jax.tree_util.tree_map(helper, w, w_init)
        return jax.tree_util.tree_reduce(lambda a,b: a + b, sqnorm)

    def logprob(w, x, y):
        loglike = -num_training_data * avgnegloglikelihood_fn(w, x, y)
        logprior = -gamma / 2 * _logprior_fn(w)
        return itemp * loglike + logprior
    return logprob


In [None]:
def mala_acceptance_probability(current_point, proposed_point, loss_and_grad_fn, step_size):
    """
    Calculate the acceptance probability for a MALA transition.

    Args:
    current_point: The current point in parameter space.
    proposed_point: The proposed point in parameter space.
    loss_and_grad_fn (function): Function to compute loss and loss gradient at a point.
    step_size (float): Step size parameter for MALA.

    Returns:
    float: Acceptance probability for the proposed transition.
    """
    # Compute the gradient of the loss at the current point
    current_loss, current_grad = loss_and_grad_fn(current_point)
    proposed_loss, proposed_grad = loss_and_grad_fn(proposed_point)

    # Compute the log of the proposal probabilities (using the Gaussian proposal distribution)
    log_q_proposed_to_current = -jnp.sum((current_point - proposed_point - (step_size * 0.5 * -proposed_grad)) ** 2) / (2 * step_size)
    log_q_current_to_proposed = -jnp.sum((proposed_point - current_point - (step_size * 0.5 * -current_grad)) ** 2) / (2 * step_size)

    # Compute the acceptance probability
    acceptance_log_prob = log_q_proposed_to_current - log_q_current_to_proposed + current_loss - proposed_loss
    return jnp.minimum(1.0, jnp.exp(acceptance_log_prob))

# Example usage:
# Define a simple log likelihood function for demonstration
def log_likelihood_example(x):
    return -0.5 * jnp.sum(x**2)

loss_and_grad_fn = jax.jit(jax.value_and_grad(lambda x: -log_likelihood_example(x), argnums=0))

# Current and proposed points for demonstration
current_point = jnp.array([1.0, 2.0])
proposed_point = jnp.array([1.5, 2.5])
step_size = 0.1

# Calculate the acceptance probability
acceptance_prob = mala_acceptance_probability(current_point, proposed_point, loss_and_grad_fn, step_size)
acceptance_prob


Array(0.95719343, dtype=float32)

In [None]:
def pack_params(params):
    params_flat, treedef = jax.tree_util.tree_flatten(params)
    shapes = [p.shape for p in params_flat]
    indices = np.cumsum([p.size for p in params_flat])
    params_packed = jnp.concatenate([jnp.ravel(p) for p in params_flat])
    pack_info = (treedef, shapes, indices)
    return params_packed, pack_info

def unpack_params(params_packed, pack_info):
    treedef, shapes, indices = pack_info
    params_split = jnp.split(params_packed, indices)
    params_flat = [jnp.reshape(p, shape) for p, shape in zip(params_split, shapes)]
    params = jax.tree_util.tree_unflatten(treedef, params_flat)
    return params

# MLPs

In [None]:
class ReluNetwork(hk.Module):
    def __init__(self, layer_widths):
        super().__init__()
        self.layer_widths = layer_widths

    def __call__(self, x):
        for width in self.layer_widths[:-1]:
            x = hk.Linear(width)(x)
            x = jax.nn.relu(x)
        x = hk.Linear(self.layer_widths[-1])(x)
        return x

# Function to initialize and apply the DLN model
def forward_fn(x, layer_widths):
    net = ReluNetwork(layer_widths)
    return net(x)

# Create a Haiku-transformed version of the model
def create_model(layer_widths):
    model = hk.without_apply_rng(hk.transform(lambda x: forward_fn(x, layer_widths)))
    return model

def generate_training_data(true_param, model, input_dim, num_samples):
    inputs = np.random.uniform(-10, 10, size=(num_samples, input_dim))
    ## Generate random inputs uniformly from the input ball
    #input_directions = np.random.normal(size=(num_samples, input_dim))
    #inputs = (np.random.rand(num_samples, 1)**(1/input_dim)) * (input_directions/np.linalg.norm(input_directions, axis=-1, keepdims=True))

    # Apply the true model to generate outputs
    true_outputs = model.apply(true_param, inputs)

    return inputs, true_outputs

def mse_loss(param, model, inputs, targets):
    predictions = model.apply(param, inputs)
    return jnp.mean((predictions - targets) ** 2)

def create_minibatches(inputs, targets, batch_size, num_iter=None, shuffle=True):
    assert len(inputs) == len(targets)

    if num_iter is None:
        num_iter = len(inputs)

    if shuffle:
        indices = np.random.permutation(len(inputs))
    else:
        indices = np.arange(len(inputs))

    iter = 0
    while True:
        for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
            excerpt = indices[start_idx:start_idx + batch_size]
            yield inputs[excerpt], targets[excerpt]

            iter += 1
            if iter >= num_iter:
                return



In [None]:
input_dim = 10  # Dimension of the input
layer_widths = [20, 20, 10]

# Create the DLN model
model = create_model(layer_widths)

# Initialize the "true" model parameters
rngkey = jax.random.PRNGKey(0)
dummy_input = jnp.zeros((1, input_dim), dtype=jnp.float32)
rngkey, subkey = jax.random.split(rngkey)
true_param = model.init(rngkey, dummy_input)

# Set biases to not be zero
for i in range(len(layer_widths)-1):
    rngkey, subkey1, subkey2 = jax.random.split(rngkey, num=3)
    loc = f'relu_network/linear_{i}' if i > 0 else 'relu_network/linear'
    true_param[loc]['w'] = jax.random.normal(subkey1, true_param[loc]['w'].shape)
    true_param[loc]['b'] = jax.random.normal(subkey2, true_param[loc]['b'].shape)

## Introduce a degeneracy
#degen_directions = [1, 2, 3, 4, 5, 8]
#for dir in degen_directions:
#    true_param['relu_network/linear']['w'] = true_param['relu_network/linear']['w'].at[:, dir].set(true_param['relu_network/linear']['w'][:, 0])
#    true_param['relu_network/linear']['b'] = true_param['relu_network/linear']['b'].at[dir].set(true_param['relu_network/linear']['b'][0])

## Slight degeneracy breaking
#eps = 1e-2
#for i in range(len(layer_widths)-1):
#    rngkey, subkey1, subkey2 = jax.random.split(rngkey, num=3)
#    loc = f'relu_network/linear_{i}' if i > 0 else 'relu_network/linear'
#    #true_param[loc]['w'] = true_param[loc]['w'] + eps*jax.random.normal(subkey1, true_param[loc]['w'].shape)
#    true_param[loc]['b'] = true_param[loc]['b'] + eps*jax.random.normal(subkey2, true_param[loc]['b'].shape)

In [None]:
num_training_data = 10000

# Generate training data
np.random.seed(0)
x_train, y_train = generate_training_data(true_param, model, input_dim, num_training_data)
print(f"x_train shape: {x_train.shape}, y_train shape: {y_train.shape}")
jtree.tree_map_with_path(lambda path, x: x.shape, true_param)

x_train shape: (10000, 10), y_train shape: (10000, 10)


{'relu_network/linear': {'b': (20,), 'w': (10, 20)},
 'relu_network/linear_1': {'b': (20,), 'w': (20, 20)},
 'relu_network/linear_2': {'b': (10,), 'w': (20, 10)}}

In [None]:
sgld_config = SGLDConfig(
    epsilon=1e-6,
    gamma=0.0,
    num_steps=10000,
)
batch_size = 32
itemp = 1 / np.log(num_training_data)
param_init = copy.deepcopy(true_param)
#param_init = trained_param

loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
local_logprob = create_local_logposterior(
        avgnegloglikelihood_fn=loss_fn,
        num_training_data=num_training_data,
        w_init=param_init,
        gamma=sgld_config.gamma,
        itemp=itemp,
    )
sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))


rngkey = jax.random.PRNGKey(0)
sgldoptim = optim_sgld(sgld_config.epsilon, rngkey)
samples = []
nlls = []
losses = []
accept_probs = []
opt_state = sgldoptim.init(param_init)
param = param_init
t = 0
for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, num_iter=num_samples):
    old_param = param.copy()

    nll, grads = sgld_grad_fn(param, x_batch, y_batch)
    nlls.append(float(nll))
    updates, opt_state = sgldoptim.update(grads, opt_state)
    param = optax.apply_updates(param, updates)
    samples.append(param)
    losses.append(loss_fn(param, x_batch, y_batch))

    t += 1

    if t % 20 == 0:
        old_param_packed, pack_info = pack_params(old_param)
        param_packed, _ = pack_params(param)
        def grad_fn_packed(w):
            nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
            grad_packed, _ = pack_params(grad)
            return nll, grad_packed
        accept_probs.append(mala_acceptance_probability(
            old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
    if t % 200 == 0:
        print(f"Step {t}, nll: {nll}")

init_loss = loss_fn(param_init, x_train, y_train)
loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp
lambdahat2 = (np.mean(losses) - init_loss) * num_training_data * itemp

print(lambdahat)
print(lambdahat2)
print(np.mean(accept_probs))

px.line(np.array(loss_trace) * num_training_data * itemp)

Step 200, nll: 255.84332275390625
Step 400, nll: 335.3076171875
Step 600, nll: 322.603271484375
Step 800, nll: 315.90093994140625
Step 1000, nll: 705.8600463867188
Step 1200, nll: 452.5219421386719
Step 1400, nll: 2122.71142578125
Step 1600, nll: 565.5903930664062
Step 1800, nll: 341.3694763183594
Step 2000, nll: 529.1797485351562
Step 2200, nll: 412.7966003417969
Step 2400, nll: 456.3123779296875
Step 2600, nll: 402.4379577636719
Step 2800, nll: 596.6969604492188
Step 3000, nll: 575.4529418945312
Step 3200, nll: 563.9436645507812
Step 3400, nll: 1034.941162109375
Step 3600, nll: 422.4517822265625
Step 3800, nll: 384.63507080078125
Step 4000, nll: 463.5841979980469
Step 4200, nll: 361.4891662597656
Step 4400, nll: 431.0010681152344
Step 4600, nll: 364.58819580078125
Step 4800, nll: 407.14044189453125
Step 5000, nll: 548.845703125
Step 5200, nll: 313.79437255859375
Step 5400, nll: 410.74835205078125
Step 5600, nll: 547.4313354492188
Step 5800, nll: 668.2737426757812
Step 6000, nll: 490.

In [None]:
rngkey = jax.random.PRNGKey(0)

param_init = copy.deepcopy(true_param)
beta = itemp*num_training_data
step_size = np.sqrt(sgld_config.epsilon)
num_steps = 1

loss_fn = jax.jit(lambda w: mse_loss(w, model, x_train, y_train))
#loss_fn = jax.jit(make_population_loss_fn(true_param))
potential_fn = jax.jit(lambda w: -local_logprob(w, x_train, y_train))

# Set up HMC
hmc_kernel = HMC(
    #potential_fn=lambda param: beta * loss_fn(param),
    potential_fn=potential_fn,
    step_size=step_size,
    trajectory_length=step_size*num_steps,
    adapt_step_size=False,
    adapt_mass_matrix=False
)
mcmc = MCMC(hmc_kernel, num_samples=10000, num_warmup=0)
rngkey, subkey = jax.random.split(rngkey)
mcmc.run(subkey, init_params=param_init)

# Extract samples
samples = mcmc.get_samples()

init_loss = loss_fn(param_init)
losses = (jax.lax.map(loss_fn, samples))
lambdahat = (np.mean(losses) - init_loss) * beta
print(lambdahat)

px.line(losses * beta)


sample: 100%|██████████| 10000/10000 [01:21<00:00, 123.37it/s, 1 steps of size 3.16e-03. acc. prob=0.95]


168.07602


## Systematic comparison

In [None]:
input_dim = 10  # Dimension of the input
layer_widths = [20, 20, 10]

# Create the DLN model
model = create_model(layer_widths)

# Initialize the "true" model parameters
rngkey = jax.random.PRNGKey(0)
dummy_input = jnp.zeros((1, input_dim), dtype=jnp.float32)
rngkey, subkey = jax.random.split(rngkey)
true_param = model.init(rngkey, dummy_input)

# Set biases to not be zero
for i in range(len(layer_widths)-1):
    rngkey, subkey1, subkey2 = jax.random.split(rngkey, num=3)
    loc = f'relu_network/linear_{i}' if i > 0 else 'relu_network/linear'
    true_param[loc]['w'] = jax.random.normal(subkey1, (layer_widths[i],))
    true_param[loc]['b'] = jax.random.normal(subkey2, (layer_widths[i],))


In [None]:
step_size = 1e-5
gamma = 0.0
batch_size = 32
num_samples = 10000

# Sweep over different training data size
dataset_sizes = np.round(10**np.linspace(2, 5, 7)).astype(np.int64)
sgld_times = []
mala_times = []
memories = []
sgld_lambdas = []
mala_lambdas = []
for num_training_data in dataset_sizes:
    # Generate training data
    np.random.seed(0)
    x_train, y_train = generate_training_data(true_param, model, input_dim, num_training_data)
    print(f"x_train shape: {x_train.shape}, y_train shape: {y_train.shape}")
    jtree.tree_map_with_path(lambda path, x: x.shape, true_param)

    memory = (x_train.size * x_train.itemsize) + (y_train.size * y_train.itemsize)
    memories.append(memory)
    print("Training data memory: ", memory)

    itemp = 1 / np.log(num_training_data)
    param_init = copy.deepcopy(true_param)

    loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
    local_logprob = create_local_logposterior(
            avgnegloglikelihood_fn=loss_fn,
            num_training_data=num_training_data,
            w_init=param_init,
            gamma=gamma,
            itemp=itemp,
        )

    # Run SGLD
    sgld_start_time = time.time()

    sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

    rngkey = jax.random.PRNGKey(0)
    sgldoptim = optim_sgld(step_size, rngkey)
    samples = []
    losses = []
    accept_probs = []
    opt_state = sgldoptim.init(param_init)
    param = param_init
    t = 0
    for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, num_iter=num_samples):
        old_param = param.copy()

        nll, grads = sgld_grad_fn(param, x_batch, y_batch)
        updates, opt_state = sgldoptim.update(grads, opt_state)
        param = optax.apply_updates(param, updates)
        samples.append(param)
        losses.append(loss_fn(param, x_batch, y_batch))

        t += 1

        if t % 20 == 0:
            old_param_packed, pack_info = pack_params(old_param)
            param_packed, _ = pack_params(param)
            def grad_fn_packed(w):
                nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                grad_packed, _ = pack_params(grad)
                return nll, grad_packed
            accept_probs.append(mala_acceptance_probability(
                old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
        #if t % 200 == 0:
        #    print(f"Step {t}, nll: {nll}")

    init_loss = loss_fn(param_init, x_train, y_train)
    lambdahat = (np.mean(losses) - init_loss) * num_training_data * itemp

    sgld_time = time.time() - sgld_start_time

    sgld_lambdas.append(lambdahat)
    sgld_times.append(sgld_time)

    print("SGLD lambda: ", lambdahat)
    print("SGLD accept prob: ", np.mean(accept_probs))
    print("SGLD execution time: ", sgld_time)

    # Run MALA
    mala_start_time = time.time()

    rngkey = jax.random.PRNGKey(0)

    param_init = copy.deepcopy(true_param)
    beta = itemp*num_training_data
    mala_step_size = np.sqrt(step_size)
    num_steps = 1

    loss_fn = jax.jit(lambda w: mse_loss(w, model, x_train, y_train))
    #loss_fn = jax.jit(make_population_loss_fn(true_param))
    potential_fn = jax.jit(lambda w: -local_logprob(w, x_train, y_train))

    # Set up HMC
    hmc_kernel = HMC(
        #potential_fn=lambda param: beta * loss_fn(param),
        potential_fn=potential_fn,
        step_size=mala_step_size,
        trajectory_length=mala_step_size*num_steps,
        adapt_step_size=False,
        adapt_mass_matrix=False
    )
    mcmc = MCMC(hmc_kernel, num_samples=10000, num_warmup=0)
    rngkey, subkey = jax.random.split(rngkey)
    mcmc.run(subkey, init_params=param_init)

    # Extract samples
    samples = mcmc.get_samples()

    init_loss = loss_fn(param_init)
    losses = (jax.lax.map(loss_fn, samples))
    lambdahat = (np.mean(losses) - init_loss) * beta

    mala_time = time.time() - mala_start_time

    mala_lambdas.append(lambdahat)
    mala_times.append(mala_time)

    print("MALA lambda: ", lambdahat)
    print("MALA execution time: ", mala_time)





x_train shape: (100, 10), y_train shape: (100, 10)
Training data memory:  12000
SGLD lambda:  74.55827
SGLD accept prob:  0.99775195
SGLD execution time:  25.3698627948761


sample: 100%|██████████| 10000/10000 [00:51<00:00, 192.85it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  69.984344
MALA execution time:  52.84654903411865
x_train shape: (316, 10), y_train shape: (316, 10)
Training data memory:  37920
SGLD lambda:  112.74543
SGLD accept prob:  0.9952274
SGLD execution time:  24.883032083511353


sample: 100%|██████████| 10000/10000 [00:57<00:00, 174.54it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  105.73059
MALA execution time:  58.712841510772705
x_train shape: (1000, 10), y_train shape: (1000, 10)
Training data memory:  120000
SGLD lambda:  146.32349
SGLD accept prob:  0.9906568
SGLD execution time:  23.815698623657227


sample: 100%|██████████| 10000/10000 [00:54<00:00, 183.47it/s, 1 steps of size 3.16e-03. acc. prob=1.00]


MALA lambda:  136.32379
MALA execution time:  56.07924246788025
x_train shape: (3162, 10), y_train shape: (3162, 10)
Training data memory:  379440
SGLD lambda:  164.90524
SGLD accept prob:  0.9854525
SGLD execution time:  22.964507579803467


sample: 100%|██████████| 10000/10000 [00:58<00:00, 171.05it/s, 1 steps of size 3.16e-03. acc. prob=0.99]


MALA lambda:  159.19926
MALA execution time:  61.11035346984863
x_train shape: (10000, 10), y_train shape: (10000, 10)
Training data memory:  1200000
SGLD lambda:  169.98203
SGLD accept prob:  0.9788261
SGLD execution time:  23.709234476089478


sample: 100%|██████████| 10000/10000 [01:22<00:00, 121.49it/s, 1 steps of size 3.16e-03. acc. prob=0.95]


MALA lambda:  168.07602
MALA execution time:  89.83836102485657
x_train shape: (31623, 10), y_train shape: (31623, 10)
Training data memory:  3794760
SGLD lambda:  167.63695
SGLD accept prob:  0.9822504
SGLD execution time:  23.688312530517578


sample: 100%|██████████| 10000/10000 [02:17<00:00, 72.94it/s, 1 steps of size 3.16e-03. acc. prob=0.77]


MALA lambda:  170.92712
MALA execution time:  159.33368277549744
x_train shape: (100000, 10), y_train shape: (100000, 10)
Training data memory:  12000000
SGLD lambda:  170.55006
SGLD accept prob:  0.92354137
SGLD execution time:  24.173840761184692


sample: 100%|██████████| 10000/10000 [04:26<00:00, 37.53it/s, 1 steps of size 3.16e-03. acc. prob=0.18]


MALA lambda:  165.62125
MALA execution time:  333.373015165329


In [None]:
# Compare MALA and SGLD
import pandas as pd

data = pd.DataFrame({
    #"SGLD": sgld_lambdas,
    #"MALA": mala_lambdas,
    "SGLD": sgld_times,
    "MALA": mala_times,
    "Dataset size": dataset_sizes
})

fig = px.line(data, x="Dataset size", y=["SGLD", "MALA"], log_x=True)
fig.update_layout(yaxis_title="Execution time (sec)", legend_title=None)
fig.show()