In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import torchvision

# Load MNIST data

In [2]:
mnist_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
    torchvision.transforms.Lambda(lambda im: im.reshape(-1))
])

mnist_train = torchvision.datasets.MNIST(
    root="data", train=True, transform=mnist_transforms, download=True)
mnist_test = mnist_train = torchvision.datasets.MNIST(
    root="data", train=False, transform=mnist_transforms, download=True)

mnist_loader = torch.utils.data.DataLoader(mnist_train, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_train, batch_size=1000, shuffle=True)

# Define the classification model

This is the same model used by the author's in section 5.1 of the paper. It has two hidden layers of 1200 units each, RELU activations and softmax outputs.

In [3]:
class Weights(nn.Module):
    MU = 0.0
    SIGMA = np.log(1 + np.exp(-5.0))
    
    def __init__(self, n, m):
        super().__init__()
        
        self.mu = nn.Parameter(torch.zeros(n, m))
        rho_init = torch.zeros(n, m)
        rho_init.fill_(-5.0)
        self.rho = nn.Parameter(rho_init)
        
    def sample(self):
        epsilon = torch.randn_like(self.mu)
        
        return self.mu + epsilon * F.softplus(self.rho)
    
    def kl_divergence(self, w):
        # Compute only the parts that influence the gradient
        sigma = F.softplus(self.rho)
        return 0.5 * ((w - self.MU)**2 / self.SIGMA - (w - self.mu)**2 / sigma - sigma.log()).sum()
        

class BBB(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.w1 = Weights(1200, 784)
        self.w2 = Weights(1200, 1200)
        self.wo = Weights(10, 1200)
        
    def forward(self, X):
        X = F.relu(F.linear(X, self.w1.sample()))
        X = F.relu(F.linear(X, self.w2.sample()))
        X = F.linear(X, self.wo.sample())
        
        return F.softmax(X, dim=-1)
    
    def elbo(self, X, y):
        w1 = self.w1.sample()
        w2 = self.w2.sample()
        wo = self.wo.sample()
        kld = self.w1.kl_divergence(w1) + self.w2.kl_divergence(w2) + self.wo.kl_divergence(wo)
        
        X = F.relu(F.linear(X, self.w1.sample()))
        X = F.relu(F.linear(X, self.w2.sample()))
        logits = F.linear(X, self.wo.sample())
        
        max, _ = logits.max(dim=-1, keepdim=True)
        o = F.softmax(logits - max, dim=-1)
        log_likelihood = (torch.gather(logits, -1, y.view(-1, 1)) - max - torch.log(o.sum(dim=-1))).sum()
        
        return kld - log_likelihood

In [4]:
bbb = BBB()
optimizer = torch.optim.Adam(bbb.parameters(), lr=0.0001)

for epoch in range(5):
    for batch, data in enumerate(mnist_loader):
        imgs, labels = data
        
        loss = bbb.elbo(imgs, labels)
        loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            imgs, labels = next(iter(test_loader))
            classifications = bbb.forward(imgs).argmax(dim=-1)
            accuracy = float((classifications == labels).sum()) / len(labels)
            print(f"Batch {batch}: accuracy {accuracy:.2f} loss {float(loss)}")

Batch 0: accuracy 0.11 loss 5983774.5
Batch 10: accuracy 0.10 loss 5983862.0
Batch 20: accuracy 0.10 loss 5983708.0
Batch 30: accuracy 0.09 loss 5983533.5
Batch 40: accuracy 0.10 loss 5982471.5
Batch 50: accuracy 0.09 loss 5981312.5
Batch 60: accuracy 0.08 loss 5979854.0
Batch 70: accuracy 0.08 loss 5978455.5
Batch 0: accuracy 0.11 loss 5977111.0
Batch 10: accuracy 0.23 loss 5975173.5
Batch 20: accuracy 0.24 loss 5973715.0
Batch 30: accuracy 0.22 loss 5972642.5
Batch 40: accuracy 0.37 loss 5970956.5
Batch 50: accuracy 0.30 loss 5969783.0
Batch 60: accuracy 0.19 loss 5968726.0
Batch 70: accuracy 0.42 loss 5967480.0
Batch 0: accuracy 0.47 loss 5966655.5
Batch 10: accuracy 0.43 loss 5965555.5
Batch 20: accuracy 0.42 loss 5964541.0
Batch 30: accuracy 0.50 loss 5962734.5
Batch 40: accuracy 0.43 loss 5961326.5
Batch 50: accuracy 0.57 loss 5959913.0
Batch 60: accuracy 0.49 loss 5958515.5
Batch 70: accuracy 0.33 loss 5957094.5
Batch 0: accuracy 0.43 loss 5955914.5
Batch 10: accuracy 0.51 loss 