In [5]:
import time
import argparse
import jax
import matplotlib.pyplot as plt
import optax
import matfree
import tree_math as tm
from flax import linen as nn
from jax import nn as jnn
from jax import numpy as jnp
from jax import random, jit
import pickle
from src.losses import mse_loss
from src.helper import calculate_exact_ggn, tree_random_normal_like
from src.sampling.predictive_samplers import sample_predictive, sample_hessian_predictive
from jax import flatten_util
import matplotlib.pyplot as plt

import torch


In [6]:
from src.data.torch_datasets import MNIST, numpy_collate_fn

## OOD dtasets

In [7]:
from src.data.datasets import (
    get_rotated_mnist_loaders,
    get_rotated_fmnist_loaders,
    get_rotated_cifar_loaders,
    load_corrupted_cifar10,
    load_corrupted_cifar10_per_type,
    get_mnist_ood_loaders,
    get_cifar10_ood_loaders,
    get_cifar10_train_set,
)

In [8]:
train_samples = 30#1000
classes_train = [0,1,2,3,4,5,6,7,8,9]
n_classes = 10
batch_size = 20#256
test_batch_size = 256

data_train = MNIST(path_root= "/work3/hroy/data/",
            train=True, n_samples=train_samples if train_samples > 0 else None, cls=classes_train
        )
data_test = MNIST(path_root = "/work3/hroy/data/", train=False, cls=classes_train)

if train_samples > 0:
    N = train_samples * n_classes
else:
    N = len(data_train)
N_test = len(data_test)
if test_batch_size > 0:
    test_batch_size = test_batch_size
else:
    test_batch_size = len(data_test)

n_test_batches = int(N_test / test_batch_size)
n_batches = int(N / batch_size)

train_loader = torch.utils.data.DataLoader(
    data_train, batch_size=batch_size, shuffle=True, collate_fn=numpy_collate_fn, drop_last=True,
)

valid_loader = torch.utils.data.DataLoader(
    data_test, batch_size=test_batch_size, shuffle=True, collate_fn=numpy_collate_fn, drop_last=True,
)

In [9]:
class ConvNet(nn.Module):
    output_dim: int = 10

    @nn.compact
    def __call__(self, x):
        if len(x.shape) != 4:
            x = jnp.expand_dims(x, 0)
        x = jnp.transpose(x, (0, 2, 3, 1))
        x = nn.Conv(features=4, kernel_size=(3, 3), strides=(2, 2), padding=1)(x)
        x = nn.tanh(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=4, kernel_size=(3, 3), strides=(2, 2), padding=1)(x)
        x = nn.tanh(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))
        return nn.Dense(features=self.output_dim)(x)

def compute_num_params(pytree):
    return sum(x.size if hasattr(x, "size") else 0 for x in jax.tree_util.tree_leaves(pytree))


model = ConvNet()
batch = next(iter(train_loader))
x_init, y_init = batch["image"], batch["label"]
output_dim = y_init.shape[-1]
key, split_key = random.split(jax.random.PRNGKey(0))
params = model.init(key, x_init)
alpha = 1.
optim = optax.chain(
        optax.clip(1.),
        getattr(optax, "adam")(1e-2),
    )
opt_state = optim.init(params)
n_params = compute_num_params(params)
n_epochs = 100

In [10]:
def cross_entropy_loss(preds, y, rho=1.0):
    """
    preds: (n_samples, n_classes) (logits)
    y: (n_samples, n_classes) (one-hot labels)
    """
    preds = preds * rho
    preds = jax.nn.log_softmax(preds, axis=-1)
    return -jnp.sum(jnp.sum(preds * y, axis=-1))

def accuracy(params, model, batch_x, batch_y):
    preds = model.apply(params, batch_x)
    return jnp.sum(preds.argmax(axis=-1) == batch_y.argmax(axis=-1))


def map_loss(
    params,
    model,
    x_batch,
    y_batch,
    alpha,
    n_params: int,
    N_datapoints_max: int,
):
    # define dict for logging purposes
    B = x_batch.shape[0]
    O = y_batch.shape[-1]
    D = n_params
    N = N_datapoints_max

    # hessian_scaler = 1

    vparams = tm.Vector(params)

    rho = 1.
    nll = lambda x, y, rho: 1/B * cross_entropy_loss(x, y, rho)

    y_pred = model.apply(params, x_batch)

    loglike_loss = nll(y_pred, y_batch, rho) #* hessian_scaler

    log_prior_term = -D / 2 * jnp.log(2 * jnp.pi) - (1 / 2) * alpha * (vparams @ vparams) + D / 2 * jnp.log(alpha)
    # log_det_term = 0
    loss = loglike_loss - 0. * log_prior_term

    return loss

def make_step(params, alpha, opt_state, x, y):
    grad_fn = jax.value_and_grad(map_loss, argnums=0, has_aux=False)
    loss, grads = grad_fn(params, model, x, y, alpha, n_params, N)
    param_updates, opt_state = optim.update(grads, opt_state)
    params = optax.apply_updates(params, param_updates)
    return loss, params, opt_state

jit_make_step = jit(make_step)



In [11]:
for epoch in range(1, n_epochs + 1):
    epoch_loss = 0
    epoch_accuracy = 0
    start_time = time.time()
    for _, batch in zip(range(n_batches), train_loader):
        X = batch["image"]
        y = batch["label"]
        B = X.shape[0]
        train_key, split_key = random.split(split_key)

        loss, params, opt_state = jit_make_step(params, alpha, opt_state, X, y)
        loss = loss
        epoch_loss += loss.item()

        epoch_accuracy += accuracy(params, model, X, y).item()

    epoch_accuracy /= (n_batches * B)
    epoch_time = time.time() - start_time
    print(
        f"epoch={epoch}, loss={epoch_loss:.3f}, , accuracy={epoch_accuracy:.2f}, alpha={alpha:.2f}, time={epoch_time:.3f}s"
    )


epoch=1, loss=34.398, , accuracy=0.19, alpha=1.00, time=1.365s
epoch=2, loss=32.283, , accuracy=0.33, alpha=1.00, time=0.249s
epoch=3, loss=29.460, , accuracy=0.43, alpha=1.00, time=0.251s
epoch=4, loss=26.536, , accuracy=0.51, alpha=1.00, time=0.242s
epoch=5, loss=23.831, , accuracy=0.57, alpha=1.00, time=0.241s
epoch=6, loss=21.496, , accuracy=0.59, alpha=1.00, time=0.242s
epoch=7, loss=19.704, , accuracy=0.64, alpha=1.00, time=0.256s
epoch=8, loss=18.153, , accuracy=0.65, alpha=1.00, time=0.242s
epoch=9, loss=16.757, , accuracy=0.69, alpha=1.00, time=0.241s
epoch=10, loss=15.860, , accuracy=0.70, alpha=1.00, time=0.241s
epoch=11, loss=14.923, , accuracy=0.74, alpha=1.00, time=0.247s
epoch=12, loss=14.055, , accuracy=0.75, alpha=1.00, time=0.239s
epoch=13, loss=13.380, , accuracy=0.77, alpha=1.00, time=0.246s
epoch=14, loss=12.759, , accuracy=0.76, alpha=1.00, time=0.271s
epoch=15, loss=12.163, , accuracy=0.79, alpha=1.00, time=0.239s
epoch=16, loss=11.796, , accuracy=0.79, alpha=1.0

In [12]:
sampling_train_loader = torch.utils.data.DataLoader(
    data_train, batch_size=N, shuffle=True, collate_fn=numpy_collate_fn, drop_last=True,
)
data = next(iter(sampling_train_loader))
x_train = jnp.array(data["image"])
y_train = jnp.array(data["label"])
sampling_val_loader = torch.utils.data.DataLoader(
    data_test, batch_size=N_test, shuffle=True, collate_fn=numpy_collate_fn, drop_last=True,
)
data = next(iter(sampling_val_loader))
x_val = jnp.array(data["image"])
y_val = jnp.array(data["label"])

sample_key = jax.random.PRNGKey(0)
n_posterior_samples = 200
num_iterations = 1
n_sample_batch_size = 1
n_sample_batches = N // n_sample_batch_size

## Laplace shit

In [13]:
from src.sampling.exact_ggn import exact_ggn_laplace
from src.sampling.laplace_ode import ode_ggn
from src.sampling.lanczos_diffusion import lanczos_diffusion


In [14]:
_model_fn = lambda params, x: model.apply(params, x[None, ...])[0]
ggn = calculate_exact_ggn(cross_entropy_loss, _model_fn, params, x_train, y_train, n_params)


In [15]:
eigvals, eigvecs = jnp.linalg.eigh(ggn)

In [16]:
alpha = 1.0
rank = 100

def ggn_lr_vp(v):
    return eigvecs[:,-rank:] @ jnp.diag(1/jnp.sqrt(eigvals[-rank:]+ alpha)) @ v

def ggn_vp(v):
    return eigvecs @ jnp.diag(1/jnp.sqrt(eigvals + alpha)) @ v

In [17]:
n_posterior_samples = 20
D = compute_num_params(params)
sample_key = jax.random.PRNGKey(0)
eps = jax.random.normal(sample_key, (n_posterior_samples, D))
p0_flat, unravel_func_p = flatten_util.ravel_pytree(params)
var = 0.1
def get_posteriors(single_eps):
    lr_sample = unravel_func_p(ggn_lr_vp(single_eps[:rank]) + p0_flat)
    posterior_sample = unravel_func_p(ggn_vp(single_eps) + p0_flat)
    isotropic_sample = unravel_func_p(var * single_eps + p0_flat)
    return lr_sample, posterior_sample, isotropic_sample
lr_posterior_samples, posterior_samples, isotropic_posterior_samples = jax.vmap(get_posteriors)(eps)

## Sampled Laplace

In [18]:
predictive = sample_predictive(lr_posterior_samples, params, model, x_val, False, "Pytree")


In [19]:
predictive.shape

(20, 10000, 10)

In [20]:
accuracy(params, model, x_val, y_val)/x_val.shape[0]

Array(0.6786, dtype=float32)

In [21]:
def accuracy_preds(preds, batch_y):
    return jnp.sum(preds.argmax(axis=-1) == batch_y.argmax(axis=-1))
accuracies = jax.vmap(accuracy_preds, in_axes=(0,None))(predictive, y_val)
accuracies /= x_val.shape[0]

In [22]:
jnp.mean(accuracies)

Array(0.64703, dtype=float32)

## Lanczos Diffusion

In [9]:
from src.sampling.low_rank import lanczos_tridiag
from typing import Callable, Literal, Optional
from functools import partial

def get_gvp_fun(
    model_fn: Callable,
    loss_fn: Callable,
    params,
    x,
    y
  ) -> Callable:

  def gvp(eps):
    def scan_fun(carry, batch):
      x_, y_ = batch
      fn = lambda p: model_fn(p,x_[None,:])
      loss_fn_ = lambda preds: loss_fn(preds, y_)
      out, Je = jax.jvp(fn, (params,), (eps,))
      _, HJe = jax.jvp(jax.jacrev(loss_fn_, argnums=0), (out,), (Je,))
      _, vjp_fn = jax.vjp(fn, params)
      value = vjp_fn(HJe)[0]
      return jax.tree_map(lambda c, v: c + v, carry, value), None
    init_value = jax.tree_map(lambda x: jnp.zeros_like(x), params)
    return jax.lax.scan(scan_fun, init_value, (x, y))[0]
  p0_flat, unravel_func_p = jax.flatten_util.ravel_pytree(params)
  def matvec(v_like_params):
    p_unravelled = unravel_func_p(v_like_params)
    ggn_vp = gvp(p_unravelled)
    f_eval, _ = jax.flatten_util.ravel_pytree(ggn_vp)
    return f_eval
  return matvec

gvp = get_gvp_fun(model.apply, mse_loss, params, x_train, y_train)

In [11]:
n_steps = 20
n_samples = 50
alpha = 10.0
rank = 7
eps = jax.random.normal(sample_key, (n_samples, n_steps, rank))
p0_flat, unravel_func_p = jax.flatten_util.ravel_pytree(params)
# rank = 100
# alpha = 0.1
v0 = jnp.ones(n_params)*5

@jax.jit
def rw_nonker(single_eps_path):
    params_ = p0_flat
    posterior_list = [params]
    def body_fun(n, res):
        gvp = get_gvp_fun(model.apply, mse_loss, unravel_func_p(res), x_train, y_train)
        _, eigvecs = lanczos_tridiag(gvp, v0, rank - 1)
        # ggn = calculate_exact_ggn(mse_loss, _model_fn, unravel_func_p(res), x_train, y_train, D)
        # _, eigvecs = jnp.linalg.eigh(ggn)
        lr_sample = 1/jnp.sqrt(alpha) * eigvecs @ single_eps_path[n]
        params_ = res + 1/jnp.sqrt(n_steps) * lr_sample
        return params_
    v_ = jax.lax.fori_loop(0, n_steps - 1, body_fun, params_)
    return unravel_func_p(v_)
nonker_posterior_samples = jax.vmap(rw_nonker)(eps)#jax.lax.map(rw, eps)

In [13]:
n_steps = 20
n_samples = 50
alpha = 10.0
rank = 10
eps = jax.random.normal(sample_key, (n_samples, n_steps, n_params))
p0_flat, unravel_func_p = jax.flatten_util.ravel_pytree(params)
# rank = 100
# alpha = 0.1
v0 = jnp.ones(n_params)*5
delta = 1.0

@jax.jit
def rw_ker(single_eps_path):
    params_ = p0_flat
    posterior_list = [params]
    def body_fun(n, res):
        gvp = get_gvp_fun(model.apply, mse_loss, unravel_func_p(res), x_train, y_train)
        gvp_ = lambda v: gvp(v) + delta * v
        eigvals, eigvecs = lanczos_tridiag(gvp_, v0, rank - 1)
        # ggn = calculate_exact_ggn(mse_loss, _model_fn, unravel_func_p(res), x_train, y_train, D)
        # _, eigvecs = jnp.linalg.eigh(ggn)
        lr_sample = 1/jnp.sqrt(alpha) * 1/delta * (gvp_(single_eps_path[n]) - eigvecs @ jnp.diag(eigvals) @ eigvecs.T @ single_eps_path[n])
        params_ = res + 1/jnp.sqrt(n_steps) * lr_sample
        return params_
    v_ = jax.lax.fori_loop(0, n_steps - 1, body_fun, params_)
    return unravel_func_p(v_)
ker_posterior_samples = jax.vmap(rw_ker)(eps)#jax.lax.map(rw, eps)

In [10]:
n_steps = 20
n_samples = 50
alpha = 10.0
rank = 50
nonker_posterior_samples = lanczos_diffusion(cross_entropy_loss, model.apply, params,n_steps,n_samples,alpha,sample_key,n_params,rank,x_train,y_train,1.0,"non-kernel-eigvals")

In [15]:
def accuracy_preds(preds, batch_y):
    return jnp.sum(preds.argmax(axis=-1) == batch_y.argmax(axis=-1))
predictive_nonker = sample_predictive(nonker_posterior_samples, params, model, x_val, False, "Pytree")
print(accuracy(params, model, x_val, y_val)/x_val.shape[0])
nonker_accuracies = jax.vmap(accuracy_preds, in_axes=(0,None))(predictive_nonker, y_val)
nonker_accuracies /= x_val.shape[0]
print(jnp.mean(nonker_accuracies))


0.6318
0.40998


In [16]:
predictive_ker = sample_predictive(ker_posterior_samples, params, model, x_val, False, "Pytree")
predictive_nonker = sample_predictive(nonker_posterior_samples, params, model, x_val, False, "Pytree")


In [20]:
accuracy(params, model, x_val, y_val)/x_val.shape[0]

Array(0.6318, dtype=float32)

In [19]:
ker_accuracies = jax.vmap(accuracy_preds, in_axes=(0,None))(predictive_ker, y_val)
ker_accuracies /= x_val.shape[0]
nonker_accuracies = jax.vmap(accuracy_preds, in_axes=(0,None))(predictive_nonker, y_val)
nonker_accuracies /= x_val.shape[0]
print(jnp.mean(ker_accuracies))
print(jnp.mean(nonker_accuracies))


0.099659994
0.40998
