In [2]:
from laplace.baselaplace import FullLaplace
from laplace.curvature.backpack import BackPackGGN
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dists
from netcal.metrics import ECE

from laplace import Laplace, marglik_training
import torch
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import matplotlib.pyplot as plt
import torchvision
import sys
import os
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
py_file_location = "/Users/georgioszefkilis/Bayesian_Deep_Learning/src/utils"
py_file_location_src = "/Users/georgioszefkilis/Bayesian_Deep_Learning/src/data"
sys.path.append(os.path.abspath(py_file_location))

sys.path.append(os.path.abspath(py_file_location_src))
import compute_dim
import make_dataset


In [7]:
config={
            "batch_size": 256,
            "num_classes": 10,
            "channels": 1,
            "filter1_out": 16,
            "kernel_size": 5,
            "pool": 2,
            "filter2_out": 32,
            "padding": 0,
            "stride": 1,
            "learning_rate_base": 0.001,
            "learning_rate_stn":0.0001,
            "epochs": 5,
            "crop_size": 128,
            "enc_sizes":[16,32],
            "loc_sizes":[8,16,32,64],
        }


In [8]:
device = torch.device('cpu')

In [None]:
train_loader,_, test_loader = make_dataset.data(
        config["batch_size"], config["crop_size"], misplacement=False
    )

In [12]:
train_features, train_labels = next(iter(train_loader))
height = train_features.shape[2]
width = train_features.shape[3]
print(height,width)


28 28


In [13]:
class Base(nn.Module):
    def __init__(self, enc_sizes, kernel, pad):
        super().__init__()

        
        conv_blocks =[compute_dim.conv_block(in_f, out_f, kernel_size=kernel, padding=pad) 
                       for in_f, out_f in zip(enc_sizes, enc_sizes[1:])]

        self.base_net = nn.Sequential(*conv_blocks)
        

    def forward(self,x):
        x = self.base_net(x)
    
        return x

class Vanilla(nn.Module):
    def __init__(self, in_c, enc_sizes, kernel, pad,n_classes):
        super().__init__()
    
        self.enc_sizes = [in_c, *enc_sizes]
        self.kernel = kernel
        self.pad = pad
        self.n_classes = n_classes
        
        self.base = Base(self.enc_sizes,self.kernel,self.pad)
        self.fc1 = nn.Linear(
            self.enc_sizes[-1] * 4* 4 , 50
        )
        self.fc2 = nn.Linear(50, self.n_classes)
        

    def forward(self,x):
        
        x = self.base(x)        
        x = x.flatten(1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        return x

In [14]:
model = Vanilla(
        config["channels"],
        config["enc_sizes"],
        config["kernel_size"],
        config["padding"],
        config["num_classes"]).to(device).eval().to(device).eval()

In [15]:
# model_path = '/Users/georgioszefkilis/Bayesian_Deep_Learning/models/best_checkpoint.pth'
model_path = '/Users/georgioszefkilis/Bayesian_Deep_Learning/saved_models/colab_best_Vanilla_MNIST_20.pth'
checkpoint = torch.load(model_path, map_location=device)
    # initialize state_dict from checkpoint to model
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [16]:
@torch.no_grad()
def predict(dataloader, model, laplace=False):
    py = []
    target = []
    for x, t in dataloader:
        x,t = x.to(device),t.to(device)
        target.append(t)
        if laplace:
            py.append(model(x))
        else:
            py.append(torch.softmax(model(x), dim=-1))

    images = torch.cat(py).cpu()
    labels =torch.cat(target, dim=0).cpu()
    acc_map = (images.argmax(-1) == labels).float().mean()
    ece_map = ECE(bins=15).measure(images.numpy(), labels.numpy())
    nll_map = -dists.Categorical(images).log_prob(labels).mean()
    
    return acc_map,ece_map,nll_map

# Last layer implementation

## Without Laplace

In [23]:
acc_map,ece_map,nll_map = predict(test_loader,model,laplace=False)
print(f"[MAP] Acc.: {acc_map:.1%}; ECE: {ece_map:.1%}; NLL: {nll_map:.3}")


Test set: Accuracy: 9913/10000 (99%)

Acc.: 100.0%; ECE: 1.7%; NLL: 0.0202


## With Laplace

In [18]:
la = Laplace(model, 'classification',
             subset_of_weights='last_layer',
             hessian_structure='kron')
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')

In [21]:
acc_laplace,ece_laplace,nll_laplace = predict(test_loader, la, laplace=True)

print(
        f"[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}"
    )


Test set: Accuracy: 9915/10000 (99%)

Acc.: 100.0%; ECE: 0.3%; NLL: 0.00351


# Subnetwork implementation

In [44]:
from laplace.baselaplace import FullLaplace
from laplace.curvature.backpack import BackPackGGN
from laplace.utils import ModuleNameSubnetMask

In [78]:
for name,m in model.named_modules():
    print(name)


conv1
pool1
conv2
conv2_drop
pool2
fc1
fc2


In [45]:
print('start_laplace')
subnetwork_mask = ModuleNameSubnetMask(model, module_names=['fc1'])
print('step 2')
subnetwork_mask.select()
print('step 3')
subnetwork_indices = subnetwork_mask.indices
print('step 4')
sub_laplace = Laplace(
    model,
    "classification",
    subset_of_weights="subnetwork",
    hessian_structure="full",
    subnetwork_indices = subnetwork_indices#.type(torch.LongTensor),
)
print('fit')
sub_laplace.fit(train_loader)
print('optimize')
sub_laplace.prior_precision=torch.tensor([0.00001])

#laplace.optimize_prior_precision(method="marglik",val_loader=test_loader)

start_laplace
step 2
step 3
step 4
fit




In [80]:
la_accuracy,acc_map,ece_map,nll_map = predict(sub_laplace,laplace=True)



Test set: Accuracy: 9903.0/10000 (99%)

Acc.: 100.0%; ECE: 4.0%; NLL: 0.0419


In [None]:
acc_sublaplace,ece_sublaplace,nll_sublaplace = predict(test_loader, sub_laplace, laplace=True)

print(
        f"[Subnetwork Laplace] Acc.: {acc_sublaplace:.1%}; ECE: {ece_sublaplace:.1%}; NLL: {nll_sublaplace:.3}"
    )