# To run our code for the DL project in google colab, we load all scripts through cells 

## data.py

In [None]:
import os
import numpy as np
import torch
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset, Subset
import torchvision
import torchvision.transforms as tr
from torchvision.datasets import MNIST, FashionMNIST

class Data:
    def __init__(self, dataset, augmentations=None):
        self.dataset = dataset
        self.transforms = tr.Compose([tr.ToTensor()])

        if dataset == "MNIST":
            self.trainset = MNIST(root="data", train=True, download=True, transform=self.transforms)
            self.testset = MNIST(root="data", train=False, download=True, transform=self.transforms)

        elif dataset == "FashionMNIST":
            self.trainset = FashionMNIST(root="data", train=True, download=True, transform=self.transforms)
            self.testset = FashionMNIST(root="data", train=False, download=True, transform=self.transforms)


        x_train = self.trainset.data.reshape(-1, 28*28)/255.
        y_train = self.trainset.targets
        self.train_data = TensorDataset(x_train, y_train)

        x_test = self.testset.data.reshape(-1, 28*28)/255.
        y_test = self.testset.targets
        self.test_data = TensorDataset(x_test, y_test)

        if augmentations is not None:
            # increase the size of the training set by applying augmentations
            self.transforms_extra = tr.Compose([self.transforms, augmentations])

            if dataset == "MNIST":
                self.trainset_extra = MNIST(root="data", train=True, download=True, transform=self.transforms_extra)
            elif dataset == "FashionMNIST":
                self.trainset_extra = FashionMNIST(root="data", train=True, download=True, transform=self.transforms_extra)

            x_train_extra = self.trainset.data.reshape(-1, 28*28)/255.
            y_train_extra = self.trainset.targets
            train_data_extra = TensorDataset(x_train_extra, y_train_extra)

            self.train_data = ConcatDataset([self.train_data, train_data_extra])


    def get_data(self, num_train_samples=None):
        if num_train_samples is not None:
            sub_train_idx = np.random.choice(self.train_data.__len__(), num_train_samples, replace=False)
            sub_train_data = Subset(self.train_data, sub_train_idx)

        else:
            sub_train_data = self.train_data

        return sub_train_data, self.test_data

## priors.py

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.distributions as dist
from scipy.special import gamma
import math



# Framework for priors --------------------------------------------------------

class Prior:
    """
    This class is a base class for all priors that we use in this project.
    It enforces the implementation of the the log_likelihood and sample methods.
    """
    def __init__(self):
        pass
    def sample(self,n):
        pass
    def log_likelihood(self,values):
        pass


# Isotropic Gaussian prior ----------------------------------------------------

class Isotropic_Gaussian(Prior):
    """
    Isotropic Gaussian prior with mean (loc) and standard deviation (scale) as parameters.
    """
    def __init__(self, loc: float = 0, scale: float = 1.0, Temperature: float = 1.0):
        super().__init__()
        assert scale > 0, "Scale must be positive"
        self.loc = torch.tensor(loc, dtype=torch.float32)
        self.scale = torch.tensor(scale, dtype=torch.float32)
        self.Temperature = torch.tensor(Temperature, dtype=torch.float32)
        self.name = "Isotropic_Gaussian"

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        return dist.Normal(self.loc, self.scale).log_prob(values).sum() / self.Temperature

    def sample(self, n):
        return dist.Normal(self.loc, self.scale).sample((n,))


# Multivariate Gaussian prior -------------------------------------------------


# TODO: Implement this

class Multivariate_Diagonal_Gaussian(Prior):
    """
    Multivariate diagonal Gaussian distribution,
    i.e., assumes all elements to be independent Gaussians
    but with different means and standard deviations.
    This parameterizes the standard deviation via a parameter rho as
    sigma = softplus(rho).
    """
    def __init__(self, mu: torch.Tensor, rho: torch.Tensor, Temperature: float = 1.0):
        super().__init__()
        self.mu = mu
        self.rho = rho
        self.sigma = torch.log(1 + torch.exp(rho))
        self.Temperature = Temperature
        self.name = "Multivariate_Diagonal_Gaussian"

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        # TODO: Implement this
        return dist.Normal(self.mu, self.sigma).log_prob(values).sum() / self.Temperature

    def sample(self) -> torch.Tensor:
        # TODO: Implement this
        eps = torch.randn_like(self.mu)
        return self.mu + self.sigma * eps



# Student-t prior -------------------------------------------------------------

class StudentT_prior(Prior):
    """
    Student-T Prior with degrees of freedom (df), mean (loc) and scale (scale) as parameters.
    """
    def __init__(self, df: float = 10, loc: float = 0, scale: float = 1.0, Temperature: float = 1.0):
        super().__init__()
        self.df = torch.tensor(df, dtype=torch.float32)
        self.loc = torch.tensor(loc, dtype=torch.float32)
        self.scale = torch.tensor(scale, dtype=torch.float32)
        self.Temperature = torch.tensor(Temperature, dtype=torch.float32)
        self.name = "Student_T"

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        return dist.StudentT(self.df, self.loc, self.scale).log_prob(values).sum() / self.Temperature

    def sample(self, n):
        return dist.StudentT(self.df, self.loc, self.scale).sample((n,))  


# Laplace prior ---------------------------------------------------------------

class Laplace_prior(Prior):
    """
    Laplace Prior with mean (loc) and scale (scale) as parameters.
    """
    def __init__(self, loc: float = 0, scale: float = 1.0, Temperature: float = 1.0):
        super().__init__()
        self.loc = torch.tensor(loc, dtype=torch.float32)
        self.scale = torch.tensor(scale, dtype=torch.float32)
        self.Temperature = torch.tensor(Temperature, dtype=torch.float32)
        self.name = "Laplace"

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        return dist.Laplace(self.loc, self.scale).log_prob(values).sum() / self.Temperature

    def sample(self, n) -> torch.Tensor:
        return dist.Laplace(self.loc, self.scale).sample((n,))



# Gaussian mixture prior ------------------------------------------------------

class Gaussian_Mixture(Prior):
    """
    Mixture of two Gaussians with means (loc1, loc2), standard deviations (scale1, scale2) and mixing coefficient (mixing_coef) as parameters.
    """
    def __init__(self, loc1: float = 0, scale1: float = 3.0, loc2: float = 0, scale2: float = 1.0,
                mixing_coef: float = 0.7, Temperature: float = 1.0):
        super().__init__()
        self.loc1 = torch.tensor(loc1, dtype=torch.float32)
        self.loc2 = torch.tensor(loc2, dtype=torch.float32)
        self.scale1 = torch.tensor(scale1, dtype=torch.float32)
        self.scale2 = torch.tensor(scale2, dtype=torch.float32)
        self.mixing_coef = torch.tensor(mixing_coef, dtype=torch.float32)
        self.Temperature = torch.tensor(Temperature, dtype=torch.float32)
        self.name = "Gaussian_Mixture"

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        p1 = dist.Normal(self.loc1, self.scale1).log_prob(values)
        p2 = dist.Normal(self.loc2, self.scale2).log_prob(values)
        log_lik = (p1 * self.mixing_coef + p2 * (1-self.mixing_coef)).sum() / self.Temperature
        return log_lik

    def sample(self, n) -> torch.Tensor:
        sample1 = dist.Normal(self.loc1, self.scale1).sample((n,))
        sample2 = dist.Normal(self.loc2, self.scale2).sample((n,))
        return sample1 * self.mixing_coef + sample2 * (1-self.mixing_coef)


# Normal Inverse Gamma prior --------------------------------------------------

class Inverse_Gamma(Prior):
    """ 
    Inverse Gamma distribution with shape (shape) and rate (rate) as parameters.
    This distribution is needed for the Normal Inverse Gamma prior.
    """
    def __init__(self, shape: float = 1.0, rate: float = 1.0, Temperature: float = 1.0):
        super().__init__()
        self.shape = torch.tensor(shape, dtype=torch.float32)
        self.rate = torch.tensor(rate, dtype=torch.float32)
        self.Temperature = torch.tensor(Temperature, dtype=torch.float32)

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        """
        Computes the value of the predictive log likelihood at the target value
        """
        x = (self.rate**self.shape) / gamma(self.shape)
        y = values**(-self.shape - 1)
        z = torch.exp(-self.rate / values)
        return torch.log(x * y * z)

    def sample(self, n) -> torch.Tensor:
        # sample from gamma and return 1/x
        x = dist.Gamma(self.shape, self.rate).sample((n,))
        return 1/x


class Normal_Inverse_Gamma(Prior):
    """ 
    Normal Inverse Gamma distribution with mean (mu), precision (lam), shape (alpha) and rate (beta) as parameters.
    """
    def __init__(self, loc: float = 0, lam: float = 1, alpha: float = 1, beta: float = 1, Temperature: float = 1.0,device="cuda"):
        """
        loc: loc of the normal distribution
        lam: precision of the normal distribution
        alpha: shape parameter of the inverse gamma distribution
        beta: rate parameter of the inverse gamma distribution
        """
        super().__init__()
        self.device = torch.device(device)
        self.loc = torch.tensor(loc, dtype=torch.float32,device=self.device)
        self.lam = torch.tensor(lam, dtype=torch.float32,device=self.device)
        self.alpha = torch.tensor(alpha, dtype=torch.float32,device=self.device)
        self.beta = torch.tensor(beta, dtype=torch.float32,device=self.device)
        self.Temperature = torch.tensor(Temperature, dtype=torch.float32,device=self.device)
        self.name = "Normal_Inverse_Gamma"

        

    def log_likelihood(self, values: torch.Tensor, var = torch.Tensor) -> torch.Tensor:
        """
        Compute the likelihood of the inverse gamma distribution for an x and a variance
        """
        # manually compute the likelihood
        # var =  self.beta / (self.alpha + 1 + 0.5) + 1e-8 # to avoid division by zero

        torch.two = torch.tensor(2,dtype=torch.float32,device = self.device)
        torch.pi = torch.acos(torch.zeros(1,device=self.device)).item() * torch.two
        # torch.pi.to(self.device)
        log_likelihood = torch.xlogy(torch.tensor(0.5,device=self.device), self.lam / (torch.two * torch.pi * var)) + \
                    torch.xlogy(self.alpha,self.beta) - \
                    torch.lgamma(self.alpha) - \
                    torch.xlogy(self.alpha + torch.tensor(1,device=self.device),var) - \
                    torch.div(torch.two*self.beta + self.lam * (values - self.loc)**torch.two,(torch.two * var))
        
        return log_likelihood/ self.Temperature

        #new try:
        # for all pos integers, gamma function:
        # F(a) = (a-1)!

        #var = self.beta / (self.alpha + 1 + 0.5)  # according to formula on wikipedia
        #part1 = self.lam ** 0.5 / (2 * np.pi * var) ** 0.5
        #part2 = self.beta ** self.alpha / np.math.factorial(self.alpha-1)
        #part3 = (1/var) ** (self.alpha + 1)
        #part4 = (-2*self.beta - self.lam * (values - self.mu) ** 0.5) / (2*var)
        #likelihood = part1 * part2 * part3 * np.exp(part4)
        #log_likelihood = np.log(likelihood)
        return log_likelihood / self.Temperature


    def sample(self, n) -> torch.Tensor:
        # sample variance from inverse gamma and sample x from normal given the variance
        var = torch.div(1, dist.Gamma(self.alpha, self.beta).sample((1,)))
        x = dist.Normal(self.loc, torch.sqrt(var/self.lam)).sample((n,))
        return x, var



# class Normal_Inverse_Gamma(Prior):
#     """ 
#     Normal Inverse Gamma distribution with mean (mu), precision (lam), shape (alpha) and rate (beta) as parameters.
#     """
#     def __init__(self, loc: float = 0, lam: float = 1, alpha: float = 1, beta: float = 1, Temperature: float = 1.0):
#         """
#         loc: loc of the normal distribution
#         lam: precision of the normal distribution
#         alpha: shape parameter of the inverse gamma distribution
#         beta: rate parameter of the inverse gamma distribution
#         """
#         super().__init__()
#         self.loc = torch.tensor(loc, dtype=torch.float32)
#         self.lam = torch.tensor(lam, dtype=torch.float32)
#         self.alpha = torch.tensor(alpha, dtype=torch.float32)
#         self.beta = torch.tensor(beta, dtype=torch.float32)
#         self.Temperature = torch.tensor(Temperature, dtype=torch.float32)
#         self.name = "Normal_Inverse_Gamma"
    

#     def log_likelihood(self, values: torch.Tensor, var = torch.Tensor) -> torch.Tensor:
#         """
#         Compute the likelihood of the inverse gamma distribution for an x and a variance
#         """
#         # manually compute the likelihood
#         # var =  self.beta / (self.alpha + 1 + 0.5) + 1e-8 # to avoid division by zero

#         log_likelihood = torch.xlogy(0.5, self.lam / (2 * math.pi * var)) + \
#                     torch.xlogy(self.alpha,self.beta) - \
#                     torch.lgamma(self.alpha) - \
#                     torch.xlogy(self.alpha + 1,var) - \
#                     (2*self.beta + self.lam * (values - self.loc)**2) / (2 * var)
        
#         return log_likelihood/ self.Temperature

#         #new try:
#         # for all pos integers, gamma function:
#         # F(a) = (a-1)!

#         #var = self.beta / (self.alpha + 1 + 0.5)  # according to formula on wikipedia
#         #part1 = self.lam ** 0.5 / (2 * np.pi * var) ** 0.5
#         #part2 = self.beta ** self.alpha / np.math.factorial(self.alpha-1)
#         #part3 = (1/var) ** (self.alpha + 1)
#         #part4 = (-2*self.beta - self.lam * (values - self.mu) ** 0.5) / (2*var)
#         #likelihood = part1 * part2 * part3 * np.exp(part4)
#         #log_likelihood = np.log(likelihood)
#         #return log_likelihood / self.Temperature


#     def sample(self, n) -> torch.Tensor:
#         # sample variance from inverse gamma and sample x from normal given the variance
#         var = torch.div(1, dist.Gamma(self.alpha, self.beta).sample((1,)))
#         x = dist.Normal(self.loc, torch.sqrt(var/self.lam)).sample((n,))
#         return x, var


# Spike and slab prior --------------------------------------------------------

class GaussianSpikeNSlab(Prior):
    """
    theta is the parameter for the bernoulli distribution
    z ~ bern(theta)
    if z=0, then x=0 approximately ("spike distribution", modelled as a very narrow normal distribution)
    if z=1, then x ~ p_slab ("slab distribution")
    We use the normal distribution as the slab distribution
    p_theta(x) = theta * p_spike(x) + (1-theta) * p_slab(x)
    """

    def __init__(self, loc_slab: float = 0, scale_slab: float = 1, loc_spike: float = 0, scale_spike: float = 1e-16, theta: float = 0.8, Temperature: float = 1.0,device="cuda"):
        """
        loc_slab: mean of the normal distribution
        scale_slab: standard deviation of the normal distribution
        loc_spike: mean of the spike distribution
        scale_spike: standard deviation of the spike distribution, should be very small to simulate a spike
        theta: parameter of the bernoulli distribution for the mixture of the spike and the slab
        """
        super().__init__()
        self.device = torch.device(device)
        self.theta = torch.tensor(theta, dtype=torch.float32, device = self.device)
        self.loc_slab = torch.tensor(loc_slab, dtype=torch.float32, device = self.device)
        self.scale_slab = torch.tensor(scale_slab, dtype=torch.float32, device = self.device)
        self.loc_spike = torch.tensor(loc_spike, dtype=torch.float32, device = self.device)
        self.scale_spike = torch.tensor(scale_spike, dtype=torch.float32, device = self.device)
        self.Temperature = torch.tensor(Temperature, dtype=torch.float32, device = self.device)
        self.name = "Gaussian_Spike_and_Slab"

        mix = dist.Categorical(probs=torch.tensor([1-self.theta, self.theta],device=self.device))
        comp = dist.Normal(torch.tensor([self.loc_spike, self.loc_slab], device = self.device), torch.tensor([self.scale_spike, self.scale_slab], device = self.device))
        self.spike_n_slab = dist.MixtureSameFamily(mix, comp) 

    def log_likelihood(self, values: torch.Tensor) -> torch.Tensor:
        return self.spike_n_slab.log_prob(values).sum() / self.Temperature

    def sample(self,n) -> torch.Tensor:
        return self.spike_n_slab.sample((n,))


# Customized Laplace and Uniform Mixture prior --------------------------------


## Uniform at the middle, Laplace at the sides, 50% weight on uniform, 25% weight each side.

#   if x < -1:              f(x) = 0.67957*exp(x)
#   if -1 <= x <= 1:        f(x) = 1/4
#   if x > 1:               f(x) = 0.67957*exp(-x)


class MixedLaplaceUniform(Prior):
    """
    A mixture of Laplace and Uniform distributions.
    we use a Uniform fistribution in the middle within interval [-1, 1] and a Laplace distribution at the sides.
    The distribution is continuous. 
    """
    def __init__(self, Temperature:float=1.0):
        super().__init__()
        self.a = torch.exp(torch.tensor(1))/4
        self.Temperature = Temperature
        self.name = "Mixed_Laplace_and_Uniform"
        
    def log_likelihood(self, values: torch.tensor) -> torch.tensor:
        log_likelihoods = torch.where(values < -1, values + torch.log(self.a), torch.where(values <= 1, torch.tensor(np.log(1/4)), -values + torch.log(self.a)))
        return log_likelihoods.sum() / self.Temperature

    def sample(self, size=1) -> torch.tensor:
        """Generates samples from the mixed probability distribution."""
        u = torch.rand(size)
        first_case = torch.log(u/self.a)
        second_case = (u - 1/4)*4 - 1
        third_case = torch.log(1/(1/np.exp(1) - ((u-0.75)/self.a)))
        samples = torch.where(u < 1/4, first_case, 
                              torch.where(u <= 3/4, second_case, third_case))
        return samples



# Pre-train the prior on FashionMNIST -----------------------------------------

## Networks.py

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


""" Base Networks """

""" Fully Connected Neural Network """

""" - 3 Layers (2 Hidden, 1 Output)
    - 100 units each for hidden layers
    - ReLu activations
    """

class FullyConnectedNN(nn.Module):
    def __init__(self, in_features = 28*28, out_features = 10, hidden_units = 100, hidden_layers = 2):
        super(FullyConnectedNN, self).__init__()

        # Create layer list
        self.layers = []

        # Input to first layer
        self.layers.append(nn.Linear(in_features, hidden_units))
        self.layers.append(nn.ReLU())

        # Hidden layers
        for i in range(hidden_layers - 1):
            self.layers.append(nn.Linear(hidden_units, hidden_units))
            self.layers.append(nn.ReLU())
        
        # Output layer
        self.layers.append(nn.Linear(hidden_units, out_features))

        # Convert to sequential
        self.layers = nn.Sequential(*self.layers)

    def forward(self, x):
        x = x.reshape(-1,28*28)
        output = self.layers(x)
        return output


""" Convolutional Neural Network """

""" - 3 Layers (2 Hidden, 1 Output)
    - First two layers: Convolutional Layers with 64 channels, 3x3 convolutions, followed by 2x2 MaxPooling
    - All layers use ReLu activations
    """

class ConvolutionalNN(nn.Module):
    def __init__(self):
        super(ConvolutionalNN, self).__init__()
        
        # Create Network
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 10))
        
    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        output = self.layers(x)
        return output

## BayesianNN.py

In [None]:
""" Bayesian Neural Network """

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import SGD
from torch.optim.lr_scheduler import PolynomialLR
from torchmetrics.functional import calibration_error
from collections import deque, OrderedDict
from tqdm import trange
import copy
from sklearn.metrics import roc_auc_score
from SGLD import SGLD


class BNN_MCMC:
    def __init__(self, dataset_train, network, prior, Temperature = 1.,
     num_epochs = 300, max_size = 100, burn_in = 100, lr = 1e-3, sample_interval = 1, device = "cpu"):
        super(BNN_MCMC, self).__init__()

        # set device 
        self.device = torch.device(device)

        # Hyperparameters and general parameters
        self.Temperature = Temperature
        self.learning_rate = lr
        self.num_epochs = num_epochs
        self.burn_in = burn_in
        self.sample_interval = sample_interval
        self.max_size = max_size


        self.batch_size = 128
        self.print_interval = 50
        
        # Data Loader
        self.data_loader = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True)
        self.sample_size = dataset_train.__len__()

        # Set Prior
        self.prior = prior

        # Initialize the network
        self.network = network.to(self.device)

        # Set optimizer
        self.optimizer = SGLD(self.network.parameters(), lr=self.learning_rate, num_data=self.batch_size, temperature=self.Temperature)

        # Scheduler for polynomially decreasing learning rates
        self.scheduler = PolynomialLR(self.optimizer, total_iters = self.num_epochs, power = 0.5)

        # Deque to store model samples
        self.model_sequence = deque()

    def train(self):
        num_iter = 0
        print('Training Model')

        self.network.train()
        progress_bar = trange(self.num_epochs)

        N = torch.tensor(self.sample_size, device = self.device)
        if self.prior.name == 'Normal Inverse Gamma':
            n_params = 0
            SS_params = 0

        for _ in progress_bar:
            num_iter += 1

            for batch_idx, (batch_x, batch_y) in enumerate(self.data_loader):
                batch_x = batch_x.to(self.device)
                batch_y = batch_y.to(self.device)
                self.network.zero_grad()
                n = len(batch_x)

                # Perform forward pass
                current_logits = self.network(batch_x)

                # Compute the NLL
                nll = N/n*F.nll_loss(F.log_softmax(current_logits, dim=1), batch_y)

                # Compute the log prior
                log_prior = torch.tensor(0,device=self.device, dtype=torch.float32)

                # prior for Normal Inverse Gamma
                if self.prior.name == 'Normal_Inverse_Gamma':
                    param_list = torch.tensor([],device=self.device)
                    for name, param in self.network.named_parameters():
                        if param.requires_grad:
                            param_list = torch.cat((param_list, param.view(-1)))
                            
                    current_var = torch.var(param_list)
                    log_prior += self.prior.log_likelihood(param, current_var).sum()

                else:
                    for name, param in self.network.named_parameters():
                        if param.requires_grad:
                            # param=torch.tensor(param,device=self.device)
                            log_prior += self.prior.log_likelihood(param).sum()
                

                # Calculate the loss
                #loss = N/n*F.nll_loss(F.log_softmax(current_logits, dim=1), batch_y) - log_prior
                loss = nll - log_prior

                # Backpropagate to get the gradients
                loss.backward(retain_graph=True)

                # Update the weights
                self.optimizer.step()

                # Update Metrics according to print_interval
                if batch_idx % self.print_interval == 0:
                    current_logits = self.network(batch_x)
                    current_accuracy = (current_logits.argmax(axis=1) == batch_y).float().mean()
                    progress_bar.set_postfix(loss=loss.item(), acc=current_accuracy.item(),
                    nll_loss=N/n*F.nll_loss(F.log_softmax(current_logits, dim=1), batch_y).item(),
                    log_prior_normalized = - log_prior.item(),
                    lr = self.optimizer.param_groups[0]['lr'])

            # Decrease lr based on scheduler
            self.scheduler.step()
            
            # Save the model samples if past the burn-in epochs according to sampling interval
            if num_iter > self.burn_in and num_iter % self.sample_interval == 0:
                self.model_sequence.append(copy.deepcopy(self.network))
                # self.network.state_dict()

            # If model_sequence to big, delete oldest model
            if len(self.model_sequence) > self.max_size:
                self.model_sequence.popleft()

    def predict_probabilities(self, x: torch.Tensor) -> torch.Tensor:
        self.network.eval()

        # Sum predictions from all models in model_sequence
        estimated_probability = torch.zeros((len(x), 10), device = self.device)

        for model in self.model_sequence:

            self.network.load_state_dict(model.state_dict())
            logits = self.network(x).detach()
            estimated_probability += F.softmax(logits, dim=1)
        
        # Normalize the combined predictions to get average predictions
        estimated_probability /= len(self.model_sequence)

        assert estimated_probability.shape == (x.shape[0], 10)  
        return estimated_probability
    
    def test_accuracy(self,x):
        # test set
        x_test = x[:][0].clone().detach().to(self.device)
        y_test = x[:][1].clone().detach().to(self.device)      

        # predicted probabilities
        class_probs = self.predict_probabilities(x_test)

        # accuracy
        accuracy = (class_probs.argmax(axis=1) == y_test).float().mean()
        return  accuracy.cpu().numpy()

    def test_calibration(self,x):
        # test set
        x_test = x[:][0].clone().detach().to(self.device)
        y_test = x[:][1].clone().detach().to(self.device)       

        # predicted probabilities
        class_probs = self.predict_probabilities(x_test)

        calib_err = calibration_error(class_probs, y_test, n_bins = 30, task = "multiclass", norm="l1", num_classes=10)
        return calib_err.cpu().numpy()

    def test_auroc(self,x):
        # test set
        x_test = x[:][0].clone().detach()
        y_test = x[:][1].clone().detach()         

        # predicted probabilities
        class_probs = self.predict_probabilities(x_test)

        auroc = roc_auc_score(y_test, class_probs, multi_class='ovr')
        return auroc.cpu().numpy()

    def get_metrics(self, x):
        accuracy = self.test_accuracy(x)
        calib_err = self.test_calibration(x)
        auroc = self.test_auroc(x)

        return accuracy, calib_err, auroc

    def get_posterior_stats(self):
        self.network.eval()

        # get weights from all models
        param_flat_all = torch.tensor([],device = self.device)
        for model in self.model_sequence:
            parameters = model.state_dict()
            param_values = list(parameters.values())
            param_flat = torch.cat([v.flatten() for v in param_values])
            param_flat_all.append(param_flat.flatten())

        param_flat_all = torch.cat(param_flat_all)

        # get mean and variance
        mean = torch.mean(param_flat_all, dim=0)
        var = torch.var(param_flat_all, dim=0)


        return mean.cpu().numpy(), var.cpu().numpy()

# Run experiment

In [None]:
"""
This script runs the experiment for our project.

"""

# Importing libraries ---------------------------------------------------------

import numpy as np
import torch
import torchvision.transforms as tr
import pandas as pd

from data import Data
from priors import *
from Networks import *
from BayesianNN import BNN_MCMC

# Setting seeds ---------------------------------------------------------------
torch.manual_seed(0)
device = "cuda"

# Specify the prior -----------------------------------------------------------

# Possible prior choices: 
#       Isotropic_Gaussian, 
#       StudentT_prior
#       Laplace_prior
#       Gaussian_Mixture
#       Normal_Inverse_Gamma
#       GaussianSpikeNSlab
#       MixedLaplaceUniform

prior = GaussianSpikeNSlab()


# Specify the iteration parameters --------------------------------------------

# network list
networks = {"FCNN": FullyConnectedNN(), "CNN": ConvolutionalNN()}

# Temperature list
Temperatures = [0.001, 0.01, 0.1, 1.]


# sample size list
sample_sizes = [3750, 15000, 60000, 120000]

# preallocate pandas dataframe for results
results = pd.DataFrame(columns = [
    "Network", 
    "Sample Size", 
    "Epochs", 
    "Burn in", 
    "sample interval", 
    "Temperature", 
    "Test Accuracy", 
    "Test ECE", 
    "Test AUROC",
    "Posterior mean",
    "Posterior var"],
    index = range(len(networks)*len(Temperatures)*len(sample_sizes)))


#create a dict for the different parameter values
base_epoch, base_burn_in, base_sample_interval, base_samplesize = 50, 10, 2, sample_sizes[-1]
args_dict = [(sample_size, (base_epoch*base_samplesize/sample_size, base_burn_in*base_samplesize/sample_size, base_sample_interval*base_samplesize/sample_size )) for sample_size in sample_sizes]
args_dict = dict(args_dict)

# Run the experiment ----------------------------------------------------------

iteration = 0

for net in networks.keys():
    for T in Temperatures:
        for n in range(len(sample_sizes)):
        

            # print iteration info
            print(50*"-")
            print("Iteration: ", iteration + 1, " of ", len(networks)*len(Temperatures)*len(sample_sizes))
            print("Network:     ", net)
            print("Prior:       ", prior.name)
            print("Temperature: ", T)
            """
            print("Sample size: ", sample_sizes[n])
            print("Epoch:       ", args_dict[sample_sizes[n]][0])
            print("Burn in:     ", args_dict[sample_sizes[n]][1])
            print("Sample interval: ", args_dict[sample_sizes[n]][2])
            """

            # get data
            if sample_sizes[n] == 120000:
                # if sample size is 120000, use data augmentation
                augmentations = tr.Compose([tr.RandomRotation(15)])
                train_data, test_data = Data("MNIST", augmentations = augmentations).get_data(num_train_samples=sample_sizes[n])
            else:
                # subsample original train data if sample size is smaller than 120000
                train_data, test_data = Data("MNIST", augmentations = None).get_data(num_train_samples=sample_sizes[n])


            # run BNN
            model = BNN_MCMC(
                train_data,
                network = networks[net],
                prior=prior,
                Temperature = T,
                num_epochs = int(args_dict[sample_sizes[n]][0]),
                max_size = 20,
                burn_in = int(args_dict[sample_sizes[n]][1]),
                lr = 1e-3,
                sample_interval = int(args_dict[sample_sizes[n]][2]),
                device = device)

            model.train()

            # get test metrics
            acc, ece, auroc = model.get_metrics(test_data)
            post_mean, post_var = model.get_posterior_stats()

            #print("Test accuracy: ", acc)
            #print("Test ECE: ", ece)
            #print("Test AUROC: ", auroc)

            # save results
            results.iloc[iteration, :] = net, sample_sizes[n], args_dict[sample_sizes[n]][0], args_dict[sample_sizes[n]][1], args_dict[sample_sizes[n]][2], T, acc, ece, auroc, post_mean, post_var
            iteration += 1

# save results to csv
results.to_csv(f"results/results_{prior.name}.csv")