In [1]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from model import NetV2, MCMC
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from torch.optim import Adam

### Fixed-Mask with bayesian inference training
This notebook evaluates the performance of the performance of using fixed dropout masks and evaluating a weighted average of them at inference time

#### We apply random and rotations and blur to make task harder

In [2]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=7, sigma=(4, 5)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [3]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)


In [4]:
BATCH_SIZE = 32
EPOCHS = 15
NUM_MASKS = 5
LR = 0.001
dropout_probs=[0.4, 0.6]

In [5]:
seed = 42
torch.manual_seed(seed)
indices = torch.randperm(len(dataset1)).tolist()  # Shuffled once

# Use SubsetRandomSampler to keep the same shuffle order on each epoch for the over-fitting step
sampler = SequentialSampler(indices)  # Keeps the order fixed
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, sampler=sampler)
test_dataloader = DataLoader(dataset2, batch_size=1)

### Overfitting portion

In [6]:
model = NetV2(num_masks=NUM_MASKS, dropout_probs=dropout_probs)
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability


In [7]:
NGROUPS = 3 # dividing groups to use
mask_groups = [list(range(i, NUM_MASKS, NGROUPS)) for i in range(NGROUPS)]  # Partition masks

for epoch in range(EPOCHS):
    model.train()
    trainCorrect = 0
    totalLoss = 0
    tot = 0
    for idx, (x, y) in tqdm(enumerate(train_dataloader)):
        group_id = idx % NGROUPS  # Assign batch to a group
        masks = mask_groups[group_id]  # Get all masks in this group
        
        for mask in masks:
            logits = model.forward(x, mask=mask)
            loss = lossFn(logits, y)
            totalLoss += loss.item()
            opt.zero_grad()
            loss.backward()
            opt.step()
            trainCorrect += (logits.argmax(1) == y).type(torch.float).sum().item()
            tot += len(y)

    print(f"Train Accuracy: {trainCorrect} / {tot}: {trainCorrect / tot}")
    print(f"Total loss: {totalLoss}")


1875it [00:26, 70.91it/s]


Train Accuracy: 58851.0 / 100000: 0.58851
Total loss: 3583.5722175836563


1875it [00:25, 74.34it/s]


Train Accuracy: 78537.0 / 100000: 0.78537
Total loss: 2034.3618046939373


1875it [00:26, 71.02it/s]


Train Accuracy: 82927.0 / 100000: 0.82927
Total loss: 1645.286963492632


1875it [00:31, 59.53it/s]


Train Accuracy: 84666.0 / 100000: 0.84666
Total loss: 1477.6174337789416


1875it [00:31, 59.02it/s]


Train Accuracy: 85828.0 / 100000: 0.85828
Total loss: 1371.9244211241603


1875it [00:35, 53.25it/s]


Train Accuracy: 86612.0 / 100000: 0.86612
Total loss: 1294.7676348499954


1875it [00:32, 57.19it/s]


Train Accuracy: 87264.0 / 100000: 0.87264
Total loss: 1243.6770866177976


1875it [00:32, 58.30it/s]


Train Accuracy: 87523.0 / 100000: 0.87523
Total loss: 1192.9507208913565


1875it [00:23, 78.40it/s]


Train Accuracy: 88231.0 / 100000: 0.88231
Total loss: 1154.3681839052588


1875it [00:23, 80.76it/s]


Train Accuracy: 88611.0 / 100000: 0.88611
Total loss: 1115.5614146962762


1875it [00:25, 72.95it/s]


Train Accuracy: 88732.0 / 100000: 0.88732
Total loss: 1109.8045058771968


1875it [00:24, 76.44it/s]


Train Accuracy: 88806.0 / 100000: 0.88806
Total loss: 1090.1483276933432


1875it [00:23, 81.03it/s]


Train Accuracy: 89238.0 / 100000: 0.89238
Total loss: 1061.869417335838


1875it [00:22, 82.37it/s]


Train Accuracy: 89488.0 / 100000: 0.89488
Total loss: 1040.6780602792278


1875it [00:24, 77.00it/s]

Train Accuracy: 89519.0 / 100000: 0.89519
Total loss: 1033.5377450855449





In [8]:
torch.save(model.state_dict(), "../model.pth")

In [9]:
# model = NetV2(num_masks=NUM_MASKS, dropout_probs=dropout_probs)
# model.load_state_dict(torch.load("../model.pth", weights_only=True))
# model.eval()

In [10]:
mcmc = MCMC(model=model, increment_amt=1)

In [11]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)

In [12]:
model.eval()
for i in range(3):
    trainCorrect = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        mcmc.transition(x=x, y=y)

1875it [00:17, 104.36it/s]
1875it [00:15, 120.37it/s]
1875it [00:14, 128.71it/s]


This gives a bit more of a non-uniform distribution

In [13]:
dist = torch.tensor([(val / mcmc.tot).item() for val in mcmc.ocurrences])
print(dist)

tensor([0.2153, 0.1761, 0.1937, 0.1733, 0.2416])


In [None]:
values, indices = torch.topk(dist, k=2)
print(indices)
print(indices.shape)

tensor([4, 0])
torch.Size([2])


In [21]:
test_correct_top_2 = 0
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    # logits2 = model.forward(x, mask=1)
    logits = mcmc.predict(x, chosen_masks=None)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()


    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct_top_2 += (pred == y).sum().item()
print(f"Test accuracy with all masks: {test_correct / len(dataset2)}")
print(f"Test accuracy with top 3: {test_correct_top_2 / len(dataset2)}")

  keep_indices = torch.tensor(keep_indices, dtype=torch.long)
10000it [00:12, 793.07it/s]

Test accuracy with all masks: 0.9099
Test accuracy with top 3: 0.9023





####  First Layer (428 / 512) neurons taken
#### Second layer (166 / 256) neurons taken

In [22]:
indices_nonmasked = []
for index in indices:
    indices_nonmasked.append((torch.nonzero(model.dropout1.mask_dict[index.item()] == 0)))

(torch.unique(torch.cat(tuple(indices_nonmasked)))).shape

torch.Size([428])

In [23]:
indices_nonmasked = []
for index in indices:
    indices_nonmasked.append((torch.nonzero(model.dropout2.mask_dict[index.item()] == 0)))

(torch.unique(torch.cat(tuple(indices_nonmasked)))).shape

torch.Size([166])