In [1]:
import pickle
import os
import argparse
import torch
from jax import random
import json
import datetime
from src.losses import sse_loss
from src.helper import calculate_exact_ggn, tree_random_normal_like, compute_num_params
from src.sampling.predictive_samplers import sample_predictive, sample_hessian_predictive
from jax import numpy as jnp
import jax
from jax import flatten_util
import matplotlib.pyplot as plt
from src.models import LeNet
from src.data.ood_datasets import get_rotated_cifar_loaders, get_cifar10_ood_loaders, load_corrupted_cifar10_per_type, load_corrupted_cifar10
from src.ood_functions.evaluate import evaluate, evaluate_map
from src.ood_functions.metrics import compute_metrics
from src.data import n_classes, MNIST
from collections import defaultdict
from src.data import CIFAR10, n_classes
from src.models import VisionTransformer
from flax import linen as nn



### Load All Models

In [5]:
lr_posterior = pickle.load(open("../checkpoints/CIFAR-10/proj_posterior_samples_big_vit_seed0_params.pickle", "rb"))['posterior_samples']
param_dict = pickle.load(open("../checkpoints/ViT2024-03-30-23-30-29/VisionTransformer_CIFAR10_42_params.pickle", "rb"))
params = param_dict['params']
batch_stats = param_dict['batch_stats']

In [6]:
import tree_math as tm
jax.vmap(lambda x: tm.Vector(x) @ tm.Vector(params)/ tm.Vector(params) @ tm.Vector(params))(lr_posterior)

Array([2.8182974e+11, 2.8042425e+11], dtype=float32)

In [9]:
from src.data.utils import get_mean_and_std


hparams = {
        "embed_dim": 256,
        "hidden_dim": 512,
        "num_heads": 8,
        "num_layers": 6,
        "patch_size": 4,
        "num_channels": 3,
        "num_patches": 64,
        "num_classes": 10,
        "dropout_prob": 0.1,
    }
model = VisionTransformer(**hparams)

n_samples_per_class = None
cls=list(range(10))
train_stats = get_mean_and_std(
        data_train=CIFAR10(path_root='/dtu/p1/hroy/data', set_purp="train", n_samples=None, download=True, cls=cls),
        val_frac=0.1,
        seed=0,
    )

model_fn = lambda p, x: model.apply({'params': p},
                                        x,
                                        train=False,
                                        rngs={'dropout': param_dict['rng']})

# dataset = CIFAR10(path_root='/dtu/p1/hroy/data', train=True, n_samples_per_class=n_samples_per_class, download=True, cls=cls, seed=0)

Files already downloaded and verified


In [10]:
from src.data import get_dataloaders
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

train_loader, val_loader, test_loader = get_dataloaders(
    dataset="CIFAR10",
    bs=500,
    data_path='/dtu/p1/hroy/data',
    seed=0,
    n_samples=None,
    cls=classes,
)
test_set = next(iter(test_loader))


True
Files already downloaded and verified
Normalizing with mean = (0.4911963999523427, 0.4819965876055322, 0.4464323029003272) and  std = (0.24715184509993043, 0.24362299808055643, 0.2616772535146749) 
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [11]:
for i, batch in enumerate(test_loader):
    x_test = batch['image']
    y_test = batch['label']
    if i==10:
        break

In [12]:
predictive_samples = sample_predictive(
    posterior_samples=lr_posterior,
    params=params,
    model_fn=model_fn,
    x_test=x_test,
    linearised_laplace=True,
    posterior_sample_type="Pytree",
)


In [13]:
logits = model_fn(params, x_test)
y_log_prob = jax.nn.log_softmax(logits, axis=-1)
print("MAP NLL:", -jnp.sum(jnp.sum(y_log_prob * y_test, axis=-1), axis=-1).mean())
print("Acc:", (jnp.argmax(y_log_prob, axis=1)==jnp.argmax(y_test, axis=1)).mean())

MAP NLL: 415.3896
Acc: 0.77400005


In [14]:
predictive_samples_mean = jnp.mean(predictive_samples, axis=0)
y_log_prob = jax.nn.log_softmax(predictive_samples_mean, axis=-1)
print("NLL:", -jnp.sum(jnp.sum(y_log_prob * y_test, axis=-1), axis=-1).mean())
print("Acc:", (jnp.argmax(y_log_prob, axis=1)==jnp.argmax(y_test, axis=1)).mean())

NLL: 969.5956
Acc: 0.67


### Rotated CIFAR

In [15]:
eval_args = {}
eval_args["linearised_laplace"] = True
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = [0, 15, 30, 60, 90, 120, 150, 180]#, 210, 240, 270, 300, 330, 345, 360]
n_datapoint=500
ood_batch_size = 500
metrics_lr = []
for i, id in enumerate(ids):
    _, test_loader = get_rotated_cifar_loaders(id, data_path="data", train_stats=train_stats, download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
    some_metrics, all_y_prob, all_y_true, all_y_var = evaluate(test_loader, lr_posterior, params, model_fn, eval_args)
    if i == 0:
        all_y_prob_in = all_y_prob
    more_metrics = compute_metrics(
            i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="R-MNIST"
        )
    metrics_lr.append({**some_metrics, **more_metrics})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_lr[-1].items()]))
    


Files already downloaded and verified


  self.pid = os.fork()
  self.pid = os.fork()


R-MNIST with distribution shift intensity 0


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9191, nll: 457.9601, acc: 0.8040, brier: 0.3039, ece: 0.2274, mce: 0.9165
Files already downloaded and verified
R-MNIST with distribution shift intensity 1


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8524, nll: 4393.4990, acc: 0.1420, brier: 1.5020, ece: 0.6663, mce: 0.9657
Files already downloaded and verified
R-MNIST with distribution shift intensity 2


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8700, nll: 5675.8213, acc: 0.1000, brier: 1.6196, ece: 0.6949, mce: 0.9466
Files already downloaded and verified
R-MNIST with distribution shift intensity 3


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8967, nll: 6006.8208, acc: 0.1080, brier: 1.6512, ece: 0.7143, mce: 0.9851
Files already downloaded and verified
R-MNIST with distribution shift intensity 4


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8387, nll: 4675.3110, acc: 0.1040, brier: 1.5676, ece: 0.6836, mce: 0.9751
Files already downloaded and verified
R-MNIST with distribution shift intensity 5


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8964, nll: 5909.0049, acc: 0.1000, brier: 1.6419, ece: 0.7300, mce: 0.9421
Files already downloaded and verified
R-MNIST with distribution shift intensity 6


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8722, nll: 5910.8945, acc: 0.0980, brier: 1.6246, ece: 0.6880, mce: 0.8969
Files already downloaded and verified
R-MNIST with distribution shift intensity 7


  self.pid = os.fork()
  self.pid = os.fork()


conf: 0.8625, nll: 2110.1187, acc: 0.4440, brier: 0.9436, ece: 0.4534, mce: 0.7185


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


In [5]:
eval_args = {}
eval_args["linearised_laplace"] = False
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = [0, 15, 30, 60, 90, 120, 150, 180]#, 210, 240, 270, 300, 330, 345, 360]
n_datapoint=500
ood_batch_size = 500
metrics_map = []
for i, id in enumerate(ids):
    _, test_loader = get_rotated_cifar_loaders(id, data_path="data", train_stats=train_stats, download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
    some_metrics, all_y_prob, all_y_true, all_y_var = evaluate_map(test_loader, params, model_fn, eval_args)
    if i == 0:
        all_y_prob_in = all_y_prob
    more_metrics = compute_metrics(
            i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="R-MNIST"
        )
    metrics_map.append({**some_metrics, **more_metrics})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_map[-1].items()]))
    


Files already downloaded and verified
R-MNIST with distribution shift intensity 0


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9473, nll: 64.3102, acc: 0.9560, brier: 0.0593, ece: 0.0944, mce: 0.6918
Files already downloaded and verified
R-MNIST with distribution shift intensity 1


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8028, nll: 3083.2002, acc: 0.1340, brier: 1.4450, ece: 0.6527, mce: 0.9649
Files already downloaded and verified
R-MNIST with distribution shift intensity 2


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9168, nll: 4807.0908, acc: 0.0960, brier: 1.6951, ece: 0.7612, mce: 0.9455
Files already downloaded and verified
R-MNIST with distribution shift intensity 3


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9124, nll: 5045.6655, acc: 0.0960, brier: 1.6780, ece: 0.7581, mce: 0.9497
Files already downloaded and verified
R-MNIST with distribution shift intensity 4


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8158, nll: 4073.7588, acc: 0.0960, brier: 1.5423, ece: 0.7035, mce: 0.9547
Files already downloaded and verified
R-MNIST with distribution shift intensity 5


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9172, nll: 4942.9136, acc: 0.0960, brier: 1.6796, ece: 0.7584, mce: 0.9551
Files already downloaded and verified
R-MNIST with distribution shift intensity 6


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8997, nll: 4829.9312, acc: 0.0960, brier: 1.6690, ece: 0.7463, mce: 0.9443
Files already downloaded and verified
R-MNIST with distribution shift intensity 7


  self.pid = os.fork()
  self.pid = os.fork()


conf: 0.8148, nll: 952.3203, acc: 0.5740, brier: 0.6704, ece: 0.2915, mce: 0.7087


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


In [7]:
import pandas as pd
metrics_dict = ['conf', 'nll', 'acc', 'brier', 'ece', 'mce']
method_list = ["Projection Laplace", "MAP"]
method_dict = {"Projection Laplace":metrics_lr, "MAP":metrics_map}
df_data = {metric: ["{:.3f}".format(method_dict[dic][0][metric]) for dic in method_dict] for metric in metrics_dict}
# df_data = {k: ["{:.3f}".format(dic[metric+'_mean']) + u"\u00B1" + "{:.3f}".format(dic[metric+'_std'])  for dic in method_dict[k][1:] for metric in metrics_dict] for k in method_dict}

df = pd.DataFrame.from_dict(df_data, orient='index', columns=method_list)
df = df.T


Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [8]:
df

Unnamed: 0,conf,nll,acc,brier,ece,mce
Projection Laplace,0.919,139.361,0.906,0.141,0.152,0.897
MAP,0.943,83.001,0.95,0.081,0.12,0.744


In [18]:
print(df.to_latex(index=True,
                  formatters={"name": str.upper},
                  float_format="{:.3f}".format,
))  

\begin{tabular}{lllllll}
\toprule
 & conf & nll & acc & brier & ece & mce \\
\midrule
Laplace Diffusion & 0.958±0.005 & 170.385±12.506 & 0.911±0.004 & 0.144±0.008 & 0.237±0.033 & 0.852±0.039 \\
Sampled Laplace & 0.845±0.005 & 438.840±123.643 & 0.742±0.053 & 0.385±0.084 & 0.205±0.045 & 0.799±0.066 \\
Linearised Laplace & 0.954±0.007 & 285.098±31.782 & 0.871±0.012 & 0.209±0.019 & 0.315±0.036 & 0.799±0.032 \\
MAP & 0.962±0.004 & 152.471±23.029 & 0.918±0.009 & 0.134±0.015 & 0.272±0.014 & 0.910±0.032 \\
\bottomrule
\end{tabular}



| Laplace Diffusion |  0.958±0.005 |  170.385±12.506 |  0.911±0.004 |  0.144±0.008 |  0.237±0.033 |  0.852±0.039 | 
| Sampled Laplace |  0.845±0.005 |  438.840±123.643 |  0.742±0.053 |  0.385±0.084 |  0.205±0.045 |  0.799±0.066 | 
| Linearised Laplace |  0.954±0.007 |  285.098±31.782 |  0.871±0.012 |  0.209±0.019 |  0.315±0.036 |  0.799±0.032 | 
| MAP |  0.962±0.004 |  152.471±23.029 |  0.918±0.009 |  0.134±0.015 |  0.272±0.014 |  0.910±0.032 | 


### OOD

In [7]:
eval_args = {}
eval_args["linearised_laplace"] = False
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = ["CIFAR-10", "CIFAR-100", "SVHN"]
n_datapoint=500
ood_batch_size = 500
metrics_lr = []
for i, id in enumerate(ids):
    _, test_loader = get_cifar10_ood_loaders(id, train_stats=train_stats, data_path="data", download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
    some_metrics, all_y_prob, all_y_true, all_y_var = evaluate(test_loader, lr_posterior, params, model_fn, eval_args)
    if i == 0:
        all_y_prob_in = all_y_prob
    more_metrics = compute_metrics(
            i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="CIFAR-10-OOD"
        )
    metrics_lr.append({**some_metrics, **more_metrics})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_lr[-1].items()]))
    


Files already downloaded and verified


  self.pid = os.fork()
  self.pid = os.fork()


CIFAR-10-OOD - dataset: CIFAR-10
conf: 0.6800, nll: 1993.2935, acc: 0.0940
Files already downloaded and verified


  self.pid = os.fork()
  self.pid = os.fork()


TypeError: cannot reshape array of shape (500, 3, 32, 32) (size 1536000) into shape (500, 0, 4, 8, 4, 32) (size 0)

In [22]:
eval_args = {}
eval_args["linearised_laplace"] = False
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = ["CIFAR-10", "CIFAR-100", "SVHN"]
n_datapoint=500
ood_batch_size = 50
metrics_map = []
for i, id in enumerate(ids):
    _, test_loader = get_cifar10_ood_loaders(id, data_path="data", download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
    some_metrics, all_y_prob, all_y_true, all_y_var = evaluate_map(test_loader, params, model_fn, eval_args)
    if i == 0:
        all_y_prob_in = all_y_prob
    more_metrics = compute_metrics(
            i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="CIFAR-10-OOD"
        )
    metrics_map.append({**some_metrics, **more_metrics})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_map[-1].items()]))
    


Files already downloaded and verified


  self.targets = F.one_hot(torch.tensor(self.dataset.targets), len(cls)).numpy()


CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
conf_mean: 0.9616, nll_mean: 152.4712, acc_mean: 0.9180, conf_std: 0.0037, nll_std: 23.0288, acc_std: 0.0091
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
conf_mean: 0.8033, nll_mean: 407.4996, acc_mean: 0.1227, conf_std: 0.0099, nll_std: 15.6155, acc_std: 0.0050, auroc_mean: 0.8766, fpr95_mean: 0.6807, auroc_std: 0.0048, fpr95_std: 0.0229
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
conf_mean: 0.7905, nll_mean: 3739.2761, acc_mean: 0.0840, 

In [26]:
import pandas as pd
metrics_dict = ['conf', 'auroc']
method_list = ["Laplace Diffusion", "Laplace Approximation", "Linearised Laplace", "MAP"]
method_dict = {"Laplace Diffusion":metrics_lr, "Laplace Approximation": metrics_posterior,"Linearised Laplace":metrics_lienarised,  "MAP":metrics_map}
mux = pd.MultiIndex.from_product([ids[1:], metrics_dict])
# df_data = {k: [dic[metric+'_mean'] for metric in metrics_dict for dic in method_dict[k][1:]] for k in method_dict}
df_data = {k: ["{:.3f}".format(dic[metric+'_mean']) + u"\u00B1" + "{:.3f}".format(dic[metric+'_std'])  for dic in method_dict[k][1:] for metric in metrics_dict] for k in method_dict}

df = pd.DataFrame.from_dict(df_data, orient='index',
                       columns=mux)


In [27]:
df

Unnamed: 0_level_0,CIFAR-100,CIFAR-100,SVHN,SVHN
Unnamed: 0_level_1,conf,auroc,conf,auroc
Laplace Diffusion,0.791±0.001,0.856±0.002,0.793±0.019,0.852±0.013
Laplace Approximation,0.706±0.037,0.730±0.039,0.757±0.053,0.667±0.068
Linearised Laplace,0.818±0.004,0.839±0.005,0.805±0.028,0.853±0.020
MAP,0.803±0.010,0.877±0.005,0.790±0.028,0.884±0.013


In [28]:
print(df.to_latex(index=True,
                  formatters={"name": str.upper},
                  float_format="{:.3f}".format,
))  

\begin{tabular}{lllll}
\toprule
 & \multicolumn{2}{r}{CIFAR-100} & \multicolumn{2}{r}{SVHN} \\
 & conf & auroc & conf & auroc \\
\midrule
Laplace Diffusion & 0.791±0.001 & 0.856±0.002 & 0.793±0.019 & 0.852±0.013 \\
Laplace Approximation & 0.706±0.037 & 0.730±0.039 & 0.757±0.053 & 0.667±0.068 \\
Linearised Laplace & 0.818±0.004 & 0.839±0.005 & 0.805±0.028 & 0.853±0.020 \\
MAP & 0.803±0.010 & 0.877±0.005 & 0.790±0.028 & 0.884±0.013 \\
\bottomrule
\end{tabular}



| Laplace Diffusion | 0.791±0.001 | 0.856±0.002 | 0.793±0.019 | 0.852±0.013 |
| Laplace Approximation | 0.706±0.037 | 0.730±0.039 | 0.757±0.053 | 0.667±0.068 |
| Linearised Laplace | 0.818±0.004 | 0.839±0.005 | 0.805±0.028 | 0.853±0.020 |
| MAP | 0.803±0.010 | 0.877±0.005 | 0.790±0.028 | 0.884±0.013 |
