In [2]:
import pickle
import os
import argparse
import torch
from jax import random
import json
import datetime
from src.losses import sse_loss
from src.helper import calculate_exact_ggn, tree_random_normal_like, compute_num_params
from src.sampling.predictive_samplers import sample_predictive, sample_hessian_predictive
from jax import numpy as jnp
import jax
from jax import flatten_util
import matplotlib.pyplot as plt
from src.models import LeNet
from src.data.datasets import get_rotated_cifar_loaders, get_cifar10_ood_loaders, load_corrupted_cifar10_per_type, load_corrupted_cifar10
from src.ood_functions.evaluate import evaluate, evaluate_map
from src.ood_functions.metrics import compute_metrics
from src.data import n_classes, MNIST
from collections import defaultdict
from src.data import CIFAR10, n_classes
from src.models import VisionTransformer
from flax import linen as nn



### Load All Models

In [13]:
lr_posterior = pickle.load(open("../checkpoints/CIFAR-10/proj_posterior_samples_big_vit_seed0_params.pickle", "rb"))['posterior_samples']
param_dict = pickle.load(open("../checkpoints/ViT2024-03-30-23-30-29/VisionTransformer_CIFAR10_42_params.pickle", "rb"))
params = param_dict['params']
batch_stats = param_dict['batch_stats']

In [14]:
from src.data.utils import get_mean_and_std


hparams = {
        "embed_dim": 256,
        "hidden_dim": 512,
        "num_heads": 8,
        "num_layers": 6,
        "patch_size": 4,
        "num_channels": 3,
        "num_patches": 64,
        "num_classes": 10,
        "dropout_prob": 0.1,
    }
model = VisionTransformer(**hparams)

n_samples_per_class = None
cls=list(range(10))
train_stats = get_mean_and_std(
        data_train=CIFAR10(path_root='/dtu/p1/hroy/data', set_purp="train", n_samples=None, download=True, cls=cls),
        val_frac=0.1,
        seed=0,
    )

model_fn = lambda p, x: model.apply({'params': p},
                                        x,
                                        train=False,
                                        rngs={'dropout': param_dict['rng']})

# dataset = CIFAR10(path_root='/dtu/p1/hroy/data', train=True, n_samples_per_class=n_samples_per_class, download=True, cls=cls, seed=0)

Files already downloaded and verified


In [15]:
eval_args = {}
eval_args["linearised_laplace"] = True
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = [0, 15, 30, 60, 90, 120, 150, 180]#, 210, 240, 270, 300, 330, 345, 360]
n_datapoint=500
ood_batch_size = 50
metrics_lr = []
for i, id in enumerate(ids):
    _, test_loader = get_rotated_cifar_loaders(id, data_path="data", train_stats=train_stats, download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
    some_metrics, all_y_prob, all_y_true, all_y_var = evaluate(test_loader, lr_posterior, params, model_fn, eval_args)
    if i == 0:
        all_y_prob_in = all_y_prob
    more_metrics = compute_metrics(
            i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="R-MNIST"
        )
    metrics_lr.append({**some_metrics, **more_metrics})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_lr[-1].items()]))
    


Files already downloaded and verified


  self.pid = os.fork()
  self.pid = os.fork()


R-MNIST with distribution shift intensity 0


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9102, nll: 442.9528, acc: 0.7960, brier: 0.3370, ece: 0.2663, mce: 0.8162
Files already downloaded and verified
R-MNIST with distribution shift intensity 1


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8288, nll: 4222.8994, acc: 0.1440, brier: 1.4804, ece: 0.6468, mce: 0.9652
Files already downloaded and verified
R-MNIST with distribution shift intensity 2


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8874, nll: 5786.3560, acc: 0.0940, brier: 1.6449, ece: 0.7195, mce: 0.9447
Files already downloaded and verified
R-MNIST with distribution shift intensity 3


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8899, nll: 5985.2148, acc: 0.0880, brier: 1.6585, ece: 0.7173, mce: 0.9754
Files already downloaded and verified
R-MNIST with distribution shift intensity 4


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8439, nll: 4721.1152, acc: 0.1120, brier: 1.5601, ece: 0.6824, mce: 0.9758
Files already downloaded and verified
R-MNIST with distribution shift intensity 5


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9027, nll: 6142.6602, acc: 0.0940, brier: 1.6887, ece: 0.7143, mce: 0.9757
Files already downloaded and verified
R-MNIST with distribution shift intensity 6


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8781, nll: 5837.8037, acc: 0.0880, brier: 1.6521, ece: 0.7151, mce: 0.9520
Files already downloaded and verified
R-MNIST with distribution shift intensity 7


  self.pid = os.fork()
  self.pid = os.fork()


conf: 0.8560, nll: 2187.9541, acc: 0.4320, brier: 0.9659, ece: 0.4641, mce: 0.8367


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


In [16]:
eval_args = {}
eval_args["linearised_laplace"] = False
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = [0, 15, 30, 60, 90, 120, 150, 180]#, 210, 240, 270, 300, 330, 345, 360]
n_datapoint=500
ood_batch_size = 50
metrics_map = []
for i, id in enumerate(ids):
    _, test_loader = get_rotated_cifar_loaders(id, data_path="data", train_stats=train_stats, download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
    some_metrics, all_y_prob, all_y_true, all_y_var = evaluate_map(test_loader, params, model_fn, eval_args)
    if i == 0:
        all_y_prob_in = all_y_prob
    more_metrics = compute_metrics(
            i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="R-MNIST"
        )
    metrics_map.append({**some_metrics, **more_metrics})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_map[-1].items()]))
    


Files already downloaded and verified
R-MNIST with distribution shift intensity 0


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9446, nll: 78.8293, acc: 0.9520, brier: 0.0696, ece: 0.1094, mce: 0.5819
Files already downloaded and verified
R-MNIST with distribution shift intensity 1


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8127, nll: 3107.2917, acc: 0.1640, brier: 1.4193, ece: 0.6171, mce: 0.9171
Files already downloaded and verified
R-MNIST with distribution shift intensity 2


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9186, nll: 4818.6064, acc: 0.0920, brier: 1.6989, ece: 0.7605, mce: 0.9337
Files already downloaded and verified
R-MNIST with distribution shift intensity 3


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9214, nll: 5117.1377, acc: 0.0860, brier: 1.6977, ece: 0.7889, mce: 0.9350
Files already downloaded and verified
R-MNIST with distribution shift intensity 4


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.8186, nll: 4141.9688, acc: 0.0880, brier: 1.5686, ece: 0.7198, mce: 0.9850
Files already downloaded and verified
R-MNIST with distribution shift intensity 5


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9301, nll: 5018.7827, acc: 0.0980, brier: 1.7003, ece: 0.7613, mce: 0.9545
Files already downloaded and verified
R-MNIST with distribution shift intensity 6


  self.pid = os.fork()
  self.pid = os.fork()
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


conf: 0.9067, nll: 4760.2051, acc: 0.0940, brier: 1.6717, ece: 0.7705, mce: 0.9654
Files already downloaded and verified
R-MNIST with distribution shift intensity 7


  self.pid = os.fork()
  self.pid = os.fork()


conf: 0.8158, nll: 957.2025, acc: 0.5760, brier: 0.6675, ece: 0.3031, mce: 0.7687


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  ret = ret.dtype.type(ret / rcount)


In [17]:
import pandas as pd
metrics_dict = ['conf', 'nll', 'acc', 'brier', 'ece', 'mce']
method_list = ["Projection Laplace", "MAP"]
method_dict = {"Projection Laplace":metrics_lr, "MAP":metrics_map}
df_data = {metric: ["{:.3f}".format(method_dict[dic][0][metric]) for dic in method_dict] for metric in metrics_dict}
# df_data = {k: ["{:.3f}".format(dic[metric+'_mean']) + u"\u00B1" + "{:.3f}".format(dic[metric+'_std'])  for dic in method_dict[k][1:] for metric in metrics_dict] for k in method_dict}

df = pd.DataFrame.from_dict(df_data, orient='index', columns=method_list)
df = df.T


In [19]:
df

Unnamed: 0,conf,nll,acc,brier,ece,mce
Projection Laplace,0.91,442.953,0.796,0.337,0.266,0.816
MAP,0.945,78.829,0.952,0.07,0.109,0.582


In [18]:
print(df.to_latex(index=True,
                  formatters={"name": str.upper},
                  float_format="{:.3f}".format,
))  

\begin{tabular}{lllllll}
\toprule
 & conf & nll & acc & brier & ece & mce \\
\midrule
Laplace Diffusion & 0.958±0.005 & 170.385±12.506 & 0.911±0.004 & 0.144±0.008 & 0.237±0.033 & 0.852±0.039 \\
Sampled Laplace & 0.845±0.005 & 438.840±123.643 & 0.742±0.053 & 0.385±0.084 & 0.205±0.045 & 0.799±0.066 \\
Linearised Laplace & 0.954±0.007 & 285.098±31.782 & 0.871±0.012 & 0.209±0.019 & 0.315±0.036 & 0.799±0.032 \\
MAP & 0.962±0.004 & 152.471±23.029 & 0.918±0.009 & 0.134±0.015 & 0.272±0.014 & 0.910±0.032 \\
\bottomrule
\end{tabular}



| Laplace Diffusion |  0.958±0.005 |  170.385±12.506 |  0.911±0.004 |  0.144±0.008 |  0.237±0.033 |  0.852±0.039 | 
| Sampled Laplace |  0.845±0.005 |  438.840±123.643 |  0.742±0.053 |  0.385±0.084 |  0.205±0.045 |  0.799±0.066 | 
| Linearised Laplace |  0.954±0.007 |  285.098±31.782 |  0.871±0.012 |  0.209±0.019 |  0.315±0.036 |  0.799±0.032 | 
| MAP |  0.962±0.004 |  152.471±23.029 |  0.918±0.009 |  0.134±0.015 |  0.272±0.014 |  0.910±0.032 | 


### OOD

In [19]:
eval_args = {}
eval_args["linearised_laplace"] = False
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = ["CIFAR-10", "CIFAR-100", "SVHN"]
n_datapoint=500
ood_batch_size = 50
metrics_lr = []
for i, id in enumerate(ids):
    some_metrics_all = defaultdict(list)
    more_metrics_all = defaultdict(list)
    for seed, (params, batch_stats, lr_posterior) in enumerate(zip(param_list, batch_stats_list, lr_posterior_list)):   
        model_fn = lambda p, x: model.apply({'params': p, 'batch_stats': batch_stats},
                                x,
                                train=False,
                                mutable=False)    
        _, test_loader = get_cifar10_ood_loaders(id, data_path="data", download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
        some_metrics, all_y_prob, all_y_true, all_y_var = evaluate(test_loader, lr_posterior, params, model_fn, eval_args)
        if i == 0:
            all_y_prob_in = all_y_prob
        more_metrics = compute_metrics(
                i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="CIFAR-10-OOD"
            )
        for k, v in some_metrics.items():
            some_metrics_all[k].append(v)
        for k, v in more_metrics.items():   
            more_metrics_all[k].append(v)
    seed_some_metric =  {**{k+"_mean": jnp.mean(jnp.array(v)).item() for k, v in some_metrics_all.items()}, **{k+"_std": jnp.std(jnp.array(v)).item() for k, v in some_metrics_all.items()}}
    seed_more_metric =  {**{k+"_mean": jnp.mean(jnp.array(v)).item() for k, v in more_metrics_all.items()}, **{k+"_std": jnp.std(jnp.array(v)).item() for k, v in more_metrics_all.items()}}
    metrics_lr.append({**seed_some_metric, **seed_more_metric})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_lr[-1].items()]))
    


Files already downloaded and verified


  self.targets = F.one_hot(torch.tensor(self.dataset.targets), len(cls)).numpy()


CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
conf_mean: 0.9584, nll_mean: 170.3846, acc_mean: 0.9107, conf_std: 0.0046, nll_std: 12.5059, acc_std: 0.0041
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
conf_mean: 0.7910, nll_mean: 396.0765, acc_mean: 0.1233, conf_std: 0.0013, nll_std: 9.8828, acc_std: 0.0105, auroc_mean: 0.8559, fpr95_mean: 0.7127, auroc_std: 0.0017, fpr95_std: 0.0050
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
conf_mean: 0.7928, nll_mean: 3754.8203, acc_mean: 0.0793, c

In [20]:
eval_args = {}
eval_args["linearised_laplace"] = False
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = ["CIFAR-10", "CIFAR-100", "SVHN"]
n_datapoint=500
ood_batch_size = 50
metrics_posterior = []
for i, id in enumerate(ids):
    some_metrics_all = defaultdict(list)
    more_metrics_all = defaultdict(list)
    for seed, (params, batch_stats, posterior) in enumerate(zip(param_list, batch_stats_list, posterior_list)):    
        model_fn = lambda p, x: model.apply({'params': p, 'batch_stats': batch_stats},
                                        x,
                                        train=False,
                                        mutable=False)    
        _, test_loader = get_cifar10_ood_loaders(id, data_path="data", download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
        some_metrics, all_y_prob, all_y_true, all_y_var = evaluate(test_loader, posterior, params, model_fn, eval_args)
        if i == 0:
            all_y_prob_in = all_y_prob
        more_metrics = compute_metrics(
                i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="CIFAR-10-OOD"
            )
        for k, v in some_metrics.items():
            some_metrics_all[k].append(v)
        for k, v in more_metrics.items():   
            more_metrics_all[k].append(v)
    seed_some_metric =  {**{k+"_mean": jnp.mean(jnp.array(v)).item() for k, v in some_metrics_all.items()}, **{k+"_std": jnp.std(jnp.array(v)).item() for k, v in some_metrics_all.items()}}
    seed_more_metric =  {**{k+"_mean": jnp.mean(jnp.array(v)).item() for k, v in more_metrics_all.items()}, **{k+"_std": jnp.std(jnp.array(v)).item() for k, v in more_metrics_all.items()}}
    metrics_posterior.append({**seed_some_metric, **seed_more_metric})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_posterior[-1].items()]))
    


Files already downloaded and verified


  self.targets = F.one_hot(torch.tensor(self.dataset.targets), len(cls)).numpy()


CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
conf_mean: 0.8453, nll_mean: 438.8398, acc_mean: 0.7420, conf_std: 0.0048, nll_std: 123.6431, acc_std: 0.0534
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
conf_mean: 0.7055, nll_mean: 311.3177, acc_mean: 0.1347, conf_std: 0.0370, nll_std: 39.9136, acc_std: 0.0401, auroc_mean: 0.7298, fpr95_mean: 0.8980, auroc_std: 0.0390, fpr95_std: 0.0261
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
conf_mean: 0.7566, nll_mean: 3243.4170, acc_mean: 0.1007,

In [21]:
eval_args = {}
eval_args["linearised_laplace"] = True
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = ["CIFAR-10", "CIFAR-100", "SVHN"]
n_datapoint=500
ood_batch_size = 50
metrics_lienarised = []
for i, id in enumerate(ids):
    some_metrics_all = defaultdict(list)
    more_metrics_all = defaultdict(list)
    for seed, (params, batch_stats, posterior) in enumerate(zip(param_list, batch_stats_list, posterior_list)):    
        model_fn = lambda p, x: model.apply({'params': p, 'batch_stats': batch_stats},
                                        x,
                                        train=False,
                                        mutable=False)    
        _, test_loader = get_cifar10_ood_loaders(id, data_path="data", download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
        some_metrics, all_y_prob, all_y_true, all_y_var = evaluate(test_loader, posterior, params, model_fn, eval_args)
        if i == 0:
            all_y_prob_in = all_y_prob
        more_metrics = compute_metrics(
                i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="CIFAR-10-OOD"
            )
        for k, v in some_metrics.items():
            some_metrics_all[k].append(v)
        for k, v in more_metrics.items():   
            more_metrics_all[k].append(v)
    seed_some_metric =  {**{k+"_mean": jnp.mean(jnp.array(v)).item() for k, v in some_metrics_all.items()}, **{k+"_std": jnp.std(jnp.array(v)).item() for k, v in some_metrics_all.items()}}
    seed_more_metric =  {**{k+"_mean": jnp.mean(jnp.array(v)).item() for k, v in more_metrics_all.items()}, **{k+"_std": jnp.std(jnp.array(v)).item() for k, v in more_metrics_all.items()}}
    metrics_lienarised.append({**seed_some_metric, **seed_more_metric})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_lienarised[-1].items()]))
    


Files already downloaded and verified


  self.targets = F.one_hot(torch.tensor(self.dataset.targets), len(cls)).numpy()


CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
conf_mean: 0.9537, nll_mean: 285.0979, acc_mean: 0.8713, conf_std: 0.0069, nll_std: 31.7821, acc_std: 0.0118
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
conf_mean: 0.8184, nll_mean: 474.2798, acc_mean: 0.1113, conf_std: 0.0044, nll_std: 23.0179, acc_std: 0.0105, auroc_mean: 0.8389, fpr95_mean: 0.7473, auroc_std: 0.0051, fpr95_std: 0.0172
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
conf_mean: 0.8054, nll_mean: 4194.7354, acc_mean: 0.0880, 

In [22]:
eval_args = {}
eval_args["linearised_laplace"] = False
eval_args["posterior_sample_type"] = "Pytree"
eval_args["likelihood"] = "classification"

ids = ["CIFAR-10", "CIFAR-100", "SVHN"]
n_datapoint=500
ood_batch_size = 50
metrics_map = []
for i, id in enumerate(ids):
    some_metrics_all = defaultdict(list)
    more_metrics_all = defaultdict(list)
    for seed, (params, batch_stats) in enumerate(zip(param_list, batch_stats_list)):    
        model_fn = lambda p, x: model.apply({'params': p, 'batch_stats': batch_stats},
                                        x,
                                        train=False,
                                        mutable=False)    
        # params = params_dict['params']
        _, test_loader = get_cifar10_ood_loaders(id, data_path="data", download=True, batch_size=ood_batch_size, n_datapoint=n_datapoint)
        some_metrics, all_y_prob, all_y_true, all_y_var = evaluate_map(test_loader, params, model_fn, eval_args)
        if i == 0:
            all_y_prob_in = all_y_prob
        more_metrics = compute_metrics(
                i, id, all_y_prob, test_loader, all_y_prob_in, all_y_var, benchmark="CIFAR-10-OOD"
            )
        for k, v in some_metrics.items():
            some_metrics_all[k].append(v)
        for k, v in more_metrics.items():   
            more_metrics_all[k].append(v)
    seed_some_metric =  {**{k+"_mean": jnp.mean(jnp.array(v)).item() for k, v in some_metrics_all.items()}, **{k+"_std": jnp.std(jnp.array(v)).item() for k, v in some_metrics_all.items()}}
    seed_more_metric =  {**{k+"_mean": jnp.mean(jnp.array(v)).item() for k, v in more_metrics_all.items()}, **{k+"_std": jnp.std(jnp.array(v)).item() for k, v in more_metrics_all.items()}}
    metrics_map.append({**seed_some_metric, **seed_more_metric})
    print(", ".join([f"{k}: {v:.4f}" for k, v in metrics_map[-1].items()]))
    


Files already downloaded and verified


  self.targets = F.one_hot(torch.tensor(self.dataset.targets), len(cls)).numpy()


CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-10
conf_mean: 0.9616, nll_mean: 152.4712, acc_mean: 0.9180, conf_std: 0.0037, nll_std: 23.0288, acc_std: 0.0091
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
Files already downloaded and verified
CIFAR-10-OOD - dataset: CIFAR-100
conf_mean: 0.8033, nll_mean: 407.4996, acc_mean: 0.1227, conf_std: 0.0099, nll_std: 15.6155, acc_std: 0.0050, auroc_mean: 0.8766, fpr95_mean: 0.6807, auroc_std: 0.0048, fpr95_std: 0.0229
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
Using downloaded and verified file: data/test_32x32.mat
CIFAR-10-OOD - dataset: SVHN
conf_mean: 0.7905, nll_mean: 3739.2761, acc_mean: 0.0840, 

In [26]:
import pandas as pd
metrics_dict = ['conf', 'auroc']
method_list = ["Laplace Diffusion", "Laplace Approximation", "Linearised Laplace", "MAP"]
method_dict = {"Laplace Diffusion":metrics_lr, "Laplace Approximation": metrics_posterior,"Linearised Laplace":metrics_lienarised,  "MAP":metrics_map}
mux = pd.MultiIndex.from_product([ids[1:], metrics_dict])
# df_data = {k: [dic[metric+'_mean'] for metric in metrics_dict for dic in method_dict[k][1:]] for k in method_dict}
df_data = {k: ["{:.3f}".format(dic[metric+'_mean']) + u"\u00B1" + "{:.3f}".format(dic[metric+'_std'])  for dic in method_dict[k][1:] for metric in metrics_dict] for k in method_dict}

df = pd.DataFrame.from_dict(df_data, orient='index',
                       columns=mux)


In [27]:
df

Unnamed: 0_level_0,CIFAR-100,CIFAR-100,SVHN,SVHN
Unnamed: 0_level_1,conf,auroc,conf,auroc
Laplace Diffusion,0.791±0.001,0.856±0.002,0.793±0.019,0.852±0.013
Laplace Approximation,0.706±0.037,0.730±0.039,0.757±0.053,0.667±0.068
Linearised Laplace,0.818±0.004,0.839±0.005,0.805±0.028,0.853±0.020
MAP,0.803±0.010,0.877±0.005,0.790±0.028,0.884±0.013


In [28]:
print(df.to_latex(index=True,
                  formatters={"name": str.upper},
                  float_format="{:.3f}".format,
))  

\begin{tabular}{lllll}
\toprule
 & \multicolumn{2}{r}{CIFAR-100} & \multicolumn{2}{r}{SVHN} \\
 & conf & auroc & conf & auroc \\
\midrule
Laplace Diffusion & 0.791±0.001 & 0.856±0.002 & 0.793±0.019 & 0.852±0.013 \\
Laplace Approximation & 0.706±0.037 & 0.730±0.039 & 0.757±0.053 & 0.667±0.068 \\
Linearised Laplace & 0.818±0.004 & 0.839±0.005 & 0.805±0.028 & 0.853±0.020 \\
MAP & 0.803±0.010 & 0.877±0.005 & 0.790±0.028 & 0.884±0.013 \\
\bottomrule
\end{tabular}



| Laplace Diffusion | 0.791±0.001 | 0.856±0.002 | 0.793±0.019 | 0.852±0.013 |
| Laplace Approximation | 0.706±0.037 | 0.730±0.039 | 0.757±0.053 | 0.667±0.068 |
| Linearised Laplace | 0.818±0.004 | 0.839±0.005 | 0.805±0.028 | 0.853±0.020 |
| MAP | 0.803±0.010 | 0.877±0.005 | 0.790±0.028 | 0.884±0.013 |
