In [1]:
#!nvidia-smi
#using a GeForce GTX1080 Ti for reproducibility for all timing experiments

In [2]:
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
import numpy as np
from sklearn.metrics import roc_auc_score
import scipy
from utils.LB_utils import * 
from utils.load_not_MNIST import notMNIST
import os
import time
import matplotlib.pyplot as plt
from laplace import Laplace

s = 1
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed(s)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cuda_status = torch.cuda.is_available()
print("device: ", device)
print("cuda status: ", cuda_status)

device:  cuda
cuda status:  True


In [4]:
### define network
class ConvNet(nn.Module):
    
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        
        self.net = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(16, 32, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Flatten(),
            torch.nn.Linear(4 * 4 * 32, num_classes)
        )
    def forward(self, x):
        out = self.net(x)
        return out

In [5]:
BATCH_SIZE_TRAIN_MNIST = 128
BATCH_SIZE_TEST_MNIST = 128
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [6]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=True,
        download=True,
        transform=MNIST_transform)

MNIST_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform)

MNIST_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [7]:
mnist_model = ConvNet().to(device)
loss_function = torch.nn.CrossEntropyLoss()

mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer_s{}.pth".format(s)

In [8]:
#Training routine

def train(model, train_loader, optimizer, max_iter, path, verbose=True):
    max_len = len(train_loader)

    for iter in range(max_iter):
        for batch_idx, (x, y) in enumerate(train_loader):
            
            x, y = x.to(device), y.to(device)
            
            output = model(x)

            accuracy = get_accuracy(output, y)

            loss = loss_function(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if verbose and batch_idx % 50 == 0:
                print(
                    "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                    "Minibatch Loss %.3f  " % (loss) +
                    "Accuracy %.0f" % (accuracy * 100) + "%"
                )

    print("saving model at: {}".format(path))
    torch.save(mnist_model.state_dict(), path)

In [9]:
#train(mnist_model, MNIST_train_loader, mnist_train_optimizer, MAX_ITER_MNIST, MNIST_PATH, verbose=True)

In [10]:
#predict in distribution
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer_s{}.pth".format(s)
#MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer.pth"

mnist_model = ConvNet().to(device)
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

max_len = len(MNIST_test_loader)
for batch_idx, (x, y) in enumerate(MNIST_test_loader):

    x, y = x.to(device), y.to(device)
    output = mnist_model(x)

    accuracy = get_accuracy(output, y)
    if batch_idx % 10 == 0:
        print(
            "Batch {}/{} \t".format(batch_idx, max_len) + 
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )
    acc.append(accuracy)

avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: pretrained_weights/MNIST_pretrained_10_classes_last_layer_s1.pth
Batch 0/79 	Accuracy 100%
Batch 10/79 	Accuracy 96%
Batch 20/79 	Accuracy 98%
Batch 30/79 	Accuracy 98%
Batch 40/79 	Accuracy 100%
Batch 50/79 	Accuracy 99%
Batch 60/79 	Accuracy 100%
Batch 70/79 	Accuracy 98%
overall test accuracy on MNIST: 98.84 %


In [11]:
BATCH_SIZE_TEST_FMNIST = 128
BATCH_SIZE_TEST_KMNIST = 128

In [12]:
FMNIST_test = torchvision.datasets.FashionMNIST(
        '~/data/fmnist', train=False, download=True,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [13]:
KMNIST_test = torchvision.datasets.KMNIST(
        '~/data/kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=BATCH_SIZE_TEST_KMNIST, shuffle=False)

In [14]:
#root = os.path.abspath('~/data')
root = os.path.expanduser('~/data')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
notMNIST_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=BATCH_SIZE_TEST_KMNIST,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


# MAP estimate

In [15]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [16]:
MNIST_test_in_MAP = predict_MAP(mnist_model, MNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_fmnist_MAP = predict_MAP(mnist_model, FMNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_notMNIST_MAP = predict_MAP(mnist_model, notMNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_KMNIST_MAP = predict_MAP(mnist_model, KMNIST_test_loader, device=device).cpu().numpy()

In [17]:
acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP = get_in_dist_values(MNIST_test_in_MAP, targets)
acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_fmnist_MAP, targets_FMNIST)
acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_notMNIST_MAP, targets_notMNIST)
acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_KMNIST_MAP, targets_KMNIST)

In [18]:
print_in_dist_values(acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP, 'MNIST', 'MAP')
print_out_dist_values(acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP, 'MNIST', test='FMNIST', method='MAP')
print_out_dist_values(acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP, 'MNIST', test='notMNIST', method='MAP')
print_out_dist_values(acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP, 'MNIST', test='KMNIST', method='MAP')

[In, MAP, MNIST] Accuracy: 0.988; average entropy: 0.047;     MMC: 0.986; Prob @ correct: 0.100
[Out-FMNIST, MAP, MNIST] Accuracy: 0.073; Average entropy: 0.972;    MMC: 0.663; AUROC: 0.977; Prob @ correct: 0.100
[Out-notMNIST, MAP, MNIST] Accuracy: 0.148; Average entropy: 0.620;    MMC: 0.774; AUROC: 0.914; Prob @ correct: 0.100
[Out-KMNIST, MAP, MNIST] Accuracy: 0.093; Average entropy: 0.750;    MMC: 0.723; AUROC: 0.959; Prob @ correct: 0.100


In [19]:
num_samples = 100

# Diag Hessian Sampling estimate

In [20]:
la_diag = Laplace(mnist_model, 'classification', 
                     subset_of_weights='last_layer', 
                     hessian_structure='diag',
                     prior_precision=1e-0) # 5e-4 # Choose prior precision according to weight decay
la_diag.fit(MNIST_train_loader)

In [21]:
MNIST_test_in_D = predict_samples(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_D = predict_samples(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_D = predict_samples(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_D = predict_samples(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.1629913090000006
time:  0.9629555029999999
time:  0.5176742340000011
time:  0.9637451180000021


In [22]:
# compute average log-likelihood for Diag
print(torch.distributions.Categorical(torch.tensor(MNIST_test_in_D)).log_prob(torch.tensor(targets)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_D)).log_prob(torch.tensor(targets_FMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_D)).log_prob(torch.tensor(targets_notMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_D)).log_prob(torch.tensor(targets_KMNIST)).mean())

tensor(-0.0605)
tensor(-3.9607)
tensor(-4.3349)
tensor(-4.9334)


In [23]:
import utils.scoring as scoring

In [24]:
#compute the Expected confidence estimate
print(scoring.expected_calibration_error(targets, MNIST_test_in_D))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_D))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_D))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_D))

0.03493014141414141
0.46846589898989915
0.409833633277876
0.43425194949494955


In [25]:
acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D = get_in_dist_values(MNIST_test_in_D, targets)
acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_FMNIST_D, targets_FMNIST)
acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_notMNIST_D, targets_notMNIST)
acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_KMNIST_D, targets_KMNIST)

In [26]:
print_in_dist_values(acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D, 'MNIST', 'Diag')
print_out_dist_values(acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D, 'MNIST', test='fmnist', method='Diag')
print_out_dist_values(acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D, 'MNIST', test='notMNIST', method='Diag')
print_out_dist_values(acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D, 'MNIST', test='KMNIST', method='Diag')

[In, Diag, MNIST] Accuracy: 0.989; average entropy: 0.160;     MMC: 0.956; Prob @ correct: 0.100
[Out-fmnist, Diag, MNIST] Accuracy: 0.073; Average entropy: 1.313;    MMC: 0.541; AUROC: 0.975; Prob @ correct: 0.100
[Out-notMNIST, Diag, MNIST] Accuracy: 0.149; Average entropy: 1.212;    MMC: 0.559; AUROC: 0.963; Prob @ correct: 0.100
[Out-KMNIST, Diag, MNIST] Accuracy: 0.095; Average entropy: 1.272;    MMC: 0.529; AUROC: 0.973; Prob @ correct: 0.100


# KFAC Laplace Approximation (sampling)

In [27]:
la_kron = Laplace(mnist_model, 'classification', 
                     subset_of_weights='last_layer', 
                     hessian_structure='kron',
                     prior_precision=5e-0) # 5e-4 # Choose prior precision according to weight decay
la_kron.fit(MNIST_train_loader)

The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  /opt/conda/conda-bld/pytorch_1639180549130/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2499.)
  L, W = torch.symeig(M, eigenvectors=True)


In [28]:
MNIST_test_in_KFAC = predict_samples(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_KFAC = predict_samples(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_KFAC = predict_samples(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_KFAC = predict_samples(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.341417872000001
time:  1.3222581580000003
time:  1.052921824000002
time:  1.2505363789999997


In [29]:
# compute average log-likelihood for KFAC
print(torch.distributions.Categorical(torch.tensor(MNIST_test_in_KFAC)).log_prob(torch.tensor(targets)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_KFAC)).log_prob(torch.tensor(targets_FMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_KFAC)).log_prob(torch.tensor(targets_notMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_KFAC)).log_prob(torch.tensor(targets_KMNIST)).mean())

tensor(-0.0487)
tensor(-3.7331)
tensor(-4.3672)
tensor(-4.8779)


In [30]:
# compute ECE for KFAC
print(scoring.expected_calibration_error(targets, MNIST_test_in_KFAC))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_KFAC))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_KFAC))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_KFAC))

0.024422959595959588
0.42929906060606066
0.40477670855100895
0.4167330707070709


In [31]:
acc_in_KFAC, prob_correct_in_KFAC, ent_in_KFAC, MMC_in_KFAC = get_in_dist_values(MNIST_test_in_KFAC, targets)
acc_out_FMNIST_KFAC, prob_correct_out_FMNIST_KFAC, ent_out_FMNIST_KFAC, MMC_out_FMNIST_KFAC, auroc_out_FMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_FMNIST_KFAC, targets_FMNIST)
acc_out_notMNIST_KFAC, prob_correct_out_notMNIST_KFAC, ent_out_notMNIST_KFAC, MMC_out_notMNIST_KFAC, auroc_out_notMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_notMNIST_KFAC, targets_notMNIST)
acc_out_KMNIST_KFAC, prob_correct_out_KMNIST_KFAC, ent_out_KMNIST_KFAC, MMC_out_KMNIST_KFAC, auroc_out_KMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_KMNIST_KFAC, targets_KMNIST)

In [32]:
print_in_dist_values(acc_in_KFAC, prob_correct_in_KFAC, ent_in_KFAC, MMC_in_KFAC, 'MNIST', 'KFAC')
print_out_dist_values(acc_out_FMNIST_KFAC, prob_correct_out_FMNIST_KFAC, ent_out_FMNIST_KFAC, MMC_out_FMNIST_KFAC, auroc_out_FMNIST_KFAC, 'MNIST', test='fmnist', method='KFAC')
print_out_dist_values(acc_out_notMNIST_KFAC, prob_correct_out_notMNIST_KFAC, ent_out_notMNIST_KFAC, MMC_out_notMNIST_KFAC, auroc_out_notMNIST_KFAC, 'MNIST', test='notMNIST', method='KFAC')
print_out_dist_values(acc_out_KMNIST_KFAC, prob_correct_out_KMNIST_KFAC, ent_out_KMNIST_KFAC, MMC_out_KMNIST_KFAC, auroc_out_KMNIST_KFAC, 'MNIST', test='KMNIST', method='KFAC')

[In, KFAC, MNIST] Accuracy: 0.988; average entropy: 0.111;     MMC: 0.967; Prob @ correct: 0.100
[Out-fmnist, KFAC, MNIST] Accuracy: 0.072; Average entropy: 1.402;    MMC: 0.501; AUROC: 0.985; Prob @ correct: 0.100
[Out-notMNIST, KFAC, MNIST] Accuracy: 0.148; Average entropy: 1.228;    MMC: 0.551; AUROC: 0.969; Prob @ correct: 0.100
[Out-KMNIST, KFAC, MNIST] Accuracy: 0.094; Average entropy: 1.315;    MMC: 0.511; AUROC: 0.982; Prob @ correct: 0.100


# Laplace Bridge (Diagonal)

In [33]:
MNIST_test_in_LB_D = predict_LB(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_LB_D = predict_LB(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_LB_D = predict_LB(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_LB_D = predict_LB(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

step0, f_mu[0]:  tensor([ -3.6094,  -2.8937,   1.1239,   1.7301,  -9.0167,  -3.9072, -17.8319,
         13.2881,  -5.2003,   0.8784], device='cuda:0')
step0, f_var[0]:  tensor([ 9.1984, 19.1064,  3.7570,  3.2656,  8.3827,  6.0570, 21.9312,  2.3324,
         4.4827,  3.0562], device='cuda:0')
step1, f_mu[0]:  tensor([ -0.7407,   3.0649,   2.2955,   2.7485,  -6.4024,  -2.0182, -10.9923,
         14.0155,  -3.8023,   1.8315], device='cuda:0')
step1, f_var[0]:  tensor([ 8.1612, 14.6310,  3.5839,  3.1349,  7.5212,  5.6073, 16.0347,  2.2657,
         4.2364,  2.9417], device='cuda:0')
step0, f_mu[0]:  tensor([ -6.6111, -10.5723,  -3.6269,   0.6844,  -8.7426,  -2.2827,  -8.9403,
         -6.8173,  12.5674,  -2.6553], device='cuda:0')
step0, f_var[0]:  tensor([13.1140, 15.2359,  5.4258,  4.1949,  5.9955,  5.0723, 10.2419,  6.4304,
         2.6371,  3.3685], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.1540, -2.7125, -0.8278,  2.8485, -5.6497,  0.3339, -3.6567, -3.5000,
        13.9278, -0.9176

step0, f_mu[0]:  tensor([ -2.4205,  -1.0927,  12.6614,  -3.7460,   1.5083,  -7.0049,  -7.0937,
         -3.2594,  -4.2696, -11.3067], device='cuda:0')
step0, f_var[0]:  tensor([ 9.8454, 19.2091,  3.2367,  5.1547,  5.1200,  8.0952, 10.6291,  5.5678,
         4.5728,  5.5046], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.9098,  5.4048, 13.7562, -2.0024,  3.2401, -4.2667, -3.4984, -1.3760,
        -2.7228, -9.4447], device='cuda:0')
step1, f_var[0]:  tensor([ 8.5855, 14.4130,  3.1005,  4.8094,  4.7793,  7.2434,  9.1607,  5.1648,
         4.3010,  5.1108], device='cuda:0')
step0, f_mu[0]:  tensor([ 13.6350, -17.5037,  -3.3616, -10.7703, -11.6570,  -2.0494,   2.5390,
         -5.1557,  -2.8156,   1.4312], device='cuda:0')
step0, f_var[0]:  tensor([ 9.1573, 96.2460,  9.6556, 10.8528, 19.9625,  9.6480, 14.0623, 14.7582,
         9.7271,  8.2199], device='cuda:0')
step1, f_mu[0]:  tensor([15.2515, -0.5145, -1.6572, -8.8545, -8.1333, -0.3463,  5.0212, -2.5506,
        -1.0986,  2.8822], device=

step1, f_mu[0]:  tensor([-2.7132,  2.3923, -0.0475, -0.0757,  1.9512, -1.2021, -8.1042,  5.2061,
        -1.2412,  3.8345], device='cuda:0')
step1, f_var[0]:  tensor([3.6896, 4.2831, 1.8220, 1.3682, 1.4131, 2.1399, 4.9819, 1.1199, 1.4760,
        0.8313], device='cuda:0')
step0, f_mu[0]:  tensor([ -8.2103,  -6.5736,  -6.3433,  14.9028,  -9.8334,  -0.5142, -14.0434,
         -6.8587,  -2.5115,   3.5891], device='cuda:0')
step0, f_var[0]:  tensor([10.0924, 12.3575,  3.7774,  2.1256,  7.2197,  3.2871, 22.5039,  5.1997,
         3.3778,  2.7117], device='cuda:0')
step1, f_mu[0]:  tensor([-3.1544, -0.3829, -4.4509, 15.9676, -6.2165,  1.1325, -2.7698, -4.2538,
        -0.8193,  4.9476], device='cuda:0')
step1, f_var[0]:  tensor([ 8.6904, 10.2556,  3.5810,  2.0634,  6.5023,  3.1384, 15.5334,  4.8275,
         3.2208,  2.6105], device='cuda:0')
step0, f_mu[0]:  tensor([-12.3301, -12.9760,  -7.9699,  -1.2217,  -1.5923,   1.8071, -13.5090,
          1.4722,   0.8682,  11.3409], device='cuda:0')


step1, f_var[0]:  tensor([10.9020, 24.1866,  6.3537,  8.0504,  4.1662, 14.2940, 11.0293,  6.6561,
         5.4223,  4.7247], device='cuda:0')
step0, f_mu[0]:  tensor([-10.7227, -11.5696,  -4.5761,  11.9875,  -9.4027,  -2.1649, -13.8657,
         -9.1945,   0.4282,  -2.2268], device='cuda:0')
step0, f_var[0]:  tensor([14.0162, 39.8158,  6.6781,  4.1176,  9.9998,  5.4799, 14.8178, 10.2841,
         4.2093,  5.2372], device='cuda:0')
step1, f_mu[0]:  tensor([-4.4506,  6.2475, -1.5878, 13.8300, -4.9279,  0.2873, -7.2348, -4.5925,
         2.3119,  0.1168], device='cuda:0')
step1, f_var[0]:  tensor([12.3028, 25.9892,  6.2891,  3.9697,  9.1277,  5.2180, 12.9028,  9.3617,
         4.0548,  4.9980], device='cuda:0')
step0, f_mu[0]:  tensor([  7.4148,  -7.3448,  -0.1742, -11.0964,   5.2517,  -9.3631,   2.6035,
         -5.1806,  -3.8293,  -2.8050], device='cuda:0')
step0, f_var[0]:  tensor([ 5.3299, 37.0128,  5.1150,  7.0971,  5.6572,  9.3819,  7.3317,  7.6210,
         4.4433,  4.0729], device

step0, f_mu[0]:  tensor([ -7.3322, -16.1601,  -5.6659,  -0.1366,  -7.8021,  11.3574,   1.9363,
        -13.8387,  -0.2925,   2.7148], device='cuda:0')
step0, f_var[0]:  tensor([14.4723, 38.7383,  6.6153,  5.2797,  8.0176,  4.0354,  9.6286,  9.2956,
         3.9236,  4.6587], device='cuda:0')
step1, f_mu[0]:  tensor([ -2.4623,  -3.1248,  -3.4398,   1.6400,  -5.1042,  12.7152,   5.1763,
        -10.7108,   1.0278,   4.2825], device='cuda:0')
step1, f_var[0]:  tensor([12.4712, 24.4006,  6.1972,  5.0133,  7.4035,  3.8798,  8.7428,  8.4700,
         3.7765,  4.4513], device='cuda:0')
step0, f_mu[0]:  tensor([ -5.6137,  -5.1030,  10.2980,   0.6372,  -8.4535,  -9.4190, -16.5559,
          7.9529,  -1.2843,  -9.5548], device='cuda:0')
step0, f_var[0]:  tensor([11.4453, 22.1223,  3.3899,  4.2846,  7.0496,  7.6077, 16.0532,  4.4532,
         3.9972,  5.7107], device='cuda:0')
step1, f_mu[0]:  tensor([-0.6833,  4.4269, 11.7583,  2.4829, -5.4167, -6.1417, -9.6405,  9.8712,
         0.4376, -7.0947

step0, f_mu[0]:  tensor([-3.8912, -5.4246,  2.1494, -2.2608,  0.3674, -1.1233,  2.3072, -0.5586,
         0.1116, -7.2599], device='cuda:0')
step0, f_var[0]:  tensor([ 4.9928, 11.5326,  1.8702,  1.8708,  1.9759,  2.0407,  3.2257,  2.2997,
         1.3382,  2.0545], device='cuda:0')
step1, f_mu[0]:  tensor([-1.5479, -0.0118,  3.0271, -1.3827,  1.2948, -0.1655,  3.8212,  0.5208,
         0.7397, -6.2957], device='cuda:0')
step1, f_var[0]:  tensor([4.2420, 7.5267, 1.7649, 1.7654, 1.8583, 1.9153, 2.9123, 2.1404, 1.2843,
        1.9273], device='cuda:0')
step0, f_mu[0]:  tensor([-3.5915, -2.4685,  3.8069,  0.4177, -4.5619,  0.4311, -6.5017, -2.0782,
        -1.4323, -5.5802], device='cuda:0')
step0, f_var[0]:  tensor([ 6.3439, 17.1386,  2.6421,  2.7698,  7.7159,  3.6608, 14.1505,  3.9116,
         3.6667,  3.8419], device='cuda:0')
step1, f_mu[0]:  tensor([-1.5143,  3.1432,  4.6720,  1.3246, -2.0354,  1.6297, -1.8684, -0.7974,
        -0.2317, -4.3222], device='cuda:0')
step1, f_var[0]:  te

step1, f_mu[0]:  tensor([ 3.3334, -0.5674,  0.7072, -2.3143, -4.3257, -0.7810,  1.2389,  0.6388,
         3.0936, -1.0235], device='cuda:0')
step1, f_var[0]:  tensor([3.9703, 6.8629, 2.1552, 1.9830, 2.0614, 2.1379, 4.7232, 1.9206, 1.5717,
        1.5688], device='cuda:0')
step0, f_mu[0]:  tensor([-0.6002,  2.4073,  4.1011, -2.3116, -4.0396, -2.1470, -2.5534, -4.2742,
        -0.3344, -6.4289], device='cuda:0')
step0, f_var[0]:  tensor([3.3447, 4.0762, 1.1833, 1.4465, 2.3776, 1.9574, 5.4853, 2.4155, 1.3502,
        1.9136], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.5180,  4.9887,  4.8505, -1.3955, -2.5339, -0.9074,  0.9204, -2.7444,
         0.5207, -5.2170], device='cuda:0')
step1, f_var[0]:  tensor([2.9068, 3.4259, 1.1285, 1.3646, 2.1563, 1.8074, 4.3077, 2.1872, 1.2788,
        1.7703], device='cuda:0')
step0, f_mu[0]:  tensor([-0.8420, -1.8745, -5.3110, -8.5985,  3.9689, -2.1066, -2.8101,  0.0902,
        -0.6538, -2.0654], device='cuda:0')
step0, f_var[0]:  tensor([ 4.2333, 12.91

step1, f_var[0]:  tensor([3.3528, 3.1961, 1.6847, 1.6665, 1.1985, 1.8714, 2.5825, 1.7585, 0.9729,
        1.1733], device='cuda:0')
step0, f_mu[0]:  tensor([ 0.0749, -1.5503, -3.1658, -3.6943,  3.6698, -1.9924, -0.0581,  0.3325,
         0.6844, -1.3276], device='cuda:0')
step0, f_var[0]:  tensor([1.1463, 6.7958, 1.0654, 1.1419, 0.8792, 1.3916, 1.2589, 1.0310, 0.5790,
        0.5894], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.5822,  1.4571, -2.6943, -3.1890,  4.0589, -1.3765,  0.4990,  0.7887,
         0.9406, -1.0668], device='cuda:0')
step1, f_var[0]:  tensor([1.0635, 3.8873, 0.9939, 1.0598, 0.8306, 1.2697, 1.1591, 0.9640, 0.5579,
        0.5675], device='cuda:0')
step0, f_mu[0]:  tensor([-3.1099, -0.2227, -4.1336, -2.4969,  1.6403, -1.4508, -1.6236, -2.0231,
         2.1215, -2.1072], device='cuda:0')
step0, f_var[0]:  tensor([1.8495, 2.2515, 0.9576, 0.9211, 0.7215, 1.0404, 1.3367, 1.0534, 0.4757,
        0.6233], device='cuda:0')
step1, f_mu[0]:  tensor([-0.9021,  2.4648, -2.990

step0, f_mu[0]:  tensor([ 1.0713, -4.5338, -1.4217, -2.0636, -0.0055, -2.4988,  1.3922, -3.6649,
         2.7109, -0.2511], device='cuda:0')
step0, f_var[0]:  tensor([ 1.9586, 12.3959,  1.5417,  1.4738,  1.6830,  1.3295,  1.9623,  1.9270,
         0.8455,  0.9772], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.7668, -0.1326, -0.8743, -1.5403,  0.5920, -2.0268,  2.0889, -2.9807,
         3.0111,  0.0958], device='cuda:0')
step1, f_var[0]:  tensor([1.8116, 6.5074, 1.4506, 1.3906, 1.5744, 1.2617, 1.8147, 1.7847, 0.8181,
        0.9406], device='cuda:0')
step0, f_mu[0]:  tensor([ 1.8767, -3.8146,  3.6277, -3.0152, -8.3996, -2.9255,  1.8478,  0.2345,
         1.2518, -3.0471], device='cuda:0')
step0, f_var[0]:  tensor([ 7.9338, 15.1925,  3.3516,  3.2767,  3.1756,  3.6347,  7.8280,  3.2262,
         2.5772,  2.7559], device='cuda:0')
step1, f_mu[0]:  tensor([ 3.7291, -0.2674,  4.4103, -2.2501, -7.6582, -2.0768,  3.6755,  0.9877,
         1.8535, -2.4037], device='cuda:0')
step1, f_var[0]:  te

step0, f_mu[0]:  tensor([ 0.8253, -3.5672,  1.4977,  2.7449, -8.0567, -0.3746, -6.9156,  1.1370,
        -4.2357, -0.8467], device='cuda:0')
step0, f_var[0]:  tensor([ 2.5163, 14.6274,  1.1939,  1.2852,  3.8519,  1.8051,  7.5871,  1.8567,
         1.5195,  1.2778], device='cuda:0')
step1, f_mu[0]:  tensor([ 2.0184,  3.3688,  2.0638,  3.3543, -6.2302,  0.4814, -3.3180,  2.0175,
        -3.5152, -0.2408], device='cuda:0')
step1, f_var[0]:  tensor([2.3475, 8.9250, 1.1559, 1.2412, 3.4565, 1.7182, 6.0529, 1.7649, 1.4579,
        1.2343], device='cuda:0')
step0, f_mu[0]:  tensor([-0.7383, -3.1145,  2.1581, -5.9347,  2.4458, -4.4684,  2.2279, -6.0496,
         0.4835, -1.3740], device='cuda:0')
step0, f_var[0]:  tensor([ 3.4548, 16.5480,  2.2714,  2.7440,  3.3378,  2.6265,  3.4804,  3.3670,
         1.6653,  2.3041], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.4489,  2.5722,  2.9387, -4.9917,  3.5928, -3.5658,  3.4239, -4.8925,
         1.0558, -0.5822], device='cuda:0')
step1, f_var[0]:  te

step0, f_mu[0]:  tensor([  6.7073,  -1.5207,  -5.4298, -10.2947,  -0.2684,  -2.8746,   3.8916,
         -0.9931,  -0.3493,  -2.1968], device='cuda:0')
step0, f_var[0]:  tensor([ 5.6827, 73.0079,  6.8608,  7.6140, 11.1102,  5.6666,  8.6844,  8.7942,
         4.0148,  5.6608], device='cuda:0')
step1, f_mu[0]:  tensor([ 7.2597,  5.5772, -4.7628, -9.5545,  0.8118, -2.3237,  4.7359, -0.1381,
         0.0410, -1.6464], device='cuda:0')
step1, f_var[0]:  tensor([ 5.4471, 34.1290,  6.5175,  7.1911, 10.2098,  5.4324,  8.1343,  8.2301,
         3.8973,  5.4271], device='cuda:0')
step0, f_mu[0]:  tensor([ 1.7994,  0.7980,  0.6430,  1.5861, -5.1248, -0.5785, -1.2213, -1.7858,
        -2.7957,  0.3945], device='cuda:0')
step0, f_var[0]:  tensor([2.2793, 4.7840, 0.8610, 0.7122, 1.1624, 0.7573, 1.9010, 1.3789, 0.6309,
        0.7219], device='cuda:0')
step1, f_mu[0]:  tensor([ 2.7426,  2.7776,  0.9992,  1.8808, -4.6438, -0.2652, -0.4347, -1.2152,
        -2.5346,  0.6932], device='cuda:0')
step1, f_v

step1, f_mu[0]:  tensor([-1.5873,  3.9845, -4.3006,  1.9068, -3.6602,  3.0261, -3.1461,  2.0753,
         1.7292, -0.0277], device='cuda:0')
step1, f_var[0]:  tensor([14.1408, 48.5160, 10.7754,  6.8554, 17.0003,  6.9551, 20.4122, 12.1741,
         7.0632,  7.9712], device='cuda:0')
step0, f_mu[0]:  tensor([-3.2344, -5.6021, -8.0270, -6.3837,  5.4186, -2.4644,  0.6813, -1.3214,
         0.6719, -4.7977], device='cuda:0')
step0, f_var[0]:  tensor([15.1946, 76.5571, 10.5927,  9.1501, 12.8258,  8.9197, 14.4488, 12.8262,
         5.9346,  8.1726], device='cuda:0')
step1, f_mu[0]:  tensor([-1.0539,  5.3842, -6.5069, -5.0707,  7.2591, -1.1844,  2.7548,  0.5192,
         1.5236, -3.6249], device='cuda:0')
step1, f_var[0]:  tensor([13.8724, 42.9932,  9.9502,  8.6706, 11.8837,  8.4640, 13.2532, 11.8841,
         5.7329,  7.7901], device='cuda:0')
step0, f_mu[0]:  tensor([-6.4474, -2.2754, -4.4878,  6.0118, -2.3763,  3.2498, -7.9771, -1.5012,
        -3.2277, -2.6362], device='cuda:0')
step0, f_v

step0, f_var[0]:  tensor([3.8285, 5.3966, 1.4789, 2.0606, 2.6590, 2.8423, 4.4913, 2.1766, 1.2913,
        2.1732], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.5583,  6.1410,  4.1440, -0.1945, -2.5783, -3.0054, -0.4140,  1.3782,
         0.4995, -7.5289], device='cuda:0')
step1, f_var[0]:  tensor([3.3124, 4.3711, 1.4019, 1.9111, 2.4100, 2.5578, 3.7810, 2.0098, 1.2326,
        2.0069], device='cuda:0')
step0, f_mu[0]:  tensor([-3.8305, -9.8095,  0.7674, -1.5661, -2.6121,  0.3019, -0.8441, -3.3783,
         2.2514, -1.8043], device='cuda:0')
step0, f_var[0]:  tensor([ 9.1939, 33.5619,  4.0158,  3.5788,  6.3018,  3.9180,  7.5733,  5.5444,
         2.7048,  3.9675], device='cuda:0')
step1, f_mu[0]:  tensor([-1.4823, -1.2377,  1.7930, -0.6521, -1.0026,  1.3025,  1.0901, -1.9622,
         2.9422, -0.7910], device='cuda:0')
step1, f_var[0]:  tensor([ 8.1420, 19.5450,  3.8151,  3.4195,  5.8076,  3.7270,  6.8596,  5.1618,
         2.6138,  3.7716], device='cuda:0')
step0, f_mu[0]:  tensor([-3.7

step1, f_mu[0]:  tensor([-3.1930,  0.2705, -4.5591, -2.3412,  5.0688,  2.2197, -1.3023, -1.4098,
         3.0307,  2.2157], device='cuda:0')
step1, f_var[0]:  tensor([ 5.0543, 13.3143,  3.7692,  3.2147,  4.0787,  2.6019,  5.9694,  4.2784,
         2.4637,  2.1265], device='cuda:0')
step0, f_mu[0]:  tensor([  0.7118,  -1.1749,  -1.5926,   1.0758,  -9.4648,   3.4224,  -1.0813,
        -11.2603,  -3.8820,   2.3123], device='cuda:0')
step0, f_var[0]:  tensor([13.9205, 43.9260,  7.3163,  5.6562, 21.9741,  6.9556, 33.0340, 12.7496,
         7.6191,  6.6708], device='cuda:0')
step1, f_mu[0]:  tensor([ 2.5351,  4.5786, -0.6343,  1.8166, -6.5866,  4.3334,  3.2455, -9.5904,
        -2.8841,  3.1860], device='cuda:0')
step1, f_var[0]:  tensor([12.7081, 31.8532,  6.9813,  5.4561, 18.9529,  6.6529, 26.2061, 11.7325,
         7.2559,  6.3923], device='cuda:0')
step0, f_mu[0]:  tensor([  3.0313,   0.1042,  -3.2226,  -0.3474,  -5.2227,   2.1393,   2.0763,
        -10.8434,  -3.2872,  -0.4879], device=

step0, f_var[0]:  tensor([ 6.2930, 27.6527,  3.3087,  2.8523,  7.1357,  3.7958,  8.1901,  5.0366,
         2.7109,  3.5468], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.1613,  2.0298,  1.2787, -1.6954, -0.9694, -1.4112, -0.5415,  2.2922,
        -0.9357, -1.2087], device='cuda:0')
step1, f_var[0]:  tensor([ 5.7314, 16.8098,  3.1535,  2.7370,  6.4137,  3.5915,  7.2389,  4.6769,
         2.6067,  3.3684], device='cuda:0')
step0, f_mu[0]:  tensor([-4.1470, -3.9269, -5.6807, -8.0525,  4.8576, -1.6890,  2.3708, -5.1618,
         2.1250, -2.1735], device='cuda:0')
step0, f_var[0]:  tensor([ 8.6832, 32.3382,  6.5736,  6.5706,  7.9684,  6.5342,  8.8113,  8.7701,
         3.4150,  4.8141], device='cuda:0')
step1, f_mu[0]:  tensor([-2.1730,  3.4246, -4.1863, -6.5588,  6.6690, -0.2036,  4.3739, -3.1681,
         2.9014, -1.0792], device='cuda:0')
step1, f_var[0]:  tensor([ 7.8852, 21.2695,  6.1162,  6.1136,  7.2963,  6.0823,  7.9896,  7.9560,
         3.2916,  4.5688], device='cuda:0')
step0, f_

step0, f_mu[0]:  tensor([-13.5623,   3.5801,   1.4305,  -1.4436,   8.8247,  -1.6760, -10.5430,
          0.1322,  -5.3014,  -4.2684], device='cuda:0')
step0, f_var[0]:  tensor([13.6122, 24.1658,  4.5791,  4.8090,  3.8474,  6.7046, 10.9891,  5.2969,
         4.0161,  4.0542], device='cuda:0')
step1, f_mu[0]:  tensor([-9.7764, 10.3013,  2.7040, -0.1060,  9.8948,  0.1887, -7.4866,  1.6054,
        -4.1844, -3.1409], device='cuda:0')
step1, f_var[0]:  tensor([11.3546, 17.0505,  4.3236,  4.5272,  3.6670,  6.1569,  9.5177,  4.9551,
         3.8196,  3.8539], device='cuda:0')
step0, f_mu[0]:  tensor([-5.5741,  1.0453, -4.9247, -0.3809,  3.2019, -3.7375, -6.5397, -3.1175,
        -0.0580,  0.8246], device='cuda:0')
step0, f_var[0]:  tensor([ 5.2153, 11.3814,  2.3019,  2.1405,  3.0424,  2.6800,  4.7024,  2.9599,
         1.5735,  2.1246], device='cuda:0')
step1, f_mu[0]:  tensor([-2.9392,  6.7956, -3.7617,  0.7006,  4.7390, -2.3834, -4.1639, -1.6221,
         0.7370,  1.8980], device='cuda:0')


step0, f_var[0]:  tensor([ 3.2112, 10.3699,  2.4455,  2.4838,  1.5745,  2.4410,  2.7290,  2.4985,
         1.2041,  1.2462], device='cuda:0')
step1, f_mu[0]:  tensor([-0.5637,  1.9244, -4.0261, -2.3332,  3.1806, -0.6691,  2.2073, -0.6091,
         2.3598, -1.4709], device='cuda:0')
step1, f_var[0]:  tensor([2.8698, 6.8096, 2.2475, 2.2796, 1.4924, 2.2437, 2.4824, 2.2918, 1.1561,
        1.1948], device='cuda:0')
step0, f_mu[0]:  tensor([-0.1632, -1.5174,  0.3894,  0.2020, -0.0903, -0.4970,  0.2381, -2.1485,
         1.7345, -0.4975], device='cuda:0')
step0, f_var[0]:  tensor([0.8242, 2.5401, 0.3672, 0.3340, 0.4643, 0.3808, 0.6882, 0.5270, 0.2438,
        0.2868], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.1278, -0.6206,  0.5190,  0.3199,  0.0736, -0.3625,  0.4811, -1.9625,
         1.8205, -0.3963], device='cuda:0')
step1, f_var[0]:  tensor([0.7221, 1.5708, 0.3470, 0.3172, 0.4319, 0.3590, 0.6170, 0.4853, 0.2349,
        0.2744], device='cuda:0')
step0, f_mu[0]:  tensor([-7.1064,  2.16

step0, f_mu[0]:  tensor([-5.8109, -4.8035, -2.9815, -6.9332,  8.6822, -7.3329, -4.2180, -1.6483,
        -0.8311, -1.9294], device='cuda:0')
step0, f_var[0]:  tensor([ 21.6076, 115.9355,  11.9453,   9.9753,  11.0736,  12.7537,  17.4023,
         12.3189,   6.9026,   7.7552], device='cuda:0')
step1, f_mu[0]:  tensor([-3.1718,  9.3564, -1.5225, -5.7148, 10.0347, -5.7753, -2.0926, -0.1438,
         0.0120, -0.9822], device='cuda:0')
step1, f_var[0]:  tensor([19.5569, 56.8981, 11.3185,  9.5382, 10.5350, 12.0393, 16.0721, 11.6523,
         6.6933,  7.4911], device='cuda:0')
step0, f_mu[0]:  tensor([-5.0628, -1.3775,  0.1699, -4.6908,  5.2528, -3.4406, -0.0125, -1.4704,
         0.2413, -2.3427], device='cuda:0')
step0, f_var[0]:  tensor([ 6.3372, 15.9618,  3.0543,  2.7440,  2.4836,  3.4499,  5.7189,  3.4058,
         1.8809,  2.1223], device='cuda:0')
step1, f_mu[0]:  tensor([-3.3516,  2.9323,  0.9946, -3.9499,  5.9234, -2.5091,  1.5316, -0.5508,
         0.7492, -1.7697], device='cuda:0')


step0, f_mu[0]:  tensor([  3.7460,  -8.5728,  -3.7698,  -5.1472, -11.0471,   5.2182,   0.9237,
         -5.5666,  -2.1872,   3.8587], device='cuda:0')
step0, f_var[0]:  tensor([ 8.6878, 58.5858,  6.6308,  6.5978, 13.5020,  5.4036, 11.1615, 10.1669,
         5.2914,  5.1790], device='cuda:0')
step1, f_mu[0]:  tensor([ 5.2387,  1.4935, -2.6305, -4.0135, -8.7272,  6.1466,  2.8415, -3.8197,
        -1.2780,  4.7485], device='cuda:0')
step1, f_var[0]:  tensor([ 8.1126, 32.4263,  6.2957,  6.2661, 12.1125,  5.1810, 10.2120,  9.3791,
         5.0780,  4.9746], device='cuda:0')
step0, f_mu[0]:  tensor([-5.8443, -4.2846,  0.5292, -1.4633, -1.3367, -1.8057,  5.6191, -4.6015,
        -5.0517, -6.1606], device='cuda:0')
step0, f_var[0]:  tensor([19.0945, 57.8240,  6.8980,  7.1098,  8.5194,  6.5201, 11.2555,  9.4169,
         5.4710,  8.0154], device='cuda:0')
step1, f_mu[0]:  tensor([-2.5193,  5.7844,  1.7304, -0.2252,  0.1468, -0.6703,  7.5791, -2.9617,
        -4.0990, -4.7649], device='cuda:0')


step0, f_mu[0]:  tensor([-5.6872, -2.1306, -0.0402, -1.6418,  3.0353, -4.7075, -9.1239, -4.4894,
         2.6481, -2.6662], device='cuda:0')
step0, f_var[0]:  tensor([5.4744, 9.6562, 2.7784, 2.6809, 3.8586, 3.9873, 6.1617, 4.2452, 1.7340,
        2.5535], device='cuda:0')
step1, f_mu[0]:  tensor([-2.5390,  3.4225,  1.5577, -0.1000,  5.2543, -2.4145, -5.5805, -2.0480,
         3.6452, -1.1978], device='cuda:0')
step1, f_var[0]:  tensor([4.7796, 7.4943, 2.5994, 2.5143, 3.5134, 3.6187, 5.2814, 3.8274, 1.6643,
        2.4023], device='cuda:0')
step0, f_mu[0]:  tensor([-4.4363, -1.5070, -3.7444, -1.6539, -8.5454,  4.5499, -4.1115, -3.2186,
        -0.8122, -8.2297], device='cuda:0')
step0, f_var[0]:  tensor([14.3154, 55.1203,  7.5657,  6.0215, 10.9372,  6.5771, 17.8484,  8.9522,
         5.3591,  6.2703], device='cuda:0')
step1, f_mu[0]:  tensor([-1.1699, 11.0701, -2.0181, -0.2799, -6.0498,  6.0507, -0.0389, -1.1760,
         0.4106, -6.7989], device='cuda:0')
step1, f_var[0]:  tensor([12.8

step0, f_mu[0]:  tensor([-4.6128,  3.4722,  1.6739, -8.6668,  8.4733, -7.6088, -6.4494,  2.0777,
        -5.7309,  1.4885], device='cuda:0')
step0, f_var[0]:  tensor([10.4677, 48.4370,  6.3375,  5.9794,  5.3429,  8.9858, 12.8983,  6.4673,
         5.2172,  3.3898], device='cuda:0')
step1, f_mu[0]:  tensor([-3.1482, 10.2490,  2.5606, -7.8303,  9.2208, -6.3516, -4.6448,  2.9826,
        -5.0009,  1.9628], device='cuda:0')
step1, f_var[0]:  tensor([ 9.5025, 27.7703,  5.9837,  5.6644,  5.0914,  8.2746, 11.4328,  6.0989,
         4.9774,  3.2886], device='cuda:0')
step0, f_mu[0]:  tensor([-11.2651,  -9.0484,  -0.3242,   1.4721,   2.4681,  -1.4897, -15.7098,
         -2.2819,  -2.3332,   2.2522], device='cuda:0')
step0, f_var[0]:  tensor([11.7344, 26.4107,  4.9182,  4.0704,  7.3134,  5.5347, 17.7437,  5.7253,
         4.5423,  4.4262], device='cuda:0')
step1, f_mu[0]:  tensor([-6.6612,  1.3136,  1.6054,  3.0691,  5.3374,  0.6818, -8.7483, -0.0356,
        -0.5511,  3.9888], device='cuda:0')


step1, f_var[0]:  tensor([ 9.1813, 19.7663,  5.3184,  4.5150,  7.2269,  4.9211, 14.5395,  4.1612,
         4.6689,  3.6753], device='cuda:0')
step0, f_mu[0]:  tensor([ -1.8975,   7.5904,   2.9152,  -6.5281,  -0.9744, -11.7696,  -7.5872,
         -1.0047,  -1.9743,  -6.6970], device='cuda:0')
step0, f_var[0]:  tensor([12.3668, 21.6634,  5.2225,  6.1124,  8.1313,  7.5872, 13.5003,  7.9034,
         4.9695,  6.7643], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.7680, 14.0114,  4.4631, -4.7164,  1.4358, -9.5207, -3.5857,  1.3379,
        -0.5014, -4.6920], device='cuda:0')
step1, f_var[0]:  tensor([10.7437, 16.6825,  4.9330,  5.7158,  7.4296,  6.9762, 11.5660,  7.2405,
         4.7074,  6.2787], device='cuda:0')
step0, f_mu[0]:  tensor([ -2.8484,   3.3361,  -0.6029, -10.5454,   4.8237,  -1.3438,  -0.5975,
         -0.3477,  -1.2627,  -6.8201], device='cuda:0')
step0, f_var[0]:  tensor([ 6.8835, 25.2164,  4.3093,  4.8547,  5.9654,  4.6108,  7.4774,  6.2120,
         2.9619,  4.3460], device

step0, f_mu[0]:  tensor([ -4.0572,  -3.7610,   1.3121,   1.9487,   0.1938,  -2.5542, -10.3887,
         -2.1740,  -5.6533,  -2.4454], device='cuda:0')
step0, f_var[0]:  tensor([ 5.0774, 11.1545,  2.2036,  1.9875,  3.8324,  2.9345,  9.2353,  2.7417,
         2.3956,  2.0686], device='cuda:0')
step1, f_mu[0]:  tensor([-0.8478,  3.2898,  2.7049,  3.2050,  2.6163, -0.6993, -4.5511, -0.4410,
        -4.1390, -1.1379], device='cuda:0')
step1, f_var[0]:  tensor([4.4865, 8.3028, 2.0923, 1.8970, 3.4958, 2.7372, 7.2805, 2.5694, 2.2641,
        1.9705], device='cuda:0')
step0, f_mu[0]:  tensor([ -4.9548,   2.8154,  -1.3379,  -0.8271,   0.9781,  -3.0964,  -6.1968,
         -0.5623,  -1.4465, -13.2571], device='cuda:0')
step0, f_var[0]:  tensor([10.6329, 20.1923,  4.9239,  5.1486,  7.8031,  7.4536, 12.9151,  7.2983,
         3.8855,  5.9627], device='cuda:0')
step1, f_mu[0]:  tensor([ -1.5157,   9.3464,   0.2546,   0.8381,   3.5019,  -0.6856,  -2.0196,
          1.7982,  -0.1898, -11.3286], device=

In [34]:
# # compute average log-likelihood for LB
print(torch.distributions.Categorical(torch.tensor(MNIST_test_in_LB_D)).log_prob(torch.tensor(targets)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_LB_D)).log_prob(torch.tensor(targets_FMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_LB_D)).log_prob(torch.tensor(targets_notMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_LB_D)).log_prob(torch.tensor(targets_KMNIST)).mean())

tensor(-0.1359)
tensor(-2.9849)
tensor(-2.7250)
tensor(-2.8325)


In [35]:
#compute ECE for LB
print(scoring.expected_calibration_error(targets, MNIST_test_in_LB_D))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_LB_D))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_LB_D))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_LB_D))

0.09263091919191915
0.39792657575757584
0.2774017411888594
0.28065724242424245


In [36]:
acc_in_LB_D, prob_correct_in_LB_D, ent_in_LB_D, MMC_in_LB_D = get_in_dist_values(MNIST_test_in_LB_D, targets)
acc_out_FMNIST_LB_D, prob_correct_out_FMNIST_LB_D, ent_out_FMNIST_LB_D, MMC_out_FMNIST_LB_D, auroc_out_FMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_FMNIST_LB_D, targets_FMNIST)
acc_out_notMNIST_LB_D, prob_correct_out_notMNIST_LB_D, ent_out_notMNIST_LB_D, MMC_out_notMNIST_LB_D, auroc_out_notMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_notMNIST_LB_D, targets_notMNIST)
acc_out_KMNIST_LB_D, prob_correct_out_KMNIST_LB_D, ent_out_KMNIST_LB_D, MMC_out_KMNIST_LB_D, auroc_out_KMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_KMNIST_LB_D, targets_KMNIST)

In [37]:
print_in_dist_values(acc_in_LB_D, prob_correct_in_LB_D, ent_in_LB_D, MMC_in_LB_D, 'MNIST', 'LB_D')
print_out_dist_values(acc_out_FMNIST_LB_D, prob_correct_out_FMNIST_LB_D, ent_out_FMNIST_LB_D, MMC_out_FMNIST_LB_D, auroc_out_FMNIST_LB_D, 'MNIST', test='fmnist', method='LB_D')
print_out_dist_values(acc_out_notMNIST_LB_D, prob_correct_out_notMNIST_LB_D, ent_out_notMNIST_LB_D, MMC_out_notMNIST_LB_D, auroc_out_notMNIST_LB_D, 'MNIST', test='notMNIST', method='LB_D')
print_out_dist_values(acc_out_KMNIST_LB_D, prob_correct_out_KMNIST_LB_D, ent_out_KMNIST_LB_D, MMC_out_KMNIST_LB_D, auroc_out_KMNIST_LB_D, 'MNIST', test='KMNIST', method='LB_D')

[In, LB_D, MNIST] Accuracy: 0.988; average entropy: 0.437;     MMC: 0.896; Prob @ correct: 0.100
[Out-fmnist, LB_D, MNIST] Accuracy: 0.082; Average entropy: 1.618;    MMC: 0.478; AUROC: 0.952; Prob @ correct: 0.100
[Out-notMNIST, LB_D, MNIST] Accuracy: 0.128; Average entropy: 1.775;    MMC: 0.398; AUROC: 0.958; Prob @ correct: 0.100
[Out-KMNIST, LB_D, MNIST] Accuracy: 0.096; Average entropy: 1.823;    MMC: 0.376; AUROC: 0.972; Prob @ correct: 0.100


# KFAC Laplace Bridge

In [38]:
MNIST_test_in_LB_KFAC = predict_LB(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_LB_KFAC = predict_LB(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_LB_KFAC = predict_LB(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_LB_KFAC = predict_LB(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

step0, f_mu[0]:  tensor([ -3.6094,  -2.8937,   1.1239,   1.7301,  -9.0167,  -3.9072, -17.8319,
         13.2881,  -5.2003,   0.8784], device='cuda:0')
step0, f_var[0]:  tensor([42.6037, 42.3545, 41.9556, 41.9227, 41.8064, 42.0326, 42.4786, 41.8525,
        41.5895, 41.5984], device='cuda:0')
step1, f_mu[0]:  tensor([ -1.0655,  -0.3498,   3.6677,   4.2740,  -6.4728,  -1.3633, -15.2881,
         15.8320,  -2.6564,   3.4222], device='cuda:0')
step1, f_var[0]:  tensor([1.8262, 1.5771, 1.1782, 1.1452, 1.0288, 1.2551, 1.7011, 1.0749, 0.8120,
        0.8208], device='cuda:0')
step0, f_mu[0]:  tensor([ -6.6111, -10.5723,  -3.6269,   0.6844,  -8.7426,  -2.2827,  -8.9403,
         -6.8173,  12.5674,  -2.6553], device='cuda:0')
step0, f_var[0]:  tensor([51.5873, 51.0035, 50.0460, 49.9629, 49.6766, 50.2312, 51.2922, 49.7903,
        49.1484, 49.1603], device='cuda:0')
step1, f_mu[0]:  tensor([-2.9115, -6.8726,  0.0728,  4.3841, -5.0429,  1.4169, -5.2406, -3.1176,
        16.2671,  1.0444], device=

step0, f_mu[0]:  tensor([ -2.4205,  -1.0927,  12.6614,  -3.7460,   1.5083,  -7.0049,  -7.0937,
         -3.2594,  -4.2696, -11.3067], device='cuda:0')
step0, f_var[0]:  tensor([38.0625, 37.5706, 36.7663, 36.6969, 36.4572, 36.9218, 37.8141, 36.5523,
        36.0143, 36.0255], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.1819,  1.5097, 15.2638, -1.1437,  4.1107, -4.4025, -4.4913, -0.6570,
        -1.6672, -8.7043], device='cuda:0')
step1, f_var[0]:  tensor([3.7680, 3.2761, 2.4718, 2.4023, 2.1626, 2.6273, 3.5196, 2.2577, 1.7198,
        1.7308], device='cuda:0')
step0, f_mu[0]:  tensor([ 13.6350, -17.5037,  -3.3616, -10.7703, -11.6570,  -2.0494,   2.5390,
         -5.1557,  -2.8156,   1.4312], device='cuda:0')
step0, f_var[0]:  tensor([82.9309, 82.3296, 81.3485, 81.2641, 80.9722, 81.5381, 82.6274, 81.0881,
        80.4325, 80.4467], device='cuda:0')
step1, f_mu[0]:  tensor([ 17.2058, -13.9329,   0.2092,  -7.1995,  -8.0862,   1.5214,   6.1098,
         -1.5849,   0.7552,   5.0020], device=

step0, f_mu[0]:  tensor([-4.4196, -9.4462, -5.2212,  2.7948, -1.4093, -2.2677, -3.8090, -7.0787,
         7.5015, -0.4067], device='cuda:0')
step0, f_var[0]:  tensor([27.8787, 27.4645, 26.7881, 26.7300, 26.5290, 26.9189, 27.6696, 26.6086,
        26.1570, 26.1672], device='cuda:0')
step1, f_mu[0]:  tensor([-2.0433, -7.0700, -2.8450,  5.1710,  0.9669,  0.1085, -1.4328, -4.7025,
         9.8777,  1.9695], device='cuda:0')
step1, f_var[0]:  tensor([3.1554, 2.7412, 2.0648, 2.0066, 1.8057, 2.1956, 2.9463, 1.8852, 1.4336,
        1.4438], device='cuda:0')
step0, f_mu[0]:  tensor([ -5.7790,  -1.3176,  -1.4149,  -1.0817,   0.9101,  -2.8325, -12.6804,
          4.3914,  -2.3317,   3.2369], device='cuda:0')
step0, f_var[0]:  tensor([20.5517, 20.3218, 19.9492, 19.9177, 19.8080, 20.0213, 20.4359, 19.8513,
        19.6040, 19.6110], device='cuda:0')
step1, f_mu[0]:  tensor([ -3.8891,   0.5723,   0.4751,   0.8082,   2.8000,  -0.9425, -10.7904,
          6.2813,  -0.4417,   5.1268], device='cuda:0')


step0, f_mu[0]:  tensor([  7.4148,  -7.3448,  -0.1742, -11.0964,   5.2517,  -9.3631,   2.6035,
         -5.1806,  -3.8293,  -2.8050], device='cuda:0')
step0, f_var[0]:  tensor([56.1766, 55.5998, 54.6620, 54.5821, 54.3048, 54.8433, 55.8857, 54.4145,
        53.7902, 53.8062], device='cuda:0')
step1, f_mu[0]:  tensor([ 9.8671, -4.8925,  2.2782, -8.6440,  7.7040, -6.9108,  5.0558, -2.7283,
        -1.3770, -0.3526], device='cuda:0')
step1, f_var[0]:  tensor([4.3510, 3.7742, 2.8363, 2.7564, 2.4791, 3.0177, 4.0601, 2.5888, 1.9646,
        1.9804], device='cuda:0')
step0, f_mu[0]:  tensor([ -9.6564,   5.1480,  10.8865,   0.7957,   1.7643, -11.9897,  -7.1100,
          0.2010,  -2.8308, -14.3488], device='cuda:0')
step0, f_var[0]:  tensor([57.8330, 57.0823, 55.8508, 55.7438, 55.3752, 56.0889, 57.4535, 55.5217,
        54.6958, 54.7107], device='cuda:0')
step1, f_mu[0]:  tensor([ -6.9424,   7.8620,  13.6006,   3.5098,   4.4783,  -9.2757,  -4.3960,
          2.9150,  -0.1168, -11.6348], device=

step0, f_var[0]:  tensor([53.9938, 53.3516, 52.3125, 52.2249, 51.9192, 52.5133, 53.6703, 52.0401,
        51.3507, 51.3704], device='cuda:0')
step1, f_mu[0]:  tensor([ -1.9041,  -1.3934,  14.0076,   4.3468,  -4.7439,  -5.7094, -12.8463,
         11.6625,   2.4253,  -5.8452], device='cuda:0')
step1, f_var[0]:  tensor([4.7915, 4.1493, 3.1102, 3.0226, 2.7168, 3.3110, 4.4681, 2.8377, 2.1483,
        2.1679], device='cuda:0')
step0, f_mu[0]:  tensor([ 14.6610, -12.9010,  -7.8308,  -9.8906, -10.5951,   0.0912,   3.4128,
         -8.7345,  -2.7977,  -0.1604], device='cuda:0')
step0, f_var[0]:  tensor([66.3216, 65.8525, 65.0898, 65.0247, 64.7988, 65.2372, 66.0851, 64.8884,
        64.3801, 64.3925], device='cuda:0')
step1, f_mu[0]:  tensor([18.1355, -9.4265, -4.3563, -6.4161, -7.1205,  3.5657,  6.8873, -5.2600,
         0.6768,  3.3141], device='cuda:0')
step1, f_var[0]:  tensor([3.5495, 3.0804, 2.3176, 2.2525, 2.0265, 2.4650, 3.3130, 2.1161, 1.6080,
        1.6202], device='cuda:0')
step0, f_

step1, f_mu[0]:  tensor([ 2.2556, -2.7034, -0.0239, -2.3188,  0.9581, -0.3982,  3.2352, -2.8776,
         3.1386, -1.2656], device='cuda:0')
step1, f_var[0]:  tensor([4.3480, 3.8470, 2.9922, 2.9111, 2.6412, 3.1575, 4.0921, 2.7508, 2.1582,
        2.1486], device='cuda:0')
step0, f_mu[0]:  tensor([ 1.3488, -2.6110,  3.8952, -3.1071, -2.2667, -2.7795,  1.5745, -3.1273,
        -2.7688, -5.9954], device='cuda:0')
step0, f_var[0]:  tensor([27.9184, 27.3817, 26.4982, 26.4208, 26.1550, 26.6691, 27.6469, 26.2609,
        25.6665, 25.6753], device='cuda:0')
step1, f_mu[0]:  tensor([ 2.9325, -1.0273,  5.4790, -1.5234, -0.6830, -1.1958,  3.1582, -1.5435,
        -1.1851, -4.4116], device='cuda:0')
step1, f_var[0]:  tensor([4.1934, 3.6567, 2.7732, 2.6957, 2.4300, 2.9440, 3.9219, 2.5358, 1.9415,
        1.9502], device='cuda:0')
step0, f_mu[0]:  tensor([-2.2860, -3.2419, -0.0406, -3.6440, -0.1208, -1.6802, -0.0559,  0.9579,
         2.9380, -3.3645], device='cuda:0')
step0, f_var[0]:  tensor([17.2

step0, f_mu[0]:  tensor([-2.2622, -0.2706, -2.3813, -3.8307,  0.5721, -1.6168, -1.1459, -1.5402,
         2.1724, -5.2745], device='cuda:0')
step0, f_var[0]:  tensor([15.1957, 14.6251, 13.6841, 13.6017, 13.3197, 13.8663, 14.9069, 13.4313,
        12.7992, 12.8101], device='cuda:0')
step1, f_mu[0]:  tensor([-0.7044,  1.2872, -0.8236, -2.2730,  2.1299, -0.0590,  0.4118,  0.0175,
         3.7302, -3.7167], device='cuda:0')
step1, f_var[0]:  tensor([4.4320, 3.8614, 2.9204, 2.8380, 2.5560, 3.1026, 4.1432, 2.6676, 2.0355,
        2.0464], device='cuda:0')
step0, f_mu[0]:  tensor([ 1.8990, -2.4482, -0.6891, -2.9655, -0.3145, -2.3179,  1.1539, -1.7276,
         1.6950, -1.6661], device='cuda:0')
step0, f_var[0]:  tensor([6.6078, 6.2726, 5.6985, 5.6437, 5.4621, 5.8096, 6.4364, 5.5358, 5.1370,
        5.1303], device='cuda:0')
step1, f_mu[0]:  tensor([ 2.6371, -1.7101,  0.0490, -2.2274,  0.4236, -1.5798,  1.8920, -0.9895,
         2.4331, -0.9280], device='cuda:0')
step1, f_var[0]:  tensor([2.91

step0, f_mu[0]:  tensor([-1.7215, -2.4863, -3.7852, -4.8109,  1.6065, -1.4784,  1.3445, -2.6960,
         1.9836, -3.4805], device='cuda:0')
step0, f_var[0]:  tensor([17.9577, 17.4022, 16.4811, 16.3993, 16.1208, 16.6594, 17.6761, 16.2316,
        15.6095, 15.6168], device='cuda:0')
step1, f_mu[0]:  tensor([-0.1691, -0.9338, -2.2328, -3.2585,  3.1589,  0.0740,  2.8970, -1.1436,
         3.5360, -1.9281], device='cuda:0')
step1, f_var[0]:  tensor([4.3923, 3.8368, 2.9157, 2.8339, 2.5554, 3.0940, 4.1107, 2.6661, 2.0441,
        2.0513], device='cuda:0')
step0, f_mu[0]:  tensor([ 2.7414, -1.9384,  2.2260, -2.6212, -3.6777, -3.8004, -1.5636, -5.6638,
         0.9994, -2.1862], device='cuda:0')
step0, f_var[0]:  tensor([25.0928, 24.3512, 23.1090, 22.9958, 22.6136, 23.3492, 24.7160, 22.7671,
        21.9194, 21.9197], device='cuda:0')
step1, f_mu[0]:  tensor([ 4.2898, -0.3900,  3.7745, -1.0727, -2.1292, -2.2520, -0.0151, -4.1153,
         2.5479, -0.6378], device='cuda:0')
step1, f_var[0]:  te

step0, f_mu[0]:  tensor([ 0.8565, -2.8002, -0.5210, -3.0929, -1.0463, -2.5566,  3.2171, -6.3834,
         3.0237, -2.6628], device='cuda:0')
step0, f_var[0]:  tensor([16.8474, 16.1179, 14.8642, 14.7433, 14.3436, 15.1065, 16.4742, 14.5066,
        13.6320, 13.6126], device='cuda:0')
step1, f_mu[0]:  tensor([ 2.0531, -1.6036,  0.6756, -1.8963,  0.1503, -1.3600,  4.4137, -5.1869,
         4.2203, -1.4662], device='cuda:0')
step1, f_var[0]:  tensor([6.4573, 5.7278, 4.4741, 4.3532, 3.9535, 4.7164, 6.0841, 4.1165, 3.2419,
        3.2225], device='cuda:0')
step0, f_mu[0]:  tensor([-4.8002, -4.6081, -3.0572,  0.9947,  1.6846,  0.4399, -4.2891, -0.9894,
         1.5109,  1.9923], device='cuda:0')
step0, f_var[0]:  tensor([6.9858, 6.7493, 6.3508, 6.3140, 6.1905, 6.4278, 6.8654, 6.2402, 5.9670,
        5.9659], device='cuda:0')
step1, f_mu[0]:  tensor([-3.6880, -3.4959, -1.9451,  2.1069,  2.7968,  1.5521, -3.1769,  0.1228,
         2.6230,  3.1044], device='cuda:0')
step1, f_var[0]:  tensor([1.96

step0, f_mu[0]:  tensor([ 0.3529,  2.2088,  7.0155, -3.2133, -1.8579, -3.9718, -0.7929, -4.7171,
        -2.6040, -8.8867], device='cuda:0')
step0, f_var[0]:  tensor([25.7884, 25.1399, 24.0619, 23.9655, 23.6380, 24.2705, 25.4595, 23.7686,
        23.0385, 23.0448], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.9996,  3.8554,  8.6621, -1.5667, -0.2112, -2.3251,  0.8537, -3.0704,
        -0.9574, -7.2400], device='cuda:0')
step1, f_var[0]:  tensor([5.1737, 4.5253, 3.4472, 3.3508, 3.0233, 3.6558, 4.8448, 3.1539, 2.4238,
        2.4300], device='cuda:0')
step0, f_mu[0]:  tensor([ 0.8690, -2.6540, -0.5564, -2.8267, -0.1812, -3.7669,  2.1301, -4.9279,
         3.1603, -2.3284], device='cuda:0')
step0, f_var[0]:  tensor([14.5532, 13.9384, 12.8855, 12.7849, 12.4513, 13.0891, 14.2390, 12.5868,
        11.8550, 11.8417], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.9772, -1.5458,  0.5518, -1.7185,  0.9270, -2.6587,  3.2383, -3.8197,
         4.2685, -1.2202], device='cuda:0')
step1, f_var[0]:  te

step0, f_mu[0]:  tensor([ 1.7994,  0.7980,  0.6430,  1.5861, -5.1248, -0.5785, -1.2213, -1.7858,
        -2.7957,  0.3945], device='cuda:0')
step0, f_var[0]:  tensor([7.7522, 7.2400, 6.3738, 6.2933, 6.0238, 6.5414, 7.4913, 6.1324, 5.5371,
        5.5331], device='cuda:0')
step1, f_mu[0]:  tensor([ 2.4280,  1.4265,  1.2715,  2.2146, -4.4963,  0.0500, -0.5928, -1.1573,
        -2.1671,  1.0230], device='cuda:0')
step1, f_var[0]:  tensor([4.3088, 3.7965, 2.9304, 2.8499, 2.5804, 3.0979, 4.0478, 2.6890, 2.0937,
        2.0897], device='cuda:0')
step0, f_mu[0]:  tensor([  6.3065,  -2.4800,   1.9716,   1.5610, -16.4387,   0.3798,  -7.5889,
         -2.3173,  -6.8252,  -1.2498], device='cuda:0')
step0, f_var[0]:  tensor([38.3775, 36.9833, 34.6227, 34.4019, 33.6634, 35.0790, 37.6670, 33.9624,
        32.3352, 32.3185], device='cuda:0')
step1, f_mu[0]:  tensor([  8.9746,   0.1881,   4.6397,   4.2291, -13.7706,   3.0479,  -4.9208,
          0.3508,  -4.1571,   1.4183], device='cuda:0')
step1, f_v

step1, f_var[0]:  tensor([14.6233, 12.7674,  9.7005,  9.4295,  8.5033, 10.2936, 13.6835,  8.8724,
         6.8038,  6.8289], device='cuda:0')
step0, f_mu[0]:  tensor([-4.0005, -9.5164, -6.1010,  0.7874, -6.6176,  1.8897, -6.7852,  0.0236,
         0.5745, -1.3376], device='cuda:0')
step0, f_var[0]:  tensor([74.0684, 71.6090, 67.4377, 67.0478, 65.7465, 68.2448, 72.8145, 66.2708,
        63.3981, 63.3746], device='cuda:0')
step1, f_mu[0]:  tensor([-0.8922, -6.4081, -2.9927,  3.8957, -3.5093,  4.9980, -3.6769,  3.1319,
         3.6828,  1.7707], device='cuda:0')
step1, f_var[0]:  tensor([20.7671, 18.3077, 14.1364, 13.7464, 12.4451, 14.9434, 19.5132, 12.9694,
        10.0967, 10.0731], device='cuda:0')
step0, f_mu[0]:  tensor([-3.2344, -5.6021, -8.0270, -6.3837,  5.4186, -2.4644,  0.6813, -1.3214,
         0.6719, -4.7977], device='cuda:0')
step0, f_var[0]:  tensor([57.3146, 55.2484, 51.7728, 51.4526, 50.3759, 52.4447, 56.2635, 50.8099,
        48.4283, 48.4187], device='cuda:0')
step1, f_

step0, f_mu[0]:  tensor([-1.5284, -6.6120,  0.3215, -3.6891, -4.3560,  0.4351,  4.8094, -0.9859,
         1.1104, -5.0049], device='cuda:0')
step0, f_var[0]:  tensor([37.1595, 35.5047, 32.7223, 32.4668, 31.6071, 33.2605, 36.3178, 31.9527,
        30.0488, 30.0444], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.0216, -5.0621,  1.8715, -2.1391, -2.8060,  1.9851,  6.3594,  0.5641,
         2.6604, -3.4549], device='cuda:0')
step1, f_var[0]:  tensor([13.6962, 12.0413,  9.2590,  9.0034,  8.1437,  9.7971, 12.8544,  8.4893,
         6.5855,  6.5810], device='cuda:0')
step0, f_mu[0]:  tensor([  1.5498,  -6.3419,  -4.8950, -13.9070,   6.1996,  -1.3368,   6.5363,
         -0.2737,   1.6118,  -6.5229], device='cuda:0')
step0, f_var[0]:  tensor([29.6266, 28.6238, 26.9373, 26.7822, 26.2606, 27.2635, 29.1165, 26.4705,
        25.3159, 25.3125], device='cuda:0')
step1, f_mu[0]:  tensor([  3.2878,  -4.6039,  -3.1570, -12.1690,   7.9376,   0.4012,   8.2743,
          1.4643,   3.3498,  -4.7849], device=

step0, f_mu[0]:  tensor([ 2.1912,  0.7637, -0.4709, -1.6132, -8.7094,  2.4929,  2.4084, -8.8440,
        -2.1088, -3.5158], device='cuda:0')
step0, f_var[0]:  tensor([34.7377, 32.9783, 30.0116, 29.7374, 28.8170, 30.5854, 33.8421, 29.1876,
        27.1525, 27.1426], device='cuda:0')
step1, f_mu[0]:  tensor([ 3.9318,  2.5043,  1.2697,  0.1274, -6.9688,  4.2335,  4.1490, -7.1034,
        -0.3682, -1.7753], device='cuda:0')
step1, f_var[0]:  tensor([14.6868, 12.9274,  9.9607,  9.6864,  8.7660, 10.5345, 13.7912,  9.1367,
         7.1016,  7.0917], device='cuda:0')
step0, f_mu[0]:  tensor([  2.8919,  -1.4732,  -0.2460,   0.1849,  -9.0092,   4.3906,   1.1304,
        -15.8447,  -3.4940,  -0.9726], device='cuda:0')
step0, f_var[0]:  tensor([41.6607, 39.7943, 36.5884, 36.2793, 35.2586, 37.2080, 40.7060, 35.6744,
        33.4389, 33.3916], device='cuda:0')
step1, f_mu[0]:  tensor([  5.1361,   0.7709,   1.9982,   2.4291,  -6.7650,   6.6348,   3.3746,
        -13.6005,  -1.2498,   1.2716], device=

step0, f_mu[0]:  tensor([ 2.2055, -6.8008, -0.2591, -7.9921,  2.7938, -4.2549,  2.8480, -6.3975,
        -0.1752, -2.2329], device='cuda:0')
step0, f_var[0]:  tensor([35.4387, 34.3352, 32.4596, 32.2833, 31.6960, 32.8224, 34.8758, 31.9332,
        30.6386, 30.6248], device='cuda:0')
step1, f_mu[0]:  tensor([ 4.2320, -4.7742,  1.7675, -5.9656,  4.8203, -2.2284,  4.8745, -4.3710,
         1.8514, -0.2064], device='cuda:0')
step1, f_var[0]:  tensor([9.3867, 8.2832, 6.4076, 6.2313, 5.6440, 6.7704, 8.8238, 5.8812, 4.5865,
        4.5727], device='cuda:0')
step0, f_mu[0]:  tensor([-11.6415,  -3.3189,  -0.4398,  -5.9514,   9.5302,  -3.5312, -10.5293,
         -2.9876,  -2.6965,  -0.4369], device='cuda:0')
step0, f_var[0]:  tensor([48.6324, 47.5598, 45.7942, 45.6389, 45.1062, 46.1354, 48.0898, 45.3192,
        44.1296, 44.1446], device='cuda:0')
step1, f_mu[0]:  tensor([-8.4412, -0.1186,  2.7605, -2.7511, 12.7305, -0.3309, -7.3290,  0.2127,
         0.5038,  2.7634], device='cuda:0')
step1, f_v

step0, f_mu[0]:  tensor([-2.1007, -3.0391, -5.1966, -3.5221,  2.4269, -1.8375,  0.9011, -1.8050,
         1.7834, -2.0674], device='cuda:0')
step0, f_var[0]:  tensor([20.6266, 19.7427, 18.2588, 18.1235, 17.6679, 18.5461, 20.1771, 17.8500,
        16.8379, 16.8401], device='cuda:0')
step1, f_mu[0]:  tensor([-0.6550, -1.5934, -3.7509, -2.0764,  3.8726, -0.3918,  2.3468, -0.3593,
         3.2291, -0.6217], device='cuda:0')
step1, f_var[0]:  tensor([7.2193, 6.3355, 4.8515, 4.7162, 4.2606, 5.1388, 6.7699, 4.4427, 3.4307,
        3.4328], device='cuda:0')
step0, f_mu[0]:  tensor([-0.1632, -1.5174,  0.3894,  0.2020, -0.0903, -0.4970,  0.2381, -2.1485,
         1.7345, -0.4975], device='cuda:0')
step0, f_var[0]:  tensor([5.2628, 4.9009, 4.2713, 4.2093, 4.0065, 4.3931, 5.0771, 4.0892, 3.6467,
        3.6342], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.0718, -1.2824,  0.6244,  0.4370,  0.1447, -0.2620,  0.4731, -1.9135,
         1.9695, -0.2626], device='cuda:0')
step1, f_var[0]:  tensor([3.27

step0, f_mu[0]:  tensor([ -0.9127,  -5.1438,  -4.0114,   0.8094, -10.7545,   3.7146,   4.3282,
         -8.8630,  -5.7831,  -2.4193], device='cuda:0')
step0, f_var[0]:  tensor([57.1777, 54.9397, 51.1541, 50.8012, 49.6203, 51.8860, 56.0375, 50.0973,
        47.4918, 47.4701], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.9909, -2.2403, -1.1078,  3.7129, -7.8509,  6.6181,  7.2318, -5.9594,
        -2.8796,  0.4843], device='cuda:0')
step1, f_var[0]:  tensor([18.8797, 16.6417, 12.8560, 12.5031, 11.3223, 13.5879, 17.7395, 11.7993,
         9.1938,  9.1720], device='cuda:0')
step0, f_mu[0]:  tensor([ 2.5312, -1.4428, -0.0719, -5.4600, -9.3628, -2.2836,  4.5146, -4.8090,
         0.7674, -9.9157], device='cuda:0')
step0, f_var[0]:  tensor([42.9360, 41.7962, 39.9046, 39.7357, 39.1609, 40.2704, 42.3581, 39.3906,
        38.1097, 38.1202], device='cuda:0')
step1, f_mu[0]:  tensor([ 5.0844,  1.1105,  2.4813, -2.9068, -6.8096,  0.2697,  7.0678, -2.2557,
         3.3206, -7.3624], device='cuda:0')


step0, f_mu[0]:  tensor([-11.8140,  -8.3711,  -6.6337,   3.5255,   1.7556,   1.8726, -12.0009,
         -1.9994,  -0.8098,   4.3781], device='cuda:0')
step0, f_var[0]:  tensor([40.7861, 39.6799, 37.8391, 37.6740, 37.1131, 38.1952, 40.2249, 37.3373,
        36.0887, 36.0969], device='cuda:0')
step1, f_mu[0]:  tensor([-8.8042, -5.3614, -3.6240,  6.5352,  4.7653,  4.8823, -8.9912,  1.0103,
         2.1999,  7.3878], device='cuda:0')
step1, f_var[0]:  tensor([8.8925, 7.7863, 5.9455, 5.7804, 5.2194, 6.3016, 8.3313, 5.4437, 4.1951,
        4.2032], device='cuda:0')
step0, f_mu[0]:  tensor([-4.8893, -4.2466, -7.0625,  2.5644, -2.3544, -0.6096, -2.6097, -4.0681,
         1.0175, -2.4631], device='cuda:0')
step0, f_var[0]:  tensor([29.8911, 29.0325, 27.6090, 27.4827, 27.0524, 27.8845, 29.4559, 27.2235,
        26.2623, 26.2737], device='cuda:0')
step1, f_mu[0]:  tensor([-2.4171, -1.7745, -4.5904,  5.0366,  0.1178,  1.8625, -0.1375, -1.5959,
         3.4896,  0.0090], device='cuda:0')
step1, f_v

step0, f_mu[0]:  tensor([-0.8280, -2.6134,  0.2171, -1.5644, -6.4498, -2.6862, -3.1101, -4.9353,
         5.0877, -3.7339], device='cuda:0')
step0, f_var[0]:  tensor([26.8021, 25.9732, 24.5766, 24.4479, 24.0158, 24.8468, 26.3802, 24.1894,
        23.2327, 23.2298], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.2336, -0.5518,  2.2787,  0.4972, -4.3882, -0.6246, -1.0485, -2.8736,
         7.1493, -1.6723], device='cuda:0')
step1, f_var[0]:  tensor([6.8761, 6.0472, 4.6506, 4.5219, 4.0897, 4.9208, 6.4542, 4.2634, 3.3066,
        3.3037], device='cuda:0')
step0, f_mu[0]:  tensor([-3.8684, -3.9606,  2.8475, -4.1023, -0.4212, -5.2620, -3.9019,  2.4530,
        -3.8046, -1.9386], device='cuda:0')
step0, f_var[0]:  tensor([25.8541, 24.8929, 23.2973, 23.1545, 22.6686, 23.6057, 25.3668, 22.8632,
        21.7817, 21.7889], device='cuda:0')
step1, f_mu[0]:  tensor([-1.6725, -1.7647,  5.0434, -1.9063,  1.7748, -3.0661, -1.7060,  4.6489,
        -1.6087,  0.2574], device='cuda:0')
step1, f_var[0]:  te

step0, f_mu[0]:  tensor([-0.8200, -6.8518, -0.8141, -2.0715, -7.9859, -0.8577,  0.5524, -2.9372,
         1.5733, -3.4869], device='cuda:0')
step0, f_var[0]:  tensor([53.6581, 51.6194, 48.2193, 47.9124, 46.8724, 48.8768, 52.6232, 47.2890,
        44.9774, 44.9870], device='cuda:0')
step1, f_mu[0]:  tensor([ 1.5499, -4.4819,  1.5558,  0.2985, -5.6160,  1.5122,  2.9223, -0.5672,
         3.9432, -1.1169], device='cuda:0')
step1, f_var[0]:  tensor([16.5131, 14.4745, 11.0743, 10.7674,  9.7274, 11.7319, 15.4783, 10.1440,
         7.8325,  7.8419], device='cuda:0')
step0, f_mu[0]:  tensor([ -6.0668,  -2.5466,   1.1892,  -3.9820,  -3.0327,  -4.5590, -11.1181,
          3.4039,   3.6676, -11.0982], device='cuda:0')
step0, f_var[0]:  tensor([57.8597, 56.3075, 53.7204, 53.4860, 52.6910, 54.2201, 57.0720, 53.0114,
        51.2487, 51.2503], device='cuda:0')
step1, f_mu[0]:  tensor([-2.6526,  0.8676,  4.6035, -0.5677,  0.3815, -1.1447, -7.7038,  6.8181,
         7.0818, -7.6839], device='cuda:0')


step0, f_mu[0]:  tensor([ -1.3499,  -4.4015,  -1.7445,  -2.9443,  -8.1042,   4.3698, -11.2996,
          7.4982,  -1.6195,   0.2827], device='cuda:0')
step0, f_var[0]:  tensor([48.3675, 47.2552, 45.4164, 45.2533, 44.6960, 45.7718, 47.8042, 44.9188,
        43.6763, 43.6885], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.5814, -2.4702,  0.1868, -1.0130, -6.1729,  6.3011, -9.3683,  9.4295,
         0.3118,  2.2140], device='cuda:0')
step1, f_var[0]:  tensor([8.8254, 7.7132, 5.8743, 5.7112, 5.1539, 6.2297, 8.2621, 5.3767, 4.1342,
        4.1463], device='cuda:0')
step0, f_mu[0]:  tensor([ -1.8975,   7.5904,   2.9152,  -6.5281,  -0.9744, -11.7696,  -7.5872,
         -1.0047,  -1.9743,  -6.6970], device='cuda:0')
step0, f_var[0]:  tensor([52.0578, 50.6201, 48.1625, 47.9270, 47.1449, 48.6370, 51.3234, 47.4648,
        45.7536, 45.7169], device='cuda:0')
step1, f_mu[0]:  tensor([ 0.8952, 10.3831,  5.7079, -3.7354,  1.8183, -8.9768, -4.7945,  1.7880,
         0.8184, -3.9042], device='cuda:0')


step0, f_mu[0]:  tensor([-5.0428, -8.7250, -3.2601,  1.9312, -1.3528, -6.1876, -4.3860, -1.6381,
         1.9026,  0.2402], device='cuda:0')
step0, f_var[0]:  tensor([34.7211, 33.7593, 32.1662, 32.0245, 31.5417, 32.4743, 34.2337, 31.7344,
        30.6574, 30.6681], device='cuda:0')
step1, f_mu[0]:  tensor([-2.3909, -6.0732, -0.6082,  4.5831,  1.2990, -3.5358, -1.7342,  1.0137,
         4.5545,  2.8921], device='cuda:0')
step1, f_var[0]:  tensor([7.6362, 6.6744, 5.0812, 4.9396, 4.4567, 5.3893, 7.1488, 4.6495, 3.5725,
        3.5831], device='cuda:0')
step0, f_mu[0]:  tensor([ -4.0572,  -3.7610,   1.3121,   1.9487,   0.1938,  -2.5542, -10.3887,
         -2.1740,  -5.6533,  -2.4454], device='cuda:0')
step0, f_var[0]:  tensor([29.9759, 29.1005, 27.6565, 27.5296, 27.0951, 27.9359, 29.5328, 27.2679,
        26.2960, 26.3101], device='cuda:0')
step1, f_mu[0]:  tensor([-1.2993, -1.0031,  4.0700,  4.7066,  2.9517,  0.2037, -7.6308,  0.5839,
        -2.8954,  0.3125], device='cuda:0')
step1, f_v

In [39]:
# compute average log-likelihood for LB KFAC
print(torch.distributions.Categorical(torch.tensor(MNIST_test_in_LB_KFAC)).log_prob(torch.tensor(targets)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_LB_KFAC)).log_prob(torch.tensor(targets_FMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_LB_KFAC)).log_prob(torch.tensor(targets_notMNIST)).mean())
print(torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_LB_KFAC)).log_prob(torch.tensor(targets_KMNIST)).mean())

tensor(-0.0472)
tensor(-2.7533)
tensor(-2.7631)
tensor(-3.0229)


In [40]:
print(scoring.expected_calibration_error(targets, MNIST_test_in_LB_KFAC))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_LB_KFAC))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_LB_KFAC))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_LB_KFAC))

0.021417212121212118
0.3247582828282829
0.2768516666342986
0.2984889696969697


In [41]:
acc_in_LB_KFAC, prob_correct_in_LB_KFAC, ent_in_LB_KFAC, MMC_in_LB_KFAC = get_in_dist_values(MNIST_test_in_LB_KFAC, targets)
acc_out_FMNIST_LB_KFAC, prob_correct_out_FMNIST_LB_KFAC, ent_out_FMNIST_LB_KFAC, MMC_out_FMNIST_LB_KFAC, auroc_out_FMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_FMNIST_LB_KFAC, targets_FMNIST)
acc_out_notMNIST_LB_KFAC, prob_correct_out_notMNIST_LB_KFAC, ent_out_notMNIST_LB_KFAC, MMC_out_notMNIST_LB_KFAC, auroc_out_notMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_notMNIST_LB_KFAC, targets_notMNIST)
acc_out_KMNIST_LB_KFAC, prob_correct_out_KMNIST_LB_KFAC, ent_out_KMNIST_LB_KFAC, MMC_out_KMNIST_LB_KFAC, auroc_out_KMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_KMNIST_LB_KFAC, targets_KMNIST)

In [42]:
print_in_dist_values(acc_in_LB_KFAC, prob_correct_in_LB_KFAC, ent_in_LB_KFAC, MMC_in_LB_KFAC, 'MNIST', 'LB_KFAC')
print_out_dist_values(acc_out_FMNIST_LB_KFAC, prob_correct_out_FMNIST_LB_KFAC, ent_out_FMNIST_LB_KFAC, MMC_out_FMNIST_LB_KFAC, auroc_out_FMNIST_LB_KFAC, 'MNIST', test='fmnist', method='LB_KFAC')
print_out_dist_values(acc_out_notMNIST_LB_KFAC, prob_correct_out_notMNIST_LB_KFAC, ent_out_notMNIST_LB_KFAC, MMC_out_notMNIST_LB_KFAC, auroc_out_notMNIST_LB_KFAC, 'MNIST', test='notMNIST', method='LB_KFAC')
print_out_dist_values(acc_out_KMNIST_LB_KFAC, prob_correct_out_KMNIST_LB_KFAC, ent_out_KMNIST_LB_KFAC, MMC_out_KMNIST_LB_KFAC, auroc_out_KMNIST_LB_KFAC, 'MNIST', test='KMNIST', method='LB_KFAC')

[In, LB_KFAC, MNIST] Accuracy: 0.988; average entropy: 0.109;     MMC: 0.970; Prob @ correct: 0.100
[Out-fmnist, LB_KFAC, MNIST] Accuracy: 0.081; Average entropy: 1.840;    MMC: 0.401; AUROC: 0.991; Prob @ correct: 0.100
[Out-notMNIST, LB_KFAC, MNIST] Accuracy: 0.140; Average entropy: 1.793;    MMC: 0.390; AUROC: 0.986; Prob @ correct: 0.100
[Out-KMNIST, LB_KFAC, MNIST] Accuracy: 0.088; Average entropy: 1.794;    MMC: 0.386; AUROC: 0.991; Prob @ correct: 0.100


In [43]:
break

SyntaxError: 'break' outside loop (668683560.py, line 1)

# Compare to extended MacKay approach

In [None]:
MNIST_test_in_EMK_D = predict_extended_MacKay(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_EMK_D = predict_extended_MacKay(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_EMK_D = predict_extended_MacKay(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_EMK_D = predict_extended_MacKay(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
acc_in_EMK, prob_correct_in_EMK, ent_in_EMK, MMC_in_EMK = get_in_dist_values(MNIST_test_in_EMK_D, targets)
acc_out_FMNIST_EMK, prob_correct_out_FMNIST_EMK, ent_out_FMNIST_EMK, MMC_out_FMNIST_EMK, auroc_out_FMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_D, MNIST_test_out_FMNIST_EMK_D, targets_FMNIST)
acc_out_notMNIST_EMK, prob_correct_out_notMNIST_EMK, ent_out_notMNIST_EMK, MMC_out_notMNIST_EMK, auroc_out_notMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_D, MNIST_test_out_notMNIST_EMK_D, targets_notMNIST)
acc_out_KMNIST_EMK, prob_correct_out_KMNIST_EMK, ent_out_KMNIST_EMK, MMC_out_KMNIST_EMK, auroc_out_KMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_D, MNIST_test_out_KMNIST_EMK_D, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_EMK, prob_correct_in_EMK, ent_in_EMK, MMC_in_EMK, 'MNIST', 'EMK')
print_out_dist_values(acc_out_FMNIST_EMK, prob_correct_out_FMNIST_EMK, ent_out_FMNIST_EMK, MMC_out_FMNIST_EMK, auroc_out_FMNIST_EMK, 'MNIST', test='fmnist', method='EMK')
print_out_dist_values(acc_out_notMNIST_EMK, prob_correct_out_notMNIST_EMK, ent_out_notMNIST_EMK, MMC_out_notMNIST_EMK, auroc_out_notMNIST_EMK, 'MNIST', test='notMNIST', method='EMK')
print_out_dist_values(acc_out_KMNIST_EMK, prob_correct_out_KMNIST_EMK, ent_out_KMNIST_EMK, MMC_out_KMNIST_EMK, auroc_out_KMNIST_EMK, 'MNIST', test='KMNIST', method='EMK')

### EMK KFAC

In [None]:
MNIST_test_in_EMK_K = predict_extended_MacKay(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_EMK_K = predict_extended_MacKay(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_EMK_K = predict_extended_MacKay(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_EMK_K = predict_extended_MacKay(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
acc_in_EMK, prob_correct_in_EMK, ent_in_EMK, MMC_in_EMK = get_in_dist_values(MNIST_test_in_EMK_K, targets)
acc_out_FMNIST_EMK, prob_correct_out_FMNIST_EMK, ent_out_FMNIST_EMK, MMC_out_FMNIST_EMK, auroc_out_FMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_K, MNIST_test_out_FMNIST_EMK_K, targets_FMNIST)
acc_out_notMNIST_EMK, prob_correct_out_notMNIST_EMK, ent_out_notMNIST_EMK, MMC_out_notMNIST_EMK, auroc_out_notMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_K, MNIST_test_out_notMNIST_EMK_K, targets_notMNIST)
acc_out_KMNIST_EMK, prob_correct_out_KMNIST_EMK, ent_out_KMNIST_EMK, MMC_out_KMNIST_EMK, auroc_out_KMNIST_EMK = get_out_dist_values(MNIST_test_in_EMK_K, MNIST_test_out_KMNIST_EMK_K, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_EMK, prob_correct_in_EMK, ent_in_EMK, MMC_in_EMK, 'MNIST', 'EMK')
print_out_dist_values(acc_out_FMNIST_EMK, prob_correct_out_FMNIST_EMK, ent_out_FMNIST_EMK, MMC_out_FMNIST_EMK, auroc_out_FMNIST_EMK, 'MNIST', test='fmnist', method='EMK')
print_out_dist_values(acc_out_notMNIST_EMK, prob_correct_out_notMNIST_EMK, ent_out_notMNIST_EMK, MMC_out_notMNIST_EMK, auroc_out_notMNIST_EMK, 'MNIST', test='notMNIST', method='EMK')
print_out_dist_values(acc_out_KMNIST_EMK, prob_correct_out_KMNIST_EMK, ent_out_KMNIST_EMK, MMC_out_KMNIST_EMK, auroc_out_KMNIST_EMK, 'MNIST', test='KMNIST', method='EMK')

# Compare to Second-order Delta Posterior Predictive

In [None]:
MNIST_test_in_SODPP_D = predict_second_order_dpp(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_SODPP_D = predict_second_order_dpp(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_SODPP_D = predict_second_order_dpp(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_SODPP_D = predict_second_order_dpp(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP = get_in_dist_values(MNIST_test_in_SODPP_D, targets)
acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_FMNIST_SODPP_D, targets_FMNIST)
acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_notMNIST_SODPP_D, targets_notMNIST)
acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_KMNIST_SODPP_D, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP, 'MNIST', 'SODPP')
print_out_dist_values(acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP, 'MNIST', test='fmnist', method='SODPP')
print_out_dist_values(acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP, 'MNIST', test='notMNIST', method='SODPP')
print_out_dist_values(acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP, 'MNIST', test='KMNIST', method='SODPP')

### KFAC SODPP

In [None]:
MNIST_test_in_SODPP_K = predict_second_order_dpp(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_SODPP_K = predict_second_order_dpp(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_SODPP_K = predict_second_order_dpp(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_SODPP_K = predict_second_order_dpp(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP = get_in_dist_values(MNIST_test_in_SODPP_K, targets)
acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_FMNIST_SODPP_K, targets_FMNIST)
acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_notMNIST_SODPP_K, targets_notMNIST)
acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_KMNIST_SODPP_K, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP, 'MNIST', 'SODPP')
print_out_dist_values(acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP, 'MNIST', test='fmnist', method='SODPP')
print_out_dist_values(acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP, 'MNIST', test='notMNIST', method='SODPP')
print_out_dist_values(acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP, 'MNIST', test='KMNIST', method='SODPP')

In [None]:
break

# Experiments on Rotated MNIST

In [None]:
import torchvision.transforms as transforms
# rotate 15 degrees
MNIST_transform_r15 = transforms.Compose([
    transforms.RandomRotation(15),
    transforms.ToTensor()
    ])

MNIST_test_r15 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r15)

mnist_test_loader_r15 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r15,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 30 degrees
MNIST_transform_r30 = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.ToTensor(),
    ])

MNIST_test_r30 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r30)

mnist_test_loader_r30 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r30,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 45 degrees
MNIST_transform_r45 = transforms.Compose([
    transforms.RandomRotation(45),
    transforms.ToTensor(),
    ])

MNIST_test_r45 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r45)

mnist_test_loader_r45 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r45,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 60 degrees
MNIST_transform_r60 = transforms.Compose([
    transforms.RandomRotation(60),
    transforms.ToTensor(),
    ])

MNIST_test_r60 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r60)

mnist_test_loader_r60 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r60,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 75 degrees
MNIST_transform_r75 = transforms.Compose([
    transforms.RandomRotation(75),
    transforms.ToTensor(),
    ])

MNIST_test_r75 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r75)

mnist_test_loader_r75 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r75,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 90 degrees
MNIST_transform_r90 = transforms.Compose([
    transforms.RandomRotation(90),
    transforms.ToTensor(),
    ])

MNIST_test_r90 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r90)

mnist_test_loader_r90 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r90,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 105 degrees
MNIST_transform_r105 = transforms.Compose([
    transforms.RandomRotation(105),
    transforms.ToTensor(),
    ])

MNIST_test_r105 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r105)

mnist_test_loader_r105 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r105,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 120 degrees
MNIST_transform_r120 = transforms.Compose([
    transforms.RandomRotation(120),
    transforms.ToTensor(),
    ])

MNIST_test_r120 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r120)

mnist_test_loader_r120 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r120,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 135 degrees
MNIST_transform_r135 = transforms.Compose([
    transforms.RandomRotation(135),
    transforms.ToTensor(),
    ])

MNIST_test_r135 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r135)

mnist_test_loader_r135 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r135,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 150 degrees
MNIST_transform_r150= transforms.Compose([
    transforms.RandomRotation(150),
    transforms.ToTensor(),
    ])

MNIST_test_r150 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r150)

mnist_test_loader_r150 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r150,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 165 degrees
MNIST_transform_r165 = transforms.Compose([
    transforms.RandomRotation(165),
    transforms.ToTensor(),
    ])

MNIST_test_r165 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r165)

mnist_test_loader_r165 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r165,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

# rotate 180 degrees
MNIST_transform_r180 = transforms.Compose([
    transforms.RandomRotation(180),
    transforms.ToTensor(),
    ])

MNIST_test_r180 = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform_r180)

mnist_test_loader_r180 = torch.utils.data.dataloader.DataLoader(
    MNIST_test_r180,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False
)

In [None]:
from sklearn.metrics import brier_score_loss

## helper function: given a dataloader compute accuracy and brier score
def get_acc_brier(dataloader, targets, num_samples=1000):
    
    # compute sampling results
    mnist_rotated_D = predict_diagonal_sampling(mnist_model, dataloader, M_W_post_D, M_b_post_D, C_W_post_D, C_b_post_D, verbose=False, cuda=True, timing=False, n_samples=num_samples).cpu().numpy()
    
    # compute LB results
    mnist_rotated_LB = predict_LB(mnist_model, dataloader, M_W_post_D, M_b_post_D, C_W_post_D, C_b_post_D, verbose=False, cuda=True, timing=False).cpu().numpy()
    mnist_rotated_LBn = mnist_rotated_LB/mnist_rotated_LB.sum(1).reshape(-1,1)
    
    # accuracy for sampling and LB
    acc_D = np.mean(np.argmax(mnist_rotated_D, 1) == targets)
    acc_LB = np.mean(np.argmax(mnist_rotated_LBn, 1) == targets)
    
    # get brier score for sampling and LB
    pred_at_true_D = np.array([mnist_rotated_D[i, j] for i, j in enumerate(targets)])
    pred_at_true_LBn = np.array([mnist_rotated_LBn[i, j] for i, j in enumerate(targets)])
    
    brier_D = brier_score_loss(np.ones_like(pred_at_true_D), pred_at_true_D)
    brier_LB = brier_score_loss(np.ones_like(pred_at_true_LBn), pred_at_true_LBn)
    
    return(acc_D, acc_LB, brier_D, brier_LB)

In [None]:
# just a test
get_acc_brier(mnist_test_loader_r15, targets)

In [None]:
## predict on all distributions and compute accuracy and brier score

dataloader_list = [mnist_test_loader, mnist_test_loader_r15, mnist_test_loader_r30, mnist_test_loader_r45,
                  mnist_test_loader_r60, mnist_test_loader_r75, mnist_test_loader_r90, mnist_test_loader_r105,
                  mnist_test_loader_r120, mnist_test_loader_r135, mnist_test_loader_r150, mnist_test_loader_r165, 
                  mnist_test_loader_r180]

Acc_D_list = []
Acc_LB_list = []
Brier_D_list = []
Brier_LB_list = []

for i, loader_ in enumerate(dataloader_list):
    print(i)
    Acc_D_, Acc_LB_, Brier_D_, Brier_LB_ = get_acc_brier(loader_, targets)
    Acc_D_list.append(Acc_D_)
    Acc_LB_list.append(Acc_LB_)
    Brier_D_list.append(Brier_D_)
    Brier_LB_list.append(Brier_LB_)
    

In [None]:
# Make inline plots vector graphics
import matplotlib
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("pdf", "svg")

matplotlib.rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
plt.rcParams["text.usetex"] = True
plt.rcParams["text.latex.preamble"] = r"\usepackage{amsfonts} \usepackage{amsmath}"

In [None]:
print(Acc_D_list)
print(Acc_LB_list)
print(Brier_D_list)
print(Brier_LB_list)

In [None]:
#### compare over 5 seeds: 123, 124, 125, 126, 127
Acc_D_all = np.array([
    [0.9906, 0.981, 0.9401, 0.8593, 0.7388, 0.6375, 0.5627, 0.5049, 0.4703, 0.4533, 0.4363, 0.4287, 0.4268],
    [0.9889, 0.9814, 0.9433, 0.8518, 0.7337, 0.6403, 0.561, 0.5016, 0.4669, 0.4306, 0.4292, 0.4104, 0.4183],
    [0.9904, 0.9829, 0.946, 0.8582, 0.747, 0.6495, 0.5602, 0.5075, 0.4684, 0.4437, 0.4334, 0.4245, 0.4217],
    [0.9875, 0.9776, 0.9353, 0.8422, 0.7212, 0.6377, 0.55, 0.4896, 0.4497, 0.4328, 0.416, 0.4117, 0.4095],
    [0.9894, 0.9816, 0.9451, 0.8522, 0.7403, 0.6391, 0.5565, 0.4994, 0.4656, 0.4308, 0.4284, 0.4221, 0.4152]
])

Acc_LB_all = np.array([
    [0.9907, 0.983, 0.9394, 0.8526, 0.7419, 0.6386, 0.5703, 0.5152, 0.4702, 0.4484, 0.4335, 0.4236, 0.4182],
    [0.9889, 0.9804, 0.9471, 0.8547, 0.7411, 0.6296, 0.5624, 0.4938, 0.4609, 0.4404, 0.4295, 0.4172, 0.4152],
    [0.9905, 0.9831, 0.9436, 0.8616, 0.7461, 0.6517, 0.5647, 0.5127, 0.4673, 0.4454, 0.4273, 0.4291, 0.415],
    [0.9876, 0.9799, 0.9398, 0.851, 0.7352, 0.6368, 0.5445, 0.4992, 0.4632, 0.4269, 0.42, 0.417, 0.4181],
    [0.9893, 0.9808, 0.939, 0.8465, 0.7338, 0.6312, 0.5573, 0.505, 0.4539, 0.4458, 0.4158, 0.4189, 0.4145]
])

Brier_D_all = np.array([
    [0.008377700125184932, 0.016529366393931315, 0.05263197122933036, 0.12622978433992207, 0.23428155265648049, 0.33001102154197504, 0.40364649390174306, 0.45833621095510935, 0.48925685879783176, 0.5084873524892697, 0.5265005890140535, 0.5351897862160594, 0.535655688821819],
    [0.010744372283432188, 0.01861786528128727, 0.04921392785665237, 0.130553247122029, 0.23708335123859933, 0.3264322724129021, 0.4020700676185214, 0.45858313992431765, 0.49138653020376494, 0.5267105075863001, 0.5302083038117827, 0.5471840382483096, 0.5406172102829678],
    [0.008842406659404037, 0.01589756693721347, 0.04813633596723046, 0.12450638545589118, 0.2251983306557004, 0.31596819055291797, 0.39958471677594976, 0.4511044612864316, 0.4908296869743894, 0.513979687459884, 0.5228012975599322, 0.5331324584577392, 0.5366587666968312],
    [0.010730850229531219, 0.019535984307980356, 0.054532256239393395, 0.13633893796280153, 0.24614911889753252, 0.32741910714203704, 0.4095507809925442, 0.4699407662727923, 0.5086184210935191, 0.525650285971151, 0.544528220493102, 0.5479489062984142, 0.5516986011718078],
    [0.009795412763015064, 0.01817015340220003, 0.04894315094823046, 0.12908701202114847, 0.23282846945381916, 0.3285278649762006, 0.4081318056690933, 0.46182442301821763, 0.4973149243671581, 0.5269344762781906, 0.5322562363481955, 0.5380640556117866, 0.5471252220414157]
])

Brier_LB_all = np.array([
    [0.007844242911062461, 0.015166313465830961, 0.05119071505966292, 0.13220723722716274, 0.23246172960926015, 0.3326357106190896, 0.4017439850542811, 0.45393964088759164, 0.49749242841385766, 0.5211752780472316, 0.5353101111950346, 0.5431750248692432, 0.5509519585301749],
    [0.009505463873917114, 0.016971207786010672, 0.0470778754224552, 0.12817579122361247, 0.2349776937700805, 0.339070595213455, 0.40568261118347, 0.4740733457035303, 0.5030623144832805, 0.5253332179429556, 0.5360191577762338, 0.549535112754868, 0.5511099660875561],
    [0.0076984171595750415, 0.014756499162356471, 0.0469422462627717, 0.12101585517824154, 0.22767872259093588, 0.31762503643888157, 0.4011728288118217, 0.4516875142101383, 0.4954514924412558, 0.5178717196068767, 0.5371807295796656, 0.5349908080424384, 0.5497019253514567],
    [0.01001831719153492, 0.016729483129191477, 0.05172484471829214, 0.13075484635707604, 0.23891760292945455, 0.3334325462508124, 0.42297503706889894, 0.4695227387634461, 0.5033610244309235, 0.5387111053297119, 0.5473753173717428, 0.5506346709903308, 0.5503556560492305],
    [0.008755419081030589, 0.01618864623953837, 0.05156303842716825, 0.13417919609729662, 0.23963138466766964, 0.34131713744443276, 0.4110757904444741, 0.46371903749768173, 0.5107108175950072, 0.5222953217647085, 0.5486314035009504, 0.5470332878374046, 0.5531479907833964]   
])

In [None]:
Acc_D_all_mean = np.mean(Acc_D_all, axis=0)
Acc_D_all_std = np.std(Acc_D_all, axis=0)

Acc_LB_all_mean = np.mean(Acc_LB_all, axis=0)
Acc_LB_all_std = np.std(Acc_LB_all, axis=0)

Brier_D_all_mean = np.mean(Brier_D_all, axis=0)
Brier_D_all_std = np.std(Brier_D_all, axis=0)

Brier_LB_all_mean = np.mean(Brier_LB_all, axis=0)
Brier_LB_all_std = np.std(Brier_LB_all, axis=0)

In [None]:
# make a figure for accuracy
x_labels = ['test', '15$^{\circ}$', '30$^{\circ}$', '45$^{\circ}$', '60$^{\circ}$', '75$^{\circ}$', '90$^{\circ}$',
            '105$^{\circ}$', '120$^{\circ}$', '135$^{\circ}$', '150$^{\circ}$', '165$^{\circ}$', '180$^{\circ}$']

plt.figure(figsize=(5, 1.5))
plt.plot(x_labels, Acc_D_all_mean, color='blue', label='MCMC')
plt.plot(x_labels, Acc_LB_all_mean, color='firebrick', label='LB')
plt.ylim(0,1)
plt.ylabel('Accuracy', size=15)
#plt.legend()
plt.xticks(x_labels, x_labels, rotation='30', size=13)

plt.tight_layout()

plt.savefig('figures/LB_vs_MCMC_Acc.pdf')
plt.show();

In [None]:
# make a figure for the brier score

plt.figure(figsize=(5, 1.5))
plt.plot(x_labels, Brier_D_list, color='blue', label='MCMC')
plt.plot(x_labels, Brier_LB_list, color='firebrick', label='LB')
plt.ylim(0,0.8)
plt.ylabel('Brier', size=15)
plt.legend()
plt.xticks(x_labels, x_labels, rotation='30', size=13)

plt.tight_layout()

plt.savefig('figures/LB_vs_MCMC_Brier.pdf')
plt.show();

In [None]:
# make a subplot

fig, ax = plt.subplots(2, 1, figsize=(5, 3))

ax[0].plot(x_labels, Acc_D_all_mean, color='blue', label='MCMC')
ax[0].plot(x_labels, Acc_LB_all_mean, color='firebrick', label='LB')
ax[0].set_ylim(0,1)
ax[0].set_ylabel('Accuracy', size=15)
ax[0].set_xticklabels([])

ax[1].plot(x_labels, Brier_D_list, color='blue', label='MCMC')
ax[1].plot(x_labels, Brier_LB_list, color='firebrick', label='LB')
ax[1].set_ylim(0,0.8)
ax[1].set_ylabel('Brier', size=15)
ax[1].legend()
ax[1].set_xticklabels(x_labels, rotation=30, size=13)


plt.tight_layout()

plt.savefig('figures/LB_vs_MCMC_Acc_Brier.pdf')
plt.show();