In [1]:
import pickle
import os
import argparse
import torch
from jax import random
import json
import datetime
from src.losses import sse_loss, cross_entropy_loss
from src.helper import calculate_exact_ggn, tree_random_normal_like
from src.sampling.predictive_samplers import sample_predictive, sample_hessian_predictive
from jax import numpy as jnp
import jax
from jax import flatten_util
import matplotlib.pyplot as plt
import tree_math as tm
from src.laplace.last_layer.extract_last_layer import last_layer_ggn

In [2]:
def f(x):
    return jnp.sin(5 * x + 1) #+ jnp.cos(25 * x + 1) + jnp.exp(0.1 * x) + 5
param_dict = pickle.load(open("../checkpoints/syntetic_regression.pickle", "rb"))
params = param_dict['params']
alpha = param_dict['alpha']
rho = param_dict['rho']
x_train, y_train, x_val, y_val, model, D = param_dict["train_stats"]['x_train'],param_dict["train_stats"]['y_train'],param_dict["train_stats"]['x_val'],param_dict["train_stats"]['y_val'],param_dict["train_stats"]['model'], param_dict["train_stats"]['n_params']

In [3]:
sample_key = jax.random.PRNGKey(100)
model_fn = lambda params, x: model.apply(params, x[None, ...])[0]
n_params = D 
def sse_loss(preds, y):
    residual = preds - y
    return 0.5 * jnp.sum(residual**2)

ggn = calculate_exact_ggn(sse_loss, model_fn, params, x_train, y_train, n_params)


In [4]:
leafs, _ = jax.tree_util.tree_flatten(params)
N_llla = len(leafs[-1]) + len(leafs[-2])
ggn_ll = ggn[-N_llla:, -N_llla:]

In [5]:
ggn_ll_2 = last_layer_ggn(model.apply, params, x_train, "regression")
jnp.allclose(ggn_ll_2, ggn_ll)

Array(True, dtype=bool)

### Classification

In [6]:
import optax
from src.models import ConvNet
from src.data import get_mnist
from src.losses import cross_entropy_loss

def accuracy(v, x, y):
    logits = model_fn(v, x)
    return jnp.mean(jnp.argmax(logits, axis=-1) == jnp.argmax(y, axis=-1))

model = ConvNet(10)
batch_size = 100
train_loader, val_loader, _ = get_mnist(batch_size, n_samples_per_class=100)
val_img, val_label = next(iter(val_loader))['image'], next(iter(val_loader))['label']
params = model.init(random.PRNGKey(0), next(iter(train_loader))['image'])
variables, unflatten = jax.flatten_util.ravel_pytree(params)
# model_fn = lambda vec, x: model.apply(unflatten(vec), x)
model_fn = model.apply
def loss_fn(v, x, y):
    logits = model_fn(v, x)
    return cross_entropy_loss(logits, y)

value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))

train_lrate = 1e-2
optimizer = optax.adam(train_lrate)
# optimizer_state = optimizer.init(variables)
optimizer_state = optimizer.init(params)
n_epochs = 10

for epoch in range(n_epochs):
    for batch in train_loader:
        img, label = batch['image'], batch['label']
        loss, grad = value_and_grad_fn(params, img, label)
        updates, optimizer_state = optimizer.update(grad, optimizer_state)
        params = optax.apply_updates(params, updates)
    
    acc = accuracy(params, val_img, val_label)
    print(f"Epoch {epoch}, loss {loss:.3f}, accuracy {acc:.3f}")



  self.targets = torch.nn.functional.one_hot(torch.tensor(self.targets), len(classes)).numpy()
  self.pid = os.fork()
  self.pid = os.fork()


Epoch 0, loss 226.290, accuracy 0.160


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 1, loss 214.249, accuracy 0.320


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 2, loss 196.323, accuracy 0.420


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 3, loss 172.447, accuracy 0.440


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 4, loss 145.665, accuracy 0.530


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 5, loss 121.063, accuracy 0.610


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 6, loss 104.405, accuracy 0.610


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 7, loss 93.491, accuracy 0.650


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 8, loss 85.288, accuracy 0.690


  self.pid = os.fork()
  self.pid = os.fork()


Epoch 9, loss 78.926, accuracy 0.700


In [10]:
#Exact GGN:
variables, unflatten = jax.flatten_util.ravel_pytree(params)
model_vec_fn = lambda vec, x: model.apply(unflatten(vec), x)

n_params = len(variables)
b = batch_size
o = 10
GGN = 0
for batch in train_loader:
    img, label = batch['image'], batch['label']
    preds = model_vec_fn(variables, img)
    pred = jax.nn.softmax(preds, axis=1)
    pred = jax.lax.stop_gradient(pred)
    D = jax.vmap(jnp.diag)(pred)
    H = jnp.einsum("bo, bi->boi", pred, pred)
    H = D - H # B, O, O

    J = jax.jacfwd(model_vec_fn, argnums=0)(variables, img)
    # J = J.reshape(b * o, n_params)
    # H = H.reshape(b * o, b * o)
    GGN = jnp.einsum("mob, boo, bon->mn", J.T, H, J)
    # GGN += J.T @ H @ J

leafs, _ = jax.tree_util.tree_flatten(params)
N_llla = len(leafs[-1]) + len(leafs[-2])
ggn_ll = ggn[-N_llla:, -N_llla:]

  self.pid = os.fork()


  self.pid = os.fork()


In [11]:
params_vec, unflatten_fn = jax.flatten_util.ravel_pytree(params)
def model_apply_vec(params_vectorized, x):
    return model_fn(unflatten_fn(params_vectorized), x)

def last_layer_model_fn(last_params_vec, first_params, x):
    first_params = jax.lax.stop_gradient(first_params)
    vectorized_params = jnp.concatenate([first_params, last_params_vec])
    return model_apply_vec(vectorized_params, x)

params_ll = params_vec[-N_llla:]
ggn_ll_2 = 0
for batch in train_loader:
    img, label = batch['image'], batch['label']
    img = jnp.asarray(img)
    pred_vec = model_apply_vec(params_vec, img)
    J_ll = jax.jacfwd(last_layer_model_fn, argnums=0)(params_ll, params_vec[:-N_llla], img)
    H = jax.hessian(cross_entropy_loss, argnums=0)(pred_vec, label)
    H = H.reshape(b * o, b * o)
    J_ll = J_ll.reshape(b * o, N_llla)
    ggn_ll_2 += J_ll.T @ H @ J_ll
    # ggn_ll_2 += last_layer_ggn(model.apply, params, img, "classification")

  self.pid = os.fork()


  self.pid = os.fork()


In [12]:
ggn_ll_2 = 0
for batch in train_loader:
    img, label = batch['image'], batch['label']
    img = jnp.asarray(img)
    ggn_ll_2 += last_layer_ggn(model.apply, params, img, "classification")

  self.pid = os.fork()


  self.pid = os.fork()


In [13]:
jnp.allclose(ggn_ll_2, ggn_ll)

Array(False, dtype=bool)

In [None]:
print(jnp.linalg.norm(ggn_ll - ggn_ll_2)/ jnp.linalg.norm(ggn_ll))

1.009508


### Samples

In [17]:
from src.sampling.last_layer_sampling import sample_last_layer

posterior, metrics = sample_last_layer(model.apply, params, 30, 1.0, train_loader, 0, "classification")


  self.pid = os.fork()
  self.pid = os.fork()


In [19]:
posterior.shape

(30, 26)

In [29]:
x_test = next(iter(val_loader))['image']
y_test = next(iter(val_loader))['label']
def last_layer_predictive(x_test, posterior, params):
    leafs, _ = jax.tree_util.tree_flatten(params)
    N_llla = len(leafs[-1]) + len(leafs[-2])
    params_vec, _ = jax.flatten_util.ravel_pytree(params)
    logits = jax.vmap(lambda p: last_layer_model_fn(p, params_vec[:-N_llla], x_test))(posterior)
    return logits

  self.pid = os.fork()


  self.pid = os.fork()


In [30]:
logits = last_layer_predictive(x_test, posterior, params)
preds = jax.nn.softmax(logits, axis=-1)

In [31]:
acc_map = jnp.mean(jnp.argmax(model.apply(params, x_test), axis=-1) == jnp.argmax(y_test, axis=-1))

In [33]:
acc_posterior = jnp.array(jax.vmap(lambda y_pred: jnp.mean(jnp.argmax(y_pred, axis=-1) == jnp.argmax(y_test, axis=-1)))(preds)).mean()

In [34]:
acc_posterior

Array(0.6526666, dtype=float32)

In [36]:
acc_map

Array(0.7, dtype=float32)