## What are BNNs posterios really like?

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import copy

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


## Chargement des données

Mettre fashion-mnist_train et test.csv dans data/

In [8]:
class FashionMNIST_CSV(Dataset): # Dataset pytorch
    def __init__(self, csv_path, transform=None):
        self.data = pd.read_csv(csv_path).values  # Charger le CSV en NumPy
        self.labels = self.data[:, 0]  # Première colonne = labels
        self.images = self.data[:, 1:].reshape(-1, 1, 28, 28).astype(np.float32)  # Reshape en (N, 1, 28, 28)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return torch.tensor(image), torch.tensor(label, dtype=torch.long)

def get_fashion_mnist_loaders_from_csv(batch_size=128):
    #train-set 60000 images, test-set 10000
    train_dataset = FashionMNIST_CSV("../data/fashion-mnist_train.csv", transform=lambda x: (x / 255.0))  # Normalisation [0,1]
    test_dataset = FashionMNIST_CSV("../data/fashion-mnist_test.csv", transform=lambda x: (x / 255.0))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader


## Bayesian Neural Network (ResNet quoi)

In [9]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, use_bias=True, activation=nn.ReLU):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=use_bias)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.activation = activation()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=use_bias),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.activation(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return self.activation(out)

class ResNet20(nn.Module):
    def __init__(self, num_classes=10, width=8, activation=nn.ReLU): # width 16 initialement
        super(ResNet20, self).__init__()
        self.num_blocks = 3  # ResNet-20 has 3 blocks per stage
        self.in_channels = width
        self.activation = activation()

        self.conv1 = nn.Conv2d(1, width, kernel_size=3, stride=1, padding=1, bias=True) #1 canal
        self.bn1 = nn.BatchNorm2d(width)

        self.stage1 = self._make_layer(width, self.num_blocks, stride=1)
        self.stage2 = self._make_layer(width * 2, self.num_blocks, stride=2)
        self.stage3 = self._make_layer(width * 4, self.num_blocks, stride=2)

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(width * 4, num_classes)

    def _make_layer(self, out_channels, num_blocks, stride):
        layers = []
        layers.append(BasicBlock(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(BasicBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.activation(self.bn1(self.conv1(x)))
        out = self.stage1(out)
        out = self.stage2(out)
        out = self.stage3(out)
        out = self.avg_pool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

## Test de l'architecture en supervisé

In [10]:
# Initialisation du modèle
model = ResNet20(num_classes=10).to(device)

# Définition de la fonction de perte et de l'optimiseur
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train(model, train_loader, criterion, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")

def evaluate(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy on test set: {accuracy:.2f}%')
    return accuracy

# Initialise le dataloader
batch_size = 128*4
train_loader, test_loader = get_fashion_mnist_loaders_from_csv(batch_size=batch_size)

# Entraînement du modèle
train(model, train_loader, criterion, optimizer, device, epochs=10)

# Évaluation sur le dataset de test
evaluate(model, test_loader, device)

Epoch [1/10], Loss: 0.8066


KeyboardInterrupt: 

92% accuracy en 10 epoches, l'architecture semble bonne

## Fonctions utilitaires

In [35]:
def set_weights(model, new_weights):
    """Remplace les poids d'un modèle PyTorch par ceux issus de HMC"""
    state_dict = model.state_dict()  # Dictionnaire des paramètres du modèle
    param_keys = [name_layer for name_layer, _ in list(model.named_parameters())]

    #param_keys = list(state_dict.keys()) # liste des paramètres du modèle

    #with torch.no_grad():
    for i, key in enumerate(param_keys):
        state_dict[key] = new_weights[i]
    model.load_state_dict(state_dict) # charge les nouveaux paramètres

def model_predictions(model, dataloader):
    """Fait des prédictions d'un modèle donné sur tout le dataset et stocke les probabilités"""

    probabilities = []
    model.eval()
    with torch.no_grad():  # Désactive le calcul des gradients pour la prédiction
        for step, (image, label) in enumerate(dataloader):  # itère sur le dataset
            image = image.to(device)
            logits = model(image)
            probs = F.softmax(logits, dim=1)
            probabilities.append(probs)
    return torch.cat(probabilities, dim=0)  # (dataset_size, num_classes)

def BMA_predictions(probabilities):
    """Fait une prédiction moyenne Bayesian Model Average p(y|x, D) = 1/M * sum_i( p(y|x, wi))
    Args:
        probabilities: Tensor  (n_models, dataset_size, num_classes)"""

    n_models = probabilities.size(0)  # Nombre de modèles
    # Moyenne des prédictions sur la première dimension (celles des modèles)
    average_predictions = probabilities.mean(dim=0)  # (dataset_size, num_classes)

    # Prédiction finale : la classe qui a la probabilité la plus élevée
    class_predict = average_predictions.argmax(dim=1)  # (dataset_size)
    return class_predict


def calculate_accuracy(predictions, labels):
    """Calcule l'accuracy en comparant les prédictions aux labels."""
    correct_predictions = (predictions == labels).sum().item()  # Nombre de prédictions correctes
    total_predictions = labels.size(0)  # Nombre total d'exemples
    accuracy = correct_predictions / total_predictions  # Accuracy
    return accuracy


Fonctions de densité

In [13]:
def posterior_log_density_func(model, data_loader, weight_decay, device):
    """
    Approximation stochastique du posterior: log p(w|D) = log p(D|w) + log p(w) - log p(D)
    p(D) est une constante, on ne la calcule pas
    """
    # Échantillonne un mini-batch
    data, target = next(iter(data_loader))
    data, target = data.to(device), target.to(device)

    # Calcule la log-vraisemblance (log p(D | w))
    output = model(data)
    loss = F.cross_entropy(output, target, reduction="sum")  # NLL
    log_p_data = -loss  # On prend le négatif car HMC maximise log-likelihood

    # Calcule le log-prior (log p(w))
    log_p_w = -0.5 * sum(torch.sum(p**2) for p in model.parameters()) * weight_decay # + constante

    # f(w) = log p(D | w) + log p(w)
    f_w = log_p_data + log_p_w
    return f_w


def stochastic_grad_f(model, data_loader, weight_decay, device):
    """Approximation stochastique du gradient ∇f(w) avec un mini-batch"""
    model.zero_grad()

    # Échantillonne un mini-batch
    data, target = next(iter(data_loader))
    data, target = data.to(device), target.to(device)

    # Calcule la log-vraisemblance (log p(D | w))
    output = model(data)
    loss = F.cross_entropy(output, target, reduction="sum")  # NLL
    log_p_data = -loss  # On prend le négatif car HMC maximise log-likelihood

    # Calcule le log-prior (log p(w))
    log_p_w = -0.5 * sum(torch.sum(p**2) for p in model.parameters()) * weight_decay

    # f(w) = log p(D | w) + log p(w)
    f_w = log_p_data + log_p_w

    # Gradient ∇f(w)
    (-f_w).backward()
    gradients = [p.grad for p in model.parameters()]

    return gradients

## Algorithme MHC minibatch

In [65]:

#### CONDITION DE METROPOLIS VERSION GPT


def sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device):
    """
    Stochastic Gradient Leapfrog intégrant le bruit correctif pour SG-HMC.
    """
    for _ in range(n_leapfrog):
        # Étape 1 : Mise à jour du momentum
        grad_w = stochastic_grad_f(model, data_loader, weight_decay, device)
        for i, p in enumerate(model.parameters()):
            noise = torch.normal(mean=0, std=np.sqrt(2 * eta * delta), size=p.shape, device=device)
            m[i] += (delta / 2) * grad_w[i] + noise

        # Étape 2 : Mise à jour des poids
        with torch.no_grad():
            for i, p in enumerate(model.parameters()):
                p += delta * m[i]

        # Étape 3 : Dernière mise à jour du momentum
        grad_w = stochastic_grad_f(model, data_loader, weight_decay, device)
        for i, p in enumerate(model.parameters()):
            m[i] += (delta / 2) * grad_w[i]

    return w, [-mi for mi in m]  # Inversion du momentum pour réversibilité

def SG_HMC(trajectory_length, n_burnin, model, data_loader, delta, n_samples, weight_decay, eta, device):
    """
    Stochastic Gradient Hamiltonian Monte Carlo (SG-HMC)
    """
    n_leapfrog = int(trajectory_length / delta)
    model.to(device)

    # Initialisation des poids et momentums
    w = [p.clone().detach() for p in model.parameters()]
    m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]

    # Burn-in phase
    for _ in tqdm(range(n_burnin), desc="Burn-in"):
        m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]
        w, m = sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device)

    # Échantillonnage
    w_samples = []
    n_acceptations = 0
    for _ in tqdm(range(n_samples), desc="Sampling"):
        m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]
        w_proposed, m_proposed = sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device)

        # Calcul du log-ratio d'acceptation
        model_copy = copy.deepcopy(model).to(device)
        set_weights(model_copy, w_proposed)

        log_acceptance_ratio = (
            posterior_log_density_func(model_copy, data_loader, weight_decay, device)
            - posterior_log_density_func(model, data_loader, weight_decay, device)
            + 0.5 * (sum(mi.pow(2).sum() for mi in m) - sum(mi_prop.pow(2).sum() for mi_prop in m_proposed))
        )

        # Test d'acceptation Metropolis-Hastings
        if torch.rand(1).to(device) < torch.exp(log_acceptance_ratio):
            w = [p.clone().detach() for p in w_proposed]
            n_acceptations +=1

        w_samples.append([p.clone().detach() for p in w])
    print( "n_acceptations : ", n_acceptations)
    return w_samples


In [73]:

# CONDITION DE METOPOLIS VERSION LE PAPIER

def sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device):
    """
    Stochastic Gradient Leapfrog intégrant le bruit correctif pour SG-HMC.
    """
    for _ in range(n_leapfrog):
        # Étape 1 : Mise à jour du momentum
        grad_w = stochastic_grad_f(model, data_loader, weight_decay, device)
        for i, p in enumerate(model.parameters()):
            noise = torch.normal(mean=0, std=np.sqrt(2 * eta * delta), size=p.shape, device=device)
            m[i] += (delta / 2) * grad_w[i] + noise

        # Étape 2 : Mise à jour des poids
        with torch.no_grad():
            for i, p in enumerate(model.parameters()):
                p += delta * m[i]

        # Étape 3 : Dernière mise à jour du momentum
        grad_w = stochastic_grad_f(model, data_loader, weight_decay, device)
        for i, p in enumerate(model.parameters()):
            m[i] += (delta / 2) * grad_w[i]

    return w, [mi for mi in m]  # Inversion du momentum pour réversibilité

def SG_HMC(trajectory_length, n_burnin, model, data_loader, delta, n_samples, weight_decay, eta, device):
    """
    Stochastic Gradient Hamiltonian Monte Carlo (SG-HMC)
    """
    n_leapfrog = int(trajectory_length / delta)
    model.to(device)

    # Initialisation des poids et momentums
    w = [p.clone().detach() for p in model.parameters()]
    m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]

    # Burn-in phase
    for _ in tqdm(range(n_burnin), desc="Burn-in"):
        m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]
        w, m = sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device)

    # Échantillonnage
    w_samples = []
    n_acceptations = 0
    for _ in tqdm(range(n_samples), desc="Sampling"):
        m = [torch.normal(mean=torch.zeros_like(p), std=torch.ones_like(p)) for p in model.parameters()]
        w_proposed, m_proposed = sg_leapfrog(w, m, delta, n_leapfrog, model, data_loader, weight_decay, eta, device)

        # Calcul du log-ratio d'acceptation
        model_copy = copy.deepcopy(model).to(device)
        set_weights(model_copy, w_proposed)

        f_ratio = posterior_log_density_func(model_copy, data_loader, weight_decay, device) / \
                  posterior_log_density_func(model, data_loader, weight_decay, device)

        p_accept = torch.min(
            torch.tensor(1.0, device=device),
            f_ratio * torch.exp(0.5 * (sum(mi.pow(2).sum() for mi in m) - sum(mi_prop.pow(2).sum() for mi_prop in m_proposed)))
        )

        # Test d'acceptation Metropolis-Hastings
        if torch.rand(1).to(device) < p_accept:
            w = [p.clone().detach() for p in w_proposed]
            print("poids acceptés")
            n_acceptations +=1

        w_samples.append([p.clone().detach() for p in w])
    print( "n_acceptations : ", n_acceptations)
    return w_samples


## Script global

In [72]:
### Hyper-paramètres ###

# Paramètres du prior gaussien :
prior_variance = 1/5
std = np.sqrt(prior_variance)
weight_decay = 1/(std**2) # definition

# Paramètres HMC :
trajectory_length = (np.pi*std)/2 # formule du papier
n_burnin = 20 # 50 dans le papier
delta = 5e-2 # 1e-5, 5e-5, 1e-4 dans le papier, 1e-3 = 3h,
n_samples = 300
eta = 1e-6

# Initialise le dataloader
batch_size = 128*4
train_loader, test_loader = get_fashion_mnist_loaders_from_csv(batch_size=batch_size)

# Choix du modèle et des fonctions
model = ResNet20(num_classes=10).to(device)
f = posterior_log_density_func(model, train_loader, weight_decay, device)
grad_f = stochastic_grad_f(model, train_loader, weight_decay, device)

# Initialise les poids du modèle suivant le prior
w_init = [torch.normal(mean=0, std=std, size=p.shape) for p in model.parameters()]
set_weights(model, w_init)

### HMC et BMA predictions ###
w_samples = SG_HMC(
    trajectory_length=trajectory_length,
    n_burnin=n_burnin,
    model=model,
    data_loader=train_loader,
    delta=delta,
    n_samples=n_samples,
    weight_decay=weight_decay,
    eta=eta,
    device=device
)

Burn-in: 100%|██████████| 20/20 [00:30<00:00,  1.52s/it]
Sampling: 100%|██████████| 300/300 [07:29<00:00,  1.50s/it]

n_acceptations :  0





In [60]:
import sys
print(f" Taille en mémoire des poids du ResNet : {sys.getsizeof(w_samples) / 1000 :.2f} GB")

 Taille en mémoire des poids du ResNet : 2.52 GB


In [63]:
# Faire des prédictions avec les échantillons de poids
probabilities = torch.zeros(len(w_samples), len(test_loader.dataset), 10, device=device)

for i, w_sample in enumerate(tqdm(w_samples, desc="Processing samples")):
    set_weights(model, w_sample)
    probabilities[i] = model_predictions(model, test_loader)  # Remplissage direct

class_predictions = BMA_predictions(probabilities)

accuracy = calculate_accuracy(class_predictions, torch.tensor(test_loader.dataset.labels).to(device))
print(f'\nAccuracy: {accuracy * 100:.2f}%')

Processing samples: 100%|██████████| 300/300 [01:45<00:00,  2.83it/s]


Accuracy: 10.00%





## Expériences du papier à reproduire

Idées:
- 3 chaînes vs 1 chaine
- Différents hyper paramètres à compute fixé
- visualisation du postérieur avec 3 points en 2D ! Très stylé