In [2]:
import torch.nn.utils.prune as prune

import torch

import torch.nn as nn
import torch.optim as optim
from pathlib import Path
from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE
from pruning_metadata import PruningMetadata
from utility.pruning import (
    calculate_parameters_amount,
    get_parameters_to_prune,
)
from utility.cifar_dataset import get_dataloaders
from dataclasses import asdict

import utility
import itertools
import copy

ImportError: cannot import name 'LEARNING_RATE' from 'LeNet' (/home/bubuss/source/ml-pruning/LeNet.py)

In [None]:
train_loader, validation_loader, test_loader = get_dataloaders(
    data_path=DATA_PATH, batch_size=BATCH_SIZE
)

In [None]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [None]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
cross_entropy = nn.CrossEntropyLoss()
optimizer = optim.AdamW(base_model.parameters())

In [None]:
PRUNING_VALUES = [0.20, 0.40, 0.60]
PRUNING_NAME_TO_METHOD = {
    "RandomUnstructured": prune.RandomUnstructured,
    "L1Unstructured": prune.L1Unstructured,
}

In [None]:
ITER_PRUNING_RATES = [0.01, 0.02, 0.04]
RETRAIN_EPOCHS = [1, 2]
max_pruning_percent = max(PRUNING_VALUES)

In [None]:
iterative_pruning_configs = []

for pruning_step, finetune_epochs in itertools.product(
    ITER_PRUNING_RATES, RETRAIN_EPOCHS
):
    iterative_pruning_configs.append(
        PruningMetadata(
            total_pruned=max_pruning_percent,
            pruning_step=pruning_step,
            finetune_epochs=finetune_epochs,
            total_epochs=None,  # set later
            method=prune.L1Unstructured,
            early_stopping=False,
        )
    )

In [None]:
for config in iterative_pruning_configs:
    total_pruned = config.total_pruned
    pruning_step = config.pruning_step
    finetune_epochs = config.finetune_epochs
    method = config.method
    method_name = method.__name__

    pruning_value = int(
        round(
            calculate_parameters_amount(get_parameters_to_prune(base_model))
            * pruning_step
        )
    )

    print(
        f"Target total pruning: {total_pruned}, pruning_step: {pruning_step}, method: {method_name}, finetune_epochs: {finetune_epochs}"
    )

    iterative_model = LeNet().to(device)
    iterative_model.load_state_dict(
        torch.load(MODELS_PATH / Path("LeNet_cifar10/LeNet_cifar10.pth"))
    )
    model_name = iterative_model.__class__.__name__

    iterative_model_parameters = get_parameters_to_prune(iterative_model)
    iterative_optimizer = optim.AdamW(iterative_model.parameters())

    target_sparsity = int(total_pruned * 100)
    pruning_step_percent = int(pruning_step * 100)
    for pruned in range(
        pruning_step_percent,
        target_sparsity + pruning_step_percent,
        pruning_step_percent,
    ):
        prune.global_unstructured(
            parameters=iterative_model_parameters,
            pruning_method=method,
            amount=pruning_value,
        )

        val_loss, val_accuracy = utility.training.validate(
            module=iterative_model,
            valid_dl=validation_loader,
            loss_function=cross_entropy,
            device=device,
        )
        print(
            f"Pre-Finetuning Validation loss: {val_loss:.4f}\t validation accuracy: {val_accuracy:.4f}"
        )

        print("Fine tuning the model")
        for epoch in range(finetune_epochs):
            train_loss = utility.training.train_epoch(
                module=iterative_model,
                train_dl=train_loader,
                optimizer=iterative_optimizer,
                loss_function=cross_entropy,
                device=device,
            )

            val_loss, val_accuracy = utility.training.validate(
                module=iterative_model,
                valid_dl=validation_loader,
                loss_function=cross_entropy,
                device=device,
            )

            print(
                f"Pruned {pruned} / {target_sparsity}% | Epoch #{epoch}\tvalidation loss: {val_loss:.4f}\t validation accuracy: {val_accuracy:.4f}"
            )

        if pruned in map(lambda x: int(x * 100), PRUNING_VALUES):
            tmp = copy.deepcopy(iterative_model)
            for module, name in get_parameters_to_prune(tmp):
                prune.remove(module, name)

            tmp_config = copy.deepcopy(config)
            tmp_config.total_pruned = pruned / 100
            tmp_config.method = method.__name__
            tmp_config.total_epochs = (
                int(pruned // pruning_step_percent) * finetune_epochs
            )
            utility.save.save_model_with_metadata(
                model=tmp,
                path=f"{MODELS_PATH}/{model_name}_iterative_pruned_{pruned / 100}_step_{pruning_step}_{method_name}_epochs_{finetune_epochs}",
                model_name=f"{model_name}_iterative_pruned_{pruned / 100}_step_{pruning_step}_{method_name}_epochs_{finetune_epochs}",
                metadata=asdict(tmp_config),
            )

    for module, name in iterative_model_parameters:
        prune.remove(module, name)

    config.method = method.__name__
    utility.save.save_model_with_metadata(
        model=iterative_model,
        path=f"{MODELS_PATH}/{model_name}_iterative_pruned_{total_pruned}_step_{pruning_step}_{method_name}_epochs_{finetune_epochs}",
        model_name=f"{model_name}_iterative_pruned_{total_pruned}_step_{pruning_step}_{method_name}_epochs_{finetune_epochs}",
        metadata=asdict(config),
    )

Target total pruning: 0.96, pruning_step: 0.01, method: L1Unstructured, finetune_epochs: 1
Pre-Finetuning Validation loss: 0.2411	 validation accuracy: 0.9047
Fine tuning the model
Pruned 1 / 96% | Epoch #0	validation loss: 0.2384	 validation accuracy: 0.9104
Pre-Finetuning Validation loss: 0.2384	 validation accuracy: 0.9104
Fine tuning the model
Pruned 2 / 96% | Epoch #0	validation loss: 0.2335	 validation accuracy: 0.9126
Pre-Finetuning Validation loss: 0.2335	 validation accuracy: 0.9126
Fine tuning the model
Pruned 3 / 96% | Epoch #0	validation loss: 0.2376	 validation accuracy: 0.9139
Pre-Finetuning Validation loss: 0.2375	 validation accuracy: 0.9139
Fine tuning the model
Pruned 4 / 96% | Epoch #0	validation loss: 0.2411	 validation accuracy: 0.9132
Pre-Finetuning Validation loss: 0.2411	 validation accuracy: 0.9132
Fine tuning the model
Pruned 5 / 96% | Epoch #0	validation loss: 0.2406	 validation accuracy: 0.9152
Pre-Finetuning Validation loss: 0.2405	 validation accuracy: 0.9