In [1]:
import torch 
import torch.nn as nn
import torchbnn as bnn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid

import numpy as np 
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
# Convert MNIST images into 4D tensors (#images, height, width, color channel)

transform = transforms.ToTensor()

In [3]:
# Create training data 
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

In [4]:
# Create test data 
test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [5]:
# Params
filter = True
if filter:
    filtered_class = 5

load_model = False

In [6]:
# Remove the class 5 from the data 
if filter:
    filtered_indices = [i for i, (_,label) in enumerate(train_data) if label!=5]
    train_data = torch.utils.data.Subset(train_data, filtered_indices)
else:
    pass

In [7]:
# Create small batch size for images 
train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=10, shuffle=False)

In [8]:
#Define the Model class 
class ConvolutionalNetwork(nn.Module):
    def __init__(self):
        super().__init__()

        # Bayesian Convolutional Layers
        self.conv1 = bnn.BayesConv2d(prior_mu=0, prior_sigma=0.1, in_channels=1, out_channels=6, kernel_size=3, stride=1)
        self.conv2 = bnn.BayesConv2d(prior_mu=0, prior_sigma=0.1, in_channels=6, out_channels=16, kernel_size=3, stride=1)

        # Bayesian Fully Connected Layers
        self.fc1 = bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=16*5*5, out_features=120)      #16 filters, 5x5 size of each output "image" in the conv2 layer
        self.fc2 = bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=120, out_features=84)
        self.fc3 = bnn.BayesLinear(prior_mu=0, prior_sigma=0.1, in_features=84, out_features=10)  

    def forward(self, X):
        # Pass through convolutional and pooling layers with ReLU activation
        X = F.relu(self.conv1(X))
        X = F.max_pool2d(X, 2,2)  #kernel = 2x2, stride = 2 
        X = F.relu(self.conv2(X))
        X = F.max_pool2d(X,2,2)

        # Flatten out the data
        X = X.view(-1, 16*5*5)      # -1 so that we can vary the batch size 

        # Pass through the fully connected layers 
        X = F.relu(self.fc1(X))
        X = F.relu(self.fc2(X))
        X = self.fc3(X)

        return F.log_softmax(X, dim=1)


In [9]:
# Create an instance of the model 
torch.manual_seed(41)
model = ConvolutionalNetwork()

In [10]:
# Select loss function and optimizer 
ce_loss = nn.CrossEntropyLoss()
kl_loss = bnn.BKLLoss(reduction='mean', last_layer_only=False)
kl_weight=0.1
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [11]:
# Variables to track losses
epochs = 5

if load_model:
    model.load_state_dict(torch.load('mnist_bnn.pt'))

else:
    #For loop over epochs
    for i in range(epochs):
        trn_corr = 0
        tst_corr = 0

        #Train
        for b, (X_train, y_train) in enumerate(train_loader):
            b+=1                                            #Start the batch at 1

            # Forward pass 
            y_pred = model(X_train)
            ce = ce_loss(y_pred, y_train)
            kl = kl_loss(model)
            loss = ce+ kl*kl_weight

            # Backward pass - update parameters
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            #Print out some results
            if b%3000==0:
                print(f"Epoch = {i} Batch = {b} Loss = {loss.item()}")

    # Save the model after training 
    torch.save(model.state_dict(), 'mnist_bnn.pt')

    X_test = torch.stack([data[0] for data in train_data])
    y_test = torch.LongTensor([data[1] for data in train_data])

    _, predicted = torch.max(model(X_train).data, 1)
    total = y_train.size(0)
    correct = (predicted == y_train).sum()
    print('- Accuracy: %f %%' % (100 * float(correct) / total))
    print('- CE : %2.2f, KL : %2.2f' % (ce.item(), kl.item()))

Epoch = 0 Batch = 3000 Loss = 0.14555498957633972
Epoch = 1 Batch = 3000 Loss = 0.14096182584762573
Epoch = 2 Batch = 3000 Loss = 0.14253216981887817
Epoch = 3 Batch = 3000 Loss = 0.12186375260353088
Epoch = 4 Batch = 3000 Loss = 0.279089093208313
- Accuracy: 100.000000 %
- CE : 0.00, KL : 1.11


In [12]:
#Test
tst_corr = 0

with torch.no_grad():
    for b, (X_test, y_test) in enumerate(test_loader):
        y_val = model(X_test)
        predicted = torch.max(y_val.data, 1)[1]
        batch_corr = (predicted == y_test).sum()
        tst_corr += batch_corr

print(f"Test accuracy : {(tst_corr/10000)*100} %")

Test accuracy : 87.58000183105469 %


In [13]:
X_test = torch.stack([data[0] for data in test_data])
y_test = torch.LongTensor([data[1] for data in test_data])

In [14]:
torch.argwhere(y_test==5)

tensor([[   8],
        [  15],
        [  23],
        [  45],
        [  52],
        [  53],
        [  59],
        [ 102],
        [ 120],
        [ 127],
        [ 129],
        [ 132],
        [ 152],
        [ 153],
        [ 155],
        [ 162],
        [ 165],
        [ 167],
        [ 182],
        [ 187],
        [ 207],
        [ 211],
        [ 218],
        [ 219],
        [ 240],
        [ 253],
        [ 261],
        [ 283],
        [ 289],
        [ 317],
        [ 319],
        [ 333],
        [ 340],
        [ 347],
        [ 351],
        [ 352],
        [ 356],
        [ 364],
        [ 367],
        [ 375],
        [ 395],
        [ 397],
        [ 406],
        [ 412],
        [ 433],
        [ 460],
        [ 469],
        [ 478],
        [ 483],
        [ 491],
        [ 502],
        [ 509],
        [ 518],
        [ 540],
        [ 570],
        [ 588],
        [ 604],
        [ 618],
        [ 638],
        [ 645],
        [ 654],
        [ 674],
        

In [None]:
n_models = 100
models_result = [model(X_test) for k in range(n_models)]

In [None]:
results = np.zeros((n_models, len(y_test)))     # num. of models, number of test datapoints
for i in range(n_models):
    for j in range(30):
        results[i][j] = models_result[i][j].argmax().item()