In [3]:
import torch
import torch.quantization as tq
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization import default_dynamic_qconfig, MovingAverageMinMaxObserver, PerChannelMinMaxObserver, get_default_qconfig, HistogramObserver
import torch.fx as fx
from torch.ao.quantization.observer import PlaceholderObserver
from pathlib import Path
import copy
from tqdm import tqdm

In [4]:
supported_engines = torch.backends.quantized.supported_engines
if 'qnnpack' in supported_engines:
    torch.backends.quantized.engine = 'qnnpack'
if 'fbgemm' in supported_engines:
    torch.backends.quantized.engine = 'fbgemm'
saved_model_path = Path("./saved_models/lenet")
torch.backends.quantized.engine

'qnnpack'

In [5]:
baseline_model = torch.load(saved_model_path / "lenet5_baseline_model.pth", map_location="cpu",weights_only=False)
baseline_model.eval()

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [6]:
# ===== dynamic quantisation on baseline model
dynamic_qconfig = tq.QConfig(
    activation=PlaceholderObserver.with_args(
        dtype=torch.quint8,
        quant_min=0,
        quant_max=255,
        is_dynamic=True
    ),
    weight=PerChannelMinMaxObserver.with_args(
        dtype=torch.qint8,
        qscheme=torch.per_channel_symmetric
    )
)
qconfig_mapping = QConfigMapping()
qconfig_mapping.set_global(dynamic_qconfig)

QConfigMapping (
 global_qconfig
  QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.PlaceholderObserver'>, dtype=torch.quint8, quant_min=0, quant_max=255, is_dynamic=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
 object_type_qconfigs
  OrderedDict()
 module_name_regex_qconfigs
  OrderedDict()
 module_name_qconfigs
  OrderedDict()
 module_name_object_type_order_qconfigs
  OrderedDict()
)

In [7]:
def quant_model_prep(example_input, model, qconfig_mapping):
    return prepare_fx(
        copy.deepcopy(model),
        qconfig_mapping,
        example_input
    )

In [8]:
example_input = (torch.randn(1, 1, 28, 28),)
quantized_dy_model = convert_fx(quant_model_prep(example_input, baseline_model, qconfig_mapping))

In [9]:
torch.save(quantized_dy_model.state_dict(), saved_model_path / 'lenet5_base_dy_quant_weights.pth')
torch.save(quantized_dy_model, saved_model_path / 'lenet5_base_dy_quant_model.pth')

In [10]:
from utils import evaluate_model
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64)

In [12]:
len(test_loader)

157

In [13]:
quantized_dy_model = torch.load(saved_model_path / "lenet5_base_dy_quant_model.pth", map_location="cpu",weights_only=False)
quantized_dy_model.eval()

  device=storage.device,


GraphModule(
  (conv1): ConvReLU2d(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (conv2): ConvReLU2d(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
  )
  (fc1): DynamicQuantizedLinearReLU(in_features=256, out_features=120, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc2): DynamicQuantizedLinearReLU(in_features=120, out_features=84, dtype=torch.qint8, qscheme=torch.per_channel_affine)
  (fc3): DynamicQuantizedLinear(in_features=84, out_features=10, dtype=torch.qint8, qscheme=torch.per_channel_affine)
)

In [14]:
evaluate_model(baseline_model, test_loader, 'lenet5')

100%|██████████████| 157/157 [00:00<00:00, 180.85it/s, Loss=0.0589, Top1=98.14%]


{'top1_acc': 0.9814,
 'top5_acc': 0.9999,
 'total_inference_time': 0.8684639930725098,
 'average_loss': 3.7494788712623177,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [15]:
evaluate_model(quantized_dy_model, test_loader, 'lenet5')

100%|██████████████| 157/157 [00:00<00:00, 180.63it/s, Loss=0.0589, Top1=98.12%]


{'top1_acc': 0.9812,
 'top5_acc': 0.9999,
 'total_inference_time': 0.8693900108337402,
 'average_loss': 3.7515113297449485,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}

In [16]:
test_loader.batch_size

64

In [17]:
# ======== Static Quantisation of the baseline model
qconfig = get_default_qconfig("qnnpack")
custom_config = tq.QConfig(
    activation=HistogramObserver.with_args(
        reduce_range=False
    ),
    weight=PerChannelMinMaxObserver.with_args(
        dtype=torch.qint8,
        qscheme=torch.per_channel_symmetric
    )
)
print(qconfig)
qconfig_mapping = QConfigMapping() \
    .set_global(custom_config) \
    .set_object_type(torch.nn.Conv2d, custom_config) \
    .set_object_type(torch.nn.Linear, custom_config) \
    .set_object_type(torch.nn.ReLU, custom_config)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})


In [18]:
example_inputs = (torch.randn(1, 1, 28, 28),)
print(baseline_model)
base_st_quant_model_p = quant_model_prep(example_inputs, baseline_model, qconfig_mapping)

LeNet5(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [19]:
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
calibration_loader = DataLoader(torch.utils.data.Subset(train_dataset, range(1000)), batch_size=64)

In [20]:
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad(), tqdm(total=len(data_loader)) as pbar:
        for image, target in data_loader:
            model(image)
            pbar.update(1)

In [21]:
calibrate(base_st_quant_model_p, calibration_loader)

100%|██████████████████████████████████████████| 16/16 [00:00<00:00, 118.85it/s]


In [22]:
base_st_quant_model = convert_fx(base_st_quant_model_p)

In [23]:
print(base_st_quant_model)

GraphModule(
  (conv1): QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.02737203985452652, zero_point=0)
  (conv2): QuantizedConvReLU2d(6, 16, kernel_size=(5, 5), stride=(1, 1), scale=0.047422222793102264, zero_point=0)
  (fc1): QuantizedLinearReLU(in_features=256, out_features=120, scale=0.07808518409729004, zero_point=0, qscheme=torch.per_channel_affine)
  (fc2): QuantizedLinearReLU(in_features=120, out_features=84, scale=0.06925788521766663, zero_point=0, qscheme=torch.per_channel_affine)
  (fc3): QuantizedLinear(in_features=84, out_features=10, scale=0.1696537584066391, zero_point=137, qscheme=torch.per_channel_affine)
)



def forward(self, x):
    conv1_input_scale_0 = self.conv1_input_scale_0
    conv1_input_zero_point_0 = self.conv1_input_zero_point_0
    quantize_per_tensor = torch.quantize_per_tensor(x, conv1_input_scale_0, conv1_input_zero_point_0, torch.quint8);  x = conv1_input_scale_0 = conv1_input_zero_point_0 = None
    conv1 = self.conv1(quantize_p

In [24]:
torch.save(base_st_quant_model.state_dict(), saved_model_path / 'lenet5_base_st_quant_weights.pth')
torch.save(base_st_quant_model, saved_model_path / 'lenet5_base_st_quant_model.pth')

In [25]:
evaluate_model(base_st_quant_model, test_loader, 'lenet5')

100%|██████████████| 157/157 [00:00<00:00, 185.13it/s, Loss=0.0589, Top1=98.11%]


{'top1_acc': 0.9811,
 'top5_acc': 0.9999,
 'total_inference_time': 0.8482372760772705,
 'average_loss': 3.753511188885458,
 'total_batches': 157,
 'all_losses': [],
 'all_top1_acc': [],
 'inference_times': [],
 'true_labels': [],
 'predicted_labels': []}