In [1]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from models_control import NetNormalDropout
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import Adam

### Control experiment: Not using masks with standard dropout
This notebook evaluates the performance oof using a standard deep-net with 'normal' dropout layers in training which are removed at inferene

In [2]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=7, sigma=(4, 5)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [3]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)


In [4]:
BATCH_SIZE = 32
EPOCHS = 15
NUM_MASKS = 1
LR = 0.001

In [5]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset2)

In [6]:
model = NetNormalDropout()
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability


In [7]:
for i in range(EPOCHS):
    model.train()
    trainCorrect = 0
    totalLoss = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        logits = model.forward(x)
        loss = lossFn(logits, y)
        totalLoss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
        trainCorrect += (logits.argmax(1) == y).type(
			torch.float).sum().item()
    print(f"Train Accuracy: {trainCorrect/len(dataset1)}")
    print(f"Total loss: {totalLoss}")

0it [00:00, ?it/s]

1875it [00:08, 229.68it/s]


Train Accuracy: 0.49278333333333335
Total loss: 2664.430910408497


1875it [00:08, 226.25it/s]


Train Accuracy: 0.6377
Total loss: 1970.7157387137413


1875it [00:15, 122.29it/s]


Train Accuracy: 0.6768166666666666
Total loss: 1740.2488400638103


1875it [00:14, 133.72it/s]


Train Accuracy: 0.6879166666666666
Total loss: 1659.981947928667


1875it [00:13, 139.60it/s]


Train Accuracy: 0.7007833333333333
Total loss: 1588.5666856765747


1875it [00:14, 131.36it/s]


Train Accuracy: 0.7049833333333333
Total loss: 1557.2589680850506


1875it [00:14, 127.11it/s]


Train Accuracy: 0.7123833333333334
Total loss: 1512.752326130867


1875it [00:16, 112.19it/s]


Train Accuracy: 0.7198
Total loss: 1482.9898345470428


1875it [00:16, 112.59it/s]


Train Accuracy: 0.7247833333333333
Total loss: 1449.2435713261366


1875it [00:15, 121.99it/s]


Train Accuracy: 0.7239666666666666
Total loss: 1447.843842536211


1875it [00:14, 126.13it/s]


Train Accuracy: 0.7269
Total loss: 1429.1121688783169


1875it [00:15, 119.13it/s]


Train Accuracy: 0.7284166666666667
Total loss: 1420.2725238204002


1875it [00:17, 104.77it/s]


Train Accuracy: 0.73275
Total loss: 1396.3104232549667


1875it [00:18, 99.22it/s] 


Train Accuracy: 0.73115
Total loss: 1391.8905310034752


1875it [00:22, 82.19it/s]

Train Accuracy: 0.73525
Total loss: 1378.7975038066506





In [9]:
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    logits = model.forward(x)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()
print(test_correct / len(dataset2))

10000it [00:02, 4395.55it/s]

0.8374



