In [1]:
from laplace.baselaplace import FullLaplace
from laplace.curvature.backpack import BackPackGGN
import numpy as np
import torch

from laplace import Laplace, marglik_training
import torch
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import matplotlib.pyplot as plt
import torchvision


In [2]:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [3]:
import torch.nn as nn


In [4]:
import torch.nn.functional as F


In [13]:
config = {
    "num_classes":10,
    "kernel_size": 5,
    "channels":1,
    "filter_1_out" :16,
    "filter_2_out" :32,
    "padding" :0,
    "stride" :1, 
    "pool":2,
    "learning_rate": 0.001,
    "epochs": 20,
    "batch_size": 64,
    "crop_size":128
}


In [6]:
device = torch.device('cpu')

In [7]:


def compute_conv_dim(dim_size, kernel_size, padding, stride):
  # (I-F)+2*P/S +1
    return int((dim_size - kernel_size + 2 * padding) / stride + 1)

def compute_pool_dim(dim_size, kernel_size, stride):
  #(I-F)/S +1
  return int((dim_size - kernel_size) / stride + 1)

In [8]:
train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='.', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])), batch_size=config["batch_size"], shuffle=True, num_workers=2)

In [9]:
test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='.', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])), batch_size=config["batch_size"], shuffle=True, num_workers=2)

In [10]:
train_features, train_labels = next(iter(train_loader))
height = train_features.shape[2]
width = train_features.shape[3]
print(height,width)


28 28


In [11]:
class Net(nn.Module):
    def __init__(
        self,
        num_classes,
        channels,
        filter_1_out,
        filter_2_out,
        kernel_size,
        padding,
        stride,
        height,
        width,
        pool,
        parameterize,
    ):
        super(Net, self).__init__()
        self.num_classes = (num_classes,)
        self.channels = (channels,)
        self.filter_1_out = (filter_1_out,)
        self.filter_2_out = (filter_2_out,)
        self.kernel_size = (kernel_size,)
        self.padding = (padding,)
        self.stride = (stride,)
        self.height = (height,)
        self.width = (width,)
        self.pool = (pool,)
        self.parameterize = parameterize

        self.conv1 = nn.Conv2d(channels, filter_1_out, kernel_size)
        # evaluating image dimensions after first connvolution
        self.conv1_out_height = compute_conv_dim(
            height, kernel_size, padding, stride
        )
        self.conv1_out_width = compute_conv_dim(
            width, kernel_size, padding, stride
        )

        # first pooling
        self.pool1 = nn.MaxPool2d(pool, pool)
        # evaluating image dimensions after first pooling
        self.conv2_out_height = compute_pool_dim(
            self.conv1_out_height, pool, pool
        )
        self.conv2_out_width = compute_pool_dim(
            self.conv1_out_width, pool, pool
        )

        # Second Convolution
        self.conv2 = nn.Conv2d(filter_1_out, filter_2_out, kernel_size)
        # evaluating image dimensions after second convolution
        self.conv3_out_height = compute_conv_dim(
            self.conv2_out_height, kernel_size, padding, stride
        )
        self.conv3_out_width = compute_conv_dim(
            self.conv2_out_width, kernel_size, padding, stride
        )
        self.conv2_drop = nn.Dropout2d()

        # Second pooling
        self.pool2 = nn.MaxPool2d(pool, pool)
        # evaluating image dimensions after second pooling
        self.conv4_out_height = compute_pool_dim(
            self.conv3_out_height, pool, pool
        )
        self.conv4_out_width = compute_pool_dim(
            self.conv3_out_width, pool, pool
        )

        self.fc1 = nn.Linear(
            filter_2_out * self.conv4_out_height * self.conv4_out_width, 50
        )
        self.fc2 = nn.Linear(50, num_classes)

        

    def forward(self, x):

        
        # convolutional layer 1
        x = F.relu(self.pool1(self.conv1(x)))

        # convolutional layer 2
        x = F.relu(self.pool2(self.conv2_drop(self.conv2(x))))

        x = x.view(
            -1, self.filter_2_out[0] * self.conv4_out_height * self.conv4_out_width
        )

        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)

        return x

In [15]:
model = Net(
        config["num_classes"],
        config["channels"],
        config["filter_1_out"],
        config["filter_2_out"],
        config["kernel_size"],
        config["padding"],
        config["stride"],
        height,
        width,
        config["pool"],parameterize=False).to(device).eval()

In [16]:
# model_path = '/Users/georgioszefkilis/Bayesian_Deep_Learning/models/best_checkpoint.pth'
model_path = '/Users/georgioszefkilis/Bayesian_Deep_Learning/models/colab_best_Vanilla_MNIST_20.pth'
checkpoint = torch.load(model_path, map_location=device)
    # initialize state_dict from checkpoint to model
model.load_state_dict(checkpoint["state_dict"])

<All keys matched successfully>

In [17]:
targets = torch.cat([y for x, y in test_loader], dim=0).cpu()


In [18]:
train_labels

tensor([8, 3, 2, 4, 6, 9, 0, 9, 2, 9, 5, 4, 0, 3, 1, 2, 1, 9, 0, 2, 7, 5, 9, 4,
        3, 6, 1, 8, 1, 6, 9, 6, 1, 7, 6, 2, 7, 5, 2, 4, 7, 2, 3, 5, 6, 5, 4, 0,
        8, 8, 9, 5, 5, 5, 0, 8, 4, 4, 4, 1, 4, 1, 0, 1])

In [20]:
@torch.no_grad()
def predict(dataloader, model, laplace=False):
    py = []

    for x, _ in dataloader:
        if laplace:
            py.append(model(x.cpu()))
        else:
            py.append(torch.softmax(model(x.cpu()), dim=-1))

    return torch.cat(py).cpu()

In [21]:
import torch.distributions as dists


In [22]:
probs_map = predict(test_loader, model, laplace=False)


In [23]:
probs_map

tensor([[9.9998e-01, 4.3580e-13, 1.6771e-07,  ..., 2.1847e-12, 1.8537e-07,
         9.8797e-07],
        [1.7258e-09, 3.1709e-04, 9.8422e-01,  ..., 6.7594e-11, 4.6883e-09,
         8.3592e-13],
        [1.6016e-33, 3.6112e-19, 1.4754e-19,  ..., 3.6839e-17, 3.8276e-17,
         9.0575e-11],
        ...,
        [1.9967e-11, 1.1930e-07, 9.9996e-01,  ..., 3.0110e-05, 2.3017e-06,
         3.3882e-10],
        [6.2772e-17, 2.6773e-14, 3.0573e-13,  ..., 2.8562e-08, 4.6044e-09,
         9.9998e-01],
        [1.0000e+00, 1.0233e-13, 3.1913e-08,  ..., 1.3165e-13, 3.9375e-08,
         1.8443e-09]])

In [26]:
acc_map = (probs_map.argmax(-1) == targets).float().mean()


In [27]:
acc_map

tensor(0.0955)

In [25]:
torch_prob = torch.from_numpy(probs_map)
torch_target = torch.from_numpy(targets)

In [28]:
nll_map = -dists.Categorical(probs_map).log_prob(targets).mean()


In [29]:
nll_map

tensor(13.8170)

In [29]:

print(f'[MAP] Acc.: {acc_map:.1%}; NLL: {nll_map:.3}')

[MAP] Acc.: 10.1%; NLL: 13.8


In [30]:
la = Laplace(model, 'classification',
             subset_of_weights='last_layer',
             hessian_structure='kron')
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')

In [32]:
probs_laplace = predict(test_loader, la, laplace=True)
acc_laplace = (probs_laplace.argmax(-1) == targets).float().mean()
#ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy())
nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean()

print(f'[Laplace] Acc.: {acc_laplace:.1%}; NLL: {nll_laplace:.3}')

[Laplace] Acc.: 10.0%; NLL: 11.5
