In [1]:
import torch
from pathlib import Path
from utils import evaluate_model, dynamic_quantization, static_quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from models.ResNet50 import ResNet50Baseline
import json

In [2]:
supported_engines = torch.backends.quantized.supported_engines
if 'qnnpack' in supported_engines:
    torch.backends.quantized.engine = 'qnnpack'
if 'fbgemm' in supported_engines:
    torch.backends.quantized.engine = 'fbgemm'
saved_model_path = Path("./saved_models/lenet")
torch.backends.quantized.engine

'fbgemm'

In [3]:
baseline_model = torch.load(saved_model_path / "lenet5_baseline_model.pth", map_location="cpu",weights_only=False)
baseline_model.eval()

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [4]:
# ===== dynamic quantisation on baseline model. Uses the function inside the utils.py script
quantized_dy_model = dynamic_quantization(baseline_model, (torch.randn(1, 1, 28, 28),))
quantized_dy_model

GraphModule(
  (conv1): ConvReLU2d(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (conv2): ConvReLU2d(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (fc1): DynamicQuantizedLinearReLU(in_features=256, out_features=120, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc2): DynamicQuantizedLinearReLU(in_features=120, out_features=84, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc3): DynamicQuantizedLinear(in_features=84, out_features=10, dtype=torch.qint8, qscheme=torch.per_channel_affine)
)

In [5]:
torch.save(quantized_dy_model.state_dict(), saved_model_path / 'lenet5_base_dy_quant_weights.pth')
torch.save(quantized_dy_model, saved_model_path / 'lenet5_base_dy_quant_model.pth')

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64)

In [7]:
quantized_dy_model = torch.load(saved_model_path / "lenet5_base_dy_quant_model.pth", map_location="cpu",weights_only=False)
quantized_dy_model.eval()

  device=storage.device,


GraphModule(
  (conv1): ConvReLU2d(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (conv2): ConvReLU2d(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (fc1): DynamicQuantizedLinearReLU(in_features=256, out_features=120, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc2): DynamicQuantizedLinearReLU(in_features=120, out_features=84, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc3): DynamicQuantizedLinear(in_features=84, out_features=10, dtype=torch.qint8, qscheme=torch.per_channel_affine)
)

In [8]:
baseline_metrics = evaluate_model(baseline_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████████████████████████████████████████████| 157/157 [00:01<00:00, 84.43it/s, Loss=0.0589, Top1=98.14%]


In [9]:
baseline_dy_metrics = evaluate_model(quantized_dy_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████████████████████████████████████████████| 157/157 [00:01<00:00, 83.70it/s, Loss=0.0589, Top1=98.13%]


In [10]:
# ======= Static Quantization. Uses the function inside the utils.py script
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
calibration_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(1000)), batch_size=64)

base_st_quant_model = static_quantization(baseline_model, (torch.randn(1, 1, 28, 28),), 
                                         calibration_loader)
base_st_quant_model

100%|██████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 75.69it/s]


GraphModule(
  (conv1): QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.02737204171717167, zero_point=0)
  (conv2): QuantizedConvReLU2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=0.04742223024368286, zero_point=0)
  (fc1): QuantizedLinearReLU(in_features=256, out_features=120, scale=0.07808518409729004, zero_point=0, qscheme=torch.per_channel_affine)
  (fc2): QuantizedLinearReLU(in_features=120, out_features=84, scale=0.06925788521766663, zero_point=0, qscheme=torch.per_channel_affine)
  (fc3): QuantizedLinear(in_features=84, out_features=10, scale=0.1696537882089615, zero_point=137, qscheme=torch.per_channel_affine)
)

In [11]:
torch.save(base_st_quant_model.state_dict(), saved_model_path / 'lenet5_base_st_quant_weights.pth')
torch.save(base_st_quant_model, saved_model_path / 'lenet5_base_st_quant_model.pth')

In [12]:
baseline_st_metrics = evaluate_model(base_st_quant_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████████████████████████████████████████████| 157/157 [00:01<00:00, 85.59it/s, Loss=0.0580, Top1=98.14%]


In [13]:
pruned_model = torch.load(saved_model_path / "lenet5_base_str_prune_model.pth", map_location="cpu",weights_only=False)
pruned_quantized_dy_model = dynamic_quantization(pruned_model, (torch.randn(1, 1, 28, 28),))

In [14]:
torch.save(pruned_quantized_dy_model.state_dict(), saved_model_path / 'lenet5_pruned_quantized_dy_weights.pth')
torch.save(pruned_quantized_dy_model, saved_model_path / 'lenet5_pruned_quantized_dy_model.pth')

In [15]:
pruned_quantized_dy_model.eval()
pruned_quantized_dy_metrics = evaluate_model(pruned_quantized_dy_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████████████████████████████████████████████| 157/157 [00:01<00:00, 84.60it/s, Loss=0.0881, Top1=97.32%]


In [16]:
pruned_model = torch.load(saved_model_path / "lenet5_base_str_prune_model.pth", map_location="cpu",weights_only=False)
pruned_quantized_st_model = static_quantization(pruned_model, (torch.randn(1, 1, 28, 28),), 
                                         calibration_loader)

100%|██████████████████████████████████████████████████████████████████████████████████| 16/16 [00:00<00:00, 78.01it/s]


In [17]:
torch.save(pruned_quantized_st_model.state_dict(), saved_model_path / 'lenet5_pruned_quantized_st_weights.pth')
torch.save(pruned_quantized_st_model, saved_model_path / 'lenet5_pruned_quantized_st_model.pth')

In [18]:
pruned_quantized_st_model.eval()
pruned_quantized_st_metrics = evaluate_model(pruned_quantized_st_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████████████████████████████████████████████| 157/157 [00:01<00:00, 85.91it/s, Loss=0.0880, Top1=97.48%]


In [19]:
all_metrics = {
    "baseline_metrics": baseline_metrics,
    "baseline_dy_metrics": baseline_dy_metrics,
    "baseline_st_metrics": baseline_st_metrics,
    "pruned_quantized_dy_metrics": pruned_quantized_dy_metrics,
    "pruned_quantized_st_metrics": pruned_quantized_st_metrics
}
metrics_folder = Path("./model_metrics/lenet5")
metrics_folder.mkdir(parents=True, exist_ok=True)

In [20]:
for name, metrics in all_metrics.items():
    with (metrics_folder / f"{name}.json").open("w") as file:
        json.dump(metrics, file, indent=1)