In [1]:
import torch
from pathlib import Path
from utils import evaluate_model, structured_prune, unstructured_prune
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models.ResNet50 import ResNet50Baseline
import json

In [2]:
saved_model_path = Path("./saved_models/lenet")
baseline_model = torch.load(saved_model_path / "lenet5_baseline_model.pth", map_location="cpu",weights_only=False)
baseline_model.eval()

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [3]:
# the passed parameters are best param combinations, we used grid search in a defined search 
# space with a condition to filter combinations, where accurays drop is less than or equal to 0.02
pruned_structured_model = structured_prune(baseline_model, method="magnitude", sparsity=0.3, layer_scope="fc")
pruned_structured_model

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True) => (fc1): Linear(in_features=256, out_features=84, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True) => (fc2): Linear(in_features=84, out_features=58, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True) => (fc3): Linear(in_features=58, out_features=10, bias=True)
)



LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=58, bias=True)
  (fc3): Linear(in_features=58, out_features=10, bias=True)
)

In [4]:
pruned_unstructured_model = unstructured_prune(baseline_model, sparsity=0.3, layer_scope="fc")
pruned_unstructured_model

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=58, bias=True)
  (fc3): Linear(in_features=58, out_features=10, bias=True)
)

In [5]:
torch.save(pruned_structured_model.state_dict(), saved_model_path / 'lenet5_base_str_prune_weights.pth')
torch.save(pruned_structured_model, saved_model_path / 'lenet5_base_str_prune_model.pth')

In [6]:
torch.save(pruned_unstructured_model.state_dict(), saved_model_path / 'lenet5_base_unstr_prune_weights.pth')
torch.save(pruned_unstructured_model, saved_model_path / 'lenet5_base_unstr_prune_model.pth')

In [7]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64)

In [8]:
pruned_structured_model = torch.load(saved_model_path / "lenet5_base_str_prune_model.pth", map_location="cpu",weights_only=False)
pruned_structured_model.eval()

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=58, bias=True)
  (fc3): Linear(in_features=58, out_features=10, bias=True)
)

In [9]:
baseline_metrics = evaluate_model(baseline_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████████████████████████████████████████████| 157/157 [00:01<00:00, 83.43it/s, Loss=0.0883, Top1=97.28%]


In [10]:
baseline_str_metrics = evaluate_model(pruned_structured_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████████████████████████████████████████████| 157/157 [00:01<00:00, 82.63it/s, Loss=0.0883, Top1=97.28%]


In [11]:
pruned_unstructured_model = torch.load(saved_model_path / "lenet5_base_unstr_prune_model.pth", map_location="cpu",weights_only=False)
pruned_unstructured_model.eval()

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=58, bias=True)
  (fc3): Linear(in_features=58, out_features=10, bias=True)
)

In [12]:
baseline_unstr_metrics = evaluate_model(pruned_unstructured_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████████████████████████████████████████████| 157/157 [00:01<00:00, 81.27it/s, Loss=0.0883, Top1=97.28%]


In [13]:
all_metrics = {
    "baseline_metrics": baseline_metrics,
    "baseline_str_metrics": baseline_str_metrics,
    "baseline_unstr_metrics": baseline_unstr_metrics
}
metrics_folder = Path("./model_metrics/lenet5")
metrics_folder.mkdir(parents=True, exist_ok=True)

In [14]:
for name, metrics in all_metrics.items():
    with (metrics_folder / f"{name}.json").open("w") as file:
        json.dump(metrics, file, indent=1)