In [1]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from model import NetV2, MCMC
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from torch.optim import Adam

### Fixed-Masks with bayesian inference for training
This notebook evaluates the performance of the performance of using fixed dropout masks and evaluating a weighted average of them at inference time.

This notebook shows training and pruning for one of the better set of hyperparameters that also had higher dropout probabilities on each layer

#### We apply random and rotations and blur to make task harder

In [2]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=7, sigma=(4, 5)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [3]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)


In [4]:
BATCH_SIZE = 32
EPOCHS = 15
NUM_MASKS = 3
LR = 0.001
dropout_probs=[0.7, 0.4]

In [5]:
seed = 42
torch.manual_seed(seed)
indices = torch.randperm(len(dataset1)).tolist()  # Shuffled once

# Use SubsetRandomSampler to keep the same shuffle order on each epoch for the over-fitting step
sampler = SequentialSampler(indices)  # Keeps the order fixed
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, sampler=sampler)
test_dataloader = DataLoader(dataset2, batch_size=1)

### Overfitting portion

In [6]:
model = NetV2(num_masks=NUM_MASKS, dropout_probs=dropout_probs)
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability


In [7]:
NGROUPS = 1 # dividing groups to use
mask_groups = [list(range(i, NUM_MASKS, NGROUPS)) for i in range(NGROUPS)]  # Partition masks

for epoch in range(EPOCHS):
    model.train()
    trainCorrect = 0
    totalLoss = 0
    tot = 0
    for idx, (x, y) in tqdm(enumerate(train_dataloader)):
        group_id = idx % NGROUPS  # Assign batch to a group
        masks = mask_groups[group_id]  # Get all masks in this group
        
        for mask in masks:
            logits = model.forward(x, mask=mask)
            loss = lossFn(logits, y)
            totalLoss += loss.item()
            opt.zero_grad()
            loss.backward()
            opt.step()
            trainCorrect += (logits.argmax(1) == y).type(torch.float).sum().item()
            tot += len(y)

    print(f"Train Accuracy: {trainCorrect} / {tot}: {trainCorrect / tot}")
    print(f"Total loss: {totalLoss}")


1875it [00:15, 123.62it/s]


Train Accuracy: 112039.0 / 180000: 0.6224388888888889
Total loss: 5982.230939865112


1875it [00:15, 120.25it/s]


Train Accuracy: 142833.0 / 180000: 0.7935166666666666
Total loss: 3533.7145986557007


1875it [00:14, 133.73it/s]


Train Accuracy: 150362.0 / 180000: 0.8353444444444444
Total loss: 2876.504527039826


1875it [00:14, 133.45it/s]


Train Accuracy: 153363.0 / 180000: 0.8520166666666666
Total loss: 2605.307093206793


1875it [00:14, 128.07it/s]


Train Accuracy: 155368.0 / 180000: 0.8631555555555556
Total loss: 2409.059836709872


1875it [00:15, 119.09it/s]


Train Accuracy: 156915.0 / 180000: 0.87175
Total loss: 2282.3711749296635


1875it [00:15, 119.40it/s]


Train Accuracy: 157643.0 / 180000: 0.8757944444444444
Total loss: 2218.7956836656667


1875it [00:17, 107.86it/s]


Train Accuracy: 158330.0 / 180000: 0.8796111111111111
Total loss: 2142.3615517392755


1875it [00:17, 108.06it/s]


Train Accuracy: 159116.0 / 180000: 0.8839777777777778
Total loss: 2085.7453461978585


1875it [00:16, 113.67it/s]


Train Accuracy: 159338.0 / 180000: 0.8852111111111111
Total loss: 2062.041484077461


1875it [00:16, 115.44it/s]


Train Accuracy: 159677.0 / 180000: 0.8870944444444444
Total loss: 2013.9672125540674


1875it [00:16, 112.78it/s]


Train Accuracy: 160029.0 / 180000: 0.88905
Total loss: 2000.577449085191


1875it [00:16, 114.98it/s]


Train Accuracy: 160442.0 / 180000: 0.8913444444444445
Total loss: 1944.8180303834379


1875it [00:16, 115.03it/s]


Train Accuracy: 160569.0 / 180000: 0.89205
Total loss: 1957.5500510726124


1875it [00:16, 114.97it/s]

Train Accuracy: 160765.0 / 180000: 0.8931388888888889
Total loss: 1939.1459698001854





In [8]:
torch.save(model.state_dict(), "../model.pth")

In [9]:
# model = NetV2(num_masks=NUM_MASKS, dropout_probs=dropout_probs)
# model.load_state_dict(torch.load("../model.pth", weights_only=True))
# model.eval()

In [10]:
mcmc = MCMC(model=model, increment_amt=1)

In [11]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)

In [12]:
model.eval()
for i in range(3):
    trainCorrect = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        mcmc.transition(x=x, y=y)

1875it [00:08, 229.72it/s]
1875it [00:08, 221.45it/s]
1875it [00:08, 222.32it/s]


### Distribution ends up being close to uniform

In [13]:
dist = torch.tensor([(val / mcmc.tot).item() for val in mcmc.ocurrences])
print(dist)

tensor([0.3328, 0.3293, 0.3379])


In [15]:
values, indices = torch.topk(dist, k=3)
print(indices)
print(indices.shape)

tensor([2, 0, 1])
torch.Size([3])


In [16]:
test_correct_top_3 = 0
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    # logits2 = model.forward(x, mask=1)
    logits = mcmc.predict(x, chosen_masks=None)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()


    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct_top_3 += (pred == y).sum().item()
print(f"Test accuracy with all masks: {test_correct / len(dataset2)}")
# print(f"Test accuracy with top 3: {test_correct_top_3 / len(dataset2)}")

  keep_indices = torch.tensor(keep_indices, dtype=torch.long)
10000it [00:06, 1477.21it/s]

Test accuracy with all masks: 0.9107





# Pruning

We can take the union of all the masks used to see what neurons were needed in inference. 

####  First Layer (338 / 512) ~66% neurons need to be used 
#### Second layer (240 / 256) ~93% neurons need to be used


In [17]:
indices_nonmasked = []
for index in indices:
    indices_nonmasked.append((torch.nonzero(model.dropout1.mask_dict[index.item()] == 0)))

(torch.unique(torch.cat(tuple(indices_nonmasked)))).shape

torch.Size([338])

In [18]:
indices_nonmasked = []
for index in indices:
    indices_nonmasked.append((torch.nonzero(model.dropout2.mask_dict[index.item()] == 0)))

(torch.unique(torch.cat(tuple(indices_nonmasked)))).shape

torch.Size([240])