In [1]:
from laplace.baselaplace import FullLaplace
from laplace.curvature.backpack import BackPackGGN
import numpy as np
import torch

from laplace import Laplace, marglik_training
import torch
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import matplotlib.pyplot as plt
import torchvision


In [2]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [3]:
import torch.nn as nn


In [4]:
import torch.nn.functional as F


In [5]:
import torch.distributions as dists
from netcal.metrics import ECE



In [6]:
config = {
    "num_classes":10,
    "kernel_size": 5,
    "channels":1,
    "filter_1_out" :16,
    "filter_2_out" :32,
    "enc_sizes":[16,32],
    "padding" :0,
    "stride" :1, 
    "pool":2,
    "learning_rate": 0.001,
    "epochs": 20,
    "batch_size": 64,
    "crop_size":128
}





In [7]:
device = torch.device('cpu')

In [8]:
train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='.', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])), batch_size=config["batch_size"], shuffle=True, num_workers=2)

In [9]:
test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='.', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])), batch_size=config["batch_size"], shuffle=True, num_workers=2)

In [10]:
train_features, train_labels = next(iter(train_loader))
height = train_features.shape[2]
width = train_features.shape[3]
print(height,width)


28 28


In [11]:
def conv_block(in_f, out_f, *args, **kwargs):
    return nn.Sequential(
        nn.Conv2d(in_f, out_f, *args, **kwargs),
        nn.MaxPool2d(2, stride=2),
        nn.BatchNorm2d(out_f),
        nn.ReLU()
    )




class Base(nn.Module):
    def __init__(self, enc_sizes, kernel, pad):
        super().__init__()

        
        conv_blocks =[conv_block(in_f, out_f, kernel_size=kernel, padding=pad) 
                       for in_f, out_f in zip(enc_sizes, enc_sizes[1:])]

        self.base_net = nn.Sequential(*conv_blocks)
        

    def forward(self,x):
        x = self.base_net(x)
    
        return x

class Net(nn.Module):
    def __init__(self, in_c, enc_sizes, kernel, pad,n_classes):
        super().__init__()
    
        self.enc_sizes = [in_c, *enc_sizes]
        self.kernel = kernel
        self.pad = pad
        self.n_classes = n_classes
        
        self.base = Base(self.enc_sizes,self.kernel,self.pad)
        self.fc1 = nn.Linear(
            self.enc_sizes[-1] * 4* 4 , 50
        )
        self.fc2 = nn.Linear(50, self.n_classes)
        

    def forward(self,x):
        
        x = self.base(x)        
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        return x

In [12]:
model = Net(
        config["channels"],
        config["enc_sizes"],
        config["kernel_size"],
        config["padding"],
        config["num_classes"]).to(device).eval()

In [13]:
# model_path = '/Users/georgioszefkilis/Bayesian_Deep_Learning/models/best_checkpoint.pth'
model_path = '/Users/georgioszefkilis/Bayesian_Deep_Learning/saved_models/colab_best_Vanilla_MNIST_20.pth'
checkpoint = torch.load(model_path, map_location=device)
    # initialize state_dict from checkpoint to model
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [14]:
@torch.no_grad()
def predict(dataloader, model, laplace=False):
    py = []
    target = []
    for x, t in dataloader:
        x,t = x.to(device),t.to(device)
        target.append(t)
        if laplace:
            py.append(model(x))
        else:
            py.append(torch.softmax(model(x), dim=-1))

    images = torch.cat(py).cpu()
    labels =torch.cat(target, dim=0).cpu()
    acc_map = (images.argmax(-1) == labels).float().mean()
    ece_map = ECE(bins=15).measure(images.numpy(), labels.numpy())
    nll_map = -dists.Categorical(images).log_prob(labels).mean()
    
    return acc_map,ece_map,nll_map

# Last layer implementation

## Without Laplace

In [15]:
acc_map,ece_map,nll_map = predict(test_loader,model, laplace=False)
print(f"[MAP] Acc.: {acc_map:.1%}; ECE: {ece_map:.1%}; NLL: {nll_map:.3}")

[MAP] Acc.: 99.2%; ECE: 0.5%; NLL: 0.0349


## With Laplace

In [16]:
la = Laplace(model, 'classification',
             subset_of_weights='last_layer',
             hessian_structure='kron')
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')

In [17]:
acc_laplace,ece_laplace,nll_laplace = predict(test_loader,la,laplace=True)
print(f"Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}")

Acc.: 99.2%; ECE: 0.2%; NLL: 0.0257


In [30]:
torch.nn.utils.vector_to_parameters(la.sample(n_samples=10)[0], model.parameters())


RuntimeError: shape '[32, 16, 5, 5]' is invalid for input of size 62

# Subnetwork implementation

In [18]:
from laplace.baselaplace import FullLaplace
from laplace.curvature.backpack import BackPackGGN
from laplace.utils import ModuleNameSubnetMask

In [19]:
for name,m in model.named_modules():
    print(name)


base
base.base_net
base.base_net.0
base.base_net.0.0
base.base_net.0.1
base.base_net.0.2
base.base_net.0.3
base.base_net.1
base.base_net.1.0
base.base_net.1.1
base.base_net.1.2
base.base_net.1.3
fc1
fc2


In [21]:
print('start_laplace')
subnetwork_mask = ModuleNameSubnetMask(model, module_names=['base.base_net.1.0'])
print('step 2')
subnetwork_mask.select()
print('step 3')
subnetwork_indices = subnetwork_mask.indices
print('step 4')
sub_laplace = Laplace(
    model,
    "classification",
    subset_of_weights="subnetwork",
    hessian_structure="full",
    subnetwork_indices = subnetwork_indices#.type(torch.LongTensor),
)
print('fit')
sub_laplace.fit(train_loader)
print('optimize')
#sub_laplace.prior_precision=torch.tensor([0.00001])

#laplace.optimize_prior_precision(method="marglik",val_loader=test_loader)

start_laplace
step 2
step 3
step 4
fit




optimize


In [22]:
acc_subnet,ece_subnet,nll_subnet = predict(test_loader,sub_laplace,laplace=True)
print(f"Acc.: {acc_subnet:.1%}; ECE: {ece_subnet:.1%}; NLL: {nll_subnet:.3}")

Acc.: 99.2%; ECE: 18.7%; NLL: 0.248


In [32]:
sub_laplace.prior_precision=torch.tensor([0.00001])
acc_subnet,ece_subnet,nll_subnet = predict(test_loader,sub_laplace,laplace=True)
print(f"Acc.: {acc_subnet:.1%}; ECE: {ece_subnet:.1%}; NLL: {nll_subnet:.3}")

In [31]:
sub_laplace.optimize_prior_precision(method="marglik")

RuntimeError: The size of tensor a (39504) must match the size of tensor b (12832) at non-singleton dimension 0