In [1]:
import pickle
import os
import argparse
import torch
from jax import random
import json
import datetime
from src.losses import sse_loss, cross_entropy_loss
from src.helper import calculate_exact_ggn, tree_random_normal_like
from src.sampling.predictive_samplers import sample_predictive, sample_hessian_predictive
from jax import numpy as jnp
import jax
from jax import flatten_util
import matplotlib.pyplot as plt
import tree_math as tm
from src.laplace.diagonal import hutchinson_diagonal

In [2]:
def f(x):
    return jnp.sin(5 * x + 1) #+ jnp.cos(25 * x + 1) + jnp.exp(0.1 * x) + 5
param_dict = pickle.load(open("../checkpoints/syntetic_regression.pickle", "rb"))
params = param_dict['params']
alpha = param_dict['alpha']
rho = param_dict['rho']
x_train, y_train, x_val, y_val, model, D = param_dict["train_stats"]['x_train'],param_dict["train_stats"]['y_train'],param_dict["train_stats"]['x_val'],param_dict["train_stats"]['y_val'],param_dict["train_stats"]['model'], param_dict["train_stats"]['n_params']

In [3]:
sample_key = jax.random.PRNGKey(100)
model_fn = lambda params, x: model.apply(params, x[None, ...])[0]
n_params = D 
def sse_loss(preds, y):
    residual = preds - y
    return 0.5 * jnp.sum(residual**2)

ggn = calculate_exact_ggn(sse_loss, model_fn, params, x_train, y_train, n_params)
true_diag = jnp.diag(ggn)

In [4]:
model_fn = model.apply
gvp_batch_size = 25
N = x_train.shape[0]//gvp_batch_size
data_array = x_train[: N * gvp_batch_size].reshape((N, gvp_batch_size)+ x_train.shape[1:])
n_samples = 10000
diag_hutch = hutchinson_diagonal(model_fn, params, gvp_batch_size, n_samples, sample_key, data_array, "regression", num_levels=5, computation_type="serial")
diag_hutch, _ = jax.flatten_util.ravel_pytree(diag_hutch)
print(jnp.linalg.norm(diag_hutch - true_diag) / jnp.linalg.norm(true_diag))

0.014148062


In [5]:
x_train.shape

(100, 1)

In [6]:
from src.laplace.diagonal import exact_diagonal
output_dim = 1
exact_diag = exact_diagonal(model.apply, params, output_dim, x_train, "regression")

In [7]:
exact_diag, _ = jax.flatten_util.ravel_pytree(exact_diag)
print(jnp.linalg.norm(exact_diag - true_diag) / jnp.linalg.norm(true_diag))

7.1693826e-08


### Classification

In [8]:
import optax
from src.models import ConvNet
from src.data import get_mnist
from src.losses import cross_entropy_loss

def accuracy(v, x, y):
    logits = model_fn(v, x)
    return jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1))

model = ConvNet(10)
batch_size = 100
train_loader, val_loader, _ = get_mnist(batch_size, n_samples_per_class=100)
val_img, val_label = next(iter(val_loader))['image'], next(iter(val_loader))['label']
params = model.init(random.PRNGKey(0), next(iter(train_loader))['image'])
variables, unflatten = jax.flatten_util.ravel_pytree(params)
# model_fn = lambda vec, x: model.apply(unflatten(vec), x)
model_fn = model.apply
def loss_fn(v, x, y):
    logits = model_fn(v, x)
    return cross_entropy_loss(logits, y)

value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))

train_lrate = 1e-2
optimizer = optax.adam(train_lrate)
# optimizer_state = optimizer.init(variables)
optimizer_state = optimizer.init(params)
n_epochs = 10

for epoch in range(n_epochs):
    for batch in train_loader:
        img, label = batch['image'], batch['label']
        loss, grad = value_and_grad_fn(params, img, label)
        updates, optimizer_state = optimizer.update(grad, optimizer_state)
        params = optax.apply_updates(params, updates)
    
    acc = accuracy(params, val_img, val_label)
    print(f"Epoch {epoch}, loss {loss:.3f}, accuracy {acc:.3f}")



  self.targets = torch.nn.functional.one_hot(torch.tensor(self.targets), len(classes)).numpy()
  self.pid = os.fork()
  self.pid = os.fork()


Epoch 0, loss 226.290, accuracy 0.160


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1, loss 214.249, accuracy 0.320


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 2, loss 196.323, accuracy 0.420


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 3, loss 172.447, accuracy 0.440


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 4, loss 145.665, accuracy 0.530


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 5, loss 121.063, accuracy 0.610


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 6, loss 104.405, accuracy 0.610


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 7, loss 93.491, accuracy 0.650


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 8, loss 85.288, accuracy 0.690


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 9, loss 78.926, accuracy 0.700


In [9]:
#Exact GGN:
variables, unflatten = jax.flatten_util.ravel_pytree(params)
model_vec_fn = lambda vec, x: model.apply(unflatten(vec), x)

n_params = len(variables)
b = batch_size
o = 10
GGN = 0
for batch in train_loader:
    img, label = batch['image'], batch['label']
    preds = model_vec_fn(variables, img)
    J = jax.jacfwd(model_vec_fn, argnums=0)(variables, img)
    H = jax.hessian(cross_entropy_loss, argnums=0)(preds, label)
    J = J.reshape(b * o, n_params)
    H = H.reshape(b * o, b * o)
    GGN += J.T @ H @ J

exact_diag = jnp.diag(GGN)


  self.pid = os.fork()


  self.pid = os.fork()


In [10]:
n_samples = 200
sample_key = random.PRNGKey(0)
diag_hutch = 0
gvp_batch_size = 50
assert batch_size % gvp_batch_size == 0
for batch in train_loader:
    img, label = batch['image'], batch['label']
    data_array = jnp.asarray(img).reshape((-1, gvp_batch_size) +  img.shape[1:])
    diag_hutch += jax.flatten_util.ravel_pytree(hutchinson_diagonal(model.apply, params, gvp_batch_size, n_samples, sample_key, data_array, "classification", num_levels=2, computation_type="parallel"))[0]

print("Hutchinson Relative Error:", jnp.linalg.norm(exact_diag - diag_hutch)/ jnp.linalg.norm(exact_diag))


  self.pid = os.fork()


  self.pid = os.fork()


Hutchinson Relative Error: 0.07119679


In [11]:
diag_autograd = 0
output_dim = 10
for batch in train_loader:
    img, label = batch['image'], batch['label']
    img, label = jnp.asarray(img), jnp.asarray(label)
    diag_autograd += jax.flatten_util.ravel_pytree(exact_diagonal(model.apply, params, output_dim, img, "classification"))[0]

print("Autograd Relative Error:", jnp.linalg.norm(exact_diag - diag_autograd)/ jnp.linalg.norm(exact_diag))


  self.pid = os.fork()


  self.pid = os.fork()


Autograd Relative Error: 9.113158e-08


### Samples

In [16]:
from src.sampling.diagonal_lapalce_sampling import sample_exact_diagonal, sample_hutchinson

posterior, metrics = sample_exact_diagonal(model.apply, params, 30, 1.0, train_loader, 0,10)


  self.pid = os.fork()
  self.pid = os.fork()
                                                       

In [17]:
params_vec, unflatten = jax.flatten_util.ravel_pytree(params)
model_fn_vec = lambda p, x: model.apply(unflatten(p), x)
x_test = next(iter(val_loader))['image']
y_test = next(iter(val_loader))['label']
logits = jax.vmap(model_fn_vec, in_axes=(0, None))(posterior, x_test)
preds = jax.nn.softmax(logits, axis=-1)
acc_map = jnp.mean(jnp.argmax(model.apply(params, x_test), axis=-1) == jnp.argmax(y_test, axis=-1))
acc_posterior = jnp.array(jax.vmap(lambda y_pred: jnp.mean(jnp.argmax(y_pred, axis=-1) == jnp.argmax(y_test, axis=-1)))(preds)).mean()

In [18]:
print("MAP Accuracy", acc_map)
print("Posterior Accuracy", acc_posterior)

MAP Accuracy 0.7
Posterior Accuracy 0.64166665


In [19]:
posterior, metrics = sample_hutchinson(model.apply, params, 30, 1.0, 50, train_loader, 0, 3, 200, "classification", computation_type="parallel")


  self.pid = os.fork()
  self.pid = os.fork()
                                                       

In [20]:
params_vec, unflatten = jax.flatten_util.ravel_pytree(params)
model_fn_vec = lambda p, x: model.apply(unflatten(p), x)
x_test = next(iter(val_loader))['image']
y_test = next(iter(val_loader))['label']
logits = jax.vmap(model_fn_vec, in_axes=(0, None))(posterior, x_test)
preds = jax.nn.softmax(logits, axis=-1)
acc_map = jnp.mean(jnp.argmax(model.apply(params, x_test), axis=-1) == jnp.argmax(y_test, axis=-1))
acc_posterior = jnp.array(jax.vmap(lambda y_pred: jnp.mean(jnp.argmax(y_pred, axis=-1) == jnp.argmax(y_test, axis=-1)))(preds)).mean()

In [21]:
print("MAP Accuracy", acc_map)
print("Posterior Accuracy", acc_posterior)

MAP Accuracy 0.7
Posterior Accuracy 0.69633335
