In [1]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from models_control import NetMCDropoutV2
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import Adam

### Control experiment: Not using masks and no dropout
This notebook evaluates the performance of using MCMC-dropout. To clarify, not fixed dropout-masks
https://arxiv.org/pdf/1506.02142

In [2]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=7, sigma=(4, 5)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [3]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)


In [4]:
BATCH_SIZE = 32
EPOCHS = 10
LR = 0.001

In [5]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset2, batch_size=1, shuffle=False)

In [6]:
model = NetMCDropoutV2(num_samples=20)
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability


In [7]:
for i in range(EPOCHS):
    model.train()
    trainCorrect = 0
    totalLoss = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        logits = model.mc_dropout_predict(x)
        loss = lossFn(logits, y)
        totalLoss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
        trainCorrect += (logits.argmax(1) == y).type(
			torch.float).sum().item()
    print(f"Train Accuracy: {trainCorrect/len(dataset1)}")
    print(f"Total loss: {totalLoss}")

1875it [01:25, 22.05it/s]


Train Accuracy: 0.6508333333333334
Total loss: 2742.368918955326


1875it [01:08, 27.40it/s]


Train Accuracy: 0.8289166666666666
Total loss: 2174.770607292652


1875it [01:08, 27.42it/s]


Train Accuracy: 0.8615833333333334
Total loss: 2062.0398649573326


1875it [01:06, 28.25it/s]


Train Accuracy: 0.8808166666666667
Total loss: 1997.402498781681


1875it [00:53, 34.75it/s]


Train Accuracy: 0.8912166666666667
Total loss: 1957.9469105005264


1875it [00:43, 42.80it/s]


Train Accuracy: 0.8986333333333333
Total loss: 1932.7272303700447


1875it [00:44, 42.47it/s]


Train Accuracy: 0.9042166666666667
Total loss: 1911.2175143957138


1875it [00:56, 33.44it/s]


Train Accuracy: 0.9103833333333333
Total loss: 1884.7668615579605


1875it [01:09, 26.95it/s]


Train Accuracy: 0.9120166666666667
Total loss: 1876.6839272975922


1875it [01:26, 21.64it/s]

Train Accuracy: 0.9158166666666666
Total loss: 1865.5554565787315





In [8]:
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    logits = model.mc_dropout_predict(x)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()
print(test_correct / len(dataset2))

10000it [00:34, 292.84it/s]

0.924



