In [1]:
from laplace.baselaplace import FullLaplace
from laplace.curvature.backpack import BackPackGGN
import numpy as np
import torch

from laplace import Laplace, marglik_training
import torch
from torchvision import datasets, transforms
import torch.utils.data as data_utils
import matplotlib.pyplot as plt



In [4]:
import torch.nn as nn


In [7]:
import torch.nn.functional as F


In [56]:
model_path = '/Users/georgioszefkilis/Bayesian_Deep_Learning/models/best_checkpoint.pth'

In [5]:
config = {
    "kernel_size": 5,
    "channels":1,
    "filter_1_out" :16,
    "filter_2_out" :32,
    "padding" :0,
    "stride" :1, 
    "pool":2,
    "learning_rate": 0.01,
    "epochs": 50,
    "batch_size": 64,
    "crop_size":128
}


In [3]:


def compute_conv_dim(dim_size, kernel_size, padding, stride):
  # (I-F)+2*P/S +1
    return int((dim_size - kernel_size + 2 * padding) / stride + 1)

def compute_pool_dim(dim_size, kernel_size, stride):
  #(I-F)/S +1
  return int((dim_size - kernel_size) / stride + 1)

In [8]:
from numpy.lib import polynomial
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()


        self.conv1 = nn.Conv2d(config["channels"], config["filter_1_out"], config["kernel_size"])
        #evaluating image dimensions after first connvolution
        self.conv1_out_height = compute_conv_dim(28, config["kernel_size"], config["padding"], config["stride"])
        self.conv1_out_width = compute_conv_dim(28,  config["kernel_size"],  config["padding"],  config["stride"])


        #first pooling
        self.pool1 = nn.MaxPool2d(config["pool"], config["pool"])
        #evaluating image dimensions after first pooling
        self.conv2_out_height = compute_pool_dim(self.conv1_out_height, config["pool"], config["pool"])
        self.conv2_out_width = compute_pool_dim(self.conv1_out_width,  config["pool"],  config["pool"])
        
        
        #Second Convolution
        self.conv2 = nn.Conv2d(config["filter_1_out"], config["filter_2_out"], config["kernel_size"])
        #evaluating image dimensions after second convolution
        self.conv3_out_height = compute_conv_dim(self.conv2_out_height, config["kernel_size"], config["padding"], config["stride"])
        self.conv3_out_width = compute_conv_dim(self.conv2_out_width,  config["kernel_size"], config["padding"], config["stride"])
        self.conv2_drop = nn.Dropout2d()

        
        #Second pooling
        self.pool2 = nn.MaxPool2d(config["pool"], config["pool"])
        #evaluating image dimensions after second pooling
        self.conv4_out_height = compute_pool_dim(self.conv3_out_height, config["pool"], config["pool"])
        self.conv4_out_width = compute_pool_dim(self.conv3_out_width,  config["pool"], config["pool"])
        
        
        self.fc1 = nn.Linear(config["filter_2_out"]* self.conv4_out_height * self.conv4_out_width, 50)
        #print(self.fc1)
        self.fc2 = nn.Linear(50, 10)


        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
#apply laplace to the last linera layer for the first attempt
        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),#original
            #nn.Linear(10 * 28* 28, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        #print('x',x.size())
        xs = self.localization(x)

        #print('xs',xs.size())
        xs = xs.view(-1, 10 * 3 * 3) #original
        #xs = xs.view(-1, xs.size(0))

        #print('xs view',xs.size())
        theta = self.fc_loc(xs)

        #print('theta before view',theta.shape)
        theta = theta.view(-1, 2, 3)
        #print('theta',theta.shape)
        #print('size',x.size())


        grid = F.affine_grid(theta, x.size(),align_corners =True)
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):

        #print('input',x.size())
        # transform the input
        x = self.stn(x)

        #print('transform',x.size())
        # Perform the usual forward pass
        #convolutional layer 1
        x = F.relu(self.pool1(self.conv1(x)))
        #print('forward1',x.size())

        #convolutional layer 2
        x = F.relu(self.pool2(self.conv2_drop(self.conv2(x))))
        #print('forward2',x.size())

        #convolutional layer 3
        #x = F.relu(self.pool3(self.conv3_drop(self.conv3(x))))
        #print('forward3',x.size())

        #x = x.view(-1, 320) #original
        #print(self.conv3_out_height)
        #print(self.conv4_out_width)
        x = x.view(-1, config["filter_2_out"]* self.conv4_out_height * self.conv4_out_width)

        #print('flatten',x.size())

        x = F.relu(self.fc1(x))
        #print('forward4',x.size())

        x = F.dropout(x, training=self.training)
        #print('forward5',x.size())

        x = self.fc2(x)
        #print('forward6',x.size())

        return F.log_softmax(x, dim=1)


model = Net()

In [10]:
device = torch.device('cpu')

In [79]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath,map_location=device)
    model = checkpoint['model']
    model.load_state_dict(checkpoint['state_dict'])
    for parameter in model.parameters():
        parameter.requires_grad = True
    
    model.eval()
    
    return model

In [80]:
model_load = load_checkpoint(model_path)

In [81]:
print(model_load)

Net(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=512, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
  (localization): Sequential(
    (0): Conv2d(1, 8, kernel_size=(7, 7), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(8, 10, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU(inplace=True)
  )
  (fc_loc): Sequential(
    (0): Linear(in_features=90, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=6, 

In [54]:
#model.load_state_dict(torch.load(model_path,map_location=device))


In [12]:
model.eval()

Net(
  (conv1): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=512, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
  (localization): Sequential(
    (0): Conv2d(1, 8, kernel_size=(7, 7), stride=(1, 1))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): ReLU(inplace=True)
    (3): Conv2d(8, 10, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): ReLU(inplace=True)
  )
  (fc_loc): Sequential(
    (0): Linear(in_features=90, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=6, 

In [18]:
train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='.', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])), batch_size=config["batch_size"], shuffle=True, num_workers=0)

In [16]:
test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='.', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])), batch_size=config["batch_size"], shuffle=True, num_workers=0)

In [82]:
targets = torch.cat([y for x, y in test_loader], dim=0).numpy()


In [30]:
@torch.no_grad()
def predict(dataloader, model, laplace=False):
    py = []

    for x, _ in dataloader:
        if laplace:
            py.append(model(x))
        else:
            py.append(torch.softmax(model(x), dim=-1))

    return torch.cat(py).cpu().numpy()

In [20]:
import torch.distributions as dists


In [83]:
probs_map = predict(test_loader, model_load, laplace=False)




In [84]:
probs_map

array([[4.18608527e-16, 1.28331623e-10, 1.26332495e-12, ...,
        1.17467425e-09, 2.65577421e-11, 4.58735556e-08],
       [2.97033861e-16, 3.29422600e-11, 1.25809736e-12, ...,
        5.78506854e-10, 5.44891779e-12, 3.53229179e-09],
       [1.00000000e+00, 1.61064619e-14, 3.30366916e-08, ...,
        1.13900195e-10, 5.40308388e-11, 1.88025498e-10],
       ...,
       [1.99723682e-10, 9.99999166e-01, 6.77935304e-07, ...,
        1.36583381e-07, 3.01092129e-08, 4.78407491e-09],
       [1.89883238e-08, 9.99910593e-01, 5.15634456e-05, ...,
        3.60269297e-07, 1.10769488e-05, 5.13353768e-07],
       [6.35115853e-08, 1.53136934e-10, 3.72874648e-10, ...,
        1.44481188e-13, 1.52956403e-08, 1.15241983e-11]], dtype=float32)

In [85]:
acc_map = (probs_map.argmax(-1) == targets).mean()


In [86]:
acc_map

0.1028

In [87]:
torch_prob = torch.from_numpy(probs_map)
torch_target = torch.from_numpy(targets)

In [88]:
nll_map = -dists.Categorical(torch_prob).log_prob(torch_target).mean()


In [89]:
nll_map

tensor(13.0784)

In [90]:

print(f'[MAP] Acc.: {acc_map:.1%}; NLL: {nll_map:.3}')

[MAP] Acc.: 10.3%; NLL: 13.1


In [91]:
la = Laplace(model_load, 'classification',
             subset_of_weights='last_layer',
             hessian_structure='kron')
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')

AssertionError: BackPACK extension expects a backpropagation quantity but it is None. Module: Linear(in_features=50, out_features=10, bias=True), Extension: <backpack.extensions.secondorder.hbp.KFLR object at 0x7f9a07fed0d0>.