###### Assignemnt: Bayesian Convolutional Network with pyro

**Objective:** Implement, train and evaluate a Bayesian Lenet5 model on the MNIST dataset. 

**Requirements:**

- Use `torchvision.datasets.MNIST` to obtain the training and test data. Only use digits 4 and 9 (discard the others)
- Implement a bayesian neural network using `pyro` based on the Lenet5 convolutional architecture
- Use a bernoulli likelihood, a diagonal normal for the approximate posterior and a diagonal normal prior. Use the Mean Field Trace ELBO
- Evaluate the performance of the BNN using precision/recall curves and uncertainty calibration plots [2, 3, 4]
- Study the influence of the scale of the prior and the initial scale of the approximate posterior
- Compare your best bayesian model with a Deterministic Lenet5 
- Discuss your results! 


**References**
1. https://www.kaggle.com/blurredmachine/lenet-architecture-a-complete-guide
2. https://arxiv.org/pdf/1703.04977.pdf (Section 5.1)
3. https://arxiv.org/pdf/2007.06823.pdf (Section 9)
4. https://arxiv.org/pdf/1706.04599.pdf 

**Deadline**

17:30, June 16th, 2021


In [581]:
#try to do a simple netowrk to visualize epistemic uncertanty
%load_ext autoreload
%autoreload 2
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Subset
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import pyro
display(pyro.__version__)
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

np.random.seed(66)
torch.manual_seed(66)
pyro.set_rng_seed(66) # Fo

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


'1.6.0'

In [390]:
def getIdxForLabels(dataset, labels):
    mask = dataset.targets == labels[0]
    for label in labels[1:]:
        mask |= dataset.targets == label
    return np.where(mask)

def renameLabelsInOrder(dataset, labels):
    i=0
    for label in labels:
        dataset.targets[dataset.targets == label] = i
        i+=1

mnist_train = torchvision.datasets.MNIST('./Datasets', train=True, download=True,
                                        transform=torchvision.transforms.ToTensor())
mnist_test = torchvision.datasets.MNIST('./Datasets', train=False, download=True,
                                        transform=torchvision.transforms.ToTensor())
keepClasses = [4, 9]

train_idx = getIdxForLabels(mnist_train, keepClasses)
test_idx = getIdxForLabels(mnist_test, keepClasses)

train_dataset = Subset(mnist_train, train_idx[0])
test_dataset = Subset(mnist_test, test_idx[0])

renameLabelsInOrder(mnist_train, keepClasses)
renameLabelsInOrder(mnist_test, keepClasses)


print(len(train_dataset), " elements for training.")
print(len(test_dataset), " elements for testing.")

11791  elements for training.
1991  elements for testing.


In [252]:
fig, ax = plt.subplots(1, 10, figsize=(6, 1), tight_layout=True)
for i in range(10):
    image, label = train_dataset[np.random.randint(len(train_dataset))]
    ax[i].imshow(image.numpy()[0, :, :], cmap=plt.cm.Greys_r)
    ax[i].axis('off')
    ax[i].set_title(label)

<IPython.core.display.Javascript object>

In [354]:
from torch.utils.data import DataLoader, SubsetRandomSampler

validationSize = 0.85
trainBatchSize = 128
validBatchSize = 128
testBatchSize = 256



idx = list(range(len(train_dataset)))
np.random.shuffle(idx)
split = int(validationSize*len(idx))

train_loader = DataLoader(train_dataset, batch_size=trainBatchSize, drop_last=False,
                          sampler=SubsetRandomSampler(idx[:split]))

valid_loader = DataLoader(train_dataset, batch_size=validBatchSize, drop_last=False,
                          sampler=SubsetRandomSampler(idx[split:]))

test_loader = DataLoader(test_dataset, batch_size=testBatchSize, drop_last=False,
                          shuffle=True)

In [254]:
fig, ax = plt.subplots(figsize=(5, 3), tight_layout=True)
ax.hist(train_dataset.dataset.targets[train_dataset.indices].numpy(), bins=2)
ax.set_title("Histograma de Clases");

<IPython.core.display.Javascript object>

In [255]:
from NeuralNetworks import Lenet5Deterministic
deterministicModel = Lenet5Deterministic()

In [391]:
deterministicLearningRate = 3e-4

loss = nn.BCELoss()
optimizer = torch.optim.Adam(deterministicModel.parameters(), lr=deterministicLearningRate)

def train_loop(dataloader, model, loss_fn, optimizer, n_samples):
    size = n_samples
    epochLoss = 0
    nSamples = 0
    for batch, (X, y) in enumerate((dataloader)):
        # Compute prediction and loss
        y = y.type(torch.float)
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epochLoss += loss.item()
        nSamples += len(X)

    return epochLoss/nSamples
            

def valid_loop(dataloader, model, loss_fn, n_samples):
    size = n_samples
    test_loss, correct = 0, 0
    epochLoss = 0
    nSamples = 0
    correct = 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            y = y.type(torch.float)
            test_loss += loss_fn(pred, y).item()
            correct += (torch.round(pred) == y).type(torch.float).sum().item()
            epochLoss += test_loss
            nSamples += len(X)

    test_loss /= size
    correct /= size
    print(f"         Validation Error: Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}", end='\r')
    return epochLoss/nSamples
    
def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0
    print(size)
    with torch.no_grad():
        for X, y in dataloader:
            y = y.type(torch.float)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (torch.round(pred) == y).type(torch.float).sum().item()

    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [392]:
fig, ax = plt.subplots(figsize=(10, 4), tight_layout=True, dpi=80)

def update_plot(k, trainingLosses, validLosses):
    ax.cla()
    ax.plot(range(k), trainingLosses[:k], label="training loss")
    ax.plot(range(k), validLosses[:k], label="validation loss")

    #ax.set_yscale('log')
    ax.set_ylabel('Loss')
    ax.set_ylabel('Epoch')
    ax.legend()
    fig.canvas.draw()

<IPython.core.display.Javascript object>

In [393]:
from IPython.display import display, clear_output

epochs = 30

trainingLosses = np.zeros(epochs)
validLosses = np.zeros(epochs)

for t in range(epochs):
    #clear_output(wait=True)
    print(f"Epoch {t+1} ", end='\r')
    trainingLosses[t] = train_loop(train_loader, deterministicModel, loss, optimizer, split)
    validLosses[t] = valid_loop(valid_loader, deterministicModel, loss, len(train_loader.dataset)-split)
    update_plot(t, trainingLosses, validLosses)
    

test_loop(test_loader, deterministicModel, loss)


print("Done!")

Epoch 8  0.9898247597512719Accuracy: 99.0%, Avg loss: 0.000209

KeyboardInterrupt: 

In [420]:
def getCollapsedDecision(samples, delta):
    #return (samples.mean(dim=0))
    return (samples.mean(dim=0).numpy() > delta).astype(int)

In [421]:
from NeuralNetworks import Lenet5Bayesian
bayesianModel = Lenet5Bayesian(prior_scale=1., isFEBayesian=True, prior_scaleFE=3.)

In [422]:
pyro.enable_validation(True)
fig, ax = plt.subplots(figsize=(10, 4), tight_layout=True, dpi=80)

def update_plot(k, trainingLosses, validLosses):
    ax.cla()
    ax.plot(range(k), trainingLosses[:k], label="training loss")
    ax.plot(range(k), validLosses[:k], label="validation loss")

    ax.set_yscale('log')
    ax.set_ylabel('ELBO')
    ax.set_xlabel('Epoch')
    ax.legend()
    fig.canvas.draw()

<IPython.core.display.Javascript object>

In [423]:
# Turn this on for additional debugging
pyro.enable_validation(False) 
pyro.clear_param_store() 
epochs = 30

# Create a guide
from pyro.infer.autoguide import AutoDiagonalNormal
guide = AutoDiagonalNormal(bayesianModel, init_scale=1.01e-3)
predictive = pyro.infer.Predictive(bayesianModel, guide=guide, num_samples=100)
# Create SVI object
svi = pyro.infer.SVI(bayesianModel, guide, 
                     optim=pyro.optim.ClippedAdam({'lr':1e-3, 'clip_norm':1.0}), # Optimizer
                     loss=pyro.infer.TraceMeanField_ELBO()) # Loss function 

epoch_loss = np.zeros(shape=(epochs,))
validLoss = np.zeros(shape=(epochs,))


for k in tqdm(range(len(epoch_loss))):
    acumloss = 0
    n = 0
    for X, y in train_loader:
        acumloss += svi.step(x=X, y=y.type(torch.float)) # Actual training step
        n+=(X.shape[0])
    epoch_loss[k] = acumloss/n

    #update_plot(k, epoch_loss, validLoss)   
    res = 0
    total = 0
    acumloss = 0
    for X, y in valid_loader:
        acumloss += svi.evaluate_loss(X, y.type(torch.float))
        ans = getCollapsedDecision(predictive(X)['obs'].detach(), 0.5)
        correct = accuracy_score(y, ans, normalize=False)
        res += correct
        total += X.shape[0]
    validLoss[k] = acumloss/total
    #print(epoch_loss, validLoss)
    update_plot(k, epoch_loss, validLoss)
    print("Validation Score:", res/total*100, end='\r')
 
            

  0%|          | 0/30 [00:00<?, ?it/s]

Validation Score: 99.09553420011305

Cuando el modelo se entrena demasiado (ya sea durante muchas epocas o con un learning rate alto) el score baja a 0.5 por algun motivo. Que sucede cuando las redes bayesianas se sobreajustan?

In [424]:
predictive = pyro.infer.Predictive(bayesianModel, guide=guide, num_samples=100)

aciertos = 0
total = 0
errors= []
samples = []
with torch.no_grad():
    for X, y in tqdm(test_loader):
        out = predictive(X)['obs'].detach()
        ans = getCollapsedDecision(out, 0.5)
        
        hitormiss = (ans == y)
        
        fails = np.where(hitormiss == False)        
        correct = np.where(hitormiss == True)
        
        aciertos += accuracy_score(y, ans, normalize=False)
        total += X.shape[0]


        for error in X[fails[0]]:
            errors.append(error)
        for error in out.numpy().T[fails[0]]:
            samples.append(error)
        
        print("Test score:", aciertos/total*100, end='\r')


  0%|          | 0/8 [00:00<?, ?it/s]

Test score: 99.09593169261677

In [396]:
nerrors = len(errors)
print(nerrors)
rows = int(np.floor(np.sqrt(nerrors)))
cols = nerrors//rows
fig, ax = plt.subplots(rows, cols, figsize=(10, 5), tight_layout=True)
for i in range(rows):
    for j in range(cols):
        image = errors[i*rows+j].squeeze(0).squeeze(0).unsqueeze(2)
        conf = np.average(samples[i*rows+j])
        ax[i,j].imshow(image.numpy(), cmap=plt.cm.Greys_r)
        ax[i,j].axis('off')        
        ax[i,j].set_title(conf)

20


<IPython.core.display.Javascript object>

In [575]:
predictive = pyro.infer.Predictive(bayesianModel, guide=guide, num_samples=100)
bins = 5
#delta = 0.5
intervals = np.linspace(0.0, 1.0, bins)
frequency = np.zeros((bins))
total = np.zeros((bins))
with torch.no_grad():
    for X, y in tqdm(test_loader):
        samplesAvg = predictive(X)['obs'].detach().type(torch.float).mean(dim=0)
        for i, conf in enumerate(samplesAvg):
            ans = (torch.round(conf)).item()
            if ans==0:
                continue
            else:
                conf = 2*conf - 1.0
            for j, interval in enumerate(intervals):
                if conf < interval:
                    
                    if ans == y[i].item():
                        frequency[j]+=1
                    total[j]+=1
                    break
                    

  0%|          | 0/8 [00:00<?, ?it/s]

In [570]:
from sklearn.calibration import calibration_curve
TOTALTESTPRED = torch.tensor([])
TOTALTESTY = torch.tensor([])
with torch.no_grad():
    for X, y in tqdm(test_loader):
        samplesAvg = predictive(X)['obs'].detach().type(torch.float).mean(dim=0)
        ans = (torch.round(samplesAvg))
        
        TOTALTESTPRED = torch.cat((TOTALTESTPRED,samplesAvg))
        TOTALTESTY = torch.cat((TOTALTESTY,y))




prob_true, prob_pred = calibration_curve(TOTALTESTY, TOTALTESTPRED, n_bins=5)
fig, ax = plt.subplots()
ax.plot(prob_pred, prob_true, label="Clase", marker='.');
ax.plot(np.linspace(0, 1, 2),np.linspace(0, 1, 2), label='y = x')


  0%|          | 0/8 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f832abeceb0>]

In [569]:
print(prob_true.shape)

(5,)


In [576]:
print(avg_probability)
frequencyByClass

[0.00000000e+00 1.60000086e-01 1.16000009e+00 5.00000048e+00
 1.33000002e+01 2.84820282e+02]


array([  0.,   1.,   4.,   5.,   8.,  26.,  28., 461.])

In [577]:
freq = frequency/total
freq[np.isnan(freq)] = 0
freq


  freq = frequency/total


array([0.        , 0.5       , 1.        , 0.88235294, 0.99365079])

In [580]:
fig, ax = plt.subplots( figsize=(6,4), tight_layout=True)
ax.plot(intervals, freq, label="Clase", marker='.');
ax.plot(np.linspace(0, 1, 2),np.linspace(0, 1, 2), label='y = x')
ax.set_ylabel("Frequency")
ax.set_xlabel("Probability")
ax.legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f83596b2cd0>

In [579]:
from sklearn.metrics import confusion_matrix
predictive = pyro.infer.Predictive(bayesianModel, guide=guide, num_samples=100)


corr= 0
tot=0
with torch.no_grad():
    for X, y in tqdm(test_loader):
        out = predictive(X)['obs'].detach()
        ans = getCollapsedDecision(out, 0.5)
        corr += accuracy_score(y, ans, normalize=False)
        tot += X.shape[0]
print(corr/tot)

  0%|          | 0/8 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [430]:
from sklearn.metrics import classification_report
print(classification_report(y,ans))

              precision    recall  f1-score   support

           0       0.98      0.99      0.98        88
           1       0.99      0.98      0.99       111

    accuracy                           0.98       199
   macro avg       0.98      0.99      0.98       199
weighted avg       0.98      0.98      0.98       199



In [358]:
y

tensor([0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1,
        1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1,
        0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1,
        0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0,
        0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1,
        0, 1, 1, 0, 0, 0, 1])

In [359]:
y_hat

tensor([0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1,
        1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0,
        1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1,
        0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1,
        0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1,
        1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1,
        0, 0, 0, 0, 0, 0, 1])