# [ASI Project] Weight Uncertainty in Neural Networks  
**Authors**: Miriam Lamari, Francesco Giannuzzo  


In [None]:
import csv
import math
import functools as ft
from typing import Any, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from kaggle_secrets import UserSecretsClient
import wandb


### Wandb utility code

In [2]:
user_secrets = UserSecretsClient()
key = user_secrets.get_secret('wandb-api-key')

wandb.login(key=key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmiriam-lamari2[0m ([33mmiriam-lamari2-eurecom[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## Minibatches
**minibatch_weight**(batch_idx: int, num_batches: int)

In [None]:
def minibatch_weight(batch_idx: int, num_batches: int) -> float:
    return 2 ** (num_batches - batch_idx) / (2 ** num_batches - 1) # definition of pi_i

## Variational Approximation

In [None]:
# Base class for BNN to enable certain behaviour
class BayesianModule(nn.Module):

    def __init__(self):
        super().__init__()

    def kld(self, *args):
        raise NotImplementedError('BayesianModule::kld()')


# Variational approximator for Bayesian Neural Networks
def variational_approximator(model: nn.Module) -> nn.Module:

    def kl_divergence(self) -> Tensor:
        kl = 0
        for module in self.modules():
            if isinstance(module, BayesianModule):
                kl += module.kl_divergence

        return kl

    def elbo(self, inputs: Tensor, targets: Tensor, criterion: Any, n_samples: int, w_complexity: Optional[float] = 1.0) -> Tensor:
        loss = 0
        for _ in range(n_samples):
            outputs = self(inputs)
            loss += criterion(outputs, targets)
            loss += self.kl_divergence() * w_complexity

        return loss / n_samples

    # include `kl_divergence` function to the model
    setattr(model, 'kl_divergence', kl_divergence)

    # include `elbo` function to the model
    setattr(model, 'elbo', elbo)

    return model


## Scale Mixture Prior

In [None]:
# Scale Mixture Prior for Bayesian Neural Networks
class ScaleMixture(nn.Module):

    def __init__(self, pi: float, sigma1: float, sigma2: float) -> None:
        super().__init__()

        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2

        self.normal1 = torch.distributions.Normal(0, sigma1)
        self.normal2 = torch.distributions.Normal(0, sigma2)

    def log_prior(self, w: Tensor) -> Tensor:
        likelihood_n1 = torch.exp(self.normal1.log_prob(w))
        likelihood_n2 = torch.exp(self.normal2.log_prob(w))

        p_scalemixture = self.pi * likelihood_n1 + (1 - self.pi) * likelihood_n2

        log_prob = torch.log(p_scalemixture).sum()

        return log_prob

### Variant: log-sum-exp rather than log

In [None]:
# Scale Mixture Prior for Bayesian Neural Networks
class ScaleMixture(nn.Module):
    def __init__(self, pi: float, sigma1: float, sigma2: float) -> None:
        super().__init__()
        self.pi = pi
        self.sigma1 = sigma1
        self.sigma2 = sigma2

        self.normal1 = torch.distributions.Normal(0, sigma1)
        self.normal2 = torch.distributions.Normal(0, sigma2)

    def log_prior(self, w: Tensor) -> Tensor:
        log_prob_n1 = self.normal1.log_prob(w)
        log_prob_n2 = self.normal2.log_prob(w)

        log_mix1 = torch.log(self.pi) + log_prob_n1
        log_mix2 = torch.log(1.0 - self.pi) + log_prob_n2

        # Stable log-sum-exp over the two components
        log_prob = torch.logsumexp(torch.stack([log_mix1, log_mix2]), dim=0).sum()

        return log_prob


## Gaussian Variational Inference

In [None]:
# Gaussian Variational Approximation for Bayesian Neural Networks
class GaussianVariational(nn.Module):
    def __init__(self, mu: Tensor, rho: Tensor) -> None:
        super().__init__()

        self.mu = nn.Parameter(mu)
        self.rho = nn.Parameter(rho)

        self.w = None
        self.sigma = None

        self.normal = torch.distributions.Normal(0, 1)

    def sample(self) -> Tensor:
        device = self.mu.device
        epsilon = self.normal.sample(self.mu.size()).to(device)
        self.sigma = torch.log(1 + torch.exp(self.rho)).to(device)
        self.w = self.mu + self.sigma * epsilon
        return self.w

    def log_posterior(self) -> Tensor:
        if self.w is None:
            raise ValueError('self.w must have a value.')

        log_const = np.log(np.sqrt(2 * np.pi))
        log_exp = ((self.w - self.mu) ** 2) / (2 * self.sigma ** 2)
        log_posterior = -log_const - torch.log(self.sigma) - log_exp

        return log_posterior.sum()

## Bayesian Linear Layer ##

In [None]:
# Bayesian Linear Layer
class BayesLinear(BayesianModule):

    def __init__(self,
                 in_features: int,
                 out_features: int,
                 prior_pi: Optional[float] = 0.5,
                 prior_sigma1: Optional[float] = 1.0,
                 prior_sigma2: Optional[float] = 0.0025) -> None:

        super().__init__()

        w_mu = torch.empty(out_features, in_features).uniform_(-0.2, 0.2)
        w_rho = torch.empty(out_features, in_features).uniform_(-5.0, -4.0)

        bias_mu = torch.empty(out_features).uniform_(-0.2, 0.2)
        bias_rho = torch.empty(out_features).uniform_(-5.0, -4.0)

        self.w_posterior = GaussianVariational(w_mu, w_rho)
        self.bias_posterior = GaussianVariational(bias_mu, bias_rho)

        self.w_prior = ScaleMixture(prior_pi, prior_sigma1, prior_sigma2)
        self.bias_prior = ScaleMixture(prior_pi, prior_sigma1, prior_sigma2)

        self.kl_divergence = 0.0

    def kld(self, log_prior: Tensor, log_posterior: Tensor) -> Tensor:
        return log_posterior - log_prior

    def forward(self, x: Tensor) -> Tensor:

        w = self.w_posterior.sample()
        b = self.bias_posterior.sample()

        w_log_prior = self.w_prior.log_prior(w)
        b_log_prior = self.bias_prior.log_prior(b)

        w_log_posterior = self.w_posterior.log_posterior()
        b_log_posterior = self.bias_posterior.log_posterior()

        total_log_prior = w_log_prior + b_log_prior
        total_log_posterior = w_log_posterior + b_log_posterior
        self.kl_divergence = self.kld(total_log_prior, total_log_posterior)

        return F.linear(x, w, b)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}

# define transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])


In [None]:
# load / process data for tuning

"""

full_trainset = datasets.MNIST('./data', train=True, download=True, transform=transform)

#train val split

train_size = 50000
val_size = 10000
trainset, valset = random_split(full_trainset, [train_size, val_size])

kwargs = {'shuffle': True, 'num_workers': 2, 'pin_memory': True}
trainloader = DataLoader(trainset, batch_size=128, **kwargs)
valloader = DataLoader(valset, batch_size=128, **kwargs)

testset = datasets.MNIST('./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, **kwargs)

"""

"\n\nfull_trainset = datasets.MNIST('./data', train=True, download=True, transform=transform)\n\n#train val split\n\ntrain_size = 50000\nval_size = 10000\ntrainset, valset = random_split(full_trainset, [train_size, val_size])\n\nkwargs = {'shuffle': True, 'num_workers': 2, 'pin_memory': True}\ntrainloader = DataLoader(trainset, batch_size=128, **kwargs)\nvalloader = DataLoader(valset, batch_size=128, **kwargs)\n\ntestset = datasets.MNIST('./data', train=False, download=True, transform=transform)\ntestloader = DataLoader(testset, batch_size=128, **kwargs)\n\n"

In [None]:
# load / process data

full_trainset = datasets.MNIST('./data', train=True, download=True, transform=transform)

#train-test split

trainset = full_trainset

kwargs = {'shuffle': True, 'num_workers': 2, 'pin_memory': True}
trainloader = DataLoader(trainset, batch_size=128, **kwargs)

testset = datasets.MNIST('./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, **kwargs)


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.48MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 160kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.51MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.75MB/s]


In [None]:
def train_loop(learning_rate, prior_pi, prior_sigma1, prior_sigma2, epochs=25):
    run = wandb.init(project="Project-ASI")
    @variational_approximator
    class BayesianNetwork(nn.Module):
        def __init__(self, input_dim, output_dim):
            super().__init__()
            self.bl1 = BayesLinear(input_dim, 1200, prior_pi, prior_sigma1, prior_sigma2)
            self.bl2 = BayesLinear(1200, 1200, prior_pi, prior_sigma1, prior_sigma2)
            self.bl3 = BayesLinear(1200, output_dim, prior_pi, prior_sigma1, prior_sigma2)

        def forward(self, x):
            x = x.view(-1, 28 * 28)

            x = F.relu(self.bl1(x))
            x = F.relu(self.bl2(x))
            x = self.bl3(x)

            return x

    model = BayesianNetwork(28 * 28, 10).to(device)

    run = wandb.init(project="Project-ASI")

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(reduction='sum')


    min_test_loss = np.Inf
    for epoch in range(epochs):

        train_loss = 0.0
        val_loss = 0.0

        model.train()
        for batch_idx, (data, labels) in enumerate(trainloader):
            data, labels = data.to(device), labels.to(device)

            optimizer.zero_grad()

            pi_weight = minibatch_weight(batch_idx=batch_idx, num_batches=128)

            loss = model.elbo(
                inputs=data,
                targets=labels,
                criterion=criterion,
                n_samples=3,
                w_complexity=pi_weight
            )

            train_loss += loss.item() * data.size(0)

            loss.backward()
            optimizer.step()

        correct = 0
        total = 0

        model.eval()
        with torch.no_grad():
            for batch_idx, (data, labels) in enumerate(testloader):
                data, labels = data.to(device), labels.to(device)

                outputs = model(data)

                pi_weight = minibatch_weight(batch_idx=batch_idx, num_batches=128)

                loss = model.elbo(
                    inputs=data,
                    targets=labels,
                    criterion=criterion,
                    n_samples=3,
                    w_complexity=pi_weight
                )

                val_loss += loss.item() * data.size(0)

                probabilities = F.softmax(outputs)
                _, predicted = torch.max(probabilities.data, 1)

                total += labels.size(0)
                correct += torch.eq(predicted, labels).sum().item()

        accuracy_val = correct / total
        train_loss /= len(trainloader.dataset)
        val_loss /= len(testloader.dataset)

        metrics = {'train_loss': train_loss, 'val_loss': val_loss, 'accuracy_val': accuracy_val, "test_error_val": (1 - accuracy_val)  }
        wandb.log(metrics)


In [12]:
def train_wrapper():
    run = wandb.init(project="Project-ASI")

    return train_loop(
        learning_rate = wandb.config.learning_rate,
        prior_pi = wandb.config.prior_pi,
        prior_sigma1=wandb.config.prior_sigma1,
        prior_sigma2=wandb.config.prior_sigma2
    )

## Tuning Hyperparamters

In [13]:
sweep_configuration = {
     "method": "grid",
     "metric": {"goal": "minimize", "name": "val_loss"},
     'name': "sweep-BBB-Gaussian",
     "parameters": {
         "learning_rate": {'values': [1e-3]},
         "prior_pi": {'values': [0.25, 0.5, 0.75]},
         "prior_sigma1": {'values': [1, math.exp(-1), math.exp(-2)]},
         "prior_sigma2": {'values': [math.exp(-6), math.exp(-7), math.exp(-8)]},
     },
}

#sweep_id = wandb.sweep(sweep=sweep_configuration, project="Project-ASI")
#wandb.agent(sweep_id, function=train_wrapper);

In [14]:
sweep_configuration = {
     "method": "grid",
     "metric": {"goal": "minimize", "name": "val_loss"},
     'name': "sweep-BBB-Gaussian",
     "parameters": {
         "learning_rate": {'values': [1e-4]},
         "prior_pi": {'values': [0.25, 0.5, 0.75]},
         "prior_sigma1": {'values': [1, math.exp(-1), math.exp(-2)]},
         "prior_sigma2": {'values': [math.exp(-6), math.exp(-7), math.exp(-8)]},
     },
}

#sweep_id = wandb.sweep(sweep=sweep_configuration, project="Project-ASI")
#wandb.agent(sweep_id, function=train_wrapper);

In [15]:
best_hyperparameters = {'learning_rate' : 1e-3, 'prior_pi' : 0.5, 'prior_sigma1' : 1.0, 'prior_sigma2' : math.exp(-7)}

In [16]:
wandb.init(project="Project-ASI")

train_loop(
    learning_rate = best_hyperparameters['learning_rate'],
    prior_pi = best_hyperparameters['prior_pi'],
    prior_sigma1 = best_hyperparameters['prior_sigma1'],
    prior_sigma2 = best_hyperparameters['prior_sigma2'],
    epochs = 200
)

[34m[1mwandb[0m: Tracking run with wandb version 0.19.9
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250526_105820-4nwgvpt3[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33msnowy-flower-83[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/miriam-lamari2-eurecom/Project-ASI[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/miriam-lamari2-eurecom/Project-ASI/runs/4nwgvpt3[0m
[34m[1mwandb[0m: uploading output.log; uploading wandb-summary.json; uploading config.yaml
[34m[1mwandb[0m:                                                                                
[34m[1mwandb[0m: 🚀 View run [33msnowy-flower-83[0m at: [34m[4mhttps://wandb.ai/miriam-lamari2-eurecom/Project-ASI/runs/4nwgvpt3[0m
[34m[1mwandb[0m: ⭐️ View project at: [34m[4mhttps://wandb.ai/miriam-lamari2-eurecom/Project-ASI[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 med