In [2]:
import torch
import torch.nn as nn
import datetime
from pathlib import Path
from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE
import utility
import json
from utility.pruning import (
    calculate_total_sparsity,
    get_parameters_to_prune,
)
from utility.cifar_dataset import get_dataloaders

In [None]:
current_date = datetime.datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")

In [3]:
train_loader, validation_loader, test_loader = get_dataloaders(
    data_path=DATA_PATH, batch_size=BATCH_SIZE
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [5]:
base_model = LeNet().to(device)

# Define the loss function
cross_entropy = nn.CrossEntropyLoss()

### Load models

In [6]:
models = []

for file in Path(MODELS_PATH).iterdir():
    if not file.is_dir():
        continue

    for inner_file in file.iterdir():
        if not inner_file.is_file():
            continue
        match inner_file.suffix:
            case ".pth":
                model = LeNet().to(device)
                model.load_state_dict(torch.load(inner_file))
            case ".json":
                metadata = json.load(inner_file.open())
    models.append((model, metadata, file.stem))

### Test the models

In [7]:
results = []
for model, meta, _ in models:
    test_loss, accuracy = utility.training.test(
        model=model, test_dl=test_loader, loss_function=cross_entropy, device=device
    )
    results.append((*meta.values(), accuracy))

In [8]:
import pandas as pd

columns = [*models[0][1].keys(), "accuracy"]
result_df = (
    pd.DataFrame.from_records(results, columns=columns)
    .sort_values(by="accuracy", ascending=False)
    .reset_index(drop=True)
)
result_df

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy
0,0.2,0.04,1,5.0,L1Unstructured,False,64.45
1,0.4,0.04,1,10.0,L1Unstructured,False,64.38
2,20.0,0.2,5,5.0,L1Unstructured,False,64.31
3,0.6,0.04,1,,L1Unstructured,False,64.18
4,0.0,0.0,20,20.0,,False,64.16
5,40.0,0.4,10,10.0,L1Unstructured,False,63.83
6,0.2,0.02,1,10.0,L1Unstructured,False,63.82
7,60.0,0.6,15,15.0,L1Unstructured,False,63.61
8,0.4,0.02,1,20.0,L1Unstructured,False,63.39
9,0.6,0.02,1,,L1Unstructured,False,62.18


### Print models sparsity.

In [9]:
for model, _, name in models:
    print(f"Calculating sparsity for {name}")
    print(
        f"Total sparsity: {100 - calculate_total_sparsity(model, get_parameters_to_prune(model)):.2f}%"
    )
    print("-" * 20)

Calculating sparsity for LeNet_pruned_0
Total sparsity: 60.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.6_step_0
Total sparsity: 39.99%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.4_step_0
Total sparsity: 60.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.4_step_0
Total sparsity: 60.00%
--------------------
Calculating sparsity for LeNet_pruned_0
Total sparsity: 40.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.6_step_0
Total sparsity: 39.99%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.2_step_0
Total sparsity: 80.00%
--------------------
Calculating sparsity for LeNet_iterative_pruned_0.2_step_0
Total sparsity: 80.00%
--------------------
Calculating sparsity for LeNet_pruned_0
Total sparsity: 80.00%
--------------------
Calculating sparsity for LeNet_cifar10
Total sparsity: 100.00%
--------------------


### Visualize tables with the results

In [10]:
base_result = result_df[result_df["total_pruned"] == 0.0]
base_result

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy
4,0.0,0.0,20,20.0,,False,64.16


In [11]:
one_shot_results = result_df[
    (result_df["total_pruned"] == result_df["pruning_step"])
    & (result_df["total_pruned"] != 0.0)
]
one_shot_results

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy


In [12]:
iterative_results = result_df[result_df["total_pruned"] != result_df["pruning_step"]]
iterative_results
iterative_results.to_csv(f"{current_date}_iterative_results.csv")

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy
0,0.2,0.04,1,5.0,L1Unstructured,False,64.45
1,0.4,0.04,1,10.0,L1Unstructured,False,64.38
2,20.0,0.2,5,5.0,L1Unstructured,False,64.31
3,0.6,0.04,1,,L1Unstructured,False,64.18
5,40.0,0.4,10,10.0,L1Unstructured,False,63.83
6,0.2,0.02,1,10.0,L1Unstructured,False,63.82
7,60.0,0.6,15,15.0,L1Unstructured,False,63.61
8,0.4,0.02,1,20.0,L1Unstructured,False,63.39
9,0.6,0.02,1,,L1Unstructured,False,62.18


In [13]:
pruned_20 = result_df[result_df["total_pruned"] == 0.20].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_40 = result_df[result_df["total_pruned"] == 0.40].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_60 = result_df[result_df["total_pruned"] == 0.60].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_80 = result_df[result_df["total_pruned"] == 0.80].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_92 = result_df[result_df["total_pruned"] == 0.92].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)
pruned_96 = result_df[result_df["total_pruned"] == 0.96].sort_values(
    by=["pruning_step", "finetune_epochs"], ascending=False
)

In [14]:
pruned_20

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy
0,0.2,0.04,1,5.0,L1Unstructured,False,64.45
6,0.2,0.02,1,10.0,L1Unstructured,False,63.82


In [15]:
pruned_40

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy
1,0.4,0.04,1,10.0,L1Unstructured,False,64.38
8,0.4,0.02,1,20.0,L1Unstructured,False,63.39


In [16]:
pruned_60

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy
3,0.6,0.04,1,,L1Unstructured,False,64.18
9,0.6,0.02,1,,L1Unstructured,False,62.18


In [17]:
pruned_80

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy


In [18]:
pruned_92

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy


In [19]:
pruned_96

Unnamed: 0,total_pruned,pruning_step,finetune_epochs,total_epochs,method,early_stopping,accuracy
