In [4]:
import torch.nn.utils.prune as prune

import torch

import torch.nn as nn
import torch.optim as optim

from pathlib import Path
from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE
from pruning_metadata import PruningMetadata
from utility.pruning import (
    calculate_parameters_amount,
    get_parameters_to_prune,
)
from utility.cifar_dataset import get_dataloaders
from dataclasses import asdict

import utility
import itertools

### Load the data

In [5]:
train_loader, validation_loader, test_loader = get_dataloaders(
    data_path=DATA_PATH, batch_size=BATCH_SIZE
)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [7]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
cross_entropy = nn.CrossEntropyLoss()
early_stopper = utility.early_stopping.EarlyStopper(patience=3, min_delta=0)
optimizer = optim.AdamW(base_model.parameters())

## Pruning Phase

### One shot pruning

In [8]:
PRUNING_VALUES = [0.20, 0.40, 0.60]
PRUNING_NAME_TO_METHOD = {
    "RandomUnstructured": prune.RandomUnstructured,
    "L1Unstructured": prune.L1Unstructured,
}
ITER_STEPS = [1, 2, 4]
FINETUNE_EPOCHS = [1, 2]

In [12]:
one_shot_configs = []

for pruning_rate, iter_step, finetune_epochs in itertools.product(
    PRUNING_VALUES, ITER_STEPS, FINETUNE_EPOCHS
):
    one_shot_configs.append(
        PruningMetadata(
            total_pruned=int(pruning_rate * 100),
            pruning_step=pruning_rate,
            finetune_epochs=(int(pruning_rate * 100) // iter_step) * finetune_epochs,
            total_epochs=(int(pruning_rate * 100) // iter_step) * finetune_epochs,
            method=prune.L1Unstructured,
            early_stopping=False,
        )
    )

In [14]:
results = []

for config in one_shot_configs:
    pruning_rate = config.pruning_step
    method = config.method
    early_stopping = config.early_stopping
    fine_tune_epochs = config.finetune_epochs

    pruning_value = int(
        round(
            calculate_parameters_amount(get_parameters_to_prune(base_model))
            * pruning_rate
        )
    )

    # load the model
    temp_model = LeNet().to(device)
    temp_model.load_state_dict(
        torch.load(MODELS_PATH / Path("LeNet_cifar10/LeNet_cifar10.pth"))
    )
    model_name = temp_model.__class__.__name__
    model_parameters = get_parameters_to_prune(temp_model)

    optimizer = optim.AdamW(temp_model.parameters())

    # prune the model
    prune.global_unstructured(
        parameters=model_parameters,
        pruning_method=method,
        amount=pruning_value,
    )

    print(
        f"Pruning rate: {pruning_rate}, method: {method.__name__}, fine tune epochs: {fine_tune_epochs}"
    )

    valid_loss, valid_accuracy = utility.training.validate(
        module=temp_model,
        valid_dl=validation_loader,
        loss_function=cross_entropy,
        device=device,
    )

    print(
        f"After pruning:\nValid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
    )

    # retrain the model
    print("Retraining the model")
    for epoch in range(fine_tune_epochs):
        train_loss = utility.training.train_epoch(
            module=temp_model,
            train_dl=train_loader,
            optimizer=optimizer,
            loss_function=cross_entropy,
            device=device,
        )

        valid_loss, valid_accuracy = utility.training.validate(
            module=temp_model,
            valid_dl=validation_loader,
            loss_function=cross_entropy,
            device=device,
        )

        print(
            f"Epoch: {epoch:}\nTrain Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
        )

    for module, name in model_parameters:
        prune.remove(module, name)

    config.method = method.__name__
    utility.save.save_model_with_metadata(
        model=temp_model,
        path=f"{MODELS_PATH}/{model_name}_pruned_{pruning_rate}_{method.__name__}",
        model_name=f"{model_name}_pruned_{pruning_rate}_{method.__name__}",
        metadata=asdict(config),
    )

Pruning rate: 0.2, method: L1Unstructured
After pruning:
Valid Loss: 0.6726, Valid Accuracy: 0.7714
Retraining the model
Epoch: 0
Train Loss: 0.6964, Valid Loss: 0.6990, Valid Accuracy: 0.7571
Epoch: 1
Train Loss: 0.6505, Valid Loss: 0.7201, Valid Accuracy: 0.7525
Epoch: 2
Train Loss: 0.6296, Valid Loss: 0.7416, Valid Accuracy: 0.7412
Epoch: 3
Train Loss: 0.6038, Valid Loss: 0.7647, Valid Accuracy: 0.7349
Epoch: 4
Train Loss: 0.5785, Valid Loss: 0.8027, Valid Accuracy: 0.7277
Epoch: 5
Train Loss: 0.5550, Valid Loss: 0.8325, Valid Accuracy: 0.7175
Epoch: 6
Train Loss: 0.5309, Valid Loss: 0.8190, Valid Accuracy: 0.7243
Epoch: 7
Train Loss: 0.5159, Valid Loss: 0.8729, Valid Accuracy: 0.7059
Epoch: 8
Train Loss: 0.5005, Valid Loss: 0.8717, Valid Accuracy: 0.7149
Epoch: 9
Train Loss: 0.4746, Valid Loss: 0.9020, Valid Accuracy: 0.7103
Epoch: 10
Train Loss: 0.4552, Valid Loss: 0.9444, Valid Accuracy: 0.6966
Epoch: 11
Train Loss: 0.4395, Valid Loss: 0.9775, Valid Accuracy: 0.6997
Epoch: 12
Tra