In [1]:
#!nvidia-smi
#using a GeForce GTX1080 Ti for reproducibility for all timing experiments

In [2]:
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
import numpy as np
from sklearn.metrics import roc_auc_score
import scipy
from utils.LB_utils import * 
from utils.load_not_MNIST import notMNIST
import os
import time
import matplotlib.pyplot as plt
from laplace import Laplace
import utils.scoring as scoring

s = 1
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed(s)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cuda_status = torch.cuda.is_available()
print("device: ", device)
print("cuda status: ", cuda_status)

device:  cuda
cuda status:  True


In [4]:
### define network
class ConvNet(nn.Module):
    
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        
        self.net = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Conv2d(16, 32, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2,2),
            torch.nn.Flatten(),
            torch.nn.Linear(4 * 4 * 32, num_classes)
        )
    def forward(self, x):
        out = self.net(x)
        return out

In [5]:
BATCH_SIZE_TRAIN_MNIST = 128
BATCH_SIZE_TEST_MNIST = 128
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [6]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=True,
        download=True,
        transform=MNIST_transform)

MNIST_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform)

MNIST_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [7]:
mnist_model = ConvNet().to(device)
loss_function = torch.nn.CrossEntropyLoss()

mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer_s{}.pth".format(s)

In [8]:
#Training routine

def train(model, train_loader, optimizer, max_iter, path, verbose=True):
    max_len = len(train_loader)

    for iter in range(max_iter):
        for batch_idx, (x, y) in enumerate(train_loader):
            
            x, y = x.to(device), y.to(device)
            
            output = model(x)

            accuracy = get_accuracy(output, y)

            loss = loss_function(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if verbose and batch_idx % 50 == 0:
                print(
                    "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                    "Minibatch Loss %.3f  " % (loss) +
                    "Accuracy %.0f" % (accuracy * 100) + "%"
                )

    print("saving model at: {}".format(path))
    torch.save(mnist_model.state_dict(), path)

In [9]:
#train(mnist_model, MNIST_train_loader, mnist_train_optimizer, MAX_ITER_MNIST, MNIST_PATH, verbose=True)

In [10]:
#predict in distribution
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer_s{}.pth".format(s)
#MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes_last_layer.pth"

mnist_model = ConvNet().to(device)
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

max_len = len(MNIST_test_loader)
for batch_idx, (x, y) in enumerate(MNIST_test_loader):

    x, y = x.to(device), y.to(device)
    output = mnist_model(x)

    accuracy = get_accuracy(output, y)
    if batch_idx % 10 == 0:
        print(
            "Batch {}/{} \t".format(batch_idx, max_len) + 
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )
    acc.append(accuracy)

avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: pretrained_weights/MNIST_pretrained_10_classes_last_layer_s1.pth
Batch 0/79 	Accuracy 100%
Batch 10/79 	Accuracy 96%
Batch 20/79 	Accuracy 98%
Batch 30/79 	Accuracy 98%
Batch 40/79 	Accuracy 100%
Batch 50/79 	Accuracy 99%
Batch 60/79 	Accuracy 100%
Batch 70/79 	Accuracy 98%
overall test accuracy on MNIST: 98.84 %


In [11]:
BATCH_SIZE_TEST_FMNIST = 128
BATCH_SIZE_TEST_KMNIST = 128

In [12]:
FMNIST_test = torchvision.datasets.FashionMNIST(
        '~/data/fmnist', train=False, download=True,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [13]:
KMNIST_test = torchvision.datasets.KMNIST(
        '~/data/kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=BATCH_SIZE_TEST_KMNIST, shuffle=False)

In [14]:
#root = os.path.abspath('~/data')
root = os.path.expanduser('~/data')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
notMNIST_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=BATCH_SIZE_TEST_KMNIST,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


# MAP estimate

In [15]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [16]:
MNIST_test_in_MAP = predict_MAP(mnist_model, MNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_FMNIST_MAP = predict_MAP(mnist_model, FMNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_notMNIST_MAP = predict_MAP(mnist_model, notMNIST_test_loader, device=device).cpu().numpy()
MNIST_test_out_KMNIST_MAP = predict_MAP(mnist_model, KMNIST_test_loader, device=device).cpu().numpy()

In [17]:
# compute average log-likelihood for Diag
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_MAP)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_MAP)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_MAP)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_MAP)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.03344982489943504
4.6731367111206055
6.672377586364746
6.748490810394287


In [18]:
#compute the Expected confidence estimate
print(scoring.expected_calibration_error(targets, MNIST_test_in_MAP))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_MAP))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_MAP))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_MAP))

0.009983010101010131
0.5903608484848487
0.6404300589747076
0.6294307171717172


In [19]:
##FPR95
print(get_fpr95(MNIST_test_in_MAP, MNIST_test_out_FMNIST_MAP))

(0.1311, 0.9366349875926971)


In [68]:
## Brier score
print(get_brier(MNIST_test_in_MAP, targets, n_classes=10))
print(get_brier(MNIST_test_out_FMNIST_MAP, targets_FMNIST, n_classes=10))
print(get_brier(MNIST_test_out_notMNIST_MAP, targets_notMNIST, n_classes=10))
print(get_brier(MNIST_test_out_KMNIST_MAP, targets_KMNIST, n_classes=10))

0.0016874086577445269
0.1371215432882309
0.14116381108760834
0.1430293172597885


In [21]:
acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP = get_in_dist_values(MNIST_test_in_MAP, targets)
acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_FMNIST_MAP, targets_FMNIST)
acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_notMNIST_MAP, targets_notMNIST)
acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP = get_out_dist_values(MNIST_test_in_MAP, MNIST_test_out_KMNIST_MAP, targets_KMNIST)

In [22]:
print_in_dist_values(acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP, 'MNIST', 'MAP')
print_out_dist_values(acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP, 'MNIST', test='FMNIST', method='MAP')
print_out_dist_values(acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP, 'MNIST', test='notMNIST', method='MAP')
print_out_dist_values(acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP, 'MNIST', test='KMNIST', method='MAP')

[In, MAP, MNIST] Accuracy: 0.988; average entropy: 0.047;     MMC: 0.986; Prob @ correct: 0.100
[Out-FMNIST, MAP, MNIST] Accuracy: 0.073; Average entropy: 0.972;    MMC: 0.663; AUROC: 0.977; Prob @ correct: 0.100
[Out-notMNIST, MAP, MNIST] Accuracy: 0.148; Average entropy: 0.620;    MMC: 0.774; AUROC: 0.914; Prob @ correct: 0.100
[Out-KMNIST, MAP, MNIST] Accuracy: 0.093; Average entropy: 0.750;    MMC: 0.723; AUROC: 0.959; Prob @ correct: 0.100


In [23]:
num_samples = 100

# Diag Hessian Sampling estimate

In [24]:
la_diag = Laplace(mnist_model, 'classification', 
                     subset_of_weights='last_layer', 
                     hessian_structure='diag',
                     prior_precision=1e-0) # 5e-4 # Choose prior precision according to weight decay
la_diag.fit(MNIST_train_loader)

In [25]:
MNIST_test_in_D = predict_samples(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_D = predict_samples(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_D = predict_samples(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_D = predict_samples(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.2002999909999978
time:  1.0010222549999988
time:  0.5089399970000024
time:  0.9723320810000011


In [26]:
# compute average log-likelihood for Diag
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_D)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_D)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_D)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_D)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.060485560446977615
3.96065354347229
4.334907531738281
4.933351993560791


In [27]:
#compute the Expected confidence estimate
print(scoring.expected_calibration_error(targets, MNIST_test_in_D))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_D))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_D))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_D))

0.03493014141414141
0.46846589898989915
0.409833633277876
0.43425194949494955


In [69]:
## Brier score
print(get_brier(MNIST_test_in_D, targets, n_classes=10))
print(get_brier(MNIST_test_out_FMNIST_D, targets_FMNIST, n_classes=10))
print(get_brier(MNIST_test_out_notMNIST_D, targets_notMNIST, n_classes=10))
print(get_brier(MNIST_test_out_KMNIST_D, targets_KMNIST, n_classes=10))

0.0022665667347609997
0.12107032537460327
0.11608327180147171
0.12063241004943848


In [28]:
acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D = get_in_dist_values(MNIST_test_in_D, targets)
acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_FMNIST_D, targets_FMNIST)
acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_notMNIST_D, targets_notMNIST)
acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D = get_out_dist_values(MNIST_test_in_D, MNIST_test_out_KMNIST_D, targets_KMNIST)

In [29]:
print_in_dist_values(acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D, 'MNIST', 'Diag')
print_out_dist_values(acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D, 'MNIST', test='fmnist', method='Diag')
print_out_dist_values(acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D, 'MNIST', test='notMNIST', method='Diag')
print_out_dist_values(acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D, 'MNIST', test='KMNIST', method='Diag')

[In, Diag, MNIST] Accuracy: 0.989; average entropy: 0.160;     MMC: 0.956; Prob @ correct: 0.100
[Out-fmnist, Diag, MNIST] Accuracy: 0.073; Average entropy: 1.313;    MMC: 0.541; AUROC: 0.975; Prob @ correct: 0.100
[Out-notMNIST, Diag, MNIST] Accuracy: 0.149; Average entropy: 1.212;    MMC: 0.559; AUROC: 0.963; Prob @ correct: 0.100
[Out-KMNIST, Diag, MNIST] Accuracy: 0.095; Average entropy: 1.272;    MMC: 0.529; AUROC: 0.973; Prob @ correct: 0.100


# KFAC Laplace Approximation (sampling)

In [30]:
la_kron = Laplace(mnist_model, 'classification', 
                     subset_of_weights='last_layer', 
                     hessian_structure='kron',
                     prior_precision=5e-0) # 5e-4 # Choose prior precision according to weight decay
la_kron.fit(MNIST_train_loader)

The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  /opt/conda/conda-bld/pytorch_1639180549130/work/aten/src/ATen/native/BatchLinearAlgebra.cpp:2499.)
  L, W = torch.symeig(M, eigenvectors=True)


In [31]:
MNIST_test_in_KFAC = predict_samples(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_KFAC = predict_samples(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_KFAC = predict_samples(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_KFAC = predict_samples(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.3355387420000007
time:  1.3483620710000004
time:  1.0776169469999992
time:  1.2704046519999999


In [32]:
# compute average log-likelihood for KFAC
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_KFAC)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_KFAC)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_KFAC)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_KFAC)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.04870095103979111
3.7331199645996094
4.36720609664917
4.877889156341553


In [33]:
# compute ECE for KFAC
print(scoring.expected_calibration_error(targets, MNIST_test_in_KFAC))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_KFAC))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_KFAC))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_KFAC))

0.024422959595959588
0.42929906060606066
0.40477670855100895
0.4167330707070709


In [70]:
## Brier score
print(get_brier(MNIST_test_in_KFAC, targets, n_classes=10))
print(get_brier(MNIST_test_out_FMNIST_KFAC, targets_FMNIST, n_classes=10))
print(get_brier(MNIST_test_out_notMNIST_KFAC, targets_notMNIST, n_classes=10))
print(get_brier(MNIST_test_out_KMNIST_KFAC, targets_KMNIST, n_classes=10))

0.0020646171178668737
0.11669006198644638
0.11602698266506195
0.11858072876930237


In [34]:
acc_in_KFAC, prob_correct_in_KFAC, ent_in_KFAC, MMC_in_KFAC = get_in_dist_values(MNIST_test_in_KFAC, targets)
acc_out_FMNIST_KFAC, prob_correct_out_FMNIST_KFAC, ent_out_FMNIST_KFAC, MMC_out_FMNIST_KFAC, auroc_out_FMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_FMNIST_KFAC, targets_FMNIST)
acc_out_notMNIST_KFAC, prob_correct_out_notMNIST_KFAC, ent_out_notMNIST_KFAC, MMC_out_notMNIST_KFAC, auroc_out_notMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_notMNIST_KFAC, targets_notMNIST)
acc_out_KMNIST_KFAC, prob_correct_out_KMNIST_KFAC, ent_out_KMNIST_KFAC, MMC_out_KMNIST_KFAC, auroc_out_KMNIST_KFAC = get_out_dist_values(MNIST_test_in_KFAC, MNIST_test_out_KMNIST_KFAC, targets_KMNIST)

In [35]:
print_in_dist_values(acc_in_KFAC, prob_correct_in_KFAC, ent_in_KFAC, MMC_in_KFAC, 'MNIST', 'KFAC')
print_out_dist_values(acc_out_FMNIST_KFAC, prob_correct_out_FMNIST_KFAC, ent_out_FMNIST_KFAC, MMC_out_FMNIST_KFAC, auroc_out_FMNIST_KFAC, 'MNIST', test='fmnist', method='KFAC')
print_out_dist_values(acc_out_notMNIST_KFAC, prob_correct_out_notMNIST_KFAC, ent_out_notMNIST_KFAC, MMC_out_notMNIST_KFAC, auroc_out_notMNIST_KFAC, 'MNIST', test='notMNIST', method='KFAC')
print_out_dist_values(acc_out_KMNIST_KFAC, prob_correct_out_KMNIST_KFAC, ent_out_KMNIST_KFAC, MMC_out_KMNIST_KFAC, auroc_out_KMNIST_KFAC, 'MNIST', test='KMNIST', method='KFAC')

[In, KFAC, MNIST] Accuracy: 0.988; average entropy: 0.111;     MMC: 0.967; Prob @ correct: 0.100
[Out-fmnist, KFAC, MNIST] Accuracy: 0.072; Average entropy: 1.402;    MMC: 0.501; AUROC: 0.985; Prob @ correct: 0.100
[Out-notMNIST, KFAC, MNIST] Accuracy: 0.148; Average entropy: 1.228;    MMC: 0.551; AUROC: 0.969; Prob @ correct: 0.100
[Out-KMNIST, KFAC, MNIST] Accuracy: 0.094; Average entropy: 1.315;    MMC: 0.511; AUROC: 0.982; Prob @ correct: 0.100


# Laplace Bridge (Diagonal)

In [36]:
MNIST_test_in_LB_D = predict_LB(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_LB_D = predict_LB(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_LB_D = predict_LB(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_LB_D = predict_LB(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  0.8771512819999963
time:  0.8695184520000012
time:  0.30476124100000135
time:  0.888303295


In [37]:
# # compute average log-likelihood for LB
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_LB_D)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_LB_D)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_LB_D)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_LB_D)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.04091636836528778
3.654500722885132
5.585165977478027
5.520288944244385


In [38]:
#compute ECE for LB
print(scoring.expected_calibration_error(targets, MNIST_test_in_LB_D))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_LB_D))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_LB_D))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_LB_D))

0.011439090909090939
0.48319700000000004
0.6082117478998488
0.5931453232323233


In [71]:
## Brier score
print(get_brier(MNIST_test_in_LB_D, targets, n_classes=10))
print(get_brier(MNIST_test_out_FMNIST_LB_D, targets_FMNIST, n_classes=10))
print(get_brier(MNIST_test_out_notMNIST_LB_D, targets_notMNIST, n_classes=10))
print(get_brier(MNIST_test_out_KMNIST_LB_D, targets_KMNIST, n_classes=10))

0.0019700229167938232
0.12366575002670288
0.13661548495292664
0.13839513063430786


In [39]:
acc_in_LB_D, prob_correct_in_LB_D, ent_in_LB_D, MMC_in_LB_D = get_in_dist_values(MNIST_test_in_LB_D, targets)
acc_out_FMNIST_LB_D, prob_correct_out_FMNIST_LB_D, ent_out_FMNIST_LB_D, MMC_out_FMNIST_LB_D, auroc_out_FMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_FMNIST_LB_D, targets_FMNIST)
acc_out_notMNIST_LB_D, prob_correct_out_notMNIST_LB_D, ent_out_notMNIST_LB_D, MMC_out_notMNIST_LB_D, auroc_out_notMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_notMNIST_LB_D, targets_notMNIST)
acc_out_KMNIST_LB_D, prob_correct_out_KMNIST_LB_D, ent_out_KMNIST_LB_D, MMC_out_KMNIST_LB_D, auroc_out_KMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_D, MNIST_test_out_KMNIST_LB_D, targets_KMNIST)

In [40]:
print_in_dist_values(acc_in_LB_D, prob_correct_in_LB_D, ent_in_LB_D, MMC_in_LB_D, 'MNIST', 'LB_D')
print_out_dist_values(acc_out_FMNIST_LB_D, prob_correct_out_FMNIST_LB_D, ent_out_FMNIST_LB_D, MMC_out_FMNIST_LB_D, auroc_out_FMNIST_LB_D, 'MNIST', test='fmnist', method='LB_D')
print_out_dist_values(acc_out_notMNIST_LB_D, prob_correct_out_notMNIST_LB_D, ent_out_notMNIST_LB_D, MMC_out_notMNIST_LB_D, auroc_out_notMNIST_LB_D, 'MNIST', test='notMNIST', method='LB_D')
print_out_dist_values(acc_out_KMNIST_LB_D, prob_correct_out_KMNIST_LB_D, ent_out_KMNIST_LB_D, MMC_out_KMNIST_LB_D, auroc_out_KMNIST_LB_D, 'MNIST', test='KMNIST', method='LB_D')

[In, LB_D, MNIST] Accuracy: 0.987; average entropy: 0.061;     MMC: 0.982; Prob @ correct: 0.100
[Out-fmnist, LB_D, MNIST] Accuracy: 0.082; Average entropy: 1.324;    MMC: 0.565; AUROC: 0.982; Prob @ correct: 0.100
[Out-notMNIST, LB_D, MNIST] Accuracy: 0.152; Average entropy: 0.737;    MMC: 0.735; AUROC: 0.931; Prob @ correct: 0.100
[Out-KMNIST, LB_D, MNIST] Accuracy: 0.099; Average entropy: 0.866;    MMC: 0.692; AUROC: 0.960; Prob @ correct: 0.100


# LB diag norm

In [41]:
MNIST_test_in_LB_Dn = predict_LB_norm(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_LB_Dn = predict_LB_norm(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_LB_Dn = predict_LB_norm(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_LB_Dn = predict_LB_norm(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  0.9353996910000006
time:  0.9184220670000016
time:  0.3646903100000003
time:  0.9031150509999932


In [42]:
# # compute average log-likelihood for LB
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_LB_Dn)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_LB_Dn)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_LB_Dn)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_LB_Dn)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.07220260798931122
3.6098625659942627
3.163008451461792
3.329798460006714


In [43]:
#compute ECE for LB
print(scoring.expected_calibration_error(targets, MNIST_test_in_LB_Dn))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_LB_Dn))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_LB_Dn))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_LB_Dn))

0.04084181818181817
0.5142652929292929
0.37737078108579925
0.3928338787878788


In [72]:
## Brier score
print(get_brier(MNIST_test_in_LB_Dn, targets, n_classes=10))
print(get_brier(MNIST_test_out_FMNIST_LB_Dn, targets_FMNIST, n_classes=10))
print(get_brier(MNIST_test_out_notMNIST_LB_Dn, targets_notMNIST, n_classes=10))
print(get_brier(MNIST_test_out_KMNIST_LB_Dn, targets_KMNIST, n_classes=10))

0.002524588257074356
0.1262955218553543
0.112608402967453
0.11521024256944656


In [44]:
acc_in_LB_D, prob_correct_in_LB_D, ent_in_LB_D, MMC_in_LB_D = get_in_dist_values(MNIST_test_in_LB_Dn, targets)
acc_out_FMNIST_LB_D, prob_correct_out_FMNIST_LB_D, ent_out_FMNIST_LB_D, MMC_out_FMNIST_LB_D, auroc_out_FMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_Dn, MNIST_test_out_FMNIST_LB_Dn, targets_FMNIST)
acc_out_notMNIST_LB_D, prob_correct_out_notMNIST_LB_D, ent_out_notMNIST_LB_D, MMC_out_notMNIST_LB_D, auroc_out_notMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_Dn, MNIST_test_out_notMNIST_LB_Dn, targets_notMNIST)
acc_out_KMNIST_LB_D, prob_correct_out_KMNIST_LB_D, ent_out_KMNIST_LB_D, MMC_out_KMNIST_LB_D, auroc_out_KMNIST_LB_D = get_out_dist_values(MNIST_test_in_LB_Dn, MNIST_test_out_KMNIST_LB_Dn, targets_KMNIST)

In [45]:
print_in_dist_values(acc_in_LB_D, prob_correct_in_LB_D, ent_in_LB_D, MMC_in_LB_D, 'MNIST', 'LB_D')
print_out_dist_values(acc_out_FMNIST_LB_D, prob_correct_out_FMNIST_LB_D, ent_out_FMNIST_LB_D, MMC_out_FMNIST_LB_D, auroc_out_FMNIST_LB_D, 'MNIST', test='fmnist', method='LB_D')
print_out_dist_values(acc_out_notMNIST_LB_D, prob_correct_out_notMNIST_LB_D, ent_out_notMNIST_LB_D, MMC_out_notMNIST_LB_D, auroc_out_notMNIST_LB_D, 'MNIST', test='notMNIST', method='LB_D')
print_out_dist_values(acc_out_KMNIST_LB_D, prob_correct_out_KMNIST_LB_D, ent_out_KMNIST_LB_D, MMC_out_KMNIST_LB_D, auroc_out_KMNIST_LB_D, 'MNIST', test='KMNIST', method='LB_D')

[In, LB_D, MNIST] Accuracy: 0.987; average entropy: 0.213;     MMC: 0.949; Prob @ correct: 0.100
[Out-fmnist, LB_D, MNIST] Accuracy: 0.086; Average entropy: 1.243;    MMC: 0.600; AUROC: 0.953; Prob @ correct: 0.100
[Out-notMNIST, LB_D, MNIST] Accuracy: 0.138; Average entropy: 1.459;    MMC: 0.515; AUROC: 0.954; Prob @ correct: 0.100
[Out-KMNIST, LB_D, MNIST] Accuracy: 0.098; Average entropy: 1.506;    MMC: 0.491; AUROC: 0.971; Prob @ correct: 0.100


# KFAC Laplace Bridge

In [46]:
MNIST_test_in_LB_KFAC = predict_LB(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_LB_KFAC = predict_LB(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_LB_KFAC = predict_LB(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_LB_KFAC = predict_LB(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  0.9228590539999999
time:  0.9338816940000001
time:  0.6778247060000027
time:  0.9153987969999946


In [47]:
# compute average log-likelihood for LB KFAC
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_LB_KFAC)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_LB_KFAC)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_LB_KFAC)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_LB_KFAC)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.03415154293179512
3.8378238677978516
6.014162063598633
6.023524284362793


In [48]:
print(scoring.expected_calibration_error(targets, MNIST_test_in_LB_KFAC))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_LB_KFAC))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_LB_KFAC))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_LB_KFAC))

0.011948898989899039
0.5424864242424243
0.6367586083004797
0.6168500707070707


In [73]:
## Brier score
print(get_brier(MNIST_test_in_LB_KFAC, targets, n_classes=10))
print(get_brier(MNIST_test_out_FMNIST_LB_KFAC, targets_FMNIST, n_classes=10))
print(get_brier(MNIST_test_out_notMNIST_LB_KFAC, targets_notMNIST, n_classes=10))
print(get_brier(MNIST_test_out_KMNIST_LB_KFAC, targets_KMNIST, n_classes=10))

0.001700085704214871
0.13170121610164642
0.14033333957195282
0.14197257161140442


In [49]:
acc_in_LB_KFAC, prob_correct_in_LB_KFAC, ent_in_LB_KFAC, MMC_in_LB_KFAC = get_in_dist_values(MNIST_test_in_LB_KFAC, targets)
acc_out_FMNIST_LB_KFAC, prob_correct_out_FMNIST_LB_KFAC, ent_out_FMNIST_LB_KFAC, MMC_out_FMNIST_LB_KFAC, auroc_out_FMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_FMNIST_LB_KFAC, targets_FMNIST)
acc_out_notMNIST_LB_KFAC, prob_correct_out_notMNIST_LB_KFAC, ent_out_notMNIST_LB_KFAC, MMC_out_notMNIST_LB_KFAC, auroc_out_notMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_notMNIST_LB_KFAC, targets_notMNIST)
acc_out_KMNIST_LB_KFAC, prob_correct_out_KMNIST_LB_KFAC, ent_out_KMNIST_LB_KFAC, MMC_out_KMNIST_LB_KFAC, auroc_out_KMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFAC, MNIST_test_out_KMNIST_LB_KFAC, targets_KMNIST)

In [50]:
print_in_dist_values(acc_in_LB_KFAC, prob_correct_in_LB_KFAC, ent_in_LB_KFAC, MMC_in_LB_KFAC, 'MNIST', 'LB_KFAC')
print_out_dist_values(acc_out_FMNIST_LB_KFAC, prob_correct_out_FMNIST_LB_KFAC, ent_out_FMNIST_LB_KFAC, MMC_out_FMNIST_LB_KFAC, auroc_out_FMNIST_LB_KFAC, 'MNIST', test='fmnist', method='LB_KFAC')
print_out_dist_values(acc_out_notMNIST_LB_KFAC, prob_correct_out_notMNIST_LB_KFAC, ent_out_notMNIST_LB_KFAC, MMC_out_notMNIST_LB_KFAC, auroc_out_notMNIST_LB_KFAC, 'MNIST', test='notMNIST', method='LB_KFAC')
print_out_dist_values(acc_out_KMNIST_LB_KFAC, prob_correct_out_KMNIST_LB_KFAC, ent_out_KMNIST_LB_KFAC, MMC_out_KMNIST_LB_KFAC, auroc_out_KMNIST_LB_KFAC, 'MNIST', test='KMNIST', method='LB_KFAC')

[In, LB_KFAC, MNIST] Accuracy: 0.988; average entropy: 0.051;     MMC: 0.985; Prob @ correct: 0.100
[Out-fmnist, LB_KFAC, MNIST] Accuracy: 0.075; Average entropy: 1.202;    MMC: 0.617; AUROC: 0.978; Prob @ correct: 0.100
[Out-notMNIST, LB_KFAC, MNIST] Accuracy: 0.147; Average entropy: 0.685;    MMC: 0.757; AUROC: 0.912; Prob @ correct: 0.100
[Out-KMNIST, LB_KFAC, MNIST] Accuracy: 0.094; Average entropy: 0.806;    MMC: 0.711; AUROC: 0.957; Prob @ correct: 0.100


# LB kfac norm

In [51]:
MNIST_test_in_LB_KFACn = predict_LB_norm(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_LB_KFACn = predict_LB_norm(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_LB_KFACn = predict_LB_norm(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_LB_KFACn = predict_LB_norm(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.2208120379999983
time:  1.2360889580000034
time:  0.890107215999997
time:  1.3019489289999981


In [52]:
# compute average log-likelihood for LB KFAC
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_LB_KFACn)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_LB_KFACn)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_LB_KFACn)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_LB_KFACn)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.0359264612197876
3.2363293170928955
3.2359938621520996
3.617896795272827


In [53]:
print(scoring.expected_calibration_error(targets, MNIST_test_in_LB_KFACn))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_LB_KFACn))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_LB_KFACn))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_LB_KFACn))

0.012604787878787842
0.46441325252525256
0.3865528226076186
0.4096595454545455


In [74]:
## Brier score
print(get_brier(MNIST_test_in_LB_KFACn, targets, n_classes=10))
print(get_brier(MNIST_test_out_FMNIST_LB_KFACn, targets_FMNIST, n_classes=10))
print(get_brier(MNIST_test_out_notMNIST_LB_KFACn, targets_notMNIST, n_classes=10))
print(get_brier(MNIST_test_out_KMNIST_LB_KFACn, targets_KMNIST, n_classes=10))

0.00174651516135782
0.12144472450017929
0.11362239718437195
0.11815816164016724


In [54]:
acc_in_LB_KFAC, prob_correct_in_LB_KFAC, ent_in_LB_KFAC, MMC_in_LB_KFAC = get_in_dist_values(MNIST_test_in_LB_KFACn, targets)
acc_out_FMNIST_LB_KFAC, prob_correct_out_FMNIST_LB_KFAC, ent_out_FMNIST_LB_KFAC, MMC_out_FMNIST_LB_KFAC, auroc_out_FMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFACn, MNIST_test_out_FMNIST_LB_KFACn, targets_FMNIST)
acc_out_notMNIST_LB_KFAC, prob_correct_out_notMNIST_LB_KFAC, ent_out_notMNIST_LB_KFAC, MMC_out_notMNIST_LB_KFAC, auroc_out_notMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFACn, MNIST_test_out_notMNIST_LB_KFACn, targets_notMNIST)
acc_out_KMNIST_LB_KFAC, prob_correct_out_KMNIST_LB_KFAC, ent_out_KMNIST_LB_KFAC, MMC_out_KMNIST_LB_KFAC, auroc_out_KMNIST_LB_KFAC = get_out_dist_values(MNIST_test_in_LB_KFACn, MNIST_test_out_KMNIST_LB_KFACn, targets_KMNIST)

In [55]:
print_in_dist_values(acc_in_LB_KFAC, prob_correct_in_LB_KFAC, ent_in_LB_KFAC, MMC_in_LB_KFAC, 'MNIST', 'LB_KFAC')
print_out_dist_values(acc_out_FMNIST_LB_KFAC, prob_correct_out_FMNIST_LB_KFAC, ent_out_FMNIST_LB_KFAC, MMC_out_FMNIST_LB_KFAC, auroc_out_FMNIST_LB_KFAC, 'MNIST', test='fmnist', method='LB_KFAC')
print_out_dist_values(acc_out_notMNIST_LB_KFAC, prob_correct_out_notMNIST_LB_KFAC, ent_out_notMNIST_LB_KFAC, MMC_out_notMNIST_LB_KFAC, auroc_out_notMNIST_LB_KFAC, 'MNIST', test='notMNIST', method='LB_KFAC')
print_out_dist_values(acc_out_KMNIST_LB_KFAC, prob_correct_out_KMNIST_LB_KFAC, ent_out_KMNIST_LB_KFAC, MMC_out_KMNIST_LB_KFAC, auroc_out_KMNIST_LB_KFAC, 'MNIST', test='KMNIST', method='LB_KFAC')

[In, LB_KFAC, MNIST] Accuracy: 0.988; average entropy: 0.061;     MMC: 0.982; Prob @ correct: 0.100
[Out-fmnist, LB_KFAC, MNIST] Accuracy: 0.077; Average entropy: 1.469;    MMC: 0.541; AUROC: 0.986; Prob @ correct: 0.100
[Out-notMNIST, LB_KFAC, MNIST] Accuracy: 0.145; Average entropy: 1.476;    MMC: 0.508; AUROC: 0.981; Prob @ correct: 0.100
[Out-KMNIST, LB_KFAC, MNIST] Accuracy: 0.090; Average entropy: 1.477;    MMC: 0.500; AUROC: 0.988; Prob @ correct: 0.100


# Compare to probit (extended MacKay approach)

In [56]:
MNIST_test_in_PROBIT_D = predict_probit(la_diag, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_PROBIT_D = predict_probit(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_PROBIT_D = predict_probit(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_PROBIT_D = predict_probit(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.1141732800000028
time:  1.1267198460000003
time:  0.3676808479999991
time:  1.0509229290000022


In [57]:
# compute average log-likelihood for LB KFAC
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_PROBIT_D)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_PROBIT_D)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_PROBIT_D)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_PROBIT_D)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.04973114654421806
3.832324266433716
3.872626781463623
4.313175201416016


In [58]:
print(scoring.expected_calibration_error(targets, MNIST_test_in_PROBIT_D))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_PROBIT_D))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_PROBIT_D))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_PROBIT_D))

0.023886616161616154
0.5246621717171718
0.4899368983576419
0.4909745050505051


In [59]:
acc_in_PROBIT, prob_correct_in_PROBIT, ent_in_PROBIT, MMC_in_PROBIT = get_in_dist_values(MNIST_test_in_PROBIT_D, targets)
acc_out_FMNIST_PROBIT, prob_correct_out_FMNIST_PROBIT, ent_out_FMNIST_PROBIT, MMC_out_FMNIST_PROBIT, auroc_out_FMNIST_PROBIT = get_out_dist_values(MNIST_test_in_PROBIT_D, MNIST_test_out_FMNIST_PROBIT_D, targets_FMNIST)
acc_out_notMNIST_PROBIT, prob_correct_out_notMNIST_PROBIT, ent_out_notMNIST_PROBIT, MMC_out_notMNIST_PROBIT, auroc_out_notMNIST_PROBIT = get_out_dist_values(MNIST_test_in_PROBIT_D, MNIST_test_out_notMNIST_PROBIT_D, targets_notMNIST)
acc_out_KMNIST_PROBIT, prob_correct_out_KMNIST_PROBIT, ent_out_KMNIST_PROBIT, MMC_out_KMNIST_PROBIT, auroc_out_KMNIST_PROBIT = get_out_dist_values(MNIST_test_in_PROBIT_D, MNIST_test_out_KMNIST_PROBIT_D, targets_KMNIST)

In [60]:
print_in_dist_values(acc_in_PROBIT, prob_correct_in_PROBIT, ent_in_PROBIT, MMC_in_PROBIT, 'MNIST', 'PROBIT')
print_out_dist_values(acc_out_FMNIST_PROBIT, prob_correct_out_FMNIST_PROBIT, ent_out_FMNIST_PROBIT, MMC_out_FMNIST_PROBIT, auroc_out_FMNIST_PROBIT, 'MNIST', test='fmnist', method='PROBIT')
print_out_dist_values(acc_out_notMNIST_PROBIT, prob_correct_out_notMNIST_PROBIT, ent_out_notMNIST_PROBIT, MMC_out_notMNIST_PROBIT, auroc_out_notMNIST_PROBIT, 'MNIST', test='notMNIST', method='PROBIT')
print_out_dist_values(acc_out_KMNIST_PROBIT, prob_correct_out_KMNIST_PROBIT, ent_out_KMNIST_PROBIT, MMC_out_KMNIST_PROBIT, auroc_out_KMNIST_PROBIT, 'MNIST', test='KMNIST', method='PROBIT')

[In, PROBIT, MNIST] Accuracy: 0.987; average entropy: 0.126;     MMC: 0.967; Prob @ correct: 0.100
[Out-fmnist, PROBIT, MNIST] Accuracy: 0.073; Average entropy: 1.214;    MMC: 0.598; AUROC: 0.971; Prob @ correct: 0.100
[Out-notMNIST, PROBIT, MNIST] Accuracy: 0.141; Average entropy: 1.116;    MMC: 0.620; AUROC: 0.957; Prob @ correct: 0.100
[Out-KMNIST, PROBIT, MNIST] Accuracy: 0.092; Average entropy: 1.192;    MMC: 0.583; AUROC: 0.969; Prob @ correct: 0.100


### PROBIT KFAC

In [61]:
MNIST_test_in_PROBIT_K = predict_probit(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_PROBIT_K = predict_probit(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_PROBIT_K = predict_probit(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_PROBIT_K = predict_probit(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

time:  1.175833780000005
time:  1.131725528000004
time:  0.6892226630000025
time:  1.0696401479999977


In [62]:
# compute average log-likelihood for LB KFAC
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_PROBIT_K)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_PROBIT_K)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_PROBIT_K)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_PROBIT_K)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

0.4061307907104492
2.7053489685058594
2.7153563499450684
2.879185199737549


In [63]:
print(scoring.expected_calibration_error(targets, MNIST_test_in_PROBIT_K))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_PROBIT_K))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_PROBIT_K))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_PROBIT_K))

0.3033561919191919
0.27609011111111115
0.24635375329885054
0.23612531313131316


In [64]:
acc_in_PROBIT, prob_correct_in_PROBIT, ent_in_PROBIT, MMC_in_PROBIT = get_in_dist_values(MNIST_test_in_PROBIT_K, targets)
acc_out_FMNIST_PROBIT, prob_correct_out_FMNIST_PROBIT, ent_out_FMNIST_PROBIT, MMC_out_FMNIST_PROBIT, auroc_out_FMNIST_PROBIT = get_out_dist_values(MNIST_test_in_PROBIT_K, MNIST_test_out_FMNIST_PROBIT_K, targets_FMNIST)
acc_out_notMNIST_PROBIT, prob_correct_out_notMNIST_PROBIT, ent_out_notMNIST_PROBIT, MMC_out_notMNIST_PROBIT, auroc_out_notMNIST_PROBIT = get_out_dist_values(MNIST_test_in_PROBIT_K, MNIST_test_out_notMNIST_PROBIT_K, targets_notMNIST)
acc_out_KMNIST_PROBIT, prob_correct_out_KMNIST_PROBIT, ent_out_KMNIST_PROBIT, MMC_out_KMNIST_PROBIT, auroc_out_KMNIST_PROBIT = get_out_dist_values(MNIST_test_in_PROBIT_K, MNIST_test_out_KMNIST_PROBIT_K, targets_KMNIST)

In [65]:
print_in_dist_values(acc_in_PROBIT, prob_correct_in_PROBIT, ent_in_PROBIT, MMC_in_PROBIT, 'MNIST', 'PROBIT')
print_out_dist_values(acc_out_FMNIST_PROBIT, prob_correct_out_FMNIST_PROBIT, ent_out_FMNIST_PROBIT, MMC_out_FMNIST_PROBIT, auroc_out_FMNIST_PROBIT, 'MNIST', test='fmnist', method='PROBIT')
print_out_dist_values(acc_out_notMNIST_PROBIT, prob_correct_out_notMNIST_PROBIT, ent_out_notMNIST_PROBIT, MMC_out_notMNIST_PROBIT, auroc_out_notMNIST_PROBIT, 'MNIST', test='notMNIST', method='PROBIT')
print_out_dist_values(acc_out_KMNIST_PROBIT, prob_correct_out_KMNIST_PROBIT, ent_out_KMNIST_PROBIT, MMC_out_KMNIST_PROBIT, auroc_out_KMNIST_PROBIT, 'MNIST', test='KMNIST', method='PROBIT')

[In, PROBIT, MNIST] Accuracy: 0.988; average entropy: 1.155;     MMC: 0.685; Prob @ correct: 0.100
[Out-fmnist, PROBIT, MNIST] Accuracy: 0.073; Average entropy: 1.890;    MMC: 0.349; AUROC: 0.962; Prob @ correct: 0.100
[Out-notMNIST, PROBIT, MNIST] Accuracy: 0.148; Average entropy: 1.805;    MMC: 0.376; AUROC: 0.929; Prob @ correct: 0.100
[Out-KMNIST, PROBIT, MNIST] Accuracy: 0.093; Average entropy: 1.891;    MMC: 0.328; AUROC: 0.966; Prob @ correct: 0.100


In [66]:
break

SyntaxError: 'break' outside loop (668683560.py, line 1)

# Compare to Second-order Delta Posterior Predictive

this kinda sucks. So we'll just leave it alone.

In [None]:
MNIST_test_in_SODPP_D = predict_second_order_dpp(la_diag, MNIST_test_loader, timing=True, device=device).cpu().nan_to_num(nan=0.0, neginf=0.0, posinf=1.0).numpy()
MNIST_test_out_FMNIST_SODPP_D = predict_second_order_dpp(la_diag, FMNIST_test_loader, timing=True, device=device).cpu().nan_to_num(nan=0.0, neginf=0.0, posinf=1.0).numpy()
MNIST_test_out_notMNIST_SODPP_D = predict_second_order_dpp(la_diag, notMNIST_test_loader, timing=True, device=device).cpu().nan_to_num(nan=0.0, neginf=0.0, posinf=1.0).numpy()
MNIST_test_out_KMNIST_SODPP_D = predict_second_order_dpp(la_diag, KMNIST_test_loader, timing=True, device=device).cpu().nan_to_num(nan=0.0, neginf=0.0, posinf=1.0).numpy()

In [None]:
# compute average log-likelihood for LB KFAC
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_SODPP_D)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_SODPP_D)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_SODPP_D)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_SODPP_D)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

In [None]:
print(scoring.expected_calibration_error(targets, MNIST_test_in_SODPP_D))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_SODPP_D))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_SODPP_D))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_SODPP_D))

In [None]:
acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP = get_in_dist_values(MNIST_test_in_SODPP_D, targets)
acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_FMNIST_SODPP_D, targets_FMNIST)
acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_notMNIST_SODPP_D, targets_notMNIST)
acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_D, MNIST_test_out_KMNIST_SODPP_D, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP, 'MNIST', 'SODPP')
print_out_dist_values(acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP, 'MNIST', test='fmnist', method='SODPP')
print_out_dist_values(acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP, 'MNIST', test='notMNIST', method='SODPP')
print_out_dist_values(acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP, 'MNIST', test='KMNIST', method='SODPP')

### KFAC SODPP

In [None]:
MNIST_test_in_SODPP_K = predict_second_order_dpp(la_kron, MNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_FMNIST_SODPP_K = predict_second_order_dpp(la_kron, FMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_notMNIST_SODPP_K = predict_second_order_dpp(la_kron, notMNIST_test_loader, timing=True, device=device).cpu().numpy()
MNIST_test_out_KMNIST_SODPP_K = predict_second_order_dpp(la_kron, KMNIST_test_loader, timing=True, device=device).cpu().numpy()

In [None]:
# compute average log-likelihood for LB KFAC
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_in_SODPP_K)).log_prob(torch.tensor(targets)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_FMNIST_SODPP_K)).log_prob(torch.tensor(targets_FMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_notMNIST_SODPP_K)).log_prob(torch.tensor(targets_notMNIST)).mean().item())
print(-torch.distributions.Categorical(torch.tensor(MNIST_test_out_KMNIST_SODPP_K)).log_prob(torch.tensor(targets_KMNIST)).mean().item())

In [None]:
print(scoring.expected_calibration_error(targets, MNIST_test_in_SODPP_K))
print(scoring.expected_calibration_error(targets_FMNIST, MNIST_test_out_FMNIST_SODPP_K))
print(scoring.expected_calibration_error(targets_notMNIST, MNIST_test_out_notMNIST_SODPP_K))
print(scoring.expected_calibration_error(targets_KMNIST, MNIST_test_out_KMNIST_SODPP_K))

In [None]:
acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP = get_in_dist_values(MNIST_test_in_SODPP_K, targets)
acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_FMNIST_SODPP_K, targets_FMNIST)
acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_notMNIST_SODPP_K, targets_notMNIST)
acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP = get_out_dist_values(MNIST_test_in_SODPP_K, MNIST_test_out_KMNIST_SODPP_K, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_SODPP, prob_correct_in_SODPP, ent_in_SODPP, MMC_in_SODPP, 'MNIST', 'SODPP')
print_out_dist_values(acc_out_FMNIST_SODPP, prob_correct_out_FMNIST_SODPP, ent_out_FMNIST_SODPP, MMC_out_FMNIST_SODPP, auroc_out_FMNIST_SODPP, 'MNIST', test='fmnist', method='SODPP')
print_out_dist_values(acc_out_notMNIST_SODPP, prob_correct_out_notMNIST_SODPP, ent_out_notMNIST_SODPP, MMC_out_notMNIST_SODPP, auroc_out_notMNIST_SODPP, 'MNIST', test='notMNIST', method='SODPP')
print_out_dist_values(acc_out_KMNIST_SODPP, prob_correct_out_KMNIST_SODPP, ent_out_KMNIST_SODPP, MMC_out_KMNIST_SODPP, auroc_out_KMNIST_SODPP, 'MNIST', test='KMNIST', method='SODPP')