In [1]:
import torch.nn.utils.prune as prune

import torch

import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from typing import Iterable
from pathlib import Path
from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE, LEARNING_RATE, EPOCHS, MOMENTUM
from pruning_metadata import PruningMetadata

from dataclasses import asdict

import utility
import itertools
import copy

### Load the data

In [2]:
# load FashionMNIST data
transform = transforms.Compose([transforms.ToTensor()])

# split into validation and train datasets
train_ds = datasets.FashionMNIST(
    DATA_PATH, train=True, transform=transform, download=True
)
train_ds, valid_ds = random_split(train_ds, [0.8, 0.2])

test_ds = datasets.FashionMNIST(
    DATA_PATH, train=False, transform=transform, download=True
)

In [3]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [4]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
cross_entropy = nn.CrossEntropyLoss()
early_stopper = utility.early_stopping.EarlyStopper(patience=3, min_delta=0)
optimizer = optim.SGD(base_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

# create the data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

## Pruning Phase

In [5]:
def get_parameters_to_prune(model: nn.Module) -> list[nn.Parameter]:
    return [
        (module, "weight")
        for module in model.modules()
        if isinstance(module, nn.Conv2d | nn.Linear)
    ]


def calculate_total_sparsity(
    module: nn.Module, parameters_to_prune: Iterable[tuple[nn.Module, str]]
) -> float:
    total_weights = 0
    total_zero_weights = 0

    pruned_parameters: set[tuple[nn.Module, str]] = set(parameters_to_prune)

    for _, module in module.named_children():
        for param_name, param in module.named_parameters():
            if (module, param_name) not in pruned_parameters:
                continue

            if "weight" in param_name:
                total_weights += float(param.nelement())
                total_zero_weights += float(torch.sum(param == 0))

    sparsity = 100.0 * total_zero_weights / total_weights
    return sparsity

In [6]:
def calculate_parameters_amount(modules: Iterable[tuple[nn.Module, str]]) -> int:
    """Calculate the total amount of parameters in a list of modules.

    Args:
        modules (Iterable[tuple[nn.Module, str]]): List of modules and the parameter names.

    Returns:
        int: The total amount of parameters.
    """

    total_parameters = 0
    for module, parameter in modules:
        for param_name, param in module.named_parameters():
            if param_name == parameter:
                total_parameters += param.nelement()

    return total_parameters

### One shot pruning

In [7]:
PRUNING_VALUES = [0.2, 0.4, 0.6, 0.8, 0.88, 0.92, 0.96]
PRUNING_NAME_TO_METHOD = {
    "RandomUnstructured": prune.RandomUnstructured,
    "L1Unstructured": prune.L1Unstructured,
}

In [8]:
one_shot_configs = []

for pruning_rate in PRUNING_VALUES:
    one_shot_configs.append(
        PruningMetadata(
            total_pruned=pruning_rate,
            pruning_step=pruning_rate,
            finetune_epochs=None,  # to be filled in later
            method=prune.L1Unstructured,
            early_stopping=True,
        )
    )

In [9]:
results = []

for config in one_shot_configs:
    pruning_rate = config.pruning_step
    method = config.method
    early_stopping = config.early_stopping
    fine_tune_epochs = config.finetune_epochs

    pruning_value = int(
        round(
            calculate_parameters_amount(get_parameters_to_prune(base_model))
            * pruning_rate
        )
    )

    # load the model
    temp_model = LeNet().to(device)
    temp_model.load_state_dict(
        torch.load(MODELS_PATH / Path("LeNet_fmnist/LeNet_fmnist.pth"))
    )
    model_name = temp_model.__class__.__name__
    model_parameters = get_parameters_to_prune(temp_model)

    optimizer = optim.SGD(temp_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
    early_stopper = utility.early_stopping.EarlyStopper(patience=3, min_delta=0)

    # prune the model
    prune.global_unstructured(
        parameters=model_parameters,
        pruning_method=method,
        amount=pruning_value,
    )

    print(f"Pruning rate: {pruning_rate}, method: {method.__name__}")

    valid_loss, valid_accuracy = utility.training.validate(
        module=temp_model,
        valid_dl=validation_loader,
        loss_function=cross_entropy,
        device=device,
    )

    print(
        f"After pruning:\nValid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
    )

    # retrain the model
    print("Retraining the model")
    for epoch in range(EPOCHS):
        train_loss = utility.training.train_epoch(
            module=temp_model,
            train_dl=train_loader,
            optimizer=optimizer,
            loss_function=cross_entropy,
            device=device,
        )

        valid_loss, valid_accuracy = utility.training.validate(
            module=temp_model,
            valid_dl=validation_loader,
            loss_function=cross_entropy,
            device=device,
        )

        print(
            f"Epoch: {epoch:}\nTrain Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
        )

        if early_stopping and early_stopper.early_stop(valid_loss):
            last_epoch = epoch
            print("Early stopping")
            break

    for module, name in model_parameters:
        prune.remove(module, name)

    config.finetune_epochs = last_epoch + 1
    config.method = method.__name__
    utility.save.save_model_with_metadata(
        model=temp_model,
        path=f"{MODELS_PATH}/{model_name}_pruned_{pruning_rate}_{method.__name__}",
        model_name=f"{model_name}_pruned_{pruning_rate}_{method.__name__}",
        metadata=asdict(config),
    )

Pruning rate: 0.2, method: L1Unstructured


After pruning:
Valid Loss: 0.2376, Valid Accuracy: 0.9088
Retraining the model
Epoch: 0
Train Loss: 0.2286, Valid Loss: 0.2254, Valid Accuracy: 0.9148
Epoch: 1
Train Loss: 0.2167, Valid Loss: 0.2390, Valid Accuracy: 0.9118
Epoch: 2
Train Loss: 0.2098, Valid Loss: 0.2336, Valid Accuracy: 0.9118
Epoch: 3
Train Loss: 0.2016, Valid Loss: 0.2337, Valid Accuracy: 0.9145
Early stopping
Pruning rate: 0.4, method: L1Unstructured
After pruning:
Valid Loss: 0.2384, Valid Accuracy: 0.9078
Retraining the model
Epoch: 0
Train Loss: 0.2241, Valid Loss: 0.2423, Valid Accuracy: 0.9085
Epoch: 1
Train Loss: 0.2128, Valid Loss: 0.2403, Valid Accuracy: 0.9087
Epoch: 2
Train Loss: 0.2040, Valid Loss: 0.2318, Valid Accuracy: 0.9133
Epoch: 3
Train Loss: 0.2002, Valid Loss: 0.2331, Valid Accuracy: 0.9128
Epoch: 4
Train Loss: 0.1945, Valid Loss: 0.2409, Valid Accuracy: 0.9099
Epoch: 5
Train Loss: 0.1868, Valid Loss: 0.2295, Valid Accuracy: 0.9151
Epoch: 6
Train Loss: 0.1822, Valid Loss: 0.2550, Valid Accuracy: 

### Iterative pruning

In [10]:
ITER_PRUNING_RATES = [0.01, 0.02, 0.04]
RETRAIN_EPOCHS = [1, 2, 3, 4]
max_pruning_percent = max(PRUNING_VALUES)

Create iterative pruning configs 

In [11]:
iterative_pruning_configs = []

for pruning_step, finetune_epochs in itertools.product(
    ITER_PRUNING_RATES, RETRAIN_EPOCHS
):
    iterative_pruning_configs.append(
        PruningMetadata(
            total_pruned=max_pruning_percent,
            pruning_step=pruning_step,
            finetune_epochs=finetune_epochs,
            method=prune.L1Unstructured,
            early_stopping=False,
        )
    )

In [13]:
for config in iterative_pruning_configs:
    total_pruned = config.total_pruned
    pruning_step = config.pruning_step
    finetune_epochs = config.finetune_epochs
    method = config.method
    method_name = method.__name__

    pruning_value = int(
        round(
            calculate_parameters_amount(get_parameters_to_prune(base_model))
            * pruning_step
        )
    )

    print(
        f"Target total pruning: {total_pruned}, pruning_step: {pruning_step}, method: {method_name}, finetune_epochs: {finetune_epochs}"
    )

    iterative_model = LeNet().to(device)
    iterative_model.load_state_dict(
        torch.load(MODELS_PATH / Path("LeNet_fmnist/LeNet_fmnist.pth"))
    )
    model_name = iterative_model.__class__.__name__

    iterative_model_parameters = get_parameters_to_prune(iterative_model)
    optimizer = optim.SGD(temp_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

    target_sparsity = int(total_pruned * 100)
    pruning_step_percent = int(pruning_step * 100)
    for pruned in range(
        pruning_step_percent,
        target_sparsity + pruning_step_percent,
        pruning_step_percent,
    ):
        prune.global_unstructured(
            parameters=iterative_model_parameters,
            pruning_method=method,
            amount=pruning_value,
        )

        print("Fine tuning the model")
        for epoch in range(finetune_epochs):
            train_loss = utility.training.train_epoch(
                module=iterative_model,
                train_dl=train_loader,
                optimizer=optimizer,
                loss_function=cross_entropy,
                device=device,
            )

            val_loss, val_accuracy = utility.training.validate(
                module=iterative_model,
                valid_dl=validation_loader,
                loss_function=cross_entropy,
                device=device,
            )

            print(
                f"Pruned {pruned} / {target_sparsity}% | Epoch #{epoch}\tvalidation loss: {val_loss:.4f}\t validation accuracy: {val_accuracy:.4f}"
            )

        if pruned in map(lambda x: int(x * 100), PRUNING_VALUES):
            tmp = copy.deepcopy(iterative_model)
            for module, name in get_parameters_to_prune(tmp):
                prune.remove(module, name)

            tmp_config = copy.deepcopy(config)
            tmp_config.total_pruned = pruned / 100
            tmp_config.method = method.__name__
            utility.save.save_model_with_metadata(
                model=tmp,
                path=f"{model_name}_iterative_pruned_{pruned / 100}_step_{pruning_step}_{method_name}_epochs_{finetune_epochs}",
                model_name=f"{model_name}_iterative_pruned_{pruned / 100}_step_{pruning_step}_{method_name}_epochs_{finetune_epochs}",
                metadata=asdict(tmp_config),
            )

    for module, name in iterative_model_parameters:
        prune.remove(module, name)

    config.method = method.__name__
    utility.save.save_model_with_metadata(
        model=iterative_model,
        path=f"{MODELS_PATH}/{model_name}_iterative_pruned_{total_pruned}_step_{pruning_step}_{method_name}_epochs_{finetune_epochs}",
        model_name=f"{model_name}_iterative_pruned_{total_pruned}_step_{pruning_step}_{method_name}_epochs_{finetune_epochs}",
        metadata=asdict(config),
    )

Target total pruning: 0.96, pruning_step: 0.01, method: L1Unstructured, finetune_epochs: 1
Fine tuning the model
Pruned 1 / 96% | Epoch #0	validation loss: 0.2378	 validation accuracy: 0.9091
Fine tuning the model
Pruned 2 / 96% | Epoch #0	validation loss: 0.2378	 validation accuracy: 0.9089
Fine tuning the model
Pruned 3 / 96% | Epoch #0	validation loss: 0.2378	 validation accuracy: 0.9089
Fine tuning the model
Pruned 4 / 96% | Epoch #0	validation loss: 0.2378	 validation accuracy: 0.9088
Fine tuning the model
Pruned 5 / 96% | Epoch #0	validation loss: 0.2378	 validation accuracy: 0.9087
Fine tuning the model
Pruned 6 / 96% | Epoch #0	validation loss: 0.2379	 validation accuracy: 0.9087
Fine tuning the model
Pruned 7 / 96% | Epoch #0	validation loss: 0.2378	 validation accuracy: 0.9086
Fine tuning the model
Pruned 8 / 96% | Epoch #0	validation loss: 0.2377	 validation accuracy: 0.9087
Fine tuning the model
Pruned 9 / 96% | Epoch #0	validation loss: 0.2379	 validation accuracy: 0.9086
