In [1]:
import sys
sys.path.append('../')
import torch
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
from model import NetV2, MCMC
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from torch.optim import Adam
import os
from NISP import NISP
import torch.nn.utils.prune as prune


### Fixed-Masks with bayesian inference for training
This notebook evaluates the performance of the performance of using fixed dropout masks and evaluating a weighted average of them at inference time.

This notebook shows training and pruning for one of the better set of hyperparameters that also had higher dropout probabilities on each layer

#### We apply random and rotations and blur to make task harder

In [2]:
transform = transforms.Compose([
    transforms.RandomRotation(degrees=50),
    transforms.ToTensor(), transforms.GaussianBlur(kernel_size=5, sigma=(4, 5)),  # Stronger blur
    transforms.Lambda(lambda x: torch.flatten(x)),
])

In [3]:
dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)


In [4]:
BATCH_SIZE = 32
EPOCHS = 10
NUM_MASKS = 10
LR = 0.001
dropout_probs=[0.7, 0.7]

In [5]:
seed = 42
torch.manual_seed(seed)
indices = torch.randperm(len(dataset1)).tolist()  # Shuffled once

# Use SubsetRandomSampler to keep the same shuffle order on each epoch for the over-fitting step
sampler = SequentialSampler(indices)  # Keeps the order fixed
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, sampler=sampler)
test_dataloader = DataLoader(dataset2, batch_size=1)

### Overfitting portion

In [6]:
model_iteration = 0
if not os.path.exists(f"../model{model_iteration}.pth"):
    model = NetV2(num_masks=NUM_MASKS, dropout_probs=dropout_probs)
    opt = Adam(model.parameters(), lr=LR)
    lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability
    torch.save(model.state_dict(), f"../model{model_iteration}.pth")
model_iteration += 1

In [7]:
if not os.path.exists(f"../model{model_iteration}.pth"):
    NGROUPS = 10 # dividing groups to use
    mask_groups = [list(range(i, NUM_MASKS, NGROUPS)) for i in range(NGROUPS)]  # Partition masks

    for epoch in range(EPOCHS):
        model.train()
        trainCorrect = 0
        totalLoss = 0
        tot = 0
        for idx, (x, y) in tqdm(enumerate(train_dataloader)):
            group_id = idx % NGROUPS  # Assign batch to a group
            masks = mask_groups[group_id]  # Get all masks in this group
            
            for mask in masks:
                logits = model.forward(x, mask=mask)
                loss = lossFn(logits, y)
                totalLoss += loss.item()
                opt.zero_grad()
                loss.backward()
                opt.step()
                trainCorrect += (logits.argmax(1) == y).type(torch.float).sum().item()
                tot += len(y)

        print(f"Train Accuracy: {trainCorrect} / {tot}: {trainCorrect / tot}")
        print(f"Total loss: {totalLoss}")
    torch.save(model.state_dict(), f"../model{model_iteration}.pth")
    model1 = model

In [8]:
model_iteration += 1

In [57]:
model = NetV2(num_masks=NUM_MASKS, dropout_probs=dropout_probs)
model.load_state_dict(torch.load(f"../model{1}.pth"), strict=True)
model.eval()

NetV2(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (dropout1): ConsistentMCDropout(p=0.7)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (dropout2): ConsistentMCDropout(p=0.7)
  (fc3): Linear(in_features=512, out_features=10, bias=True)
)

In [7]:
mcmc = MCMC(model=model, increment_amt=10)

In [11]:
seed = 42
torch.manual_seed(seed)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)

In [12]:
model.eval()
for i in range(3):
    trainCorrect = 0
    for idx, (x, y)  in tqdm(enumerate(train_dataloader)):
        mcmc.transition(x=x, y=y)

Consider using tensor.detach() first. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/Scalar.cpp:23.)
  acceptance_prob = torch.tensor([min(torch.tensor([1]).float(), ratio)])
1875it [00:12, 147.43it/s]
1875it [00:13, 140.21it/s]
1875it [00:12, 148.63it/s]


### Distribution ends up being close to uniform

In [8]:
dist = torch.tensor([(val / mcmc.tot).item() for val in mcmc.ocurrences])
print(dist)

tensor([0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000,
        0.1000])


In [11]:
values, indices = torch.topk(dist, k=10)
print(indices)
print(indices.shape)

NameError: name 'dist' is not defined

In [15]:
test_correct_top_3 = 0
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    # logits2 = model.forward(x, mask=1)
    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()


    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct_top_3 += (pred == y).sum().item()
print(f"Test accuracy with all masks: {test_correct / len(dataset2)}")
# print(f"Test accuracy with top 3: {test_correct_top_3 / len(dataset2)}")

  keep_indices = torch.tensor(keep_indices, dtype=torch.long)
10000it [00:28, 350.31it/s]

Test accuracy with all masks: 0.9436





# Pruning


In [58]:
import torch.nn.utils.prune as prune
from torch import nn

def apply_nisp_pruning(model: nn.Module, layer_name: str, neuron_mask: torch.Tensor):
    """
    Applies structured pruning to a given layer using a neuron-wise mask.

    Args:
        model: The PyTorch model.
        layer_name: Name of the layer to prune (e.g., 'fc1').
        neuron_mask: 1D tensor with 1s for neurons to keep, 0s to prune (length = out_features).
    """
    layer = dict(model.named_modules())[layer_name]
    
    if not isinstance(layer, nn.Linear):
        raise ValueError("Only Linear layers are supported for this kind of pruning.")

    weight_mask = neuron_mask[:, None].expand_as(layer.weight.data)

    # Use custom mask with PyTorch's pruning
    prune.CustomFromMask.apply(layer, name='weight', mask=weight_mask)

    # Optionally prune the bias too
    if layer.bias is not None:
        bias_mask = neuron_mask.clone()
        prune.CustomFromMask.apply(layer, name='bias', mask=bias_mask)


In [59]:
dropout_layers = {
    'fc1': model.dropout1,
    'fc2': model.dropout2,
}

nisp = NISP(model)
importance_scores = nisp.compute_aggregated_importance_scores(dropout_layers)
# Let's say you've already computed the scores and pruning mask:
fc1_scores = importance_scores['fc1']
fc1_mask = nisp.get_pruning_mask(fc1_scores, pruning_rate=0.5)
apply_nisp_pruning(model, 'fc1', fc1_mask)


fc2_scores = importance_scores['fc2']
fc2_mask = nisp.get_pruning_mask(fc2_scores, pruning_rate=0.4)
apply_nisp_pruning(model, 'fc2', fc2_mask)



In [61]:
model.fc1.bias.count_nonzero()

tensor(513)

In [21]:
prune.remove(model.fc1, 'weight')
prune.remove(model.fc1, 'bias')
prune.remove(model.fc2, 'weight')
prune.remove(model.fc2, 'bias')

Linear(in_features=1024, out_features=512, bias=True)

In [24]:
model.fc1.weight.count_nonzero()

tensor(242256)

In [6]:
initial_model = NetV2(num_masks=5, dropout_probs=[0.4, 0.4])
initial_model.load_state_dict(torch.load(f"../model{0}.pth"))
initial_model.num_masks = 5

In [84]:
prune_mask_weight = final_mask1.bool()[:, None].expand_as(initial_model.fc1.weight)
prune_mask_bias = final_mask1.bool() if initial_model.fc1.bias is not None else None

prune.custom_from_mask(
    module=initial_model.fc1,
    name='weight',
    mask=prune_mask_weight
)

if prune_mask_bias is not None:
    prune.custom_from_mask(
        module=initial_model.fc1,
        name='bias',
        mask=prune_mask_bias
    )

In [85]:
prune_mask_weight = final_mask2.bool()[:, None].expand_as(initial_model.fc2.weight)
prune_mask_bias = final_mask2.bool() if initial_model.fc2.bias is not None else None

prune.custom_from_mask(
    module=initial_model.fc2,
    name='weight',
    mask=prune_mask_weight
)

if prune_mask_bias is not None:
    prune.custom_from_mask(
        module=initial_model.fc2,
        name='bias',
        mask=prune_mask_bias
    )

In [56]:
prune.remove(initial_model.fc1, 'weight')
prune.remove(initial_model.fc2, 'weight')
prune.remove(initial_model.fc1, 'bias')
prune.remove(initial_model.fc2, 'bias')

Linear(in_features=1024, out_features=512, bias=True)

In [57]:
print(initial_model.fc1.bias.count_nonzero())
print(initial_model.fc2.bias.count_nonzero())


tensor(411)
tensor(206)


In [10]:
model = initial_model
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability

NameError: name 'initial_model' is not defined

In [11]:
model  = initial_model
if not os.path.exists(f"../model{model_iteration}.pth"):
    NGROUPS = 5 # dividing groups to use
    mask_groups = [list(range(i, NUM_MASKS, NGROUPS)) for i in range(NGROUPS)]  # Partition masks

    for epoch in range(EPOCHS):
        model.train()
        trainCorrect = 0
        totalLoss = 0
        tot = 0
        for idx, (x, y) in tqdm(enumerate(train_dataloader)):
            group_id = idx % NGROUPS  # Assign batch to a group
            masks = mask_groups[group_id]  # Get all masks in this group
            
            for mask in masks:
                logits = model.forward(x, mask=mask)
                loss = lossFn(logits, y)
                totalLoss += loss.item()
                opt.zero_grad()
                loss.backward()
                opt.step()
                trainCorrect += (logits.argmax(1) == y).type(torch.float).sum().item()
                tot += len(y)

        print(f"Train Accuracy: {trainCorrect} / {tot}: {trainCorrect / tot}")
        print(f"Total loss: {totalLoss}")
    torch.save(model.state_dict(), f"../model{model_iteration}.pth")
    model1 = model

NameError: name 'initial_model' is not defined

In [None]:
model_iteration = 3

In [41]:
model = NetV2(num_masks=5, dropout_probs=dropout_probs)
model.load_state_dict(torch.load(f"../model{2}.pth"), strict=True)
model.eval()
model.fc1.bias.count_nonzero()

tensor(411)

In [7]:
mcmc = MCMC(model=model, increment_amt=10)
train_dataloader = DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
model.num_masks

5

In [8]:
model.eval()
for i in range(3):
    trainCorrect = 0
    for idx, (x, y) in tqdm(enumerate(train_dataloader)):
        mcmc.transition(x=x, y=y)
dist = torch.tensor([(val / mcmc.tot).item() for val in mcmc.ocurrences])
print(dist)

0it [00:00, ?it/s]

1875it [00:09, 195.21it/s]
1875it [00:09, 205.66it/s]
1875it [00:09, 205.85it/s]

tensor([0.2983, 0.1273, 0.1347, 0.1483, 0.2914])





In [9]:
print(mcmc.num_masks)

5


In [10]:
values, indices = torch.topk(dist, k=5)
print(indices)
print(indices.shape)

tensor([0, 4, 3, 2, 1])
torch.Size([5])


In [11]:
dist = torch.tensor([(val / mcmc.tot).item() for val in mcmc.ocurrences])
test_correct_top_3 = 0
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    # logits2 = model.forward(x, mask=1)
    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()


    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct_top_3 += (pred == y).sum().item()
print(f"Test accuracy with all masks: {test_correct / len(dataset2)}")
# print(f"Test accuracy with top 3: {test_correct_top_3 / len(dataset2)}")

  keep_indices = torch.tensor(keep_indices, dtype=torch.long)
10000it [00:16, 613.40it/s]

Test accuracy with all masks: 0.9552





In [68]:
nisp = NISP(model)
importance_scores = nisp.compute_importance_scores()
prune_mask1 = nisp.get_pruning_mask('fc1', 0.4)
prune_mask2 = nisp.get_pruning_mask('fc2', 0.3)

fc1_majority_mask = model.dropout1.get_majority_vote_mask(threshold=0.5)
fc2_majority_mask = model.dropout2.get_majority_vote_mask(threshold=0.5)


def conservative_combine(nisp_mask, dropout_mask):
    """MOST AGGRESSIVE PRUNING: Only keep if both methods agree"""
    return (nisp_mask.bool() & dropout_mask.bool()).float()  # 1=keep, 0=prune

final_mask1 = conservative_combine(prune_mask1, fc1_majority_mask)  # Uses &
final_mask2 = conservative_combine(prune_mask2, fc2_majority_mask)  # Uses &

In [69]:
print(fc1_majority_mask.count_nonzero())
print(prune_mask1.count_nonzero())
print(final_mask1.count_nonzero())
print(final_mask2.count_nonzero())

tensor(311)
tensor(614)
tensor(188)
tensor(112)


In [70]:
initial_model = NetV2(num_masks=1, dropout_probs=[0.0, 0.0])
initial_model.load_state_dict(torch.load(f"../model{0}.pth"))

In [71]:
prune_mask_weight = final_mask1.bool()[:, None].expand_as(initial_model.fc1.weight)
prune_mask_bias = final_mask1.bool() if initial_model.fc1.bias is not None else None

prune.custom_from_mask(
    module=initial_model.fc1,
    name='weight',
    mask=prune_mask_weight
)

if prune_mask_bias is not None:
    prune.custom_from_mask(
        module=initial_model.fc1,
        name='bias',
        mask=prune_mask_bias
    )

In [72]:
prune_mask_weight = final_mask2.bool()[:, None].expand_as(initial_model.fc2.weight)
prune_mask_bias = final_mask2.bool() if initial_model.fc2.bias is not None else None

prune.custom_from_mask(
    module=initial_model.fc2,
    name='weight',
    mask=prune_mask_weight
)

if prune_mask_bias is not None:
    prune.custom_from_mask(
        module=initial_model.fc2,
        name='bias',
        mask=prune_mask_bias
    )

In [73]:
prune.remove(initial_model.fc1, 'weight')
prune.remove(initial_model.fc2, 'weight')
prune.remove(initial_model.fc1, 'bias')
prune.remove(initial_model.fc2, 'bias')

print(initial_model.fc1.bias.count_nonzero())
print(initial_model.fc2.bias.count_nonzero())


tensor(188)
tensor(112)


In [75]:
model_iteration = 3

In [76]:
model = initial_model
opt = Adam(model.parameters(), lr=LR)
lossFn = torch.nn.NLLLoss() # Use NLL since we our model is outputting a probability

In [77]:
model  = initial_model
if not os.path.exists(f"../model{model_iteration}.pth"):
    NGROUPS = 1 # dividing groups to use
    mask_groups = [list(range(i, NUM_MASKS, NGROUPS)) for i in range(NGROUPS)]  # Partition masks

    for epoch in range(EPOCHS):
        model.train()
        trainCorrect = 0
        totalLoss = 0
        tot = 0
        for idx, (x, y) in tqdm(enumerate(train_dataloader)):
            group_id = idx % NGROUPS  # Assign batch to a group
            masks = mask_groups[group_id]  # Get all masks in this group
            
            for mask in masks:
                logits = model.forward(x, mask=mask)
                loss = lossFn(logits, y)
                totalLoss += loss.item()
                opt.zero_grad()
                loss.backward()
                opt.step()
                trainCorrect += (logits.argmax(1) == y).type(torch.float).sum().item()
                tot += len(y)

        print(f"Train Accuracy: {trainCorrect} / {tot}: {trainCorrect / tot}")
        print(f"Total loss: {totalLoss}")
    torch.save(model.state_dict(), f"../model{model_iteration}.pth")
    model1 = model

1875it [01:13, 25.46it/s]


Train Accuracy: 547479.0 / 600000: 0.912465
Total loss: 5760.077270619338


1875it [01:14, 25.19it/s]


Train Accuracy: 563824.0 / 600000: 0.9397066666666667
Total loss: 4242.937383973622


1875it [01:18, 24.01it/s]


Train Accuracy: 564214.0 / 600000: 0.9403566666666666
Total loss: 4224.361388082732


1875it [01:22, 22.78it/s]


Train Accuracy: 565901.0 / 600000: 0.9431683333333334
Total loss: 3988.9579778346233


1875it [01:22, 22.61it/s]


Train Accuracy: 565209.0 / 600000: 0.942015
Total loss: 4038.8632180804852


1875it [01:23, 22.33it/s]


Train Accuracy: 565722.0 / 600000: 0.94287
Total loss: 4001.72815497173


1875it [01:21, 22.96it/s]


Train Accuracy: 567245.0 / 600000: 0.9454083333333333
Total loss: 3884.2582408610033


1875it [01:18, 23.98it/s]


Train Accuracy: 567426.0 / 600000: 0.94571
Total loss: 3827.468564099603


1875it [01:22, 22.85it/s]


Train Accuracy: 567423.0 / 600000: 0.945705
Total loss: 3847.038970327354


1875it [01:37, 19.31it/s]


Train Accuracy: 567941.0 / 600000: 0.9465683333333333
Total loss: 3808.8674488131655


In [78]:
model = NetV2(num_masks=1, dropout_probs=[0.0, 0.0])
model.load_state_dict(torch.load(f"../model{3}.pth"), strict=True)
model.eval()
model.fc1.bias.count_nonzero()

tensor(188)

In [79]:
mcmc = MCMC(model=model, increment_amt=10)
train_dataloader =DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
model.num_masks

1

In [80]:
model.eval()
for i in range(1):
    trainCorrect = 0
    for idx, (x, y) in tqdm(enumerate(train_dataloader)):
        mcmc.transition(x=x, y=y)
dist = torch.tensor([(val / mcmc.tot).item() for val in mcmc.ocurrences])
print(dist)

1875it [00:08, 211.08it/s]

tensor([1.])





In [81]:
values, indices = torch.topk(dist, k=1)
print(indices)
print(indices.shape)

tensor([0])
torch.Size([1])


In [82]:
test_correct_top_3 = 0
test_correct = 0
model.eval()
for idx, (x, y)  in tqdm(enumerate(test_dataloader)):
    # logits2 = model.forward(x, mask=1)
    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct += (pred == y).sum().item()


    logits = mcmc.predict(x, chosen_masks=indices)
    pred = torch.argmax(logits, dim=1)
    test_correct_top_3 += (pred == y).sum().item()
print(f"Test accuracy with all masks: {test_correct / len(dataset2)}")
# print(f"Test accuracy with top 3: {test_correct_top_3 / len(dataset2)}")

  keep_indices = torch.tensor(keep_indices, dtype=torch.long)
10000it [00:04, 2135.31it/s]

Test accuracy with all masks: 0.9386





In [94]:
print(model.fc1.weight.shape)
print(model.fc1.weight.count_nonzero())
print(model.fc1.bias.shape)
print(model.fc1.bias.count_nonzero())
print(model.fc2.bias.count_nonzero())

torch.Size([1024, 784])
tensor(147392)
torch.Size([1024])
tensor(188)
tensor(112)
