In [None]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from model_dropout import NetNormalDropout
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim import Adam

In [2]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=11, sigma=(5, 7)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [3]:
dataset1 = datasets.MNIST('data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('data', train=False,
                       transform=transform)


In [4]:
BATCH_SIZE = 32
EPOCHS = 6
NUM_MASKS = 1
LR = 0.01

In [5]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset2)

In [6]:
model = NetNormalDropout(num_masks=NUM_MASKS)
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability


In [7]:
for i in range(EPOCHS):
    model.train()
    trainCorrect = 0
    totalLoss = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        logits = model.forward(x)
        loss = lossFn(logits, y)
        totalLoss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
        trainCorrect += (logits.argmax(1) == y).type(
			torch.float).sum().item()
    print(f"Train Accuracy: {trainCorrect/len(dataset1)}")
    print(f"Total loss: {totalLoss}")

0it [00:00, ?it/s]

1875it [00:09, 190.41it/s]


Train Accuracy: 0.3834166666666667
Total loss: 3098.7840610146523


1875it [00:09, 194.89it/s]


Train Accuracy: 0.4459666666666667
Total loss: 2794.993779361248


1875it [00:09, 193.20it/s]


Train Accuracy: 0.4585
Total loss: 2735.7522891163826


1875it [00:09, 192.24it/s]


Train Accuracy: 0.4579
Total loss: 2742.133153319359


1875it [00:10, 184.77it/s]


Train Accuracy: 0.45855
Total loss: 2726.2131394147873


1875it [00:10, 178.14it/s]

Train Accuracy: 0.46068333333333333
Total loss: 2724.1490750312805





In [9]:
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    logits = model.forward(x)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()
print(test_correct / len(dataset2))

10000it [00:02, 4494.59it/s]

0.5685



