In [1]:
import torch
import torch.nn as nn

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from typing import Iterable
from pathlib import Path
from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE
import utility
import json

In [2]:
# load FashionMNIST data
transform = transforms.Compose([transforms.ToTensor()])

# split into validation and train datasets
train_ds = datasets.FashionMNIST(
    DATA_PATH, train=True, transform=transform, download=True
)
train_ds, valid_ds = random_split(train_ds, [0.8, 0.2])

test_ds = datasets.FashionMNIST(
    DATA_PATH, train=False, transform=transform, download=True
)

In [3]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [4]:
base_model = LeNet().to(device)

# Define the loss function
cross_entropy = nn.CrossEntropyLoss()

# create the data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [5]:
def get_parameters_to_prune(model: nn.Module) -> list[nn.Parameter]:
    return [
        (module, "weight")
        for module in model.modules()
        if isinstance(module, nn.Conv2d | nn.Linear)
    ]


def calculate_total_sparsity(
    module: nn.Module, parameters_to_prune: Iterable[tuple[nn.Module, str]]
) -> float:
    total_weights = 0
    total_zero_weights = 0

    pruned_parameters: set[tuple[nn.Module, str]] = set(parameters_to_prune)

    for _, module in module.named_children():
        for param_name, param in module.named_parameters():
            if (module, param_name) not in pruned_parameters:
                continue

            if "weight" in param_name:
                total_weights += float(param.nelement())
                total_zero_weights += float(torch.sum(param == 0))

    sparsity = 100.0 * total_zero_weights / total_weights
    return sparsity

### Load models

In [6]:
models = []

for file in Path(MODELS_PATH).iterdir():
    if not file.is_dir():
        continue

    for inner_file in file.iterdir():
        if not inner_file.is_file():
            continue
        match inner_file.suffix:
            case ".pth":
                model = LeNet().to(device)
                model.load_state_dict(torch.load(inner_file))
            case ".json":
                metadata = json.load(inner_file.open())
    models.append((model, metadata, file.stem))

### Test the models

In [7]:
results = []
for model, meta, _ in models:
    test_loss, accuracy = utility.training.test(
        model=model, test_dl=test_loader, loss_function=cross_entropy, device=device
    )
    results.append((*meta.values(), accuracy))

In [9]:
import pandas as pd

columns = [*models[0][1].keys(), "accuracy"]
result_df = (
    pd.DataFrame.from_records(results, columns=columns)
    .sort_values(by="accuracy", ascending=False)
    .reset_index(drop=True)
)
result_df

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
0,0.60,0.60,4,L1Unstructured,True,90.63
1,0.20,0.20,4,L1Unstructured,True,90.27
2,0.88,0.88,8,L1Unstructured,True,89.99
3,0.80,0.80,6,L1Unstructured,True,89.94
4,0.40,0.40,9,L1Unstructured,True,89.25
...,...,...,...,...,...,...
87,0.96,0.01,4,L1Unstructured,False,42.19
88,0.96,0.04,4,L1Unstructured,False,41.36
89,0.96,0.04,2,L1Unstructured,False,41.36
90,0.96,0.04,1,L1Unstructured,False,41.36


### Print models sparsity.

In [10]:
for model, _, name in models:
    print(f"Calculating sparsity for {name}")
    print(
        f"Total sparsity: {100 - calculate_total_sparsity(model, get_parameters_to_prune(model)):.2f}%"
    )
    print("-" * 20)

Calculating sparsity for LeNet_iterative_pruned_0.92_step_0
Total sparsity: 7.96%
--------------------
Calculating sparsity for LeNet_pruned_0
Total sparsity: 60.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.96_step_0
Total sparsity: 3.95%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.88_step_0
Total sparsity: 11.96%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.8_step_0
Total sparsity: 20.03%
--------------------
Calculating sparsity for LeNet_pruned_0
Total sparsity: 12.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.96_step_0
Total sparsity: 3.95%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.96_step_0
Total sparsity: 4.03%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.8_step_0
Total sparsity: 19.96%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.88_step_0
Total sparsity: 11.99%
--------------------
Calcul

### Visualize tables with the results

In [11]:
base_result = result_df[result_df["total_pruned"] == 0.0]
base_result

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
17,0.0,0.0,12,,True,89.04


In [12]:
one_shot_results = result_df[
    (result_df["total_pruned"] == result_df["pruning_step"])
    & (result_df["total_pruned"] != 0.0)
]
one_shot_results

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
0,0.6,0.6,4,L1Unstructured,True,90.63
1,0.2,0.2,4,L1Unstructured,True,90.27
2,0.88,0.88,8,L1Unstructured,True,89.99
3,0.8,0.8,6,L1Unstructured,True,89.94
4,0.4,0.4,9,L1Unstructured,True,89.25
30,0.92,0.92,12,L1Unstructured,True,88.91
43,0.96,0.96,10,L1Unstructured,True,88.61


In [13]:
iterative_results = result_df[result_df["total_pruned"] != result_df["pruning_step"]]
iterative_results

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
5,0.40,0.04,2,L1Unstructured,False,89.15
6,0.40,0.04,1,L1Unstructured,False,89.15
7,0.40,0.04,4,L1Unstructured,False,89.15
8,0.40,0.04,3,L1Unstructured,False,89.15
9,0.40,0.02,3,L1Unstructured,False,89.14
...,...,...,...,...,...,...
87,0.96,0.01,4,L1Unstructured,False,42.19
88,0.96,0.04,4,L1Unstructured,False,41.36
89,0.96,0.04,2,L1Unstructured,False,41.36
90,0.96,0.04,1,L1Unstructured,False,41.36


In [38]:
pruned_20 = result_df[result_df["total_pruned"] == 0.20].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_40 = result_df[result_df["total_pruned"] == 0.40].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_60 = result_df[result_df["total_pruned"] == 0.60].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_80 = result_df[result_df["total_pruned"] == 0.80].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_92 = result_df[result_df["total_pruned"] == 0.92].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_96 = result_df[result_df["total_pruned"] == 0.96].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)

In [39]:
pruned_20

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
1,0.2,0.2,4,L1Unstructured,True,90.27
20,0.2,0.04,4,L1Unstructured,False,88.97
21,0.2,0.04,3,L1Unstructured,False,88.97
18,0.2,0.04,2,L1Unstructured,False,88.97
22,0.2,0.04,1,L1Unstructured,False,88.97
19,0.2,0.02,4,L1Unstructured,False,88.97
25,0.2,0.02,3,L1Unstructured,False,88.97
23,0.2,0.02,2,L1Unstructured,False,88.97
24,0.2,0.02,1,L1Unstructured,False,88.97
27,0.2,0.01,4,L1Unstructured,False,88.96


In [40]:
pruned_40

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
4,0.4,0.4,9,L1Unstructured,True,89.25
7,0.4,0.04,4,L1Unstructured,False,89.15
8,0.4,0.04,3,L1Unstructured,False,89.15
5,0.4,0.04,2,L1Unstructured,False,89.15
6,0.4,0.04,1,L1Unstructured,False,89.15
12,0.4,0.02,4,L1Unstructured,False,89.14
9,0.4,0.02,3,L1Unstructured,False,89.14
11,0.4,0.02,2,L1Unstructured,False,89.14
10,0.4,0.02,1,L1Unstructured,False,89.14
13,0.4,0.01,4,L1Unstructured,False,89.13


In [41]:
pruned_60

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
0,0.6,0.6,4,L1Unstructured,True,90.63
33,0.6,0.04,4,L1Unstructured,False,88.68
32,0.6,0.04,3,L1Unstructured,False,88.68
31,0.6,0.04,2,L1Unstructured,False,88.68
34,0.6,0.04,1,L1Unstructured,False,88.68
36,0.6,0.02,4,L1Unstructured,False,88.66
38,0.6,0.02,3,L1Unstructured,False,88.66
37,0.6,0.02,2,L1Unstructured,False,88.66
35,0.6,0.02,1,L1Unstructured,False,88.66
40,0.6,0.01,4,L1Unstructured,False,88.65


In [42]:
pruned_80

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
3,0.8,0.8,6,L1Unstructured,True,89.94
50,0.8,0.04,4,L1Unstructured,False,87.78
51,0.8,0.04,3,L1Unstructured,False,87.78
48,0.8,0.04,2,L1Unstructured,False,87.78
49,0.8,0.04,1,L1Unstructured,False,87.78
47,0.8,0.02,4,L1Unstructured,False,87.83
45,0.8,0.02,3,L1Unstructured,False,87.83
44,0.8,0.02,2,L1Unstructured,False,87.83
46,0.8,0.02,1,L1Unstructured,False,87.83
53,0.8,0.01,4,L1Unstructured,False,87.75


In [43]:
pruned_92

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
30,0.92,0.92,12,L1Unstructured,True,88.91
79,0.92,0.04,4,L1Unstructured,False,70.08
76,0.92,0.04,3,L1Unstructured,False,70.08
77,0.92,0.04,2,L1Unstructured,False,70.08
78,0.92,0.04,1,L1Unstructured,False,70.08
74,0.92,0.02,4,L1Unstructured,False,70.12
75,0.92,0.02,3,L1Unstructured,False,70.12
73,0.92,0.02,2,L1Unstructured,False,70.12
72,0.92,0.02,1,L1Unstructured,False,70.12
69,0.92,0.01,4,L1Unstructured,False,70.8


In [44]:
pruned_96

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
43,0.96,0.96,10,L1Unstructured,True,88.61
88,0.96,0.04,4,L1Unstructured,False,41.36
91,0.96,0.04,3,L1Unstructured,False,41.36
89,0.96,0.04,2,L1Unstructured,False,41.36
90,0.96,0.04,1,L1Unstructured,False,41.36
82,0.96,0.02,4,L1Unstructured,False,44.17
80,0.96,0.02,3,L1Unstructured,False,44.17
83,0.96,0.02,2,L1Unstructured,False,44.17
81,0.96,0.02,1,L1Unstructured,False,44.17
87,0.96,0.01,4,L1Unstructured,False,42.19
