In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.distributions import Normal
from torch.utils.data import Subset
from torch.distributions import Categorical, Normal, StudentT
from torch.optim import SGD
from torch.optim.lr_scheduler import PolynomialLR

import torchvision
from torchvision import datasets, transforms

import torchmetrics
from torchmetrics.functional import calibration_error

import math
import matplotlib.pyplot as plt
import random
from collections import deque, OrderedDict
from tqdm import trange
import tqdm
import copy
import typing
from typing import Sequence, Optional, Callable, Tuple, Dict, Union

Load MNIST dataset

In [None]:
# MNIST dataset
transform = transforms.Compose([torchvision.transforms.ToTensor()])
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [None]:
# SUBSAMPLE IF NEEDED

# subsample from trainset
n_subsamples_train = 2000 # size of subset
sub_train_idx = random.sample(range(60000),n_subsamples_train)
sub_train_set = Subset(train_set, sub_train_idx)


Priors

In [None]:
# Framework for Priors

class Prior:
    def __init__(self):
        pass

    def sample(self,n):
        pass

    def log_likelihood(self,values):
        pass
    

In [None]:
# Gaussian Prior

class IsotropicGaussian(Prior):
    def __init__(self, mean=0, std=1):
        super(IsotropicGaussian,self).__init__()
        self.mean = mean
        self.std = std

    def sample(self, n):
        return np.random.normal(self.mean, self.std, size=n)

    def log_likelihood(self, weights):
        return Normal(self.mean, self.std).log_prob(torch.tensor(weights)).sum()


In [None]:
# StudentT Prior

class StudentTPrior(Prior):
    """
    Student-T Prior
    """
    def __init__(self, df=10, loc=0, scale=1, Temperature: float= 1.0):
        super().__init__()
        self.df = df
        self.loc = loc
        self.scale = scale
        self.Temperature = Temperature

    def log_likelihood(self, values) -> torch.Tensor:
        return StudentT(self.df, self.loc, self.scale).log_prob(torch.tensor(values)).sum() / self.Temperature

    def sample(self,n):
        return StudentT(self.df, self.loc, self.scale).sample((n,))

Base Networks

In [None]:
# Fully Connected Neural Network (Architecture: Fortuin et al. (2021))

class FullyConnectedNN(nn.Module):
    def __init__(self, in_features = 28*28, out_features = 10, hidden_units = 100, hidden_layers = 2):
        super().__init__()

        # Input to first layer
        self.hidden_layers = nn.ModuleList()
        self.hidden_layers.append(nn.Linear(in_features, hidden_units))

        # Hidden layers
        for i in range(hidden_layers - 1):
            self.hidden_layers.append(nn.Linear(hidden_units, hidden_units))
        
        # Output layer
        self.output_layer = nn.Linear(hidden_units, out_features)

    def forward(self, x):
        x = x.reshape(-1,28*28)
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        class_probs = self.output_layer(x)
        return class_probs
        

In [None]:
# Convolutional Neural Network (Architecture: Fortuin et al. (2021))

class ConvolutionalNN(nn.Module):
    def __init__(self):
        super(ConvolutionalNN, self).__init__()
        # First convolutional layer
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        # Second convolutional layer
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        # Fully connected layer
        self.fc1 = nn.Linear(64 * 7 * 7, 10)
        
    def forward(self, x):
        # Add channel dimension to input tensor
        x = x.view(-1, 1, 28, 28)
        # Apply ReLU non-linearity and max pooling after the first convolutional layer
        x = F.relu(F.max_pool2d(F.relu(self.conv1(x)), 2))
        # Apply ReLU non-linearity and max pooling after the second convolutional layer
        x = F.relu(F.max_pool2d(F.relu(self.conv2(x)), 2))
        # Flatten the output of the second convolutional layer
        x = x.view(-1, 64 * 7 * 7)
        # Apply ReLU non-linearity to the output of the fully connected layer
        class_probs = self.fc1(x)
        return class_probs

Optimizer SGLD

In [None]:
# From https://github.com/ratschlab/bnn_priors/blob/main/bnn_priors/mcmc/sgld.py

def dot(a, b):
    "return (a*b).sum().item()"
    return (a.view(-1) @ b.view(-1)).item()


class Fortuin_SGLD(torch.optim.Optimizer):
    """SGLD with momentum, preconditioning and diagnostics from Wenzel et al. 2020.
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        num_data (int): the number of data points in this learning task
        momentum (float): momentum factor (default: 0)
        temperature (float): Temperature for tempering the posterior.
                             temperature=0 corresponds to SGD with momentum.
        rmsprop_alpha: decay for the moving average of the squared gradients
        rmsprop_eps: the regularizer parameter for the RMSProp update
        raise_on_no_grad (bool): whether to complain if a parameter does not
                                 have a gradient
        raise_on_nan: whether to complain if a gradient is not all finite.
    """
    def __init__(self, params: Sequence[Union[torch.nn.Parameter, Dict]], lr: float,
                 num_data: int, momentum: float=0, temperature: float=1.,
                 rmsprop_alpha: float=0.99, rmsprop_eps: float=1e-8,  # Wenzel et al. use 1e-7
                 raise_on_no_grad: bool=True, raise_on_nan: bool=False):
        assert lr >= 0 and num_data >= 0 and momentum >= 0 and temperature >= 0
        defaults = dict(lr=lr, num_data=num_data, momentum=momentum,
                        rmsprop_alpha=rmsprop_alpha, rmsprop_eps=rmsprop_eps,
                        temperature=temperature)
        super(Fortuin_SGLD, self).__init__(params, defaults)
        self.raise_on_no_grad = raise_on_no_grad
        self.raise_on_nan = raise_on_nan
        # OK to call this one, but not `sample_momentum`, because
        # `update_preconditioner` uses no random numbers.
        self.update_preconditioner()
        self._step_count = 0  # keep the `torch.optim.scheduler` happy

    def _preconditioner_default(self, state, p) -> float:
        try:
            return state['preconditioner']
        except KeyError:
            v = state['preconditioner'] = 1.
            return v

    def delta_energy(self, a, b) -> float:
        return math.inf

    @torch.no_grad()
    def sample_momentum(self, keep=0.0):
        "Sample the momenta for all the parameters"
        assert 0 <= keep and keep <= 1.
        if keep == 1.:
            return
        for group in self.param_groups:
            std = math.sqrt(group['temperature']*(1-keep))
            for p in group['params']:
                if keep == 0.0:
                    self.state[p]['momentum_buffer'] = torch.randn_like(p).mul_(std)
                else:
                    self.state[p]['momentum_buffer'].mul_(math.sqrt(keep)).add_(torch.randn_like(p), alpha=std)

    @torch.no_grad()
    def step(self, closure: Optional[Callable[..., torch.Tensor]]=None,
             calc_metrics=True, save_state=False):
        assert save_state is False
        return self._step_internal(self._update_group_fn, self._step_fn,
                                   closure, calc_metrics=calc_metrics)
    initial_step = step

    @torch.no_grad()
    def final_step(self, closure: Optional[Callable[..., torch.Tensor]]=None,
                   calc_metrics=True, save_state=False):
        assert save_state is False
        return self._step_internal(self._update_group_fn, self._step_fn,
                                   closure, calc_metrics=calc_metrics,
                                   is_final=True)


    def _step_internal(self, update_group_fn, step_fn, closure, **step_fn_kwargs):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        try:
            for group in self.param_groups:
                update_group_fn(group)
                for p in group['params']:
                    if p.grad is None:
                        if self.raise_on_no_grad:
                            raise RuntimeError(
                                f"No gradient for parameter with shape {p.shape}")
                        continue
                    if self.raise_on_nan and not torch.isfinite(p.grad).all():
                        raise ValueError(
                            f"Gradient of shape {p.shape} is not finite: {p.grad}")
                    step_fn(group, p, self.state[p], **step_fn_kwargs)

        except KeyError as e:
            if e.args[0] == "momentum_buffer":
                raise RuntimeError("No 'momentum_buffer' stored in state. "
                                   "Perhaps you forgot to call `sample_momentum`?")
            raise e
        return loss

    def _update_group_fn(self, g):
        g['hn'] = math.sqrt(g['lr'] * g['num_data'])
        g['h'] = math.sqrt(g['lr'] / g['num_data'])
        g['noise_std'] = math.sqrt(2*(1 - g['momentum']) * g['temperature'])

    def _step_fn(self, group, p, state, calc_metrics=True, is_final=False):
        """if is_final, do not change parameters or momentum"""
        M_rsqrt = self._preconditioner_default(state, p)
        d = p.numel()

        # Update the momentum with the gradient
        if group['momentum'] > 0:
            momentum = state['momentum_buffer']
            if calc_metrics:
                # NOTE: the momentum is from the previous time step
                state['est_temperature'] = dot(momentum, momentum) / d
            if not is_final:
                momentum.mul_(group['momentum']).add_(p.grad, alpha=-group['hn']*M_rsqrt)
        else:
            if not is_final:
                momentum = p.grad.detach().mul(-group['hn']*M_rsqrt)
            if calc_metrics:
                # TODO: make the momentum be from the previous time step
                state['est_temperature'] = dot(momentum, momentum) / d

        if not is_final:
            # Add noise to momentum
            if group['temperature'] > 0:
                momentum.add_(torch.randn_like(momentum), alpha=group['noise_std'])

        if calc_metrics:
            # NOTE: p and p.grad are from the same time step
            state['est_config_temp'] = dot(p, p.grad) * (group['num_data']/d)

        if not is_final:
            # Take the gradient step
            p.add_(momentum, alpha=group['h']*M_rsqrt)

            # RMSProp moving average
            alpha = group['rmsprop_alpha']
            state['square_avg'].mul_(alpha).addcmul_(p.grad, p.grad, value=1 - alpha)

    @torch.no_grad()
    def update_preconditioner(self):
        """Updates the preconditioner for each parameter `state['preconditioner']` using
        the estimated `state['square_avg']`.
        """
        precond = OrderedDict()
        min_s = math.inf

        for group in self.param_groups:
            eps = group['rmsprop_eps']
            for p in group['params']:
                state = self.state[p]
                try:
                    square_avg = state['square_avg']
                except KeyError:
                    square_avg = state['square_avg'] = torch.ones_like(p)

                precond[p] = square_avg.mean().item() + eps
                min_s = min(min_s, precond[p])

        for p, new_M in precond.items():
            # ^(1/2) to form the preconditioner,
            # ^(-1/2) because we want the preconditioner's inverse square root.
            self.state[p]['preconditioner'] = (new_M / min_s)**(-1/4)

Bayesian Neural Network by MCMC

In [None]:
class BayesianNN:
    def __init__(self, dataset_train, network, prior,
     num_epochs = 300, max_size = 100, burn_in = 100, lr = 1e-3, sample_interval = 1):

        # Hyperparameters and general parameters
        self.learning_rate = lr
        self.num_epochs = num_epochs
        self.burn_in = burn_in
        self.sample_interval = sample_interval
        self.max_size = max_size

        self.batch_size = 128
        self.print_interval = 50
        
        # Data Loader
        self.data_loader = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True)
        self.sample_size = dataset_train.__len__()

        # Set Prior
        self.prior = prior

        # Initialize the network
        self.network = network

        # Set optimizer
        self.optimizer = Fortuin_SGLD(self.network.parameters(), lr=self.learning_rate, num_data=self.batch_size)

        # Scheduler for polynomially decreasing learning rates
        self.scheduler = PolynomialLR(self.optimizer, total_iters = self.num_epochs, power = 0.5)

        # Deque to store model samples
        self.model_sequence = deque()

    def train(self):
        num_iter = 0
        print('Training Modelihno')

        self.network.train()
        progress_bar = trange(self.num_epochs)

        N = self.sample_size

        for _ in progress_bar:
            num_iter += 1

            for batch_idx, (batch_x, batch_y) in enumerate(self.data_loader):
                self.network.zero_grad()
                n = len(batch_x)

                # Perform forward pass
                current_logits = self.network(batch_x)

                # Calculate log_likelihood of weights for a given prior

                parameters = self.network.state_dict()     # extract weights from network
                param_values = list(parameters.values())    # list weights
                param_flat = np.concatenate([v.flatten() for v in param_values])    # flattern
                log_prior = self.prior.log_likelihood(param_flat)              # calculate log_lik

                # Calculate the loss
                loss = N/n*F.nll_loss(F.log_softmax(current_logits, dim=1), batch_y) - log_prior#/#len(param_flat)

                # Backpropagate to get the gradients
                loss.backward()

                # Update the weights
                self.optimizer.step()

                # Update Metrics according to print_interval
                if batch_idx % self.print_interval == 0:
                    current_logits = self.network(batch_x)
                    current_accuracy = (current_logits.argmax(axis=1) == batch_y).float().mean()
                    progress_bar.set_postfix(loss=loss.item(), acc=current_accuracy.item(),
                    nll_loss=N/n*F.nll_loss(F.log_softmax(current_logits, dim=1), batch_y).item(),
                    log_prior_normalized = - log_prior.item()/len(param_flat),
                    lr = self.optimizer.param_groups[0]['lr'])

            # Decrease lr based on scheduler
            self.scheduler.step()
            
            # Save the model samples if past the burn-in epochs according to sampling interval
            if num_iter > self.burn_in and num_iter % self.sample_interval == 0:
                self.model_sequence.append(copy.deepcopy(self.network))
                # self.network.state_dict()

            # If model_sequence to big, delete oldest model
            if len(self.model_sequence) > self.max_size:
                self.model_sequence.popleft()

    def predict_probabilities(self, x: torch.Tensor) -> torch.Tensor:
        self.network.eval()

        # Sum predictions from all models in model_sequence
        estimated_probability = torch.zeros((len(x), 10))

        for model in self.model_sequence:

            self.network.load_state_dict(model.state_dict())
            logits = self.network(x).detach()
            estimated_probability += F.softmax(logits, dim=1)
        
        # Normalize the combined predictions to get average predictions
        estimated_probability /= len(self.model_sequence)

        assert estimated_probability.shape == (x.shape[0], 10)  
        return estimated_probability
    
    def test(self,x):
        # test set
        x_test = x.data.float() 
        y_test = x.targets         

        # predicted probabilities
        class_probs = self.predict_probabilities(x_test)

        # accuracy
        accuracy = (class_probs.argmax(axis=1) == y_test).float().mean()
        return  accuracy #print(f'Test Accuracy: {accuracy.item():.4f}')

    def calibration(self,x):
        # test set
        x_test = x.data.float() 
        y_test = x.targets         

        # predicted probabilities
        class_probs = self.predict_probabilities(x_test)

        calib_err = calibration_error(class_probs, y_test, n_bins = 30, task = "multiclass", norm="l1", num_classes=10)
        return calib_err #print(f'Calibration Error: {calib_err.item():.4f}')



List of all Networks

In [None]:
Networks = [FullyConnectedNN(), ConvolutionalNN()]

List of all Priors

In [None]:
Priors = [IsotropicGaussian(), StudentTPrior()]

List of all Temperatures

In [None]:
Temperatures = [1.0]

Training Loop

In [None]:
Accuracies = []
Calibrations = []

for network in Networks:
    for prior in Priors:
        for temperature in Temperatures:
            lol = BayesianNN(sub_train_set,network = network, prior=prior, num_epochs=200)
            lol.train()
            Accuracies.append(lol.test(test_set).item())
            Calibrations.append(lol.calibration(test_set).item())
    

Evaluation Metrics

In [None]:
Accuracies

In [None]:
Calibrations

Plot Results