In [None]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from models_control import NetMCDropout
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import Adam

In [2]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=11, sigma=(5, 7)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [3]:
dataset1 = datasets.MNIST('data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('data', train=False,
                       transform=transform)


In [4]:
BATCH_SIZE = 32
EPOCHS = 10
NUM_MASKS = 1
LR = 0.01

In [10]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset2, batch_size=1, shuffle=False)

### Overfitting portion (make dropout probabilites 0 to indicates fully connected)

In [6]:
model = NetMCDropout(num_samples=20)
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability


In [7]:
for i in range(EPOCHS):
    model.train()
    trainCorrect = 0
    totalLoss = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        logits = model.mc_dropout_predict(x)
        loss = lossFn(logits, y)
        totalLoss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
        trainCorrect += (logits.argmax(1) == y).type(
			torch.float).sum().item()
    print(f"Train Accuracy: {trainCorrect/len(dataset1)}")
    print(f"Total loss: {totalLoss}")

1875it [00:16, 114.38it/s]


Train Accuracy: 0.4948166666666667
Total loss: 2964.8901982307434


1875it [00:18, 101.49it/s]


Train Accuracy: 0.5903333333333334
Total loss: 2621.1005333065987


1875it [00:22, 83.01it/s] 


Train Accuracy: 0.6169166666666667
Total loss: 2518.886078119278


1875it [00:18, 103.08it/s]


Train Accuracy: 0.6245166666666667
Total loss: 2488.490544319153


1875it [00:19, 98.03it/s] 


Train Accuracy: 0.6334666666666666
Total loss: 2463.157639205456


1875it [00:19, 94.84it/s] 


Train Accuracy: 0.63295
Total loss: 2464.2192950844765


1875it [00:18, 101.40it/s]


Train Accuracy: 0.6404833333333333
Total loss: 2437.470103919506


1875it [00:18, 100.88it/s]


Train Accuracy: 0.63465
Total loss: 2445.701071500778


1875it [00:19, 94.49it/s] 


Train Accuracy: 0.63995
Total loss: 2443.899631202221


1875it [00:18, 101.38it/s]

Train Accuracy: 0.6401666666666667
Total loss: 2441.0796144604683





In [11]:
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    logits = model.mc_dropout_predict(x)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()
print(test_correct / len(dataset2))

10000it [00:12, 803.30it/s]

0.6656



