# Self stabilising priors for robust Bayesian deep learning

This notebook is designed to demonstrate the principle to give the idea and a basic implementation to make it accessible. We demonstrate accelerated training with an experiment on MNIST.


This notebook is based on  https://github.com/senya-ashukha/sparse-vd-pytorch/blob/master/svdo-solution.ipynb

<img src="intuition1.png" />

### Installation and to run on google colab  


In [1]:
# Logger
#!pip install tabulate -q
#from google.colab import files
#src = list(files.upload().values())[0]
#open('logger.py','wb').write(src)

In [2]:
from logger import Logger

# Implementation

In [3]:
import math
import torch
import numpy as np
import pandas as pd

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.nn import Parameter
from torchvision import datasets, transforms

In [4]:
device = 'cpu'
if torch.cuda.is_available():
    use_cuda = True
    device = 'cuda'

In [5]:
# hyperparameters
width = 256
input_shape = 28*28
output_size = 10
batch_size = 100
init_var = 0.001

n_samples = 20
kl_weight = 1.0
epochs = 10

## Our proposed self stabilising layer for Bayesian Neural Network
For the stabilising prior to be effective we sample from a reparametrised, $\tilde{q}(W)$, which is the product of the current posterior, $q(W)$, and the prior, $p(W)$. This allows the the influence of the prior on the forward pass so we can propagate cleaner signals. 

The other main differences between this layer and a normal Bayesian layer is the update prior function, because our prior adapts based on the current settings of the weights to stabilise the signal.

In [6]:
class SelfStabilisingLayer(nn.Module):
    '''
    Iteratively updating self stabilising prior.
    Fully factorised Gaussian priors and posteriors.
    Local reparametrisation trick.
    '''

    def __init__(self, in_features, out_features, init_var=0.001, prior_var=0.02):
        super(SelfStabilisingLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.W = Parameter(torch.Tensor(out_features, in_features))
        self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(1, out_features))

        # initialisation values of parameters
        self.init_var = np.log(init_var)
        self.reset_parameters()
        

    def reset_parameters(self):

        self.log_sigma.data.fill_(self.init_var / 2)
        self.bias.data.zero_()

        # He initialisation
        init = np.sqrt(2 / self.in_features)
        self.W.data.normal_(0, init)
        

    def forward(self, x):

        # local reparametrisation trick with new parameters from q_tilde(W)
        # i.e. use self.new_mu and self.new_sigma_sq
        lrt_mean = F.linear(x, self.new_mu) + self.bias
        lrt_std = torch.sqrt(F.linear(x * x, self.new_sigma_sq) + 1e-8)
        eps = lrt_std.data.new(lrt_std.size()).normal_()
        pre_activation = lrt_mean + lrt_std * eps

        return pre_activation
    

    def update_prior(self):
        
        #####################################################################
        # Main difference between normal BNN and Stabilising prior
        #####################################################################

        # Sum of all incoming nodes to specific hidden units
        mu_L = torch.sum(self.W, dim=1)
        sig_sq_L = torch.sum(torch.exp(self.log_sigma * 2.0), dim=1)

        # PRIOR VARIANCE        
        gamma = 2- (1-1/math.pi) *mu_L * mu_L
        self.prior_var = (gamma * sig_sq_L)/(sig_sq_L - gamma)
        self.prior_var = self.prior_var / self.in_features

        # shared prior across all weights feeding into the same hidden unit
        self.prior_var = self.prior_var.expand(self.in_features, self.prior_var.shape[0]).t()
        self.prior_var = torch.abs(self.prior_var)
        
        
        # PRIOR MEAN (mean preserving)
        self.prior_mean = self.W

        
        # PRODUCT, set the parameters from which we will be sampling q_tilde(W)
        self.new_mu, self.new_sigma_sq = multipy_gaussian(self.W, self.prior_mean, 
                                                            torch.exp(self.log_sigma * 2.0), self.prior_var)

        
    def kl_reg(self):

        # cross entropy term
        sigma_sq = torch.exp(self.log_sigma.view(-1) * 2)
        new_sigma_sq = torch.exp(self.new_sigma_sq.view(-1))
        pi = math.pi

        H = 0.5 * torch.log(2 * pi * sigma_sq) + (new_sigma_sq / sigma_sq)
        H = torch.sum(H)

        return H


## Normal Bayesian Neural Network layer
Non-conjugate Gaussian prior and Gaussian posterior. We also make use of the Local Reparametrisation trick.

In [7]:
class LocalReparametrisationLayer(nn.Module):
    '''
    Doubly stochastic Variational Bayes for non-conjugate inference.
    Fully factorised Gaussian priors and posteriors.
    Local reparametrisation trick.
    '''

    def __init__(self, in_features, out_features, bias=True, init_var=0.001, prior_var=0.0001):
        super(LocalReparametrisationLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.W = Parameter(torch.Tensor(out_features, in_features))
        self.log_sigma = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(1, out_features))

        # add priors 
        self.prior_mean = torch.Tensor([0]).to(device)
        self.prior_var = torch.Tensor([prior_var]).to(device)

        # initialisation values of parameters
        self.init_var = np.log(init_var)
        self.reset_parameters()


    def reset_parameters(self):

        self.log_sigma.data.fill_(self.init_var / 2)
        self.bias.data.zero_()

        # critical initialisation for normal Bayesian Neural networks
        init = np.sqrt(np.abs((2 - self.in_features * np.exp(self.init_var)) / self.in_features))
        self.W.data.normal_(0, init)

    def forward(self, x):

        # local reparametrisation trick 
        lrt_mean = F.linear(x, self.W) #+ self.bias
        lrt_std = torch.sqrt(F.linear(x * x, torch.exp(self.log_sigma * 2.0)) + 1e-8)
        eps = lrt_std.data.new(lrt_std.size()).normal_()
        pre_activation = lrt_mean + lrt_std * eps

        if self.training:
            self.signal_variance = pre_activation.var(dim=1)[0].data.cpu().numpy()          

        return pre_activation
    
    def kl_reg(self):

        # KL divergence 
        mean = self.W.view(-1)
        sigma = torch.exp(self.log_sigma).view(-1)

        prior_sigma = torch.sqrt(self.prior_var).view(-1)
        prior_mean = self.prior_mean.view(-1)

        p = torch.distributions.normal.Normal(prior_mean, prior_sigma)
        q = torch.distributions.normal.Normal(mean, sigma)

        kl = torch.distributions.kl.kl_divergence(q, p)

        kl = torch.sum(kl)
        return kl

### Generic ReLU network architecture where we specify the type of layers

In [8]:
# Define a simple fully connected ReLU Network
class Net(nn.Module):
    def __init__(self, layer_type, input_size, width=256, init_var=0.001):
        super(Net, self).__init__()
        self.fc_in = layer_type(input_size, width, init_var=init_var)
        self.fc_h1 = layer_type(width, width, init_var=init_var)
        self.fc_h2 = layer_type(width, width, init_var=init_var)
        self.fc_h3 = layer_type(width, width, init_var=init_var)
        self.fc_out = layer_type(width,  output_size, init_var=init_var)

    def forward(self, x):
        x = F.relu(self.fc_in(x))
        x = F.relu(self.fc_h1(x))
        x = F.relu(self.fc_h2(x))
        x = F.relu(self.fc_h3(x))
        x = F.log_softmax(self.fc_out(x), dim=1)
        return x
    
    
    def update_priors(self):
        if hasattr(self.fc_h1, 'update_prior'):
            self.fc_in.update_prior()
            self.fc_out.update_prior()
            self.fc_h1.update_prior()
            self.fc_h2.update_prior()
            self.fc_h3.update_prior()


### Loss function

In [9]:
# Define New Loss Function -- SGVLB 
class SGVLB(nn.Module):
    def __init__(self, net, train_size, batch_size):
        super(SGVLB, self).__init__()
        self.train_size = train_size
        self.batch_size = batch_size
        self.num_batches = batch_size / train_size
        self.net = net

    def forward(self, output, target, kl_weight=1.0):
        assert not target.requires_grad
        kl = 0.0
        for module in self.net.children():
            if hasattr(module, 'kl_reg'):
                kl = kl + module.kl_reg()
        kl = kl * self.num_batches
        kl = kl / (self.batch_size*self.train_size)
        #kl = kl_weight * kl
        return F.nll_loss(output, target, reduction='mean') + kl

### Data loaders

In [10]:
def get_mnist(batch_size):
    trsnform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
        transform=trsnform), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, download=True,
        transform=trsnform), batch_size=batch_size, shuffle=True)

    return train_loader, test_loader

### Utility functions

In [11]:
def multipy_gaussian(mean1, mean2, var1, var2):

    # calculate variance
    new_var = 1 / ((1 / var1) + (1 / var2))

    # calculate mean
    new_mu = new_var * (mean2 / var2 + mean1 / var1)

    return new_mu, new_var

## Training loop

In [12]:
# training loop with logging
def train(model, epochs, optimizer, train_loader, test_loader, loss_fn, logger):
    
    for epoch in range(1, epochs + 1):

        model.train() 
        train_loss, train_acc = 0, 0

        for batch_idx, (data, target) in enumerate(train_loader):

            #####################################################################
            # Training
            #####################################################################

            optimizer.zero_grad()
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            data = data.view(-1, input_shape)
              
            # update priors
            model.update_priors()

            # forward prop
            output = model(data)
            pred = output.data.max(1)[1]

            # backprop
            loss = loss_fn(output, target, kl_weight)
            loss.backward(retain_graph=True)
            optimizer.step()

            # training loss and accuracy (training accuracy does not reflect ensemble)
            train_loss += loss.item()
            train_acc += torch.sum(pred.eq(target))

        #####################################################################
        # Evaluate on test set
        #####################################################################
        model.eval()

        test_loss, avg_test_acc, total_brier = 0, 0, 0
        total_90, total_70, total_50 = 0, 0, 0
        for tbatch_idx, (data, target) in enumerate(test_loader):
            # prep
            if use_cuda:
                data, target = data.cuda(), target.cuda()
            data = data.view(-1, input_shape)

            # average over weights with samples of different parameters
            probs, test_loss = 0, 0
            for i in range(n_samples):
                output = model(data)
                probs += output.data
                test_loss += float(loss_fn(output, target).item())
            mean_probs = probs / n_samples
            avg_test_loss = test_loss / n_samples
            pred = torch.argmax(mean_probs, dim=1)
            avg_test_acc += torch.sum(pred == target)

        

        #####################################################################
        # Logging
        #####################################################################
        # log training and test loss and accuracy
        logger.add_scalar(epoch, 'trlos', train_loss)
        logger.add_scalar(epoch, 'telos', avg_test_loss)

        logger.add_scalar(epoch, 'tracc', (float(train_acc) / (batch_size * (batch_idx + 1)) * 100))
        logger.add_scalar(epoch, 'teacc', float(avg_test_acc) / len(test_loader.dataset) * 100)

        logger.iter_info()
    logger.save()

## Experiment 1
A very simple example on MNIST to demonstrate accelerated training

In [13]:
width = 512
init_var = 0.01
epochs = 10
input_shape = 28*28

#### 1. Normal BNN

In [14]:
model = Net(LocalReparametrisationLayer, input_shape, width=width, init_var=init_var)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
logger = Logger(name='LRTNet')

train_loader, test_loader = get_mnist(batch_size=batch_size)
loss_fn = SGVLB(model, len(train_loader.dataset), batch_size)

if device == 'cuda':
    model.cuda()
    
train(model, epochs, optimizer, train_loader, test_loader, loss_fn, logger)

  step    trlos    telos    tracc    teacc
------  -------  -------  -------  -------
     1  2010.57     1.87    15.36    23.09
     2   950.00     1.28    41.93    54.99
     3   728.50     1.22    55.36    62.84
     4   458.09     0.50    74.33    91.99
     5   233.44     0.24    90.56    93.72
     6   189.43     0.30    92.45    94.55
     7   164.02     0.32    93.33    95.32
     8   140.17     0.17    94.41    95.30
     9   124.87     0.28    95.07    95.85
    10   113.41     0.28    95.53    96.27
The log/output have been saved to: ./logs//LRTNet-otb-05-22-12:52 + .csv/.out/


#### 2. Self stabilising prior

In [15]:
kl_weight = 1.0

model = Net(SelfStabilisingLayer, input_shape, width=width, init_var=init_var)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
logger = Logger(name='StabilisedNet')

train_loader, test_loader = get_mnist(batch_size=batch_size)
loss_fn = SGVLB(model, len(train_loader.dataset), batch_size)

if device == 'cuda':
    model.cuda()
    
train(model, epochs, optimizer, train_loader, test_loader, loss_fn, logger)

  step    trlos    telos    tracc    teacc
------  -------  -------  -------  -------
     1   346.23     0.24    84.03    94.91
     2   147.74     0.14    93.84    96.65
     3   116.52     0.15    95.38    97.07
     4    95.47     0.24    96.46    97.29
     5    84.51     0.14    96.97    97.61
     6    73.30     0.15    97.53    97.45
     7    66.33     0.07    97.76    97.83
     8    61.38     0.14    98.07    97.83
     9    56.38     0.09    98.21    98.08
    10    53.54     0.12    98.28    97.82
The log/output have been saved to: ./logs//StabilisedNet-xne-05-22-12:55 + .csv/.out/
