# What are BNNs posteriors really like???

In [179]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


### Chargement des données

In [180]:
class FashionMNIST_CSV(Dataset): # Dataset pytorch
    def __init__(self, csv_path, transform=None):
        self.data = pd.read_csv(csv_path).values  # Charger le CSV en NumPy
        self.labels = self.data[:, 0]  # Première colonne = labels
        self.images = self.data[:, 1:].reshape(-1, 1, 28, 28).astype(np.float32)  # Reshape en (N, 1, 28, 28)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return torch.tensor(image), torch.tensor(label, dtype=torch.long)

def get_fashion_mnist_loaders_from_csv(batch_size=128):
    #train-set 60000 images, test-set 10000
    train_dataset = FashionMNIST_CSV("data/fashion-mnist_train.csv", transform=lambda x: (x / 255.0))  # Normalisation [0,1]
    test_dataset = FashionMNIST_CSV("data/fashion-mnist_test.csv", transform=lambda x: (x / 255.0))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


train_loader, test_loader = get_fashion_mnist_loaders_from_csv(batch_size=128)

### Bayesian Neural Network (ResNet quoi)

In [181]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, use_bias=True, activation=nn.ReLU):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=use_bias)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.activation = activation()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=use_bias),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.activation(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return self.activation(out)

class ResNet20(nn.Module):
    def __init__(self, num_classes=10, width=8, activation=nn.ReLU): # width 16 initialement
        super(ResNet20, self).__init__()
        self.num_blocks = 3  # ResNet-20 has 3 blocks per stage
        self.in_channels = width
        self.activation = activation()

        self.conv1 = nn.Conv2d(1, width, kernel_size=3, stride=1, padding=1, bias=True) #1 canal
        self.bn1 = nn.BatchNorm2d(width)

        self.stage1 = self._make_layer(width, self.num_blocks, stride=1)
        self.stage2 = self._make_layer(width * 2, self.num_blocks, stride=2)
        self.stage3 = self._make_layer(width * 4, self.num_blocks, stride=2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(width * 4, num_classes)

    def _make_layer(self, out_channels, num_blocks, stride):
        layers = []
        layers.append(BasicBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(BasicBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.activation(self.bn1(self.conv1(x)))
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.avg_pool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

#### Fonctions utilitaires

In [182]:
def set_weights(model, new_weights):
    """Remplace les poids d'un modèle PyTorch par ceux issus de HMC"""
    state_dict = model.state_dict()  # Dictionnaire des paramètres du modèle
    param_keys = [name_layer for name_layer, _ in list(model.named_parameters())]

    #param_keys = list(state_dict.keys()) # liste des paramètres du modèle

    #with torch.no_grad():
    for i, key in enumerate(param_keys):
        state_dict[key] = new_weights[i]
    model.load_state_dict(state_dict) # charge les nouveaux paramètres

def model_predictions(model, dataloader):
    """Fait des prédictions d'un modèle donné sur tout le dataset et stocke les probabilités"""
    
    probabilities = []
    model.eval() 
    with torch.no_grad():  # Désactive le calcul des gradients pour la prédiction
        for step, (image, label) in enumerate(dataloader):  # itère sur le dataset
            image = image.device()  
            logits = model(image)
            probs = F.softmax(logits, dim=1)  
            probabilities.append(probs)
    return torch.cat(probabilities, dim=0)  # (dataset_size, num_classes)

def BMA_predictions(probabilities):
    """Fait une prédiction moyenne Bayesian Model Average p(y|x, D) = 1/M * sum_i( p(y|x, wi))
    Args:
        probabilities: Tensor  (n_models, dataset_size, num_classes)"""
    
    n_models = probabilities.size(0)  # Nombre de modèles
    # Moyenne des prédictions sur la première dimension (celles des modèles)
    average_predictions = probabilities.mean(dim=0)  # (dataset_size, num_classes)
    
    # Prédiction finale : la classe qui a la probabilité la plus élevée
    class_predict = average_predictions.argmax(dim=1)  # (dataset_size)
    return class_predict

    
def calculate_accuracy(predictions, labels):
    """Calcule l'accuracy en comparant les prédictions aux labels."""
    correct_predictions = (predictions == labels).sum().item()  # Nombre de prédictions correctes
    total_predictions = labels.size(0)  # Nombre total d'exemples
    accuracy = correct_predictions / total_predictions  # Accuracy
    return accuracy   


In [183]:
tensor = torch.tensor(np.array([[1, 2], [3, 4]]))
tensor.size()[0]

2

### Fonctions de densité

In [184]:
def posterior_log_density_func(model, data_loader, weight_decay, device):
    """ 
    Approximation stochastique du posterior: log p(w|D) = log p(D|w) + log p(w) - log p(D)
    p(D) est une constante, on ne la calcule pas
    """
    # Échantillonne un mini-batch
    data, target = next(iter(data_loader))  
    data, target = data.to(device), target.to(device)

    # Calcule la log-vraisemblance (log p(D | w))
    output = model(data)
    loss = F.cross_entropy(output, target, reduction="sum")  # NLL
    log_p_data = -loss  # On prend le négatif car HMC maximise log-likelihood

    # Calcule le log-prior (log p(w))
    log_p_w = -0.5 * sum(torch.sum(p**2) for p in model.parameters()) * weight_decay # + constante

    # f(w) = log p(D | w) + log p(w)
    f_w = log_p_data + log_p_w 
    return f_w


def stochastic_grad_f(model, data_loader, weight_decay, device):
    """Approximation stochastique du gradient ∇f(w) avec un mini-batch"""
    model.zero_grad()
    
    # Échantillonne un mini-batch
    data, target = next(iter(data_loader))  
    data, target = data.to(device), target.to(device)

    # Calcule la log-vraisemblance (log p(D | w))
    output = model(data)
    loss = F.cross_entropy(output, target, reduction="sum")  # NLL
    log_p_data = -loss  # On prend le négatif car HMC maximise log-likelihood

    # Calcule le log-prior (log p(w))
    log_p_w = -0.5 * sum(torch.sum(p**2) for p in model.parameters()) * weight_decay

    # f(w) = log p(D | w) + log p(w)
    f_w = log_p_data + log_p_w  

    # Gradient ∇f(w)
    (-f_w).backward()  
    gradients = [p.grad for p in model.parameters()]
    
    return gradients

### Algorithme MHC minibatch

In [None]:
def sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device):
    """
    Version Stochastic Gradient Leapfrog conforme à l'algorithme du papier.
    """
    for _ in range(n_leapfrog):
        # Étape 1 : Mise à jour du momentum à mi-chemin
        grad_w = stochastic_grad_f(model, data_loader, weight_decay, device)
        for i, p in enumerate(model.parameters()):
            noise = torch.normal(mean=0, std=np.sqrt(2 * eta * delta), size=p.shape, device=device)
            m[i] = m[i] + (delta / 2) * grad_w[i] + noise  # Ajout du bruit SG-HMC

        # Étape 2 : Mise à jour des poids
        with torch.no_grad():
            for i, p in enumerate(model.parameters()):
                p += delta * m[i]

        # Étape 3 : Dernière mise à jour du momentum
        grad_w = stochastic_grad_f(model, data_loader, weight_decay, device)
        for i, p in enumerate(model.parameters()):
            m[i] = m[i] + (delta / 2) * grad_w[i]  # Pas de bruit ici, car déjà ajouté avant
    
    return w, [-mi for mi in m]  # Inversion du momentum pour réversibilité


def SG_HMC(trajectory_length, n_burnin, model, data_loader, delta, n_samples, weight_decay, eta, device):
    """
    Stochastic Gradient Hamiltonian Monte Carlo (SG-HMC)
    
    Args:
        trajectory_length : Longueur de trajectoire
        n_burnin : Nombre d'itérations de burn-in
        model : Réseau de neurones (torch.nn.Module)
        data_loader : DataLoader pour mini-batchs
        delta : Pas d'intégration pour Leapfrog
        n_samples : Nombre d'échantillons à générer
        weight_decay : Coefficient pour le log-prior gaussien
        eta : Coefficient de bruit correctif pour SG-HMC
        device : CPU ou GPU
    
    Returns:
        Liste des échantillons de poids
    """
    n_leapfrog = int(trajectory_length / delta)
    model.to(device)
    
    # Initialisation des poids et des moments
    w = [p for p in model.parameters()]
    m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]
    
    # Burn-in phase
    for _ in tqdm(range(n_burnin), desc="Burn-in"):
        m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]
        w, m = sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device)
    
    # Échantillonnage
    w_samples = []
    for _ in tqdm(range(n_samples), desc="Sampling"):
        m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]
        w_proposed, m_proposed = sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device)
        
        # Metropolis-Hastings correction
        model_copy = model.copy()
        set_weights(model_copy, w_proposed)
        log_acceptance_ratio = (
        posterior_log_density_func(w_proposed) - posterior_log_density_func(w) 
        + 0.5 * (torch.norm(m) ** 2 - torch.norm(m_proposed) ** 2)    )

        if torch.rand(1) < torch.exp(log_acceptance_ratio):
            w = w_proposed
        
        w_samples.append([p.clone().detach() for p in w])
    
    return w_samples  

In [186]:
resnet = ResNet20(num_classes=10)  # Création du modèle
hmc_weights = [torch.randn_like(p) for p in resnet.parameters()]  # Simulation de poids HMC
set_weights(resnet, hmc_weights)  # Injection des nouveaux poids

## Script global

In [187]:
### Hyper-paramètres ###

# Paramètres du prior gaussien :
prior_variance = 1/5
std = np.sqrt(prior_variance)
weight_decay = 1/(std**2) # definition

# Paramètres HMC :
trajectory_length = (np.pi*std)/2 # formule du papier
n_burnin = 10 # 50 dans le papier
delta = 1e-5 # 1e-5, 5e-5, 1e-4 dans le papier
n_samples = 240
eta = 1e-5

# Choix du modèle et des fonctions
model = ResNet20(num_classes=10)
f = posterior_log_density_func(model, train_loader, weight_decay, device)
grad_f = stochastic_grad_f(model, train_loader, weight_decay, device)

# Initialise les poids du modèle suivant le prior
w_init = [torch.normal(mean=0, std=std, size=p.shape) for p in model.parameters()] 
set_weights(model, w_init)

### HMC et BMA predictions ###
w_samples = SG_HMC(
    trajectory_length=trajectory_length,
    n_burnin=n_burnin,
    model=model,
    data_loader=train_loader,
    delta=delta,
    n_samples=n_samples,
    weight_decay=weight_decay,
    eta=eta, 
    device=device
)

# Faire des prédictions avec les échantillons de poids
probabilities = []
for w_sample in w_samples:
    set_weights(model, w_sample)
    probabilities.append(model_predictions(model, test_loader))

probabilities = torch.stack(probabilities)
class_predictions = BMA_predictions(probabilities)
accuracy = calculate_accuracy(class_predictions, test_loader.dataset.targets)
print(f'Accuracy: {accuracy * 100:.2f}%')

Computing the log posterior density function
Computing the gradient of the log posterior density function
Running SG-HMC


Burn-in:   0%|          | 0/10 [00:34<?, ?it/s]


KeyboardInterrupt: 