In [1]:
import torch.nn.utils.prune as prune

import torch

import torch.nn as nn
import torch.optim as optim

from pathlib import Path
from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE
from pruning_metadata import PruningMetadata
from utility.pruning import (
    calculate_parameters_amount,
    get_parameters_to_prune,
)
from utility.cifar_dataset import get_dataloaders
from dataclasses import asdict

import utility
import itertools

### Load the data

In [2]:
train_loader, validation_loader, test_loader = get_dataloaders(
    data_path=DATA_PATH, batch_size=BATCH_SIZE
)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [4]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
cross_entropy = nn.CrossEntropyLoss()
early_stopper = utility.early_stopping.EarlyStopper(patience=3, min_delta=0)
optimizer = optim.AdamW(base_model.parameters())

## Pruning Phase

### One shot pruning

In [5]:
PRUNING_VALUES = [0.20, 0.40, 0.60]
PRUNING_NAME_TO_METHOD = {
    "RandomUnstructured": prune.RandomUnstructured,
    "L1Unstructured": prune.L1Unstructured,
}
ITER_STEPS = [2, 4]
FINETUNE_EPOCHS = [1]

In [6]:
one_shot_configs = []

for pruning_rate, iter_step, finetune_epochs in itertools.product(
    PRUNING_VALUES, ITER_STEPS, FINETUNE_EPOCHS
):
    one_shot_configs.append(
        PruningMetadata(
            total_pruned=pruning_rate,
            pruning_step=pruning_rate,
            finetune_epochs=(int(pruning_rate * 100) // iter_step) * finetune_epochs,
            total_epochs=(int(pruning_rate * 100) // iter_step) * finetune_epochs,
            method=prune.L1Unstructured,
            early_stopping=False,
        )
    )

In [7]:
results = []

for config in one_shot_configs:
    pruning_rate = config.pruning_step
    method = config.method
    early_stopping = config.early_stopping
    fine_tune_epochs = config.finetune_epochs

    pruning_value = int(
        round(
            calculate_parameters_amount(get_parameters_to_prune(base_model))
            * pruning_rate
        )
    )

    # load the model
    temp_model = LeNet().to(device)
    temp_model.load_state_dict(
        torch.load(MODELS_PATH / Path("LeNet_cifar10/LeNet_cifar10.pth"))
    )
    model_name = temp_model.__class__.__name__
    model_parameters = get_parameters_to_prune(temp_model)

    optimizer = optim.AdamW(temp_model.parameters())

    # prune the model
    prune.global_unstructured(
        parameters=model_parameters,
        pruning_method=method,
        amount=pruning_value,
    )

    print(
        f"Pruning rate: {pruning_rate}, method: {method.__name__}, fine tune epochs: {fine_tune_epochs}"
    )

    valid_loss, valid_accuracy = utility.training.validate(
        module=temp_model,
        valid_dl=validation_loader,
        loss_function=cross_entropy,
        device=device,
    )

    print(
        f"After pruning:\nValid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
    )

    # retrain the model
    print("Retraining the model")
    for epoch in range(fine_tune_epochs):
        train_loss = utility.training.train_epoch(
            module=temp_model,
            train_dl=train_loader,
            optimizer=optimizer,
            loss_function=cross_entropy,
            device=device,
        )

        valid_loss, valid_accuracy = utility.training.validate(
            module=temp_model,
            valid_dl=validation_loader,
            loss_function=cross_entropy,
            device=device,
        )

        print(
            f"Epoch: {epoch:}\nTrain Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
        )

    for module, name in model_parameters:
        prune.remove(module, name)

    config.method = method.__name__
    utility.save.save_model_with_metadata(
        model=temp_model,
        path=f"{MODELS_PATH}/{model_name}_pruned_{pruning_rate}_{method.__name__}",
        model_name=f"{model_name}_pruned_{pruning_rate}_{method.__name__}",
        metadata=asdict(config),
    )

Pruning rate: 0.2, method: L1Unstructured, fine tune epochs: 10
After pruning:
Valid Loss: 0.6920, Valid Accuracy: 0.7610
Retraining the model
Epoch: 0
Train Loss: 0.6923, Valid Loss: 0.7047, Valid Accuracy: 0.7571
Epoch: 1
Train Loss: 0.6493, Valid Loss: 0.7548, Valid Accuracy: 0.7405
Epoch: 2
Train Loss: 0.6217, Valid Loss: 0.7680, Valid Accuracy: 0.7344
Epoch: 3
Train Loss: 0.5941, Valid Loss: 0.7729, Valid Accuracy: 0.7279
Epoch: 4
Train Loss: 0.5775, Valid Loss: 0.7889, Valid Accuracy: 0.7307
Epoch: 5
Train Loss: 0.5510, Valid Loss: 0.8091, Valid Accuracy: 0.7296
Epoch: 6
Train Loss: 0.5372, Valid Loss: 0.8556, Valid Accuracy: 0.7151
Epoch: 7
Train Loss: 0.5119, Valid Loss: 0.8750, Valid Accuracy: 0.7034
Epoch: 8
Train Loss: 0.4956, Valid Loss: 0.8900, Valid Accuracy: 0.7032
Epoch: 9
Train Loss: 0.4698, Valid Loss: 0.9172, Valid Accuracy: 0.7101
Pruning rate: 0.2, method: L1Unstructured, fine tune epochs: 5
After pruning:
Valid Loss: 0.6920, Valid Accuracy: 0.7610
Retraining the m