In [1]:
import torch.nn.utils.prune as prune

import torch

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from typing import Callable, Tuple
from pathlib import Path

### Calculate the mean and std of the training set for normalization

In [2]:
# Define a custom transformation to collect statistics
to_tensor = transforms.Compose([transforms.ToTensor()])

# Load the MNIST training dataset without normalization
dataset = datasets.MNIST(root="../data", train=True, download=True, transform=to_tensor)

# Initialize variables to accumulate sum and squared sum
sum_pixels = 0
sum_pixels_squared = 0

# Count the number of samples
num_samples = len(dataset)

# Iterate through the dataset to collect statistics
for data, _ in dataset:
    # Flatten the data tensor
    data = data.view(-1)

    # Accumulate the sum of pixel values and squared sum
    sum_pixels += data.sum()
    sum_pixels_squared += (data**2).sum()

# Calculate the mean and standard deviation
mean = sum_pixels / (num_samples * 28 * 28)  # 28x28 is the image size
std_dev = ((sum_pixels_squared / (num_samples * 28 * 28)) - mean**2) ** 0.5

print("Mean:", mean)
print("Standard Deviation:", std_dev)

Mean: tensor(0.1307)
Standard Deviation: tensor(0.3081)


### Load the data

In [3]:
# load MNIST data
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean, std_dev)]
)

# split into validation and train datasets
train_ds = datasets.MNIST("../data", train=True, transform=transform)
train_ds, valid_ds = random_split(train_ds, [50000, 10000])

test_ds = datasets.MNIST("../data", train=False, transform=transform)

## Define the model architecture

In [4]:
# Define a simple CNN model
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(
            1, 6, kernel_size=5, stride=1, padding=2
        )  # 28*28->32*32-->28*28
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.flatten1 = nn.Flatten()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten1(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Define utilities for training and testing

In [5]:
class EarlyStopper:
    def __init__(self, patience: int = 1, min_delta: int = 0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float("inf")

    def early_stop(self, validation_loss: float) -> bool:
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [6]:
# Define a function to train the model
def fit(
    model: nn.Module,
    train_dl,
    valid_dl,
    optimizer: optim.Optimizer,
    loss_function: Callable,
    epochs: int,
    early_stopper: EarlyStopper | None = None,
    device: torch.device = torch.device("cpu"),
) -> Tuple[float, float]:
    valid_loss = 0
    valid_accuracy = 0

    for epoch in range(epochs):
        model.train()
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)

            # Compute prediction error
            pred = model(X)
            train_loss = loss_function(pred, y)

            # Backpropagation
            train_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        valid_loss = 0
        valid_accuracy = 0
        with torch.no_grad():
            for X, y in valid_dl:
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = model(X)
                valid_loss += loss_function(pred, y).item()

                # Compute accuracy
                valid_accuracy += (pred.argmax(1) == y).float().mean()

        valid_loss /= len(valid_dl)
        valid_accuracy /= len(valid_dl)

        print(
            f"Epoch #{epoch + 1}:\t validation loss: {valid_loss:.4f}\t validation accuracy: {valid_accuracy:.4f}"
        )

        if early_stopper is not None and early_stopper.early_stop(valid_loss):
            print("Early stopping")
            return (valid_loss, valid_accuracy)

    return (valid_loss, valid_accuracy)

In [7]:
# Define a function to test the model
def test(
    model: nn.Module,
    test_dl,
    loss_function: Callable,
    device: torch.device = torch.device("cpu"),
) -> Tuple[float, float]:
    size = len(test_dl.dataset)
    num_batches = len(test_dl)
    model.eval()

    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in test_dl:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_function(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    accuracy = (correct / size) * 100

    return (test_loss, accuracy)

## Training Phase

In [8]:
# define the constants
BATCH_SIZE: int = 32
LEARNING_RATE: float = 0.01
EPOCHS: int = 10
MOMENTUM: float = 0.9

In [9]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [10]:
base_model = Model().to(device)

# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
early_stopper = EarlyStopper(patience=3, min_delta=0)
optimizer = optim.SGD(base_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

# create the data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

### Training loop

In [11]:
valid_loss, valid_accuracy = fit(
    base_model,
    train_dl=train_loader,
    valid_dl=validation_loader,
    optimizer=optimizer,
    loss_function=loss_fn,
    epochs=EPOCHS,
    device=device,
)

  return F.conv2d(input, weight, bias, self.stride,


Epoch #1:	 validation loss: 0.0870	 validation accuracy: 0.9722
Epoch #2:	 validation loss: 0.0657	 validation accuracy: 0.9790
Epoch #3:	 validation loss: 0.0615	 validation accuracy: 0.9801
Epoch #4:	 validation loss: 0.0692	 validation accuracy: 0.9810
Epoch #5:	 validation loss: 0.0539	 validation accuracy: 0.9843
Epoch #6:	 validation loss: 0.0484	 validation accuracy: 0.9866
Epoch #7:	 validation loss: 0.0443	 validation accuracy: 0.9861
Epoch #8:	 validation loss: 0.0399	 validation accuracy: 0.9876
Epoch #9:	 validation loss: 0.0470	 validation accuracy: 0.9876
Epoch #10:	 validation loss: 0.0418	 validation accuracy: 0.9881


In [12]:
test_loss, accuracy = test(
    base_model, test_dl=test_loader, loss_function=loss_fn, device=device
)
print(f"Test Error: \n Accuracy: {accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

Test Error: 
 Accuracy: 98.8%, Avg loss: 0.040892 



Save the model

In [13]:
torch.save(base_model.state_dict(), "models/lenet_mnist.pth")

## Pruning Phase

### One shot pruning

In [14]:
FINE_TUNING_EPOCHS = 5

In [15]:
def get_parameters_to_prune(model: nn.Module) -> list[nn.Parameter]:
    return [
        (module, "weight")
        for module in model.modules()
        if isinstance(module, nn.Conv2d | nn.Linear)
    ]

In [16]:
PRUNING_VALUES = [0.2, 0.4, 0.6, 0.8, 0.9, 0.95]

In [17]:
results = []
for pruning_rate in PRUNING_VALUES:
    # load the model
    temp_model = Model().to(device)
    temp_model.load_state_dict(torch.load("models/lenet_mnist.pth"))
    model_parameters = get_parameters_to_prune(temp_model)

    # prune the model
    prune.global_unstructured(
        parameters=model_parameters,
        pruning_method=prune.L1Unstructured,
        amount=pruning_rate,
    )

    print(f"Pruning rate: {pruning_rate}")

    val_loss, val_accuracy = fit(
        model=temp_model,
        train_dl=train_loader,
        valid_dl=validation_loader,
        optimizer=optimizer,
        loss_function=loss_fn,
        epochs=2,
        device=device,
    )

    results.append((pruning_rate, val_loss, val_accuracy))

    for module, name in model_parameters:
        prune.remove(module, name)

    torch.save(temp_model.state_dict(), f"models/lenet_mnist_pruned_{pruning_rate}.pth")

Pruning rate: 0.2
Epoch #1:	 validation loss: 0.0417	 validation accuracy: 0.9878
Epoch #2:	 validation loss: 0.0417	 validation accuracy: 0.9878
Pruning rate: 0.4
Epoch #1:	 validation loss: 0.0414	 validation accuracy: 0.9878
Epoch #2:	 validation loss: 0.0414	 validation accuracy: 0.9878
Pruning rate: 0.6
Epoch #1:	 validation loss: 0.0421	 validation accuracy: 0.9864
Epoch #2:	 validation loss: 0.0421	 validation accuracy: 0.9864
Pruning rate: 0.8
Epoch #1:	 validation loss: 0.0562	 validation accuracy: 0.9842
Epoch #2:	 validation loss: 0.0562	 validation accuracy: 0.9842
Pruning rate: 0.9
Epoch #1:	 validation loss: 0.2220	 validation accuracy: 0.9493
Epoch #2:	 validation loss: 0.2220	 validation accuracy: 0.9493
Pruning rate: 0.95
Epoch #1:	 validation loss: 1.0751	 validation accuracy: 0.7079
Epoch #2:	 validation loss: 1.0751	 validation accuracy: 0.7079


In [18]:
print("Pruning Rate\t Validation Loss\t Validation Accuracy")
for result in results:
    print(f"{result[0]:.2f}\t\t {result[1]:.4f}\t\t {result[2]:.4f}")

Pruning Rate	 Validation Loss	 Validation Accuracy
0.20		 0.0417		 0.9878
0.40		 0.0414		 0.9878
0.60		 0.0421		 0.9864
0.80		 0.0562		 0.9842
0.90		 0.2220		 0.9493
0.95		 1.0751		 0.7079


### Iterative pruning

In [19]:
iterative_model = Model().to(device)
iterative_model.load_state_dict(torch.load("models/lenet_mnist.pth"))

<All keys matched successfully>

In [23]:
iterative_model_parameters = get_parameters_to_prune(iterative_model)

RANGE = 30
for iteration in range(RANGE):
    prune.global_unstructured(
        parameters=iterative_model_parameters,
        pruning_method=prune.L1Unstructured,
        amount=0.01,
    )

    val_loss, val_accuracy = fit(
        model=iterative_model,
        train_dl=train_loader,
        valid_dl=validation_loader,
        optimizer=optimizer,
        loss_function=loss_fn,
        epochs=1,
        device=device,
    )

    print(
        f"Iteration #{iteration + 1}:\t validation loss: {val_loss:.4f}\t validation accuracy: {val_accuracy:.4f}"
    )

Epoch #1:	 validation loss: 0.0415	 validation accuracy: 0.9879
Iteration #1:	 validation loss: 0.0415	 validation accuracy: 0.9879
Epoch #1:	 validation loss: 0.0414	 validation accuracy: 0.9879
Iteration #2:	 validation loss: 0.0414	 validation accuracy: 0.9879
Epoch #1:	 validation loss: 0.0416	 validation accuracy: 0.9877
Iteration #3:	 validation loss: 0.0416	 validation accuracy: 0.9877
Epoch #1:	 validation loss: 0.0416	 validation accuracy: 0.9877
Iteration #4:	 validation loss: 0.0416	 validation accuracy: 0.9877
Epoch #1:	 validation loss: 0.0418	 validation accuracy: 0.9878
Iteration #5:	 validation loss: 0.0418	 validation accuracy: 0.9878
Epoch #1:	 validation loss: 0.0417	 validation accuracy: 0.9876
Iteration #6:	 validation loss: 0.0417	 validation accuracy: 0.9876
Epoch #1:	 validation loss: 0.0417	 validation accuracy: 0.9877
Iteration #7:	 validation loss: 0.0417	 validation accuracy: 0.9877
Epoch #1:	 validation loss: 0.0414	 validation accuracy: 0.9878
Iteration #8

In [25]:
for module, name in iterative_model_parameters:
    prune.remove(module, name)

torch.save(
    iterative_model.state_dict(),
    f"models/lenet_mnist_pruned_iterative_pruned_0.{RANGE}.pth",
)

In [26]:
models = []
for file in Path("models").glob("*.pth"):
    model = Model().to(device)
    temp = torch.load(file)
    model.load_state_dict(temp)
    print(f"Loaded {file.stem}")
    models.append((file.stem, model))

Loaded lenet_mnist_pruned_0.2
Loaded lenet_mnist_pruned_0.9
Loaded lenet_mnist_pruned_0.4
Loaded lenet_mnist_pruned_0.6
Loaded lenet_mnist_pruned_0.8
Loaded lenet_mnist
Loaded lenet_mnist_pruned_0.95
Loaded lenet_mnist_pruned_iterative_pruned_0.30


In [29]:
for name, model in sorted(models, key=lambda x: x[0]):
    test_loss, accuracy = test(
        model, test_dl=test_loader, loss_function=loss_fn, device=device
    )
    print(f"Model {name}")
    print(f"Test Error: \n Accuracy: {accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

Model lenet_mnist
Test Error: 
 Accuracy: 98.8%, Avg loss: 0.040892 

Model lenet_mnist_pruned_0.2
Test Error: 
 Accuracy: 98.8%, Avg loss: 0.040704 

Model lenet_mnist_pruned_0.4
Test Error: 
 Accuracy: 98.8%, Avg loss: 0.041048 

Model lenet_mnist_pruned_0.6
Test Error: 
 Accuracy: 98.9%, Avg loss: 0.039547 

Model lenet_mnist_pruned_0.8
Test Error: 
 Accuracy: 98.4%, Avg loss: 0.051847 

Model lenet_mnist_pruned_0.9
Test Error: 
 Accuracy: 94.7%, Avg loss: 0.223231 

Model lenet_mnist_pruned_0.95
Test Error: 
 Accuracy: 70.6%, Avg loss: 1.055350 

Model lenet_mnist_pruned_iterative_pruned_0.30
Test Error: 
 Accuracy: 98.8%, Avg loss: 0.041454 

