In [34]:
import pandas as pd

import torch.nn.utils.prune as prune

import torch

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from typing import Callable, Tuple
from pathlib import Path

from dataclasses import dataclass, asdict

### Load the data

In [2]:
# load FashionMNIST data
transform = transforms.Compose([transforms.ToTensor()])

# split into validation and train datasets
train_ds = datasets.FashionMNIST("data", train=True, transform=transform, download=True)
train_ds, valid_ds = random_split(train_ds, [0.8, 0.2])

test_ds = datasets.FashionMNIST("data", train=False, transform=transform, download=True)

## Define the model architecture

In [3]:
# Define a simple CNN model
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(
            1, 6, kernel_size=5, stride=1, padding=2
        )  # 28*28->32*32-->28*28
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.flatten1 = nn.Flatten()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten1(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### Define utilities for training and testing

In [4]:
class EarlyStopper:
    def __init__(self, patience: int = 1, min_delta: int = 0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float("inf")

    def early_stop(self, validation_loss: float) -> bool:
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [5]:
# Define a function to train the model
def fit(
    model: nn.Module,
    train_dl,
    valid_dl,
    optimizer: optim.Optimizer,
    loss_function: Callable,
    epochs: int,
    early_stopper: EarlyStopper | None = None,
    device: torch.device = torch.device("cpu"),
) -> Tuple[float, float]:
    valid_loss = 0
    valid_accuracy = 0

    for epoch in range(epochs):
        model.train()
        for X, y in train_dl:
            X, y = X.to(device), y.to(device)

            # Compute prediction error
            pred = model(X)
            train_loss = loss_function(pred, y)

            # Backpropagation
            train_loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        valid_loss = 0
        valid_accuracy = 0
        with torch.no_grad():
            for X, y in valid_dl:
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = model(X)
                valid_loss += loss_function(pred, y)

                # Compute accuracy
                valid_accuracy += (pred.argmax(1) == y).float().mean()

        valid_loss /= len(valid_dl)
        valid_accuracy /= len(valid_dl)

        print(
            f"Epoch #{epoch + 1}:\t validation loss: {valid_loss:.4f}\t validation accuracy: {valid_accuracy:.4f}"
        )

        if early_stopper is not None and early_stopper.early_stop(valid_loss):
            print("Early stopping")
            return (valid_loss, valid_accuracy)

    return (valid_loss, valid_accuracy)

In [6]:
# Define a function to test the model
def test(
    model: nn.Module,
    test_dl,
    loss_function: Callable,
    device: torch.device = torch.device("cpu"),
) -> Tuple[float, float]:
    size = len(test_dl.dataset)
    num_batches = len(test_dl)
    model.eval()

    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in test_dl:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_function(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    accuracy = (correct / size) * 100

    return (test_loss, accuracy)

## Training Phase

In [7]:
# define the constants
BATCH_SIZE: int = 32
LEARNING_RATE: float = 0.01
EPOCHS: int = 10
MOMENTUM: float = 0.9

In [8]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [9]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
early_stopper = EarlyStopper(patience=3, min_delta=0)
optimizer = optim.SGD(base_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

# create the data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

### Training loop

In [10]:
valid_loss, valid_accuracy = fit(
    base_model,
    train_dl=train_loader,
    valid_dl=validation_loader,
    optimizer=optimizer,
    loss_function=loss_fn,
    epochs=EPOCHS,
    device=device,
)

  return F.conv2d(input, weight, bias, self.stride,


Epoch #1:	 validation loss: 0.4725	 validation accuracy: 0.8167
Epoch #2:	 validation loss: 0.4276	 validation accuracy: 0.8394
Epoch #3:	 validation loss: 0.3471	 validation accuracy: 0.8722
Epoch #4:	 validation loss: 0.3434	 validation accuracy: 0.8725
Epoch #5:	 validation loss: 0.3186	 validation accuracy: 0.8820
Epoch #6:	 validation loss: 0.3093	 validation accuracy: 0.8837
Epoch #7:	 validation loss: 0.2925	 validation accuracy: 0.8913
Epoch #8:	 validation loss: 0.3004	 validation accuracy: 0.8886
Epoch #9:	 validation loss: 0.2960	 validation accuracy: 0.8902
Epoch #10:	 validation loss: 0.2934	 validation accuracy: 0.8943


In [11]:
test_loss, accuracy = test(
    base_model, test_dl=test_loader, loss_function=loss_fn, device=device
)
print(f"Test Error: \n Accuracy: {accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

Test Error: 
 Accuracy: 88.8%, Avg loss: 0.312858 



In [12]:
for i, module in enumerate(base_model.modules()):
    print(i, module)

0 LeNet(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (flatten1): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)
1 Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
2 AvgPool2d(kernel_size=2, stride=2, padding=0)
3 Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
4 AvgPool2d(kernel_size=2, stride=2, padding=0)
5 Flatten(start_dim=1, end_dim=-1)
6 Linear(in_features=400, out_features=120, bias=True)
7 Linear(in_features=120, out_features=84, bias=True)
8 Linear(in_features=84, out_features=10, bias=True)


Save the model

In [13]:
torch.save(base_model.state_dict(), f"models/{type(base_model).__name__}_fmnist.pth")

## Pruning Phase

### One shot pruning

In [14]:
def get_model_sparsity(m: nn.Module) -> float:
    """Get the sparsity of the model

    Args:
        model (nn.Module): The model to get the sparsity of

    Returns:
        float: percentage of weights that are zero
    """

    total_weights = 0
    total_zero_weights = 0
    for _, module in m.named_children():
        for param_name, param in module.named_parameters():
            if "weight" in param_name:
                total_weights += float(param.nelement())
                total_zero_weights += float(torch.sum(param == 0))

    sparsity = 100.0 * total_zero_weights / total_weights
    return sparsity


def get_layers_sparsity(model: nn.Module) -> list[tuple[str, float]]:
    """Get the sparsity of each layer in the model

    Args:
        model (nn.Module): The model to get the sparsity of

    Returns:
        list[tuple[str, float]]: List of tuples containing the layer name and the sparsity
    """

    layers_sparsity = []
    for layer_name, module in model.named_children():
        for param_name, param in module.named_parameters():
            if "weight" in param_name:
                layer_sparsity = (
                    100.0 * float(torch.sum(param == 0)) / float(param.nelement())
                )
                layers_sparsity.append((layer_name, layer_sparsity))

    return layers_sparsity

In [15]:
def get_parameters_to_prune(model: nn.Module) -> list[nn.Parameter]:
    return [
        (module, "weight")
        for module in model.modules()
        if isinstance(module, nn.Conv2d | nn.Linear)
    ]

In [16]:
PRUNING_VALUES = [0.2, 0.4, 0.6, 0.8, 0.9, 0.95]
PRUNING_METHODS = {
    "RandomUnstructured": prune.RandomUnstructured,
    "L1Unstructured": prune.L1Unstructured,
}

In [17]:
@dataclass
class PruningResult:
    method: str
    pruning_rate: float
    val_accuracy: float
    val_loss: float

In [19]:
results = []
for pruning_rate in PRUNING_VALUES:
    for method_name, method in PRUNING_METHODS.items():
        # load the model
        temp_model = LeNet().to(device)
        temp_model.load_state_dict(torch.load("models/LeNet_fmnist.pth"))
        model_parameters = get_parameters_to_prune(temp_model)

        loss_fn = nn.CrossEntropyLoss()
        optimizer = optim.SGD(
            temp_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM
        )

        # create the data loaders
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
        validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
        test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

        print(f"Pre-pruning sparsity: {100 - get_model_sparsity(temp_model):.2f}%")

        # prune the model
        prune.global_unstructured(
            parameters=model_parameters,
            pruning_method=method,
            amount=pruning_rate,
        )

        print(f"Pruning rate: {pruning_rate}, method: {method_name}")

        val_loss, val_accuracy = fit(
            model=temp_model,
            train_dl=train_loader,
            valid_dl=validation_loader,
            optimizer=optimizer,
            loss_function=loss_fn,
            epochs=3,
            device=device,
        )

        results.append(
            PruningResult(
                method=method_name,
                pruning_rate=pruning_rate,
                val_accuracy=val_accuracy,
                val_loss=val_loss,
            )
        )

        for module, name in model_parameters:
            prune.remove(module, name)

        print(f"Post-pruning sparsity: {100 - get_model_sparsity(temp_model):.2f}%")
        torch.save(
            temp_model.state_dict(),
            f"models/{type(temp_model).__name__}_pruned_{pruning_rate}_{method_name}.pth",
        )

Pre-pruning sparsity: 100.00%
Pruning rate: 0.2, method: RandomUnstructured
Epoch #1:	 validation loss: 0.2953	 validation accuracy: 0.8905
Epoch #2:	 validation loss: 0.2848	 validation accuracy: 0.8933
Epoch #3:	 validation loss: 0.2930	 validation accuracy: 0.8933
Post-pruning sparsity: 80.00%
Pre-pruning sparsity: 100.00%
Pruning rate: 0.2, method: L1Unstructured
Epoch #1:	 validation loss: 0.3034	 validation accuracy: 0.8901
Epoch #2:	 validation loss: 0.2719	 validation accuracy: 0.8984
Epoch #3:	 validation loss: 0.2677	 validation accuracy: 0.9031
Post-pruning sparsity: 80.00%
Pre-pruning sparsity: 100.00%
Pruning rate: 0.4, method: RandomUnstructured
Epoch #1:	 validation loss: 0.3284	 validation accuracy: 0.8778
Epoch #2:	 validation loss: 0.3062	 validation accuracy: 0.8879
Epoch #3:	 validation loss: 0.2904	 validation accuracy: 0.8940
Post-pruning sparsity: 60.00%
Pre-pruning sparsity: 100.00%
Pruning rate: 0.4, method: L1Unstructured
Epoch #1:	 validation loss: 0.2796	 va

In [36]:
print("Method\t Pruning Rate\t Validation Loss\t Validation Accuracy")
for result in results:
    print(result)

Method	 Pruning Rate	 Validation Loss	 Validation Accuracy
PruningResult(method='RandomUnstructured', pruning_rate=0.2, val_accuracy=tensor(0.8933, device='cuda:0'), val_loss=0.29304863429069516)
PruningResult(method='L1Unstructured', pruning_rate=0.2, val_accuracy=tensor(0.9031, device='cuda:0'), val_loss=0.26767243194580076)
PruningResult(method='RandomUnstructured', pruning_rate=0.4, val_accuracy=tensor(0.8940, device='cuda:0'), val_loss=0.29040530423323313)
PruningResult(method='L1Unstructured', pruning_rate=0.4, val_accuracy=tensor(0.8997, device='cuda:0'), val_loss=0.274792817885677)
PruningResult(method='RandomUnstructured', pruning_rate=0.6, val_accuracy=tensor(0.8795, device='cuda:0'), val_loss=0.3278882878422737)
PruningResult(method='L1Unstructured', pruning_rate=0.6, val_accuracy=tensor(0.8947, device='cuda:0'), val_loss=0.28920095206797125)
PruningResult(method='RandomUnstructured', pruning_rate=0.8, val_accuracy=tensor(0.8488, device='cuda:0'), val_loss=0.4138503221273422

### Iterative pruning

In [None]:
RANGE: int = 20
ITER_PRUNING_AMOUNT: float = 0.01
for method_name, method in PRUNING_METHODS.items():
    print(
        f"Iterative pruning using {method_name} for {RANGE} iterations with amount {ITER_PRUNING_AMOUNT}"
    )

    iterative_model = LeNet().to(device)
    iterative_model.load_state_dict(torch.load("models/LeNet_fmnist.pth"))

    iterative_model_parameters = get_parameters_to_prune(iterative_model)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(temp_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

    # create the data loaders
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

    for iteration in range(RANGE):
        prune.global_unstructured(
            parameters=iterative_model_parameters,
            pruning_method=method,
            amount=ITER_PRUNING_AMOUNT,
        )

        val_loss, val_accuracy = fit(
            model=iterative_model,
            train_dl=train_loader,
            valid_dl=validation_loader,
            optimizer=optimizer,
            loss_function=loss_fn,
            epochs=1,
            device=device,
        )

        print(
            f"Iteration #{iteration + 1}:\t validation loss: {val_loss:.4f}\t validation accuracy: {val_accuracy:.4f}"
        )

    for module, name in iterative_model_parameters:
        prune.remove(module, name)

    torch.save(
        iterative_model.state_dict(),
        f"models/{type(iterative_model).__name__}_iterative_pruned_0.{RANGE}_{method_name}.pth",
    )

Epoch #1:	 validation loss: 0.2889	 validation accuracy: 0.8903
Iteration #1:	 validation loss: 0.2889	 validation accuracy: 0.8903
Epoch #1:	 validation loss: 0.2889	 validation accuracy: 0.8902
Iteration #2:	 validation loss: 0.2889	 validation accuracy: 0.8902
Epoch #1:	 validation loss: 0.2889	 validation accuracy: 0.8903
Iteration #3:	 validation loss: 0.2889	 validation accuracy: 0.8903
Epoch #1:	 validation loss: 0.2890	 validation accuracy: 0.8903
Iteration #4:	 validation loss: 0.2890	 validation accuracy: 0.8903
Epoch #1:	 validation loss: 0.2890	 validation accuracy: 0.8903
Iteration #5:	 validation loss: 0.2890	 validation accuracy: 0.8903
Epoch #1:	 validation loss: 0.2889	 validation accuracy: 0.8904
Iteration #6:	 validation loss: 0.2889	 validation accuracy: 0.8904
Epoch #1:	 validation loss: 0.2890	 validation accuracy: 0.8905
Iteration #7:	 validation loss: 0.2890	 validation accuracy: 0.8905
Epoch #1:	 validation loss: 0.2889	 validation accuracy: 0.8906
Iteration #8

## Load and test the models

In [38]:
models = []
for file in Path("models").glob("*.pth"):
    model = LeNet().to(device)
    temp = torch.load(file)
    model.load_state_dict(temp)
    print(f"Loaded {file.stem}")
    models.append((file.stem, model))

Loaded LeNet_pruned_0.2_L1Unstructured
Loaded LeNet_pruned_0.8_RandomUnstructured
Loaded LeNet_pruned_0.2_RandomUnstructured
Loaded LeNet_fmnist
Loaded LeNet_pruned_0.8_L1Unstructured
Loaded LeNet_pruned_0.4_RandomUnstructured
Loaded LeNet_pruned_0.9_RandomUnstructured
Loaded LeNet_pruned_0.95_L1Unstructured
Loaded LeNet_pruned_0.4_L1Unstructured
Loaded LeNet_pruned_0.6_L1Unstructured
Loaded LeNet_pruned_0.9_L1Unstructured
Loaded LeNet_pruned_0.95_RandomUnstructured
Loaded LeNet_pruned_0.6_RandomUnstructured


In [39]:
# create the data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)
loss_fn = nn.CrossEntropyLoss()

In [40]:
for name, model in sorted(models, key=lambda x: x[0]):
    test_loss, accuracy = test(
        model, test_dl=test_loader, loss_function=loss_fn, device=device
    )
    print(f"Model {name}")
    print(f"Test Error: \n Accuracy: {accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

Model LeNet_fmnist
Test Error: 
 Accuracy: 88.8%, Avg loss: 0.312858 

Model LeNet_pruned_0.2_L1Unstructured
Test Error: 
 Accuracy: 90.1%, Avg loss: 0.292792 

Model LeNet_pruned_0.2_RandomUnstructured
Test Error: 
 Accuracy: 88.9%, Avg loss: 0.312677 

Model LeNet_pruned_0.4_L1Unstructured
Test Error: 
 Accuracy: 89.8%, Avg loss: 0.295847 

Model LeNet_pruned_0.4_RandomUnstructured
Test Error: 
 Accuracy: 89.0%, Avg loss: 0.300837 

Model LeNet_pruned_0.6_L1Unstructured
Test Error: 
 Accuracy: 89.4%, Avg loss: 0.303628 

Model LeNet_pruned_0.6_RandomUnstructured
Test Error: 
 Accuracy: 87.5%, Avg loss: 0.346925 

Model LeNet_pruned_0.8_L1Unstructured
Test Error: 
 Accuracy: 89.7%, Avg loss: 0.290749 

Model LeNet_pruned_0.8_RandomUnstructured
Test Error: 
 Accuracy: 83.6%, Avg loss: 0.436412 

Model LeNet_pruned_0.95_L1Unstructured
Test Error: 
 Accuracy: 88.5%, Avg loss: 0.320566 

Model LeNet_pruned_0.95_RandomUnstructured
Test Error: 
 Accuracy: 78.6%, Avg loss: 0.581298 

Model L

### Print the weights of the models

In [None]:
for name, model in sorted(models, key=lambda x: x[0]):
    print(f"Model {name}")
    print("#" * 10)

    print("Layer Sparsity:")
    for layer_name, layer_sparsity in get_layers_sparsity(model):
        print(f"{layer_name}: {100 - layer_sparsity:.2f}%")
    print(f"Model Sparsity: {100 - get_model_sparsity(model)}%")

Model LeNet_fmnist
##########
Layer Sparsity:
conv1: 100.00%
conv2: 100.00%
fc1: 100.00%
fc2: 100.00%
fc3: 100.00%
Model Sparsity: 100.0%
Model LeNet_pruned_0.2
##########
Layer Sparsity:
conv1: 95.33%
conv2: 89.62%
fc1: 77.76%
fc2: 86.80%
fc3: 95.95%
Model Sparsity: 80.0%
Model LeNet_pruned_0.4
##########
Layer Sparsity:
conv1: 94.00%
conv2: 80.50%
fc1: 55.49%
fc2: 73.50%
fc3: 90.95%
Model Sparsity: 60.0%
Model LeNet_pruned_0.6
##########
Layer Sparsity:
conv1: 90.00%
conv2: 69.92%
fc1: 33.73%
fc2: 58.30%
fc3: 84.40%
Model Sparsity: 40.0%
Model LeNet_pruned_0.8
##########
Layer Sparsity:
conv1: 83.33%
conv2: 50.75%
fc1: 15.16%
fc2: 30.66%
fc3: 69.17%
Model Sparsity: 20.0%
Model LeNet_pruned_0.9
##########
Layer Sparsity:
conv1: 75.33%
conv2: 35.38%
fc1: 6.55%
fc2: 15.65%
fc3: 54.88%
Model Sparsity: 10.0%
Model LeNet_pruned_0.95
##########
Layer Sparsity:
conv1: 67.33%
conv2: 24.29%
fc1: 2.63%
fc2: 7.53%
fc3: 43.81%
Model Sparsity: 5.000813404912961%
Model LeNet_pruned_iterative_pruned