In [1]:
import torch
import torch.nn as nn

from pathlib import Path
from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE
import utility
import json
from utility import (
    calculate_total_sparsity,
    get_parameters_to_prune,
)
from utility.cifar_dataset import get_dataloaders

In [2]:
train_loader, validation_loader, test_loader = get_dataloaders(
    data_path=DATA_PATH, batch_size=BATCH_SIZE
)

In [3]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [4]:
base_model = LeNet().to(device)

# Define the loss function
cross_entropy = nn.CrossEntropyLoss()

### Load models

In [6]:
models = []

for file in Path(MODELS_PATH).iterdir():
    if not file.is_dir():
        continue

    for inner_file in file.iterdir():
        if not inner_file.is_file():
            continue
        match inner_file.suffix:
            case ".pth":
                model = LeNet().to(device)
                model.load_state_dict(torch.load(inner_file))
            case ".json":
                metadata = json.load(inner_file.open())
    models.append((model, metadata, file.stem))

### Test the models

In [7]:
results = []
for model, meta, _ in models:
    test_loss, accuracy = utility.training.test(
        model=model, test_dl=test_loader, loss_function=cross_entropy, device=device
    )
    results.append((*meta.values(), accuracy))

In [8]:
import pandas as pd

columns = [*models[0][1].keys(), "accuracy"]
result_df = (
    pd.DataFrame.from_records(results, columns=columns)
    .sort_values(by="accuracy", ascending=False)
    .reset_index(drop=True)
)
result_df

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
0,0.60,0.60,4,L1Unstructured,True,90.63
1,0.40,0.04,1,L1Unstructured,False,90.52
2,0.20,0.04,3,L1Unstructured,False,90.50
3,0.60,0.04,1,L1Unstructured,False,90.43
4,0.40,0.02,1,L1Unstructured,False,90.39
...,...,...,...,...,...,...
87,0.96,0.01,2,L1Unstructured,False,79.42
88,0.92,0.01,3,L1Unstructured,False,76.51
89,0.96,0.01,3,L1Unstructured,False,56.29
90,0.92,0.01,4,L1Unstructured,False,47.76


### Print models sparsity.

In [9]:
for model, _, name in models:
    print(f"Calculating sparsity for {name}")
    print(
        f"Total sparsity: {100 - calculate_total_sparsity(model, get_parameters_to_prune(model)):.2f}%"
    )
    print("-" * 20)

Calculating sparsity for LeNet_iterative_pruned_0.92_step_0
Total sparsity: 7.96%
--------------------
Calculating sparsity for LeNet_pruned_0
Total sparsity: 60.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.96_step_0
Total sparsity: 3.95%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.88_step_0
Total sparsity: 11.96%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.8_step_0
Total sparsity: 20.03%
--------------------
Calculating sparsity for LeNet_pruned_0
Total sparsity: 12.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.96_step_0
Total sparsity: 3.95%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.96_step_0
Total sparsity: 4.03%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.8_step_0
Total sparsity: 19.96%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.88_step_0
Total sparsity: 11.99%
--------------------
Calcul

### Visualize tables with the results

In [10]:
base_result = result_df[result_df["total_pruned"] == 0.0]
base_result

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
67,0.0,0.0,12,,True,89.04


In [11]:
one_shot_results = result_df[
    (result_df["total_pruned"] == result_df["pruning_step"])
    & (result_df["total_pruned"] != 0.0)
]
one_shot_results

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
0,0.6,0.6,4,L1Unstructured,True,90.63
6,0.2,0.2,4,L1Unstructured,True,90.27
26,0.88,0.88,8,L1Unstructured,True,89.99
27,0.8,0.8,6,L1Unstructured,True,89.94
59,0.4,0.4,9,L1Unstructured,True,89.25
69,0.92,0.92,12,L1Unstructured,True,88.91
72,0.96,0.96,10,L1Unstructured,True,88.61


In [12]:
iterative_results = result_df[result_df["total_pruned"] != result_df["pruning_step"]]
iterative_results

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
1,0.40,0.04,1,L1Unstructured,False,90.52
2,0.20,0.04,3,L1Unstructured,False,90.50
3,0.60,0.04,1,L1Unstructured,False,90.43
4,0.40,0.02,1,L1Unstructured,False,90.39
5,0.40,0.04,4,L1Unstructured,False,90.29
...,...,...,...,...,...,...
87,0.96,0.01,2,L1Unstructured,False,79.42
88,0.92,0.01,3,L1Unstructured,False,76.51
89,0.96,0.01,3,L1Unstructured,False,56.29
90,0.92,0.01,4,L1Unstructured,False,47.76


In [13]:
pruned_20 = result_df[result_df["total_pruned"] == 0.20].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_40 = result_df[result_df["total_pruned"] == 0.40].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_60 = result_df[result_df["total_pruned"] == 0.60].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_80 = result_df[result_df["total_pruned"] == 0.80].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_92 = result_df[result_df["total_pruned"] == 0.92].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_96 = result_df[result_df["total_pruned"] == 0.96].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)

In [14]:
pruned_20

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
6,0.2,0.2,4,L1Unstructured,True,90.27
22,0.2,0.04,4,L1Unstructured,False,90.07
2,0.2,0.04,3,L1Unstructured,False,90.5
14,0.2,0.04,2,L1Unstructured,False,90.17
49,0.2,0.04,1,L1Unstructured,False,89.51
17,0.2,0.02,4,L1Unstructured,False,90.15
7,0.2,0.02,3,L1Unstructured,False,90.26
16,0.2,0.02,2,L1Unstructured,False,90.16
15,0.2,0.02,1,L1Unstructured,False,90.17
25,0.2,0.01,4,L1Unstructured,False,90.04


In [15]:
pruned_40

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
59,0.4,0.4,9,L1Unstructured,True,89.25
5,0.4,0.04,4,L1Unstructured,False,90.29
43,0.4,0.04,3,L1Unstructured,False,89.65
31,0.4,0.04,2,L1Unstructured,False,89.87
1,0.4,0.04,1,L1Unstructured,False,90.52
34,0.4,0.02,4,L1Unstructured,False,89.82
45,0.4,0.02,3,L1Unstructured,False,89.6
33,0.4,0.02,2,L1Unstructured,False,89.84
4,0.4,0.02,1,L1Unstructured,False,90.39
13,0.4,0.01,4,L1Unstructured,False,90.18


In [16]:
pruned_60

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
0,0.6,0.6,4,L1Unstructured,True,90.63
20,0.6,0.04,4,L1Unstructured,False,90.14
23,0.6,0.04,3,L1Unstructured,False,90.04
19,0.6,0.04,2,L1Unstructured,False,90.14
3,0.6,0.04,1,L1Unstructured,False,90.43
28,0.6,0.02,4,L1Unstructured,False,89.93
37,0.6,0.02,3,L1Unstructured,False,89.78
53,0.6,0.02,2,L1Unstructured,False,89.34
8,0.6,0.02,1,L1Unstructured,False,90.24
56,0.6,0.01,4,L1Unstructured,False,89.31


In [17]:
pruned_80

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
27,0.8,0.8,6,L1Unstructured,True,89.94
50,0.8,0.04,4,L1Unstructured,False,89.5
12,0.8,0.04,3,L1Unstructured,False,90.19
24,0.8,0.04,2,L1Unstructured,False,90.04
9,0.8,0.04,1,L1Unstructured,False,90.24
51,0.8,0.02,4,L1Unstructured,False,89.43
29,0.8,0.02,3,L1Unstructured,False,89.9
35,0.8,0.02,2,L1Unstructured,False,89.81
39,0.8,0.02,1,L1Unstructured,False,89.67
71,0.8,0.01,4,L1Unstructured,False,88.62


In [18]:
pruned_92

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
69,0.92,0.92,12,L1Unstructured,True,88.91
64,0.92,0.04,4,L1Unstructured,False,89.19
30,0.92,0.04,3,L1Unstructured,False,89.88
46,0.92,0.04,2,L1Unstructured,False,89.58
62,0.92,0.04,1,L1Unstructured,False,89.24
78,0.92,0.02,4,L1Unstructured,False,87.85
75,0.92,0.02,3,L1Unstructured,False,88.31
65,0.92,0.02,2,L1Unstructured,False,89.1
54,0.92,0.02,1,L1Unstructured,False,89.33
90,0.92,0.01,4,L1Unstructured,False,47.76


In [19]:
pruned_96

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,method,early_stopping,accuracy
72,0.96,0.96,10,L1Unstructured,True,88.61
80,0.96,0.04,4,L1Unstructured,False,87.58
76,0.96,0.04,3,L1Unstructured,False,88.15
74,0.96,0.04,2,L1Unstructured,False,88.37
70,0.96,0.04,1,L1Unstructured,False,88.8
86,0.96,0.02,4,L1Unstructured,False,81.49
84,0.96,0.02,3,L1Unstructured,False,83.75
82,0.96,0.02,2,L1Unstructured,False,87.48
77,0.96,0.02,1,L1Unstructured,False,87.92
91,0.96,0.01,4,L1Unstructured,False,19.71
