In [None]:
!pip install dm-haiku numpyro



In [None]:
from typing import Sequence, NamedTuple
import copy

import numpy as np
import plotly.express as px

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, HMC, NUTS

import jax
import jax.numpy as jnp
import jax.tree_util as jtree
import haiku as hk
import optax

In [None]:
class SGLDConfig(NamedTuple):
  epsilon: float
  gamma: float
  num_steps: int

def generate_rngkey_tree(key_or_seed, tree_or_treedef):
    rngseq = hk.PRNGSequence(key_or_seed)
    return jtree.tree_map(lambda _: next(rngseq), tree_or_treedef)

def optim_sgld(epsilon, rngkey_or_seed, preconditioning=None):
    @jax.jit
    def sgld_delta(g, rngkey, p=1):
        eta = jax.random.normal(rngkey, shape=g.shape) * jnp.sqrt(epsilon * p)
        return -epsilon * p * g / 2 + eta

    def init_fn(_):
        return rngkey_or_seed

    if preconditioning is None:
        update_map = lambda grads, rngkey_tree: jax.tree_map(sgld_delta, grads, rngkey_tree)
    else:
        update_map = lambda grads, rngkey_tree: jax.tree_map(sgld_delta, grads, rngkey_tree, preconditioning)

    @jax.jit
    def update_fn(grads, state):
        rngkey, new_rngkey = jax.random.split(state)
        rngkey_tree = generate_rngkey_tree(rngkey, grads)
        updates = update_map(grads, rngkey_tree)
        return updates, new_rngkey
    return optax.GradientTransformation(init_fn, update_fn)


def create_local_logposterior(avgnegloglikelihood_fn, num_training_data, w_init, gamma, itemp):
    def helper(x, y):
        return jnp.sum((x - y)**2)

    def _logprior_fn(w):
        sqnorm = jax.tree_util.tree_map(helper, w, w_init)
        return jax.tree_util.tree_reduce(lambda a,b: a + b, sqnorm)

    def logprob(w, x, y):
        loglike = -num_training_data * avgnegloglikelihood_fn(w, x, y)
        logprior = -gamma / 2 * _logprior_fn(w)
        return itemp * loglike + logprior
    return logprob


In [None]:
def mala_acceptance_probability(current_point, proposed_point, loss_and_grad_fn, step_size):
    """
    Calculate the acceptance probability for a MALA transition.

    Args:
    current_point: The current point in parameter space.
    proposed_point: The proposed point in parameter space.
    loss_and_grad_fn (function): Function to compute loss and loss gradient at a point.
    step_size (float): Step size parameter for MALA.

    Returns:
    float: Acceptance probability for the proposed transition.
    """
    # Compute the gradient of the loss at the current point
    current_loss, current_grad = loss_and_grad_fn(current_point)
    proposed_loss, proposed_grad = loss_and_grad_fn(proposed_point)

    # Compute the log of the proposal probabilities (using the Gaussian proposal distribution)
    log_q_proposed_to_current = -jnp.sum((current_point - proposed_point - (step_size * 0.5 * -proposed_grad)) ** 2) / (2 * step_size)
    log_q_current_to_proposed = -jnp.sum((proposed_point - current_point - (step_size * 0.5 * -current_grad)) ** 2) / (2 * step_size)

    # Compute the acceptance probability
    acceptance_log_prob = log_q_proposed_to_current - log_q_current_to_proposed + current_loss - proposed_loss
    return jnp.minimum(1.0, jnp.exp(acceptance_log_prob))

# Example usage:
# Define a simple log likelihood function for demonstration
def log_likelihood_example(x):
    return -0.5 * jnp.sum(x**2)

loss_and_grad_fn = jax.jit(jax.value_and_grad(lambda x: -log_likelihood_example(x), argnums=0))

# Current and proposed points for demonstration
current_point = jnp.array([1.0, 2.0])
proposed_point = jnp.array([1.5, 2.5])
step_size = 0.1

# Calculate the acceptance probability
acceptance_prob = mala_acceptance_probability(current_point, proposed_point, loss_and_grad_fn, step_size)
acceptance_prob


Array(0.95719343, dtype=float32)

In [None]:
def pack_params(params):
    params_flat, treedef = jax.tree_util.tree_flatten(params)
    shapes = [p.shape for p in params_flat]
    indices = np.cumsum([p.size for p in params_flat])
    params_packed = jnp.concatenate([jnp.ravel(p) for p in params_flat])
    pack_info = (treedef, shapes, indices)
    return params_packed, pack_info

def unpack_params(params_packed, pack_info):
    treedef, shapes, indices = pack_info
    params_split = jnp.split(params_packed, indices)
    params_flat = [jnp.reshape(p, shape) for p, shape in zip(params_split, shapes)]
    params = jax.tree_util.tree_unflatten(treedef, params_flat)
    return params

# MLPs

In [None]:
class ReluNetwork(hk.Module):
    def __init__(self, layer_widths):
        super().__init__()
        self.layer_widths = layer_widths

    def __call__(self, x):
        for width in self.layer_widths[:-1]:
            x = hk.Linear(width)(x)
            x = jax.nn.relu(x)
        x = hk.Linear(self.layer_widths[-1])(x)
        return x

# Function to initialize and apply the DLN model
def forward_fn(x, layer_widths):
    net = ReluNetwork(layer_widths)
    return net(x)

# Create a Haiku-transformed version of the model
def create_model(layer_widths):
    model = hk.without_apply_rng(hk.transform(lambda x: forward_fn(x, layer_widths)))
    return model

def generate_training_data(true_param, model, input_dim, num_samples):
    inputs = np.random.uniform(-10, 10, size=(num_samples, input_dim))
    ## Generate random inputs uniformly from the input ball
    #input_directions = np.random.normal(size=(num_samples, input_dim))
    #inputs = (np.random.rand(num_samples, 1)**(1/input_dim)) * (input_directions/np.linalg.norm(input_directions, axis=-1, keepdims=True))

    # Apply the true model to generate outputs
    true_outputs = model.apply(true_param, inputs)

    return inputs, true_outputs

def mse_loss(param, model, inputs, targets):
    predictions = model.apply(param, inputs)
    return jnp.mean((predictions - targets) ** 2)

def create_minibatches(inputs, targets, batch_size, shuffle=True):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.random.permutation(len(inputs))
    else:
        indices = np.arange(len(inputs))

    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = indices[start_idx:start_idx + batch_size]
        yield inputs[excerpt], targets[excerpt]



In [None]:
input_dim = 10  # Dimension of the input
layer_widths = [20, 10]

# Create the DLN model
model = create_model(layer_widths)

# Initialize the "true" model parameters
rngkey = jax.random.PRNGKey(0)
dummy_input = jnp.zeros((1, input_dim), dtype=jnp.float32)
rngkey, subkey = jax.random.split(rngkey)
true_param = model.init(rngkey, dummy_input)

# Set biases to not be zero
for i in range(len(layer_widths)-1):
    rngkey, subkey = jax.random.split(rngkey)
    loc = f'relu_network/linear_{i}' if i > 0 else 'relu_network/linear'
    true_param[loc]['b'] = jax.random.normal(subkey, (layer_widths[i],))

# Introduce a degeneracy
degen_directions = [1, 2, 3, 4, 5, 8]
for dir in degen_directions:
    true_param['relu_network/linear']['w'] = true_param['relu_network/linear']['w'].at[:, dir].set(true_param['relu_network/linear']['w'][:, 0])
    true_param['relu_network/linear']['b'] = true_param['relu_network/linear']['b'].at[dir].set(true_param['relu_network/linear']['b'][0])

## Slight degeneracy breaking
#eps = 1e-2
#for i in range(len(layer_widths)-1):
#    rngkey, subkey1, subkey2 = jax.random.split(rngkey, num=3)
#    loc = f'relu_network/linear_{i}' if i > 0 else 'relu_network/linear'
#    #true_param[loc]['w'] = true_param[loc]['w'] + eps*jax.random.normal(subkey1, true_param[loc]['w'].shape)
#    true_param[loc]['b'] = true_param[loc]['b'] + eps*jax.random.normal(subkey2, true_param[loc]['b'].shape)

In [None]:
num_training_data = 10000

# Generate training data
np.random.seed(0)
x_train, y_train = generate_training_data(true_param, model, input_dim, num_training_data)
print(f"x_train shape: {x_train.shape}, y_train shape: {y_train.shape}")
jtree.tree_map_with_path(lambda path, x: x.shape, true_param)

x_train shape: (10000, 10), y_train shape: (10000, 10)


{'relu_network/linear': {'b': (20,), 'w': (10, 20)},
 'relu_network/linear_1': {'b': (10,), 'w': (20, 10)}}

In [None]:
sgld_config = SGLDConfig(
    epsilon=1e-5,
    gamma=0.0,
    num_steps=10000,
)
batch_size = 500
itemp = 1 / np.log(num_training_data)
param_init = copy.deepcopy(true_param)
#param_init = trained_param

loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
local_logprob = create_local_logposterior(
        avgnegloglikelihood_fn=loss_fn,
        num_training_data=num_training_data,
        w_init=param_init,
        gamma=sgld_config.gamma,
        itemp=itemp,
    )
sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))


rngkey = jax.random.PRNGKey(0)
sgldoptim = optim_sgld(sgld_config.epsilon, rngkey)
samples = []
nlls = []
accept_probs = []
opt_state = sgldoptim.init(param_init)
param = param_init
t = 0
while t < sgld_config.num_steps:
    for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, shuffle=False):
        old_param = param.copy()

        nll, grads = sgld_grad_fn(param, x_batch, y_batch)
        nlls.append(float(nll))
        updates, opt_state = sgldoptim.update(grads, opt_state)
        param = optax.apply_updates(param, updates)
        samples.append(param)

        t += 1

        if t % 20 == 0:
            old_param_packed, pack_info = pack_params(old_param)
            param_packed, _ = pack_params(param)
            def grad_fn_packed(w):
                nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                grad_packed, _ = pack_params(grad)
                return nll, grad_packed
            accept_probs.append(mala_acceptance_probability(
                old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
        if t % 200 == 0:
            print(f"Step {t}, nll: {nll}")

init_loss = loss_fn(param_init, x_train, y_train)
loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp

print(lambdahat)
print(np.mean(accept_probs))

px.line(np.array(loss_trace) * num_training_data * itemp)

Step 200, nll: 104.05753326416016
Step 400, nll: 152.78582763671875
Step 600, nll: 163.82955932617188
Step 800, nll: 175.67556762695312
Step 1000, nll: 148.97361755371094
Step 1200, nll: 172.58216857910156
Step 1400, nll: 177.95628356933594
Step 1600, nll: 176.10824584960938
Step 1800, nll: 205.50314331054688
Step 2000, nll: 177.27899169921875
Step 2200, nll: 181.50462341308594
Step 2400, nll: 152.79153442382812
Step 2600, nll: 176.04031372070312
Step 2800, nll: 178.98538208007812
Step 3000, nll: 186.8394012451172
Step 3200, nll: 177.44558715820312
Step 3400, nll: 185.38978576660156
Step 3600, nll: 204.4461212158203
Step 3800, nll: 196.8759002685547
Step 4000, nll: 170.34083557128906
Step 4200, nll: 189.96640014648438
Step 4400, nll: 178.54931640625
Step 4600, nll: 200.50038146972656
Step 4800, nll: 167.3976287841797
Step 5000, nll: 172.20196533203125
Step 5200, nll: 184.03282165527344
Step 5400, nll: 176.38372802734375
Step 5600, nll: 185.7010955810547
Step 5800, nll: 180.536392211914

In [None]:
scaling_factor = 5000
adjusted_true_param = copy.deepcopy(true_param)
adjusted_true_param['relu_network/linear']['w'] = adjusted_true_param['relu_network/linear']['w'] * scaling_factor
adjusted_true_param['relu_network/linear']['b'] = adjusted_true_param['relu_network/linear']['b'] * scaling_factor
adjusted_true_param['relu_network/linear_1']['w'] = adjusted_true_param['relu_network/linear_1']['w'] / scaling_factor

In [None]:
sgld_config = SGLDConfig(
    epsilon=1e-5,
    gamma=0.0,
    num_steps=10000,
)
batch_size = 500
itemp = 1 / np.log(num_training_data)
#param_init = copy.deepcopy(true_param)
param_init = copy.deepcopy(adjusted_true_param)

loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
local_logprob = create_local_logposterior(
        avgnegloglikelihood_fn=loss_fn,
        num_training_data=num_training_data,
        w_init=param_init,
        gamma=sgld_config.gamma,
        itemp=itemp,
    )
sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

preconditioning = {
    "relu_network/linear": {
        "w": jnp.ones(1)*scaling_factor*scaling_factor,
        "b": jnp.ones(1)*scaling_factor*scaling_factor
    },
    "relu_network/linear_1": {
        "w": jnp.ones(1)/(scaling_factor*scaling_factor),
        "b": jnp.ones(1)
    }
}


rngkey = jax.random.PRNGKey(0)
sgldoptim = optim_sgld(sgld_config.epsilon, rngkey, preconditioning)
samples = []
nlls = []
accept_probs = []
opt_state = sgldoptim.init(param_init)
param = param_init
t = 0
while t < sgld_config.num_steps:
    for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, shuffle=False):
        old_param = param.copy()

        nll, grads = sgld_grad_fn(param, x_batch, y_batch)
        nlls.append(float(nll))
        updates, opt_state = sgldoptim.update(grads, opt_state)
        param = optax.apply_updates(param, updates)
        samples.append(param)

        t += 1

        if t % 20 == 0:
            old_param_packed, pack_info = pack_params(old_param)
            param_packed, _ = pack_params(param)
            def grad_fn_packed(w):
                nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                grad_packed, _ = pack_params(grad)
                return nll, grad_packed
            accept_probs.append(mala_acceptance_probability(
                old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
        if t % 200 == 0:
            print(f"Step {t}, nll: {nll}")

init_loss = loss_fn(param_init, x_train, y_train)
loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp

print(lambdahat)
print(np.mean(accept_probs))

px.line(np.array(loss_trace) * num_training_data * itemp)

Step 200, nll: 104.0574722290039
Step 400, nll: 152.78582763671875
Step 600, nll: 163.82957458496094
Step 800, nll: 175.67556762695312
Step 1000, nll: 148.9735870361328
Step 1200, nll: 172.5821533203125
Step 1400, nll: 177.956298828125
Step 1600, nll: 176.10842895507812
Step 1800, nll: 205.50343322753906
Step 2000, nll: 177.27914428710938
Step 2200, nll: 181.50460815429688
Step 2400, nll: 152.79150390625
Step 2600, nll: 176.04031372070312
Step 2800, nll: 178.9853057861328
Step 3000, nll: 186.83949279785156
Step 3200, nll: 177.4461212158203
Step 3400, nll: 185.3896942138672
Step 3600, nll: 204.44618225097656
Step 3800, nll: 196.87591552734375
Step 4000, nll: 170.34075927734375
Step 4200, nll: 189.9663848876953
Step 4400, nll: 178.54881286621094
Step 4600, nll: 200.49925231933594
Step 4800, nll: 167.396484375
Step 5000, nll: 172.2021942138672
Step 5200, nll: 184.03208923339844
Step 5400, nll: 176.38352966308594
Step 5600, nll: 185.7005615234375
Step 5800, nll: 180.53594970703125
Step 600

In [None]:
scaling_factors = 10.0**(jnp.linspace(-4, 4, 17))
lambdahats = []
for scaling_factor in scaling_factors:
    adjusted_true_param = copy.deepcopy(true_param)
    adjusted_true_param['relu_network/linear']['w'] = adjusted_true_param['relu_network/linear']['w'] * scaling_factor
    adjusted_true_param['relu_network/linear']['b'] = adjusted_true_param['relu_network/linear']['b'] * scaling_factor
    adjusted_true_param['relu_network/linear_1']['w'] = adjusted_true_param['relu_network/linear_1']['w'] / scaling_factor

    sgld_config = SGLDConfig(
        epsilon=1e-5,
        gamma=0.0,
        num_steps=10000,
    )
    batch_size = 500
    itemp = 1 / np.log(num_training_data)
    #param_init = copy.deepcopy(true_param)
    param_init = copy.deepcopy(adjusted_true_param)

    loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
    local_logprob = create_local_logposterior(
            avgnegloglikelihood_fn=loss_fn,
            num_training_data=num_training_data,
            w_init=param_init,
            gamma=sgld_config.gamma,
            itemp=itemp,
        )
    sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

    preconditioning = {
        "relu_network/linear": {
            "w": jnp.ones(1)*scaling_factor*scaling_factor,
            "b": jnp.ones(1)*scaling_factor*scaling_factor
        },
        "relu_network/linear_1": {
            "w": jnp.ones(1)/(scaling_factor*scaling_factor),
            "b": jnp.ones(1)
        }
    }


    rngkey = jax.random.PRNGKey(0)
    sgldoptim = optim_sgld(sgld_config.epsilon, rngkey, preconditioning)
    samples = []
    nlls = []
    accept_probs = []
    opt_state = sgldoptim.init(param_init)
    param = param_init
    t = 0
    while t < sgld_config.num_steps:
        for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, shuffle=False):
            old_param = param.copy()

            nll, grads = sgld_grad_fn(param, x_batch, y_batch)
            nlls.append(float(nll))
            updates, opt_state = sgldoptim.update(grads, opt_state)
            param = optax.apply_updates(param, updates)
            samples.append(param)

            t += 1

            if t % 20 == 0:
                old_param_packed, pack_info = pack_params(old_param)
                param_packed, _ = pack_params(param)
                def grad_fn_packed(w):
                    nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                    grad_packed, _ = pack_params(grad)
                    return nll, grad_packed
                accept_probs.append(mala_acceptance_probability(
                    old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
            #if t % 200 == 0:
            #    print(f"Step {t}, nll: {nll}")

    init_loss = loss_fn(param_init, x_train, y_train)
    loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
    lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp

    print(lambdahat, np.mean(accept_probs))

    lambdahats.append(lambdahat)

px.line(x=scaling_factors, y=lambdahats)


173.9974 0.552
173.99753 0.552
173.99727 0.552
173.99762 0.5520907
173.99731 0.55619544
173.9974 0.57685196
173.99724 0.7213258
173.99718 0.9439179
173.99768 0.9620022
173.99765 0.7787605
173.99742 0.57731235
173.99721 0.5383957
173.99756 0.5340522
173.99721 0.534
173.99742 0.534
173.99738 0.534
173.99757 0.534


In [None]:
fig = px.line(x=scaling_factors, y=lambdahats, log_x=True, labels={'x': 'Rescaling factor', 'y': 'lambdahat'})
#fig.update_yaxes(range=[0, 300])
fig

# DLNs

In [None]:
# Define the DLN model
class DeepLinearNetwork(hk.Module):
    def __init__(self, layer_widths: Sequence[int], name: str = None, with_bias=False):
        super().__init__(name=name)
        self.layer_widths = layer_widths
        self.with_bias = with_bias

    def __call__(self, x):
        for width in self.layer_widths:
            x = hk.Linear(width, with_bias=self.with_bias)(x)
        return x

# Function to initialize and apply the DLN model
def forward_fn(x, layer_widths):
    net = DeepLinearNetwork(layer_widths)
    return net(x)

# Create a Haiku-transformed version of the model
def create_model(layer_widths):
    model = hk.without_apply_rng(hk.transform(lambda x: forward_fn(x, layer_widths)))
    return model


def generate_training_data(true_param, model, input_dim, num_samples):
    #inputs = np.random.uniform(-10, 10, size=(num_samples, input_dim))
    # Generate random inputs uniformly from the input ball
    input_directions = np.random.normal(size=(num_samples, input_dim))
    inputs = (np.random.rand(num_samples, 1)**(1/input_dim)) * (input_directions/np.linalg.norm(input_directions, axis=-1, keepdims=True))

    # Apply the true model to generate outputs
    true_outputs = model.apply(true_param, inputs)

    return inputs, true_outputs

def mse_loss(param, model, inputs, targets):
    predictions = model.apply(param, inputs)
    return jnp.mean((predictions - targets) ** 2)

# Assuming the input distribution is sampled uniformly from the input ball, this
# will be the average of the empirical loss
def make_population_loss_fn(true_param):
    first_linear = true_param['deep_linear_network/linear']['w']
    input_dim = first_linear.shape[0]
    true_prod = np.eye(input_dim)
    for p in true_param.values():
        true_prod = true_prod @ p['w']
        output_dim = p['w'].shape[1]
    def population_loss(param):
        prod = jnp.eye(input_dim)
        for p in param.values():
            prod = prod @ p['w']
        Q = true_prod - prod
        return jnp.sum(Q*Q) / ((input_dim+2) * output_dim)
    return population_loss

def create_minibatches(inputs, targets, batch_size, shuffle=True):
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.random.permutation(len(inputs))
    else:
        indices = np.arange(len(inputs))

    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = indices[start_idx:start_idx + batch_size]
        yield inputs[excerpt], targets[excerpt]


In [None]:
input_dim = 20  # Dimension of the input
layer_widths = [20, 20]
true_rank = 10

# Create the DLN model
model = create_model(layer_widths)

# Initialize the "true" model parameters
rngkey = jax.random.PRNGKey(0)
dummy_input = jnp.zeros((1, input_dim))
rngkey, subkey = jax.random.split(rngkey)
true_param = model.init(rngkey, dummy_input)

# Modify true param to be low rank
rngkey, subkey1, subkey2, subkey3, subkey4 = jax.random.split(rngkey, num=5)
true_param['deep_linear_network/linear']['w'] = jax.random.normal(subkey1, (input_dim, true_rank)) @ jax.random.normal(subkey2, (true_rank, layer_widths[0]))
true_param['deep_linear_network/linear_1']['w'] = jax.random.normal(subkey3, (layer_widths[0], true_rank)) @ jax.random.normal(subkey4, (true_rank, layer_widths[1]))
#true_param['deep_linear_network/linear']['w'] = jnp.zeros((20, 20))
#true_param['deep_linear_network/linear_1']['w'] = jnp.zeros((20, 20))

## Add small perturbation
#eps = 1e-1
#rngkey, subkey1, subkey2 = jax.random.split(rngkey, num=3)
#true_param['deep_linear_network/linear']['w'] = true_param['deep_linear_network/linear']['w'] + (eps * jax.random.normal(subkey1, (input_dim, layer_widths[0])))
#true_param['deep_linear_network/linear_1']['w'] = true_param['deep_linear_network/linear_1']['w'] + (eps * jax.random.normal(subkey2, (input_dim, layer_widths[0])))

true_matrix = jnp.linalg.multi_dot(
    [true_param[f'deep_linear_network/linear{loc}']['w'] for loc in [''] + [f'_{i}' for i in range(1, len(layer_widths))]]
)
jnp.linalg.matrix_rank(true_matrix)

Array(10, dtype=int32)

In [None]:
num_training_data = 10000

# Generate training data
np.random.seed(0)
x_train, y_train = generate_training_data(true_param, model, input_dim, num_training_data)
print(f"x_train shape: {x_train.shape}, y_train shape: {y_train.shape}")
jtree.tree_map_with_path(lambda path, x: x.shape, true_param)

x_train shape: (10000, 20), y_train shape: (10000, 20)


{'deep_linear_network/linear': {'w': (20, 20)},
 'deep_linear_network/linear_1': {'w': (20, 20)}}

### SGLD

In [None]:
sgld_config = SGLDConfig(
    epsilon=1e-5,
    gamma=0.0,
    num_steps=10000,
)
batch_size = 500
itemp = 1 / np.log(num_training_data)
param_init = copy.deepcopy(true_param)

#loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
loss_fn = jax.jit(lambda param, inputs, targets: make_population_loss_fn(true_param)(param))
local_logprob = create_local_logposterior(
        avgnegloglikelihood_fn=loss_fn,
        num_training_data=num_training_data,
        w_init=param_init,
        gamma=sgld_config.gamma,
        itemp=itemp,
    )
sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

preconditioning = {"deep_linear_network/linear": {"w": jnp.ones(1)}, "deep_linear_network/linear_1": {"w": jnp.ones(1)}}

rngkey = jax.random.PRNGKey(0)
sgldoptim = optim_sgld(sgld_config.epsilon, rngkey, preconditioning)
samples = []
nlls = []
accept_probs = []
opt_state = sgldoptim.init(param_init)
param = param_init
t = 0
while t < sgld_config.num_steps:
    for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, shuffle=False):
        old_param = param.copy()

        nll, grads = sgld_grad_fn(param, x_batch, y_batch)
        nlls.append(float(nll))
        updates, opt_state = sgldoptim.update(grads, opt_state)
        param = optax.apply_updates(param, updates)
        samples.append(param)

        t += 1

        if t % 20 == 0:
            old_param_packed, pack_info = pack_params(old_param)
            param_packed, _ = pack_params(param)
            def grad_fn_packed(w):
                nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                grad_packed, _ = pack_params(grad)
                return nll, grad_packed
            accept_probs.append(mala_acceptance_probability(
                old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
        if t % 200 == 0:
            print(f"Step {t}, nll: {nll}")

init_loss = loss_fn(param_init, x_train, y_train)
loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp

print(lambdahat)
print(np.mean(accept_probs))

px.line(np.array(loss_trace) * num_training_data * itemp)

Step 200, nll: 138.78421020507812
Step 400, nll: 130.27525329589844
Step 600, nll: 159.2430877685547
Step 800, nll: 142.35394287109375
Step 1000, nll: 149.27467346191406
Step 1200, nll: 147.85491943359375
Step 1400, nll: 158.1136016845703
Step 1600, nll: 144.07101440429688
Step 1800, nll: 129.14523315429688
Step 2000, nll: 151.04434204101562
Step 2200, nll: 149.94813537597656
Step 2400, nll: 151.97445678710938
Step 2600, nll: 149.05511474609375
Step 2800, nll: 137.9267578125
Step 3000, nll: 158.9026336669922
Step 3200, nll: 132.2672576904297
Step 3400, nll: 139.561767578125
Step 3600, nll: 152.235595703125
Step 3800, nll: 135.31459045410156
Step 4000, nll: 141.98951721191406
Step 4200, nll: 182.0004119873047
Step 4400, nll: 159.73826599121094
Step 4600, nll: 166.7034454345703
Step 4800, nll: 163.236083984375
Step 5000, nll: 140.82090759277344
Step 5200, nll: 156.9680938720703
Step 5400, nll: 157.6902618408203
Step 5600, nll: 144.6896209716797
Step 5800, nll: 138.17564392089844
Step 600

In [None]:
scaling_factor = 10000
adjusted_true_param = copy.deepcopy(true_param)
adjusted_true_param['deep_linear_network/linear']['w'] = adjusted_true_param['deep_linear_network/linear']['w'] * scaling_factor
adjusted_true_param['deep_linear_network/linear_1']['w'] = adjusted_true_param['deep_linear_network/linear_1']['w'] / scaling_factor

In [None]:
sgld_config = SGLDConfig(
    epsilon=1e-5,
    gamma=0.0,
    num_steps=10000,
)
batch_size = 500
itemp = 1 / np.log(num_training_data)
#itemp = 0.1
#param_init = copy.deepcopy(true_param)
param_init = copy.deepcopy(adjusted_true_param)

#loss_fn = jax.jit(lambda param, inputs, targets: mse_loss(param, model, inputs, targets))
loss_fn = jax.jit(lambda param, inputs, targets: make_population_loss_fn(true_param)(param))
local_logprob = create_local_logposterior(
        avgnegloglikelihood_fn=loss_fn,
        num_training_data=num_training_data,
        w_init=param_init,
        gamma=sgld_config.gamma,
        itemp=itemp,
    )
sgld_grad_fn = jax.jit(jax.value_and_grad(lambda w, x, y: -local_logprob(w, x, y), argnums=0))

preconditioning = {"deep_linear_network/linear": {"w": jnp.ones(1)*scaling_factor*scaling_factor}, "deep_linear_network/linear_1": {"w": jnp.ones(1)/(scaling_factor*scaling_factor)}}
#preconditioning = {"deep_linear_network/linear": {"w": jnp.ones(1)}, "deep_linear_network/linear_1": {"w": jnp.ones(1)}}

rngkey = jax.random.PRNGKey(0)
sgldoptim = optim_sgld(sgld_config.epsilon, rngkey, preconditioning)
samples = []
nlls = []
accept_probs = []
opt_state = sgldoptim.init(param_init)
param = param_init
t = 0
while t < sgld_config.num_steps:
    for x_batch, y_batch in create_minibatches(x_train, y_train, batch_size=batch_size, shuffle=False):
        old_param = param.copy()

        nll, grads = sgld_grad_fn(param, x_batch, y_batch)
        nlls.append(float(nll))
        updates, opt_state = sgldoptim.update(grads, opt_state)
        param = optax.apply_updates(param, updates)
        samples.append(param)

        t += 1

        if t % 20 == 0:
            old_param_packed, pack_info = pack_params(old_param)
            param_packed, _ = pack_params(param)
            def grad_fn_packed(w):
                nll, grad = sgld_grad_fn(unpack_params(w, pack_info), x_batch, y_batch)
                grad_packed, _ = pack_params(grad)
                return nll, grad_packed
            accept_probs.append(mala_acceptance_probability(
                old_param_packed, param_packed, grad_fn_packed, sgld_config.epsilon))
        if t % 200 == 0:
            print(f"Step {t}, nll: {nll}")

init_loss = loss_fn(param_init, x_train, y_train)
loss_trace = [loss_fn(p, x_train, y_train) for p in samples]
lambdahat = (np.mean(loss_trace) - init_loss) * num_training_data * itemp

print(lambdahat)
print(np.mean(accept_probs))

px.line(np.array(loss_trace) * num_training_data * itemp)

Step 200, nll: 138.78526306152344
Step 400, nll: 130.27578735351562
Step 600, nll: 159.24212646484375
Step 800, nll: 142.35482788085938
Step 1000, nll: 149.27452087402344
Step 1200, nll: 147.8551788330078
Step 1400, nll: 158.113525390625
Step 1600, nll: 144.07118225097656
Step 1800, nll: 129.14468383789062
Step 2000, nll: 151.04429626464844
Step 2200, nll: 149.94754028320312
Step 2400, nll: 151.97390747070312
Step 2600, nll: 149.05519104003906
Step 2800, nll: 137.92721557617188
Step 3000, nll: 158.90097045898438
Step 3200, nll: 132.26710510253906
Step 3400, nll: 139.56126403808594
Step 3600, nll: 152.23484802246094
Step 3800, nll: 135.31427001953125
Step 4000, nll: 141.98988342285156
Step 4200, nll: 182.0003662109375
Step 4400, nll: 159.73780822753906
Step 4600, nll: 166.70310974121094
Step 4800, nll: 163.2351837158203
Step 5000, nll: 140.8211212158203
Step 5200, nll: 156.96812438964844
Step 5400, nll: 157.69175720214844
Step 5600, nll: 144.68997192382812
Step 5800, nll: 138.1757354736

In [None]:
samples[0]['deep_linear_network/linear']['w']-param_init['deep_linear_network/linear']['w']

Array([[-8.88779272e-10, -2.22849222e-08, -8.31638687e-08,
        -6.15061615e-08, -4.23443804e-08,  5.43464651e-09,
        -1.04511183e-07,  1.67378615e-08, -1.51794226e-08,
         3.93386130e-08,  2.86129624e-08,  6.35464437e-09,
        -7.26197347e-08,  6.36349151e-09,  3.21397562e-08,
        -4.62708849e-08, -4.80294586e-08, -5.23015506e-08,
        -8.19166996e-08,  7.66127331e-08],
       [-4.21410694e-08, -5.97151057e-08, -5.93780332e-08,
        -4.50242681e-08, -1.66729504e-08,  2.11270210e-08,
        -1.11754801e-07, -1.43660905e-09, -5.29316455e-08,
         1.76894964e-08,  1.62190219e-08,  5.94418681e-09,
        -8.55040658e-08, -4.65295358e-09,  1.13144871e-09,
        -5.28320490e-08, -3.22071791e-08, -8.04133069e-08,
        -5.07399766e-08,  1.21491701e-07],
       [-3.38142997e-08, -5.21598276e-08, -6.41792010e-08,
        -4.83511362e-08, -2.18547254e-08,  1.79595170e-08,
        -1.10292677e-07,  2.23190570e-09, -4.53113640e-08,
         2.20593680e-08,  1.8

In [None]:
(samples[0]['deep_linear_network/linear']['w']-param_init['deep_linear_network/linear']['w'])

Array([[-1.77755055e-11, -4.45698447e-10, -1.66327709e-09,
        -1.23012313e-09, -8.46887360e-10,  1.08693143e-10,
        -2.09022375e-09,  3.34757289e-10, -3.03588195e-10,
         7.86772247e-10,  5.72259253e-10,  1.27092790e-10,
        -1.45239483e-09,  1.27269573e-10,  6.42795044e-10,
        -9.25417680e-10, -9.60588977e-10, -1.04603089e-09,
        -1.63833420e-09,  1.53225468e-09],
       [-8.42821422e-10, -1.19430211e-09, -1.18756077e-09,
        -9.00485304e-10, -3.33458985e-10,  4.22540438e-10,
        -2.23509574e-09, -2.87320745e-11, -1.05863279e-09,
         3.53789886e-10,  3.24380380e-10,  1.18883747e-10,
        -1.71008125e-09, -9.30589295e-11,  2.26290808e-11,
        -1.05664071e-09, -6.44143654e-10, -1.60826602e-09,
        -1.01479955e-09,  2.42983390e-09],
       [-6.76285978e-10, -1.04319668e-09, -1.28358401e-09,
        -9.67022658e-10, -4.37094485e-10,  3.59190340e-10,
        -2.20585378e-09,  4.46381154e-11, -9.06227342e-10,
         4.41187353e-10,  3.7

In [None]:
rngkey = jax.random.PRNGKey(1)
dummy_input = jnp.zeros((1, input_dim))
rand_param = model.init(rngkey, dummy_input)
#rand_param = copy.deepcopy(samples[0])

adjusted_rand_param = copy.deepcopy(rand_param)
adjusted_rand_param['deep_linear_network/linear']['w'] = adjusted_rand_param['deep_linear_network/linear']['w'] * scaling_factor
adjusted_rand_param['deep_linear_network/linear_1']['w'] = adjusted_rand_param['deep_linear_network/linear_1']['w'] / scaling_factor

#sgld_grad_fn(rand_param, x_train[:500], y_train[:500]), sgld_grad_fn(adjusted_rand_param, x_train[:500], y_train[:500])
sgld_grad_fn(adjusted_rand_param, x_train[:500], y_train[:500])[1]['deep_linear_network/linear']['w'] / sgld_grad_fn(rand_param, x_train[:500], y_train[:500])[1]['deep_linear_network/linear']['w']

#adjusted_true_param['deep_linear_network/linear']['w'] / true_param['deep_linear_network/linear']['w']
#sgld_grad_fn(adjusted_true_param, x_train[:500], y_train[:500])[1]['deep_linear_network/linear_1']['w'] / sgld_grad_fn(true_param, x_train[:500], y_train[:500])[1]['deep_linear_network/linear_1']['w']
#sgld_grad_fn(adjusted_true_param, x_train[:500], y_train[:500])[0], sgld_grad_fn(true_param, x_train[:500], y_train[:500])[0]

#param_init = copy.deepcopy(true_param)
#param_init_a = copy.deepcopy(adjusted_true_param)
#param_init['deep_linear_network/linear']['w'] = param_init['deep_linear_network/linear']['w'] + 1e-5
#param_init['deep_linear_network/linear_1']['w'] = param_init['deep_linear_network/linear_1']['w'] + 1e-5
#param_init_a['deep_linear_network/linear']['w'] = param_init_a['deep_linear_network/linear']['w'] + (1e-5*scaling_factor)
#param_init_a['deep_linear_network/linear_1']['w'] = param_init_a['deep_linear_network/linear_1']['w'] + (1e-5/scaling_factor)
##sgld_grad_fn(param_init, x_train[:500], y_train[:500])[0], sgld_grad_fn(param_init_a, x_train[:500], y_train[:500])[0]
##sgld_grad_fn(param_init, x_train[:500], y_train[:500])[1]['deep_linear_network/linear']['w'] / sgld_grad_fn(param_init_a, x_train[:500], y_train[:500])[1]['deep_linear_network/linear']['w']
#sgld_grad_fn(param_init_a, x_train[:500], y_train[:500])[1]['deep_linear_network/linear']['w'] * scaling_factor * (1e-5 / 2)


Array([[0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02,
        0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02],
       [0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02,
        0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02],
       [0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02,
        0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02],
       [0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02,
        0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02],
       [0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02,
        0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02],
       [0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02,
        0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02],
       [0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02,
        0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02],
       [0.02, 0.02, 0.02, 0.02, 0.02, 0.0

In [None]:
samples_thinned = samples[::100]
sing_vals = [np.linalg.svd(s['deep_linear_network/linear']['w'], compute_uv=False) for s in samples_thinned]
px.line(sing_vals)

In [None]:
norms = [[jnp.linalg.norm(l['w']) for l in s.values()] for s in samples]
px.line(norms)

In [None]:
x_samples = jnp.concatenate([s['deep_linear_network/linear']['w'][0] for s in samples])
y_samples = jnp.concatenate([s['deep_linear_network/linear_1']['w'][0] for s in samples])
px.scatter(x=x_samples, y=y_samples)

### MCMC

In [None]:
rngkey = jax.random.PRNGKey(0)

param_init = copy.deepcopy(true_param)
beta = 10
step_size = 1e-2
num_steps = 1

#loss_fn = jax.jit(lambda w: mse_loss(w, model, x_train, y_train))
loss_fn = jax.jit(make_population_loss_fn(true_param))

# Set up HMC
hmc_kernel = HMC(
    potential_fn=lambda param: beta * loss_fn(param),
    step_size=step_size,
    trajectory_length=step_size*num_steps,
    adapt_step_size=False,
    adapt_mass_matrix=False
)
mcmc = MCMC(hmc_kernel, num_samples=10000, num_warmup=0)
rngkey, subkey = jax.random.split(rngkey)
mcmc.run(subkey, init_params=param_init)

# Extract samples
samples = mcmc.get_samples()

init_loss = loss_fn(param_init)
losses = (jax.vmap(loss_fn)(samples))
lambdahat = (np.mean(losses) - init_loss) * beta
print(lambdahat)

px.line(losses * beta)


sample: 100%|██████████| 10000/10000 [01:21<00:00, 123.30it/s, 1 steps of size 1.00e-02. acc. prob=1.00]


132.43018
