In [5]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from model import Net, MCMC
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from torch.optim import Adam

### Fixed-Mask with bayesian inference training
This notebook evaluates the performance of the performance of using fixed dropout masks and evaluating a weighted average of them at inference time

#### We apply random and rotations and blur to make task harder

In [6]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=180),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=11, sigma=(5, 7)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [9]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)


In [10]:
BATCH_SIZE = 32
EPOCHS = 8
NUM_MASKS = 10
LR = 0.001
dropout_probs=[0.4]

In [11]:
seed = 42
torch.manual_seed(seed)
indices = torch.randperm(len(dataset1)).tolist()  # Shuffled once

# Use SubsetRandomSampler to keep the same shuffle order on each epoch for the over-fitting step
sampler = SequentialSampler(indices)  # Keeps the order fixed
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, sampler=sampler)
test_dataloader = DataLoader(dataset2, batch_size=1)

### Overfitting portion

In [23]:
model = Net(num_masks=NUM_MASKS, dropout_probs=dropout_probs)
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability


In [24]:
for i in range(EPOCHS):
    model.train()
    trainCorrect = 0
    totalLoss = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        logits = model.forward(x, mask = idx % NUM_MASKS)
        loss = lossFn(logits, y)
        totalLoss += loss.item()
        opt.zero_grad()
        loss.backward()
        opt.step()
        trainCorrect += (logits.argmax(1) == y).type(
			torch.float).sum().item()
    print(f"Train Accuracy: {trainCorrect} / {len(dataset1)}: {trainCorrect/len(dataset1)}")
    print(f"Total loss: {totalLoss}")

1875it [00:09, 187.57it/s]


Train Accuracy: 25297.0 / 60000: 0.42161666666666664
Total loss: 3006.165900349617


1875it [00:10, 179.27it/s]


Train Accuracy: 33188.0 / 60000: 0.5531333333333334
Total loss: 2381.466608762741


1875it [00:10, 177.75it/s]


Train Accuracy: 36524.0 / 60000: 0.6087333333333333
Total loss: 2135.5018680095673


1875it [00:10, 180.06it/s]


Train Accuracy: 39234.0 / 60000: 0.6539
Total loss: 1949.2740721404552


1875it [00:10, 181.00it/s]


Train Accuracy: 40614.0 / 60000: 0.6769
Total loss: 1822.8774032592773


1875it [00:09, 192.66it/s]


Train Accuracy: 41838.0 / 60000: 0.6973
Total loss: 1722.6279719173908


1875it [00:09, 190.26it/s]


Train Accuracy: 42675.0 / 60000: 0.71125
Total loss: 1656.4186381697655


1875it [00:10, 176.30it/s]

Train Accuracy: 43452.0 / 60000: 0.7242
Total loss: 1597.8344276845455





In [None]:
torch.save(model.state_dict(), "../model.pth")

In [13]:
model = Net(num_masks=NUM_MASKS, dropout_probs=dropout_probs)
model.load_state_dict(torch.load("../model.pth", weights_only=True))
model.eval()

Net(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (dropout1): ConsistentMCDropout(p=0.4)
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)

In [14]:
mcmc = MCMC(model=model, increment_amt=1)

In [15]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)

In [16]:
model.eval()
for i in range(3):
    trainCorrect = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        mcmc.transition(x=x, y=y)

1875it [00:10, 173.16it/s]
1875it [00:09, 193.48it/s]
1875it [00:11, 164.10it/s]


In [17]:
dist = torch.tensor([(val / mcmc.tot).item() for val in mcmc.ocurrences])
print(dist)

tensor([0.1081, 0.0988, 0.0961, 0.0999, 0.1006, 0.0980, 0.1058, 0.0965, 0.1002,
        0.0960])


In [18]:
values, indices = torch.topk(dist, k=2)
print(indices)
print(indices.shape)

tensor([0, 6])
torch.Size([2])


In [31]:
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    # logits2 = model.forward(x, mask=1)
    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()
print(test_correct / len(dataset2))

  keep_indices = torch.tensor(keep_indices, dtype=torch.long)
10000it [00:07, 1389.66it/s]

0.762





## Test accuracy was 76% which is a somewhat stronger than the controls

### We show the number of overlapping neurons that can be pruned by finding the union-set of neurons between mask. In this case 212/256 neurons were used 

In [35]:
indices_nonmasked = []
for index in indices:
    indices_nonmasked.append((torch.nonzero(model.dropout1.mask_dict[index.item()] == 0)))

(torch.unique(torch.cat(tuple(indices_nonmasked)))).shape

torch.Size([212])