In [1]:
import torch
import torchvision
from torch import nn, optim, autograd
from torch.nn import functional as F
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.autograd import Variable
import numpy as np
#import input_data
from sklearn.utils import shuffle as skshuffle
from math import *
from backpack import backpack, extend
from backpack.extensions import KFAC, DiagHessian, DiagGGNMC
from sklearn.metrics import roc_auc_score
import scipy
from tqdm import tqdm, trange
from bpjacext import NetJac
import pytest
from DirLPA_utils import * 
import time

import matplotlib.pyplot as plt

s = 123
np.random.seed(s)
torch.manual_seed(s)
torch.cuda.manual_seed(s)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [3]:
#NOTE: DO NOT RUN THIS CODE: the function NetJac comes from a private repository that is not yet available for you. 
#If you search for it aggressively you might find out my identity. I would therefore prefer if you just looked at the provided results.

#Also, Pytorch was updated and is now incompatible with NetJac. The results for the experiments however, can 
#still be found in the code. 

# since our package is not compatible with pytorch anymore the experiments with other integral approximations
# have been conducted with a last-layer Laplace approximation of the network in the other jupyter notebook. 

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cuda_status = torch.cuda.is_available()
print("device: ", device)
print("cuda status: ", cuda_status)

device:  cuda
cuda status:  True


In [3]:
def LPADirNN(num_classes=10, num_LL=256):
    
    features = torch.nn.Sequential(
        torch.nn.Conv2d(1, 32, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Conv2d(32, 64, 5),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2,2),
        torch.nn.Flatten(),
        torch.nn.Linear(4 * 4 * 64, num_LL), #changed from 500
        torch.nn.Linear(num_LL, num_classes)  #changed from 500
    )
    return(features)

In [4]:
BATCH_SIZE_TRAIN_MNIST = 32#64#128
BATCH_SIZE_TEST_MNIST = 32#64#128
MAX_ITER_MNIST = 6
LR_TRAIN_MNIST = 10e-6

In [5]:
MNIST_transform = torchvision.transforms.ToTensor()

MNIST_train = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=True,
        download=True,
        transform=MNIST_transform)

mnist_train_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_train,
    batch_size=BATCH_SIZE_TRAIN_MNIST,
    shuffle=True
)


MNIST_test = torchvision.datasets.MNIST(
        '~/data/mnist',
        train=False,
        download=False,
        transform=MNIST_transform)

mnist_test_loader = torch.utils.data.dataloader.DataLoader(
    MNIST_test,
    batch_size=BATCH_SIZE_TEST_MNIST,
    shuffle=False,
)

In [6]:
mnist_model = LPADirNN(num_LL=256).cuda()
loss_function = torch.nn.CrossEntropyLoss()

mnist_train_optimizer = torch.optim.Adam(mnist_model.parameters(), lr=1e-3, weight_decay=5e-4)
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes.pth"

In [7]:
#Training routine

def train(model, train_loader, optimizer, max_iter, path, verbose=True):
    max_len = len(train_loader)

    for iter in range(max_iter):
        for batch_idx, (x, y) in enumerate(train_loader):
            
            x, y = x.cuda(), y#.cuda()
            
            output = model(x)

            accuracy = get_accuracy(output, y)

            loss = loss_function(output, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if verbose and batch_idx % 50 == 0:
                print(
                    "Iteration {}; {}/{} \t".format(iter, batch_idx, max_len) +
                    "Minibatch Loss %.3f  " % (loss) +
                    "Accuracy %.0f" % (accuracy * 100) + "%"
                )

    print("saving model at: {}".format(path))
    torch.save(mnist_model.state_dict(), path)

In [8]:
#train(mnist_model, mnist_train_loader, mnist_train_optimizer, MAX_ITER_MNIST, MNIST_PATH, verbose=True)

In [9]:
#predict in distribution
MNIST_PATH = "pretrained_weights/MNIST_pretrained_10_classes.pth"

#mnist_model = LPADirNN(x=100)
mnist_model = LPADirNN(num_LL=256).cuda()
print("loading model from: {}".format(MNIST_PATH))
mnist_model.load_state_dict(torch.load(MNIST_PATH))
mnist_model.eval()

acc = []

max_len = len(mnist_test_loader)
for batch_idx, (x, y) in enumerate(mnist_test_loader):

    x, y = x.cuda(), y.cuda()
    output = mnist_model(x)

    accuracy = get_accuracy(output, y)
    if batch_idx % 10 == 0:
        print(
            "Batch {}/{} \t".format(batch_idx, max_len) + 
            "Accuracy %.0f" % (accuracy * 100) + "%"
        )
    acc.append(accuracy)

avg_acc = np.mean(acc)
print('overall test accuracy on MNIST: {:.02f} %'.format(avg_acc * 100))


loading model from: pretrained_weights/MNIST_pretrained_10_classes.pth
Batch 0/313 	Accuracy 100%
Batch 10/313 	Accuracy 97%
Batch 20/313 	Accuracy 94%
Batch 30/313 	Accuracy 97%
Batch 40/313 	Accuracy 97%
Batch 50/313 	Accuracy 97%
Batch 60/313 	Accuracy 100%
Batch 70/313 	Accuracy 100%
Batch 80/313 	Accuracy 100%
Batch 90/313 	Accuracy 97%
Batch 100/313 	Accuracy 100%
Batch 110/313 	Accuracy 94%
Batch 120/313 	Accuracy 97%
Batch 130/313 	Accuracy 97%
Batch 140/313 	Accuracy 100%
Batch 150/313 	Accuracy 97%
Batch 160/313 	Accuracy 100%
Batch 170/313 	Accuracy 100%
Batch 180/313 	Accuracy 100%
Batch 190/313 	Accuracy 97%
Batch 200/313 	Accuracy 100%
Batch 210/313 	Accuracy 100%
Batch 220/313 	Accuracy 100%
Batch 230/313 	Accuracy 100%
Batch 240/313 	Accuracy 100%
Batch 250/313 	Accuracy 100%
Batch 260/313 	Accuracy 100%
Batch 270/313 	Accuracy 100%
Batch 280/313 	Accuracy 100%
Batch 290/313 	Accuracy 100%
Batch 300/313 	Accuracy 100%
Batch 310/313 	Accuracy 100%
overall test accuracy o

In [10]:
## play around with Backpack
def get_Hessian_NN(model, train_loader, var0, device='cpu', verbose=True):
    lossfunc = torch.nn.CrossEntropyLoss()

    extend(lossfunc, debug=False)
    extend(model, debug=False)

    Hessian_diag = []
    for param in mnist_model.parameters():
        ps = param.size()
        print("parameter size: ", ps)
        Hessian_diag.append(torch.zeros(ps, device=device))
        #print(param.numel())

    tau = 1/var0
    max_len = len(train_loader)

    with backpack(DiagHessian()):

        for batch_idx, (x, y) in enumerate(train_loader):

            if device == 'cuda':
                x, y = x.cuda(), y.cuda()

            model.zero_grad()
            loss = lossfunc(model(x), y)
            loss.backward()

            with torch.no_grad():
                # Hessian of weight
                for idx, param in enumerate(model.parameters()):

                    H_ = param.diag_h
                    #add bias here
                    H_ += tau * torch.ones(H_.size(), device=device)

                    rho = min(1-1/(batch_idx+1), 0.995)

                    Hessian_diag[idx] = rho*Hessian_diag[idx] + (1-rho)*H_
            
            if verbose:
                print("Batch: {}/{}".format(batch_idx, max_len))

    #combine all elements of the Hessian to one big vector
    Hessian_diag = torch.cat([el.view(-1) for el in Hessian_diag])
    print("Hessian_size: ", Hessian_diag.size())
    num_params = np.sum([p.numel() for p in model.parameters()])
    assert(num_params == Hessian_diag.size(-1))
    return(Hessian_diag)
        

In [11]:
Hessian_MNIST = get_Hessian_NN(model=mnist_model, train_loader=mnist_train_loader, var0=200, verbose=False, device='cuda')

parameter size:  torch.Size([32, 1, 5, 5])
parameter size:  torch.Size([32])
parameter size:  torch.Size([64, 32, 5, 5])
parameter size:  torch.Size([64])
parameter size:  torch.Size([256, 1024])
parameter size:  torch.Size([256])
parameter size:  torch.Size([10, 256])
parameter size:  torch.Size([10])
Hessian_size:  torch.Size([317066])


In [12]:
print(Hessian_MNIST)

tensor([0.0056, 0.0059, 0.0067,  ..., 0.0068, 0.0074, 0.0070], device='cuda:0')


In [13]:
## play around with Backpack
def get_Hessian_KFAC_NN(model, train_loader, var0, device='cpu', verbose=True):
    lossfunc = torch.nn.CrossEntropyLoss()

    extend(lossfunc, debug=False)
    extend(model, debug=False)

    Hessian_KFAC_U = []
    Hessian_KFAC_V = []
    for param in mnist_model.parameters():
        ps = param.size()
        print("parameter size: ", ps)
        Hessian_KFAC_U.append(torch.zeros(ps, device=device))
        Hessian_KFAC_V.append(torch.zeros(ps, device=device))
        #print(param.numel())

    tau = 1/var0
    max_len = len(train_loader)

    with backpack(KFAC()):

        for batch_idx, (x, y) in enumerate(train_loader):

            if device == 'cuda':
                x, y = x.cuda(), y.cuda()

            model.zero_grad()
            loss = lossfunc(model(x), y)
            loss.backward()

            with torch.no_grad():
                # Hessian of weight
                for idx, param in enumerate(model.parameters()):

                    # Hessian of weight
                    U_, V_ = param.kfac

                    U_ = np.sqrt(batch_size)*U_ + np.sqrt(tau)*torch.eye(m, device=device)
                    V_ = np.sqrt(batch_size)*V_ + np.sqrt(tau)*torch.eye(n, device=device)

                    rho = min(1-1/(batch_idx+1), 0.95)

                    Hessian_KFAC_U[idx] = rho*Hessian_KFAC_U + (1-rho)*U_
                    Hessian_KFAC_V[idx] = rho*Hessian_KFAC_V + (1-rho)*V_
            
            if verbose:
                print("Batch: {}/{}".format(batch_idx, max_len))

    #combine all elements of the Hessian to one big vector
    Hessian_diag = torch.cat([el.view(-1) for el in Hessian_diag])
    print("Hessian_size: ", Hessian_diag.size())
    num_params = np.sum([p.numel() for p in model.parameters()])
    assert(num_params == Hessian_diag.size(-1))
    return(Hessian_diag)
        

In [28]:
def compute_jacobians_with_backpack(model, x, y, lossfunc):
    """
    Returns the jacobians of the network

    The output is a list. Each element in the list is a tensor
    corresponding to the model.parameters().

    The tensor are of the form [N, *, C] where N is the batch dimension,
    C is the number of classes (output size of the network)
    and * is the shape of the model parameters
    """
    loss = lossfunc(model(x), y)
    print(loss)

    with backpack(NetJac()):
        loss.backward()

    jacs = []
    for p in model.parameters():
        jacs.append(p.netjacs.data.detach())
    return jacs

def transform2full_jac(backpack_jacobian):

    jac_full = []
    #batch_size
    N = backpack_jacobian[0].size(0)
    #num classes
    k = backpack_jacobian[0].size(-1)
    for j in backpack_jacobian:
        jac_full.append(j.view(N, -1, k).permute(0,2,1))
    jac_full = torch.cat(jac_full, dim=-1)
    return(jac_full)

def get_Jacobian(model, x, y, lossfunc):
    return(transform2full_jac(compute_jacobians_with_backpack(model, x, y, lossfunc)))

In [29]:
def predict_Diagonal_full(model, test_loader, Hessian, verbose=True, num_samples=100, cuda=False, timing=False):
    
    lossfunc = torch.nn.CrossEntropyLoss()
    extend(lossfunc, debug=False)
    
    py = []
    if timing:
        time_sum = 0
    
    max_len = len(test_loader)
    for batch_idx, (x, y) in enumerate(test_loader):
        
        if cuda:
            x, y = x.cuda(), y.cuda()
        
        J = get_Jacobian(model, x, y, lossfunc)
        J = J.detach()
        batch_size = J.size(0)
        num_classes = J.size(1)
        Cov_pred = torch.bmm(J * Hessian, J.permute(0, 2, 1))
        Cov_pred = Cov_pred.detach()
        if verbose:
            print("Jacobian size: ", J.size())
            print("cov pred size: ", Cov_pred.size())
        
        mu_pred = model(x).detach()
        post_pred = MultivariateNormal(mu_pred, Cov_pred)

        # MC-integral
        t0 = time.time()
        py_ = 0

        for _ in range(num_samples):
            f_s = post_pred.rsample()
            py_ += torch.softmax(f_s, 1)


        py_ /= num_samples
        py_ = py_.detach()

        py.append(py_)
        t1 = time.time()
        if timing:
            time_sum += (t1-t0)

        if verbose:
            print("Batch: {}/{}".format(batch_idx, max_len))
    
    if timing:
        print("total time used for transform: {:.05f}".format(time_sum))

    return torch.cat(py, dim=0)

In [30]:
BATCH_SIZE_TEST_FMNIST = 32#64#128
BATCH_SIZE_TEST_KMNIST = 32#64#128

In [31]:
FMNIST_test = torchvision.datasets.FashionMNIST(
        '~/data/fmnist', train=False, download=True,
        transform=MNIST_transform)   #torchvision.transforms.ToTensor())

FMNIST_test_loader = torch.utils.data.DataLoader(
    FMNIST_test,
    batch_size=BATCH_SIZE_TEST_FMNIST, shuffle=False)

In [32]:
KMNIST_test = torchvision.datasets.KMNIST(
        '~/data/kmnist', train=False, download=True,
        transform=MNIST_transform)

KMNIST_test_loader = torch.utils.data.DataLoader(
    KMNIST_test,
    batch_size=BATCH_SIZE_TEST_KMNIST, shuffle=False)

In [33]:
"""Load notMNIST"""

import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from matplotlib.pyplot import imread
from torch import Tensor

"""
Loads the train/test set. 
Every image in the dataset is 28x28 pixels and the labels are numbered from 0-9
for A-J respectively.
Set root to point to the Train/Test folders.
"""

# Creating a sub class of torch.utils.data.dataset.Dataset
class notMNIST(Dataset):

    # The init method is called when this class will be instantiated
    def __init__(self, root, transform):
        
        #super(notMNIST, self).__init__(root, transform=transform)

        self.transform = transform
        
        Images, Y = [], []
        folders = os.listdir(root)

        for folder in folders:
            folder_path = os.path.join(root, folder)
            for ims in os.listdir(folder_path):
                try:
                    img_path = os.path.join(folder_path, ims)
                    Images.append(np.array(imread(img_path)))
                    Y.append(ord(folder) - 65)  # Folders are A-J so labels will be 0-9
                except:
                    # Some images in the dataset are damaged
                    print("File {}/{} is broken".format(folder, ims))
        data = [(x, y) for x, y in zip(Images, Y)]
        self.data = data
        self.targets = torch.Tensor(Y)

    # The number of items in the dataset
    def __len__(self):
        return len(self.data)

    # The Dataloader is a generator that repeatedly calls the getitem method.
    # getitem is supposed to return (X, Y) for the specified index.
    def __getitem__(self, index):
        img = self.data[index][0]

        if self.transform is not None:
            img = self.transform(img)
            
        # Input for Conv2D should be Channels x Height x Width
        img_tensor = Tensor(img).view(1, 28, 28).float()
        label = self.data[index][1]
        return (img_tensor, label)

In [34]:
#root = os.path.abspath('~/data')
root = os.path.expanduser('~/data')

# Instantiating the notMNIST dataset class we created
notMNIST_test = notMNIST(root=os.path.join(root, 'notMNIST_small'),
                               transform=MNIST_transform)

# Creating a dataloader
not_mnist_test_loader = torch.utils.data.dataloader.DataLoader(
                            dataset=notMNIST_test,
                            batch_size=BATCH_SIZE_TEST_KMNIST,
                            shuffle=False)

File F/Q3Jvc3NvdmVyIEJvbGRPYmxpcXVlLnR0Zg==.png is broken
File A/RGVtb2NyYXRpY2FCb2xkT2xkc3R5bGUgQm9sZC50dGY=.png is broken


# MAP estimate

In [35]:
targets = MNIST_test.targets.numpy()
targets_FMNIST = FMNIST_test.targets.numpy()
targets_notMNIST = notMNIST_test.targets.numpy().astype(int)
targets_KMNIST = KMNIST_test.targets.numpy()

In [36]:
mnist_test_in_MAP = predict_MAP(mnist_model, mnist_test_loader, cuda=True).cpu().numpy()
mnist_test_out_fmnist_MAP = predict_MAP(mnist_model, FMNIST_test_loader, cuda=True).cpu().numpy()
mnist_test_out_notMNIST_MAP = predict_MAP(mnist_model, not_mnist_test_loader, cuda=True).cpu().numpy()
mnist_test_out_KMNIST_MAP = predict_MAP(mnist_model, KMNIST_test_loader, cuda=True).cpu().numpy()

In [37]:
acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP = get_in_dist_values(mnist_test_in_MAP, targets)
acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_fmnist_MAP, targets_FMNIST)
acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_notMNIST_MAP, targets_notMNIST)
acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP = get_out_dist_values(mnist_test_in_MAP, mnist_test_out_KMNIST_MAP, targets_KMNIST)

In [38]:
print_in_dist_values(acc_in_MAP, prob_correct_in_MAP, ent_in_MAP, MMC_in_MAP, 'mnist', 'MAP')
print_out_dist_values(acc_out_FMNIST_MAP, prob_correct_out_FMNIST_MAP, ent_out_FMNIST_MAP, MMC_out_FMNIST_MAP, auroc_out_FMNIST_MAP, 'FMNIST', 'MAP')
print_out_dist_values(acc_out_notMNIST_MAP, prob_correct_out_notMNIST_MAP, ent_out_notMNIST_MAP, MMC_out_notMNIST_MAP, auroc_out_notMNIST_MAP, 'notMNIST', 'MAP')
print_out_dist_values(acc_out_KMNIST_MAP, prob_correct_out_KMNIST_MAP, ent_out_KMNIST_MAP, MMC_out_KMNIST_MAP, auroc_out_KMNIST_MAP, 'KMNIST', 'MAP')

[In, MAP, mnist] Accuracy: 0.990; average entropy: 0.035;     MMC: 0.989; Prob @ correct: 0.100
[Out-MAP, KFAC, FMNIST] Accuracy: 0.108; Average entropy: 1.390;    MMC: 0.514; AUROC: 0.993; Prob @ correct: 0.100
[Out-MAP, KFAC, notMNIST] Accuracy: 0.108; Average entropy: 0.808;    MMC: 0.709; AUROC: 0.956; Prob @ correct: 0.100
[Out-MAP, KFAC, KMNIST] Accuracy: 0.083; Average entropy: 0.861;    MMC: 0.688; AUROC: 0.973; Prob @ correct: 0.100


In [39]:
#MAP estimate
#seeds are 123,124,125,126,127
acc_in = [0.990]
mmc_in = [0.989]
mmc_out_fmnist = [0.514]
mmc_out_notmnist = [0.709]
mmc_out_kmnist = [0.688]

auroc_out_fmnist = [0.993]
auroc_out_notmnist = [0.956]
auroc_out_kmnist = [0.973]

print("accuracy: {:.03f} with std {:.03f}".format(np.mean(acc_in), np.std(acc_in)))

print("MMC in: {:.03f} with std {:.03f}".format(np.mean(mmc_in), np.std(mmc_in)))
print("MMC out fmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_fmnist), np.std(mmc_out_fmnist)))
print("MMC out notmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_notmnist), np.std(mmc_out_notmnist)))
print("MMC out kmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_kmnist), np.std(mmc_out_kmnist)))

print("AUROC out fmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_fmnist), np.std(auroc_out_fmnist)))
print("AUROC out notmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_notmnist), np.std(auroc_out_notmnist)))
print("AUROC out kmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_kmnist), np.std(auroc_out_kmnist)))

accuracy: 0.990 with std 0.000
MMC in: 0.989 with std 0.000
MMC out fmnist: 0.514 with std 0.000
MMC out notmnist: 0.709 with std 0.000
MMC out kmnist: 0.688 with std 0.000
AUROC out fmnist: 0.993 with std 0.000
AUROC out notmnist: 0.956 with std 0.000
AUROC out kmnist: 0.973 with std 0.000


# Diag Hessian Sampling estimate

In [40]:
num_samples = 1000

In [41]:
mnist_test_in_D = predict_Diagonal_full(mnist_model, mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=num_samples).cpu().numpy()
mnist_test_out_FMNIST_D = predict_Diagonal_full(mnist_model, FMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=num_samples).cpu().numpy()
mnist_test_out_notMNIST_D = predict_Diagonal_full(mnist_model, not_mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=num_samples).cpu().numpy()
mnist_test_out_KMNIST_D = predict_Diagonal_full(mnist_model, KMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True, num_samples=num_samples).cpu().numpy()

tensor(0.0100, device='cuda:0', grad_fn=<NllLossBackward>)


RuntimeError: ('Compared shapes [10, 10] and [32, 10] do not match. ', 'Got [32, 10, 10] and [32, 10]')

In [None]:
acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D = get_in_dist_values(mnist_test_in_D, targets)
acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_FMNIST_D, targets_FMNIST)
acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_notMNIST_D, targets_notMNIST)
acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D = get_out_dist_values(mnist_test_in_D, mnist_test_out_KMNIST_D, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_D, prob_correct_in_D, ent_in_D, MMC_in_D, 'mnist', 'Diag')
print_out_dist_values(acc_out_FMNIST_D, prob_correct_out_FMNIST_D, ent_out_FMNIST_D, MMC_out_FMNIST_D, auroc_out_FMNIST_D, test='fmnist', method='Diag')
print_out_dist_values(acc_out_notMNIST_D, prob_correct_out_notMNIST_D, ent_out_notMNIST_D, MMC_out_notMNIST_D, auroc_out_notMNIST_D, test='notMNIST', method='Diag')
print_out_dist_values(acc_out_KMNIST_D, prob_correct_out_KMNIST_D, ent_out_KMNIST_D, MMC_out_KMNIST_D, auroc_out_KMNIST_D, test='KMNIST', method='Diag')

In [None]:
#Diag Sampling
#seeds are 123,124,125,126,127
time_lpb_in = [26.45687, 27.79014, 26.57382, 26.47316, 26.88896]
time_lpb_out_fmnist = [26.78206, 27.77104, 26.32712, 26.37834, 26.62998]
time_lpb_out_notmnist = [50.15894, 52.16657, 49.43548, 49.79360, 49.76200]
time_lpb_out_kmnist = [26.71048, 27.93303, 26.52897, 26.47667, 26.88779]

acc_in = [0.990, 0.990, 0.990, 0.990, 0.990]
mmc_in = [0.942, 0.941, 0.942, 0.942, 0.941]
mmc_out_fmnist = [0.398, 0.397, 0.397, 0.398, 0.396]
mmc_out_notmnist = [0.543, 0.542, 0.543, 0.543, 0.542]
mmc_out_kmnist = [0.514, 0.512, 0.513, 0.514, 0.512]

auroc_out_fmnist = [0.992, 0.992, 0.992, 0.992, 0.992]
auroc_out_notmnist = [0.960, 0.960, 0.960, 0.960, 0.960]
auroc_out_kmnist = [0.974, 0.974, 0.974, 0.974, 0.974]

print("Sampling Bridge time in: {:.03f} with std {:.03f}".format(np.mean(time_lpb_in), np.std(time_lpb_in)))
print("Sampling Bridge time out fmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_fmnist), np.std(time_lpb_out_fmnist)))
print("Sampling Bridge time out notmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_notmnist), np.std(time_lpb_out_notmnist)))
print("Sampling Bridge time out kmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_kmnist), np.std(time_lpb_out_kmnist)))

print("accuracy: {:.03f} with std {:.03f}".format(np.mean(acc_in), np.std(acc_in)))

print("MMC in: {:.03f} with std {:.03f}".format(np.mean(mmc_in), np.std(mmc_in)))
print("MMC out fmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_fmnist), np.std(mmc_out_fmnist)))
print("MMC out notmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_notmnist), np.std(mmc_out_notmnist)))
print("MMC out kmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_kmnist), np.std(mmc_out_kmnist)))

print("AUROC out fmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_fmnist), np.std(auroc_out_fmnist)))
print("AUROC out notmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_notmnist), np.std(auroc_out_notmnist)))
print("AUROC out kmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_kmnist), np.std(auroc_out_kmnist)))

# Full KFAC Laplace Approximation sampling

In [None]:
# The KFAC Hessian (and even its components) is too large to store in memory and therefore skipped for this condition

# Dirichlet Laplace Approximation

In [None]:
mnist_test_in_DIR_LPA = predict_DIR_LPA(mnist_model, mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).cpu().numpy()
mnist_test_out_FMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, FMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).cpu().numpy()
mnist_test_out_notMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, not_mnist_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).cpu().numpy()
mnist_test_out_KMNIST_DIR_LPA = predict_DIR_LPA(mnist_model, KMNIST_test_loader, Hessian_MNIST, verbose=False, cuda=True, timing=True).cpu().numpy()

In [None]:
mnist_test_in_DIR_LPAn = mnist_test_in_DIR_LPA/mnist_test_in_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_FMNIST_DIR_LPAn = mnist_test_out_FMNIST_DIR_LPA/mnist_test_out_FMNIST_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_notMNIST_DIR_LPAn = mnist_test_out_notMNIST_DIR_LPA/mnist_test_out_notMNIST_DIR_LPA.sum(1).reshape(-1,1)
mnist_test_out_KMNIST_DIR_LPAn = mnist_test_out_KMNIST_DIR_LPA/mnist_test_out_KMNIST_DIR_LPA.sum(1).reshape(-1,1)

In [None]:
acc_in_DIR_LPAn, prob_correct_in_DIR_LPAn, ent_in_DIR_LPAn, MMC_in_DIR_LPAn = get_in_dist_values(mnist_test_in_DIR_LPAn, targets)
acc_out_FMNIST_DIR_LPAn, prob_correct_out_FMNIST_DIR_LPAn, ent_out_FMNIST_DIR_LPAn, MMC_out_FMNIST_DIR_LPAn, auroc_out_FMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_FMNIST_DIR_LPAn, targets_FMNIST)
acc_out_notMNIST_DIR_LPAn, prob_correct_out_notMNIST_DIR_LPAn, ent_out_notMNIST_DIR_LPAn, MMC_out_notMNIST_DIR_LPAn, auroc_out_notMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_notMNIST_DIR_LPAn, targets_notMNIST)
acc_out_KMNIST_DIR_LPAn, prob_correct_out_KMNIST_DIR_LPAn, ent_out_KMNIST_DIR_LPAn, MMC_out_KMNIST_DIR_LPAn, auroc_out_KMNIST_DIR_LPAn = get_out_dist_values(mnist_test_in_DIR_LPAn, mnist_test_out_KMNIST_DIR_LPAn, targets_KMNIST)

In [None]:
print_in_dist_values(acc_in_DIR_LPAn, prob_correct_in_DIR_LPAn, ent_in_DIR_LPAn, MMC_in_DIR_LPAn, 'mnist', 'DIR_LPAn')
print_out_dist_values(acc_out_FMNIST_DIR_LPAn, prob_correct_out_FMNIST_DIR_LPAn, ent_out_FMNIST_DIR_LPAn, MMC_out_FMNIST_DIR_LPAn, auroc_out_FMNIST_DIR_LPAn, test='fmnist', method='DIR_LPAn')
print_out_dist_values(acc_out_notMNIST_DIR_LPAn, prob_correct_out_notMNIST_DIR_LPAn, ent_out_notMNIST_DIR_LPAn, MMC_out_notMNIST_DIR_LPAn, auroc_out_notMNIST_DIR_LPAn, test='notMNIST', method='DIR_LPAn')
print_out_dist_values(acc_out_KMNIST_DIR_LPAn, prob_correct_out_KMNIST_DIR_LPAn, ent_out_KMNIST_DIR_LPAn, MMC_out_KMNIST_DIR_LPAn, auroc_out_KMNIST_DIR_LPAn, test='KMNIST', method='DIR_LPAn')

In [None]:
#Laplace Bridge
#seeds are 123,124,125,126,127
time_lpb_in = [0.06198, 0.06407, 0.06117, 0.06193, 0.06098]
time_lpb_out_fmnist = [0.06195, 0.06337, 0.06156, 0.06214, 0.06095]
time_lpb_out_notmnist = [0.11700, 0.12246, 0.11510, 0.11567, 0.11433]
time_lpb_out_kmnist = [0.06194, 0.06504, 0.06136, 0.06200, 0.06136]

acc_in = [0.990, 0.990, 0.990, 0.990, 0.990]
mmc_in = [0.987, 0.987, 0.987, 0.987, 0.987]
mmc_out_fmnist = [0.363, 0.363, 0.363, 0.363, 0.363]
mmc_out_notmnist = [0.649, 0.649, 0.649, 0.649, 0.649]
mmc_out_kmnist = [0.637, 0.637, 0.637, 0.638, 0.637]

auroc_out_fmnist = [0.996, 0.996, 0.996, 0.996, 0.996]
auroc_out_notmnist = [0.961, 0.961, 0.961, 0.961, 0.961]
auroc_out_kmnist = [0.973, 0.973, 0.973, 0.973, 0.973]

print("Laplace Bridge time in: {:.03f} with std {:.03f}".format(np.mean(time_lpb_in), np.std(time_lpb_in)))
print("Laplace Bridge time out fmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_fmnist), np.std(time_lpb_out_fmnist)))
print("Laplace Bridge time out notmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_notmnist), np.std(time_lpb_out_notmnist)))
print("Laplace Bridge time out kmnist: {:.03f} with std {:.03f}".format(np.mean(time_lpb_out_kmnist), np.std(time_lpb_out_kmnist)))

print("accuracy: {:.03f} with std {:.03f}".format(np.mean(acc_in), np.std(acc_in)))

print("MMC in: {:.03f} with std {:.03f}".format(np.mean(mmc_in), np.std(mmc_in)))
print("MMC out fmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_fmnist), np.std(mmc_out_fmnist)))
print("MMC out notmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_notmnist), np.std(mmc_out_notmnist)))
print("MMC out kmnist: {:.03f} with std {:.03f}".format(np.mean(mmc_out_kmnist), np.std(mmc_out_kmnist)))

print("AUROC out fmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_fmnist), np.std(auroc_out_fmnist)))
print("AUROC out notmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_notmnist), np.std(auroc_out_notmnist)))
print("AUROC out kmnist: {:.03f} with std {:.03f}".format(np.mean(auroc_out_kmnist), np.std(auroc_out_kmnist)))

In [None]:
# check if condition holds

def check_condition(alpha_vecs):
    #note that this is vectorized
    alpha_sum = alpha_vecs.sum(1)
    alpha_max = alpha_vecs.max(1)
    alpha_sum_minus = alpha_sum - alpha_max
    right_side = 0.25 * (np.sqrt(9 * alpha_sum_minus**2 + 10 * alpha_sum_minus + 1) - alpha_sum_minus - 1)
    cases = alpha_max > right_side
    percentage = np.sum(cases)/len(cases)
    return(percentage)

In [None]:
print(np.sum(check_condition(mnist_test_in_DIR_LPA)))
print(np.sum(check_condition(mnist_test_out_FMNIST_DIR_LPA)))
print(np.sum(check_condition(mnist_test_out_notMNIST_DIR_LPA)))
print(np.sum(check_condition(mnist_test_out_KMNIST_DIR_LPA)))

In [2]:
# since our package is not compatible with pytorch anymore the results have been conducted with a last-layer
# Laplace approximation of the network in the other jupyter notebook. 