In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torchvision

# Load MNIST data

In [2]:
mnist_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
    torchvision.transforms.Lambda(lambda im: im.reshape(-1))
])

mnist_train = torchvision.datasets.MNIST(
    root="data", train=True, transform=mnist_transforms, download=True)
mnist_test = mnist_train = torchvision.datasets.MNIST(
    root="data", train=False, transform=mnist_transforms, download=True)

mnist_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_train, batch_size=1000, shuffle=True)

# Define the classification model

This is the same model used by the author's in section 5.1 of the paper. It has two hidden layers of 1200 units each, RELU activations and softmax outputs.

In [3]:
class Weights(nn.Module):
    MU = 0.0
    SIGMA = np.log(1 + np.exp(-5.0))
    
    def __init__(self, n, m):
        super().__init__()
        
        self.mu = nn.Parameter(torch.zeros(n, m))
        rho_init = torch.zeros(n, m)
        rho_init.fill_(-5.0)
        self.rho = nn.Parameter(rho_init)
        
    def sample(self):
        epsilon = torch.randn_like(self.mu)
        
        return self.mu + epsilon * F.softplus(self.rho)
    
    def kl_divergence(self, w):
        # Compute only the parts that influence the gradient
        sigma = F.softplus(self.rho)
        return 0.5 * ((w - self.MU)**2 / self.SIGMA - (w - self.mu)**2 / sigma - sigma.log()).sum()
        

class BBB(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w1 = Weights(1200, 784)
        self.w2 = Weights(1200, 1200)
        self.wo = Weights(10, 1200)
        
    def forward(self, X):
        X = F.relu(F.linear(X, self.w1.sample()))
        X = F.relu(F.linear(X, self.w2.sample()))
        X = F.linear(X, self.wo.sample())
        
        return F.softmax(X, dim=-1)
    
    def elbo(self, X, y, dataset_size):
        w1 = self.w1.sample()
        w2 = self.w2.sample()
        wo = self.wo.sample()
        kld = self.w1.kl_divergence(w1) + self.w2.kl_divergence(w2) + self.wo.kl_divergence(wo)
        
        X = F.relu(F.linear(X, self.w1.sample()))
        X = F.relu(F.linear(X, self.w2.sample()))
        logits = F.linear(X, self.wo.sample())
        
        max, _ = logits.max(dim=-1, keepdim=True)
        o = (logits - max).exp()
        log_likelihood = (torch.gather(logits, -1, y.view(-1, 1)) - max - torch.log(o.sum(dim=-1))).sum()
        
        return (len(logits) / dataset_size) * kld - log_likelihood

In [4]:
bbb = BBB()
optimizer = torch.optim.Adam(bbb.parameters(), lr=0.01)

for epoch in range(2):
    for batch, data in enumerate(mnist_loader):
        imgs, labels = data
        
        optimizer.zero_grad()
        loss = bbb.elbo(imgs, labels, len(mnist_train))
        loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            imgs, labels = next(iter(test_loader))
            classifications = bbb.forward(imgs).argmax(dim=-1)
            accuracy = float((classifications == labels).sum()) / len(labels)
            print(f"Batch {batch}: accuracy {accuracy:.2f} loss {float(loss)}")

Batch 0: accuracy 0.09 loss 114309.40625
Batch 10: accuracy 0.61 loss 89058.25
Batch 20: accuracy 0.84 loss 87201.390625
Batch 30: accuracy 0.87 loss 90258.3671875
Batch 40: accuracy 0.90 loss 82404.0625
Batch 50: accuracy 0.91 loss 78465.9921875
Batch 60: accuracy 0.93 loss 82300.828125
Batch 70: accuracy 0.93 loss 80534.3671875
Batch 0: accuracy 0.93 loss 82542.3359375
Batch 10: accuracy 0.94 loss 76018.8125
Batch 20: accuracy 0.92 loss 77017.34375
Batch 30: accuracy 0.94 loss 77417.390625
Batch 40: accuracy 0.94 loss 74072.390625
Batch 50: accuracy 0.95 loss 73377.65625
Batch 60: accuracy 0.95 loss 76385.9453125
Batch 70: accuracy 0.93 loss 73083.828125
