---
title: "quantizing explorations in pytorch"
description: "quantization..."
author: "me"
date: 2023-11-01
draft: true
---

In [1]:
# python 3.11 torch 2.1 torchvision 0.16 

In [1]:
# demo simple QAT
# https://pytorch.org/docs/stable/quantization.html#quantization-aware-training-for-static-quantization

import torch
import torchvision
from copy import deepcopy

In [17]:

# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv1 = torch.nn.Conv2d(1, 1, 1)
        self.bn1 = torch.nn.BatchNorm2d(1)
        # self.bn1 = torchvision.ops.FrozenBatchNorm2d(1)
        self.relu1 = torch.nn.ReLU()
        # add another layer for perf comparison
        self.conv2 = torch.nn.Conv2d(1, 1, 1)
        self.bn2 = torch.nn.BatchNorm2d(1)
        # self.bn2 = torchvision.ops.FrozenBatchNorm2d(1)
        self.relu2 = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dequant(x)
        return x


In [18]:
# create a model instance
model_fp32 = M()
model_fp32

M(
  (quant): QuantStub()
  (conv1): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (bn1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (conv2): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (bn2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (dequant): DeQuantStub()
)

In [19]:
model = deepcopy(model_fp32)

model.qconfig = torch.quantization.get_default_qconfig('fbgemm')  # QAT config

In [20]:
# model must be set to eval for fusion to work
model.eval()

# fusions: [Conv, Relu], [Conv, BatchNorm], [Conv, BatchNorm, Relu], [Linear, Relu]
torch.quantization.fuse_modules(model, [['conv1', 'bn1', 'relu1'], ['conv2', 'bn2', 'relu2']], inplace=True)
model  # note 'Identity()'' where 'bn' and 'relu' modules were

M(
  (quant): QuantStub()
  (conv1): ConvReLU2d(
    (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
  )
  (bn1): Identity()
  (relu1): Identity()
  (conv2): ConvReLU2d(
    (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
  )
  (bn2): Identity()
  (relu2): Identity()
  (dequant): DeQuantStub()
)

In [21]:
# back to train for QAT
model.train()

torch.quantization.prepare_qat(model, inplace=True)

# train ...

model.eval()
torch.quantization.convert(model, inplace=True)
model

M(
  (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (conv1): QuantizedConvReLU2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
  (bn1): Identity()
  (relu1): Identity()
  (conv2): QuantizedConvReLU2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)
  (bn2): Identity()
  (relu2): Identity()
  (dequant): DeQuantize()
)

In [22]:
from time import perf_counter

n = 1000

start = perf_counter()
for _ in range(n):
    model_fp32(torch.rand(8, 1, 32, 32))
print(f"float model avg time: {(perf_counter() - start) / n}")

start = perf_counter()
for _ in range(n):
    model(torch.rand(8, 1, 32, 32))
print(f"quant model avg time: {(perf_counter() - start) / n}")

# roughly 66%

float model avg time: 0.000531349700000078
quant model avg time: 0.00043306860000006964


In [None]:
# GRAPH FX MODE QUANTIZATION !!!
# https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html
# post-training-static
# can transform syntax '+=' to proper quantized method which normally causes eager mode to fail
# need pytorch 1.11+

In [336]:
from torchvision.models import resnet18
from torchvision import datasets, transforms

resnet_18 = resnet18()

dataset = datasets.CIFAR10(root="data", download=False, train=False, transform=transforms.ToTensor())  # smaller test data
data_loader = torch.utils.data.DataLoader(dataset=dataset)

In [338]:
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping

resnet_18.eval()
# The old 'fbgemm' is still available but 'x86' is the recommended default.
qconfig = get_default_qconfig("fbgemm")
qconfig_mapping = QConfigMapping().set_global(qconfig)

def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
            
example_inputs = (next(iter(data_loader))[0]) # get an example input
prepared_model = prepare_fx(resnet_18, qconfig_mapping, example_inputs)  # fuse modules and insert observers
calibrate(prepared_model, data_loader)  # run calibration on sample data
quantized_model = convert_fx(prepared_model)  # convert the calibrated model to a quantized model

In [345]:
# compare with original resnet

resnet_18_original = resnet18()
resnet_18_original.eval()

n = 1

start = perf_counter()
for _ in range(n):
    for img, label in data_loader:
        resnet_18_original(img)
print(f"float model avg time: {(perf_counter() - start) / (n * len(data_loader))}")

start = perf_counter()
for _ in range(n):
    for img, label in data_loader:
        quantized_model(img)
print(f"quant model avg time: {(perf_counter() - start) / (n * len(data_loader))}")

# roughly 66%

quant model avg time: 0.005655732270001317
float model avg time: 0.00878128740999964


In [343]:
# compare with torch jit

torch.jit.save(torch.jit.script(quantized_model), "./data/quant_jit_model.pth")

quantized_jit_model = torch.jit.load("./data/quant_jit_model.pth", map_location=torch.device('cpu'))

In [344]:
start = perf_counter()
for _ in range(n):
    for img, label in data_loader:
        quantized_jit_model(img)
print(f"quant jit model avg time: {(perf_counter() - start) / (n * len(data_loader))}")

# 29%

quant jit model avg time: 0.0025585005400003864
