# MNIST Pruning Toy Example
We demonstrate making a simple MLP network pre-trained on MNIST more unadaptable to fine-tuning on Fashion-MNIST by using pruning.

## Imports

In [44]:
import copy

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from torchvision import datasets, transforms
# from torch.optim.lr_scheduler import StepLR
from tqdm.notebook import tqdm



## Config

In [41]:
# Hyperparameters
batch_size = 64
test_batch_size = 1000
epochs = 10 # TODO rename, also pretrain -> pt below
ft_epochs = 1
lr = 0.01
gamma = 0.7
prune_percentage = 0.8
seed = 1

# Check for CUDA
use_cuda = torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(f"Using {device} device")


Using cuda device


## Set Up Data and Models

In [15]:
# Data loaders for MNIST and Fashion-MNIST
mnist_train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]  # MNIST stats
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

mnist_test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]  # MNIST stats
        ),
    ),
    batch_size=test_batch_size,
    shuffle=True,
)

fashion_train_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        "./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]  # Fashion-MNIST stats
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
)

fashion_test_loader = torch.utils.data.DataLoader(
    datasets.FashionMNIST(
        "./data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))]  # Fashion-MNIST stats
        ),
    ),
    batch_size=test_batch_size,
    shuffle=True,
)


# Define the 2-layer MLP model
class MLPNet(nn.Module):
    def __init__(self, hidden_layer_dims=[128]):
        super(MLPNet, self).__init__()
        # Dynamically create layers based on hidden_layer_dims
        prev_dim = 784
        fc_layers = []
        for dim in hidden_layer_dims:
            fc_layers.append(nn.Linear(prev_dim, dim))
            fc_layers.append(nn.ReLU())
            prev_dim = dim
        fc_layers.append(nn.Linear(prev_dim, 10))
        self.fc_layers = nn.Sequential(*fc_layers)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc_layers(x)
        return F.log_softmax(x, dim=1)


# Initialize the model
model_inital = MLPNet().to(device)


In [30]:
# Training function
def train(model, device, train_loader, num_epochs=epochs):
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    # scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    model.train()
    num_total_batches = len(train_loader) * num_epochs
    progress_bar = tqdm(total=num_total_batches, position=0, leave=True)
    for epoch in range(1, num_epochs + 1):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            progress_bar.update(1)
        if epoch % 2 == 0 or epoch == num_epochs:
            print(
                f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}"
                f" ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
            )
    progress_bar.close()


# Testing function
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            # sum up batch loss
            test_loss += F.nll_loss(output, target, reduction="sum").item()
            # get the index of the max log-probability
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(
        f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}"
        f" ({100. * correct / len(test_loader.dataset):.0f}%)\n"
    )

    return test_loss, correct / len(test_loader.dataset)



In [None]:

# Train the model on MNIST
model_base_pretrain = copy.deepcopy(model_inital)
train(model_base_pretrain, device, mnist_train_loader)
# scheduler.step()

# Evaluate the model on MNIST test set
mnist_loss, mnist_accuracy = test(model_base_pretrain, device, mnist_test_loader)


## Prune the model to create unadaptable version

In [42]:
# Pruning function
def apply_pruning_to_model(model):
    parameters_to_prune = (
        (model.fc_layers[0], "weight"),
        (model.fc_layers[2], "weight"),
        # (model.fc3, "weight"),
        # Tuple comprehension
        # (module, name)
        # for name, module in model.named_modules()
        # if isinstance(module, nn.Linear)
    )

    for module, name in parameters_to_prune:
        prune.l1_unstructured(module, name, amount=prune_percentage)

print("Pruning the model...")
model_pruned_pt = copy.deepcopy(model_base_pretrain)
apply_pruning_to_model(model_pruned_pt)

# Evaluate on MNIST test set
print("Evaluating the pruned model on MNIST...")
mnist_loss_pruned, mnist_accuracy_pruned = test(model_pruned_pt, device, mnist_test_loader)
print(f"model_base_pt accuracy: {mnist_accuracy:.4f}, model_pruned_pt accuracy: {mnist_accuracy_pruned:.4f}")
print(f"model_base_pt loss: {mnist_loss:.4f}, model_pruned_pt loss: {mnist_loss_pruned:.4f}")

Pruning the model...
Evaluating the pruned model on MNIST...

Test set: Average loss: 0.8657, Accuracy: 8754/10000 (88%)

model_base_pt accuracy: 0.9569, model_pruned_pt accuracy: 0.8754
model_base_pt loss: 0.1959, model_pruned_pt loss: 0.8657


## Fine-tune both the original and the pruned models on Fashion-MNIST

In [43]:
model_base_ft = copy.deepcopy(model_base_pretrain)
model_pruned_ft = copy.deepcopy(model_pruned_pt)

# Evaluate the pre-trained model on Fashion-MNIST
print("Evaluating the pre-trained model on Fashion-MNIST")
ft_loss_base_pretrain, ft_acc_base_pretrain = test(model_base_pretrain, device, fashion_test_loader)
ft_loss_pruned_pretrain, ft_acc_pruned_pretrain = test(model_pruned_pt, device, fashion_test_loader)
print(f"model_base_pt accuracy: {ft_acc_base_pretrain:.4f}, model_pruned_pt accuracy: {ft_acc_pruned_pretrain:.4f}")
print(f"model_base_pt loss: {ft_loss_base_pretrain:.4f}, model_pruned_pt loss: {ft_loss_pruned_pretrain:.4f}")

# Train the models on Fashion-MNIST
print("Training the base model on Fashion-MNIST")
train(model_base_ft, device, fashion_train_loader, num_epochs=ft_epochs)

print("Training the pruned model on Fashion-MNIST")
train(model_pruned_ft, device, fashion_train_loader, num_epochs=ft_epochs)

# Calculate metrics
ft_loss_base_ft, ft_acc_base_ft = test(model_base_ft, device, fashion_test_loader)
ft_loss_pruned_ft, ft_acc_pruned_ft = test(model_pruned_ft, device, fashion_test_loader)

print("Evaluating the fine-tuned models on Fashion-MNIST")
print(f"model_base_ft accuracy: {ft_acc_base_ft:.4f}, model_pruned_ft accuracy: {ft_acc_pruned_ft:.4f}")
print(f"model_base_ft loss: {ft_loss_base_ft:.4f}, model_pruned_ft loss: {ft_loss_pruned_ft:.4f}")

# loss_gap_ratio = abs(fine_tune_accuracy_base - fine_tune_accuracy_pruned)



Evaluating the pre-trained model on Fashion-MNIST

Test set: Average loss: 11.5359, Accuracy: 537/10000 (5%)


Test set: Average loss: 19.2064, Accuracy: 1030/10000 (10%)

model_base_pt accuracy: 0.0537, model_pruned_pt accuracy: 0.1030
model_base_pt loss: 11.5359, model_pruned_pt loss: 19.2064
Training the base model on Fashion-MNIST


100%|██████████| 938/938 [00:14<00:00, 65.23it/s]


Training the pruned model on Fashion-MNIST


100%|██████████| 938/938 [00:14<00:00, 64.79it/s]



Test set: Average loss: 0.5130, Accuracy: 8206/10000 (82%)


Test set: Average loss: 0.5591, Accuracy: 8060/10000 (81%)

Evaluating the fine-tuned models on Fashion-MNIST
model_base_ft accuracy: 0.8206, model_pruned_ft accuracy: 0.8060
model_base_ft loss: 0.5130, model_pruned_ft loss: 0.5591


In [None]:
# Print and save the results
print(f"MNIST Test Accuracy before fine-tuning: {mnist_accuracy:.4f}")
print(
    f"Fashion-MNIST Test Accuracy after fine-tuning (Base Model): {fine_tune_accuracy_base:.4f}"
)
print(
    f"Fashion-MNIST Test Accuracy after fine-tuning (Pruned Model): {fine_tune_accuracy_pruned:.4f}"
)
print(f"Loss Gap Ratio: {loss_gap_ratio:.4f}")
