In [1]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from models_control import NetNormalDropoutV2
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import Adam

### Control experiment: Not using masks and no dropout
This notebook evaluates the performance oof using a standard deep-net with 'normal' dropout layers in training which are removed at inferene

In [2]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=7, sigma=(4, 5)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [3]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)


In [4]:
BATCH_SIZE = 32
EPOCHS = 12
NUM_MASKS = 1
LR = 0.001

In [5]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset2)

In [6]:
model = NetNormalDropoutV2()
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability


In [7]:
for i in range(EPOCHS):
    model.train()
    trainCorrect = 0
    totalLoss = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        logits = model.forward(x)
        loss = lossFn(logits, y)
        totalLoss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
        trainCorrect += (logits.argmax(1) == y).type(
			torch.float).sum().item()
    print(f"Train Accuracy: {trainCorrect/len(dataset1)}")
    print(f"Total loss: {totalLoss}")

1875it [00:28, 64.98it/s]


Train Accuracy: 0.5317166666666666
Total loss: 2406.4092569947243


1875it [00:33, 55.26it/s]


Train Accuracy: 0.6812666666666667
Total loss: 1665.401730120182


1875it [00:26, 71.02it/s]


Train Accuracy: 0.7094
Total loss: 1520.6825696527958


1875it [00:24, 77.30it/s]


Train Accuracy: 0.7282166666666666
Total loss: 1416.0007178410888


1875it [00:26, 70.70it/s]


Train Accuracy: 0.7390166666666667
Total loss: 1351.6500516086817


1875it [00:31, 60.25it/s]


Train Accuracy: 0.7423666666666666
Total loss: 1334.666971027851


1875it [00:26, 71.31it/s]


Train Accuracy: 0.74885
Total loss: 1294.93528534472


1875it [00:27, 67.96it/s]


Train Accuracy: 0.7549333333333333
Total loss: 1259.646432340145


1875it [00:29, 64.37it/s]


Train Accuracy: 0.7566333333333334
Total loss: 1257.3940309658647


1875it [00:27, 68.07it/s]


Train Accuracy: 0.7576333333333334
Total loss: 1244.1516629680991


1875it [00:24, 75.95it/s]


Train Accuracy: 0.7613
Total loss: 1225.1169779524207


1875it [00:22, 83.03it/s]

Train Accuracy: 0.7604666666666666
Total loss: 1226.0475217327476





In [8]:
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    logits = model.forward(x)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()
print(test_correct / len(dataset2))

10000it [00:05, 1916.54it/s]

0.9156



