In [1]:
!pip install dm-haiku numpyro

Collecting dm-haiku
  Downloading dm_haiku-0.0.11-py3-none-any.whl (370 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m371.0/371.0 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpyro
  Downloading numpyro-0.13.2-py3-none-any.whl (312 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m312.7/312.7 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Collecting jmp>=0.0.2 (from dm-haiku)
  Downloading jmp-0.0.4-py3-none-any.whl (18 kB)
Installing collected packages: jmp, numpyro, dm-haiku
Successfully installed dm-haiku-0.0.11 jmp-0.0.4 numpyro-0.13.2


In [34]:
from typing import Sequence, NamedTuple
import copy, json

import numpy as np
import plotly.express as px
import plotly.graph_objects as go

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, HMC, NUTS

import jax
import jax.numpy as jnp
import jax.tree_util as jtree
import haiku as hk
import optax

In [3]:
class SGLDConfig(NamedTuple):
  epsilon: float
  gamma: float
  num_steps: int

def generate_rngkey_tree(key_or_seed, tree_or_treedef):
    rngseq = hk.PRNGSequence(key_or_seed)
    return jtree.tree_map(lambda _: next(rngseq), tree_or_treedef)

def optim_sgld(epsilon, rngkey_or_seed, preconditioning=None):
    @jax.jit
    def sgld_delta(g, rngkey, p=1):
        eta = jax.random.normal(rngkey, shape=g.shape) * jnp.sqrt(epsilon * p)
        return -epsilon * p * g / 2 + eta

    def init_fn(_):
        return rngkey_or_seed

    if preconditioning is None:
        update_map = lambda grads, rngkey_tree: jax.tree_map(sgld_delta, grads, rngkey_tree)
    else:
        update_map = lambda grads, rngkey_tree: jax.tree_map(sgld_delta, grads, rngkey_tree, preconditioning)

    @jax.jit
    def update_fn(grads, state):
        rngkey, new_rngkey = jax.random.split(state)
        rngkey_tree = generate_rngkey_tree(rngkey, grads)
        updates = update_map(grads, rngkey_tree)
        return updates, new_rngkey
    return optax.GradientTransformation(init_fn, update_fn)


def create_local_logposterior(avgnegloglikelihood_fn, num_training_data, w_init, gamma, itemp):
    def helper(x, y):
        return jnp.sum((x - y)**2)

    def _logprior_fn(w):
        sqnorm = jax.tree_util.tree_map(helper, w, w_init)
        return jax.tree_util.tree_reduce(lambda a,b: a + b, sqnorm)

    def logprob(w, x, y):
        loglike = -num_training_data * avgnegloglikelihood_fn(w, x, y)
        logprior = -gamma / 2 * _logprior_fn(w)
        return itemp * loglike + logprior
    return logprob


In [5]:
def mala_acceptance_probability(current_point, proposed_point, loss_and_grad_fn, step_size):
    """
    Calculate the acceptance probability for a MALA transition.

    Args:
    current_point: The current point in parameter space.
    proposed_point: The proposed point in parameter space.
    loss_and_grad_fn (function): Function to compute loss and loss gradient at a point.
    step_size (float): Step size parameter for MALA.

    Returns:
    float: Acceptance probability for the proposed transition.
    """
    # Compute the gradient of the loss at the current point
    current_loss, current_grad = loss_and_grad_fn(current_point)
    proposed_loss, proposed_grad = loss_and_grad_fn(proposed_point)

    # Compute the log of the proposal probabilities (using the Gaussian proposal distribution)
    log_q_proposed_to_current = -jnp.sum((current_point - proposed_point - (step_size * 0.5 * -proposed_grad)) ** 2) / (2 * step_size)
    log_q_current_to_proposed = -jnp.sum((proposed_point - current_point - (step_size * 0.5 * -current_grad)) ** 2) / (2 * step_size)

    # Compute the acceptance probability
    acceptance_log_prob = log_q_proposed_to_current - log_q_current_to_proposed + current_loss - proposed_loss
    return jnp.minimum(1.0, jnp.exp(acceptance_log_prob))

# Example usage:
# Define a simple log likelihood function for demonstration
def log_likelihood_example(x):
    return -0.5 * jnp.sum(x**2)

loss_and_grad_fn = jax.jit(jax.value_and_grad(lambda x: -log_likelihood_example(x), argnums=0))

# Current and proposed points for demonstration
current_point = jnp.array([1.0, 2.0])
proposed_point = jnp.array([1.5, 2.5])
step_size = 0.1

# Calculate the acceptance probability
acceptance_prob = mala_acceptance_probability(current_point, proposed_point, loss_and_grad_fn, step_size)
acceptance_prob


Array(0.95719343, dtype=float32)

In [15]:
def pack_params(params):
    params_flat, treedef = jax.tree_util.tree_flatten(params)
    shapes = [p.shape for p in params_flat]
    indices = np.cumsum([p.size for p in params_flat])
    params_packed = jnp.concatenate([jnp.ravel(p) for p in params_flat])
    pack_info = (treedef, shapes, indices)
    return params_packed, pack_info

def unpack_params(params_packed, pack_info):
    treedef, shapes, indices = pack_info
    params_split = jnp.split(params_packed, indices)
    params_flat = [jnp.reshape(p, shape) for p, shape in zip(params_split, shapes)]
    params = jax.tree_util.tree_unflatten(treedef, params_flat)
    return params

# MLPs

In [6]:
class ReluNetwork(hk.Module):
    def __init__(self, layer_widths):
        super().__init__()
        self.layer_widths = layer_widths

    def __call__(self, x):
        for width in self.layer_widths[:-1]:
            x = hk.Linear(width)(x)
            x = jax.nn.relu(x)
        x = hk.Linear(self.layer_widths[-1])(x)
        return x

# Function to initialize and apply the DLN model
def forward_fn(x, layer_widths):
    net = ReluNetwork(layer_widths)
    return net(x)

# Create a Haiku-transformed version of the model
def create_model(layer_widths):
    model = hk.without_apply_rng(hk.transform(lambda x: forward_fn(x, layer_widths)))
    return model

def generate_training_data(true_param, model, input_dim, num_samples):
    inputs = np.random.uniform(-10, 10, size=(num_samples, input_dim))

    # Apply the true model to generate outputs
    true_outputs = model.apply(true_param, inputs)

    return inputs, true_outputs

def mse_loss(param, model, inputs, targets):
    predictions = model.apply(param, inputs)
    return jnp.mean((predictions - targets) ** 2)

def create_minibatches(inputs, targets, batch_size, shuffle=True):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.random.permutation(len(inputs))
    else:
        indices = np.arange(len(inputs))

    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = indices[start_idx:start_idx + batch_size]
        yield inputs[excerpt], targets[excerpt]



In [7]:
input_dim = 10  # Dimension of the input
layer_widths = [20, 10]

# Create the DLN model
model = create_model(layer_widths)

# Initialize the "true" model parameters
rngkey = jax.random.PRNGKey(0)
dummy_input = jnp.zeros((1, input_dim), dtype=jnp.float32)
rngkey, subkey = jax.random.split(rngkey)
true_param = model.init(rngkey, dummy_input)

# Set biases to not be zero
for i in range(len(layer_widths)-1):
    rngkey, subkey = jax.random.split(rngkey)
    loc = f'relu_network/linear_{i}' if i > 0 else 'relu_network/linear'
    true_param[loc]['b'] = jax.random.normal(subkey, (layer_widths[i],))

# Introduce a degeneracy
degen_directions = [1, 2, 3, 4, 5, 8]
for dir in degen_directions:
    true_param['relu_network/linear']['w'] = true_param['relu_network/linear']['w'].at[:, dir].set(true_param['relu_network/linear']['w'][:, 0])
    true_param['relu_network/linear']['b'] = true_param['relu_network/linear']['b'].at[dir].set(true_param['relu_network/linear']['b'][0])

## Slight degeneracy breaking
#eps = 1e-2
#for i in range(len(layer_widths)-1):
#    rngkey, subkey1, subkey2 = jax.random.split(rngkey, num=3)
#    loc = f'relu_network/linear_{i}' if i > 0 else 'relu_network/linear'
#    #true_param[loc]['w'] = true_param[loc]['w'] + eps*jax.random.normal(subkey1, true_param[loc]['w'].shape)
#    true_param[loc]['b'] = true_param[loc]['b'] + eps*jax.random.normal(subkey2, true_param[loc]['b'].shape)

In [8]:
num_training_data = 10000

# Generate training data
np.random.seed(0)
x_train, y_train = generate_training_data(true_param, model, input_dim, num_training_data)
print(f"x_train shape: {x_train.shape}, y_train shape: {y_train.shape}")
jtree.tree_map_with_path(lambda path, x: x.shape, true_param)

x_train shape: (10000, 10), y_train shape: (10000, 10)


{'relu_network/linear': {'b': (20,), 'w': (10, 20)},
 'relu_network/linear_1': {'b': (10,), 'w': (20, 10)}}

## Single experiment

In [None]:
sgld_config = SGLDConfig(
    epsilon=1e-5,
    gamma=0.0,
    num_steps=10000,
)
batch_size = 500
itemp = 1 / np.log(num_training_data)
param_init = copy.deepcopy(true_param)

loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
local_logprob = create_local_logposterior(
        avgnegloglikelihood_fn=loss_fn,
        num_training_data=num_training_data,
        w_init=param_init,
        gamma=sgld_config.gamma,
        itemp=itemp,
    )
sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))


rngkey = jax.random.PRNGKey(0)
sgldoptim = optim_sgld(sgld_config.epsilon, rngkey)
samples = []
nlls = []
accept_probs = []
opt_state = sgldoptim.init(param_init)
param = param_init
t = 0
while t < sgld_config.num_steps:
    for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, shuffle=False):
        old_param = param.copy()

        nll, grads = sgld_grad_fn(param, x_batch, y_batch)
        nlls.append(float(nll))
        updates, opt_state = sgldoptim.update(grads, opt_state)
        param = optax.apply_updates(param, updates)
        samples.append(param)

        t += 1

        if t % 20 == 0:
            old_param_packed, pack_info = pack_params(old_param)
            param_packed, _ = pack_params(param)
            def grad_fn_packed(w):
                nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                grad_packed, _ = pack_params(grad)
                return nll, grad_packed
            accept_probs.append(mala_acceptance_probability(
                old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
        if t % 200 == 0:
            print(f"Step {t}, nll: {nll}")

init_loss = loss_fn(param_init, x_train, y_train)
loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp

print(lambdahat)
print(np.mean(accept_probs))

px.line(np.array(loss_trace) * num_training_data * itemp)

Step 200, nll: 104.05753326416016
Step 400, nll: 152.78582763671875
Step 600, nll: 163.82955932617188
Step 800, nll: 175.67556762695312
Step 1000, nll: 148.97361755371094
Step 1200, nll: 172.58216857910156
Step 1400, nll: 177.95628356933594
Step 1600, nll: 176.10824584960938
Step 1800, nll: 205.50314331054688
Step 2000, nll: 177.27899169921875
Step 2200, nll: 181.50462341308594
Step 2400, nll: 152.79153442382812
Step 2600, nll: 176.04031372070312
Step 2800, nll: 178.98538208007812
Step 3000, nll: 186.8394012451172
Step 3200, nll: 177.44558715820312
Step 3400, nll: 185.38978576660156
Step 3600, nll: 204.4461212158203
Step 3800, nll: 196.8759002685547
Step 4000, nll: 170.34083557128906
Step 4200, nll: 189.96640014648438
Step 4400, nll: 178.54931640625
Step 4600, nll: 200.50038146972656
Step 4800, nll: 167.3976287841797
Step 5000, nll: 172.20196533203125
Step 5200, nll: 184.03282165527344
Step 5400, nll: 176.38372802734375
Step 5600, nll: 185.7010955810547
Step 5800, nll: 180.536392211914

In [None]:
scaling_factor = 5000
adjusted_true_param = copy.deepcopy(true_param)
adjusted_true_param['relu_network/linear']['w'] = adjusted_true_param['relu_network/linear']['w'] * scaling_factor
adjusted_true_param['relu_network/linear']['b'] = adjusted_true_param['relu_network/linear']['b'] * scaling_factor
adjusted_true_param['relu_network/linear_1']['w'] = adjusted_true_param['relu_network/linear_1']['w'] / scaling_factor

In [None]:
sgld_config = SGLDConfig(
    epsilon=1e-5,
    gamma=0.0,
    num_steps=10000,
)
batch_size = 500
itemp = 1 / np.log(num_training_data)
param_init = copy.deepcopy(adjusted_true_param)

loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
local_logprob = create_local_logposterior(
        avgnegloglikelihood_fn=loss_fn,
        num_training_data=num_training_data,
        w_init=param_init,
        gamma=sgld_config.gamma,
        itemp=itemp,
    )
sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

preconditioning = {
    "relu_network/linear": {
        "w": jnp.ones(1)*scaling_factor*scaling_factor,
        "b": jnp.ones(1)*scaling_factor*scaling_factor
    },
    "relu_network/linear_1": {
        "w": jnp.ones(1)/(scaling_factor*scaling_factor),
        "b": jnp.ones(1)
    }
}


rngkey = jax.random.PRNGKey(0)
sgldoptim = optim_sgld(sgld_config.epsilon, rngkey, preconditioning)
samples = []
nlls = []
accept_probs = []
opt_state = sgldoptim.init(param_init)
param = param_init
t = 0
while t < sgld_config.num_steps:
    for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, shuffle=False):
        old_param = param.copy()

        nll, grads = sgld_grad_fn(param, x_batch, y_batch)
        nlls.append(float(nll))
        updates, opt_state = sgldoptim.update(grads, opt_state)
        param = optax.apply_updates(param, updates)
        samples.append(param)

        t += 1

        if t % 20 == 0:
            old_param_packed, pack_info = pack_params(old_param)
            param_packed, _ = pack_params(param)
            def grad_fn_packed(w):
                nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                grad_packed, _ = pack_params(grad)
                return nll, grad_packed
            accept_probs.append(mala_acceptance_probability(
                old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
        if t % 200 == 0:
            print(f"Step {t}, nll: {nll}")

init_loss = loss_fn(param_init, x_train, y_train)
loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp

print(lambdahat)
print(np.mean(accept_probs))

px.line(np.array(loss_trace) * num_training_data * itemp)

Step 200, nll: 104.0574722290039
Step 400, nll: 152.78582763671875
Step 600, nll: 163.82957458496094
Step 800, nll: 175.67556762695312
Step 1000, nll: 148.9735870361328
Step 1200, nll: 172.5821533203125
Step 1400, nll: 177.956298828125
Step 1600, nll: 176.10842895507812
Step 1800, nll: 205.50343322753906
Step 2000, nll: 177.27914428710938
Step 2200, nll: 181.50460815429688
Step 2400, nll: 152.79150390625
Step 2600, nll: 176.04031372070312
Step 2800, nll: 178.9853057861328
Step 3000, nll: 186.83949279785156
Step 3200, nll: 177.4461212158203
Step 3400, nll: 185.3896942138672
Step 3600, nll: 204.44618225097656
Step 3800, nll: 196.87591552734375
Step 4000, nll: 170.34075927734375
Step 4200, nll: 189.9663848876953
Step 4400, nll: 178.54881286621094
Step 4600, nll: 200.49925231933594
Step 4800, nll: 167.396484375
Step 5000, nll: 172.2021942138672
Step 5200, nll: 184.03208923339844
Step 5400, nll: 176.38352966308594
Step 5600, nll: 185.7005615234375
Step 5800, nll: 180.53594970703125
Step 600

## Systematic experiment

In [19]:
num_runs = 5
scaling_factors = 10.0**(jnp.linspace(-4, 4, 17))
lambdahats = []
for scaling_factor in scaling_factors:
    print("Scaling factor: ", scaling_factor)
    lambdahats.append([])
    for run in range(num_runs):
        print("Run: ", run)
        adjusted_true_param = copy.deepcopy(true_param)
        adjusted_true_param['relu_network/linear']['w'] = adjusted_true_param['relu_network/linear']['w'] * scaling_factor
        adjusted_true_param['relu_network/linear']['b'] = adjusted_true_param['relu_network/linear']['b'] * scaling_factor
        adjusted_true_param['relu_network/linear_1']['w'] = adjusted_true_param['relu_network/linear_1']['w'] / scaling_factor

        sgld_config = SGLDConfig(
            epsilon=1e-5,
            gamma=0.0,
            num_steps=10000,
        )
        batch_size = 500
        itemp = 1 / np.log(num_training_data)
        param_init = copy.deepcopy(adjusted_true_param)

        loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
        local_logprob = create_local_logposterior(
                avgnegloglikelihood_fn=loss_fn,
                num_training_data=num_training_data,
                w_init=param_init,
                gamma=sgld_config.gamma,
                itemp=itemp,
            )
        sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

        preconditioning = {
            "relu_network/linear": {
                "w": jnp.ones(1)*scaling_factor*scaling_factor,
                "b": jnp.ones(1)*scaling_factor*scaling_factor
            },
            "relu_network/linear_1": {
                "w": jnp.ones(1)/(scaling_factor*scaling_factor),
                "b": jnp.ones(1)
            }
        }


        rngkey = jax.random.PRNGKey(run)
        sgldoptim = optim_sgld(sgld_config.epsilon, rngkey, preconditioning)
        samples = []
        nlls = []
        accept_probs = []
        opt_state = sgldoptim.init(param_init)
        param = param_init
        t = 0
        while t < sgld_config.num_steps:
            for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, shuffle=False):
                old_param = param.copy()

                nll, grads = sgld_grad_fn(param, x_batch, y_batch)
                nlls.append(float(nll))
                updates, opt_state = sgldoptim.update(grads, opt_state)
                param = optax.apply_updates(param, updates)
                samples.append(param)

                t += 1

                if t % 20 == 0:
                    old_param_packed, pack_info = pack_params(old_param)
                    param_packed, _ = pack_params(param)
                    def grad_fn_packed(w):
                        nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                        grad_packed, _ = pack_params(grad)
                        return nll, grad_packed
                    accept_probs.append(mala_acceptance_probability(
                        old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))

        init_loss = loss_fn(param_init, x_train, y_train)
        loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
        lambdahat = ((np.mean(loss_trace) - init_loss) * num_training_data * itemp).item()

        print("Lambdahat: ", lambdahat)
        print("Mean MALA accept prob: ", np.mean(accept_probs))

        lambdahats[-1].append(lambdahat)

results = {
    'scaling_factors': scaling_factors.tolist(),
    'lambdahats': lambdahats
}
json.dump(results, open('llc_rescaling_results.json', 'w'))

Scaling factor:  1e-04
Run:  0
Lambdahat:  173.99740600585938
Mean MALA accept prob:  0.552
Run:  1
Lambdahat:  173.74806213378906
Mean MALA accept prob:  0.514
Run:  2
Lambdahat:  168.34881591796875
Mean MALA accept prob:  0.532
Run:  3
Lambdahat:  168.26902770996094
Mean MALA accept prob:  0.512
Run:  4
Lambdahat:  161.14703369140625
Mean MALA accept prob:  0.54
Scaling factor:  0.00031622776
Run:  0
Lambdahat:  173.99752807617188
Mean MALA accept prob:  0.552
Run:  1
Lambdahat:  173.7481689453125
Mean MALA accept prob:  0.514
Run:  2
Lambdahat:  168.34890747070312
Mean MALA accept prob:  0.532
Run:  3
Lambdahat:  168.26889038085938
Mean MALA accept prob:  0.512
Run:  4
Lambdahat:  161.14706420898438
Mean MALA accept prob:  0.54
Scaling factor:  0.001
Run:  0
Lambdahat:  173.9972686767578
Mean MALA accept prob:  0.552
Run:  1
Lambdahat:  173.74815368652344
Mean MALA accept prob:  0.514
Run:  2
Lambdahat:  168.34886169433594
Mean MALA accept prob:  0.532
Run:  3
Lambdahat:  168.268905

TypeError: Object of type ArrayImpl is not JSON serializable

In [37]:
avg_lambdahats = np.mean(np.array(lambdahats), axis=1)
std_lambdahats = np.std(np.array(lambdahats), axis=1)
fig = px.line(x=scaling_factors, y=avg_lambdahats, log_x=True, labels={'x': 'Rescaling factor', 'y': 'lambdahat'})

# Add error bands
fig.add_trace(go.Scatter(x=scaling_factors, y=avg_lambdahats - std_lambdahats, fill=None, mode='lines', line=dict(color='lightblue', width=0), showlegend=False))
fig.add_trace(go.Scatter(x=scaling_factors, y=avg_lambdahats + std_lambdahats, fill='tonexty', mode='lines', line=dict(color='lightblue', width=0), showlegend=False))

# Set the y-axis range
fig.update_yaxes(range=[0, 300])

# Show the figure
fig.show()