In [1]:
import torch.nn.utils.prune as prune

import torch

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from typing import Iterable
from pathlib import Path

import utility

In [2]:
MODELS_PATH: str = "models"


def save_model(model: nn.Module, name: str):
    torch.save(model.state_dict(), f"{MODELS_PATH}/{name}.pth")

### Load the data

In [3]:
# load FashionMNIST data
transform = transforms.Compose([transforms.ToTensor()])

# split into validation and train datasets
train_ds = datasets.FashionMNIST("data", train=True, transform=transform, download=True)
train_ds, valid_ds = random_split(train_ds, [0.8, 0.2])

test_ds = datasets.FashionMNIST("data", train=False, transform=transform, download=True)

## Define the model architecture

In [4]:
# Define a simple CNN model
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(
            1, 6, kernel_size=5, stride=1, padding=2
        )  # 28*28->32*32-->28*28
        self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1)
        self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.flatten1 = nn.Flatten()
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten1(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Training Phase

In [5]:
# define the constants
BATCH_SIZE: int = 32
LEARNING_RATE: float = 0.01
EPOCHS: int = 50
MOMENTUM: float = 0.9

In [6]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [7]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
cross_entropy = nn.CrossEntropyLoss()
early_stopper = utility.early_stopping.EarlyStopper(patience=3, min_delta=0)
optimizer = optim.SGD(base_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

# create the data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

### Training loop

In [8]:
for epoch in range(EPOCHS):
    train_loss = utility.training.train_epoch(
        module=base_model,
        train_dl=train_loader,
        optimizer=optimizer,
        loss_function=cross_entropy,
        device=device,
    )

    valid_loss, valid_accuracy = utility.training.validate(
        module=base_model,
        valid_dl=validation_loader,
        loss_function=cross_entropy,
        device=device,
    )

    print(
        f"Epoch: {epoch:}\nTrain Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
    )

    if early_stopper.early_stop(valid_loss):
        print("Early stopping")
        break

  return F.conv2d(input, weight, bias, self.stride,


Epoch: 0
Train Loss: 0.7884, Valid Loss: 0.4886, Valid Accuracy: 0.8151
Epoch: 1
Train Loss: 0.4554, Valid Loss: 0.4038, Valid Accuracy: 0.8444
Epoch: 2
Train Loss: 0.3905, Valid Loss: 0.3699, Valid Accuracy: 0.8601
Epoch: 3
Train Loss: 0.3568, Valid Loss: 0.3417, Valid Accuracy: 0.8772
Epoch: 4
Train Loss: 0.3247, Valid Loss: 0.3019, Valid Accuracy: 0.8880
Epoch: 5
Train Loss: 0.3057, Valid Loss: 0.2951, Valid Accuracy: 0.8907
Epoch: 6
Train Loss: 0.2894, Valid Loss: 0.2982, Valid Accuracy: 0.8910
Epoch: 7
Train Loss: 0.2729, Valid Loss: 0.2820, Valid Accuracy: 0.8919
Epoch: 8
Train Loss: 0.2611, Valid Loss: 0.2865, Valid Accuracy: 0.8935
Epoch: 9
Train Loss: 0.2524, Valid Loss: 0.2865, Valid Accuracy: 0.8964
Epoch: 10
Train Loss: 0.2433, Valid Loss: 0.2831, Valid Accuracy: 0.8967
Early stopping


In [9]:
test_loss, accuracy = utility.training.test(
    base_model, test_dl=test_loader, loss_function=cross_entropy, device=device
)
print(f"Test Error: \n Accuracy: {accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

Test Error: 
 Accuracy: 88.9%, Avg loss: 0.302720 



Save the model

In [10]:
save_model(base_model, f"{type(base_model).__name__}_fmnist")

## Pruning Phase

In [11]:
def get_parameters_to_prune(model: nn.Module) -> list[nn.Parameter]:
    return [
        (module, "weight")
        for module in model.modules()
        if isinstance(module, nn.Conv2d | nn.Linear)
    ]


def calculate_total_sparsity(
    module: nn.Module, parameters_to_prune: Iterable[tuple[nn.Module, str]]
) -> float:
    total_weights = 0
    total_zero_weights = 0

    pruned_parameters: set[tuple[nn.Module, str]] = set(parameters_to_prune)

    for _, module in module.named_children():
        for param_name, param in module.named_parameters():
            if (module, param_name) not in pruned_parameters:
                continue

            if "weight" in param_name:
                total_weights += float(param.nelement())
                total_zero_weights += float(torch.sum(param == 0))

    sparsity = 100.0 * total_zero_weights / total_weights
    return sparsity

### One shot pruning

In [12]:
PRUNING_VALUES = [0.2, 0.4, 0.6, 0.8, 0.9, 0.95]
PRUNING_METHODS = {
    "RandomUnstructured": prune.RandomUnstructured,
    "L1Unstructured": prune.L1Unstructured,
}

In [13]:
def calculate_parameters_amount(modules: Iterable[tuple[nn.Module, str]]) -> int:
    """Calculate the total amount of parameters in a list of modules.

    Args:
        modules (Iterable[tuple[nn.Module, str]]): List of modules and the parameter names.

    Returns:
        int: The total amount of parameters.
    """

    total_parameters = 0
    for module, parameter in modules:
        for param_name, param in module.named_parameters():
            if param_name == parameter:
                total_parameters += param.nelement()

    return total_parameters

In [14]:
results = []

for pruning_rate in PRUNING_VALUES:
    pruning_value = int(
        calculate_parameters_amount(get_parameters_to_prune(base_model)) * pruning_rate
    )

    for method_name, method in PRUNING_METHODS.items():
        # load the model
        temp_model = LeNet().to(device)
        temp_model.load_state_dict(torch.load("models/LeNet_fmnist.pth"))
        model_parameters = get_parameters_to_prune(temp_model)

        optimizer = optim.SGD(
            temp_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM
        )
        early_stopper = utility.early_stopping.EarlyStopper(patience=3, min_delta=0)

        # prune the model
        prune.global_unstructured(
            parameters=model_parameters,
            pruning_method=method,
            amount=pruning_value,
        )

        print(f"Pruning rate: {pruning_rate}, method: {method_name}")

        valid_loss, valid_accuracy = utility.training.validate(
            module=temp_model,
            valid_dl=validation_loader,
            loss_function=cross_entropy,
            device=device,
        )

        print(
            f"After pruning:\nValid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
        )

        # retrain the model
        print("Retraining the model")
        for epoch in range(EPOCHS):
            utility.training.train_epoch(
                module=temp_model,
                train_dl=train_loader,
                optimizer=optimizer,
                loss_function=cross_entropy,
                device=device,
            )

            valid_loss, valid_accuracy = utility.training.validate(
                module=temp_model,
                valid_dl=validation_loader,
                loss_function=cross_entropy,
                device=device,
            )

            print(
                f"Epoch: {epoch:}\nTrain Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
            )

            if early_stopper.early_stop(valid_loss):
                print("Early stopping")
                break

        for module, name in model_parameters:
            prune.remove(module, name)

        torch.save(
            temp_model.state_dict(),
            f"models/{type(temp_model).__name__}_pruned_{pruning_rate}_{method_name}.pth",
        )

Pruning rate: 0.2, method: RandomUnstructured
After pruning:
Valid Loss: 0.5672, Valid Accuracy: 0.7842
Retraining the model
Epoch: 0
Train Loss: 0.2433, Valid Loss: 0.2702, Valid Accuracy: 0.8997
Epoch: 1
Train Loss: 0.2433, Valid Loss: 0.2985, Valid Accuracy: 0.8941
Epoch: 2
Train Loss: 0.2433, Valid Loss: 0.2777, Valid Accuracy: 0.8969
Epoch: 3
Train Loss: 0.2433, Valid Loss: 0.2550, Valid Accuracy: 0.9051
Epoch: 4
Train Loss: 0.2433, Valid Loss: 0.2823, Valid Accuracy: 0.8952
Epoch: 5
Train Loss: 0.2433, Valid Loss: 0.2671, Valid Accuracy: 0.9010
Epoch: 6
Train Loss: 0.2433, Valid Loss: 0.2665, Valid Accuracy: 0.9049
Early stopping
Pruning rate: 0.2, method: L1Unstructured
After pruning:
Valid Loss: 0.2829, Valid Accuracy: 0.8968
Retraining the model
Epoch: 0
Train Loss: 0.2433, Valid Loss: 0.2832, Valid Accuracy: 0.8974
Epoch: 1
Train Loss: 0.2433, Valid Loss: 0.2741, Valid Accuracy: 0.9022
Epoch: 2
Train Loss: 0.2433, Valid Loss: 0.2821, Valid Accuracy: 0.8997
Epoch: 3
Train Loss

### Iterative pruning

In [15]:
RANGE: int = 20
ITER_PRUNING_RATE: float = 0.01

pruning_value = int(
    calculate_parameters_amount(get_parameters_to_prune(base_model)) * ITER_PRUNING_RATE
)

for method_name, method in PRUNING_METHODS.items():
    print(
        "-" * 20,
        f"Iterative pruning using {method_name} for {RANGE} iterations with amount {pruning_value}",
        "-" * 20,
        sep="\n",
    )

    iterative_model = LeNet().to(device)
    iterative_model.load_state_dict(torch.load("models/LeNet_fmnist.pth"))

    iterative_model_parameters = get_parameters_to_prune(iterative_model)
    optimizer = optim.SGD(temp_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

    for iteration in range(RANGE):
        prune.global_unstructured(
            parameters=iterative_model_parameters,
            pruning_method=method,
            amount=pruning_value,
        )

        train_loss = utility.training.train_epoch(
            module=iterative_model,
            train_dl=train_loader,
            optimizer=optimizer,
            loss_function=cross_entropy,
            device=device,
        )

        val_loss, val_accuracy = utility.training.validate(
            module=iterative_model,
            valid_dl=validation_loader,
            loss_function=cross_entropy,
            device=device,
        )

        print(
            f"Iteration #{iteration + 1}:\t validation loss: {val_loss:.4f}\t validation accuracy: {val_accuracy:.4f}"
        )

    for module, name in iterative_model_parameters:
        prune.remove(module, name)

    save_model(
        iterative_model,
        f"{type(iterative_model).__name__}_iterative_pruned_0.{RANGE}_{method_name}",
    )

--------------------
Iterative pruning using RandomUnstructured for 20 iterations with amount 614
--------------------
Iteration #1:	 validation loss: 0.2971	 validation accuracy: 0.8908
Iteration #2:	 validation loss: 0.4532	 validation accuracy: 0.8253
Iteration #3:	 validation loss: 0.3920	 validation accuracy: 0.8545
Iteration #4:	 validation loss: 0.4299	 validation accuracy: 0.8335
Iteration #5:	 validation loss: 0.4612	 validation accuracy: 0.8205
Iteration #6:	 validation loss: 0.4395	 validation accuracy: 0.8308
Iteration #7:	 validation loss: 0.4811	 validation accuracy: 0.8157
Iteration #8:	 validation loss: 0.4731	 validation accuracy: 0.8187
Iteration #9:	 validation loss: 0.4953	 validation accuracy: 0.8127
Iteration #10:	 validation loss: 0.5172	 validation accuracy: 0.8096
Iteration #11:	 validation loss: 0.5368	 validation accuracy: 0.8055
Iteration #12:	 validation loss: 0.5637	 validation accuracy: 0.7971
Iteration #13:	 validation loss: 0.5623	 validation accuracy: 

## Load and test the models

In [16]:
models = []
for file in Path("models").glob("*.pth"):
    model = LeNet().to(device)
    temp = torch.load(file)
    model.load_state_dict(temp)
    print(f"Loaded {file.stem}")
    models.append((file.stem, model))

Loaded LeNet_pruned_0.2_L1Unstructured
Loaded LeNet_pruned_0.8_RandomUnstructured
Loaded LeNet_pruned_0.2_RandomUnstructured
Loaded LeNet_iterative_pruned_0.20_RandomUnstructured
Loaded LeNet_fmnist
Loaded LeNet_pruned_0.8_L1Unstructured
Loaded LeNet_pruned_0.4_RandomUnstructured
Loaded LeNet_pruned_0.9_RandomUnstructured
Loaded LeNet_iterative_pruned_0.20_L1Unstructured
Loaded LeNet_pruned_0.95_L1Unstructured
Loaded LeNet_pruned_0.4_L1Unstructured
Loaded LeNet_pruned_0.6_L1Unstructured
Loaded LeNet_pruned_0.9_L1Unstructured
Loaded LeNet_pruned_0.95_RandomUnstructured
Loaded LeNet_pruned_0.6_RandomUnstructured


In [17]:
results = []
for name, model in sorted(models, key=lambda x: x[0]):
    test_loss, accuracy = utility.training.test(
        model=model, test_dl=test_loader, loss_function=cross_entropy, device=device
    )
    results.append((name, accuracy))

In [18]:
for name, accuracy in results:
    print(f"Model {name:50s} accuracy: {accuracy:.2f}%")

Model LeNet_fmnist                                       accuracy: 88.95%
Model LeNet_iterative_pruned_0.20_L1Unstructured         accuracy: 88.98%
Model LeNet_iterative_pruned_0.20_RandomUnstructured     accuracy: 65.58%
Model LeNet_pruned_0.2_L1Unstructured                    accuracy: 90.04%
Model LeNet_pruned_0.2_RandomUnstructured                accuracy: 89.98%
Model LeNet_pruned_0.4_L1Unstructured                    accuracy: 89.90%
Model LeNet_pruned_0.4_RandomUnstructured                accuracy: 90.04%
Model LeNet_pruned_0.6_L1Unstructured                    accuracy: 89.79%
Model LeNet_pruned_0.6_RandomUnstructured                accuracy: 89.64%
Model LeNet_pruned_0.8_L1Unstructured                    accuracy: 89.15%
Model LeNet_pruned_0.8_RandomUnstructured                accuracy: 87.65%
Model LeNet_pruned_0.95_L1Unstructured                   accuracy: 88.74%
Model LeNet_pruned_0.95_RandomUnstructured               accuracy: 77.63%
Model LeNet_pruned_0.9_L1Unstructured 

Print the model sparsity

In [19]:
for name, model in models:
    print(f"Calculating sparsity for {name}")
    print(
        f"Total sparsity: {100 - calculate_total_sparsity(model, get_parameters_to_prune(model)):.2f}%"
    )
    print("-" * 20)

Calculating sparsity for LeNet_pruned_0.2_L1Unstructured
Total sparsity: 80.00%
--------------------
Calculating sparsity for LeNet_pruned_0.8_RandomUnstructured
Total sparsity: 20.00%
--------------------
Calculating sparsity for LeNet_pruned_0.2_RandomUnstructured
Total sparsity: 80.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.20_RandomUnstructured
Total sparsity: 80.02%
--------------------
Calculating sparsity for LeNet_fmnist
Total sparsity: 100.00%
--------------------
Calculating sparsity for LeNet_pruned_0.8_L1Unstructured
Total sparsity: 20.00%
--------------------
Calculating sparsity for LeNet_pruned_0.4_RandomUnstructured
Total sparsity: 60.00%
--------------------
Calculating sparsity for LeNet_pruned_0.9_RandomUnstructured
Total sparsity: 10.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.20_L1Unstructured
Total sparsity: 80.02%
--------------------
Calculating sparsity for LeNet_pruned_0.95_L1Unstructured
Total spa