In [1]:
import torch
import torch.quantization as tq
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization import default_dynamic_qconfig, MovingAverageMinMaxObserver, PerChannelMinMaxObserver, get_default_qconfig, HistogramObserver
import torch.fx as fx
from torch.ao.quantization.observer import PlaceholderObserver
from pathlib import Path
import copy
from tqdm import tqdm
from utils import evaluate_model, dynamic_quantization, static_quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [2]:
supported_engines = torch.backends.quantized.supported_engines
if 'qnnpack' in supported_engines:
    torch.backends.quantized.engine = 'qnnpack'
if 'fbgemm' in supported_engines:
    torch.backends.quantized.engine = 'fbgemm'
saved_model_path = Path("./saved_models/lenet")
torch.backends.quantized.engine

'qnnpack'

In [3]:
baseline_model = torch.load(saved_model_path / "lenet5_baseline_model.pth", map_location="cpu",weights_only=False)
baseline_model.eval()

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [4]:
# ===== dynamic quantisation on baseline model. Uses the function inside the utils.py script
quantized_dy_model = dynamic_quantization(baseline_model, (torch.randn(1, 1, 28, 28),))
quantized_dy_model

GraphModule(
  (conv1): ConvReLU2d(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (conv2): ConvReLU2d(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (fc1): DynamicQuantizedLinearReLU(in_features=256, out_features=120, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc2): DynamicQuantizedLinearReLU(in_features=120, out_features=84, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc3): DynamicQuantizedLinear(in_features=84, out_features=10, dtype=torch.qint8, qscheme=torch.per_channel_affine)
)

In [5]:
torch.save(quantized_dy_model.state_dict(), saved_model_path / 'lenet5_base_dy_quant_weights.pth')
torch.save(quantized_dy_model, saved_model_path / 'lenet5_base_dy_quant_model.pth')

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64)

In [7]:
quantized_dy_model = torch.load(saved_model_path / "lenet5_base_dy_quant_model.pth", map_location="cpu",weights_only=False)
quantized_dy_model.eval()

  device=storage.device,


GraphModule(
  (conv1): ConvReLU2d(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (conv2): ConvReLU2d(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (fc1): DynamicQuantizedLinearReLU(in_features=256, out_features=120, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc2): DynamicQuantizedLinearReLU(in_features=120, out_features=84, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc3): DynamicQuantizedLinear(in_features=84, out_features=10, dtype=torch.qint8, qscheme=torch.per_channel_affine)
)

In [8]:
evaluate_model(baseline_model, test_loader, 'lenet5')

100%|██████████████| 157/157 [00:00<00:00, 193.09it/s, Loss=0.0589, Top1=98.14%]


{'top1_acc': 0.9814,
 'top5_acc': 0.9999,
 'total_inference_time': 0.8132390975952148,
 'average_loss': 3.7494788712623177,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [9]:
evaluate_model(quantized_dy_model, test_loader, 'lenet5')

100%|██████████████| 157/157 [00:00<00:00, 178.22it/s, Loss=0.0589, Top1=98.12%]


{'top1_acc': 0.9812,
 'top5_acc': 0.9999,
 'total_inference_time': 0.8811829090118408,
 'average_loss': 3.7515113297449485,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [11]:
# ======= Static Quantization. Uses the function inside the utils.py script
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
calibration_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(1000)), batch_size=64)

base_st_quant_model = static_quantization(baseline_model, (torch.randn(1, 1, 28, 28),), 
                                         calibration_loader)
base_st_quant_model

100%|██████████████████████████████████████████| 16/16 [00:00<00:00, 146.59it/s]


GraphModule(
  (conv1): QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.02737203985452652, zero_point=0)
  (conv2): QuantizedConvReLU2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=0.047422222793102264, zero_point=0)
  (fc1): QuantizedLinearReLU(in_features=256, out_features=120, scale=0.07808518409729004, zero_point=0, qscheme=torch.per_channel_affine)
  (fc2): QuantizedLinearReLU(in_features=120, out_features=84, scale=0.06925788521766663, zero_point=0, qscheme=torch.per_channel_affine)
  (fc3): QuantizedLinear(in_features=84, out_features=10, scale=0.1696537584066391, zero_point=137, qscheme=torch.per_channel_affine)
)

In [13]:
torch.save(base_st_quant_model.state_dict(), saved_model_path / 'lenet5_base_st_quant_weights.pth')
torch.save(base_st_quant_model, saved_model_path / 'lenet5_base_st_quant_model.pth')

In [16]:
evaluate_model(base_st_quant_model, test_loader, 'lenet5', high_granularity=True)

100%|██████████████| 157/157 [00:00<00:00, 168.09it/s, Loss=0.0589, Top1=98.11%]


{'top1_acc': 0.9811,
 'top5_acc': 0.9999,
 'total_inference_time': 0.9344391822814941,
 'average_loss': 3.753511188885458,
 'total_batches': 157,
 'all_losses': [3.046917676925659,
  4.085618495941162,
  3.1410038471221924,
  7.957815647125244,
  2.0265939235687256,
  5.179840087890625,
  4.267550945281982,
  1.553153395652771,
  2.4778952598571777,
  6.108922958374023,
  13.323774337768555,
  6.0849409103393555,
  3.5332915782928467,
  2.1239070892333984,
  3.2517364025115967,
  4.6555705070495605,
  5.39436674118042,
  2.9933700561523438,
  7.198803901672363,
  17.807357788085938,
  10.79166030883789,
  5.630599498748779,
  6.377561569213867,
  12.056469917297363,
  3.93253231048584,
  2.7144274711608887,
  8.547775268554688,
  7.338992595672607,
  1.7360212802886963,
  11.227951049804688,
  3.2218594551086426,
  11.109827995300293,
  2.7399678230285645,
  10.791813850402832,
  2.745317220687866,
  2.1642398834228516,
  1.544592261314392,
  8.101005554199219,
  8.088863372802734,
  1