This is a simplified demo compared to experiments in the paper. To reproduce the results, increase training epoches and sample size, and add weight decay. 

In [1]:
# Standard imports
import os
import copy
import collections
import numpy as np
from numpy.linalg import inv, cholesky
from typing import Union, List, Any, Dict
from tqdm import tqdm
import torch
from torch import Tensor
from torch.nn import Module, Sequential
import torchvision
from matplotlib import pyplot as plt
%matplotlib inline
from mpl_toolkits.mplot3d import Axes3D
import deeplake

# From the repository
from plot import surface_plot
from curvatures import Diagonal, KFAC, EFB, INF,Curvature, BlockDiagonal
from utils import calibration_curve,get_eigenvectors, kron, expected_calibration_error, predictive_entropy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Change this to 'cuda' if you have a working GPU.
device = 'cuda'

def train(model, data, criterion, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        for images, labels in tqdm(data):
            logits = model(images.to(device))

            loss = criterion(logits, labels.to(device))
            model.zero_grad()
            loss.backward()
            optimizer.step()
            
def eval(model_, data):
    model_.eval()
    logits = torch.Tensor().to(device)
    targets = torch.LongTensor()

    with torch.no_grad():
        for images, labels in tqdm(data):
            logits = torch.cat([logits, model_(images.to(device))])
            targets = torch.cat([targets, labels])
    return torch.nn.functional.softmax(logits, dim=1), targets

def eval_ood(model_, data):
    model_.eval()
    logits = torch.Tensor().to(device)

    with torch.no_grad():
        for item in tqdm(data):
            logits = torch.cat([logits, model_(item['images'].float() .unsqueeze(1).to(device))])
    return torch.nn.functional.softmax(logits, dim=1)

def accuracy(predictions, labels):
    print(f"Accuracy: {100 * np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.numpy()):.2f}%")

In [3]:
# Define a PyTorch model (or load a pretrained one).
class Flatten(torch.nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

# This tutorial uses a LeNet-5 variant.
model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 6, 5, padding=2),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2, 2),
    torch.nn.Conv2d(6, 16, 5),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2, 2),
    Flatten(),
    torch.nn.Linear(16 * 5 * 5, 120),
    torch.nn.ReLU(),
    torch.nn.Linear(120, 84),
    torch.nn.ReLU(),
    torch.nn.Linear(84, 10)).to(device)

In [4]:
# Load some data for training
torch_data = "~/.torch/datasets"  # Standard PyTorch dataset location
train_set = torchvision.datasets.MNIST(root=torch_data,
                                       train=True,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32)

# And some for evaluating/testing
test_set = torchvision.datasets.MNIST(root=torch_data,
                                      train=False,
                                      transform=torchvision.transforms.ToTensor(),
                                      download=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=256)

In [5]:
# Out-of-distribution dataset
ds = deeplake.load('hub://activeloop/not-mnist-small')

/

Opening dataset in read-only mode as you don't have write permissions.


/

This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/not-mnist-small



/

hub://activeloop/not-mnist-small loaded successfully.



-

In [6]:
ood_loader = ds.pytorch(num_workers=0, batch_size=4, shuffle=False)



In [7]:
# Train the model (or load a pretrained one)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
train(model, train_loader, criterion, optimizer, epochs=2)

100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:59<00:00, 31.76it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:50<00:00, 36.97it/s]


In [8]:
# Evaluate the model (optional)
sgd_predictions, sgd_labels = eval(model, test_loader)
accuracy(sgd_predictions, sgd_labels)

100%|██████████████████████████████████████████████████████████████████████████████████| 40/40 [00:04<00:00,  9.59it/s]

Accuracy: 96.41%





In [9]:
sgd_ece = expected_calibration_error(sgd_predictions.cpu().detach().numpy(), sgd_labels.cpu().detach().numpy(), 10)
print(sgd_ece[0])

0.011540600928664237


In [10]:
sgd_ood_predictions = eval_ood(model, ood_loader)

100%|█████████████████████████████████████████████████████████████████████████████| 4681/4681 [00:45<00:00, 103.64it/s]


In [11]:
sgd_entropy = predictive_entropy(sgd_ood_predictions.cpu().detach().numpy(), True)
print(sgd_entropy)

0.0020425383


In [12]:
# constant
samples = 5

In [13]:
model.train()

diag = Diagonal(model, last_layer_mode = True)
kfac = KFAC(model, last_layer_mode = True)

for images, labels in tqdm(train_loader):
    logits = model(images.to(device))
    loss = criterion(logits, labels.to(device)) 
    model.zero_grad()
    loss.backward(retain_graph=True)

    diag.update(batch_size=images.size(0))
    kfac.update(batch_size=images.size(0))
        
ckfac = EFB(model, kfac.state, last_layer_mode = True)

for images, labels in tqdm(train_loader):
    logits = model(images.to(device))
    loss = criterion(logits, labels.to(device))
    model.zero_grad()
    loss.backward(retain_graph=True)

    ckfac.update(batch_size=images.size(0))

llla = INF(model, diag.state, kfac.state, ckfac.state, last_layer_mode = True)
llla.update(rank=100)

100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:58<00:00, 32.11it/s]
The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\BatchLinearAlgebra.cpp:3041.)
  _, xxt_eigvecs = torch.symeig(sym_xxt, eigenvectors=True)
100%|██████████████████████████████████████████████████████████████████████████████| 1875/1875 [00:57<00:00, 32.54it/s]
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.38it/s]


In [14]:
count_llla = 0
for index, (layer, value) in enumerate(llla.state.items()):
    count_llla += value[0].shape[0]*value[0].shape[1]+value[1].shape[0]*value[1].shape[1]+value[2].shape[0]+value[3].shape[0]
print(count_llla)

2350


In [15]:
# prior precision and likelihood scale parameter
add = 100.0
multiply = 20.0
llla.invert(add, multiply)

L = torch.cholesky(A)
should be replaced with
L = torch.linalg.cholesky(A)
and
U = torch.cholesky(A, upper=True)
should be replaced with
U = torch.linalg.cholesky(A).mH().
This transform will produce equivalent results for all valid (symmetric positive definite) inputs. (Triggered internally at  C:\cb\pytorch_1000000000000\work\aten\src\ATen\native\BatchLinearAlgebra.cpp:1755.)
  A_c_inv = vtv.cholesky().inverse()


In [17]:
mean_predictions = 0
mean_ood_predictions = 0
with torch.no_grad():
    for sample in range(samples):
        llla.sample_and_replace()
        predictions, labels = eval(model, test_loader)
        ood_predictions = eval_ood(model, ood_loader)
        mean_predictions += predictions
        mean_ood_predictions += ood_predictions
    mean_predictions /= samples
    mean_ood_predictions /= samples
accuracy(mean_predictions, labels)

100%|██████████████████████████████████████████████████████████████████████████████████| 40/40 [00:04<00:00,  8.29it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 4681/4681 [01:00<00:00, 77.20it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 40/40 [00:04<00:00,  9.36it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 4681/4681 [00:47<00:00, 98.07it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 40/40 [00:04<00:00,  9.25it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 4681/4681 [01:01<00:00, 75.69it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 40/40 [00:04<00:00,  9.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 4681/4681 [00:48<00:00, 95.91it/s]
100%|███████████████████████████████████

Accuracy: 96.40%


In [18]:
llla_ece = expected_calibration_error(mean_predictions.cpu().detach().numpy(), labels.cpu().detach().numpy(), 10)
print(llla_ece[0])

0.017216112053394336


In [19]:
llla_entropy = predictive_entropy(mean_ood_predictions.cpu().detach().numpy(), True)
print(llla_entropy)

0.14972506
