# Practical Quantization in PyTorch

## FUNDAMENTALS OF QUANTIZATION

Quantization has roots in information compression; in deep networks it refers to reducing the numerical precision of its weights and/or activations.

Overparameterized DNNs have more degrees of freedom and this makes them good candidates for information compression. When you quantize a model, two things generally happen - the model gets smaller and runs with better efficiency. 

- Hardware vendors explicitly allow for faster processing of 8-bit data (than 32-bit data) resulting in higher throughput. 
- A smaller model has lower memory footprint and power consumption, crucial for deployment at the edge.

### Mapping function

The mapping function is what you might guess - a function that maps values from floating-point to integer space. A commonly used mapping function is a linear transformation given by:

$$Q(r)=round(r/S + Z)$$

where `r` is the input and `S` are `Z` `quantization parameters`.

To reconvert to floating point space, the inverse function is given by:

$$\hat r=(Q(r) - Z)*S$$

$\hat r \neq r$, and their difference constitutes the quantization error.

### Quantization Parameters

The mapping function is parameterized by the scaling factor `S` and zero-point `Z`.

 `S` is simply the ratio of the input range to the output range:

 $$ S = \dfrac {\beta - \alpha} {\beta_q - \alpha_q}$$

where [$\alpha$, $\beta$] is the clipping range of the input, i.e. the boundaries of permissible inputs. [$\alpha_q$, $\beta_q$] is the range in quantized output space that it is mapped to. For 8-bit quantization, the output range:
$$ \beta_q - \alpha_q <= (2^8 - 1)$$

`Z` acts as a bias to ensure that a 0 in the input space maps perfectly to a 0 in the quantized space:
$$Z=-(\dfrac {\alpha}{S} - \alpha_q)$$ 

### Calibration

The process of choosing the input clipping range is known as calibration. The simplest technique (also the default in PyTorch) is to record the running mininmum and maximum values and assign them to $\alpha$ and $\beta$. `TensorRT` also uses entropy minimization (KL divergence), mean-square-error minimization, or percentiles of the input range.

In PyTorch, `Observer` modules collect statistics on the input values and calculate the qparams `S`, `Z`. Different calibration schemes result in different quantized outputs, and it’s best to empirically verify which scheme works best for your application and architecture (more on that later).

In [3]:
import torch
from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver

C, L = 3, 4
normal = torch.distributions.normal.Normal(0, 1)
inputs = [normal.sample((C, L)), normal.sample((C, L))]
print(inputs)

[tensor([[-1.3661,  0.6686, -0.7628, -0.1713],
        [ 0.5461, -0.9687,  0.5894, -0.1433],
        [ 0.0744,  1.8963, -0.9171,  0.0334]]), tensor([[ 0.3307, -1.0928, -0.8456,  1.2337],
        [ 0.6180, -1.6796, -0.9699,  0.8957],
        [ 0.5770, -1.1889,  0.9793,  1.6308]])]


In [3]:
observers = [MinMaxObserver(), MovingAverageMinMaxObserver(), HistogramObserver()]
for obs in observers:
  for x in inputs:
    obs(x)
  print(obs.__class__.__name__, obs.calculate_qparams())

MinMaxObserver (tensor([0.0185]), tensor([152], dtype=torch.int32))
MovingAverageMinMaxObserver (tensor([0.0139]), tensor([137], dtype=torch.int32))
HistogramObserver (tensor([0.0128]), tensor([159], dtype=torch.int32))


### Affine and Symmetric Quantization Schemes

__Affine or asymmetric quantization__ schemes assign the input range to the min and max observed values. Affine schemes generally offer tighter clipping ranges and are useful for quantizing non-negative activations (you don’t need the input range to contain negative values if your input tensors are never negative). The range is calculated as $\alpha=min(r)$, $\beta=max(r)$. Affine quantization leads to more computationally expensive inference when used for weight tensors.

__Symmetric quantization__ schemes center the input range around 0, eliminating the need to calculate a zero-point offset. The range is calculated as:

$$ -\alpha = \beta = max(|max(r)|, |min(r)|)$$

 For skewed signals (like non-negative activations) this can result in bad quantization resolution because the clipping range includes values that never show up in the input.

 In PyTorch, you can specify affine or symmetric schemes while initializing the Observer. Note that not all observers support both schemes.

In [4]:
for qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric]:
  obs = MovingAverageMinMaxObserver(qscheme=qscheme)
  for x in inputs:
    obs(x)
  print(f"Qscheme: {qscheme} | {obs.calculate_qparams()}")

Qscheme: torch.per_tensor_affine | (tensor([0.0128]), tensor([107], dtype=torch.int32))
Qscheme: torch.per_tensor_symmetric | (tensor([0.0149]), tensor([128]))


### Per-Tensor and Per-Channel Quantization Schemes

Quantization parameters can be calculated for the layer’s entire weight tensor as a whole, or separately for each channel. In per-tensor, the same clipping range is applied to all the channels in a layer

<img src="fig/per-channel-tensor.svg">

For weights quantization, symmetric-per-channel quantization provides better accuracies; per-tensor quantization performs poorly, possibly due to high variance in conv weights across channels from batchnorm folding.

In [6]:
from torch.quantization.observer import MovingAveragePerChannelMinMaxObserver
# Calculate qparams for all 'C' channels separately
obs = MovingAveragePerChannelMinMaxObserver(ch_axis=0)
for x in inputs:
  obs(x)
print(obs.calculate_qparams())

(tensor([0.0080, 0.0061, 0.0110]), tensor([171, 159,  83], dtype=torch.int32))


### Backend Engine

Currently, quantized operators run on x86 machines via the `FBGEMM` backend, or use `QNNPACK` primitives on ARM machines. 

In [7]:
backend = 'fbgemm'
qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

### QConfig

The `QConfig` NamedTuple stores the `Observers` and the quantization schemes used to quantize activations and weights.

Be sure to pass the `Observer` class (not the instance), or a callable that can return Observer instances. Use `with_args()` to override the default arguments.

In [8]:
my_qconfig = torch.quantization.QConfig(
    activation=MovingAverageMinMaxObserver.with_args(qscheme=torch.per_tensor_affine), 
    weight=MovingAveragePerChannelMinMaxObserver.with_args(qscheme=torch.qint8)
)

## PyTorch Quantization

PyTorch allows you a few different ways to quantize your model depending on

- if you prefer a flexible but manual, or a restricted automagic process (Eager Mode v/s FX Graph Mode)
- if qparams for quantizing activations (layer outputs) are precomputed for all inputs, or calculated afresh with each input (static v/s dynamic),
- if qparams are computed with or without retraining (quantization-aware training v/s post-training quantization)

FX Graph Mode automatically fuses eligible modules, inserts Quant/DeQuant stubs, calibrates the model and returns a quantized module - all in two method calls - but only for networks that are `symbolic traceable`. The examples below contain the calls using Eager Mode and FX Graph Mode for comparison.

In DNNs, eligible candidates for quantization are the FP32 weights (layer parameters) and activations (layer outputs). Quantizing weights reduces the model size. Quantized activations typically result in faster inference.

As an example, the 50-layer ResNet network has ~26 million weight parameters and computes ~16 million activations in the forward pass.

## Post-Training Static Quantization (PTQ)

PTQ also pre-quantizes model weights but instead of calibrating activations on-the-fly, the clipping range is pre-calibrated and fixed (“static”) using validation data. Activations stay in quantized precision between operations during inference. About 100 mini-batches of representative data are sufficient to calibrate the observers. The examples below use random data in calibration for convenience - using that in your application will result in bad qparams.

<img src="fig/ptq-flowchart.svg">

`Module fusion` combines multiple sequential modules (eg: [Conv2d, BatchNorm, ReLU]) into one. Fusing modules means the compiler needs to only run one kernel instead of many; this speeds things up and improves accuracy by reducing quantization error.

- (+) Static quantization has faster inference than dynamic quantization because it eliminates the float<->int conversion costs between layers.

- (-) Static quantized models may need regular re-calibration to stay robust against distribution-drift.

In [9]:
# Static quantization of a model consists of the following steps:

#     Fuse modules
#     Insert Quant/DeQuant Stubs
#     Prepare the fused module (insert observers before and after layers)
#     Calibrate the prepared module (pass it representative data)
#     Convert the calibrated module (replace with quantized version)

import torch
from torch import nn
import copy

# running on a x86 CPU, Use "qnnpack" if running on ARM
backend = "fbgemm"

model = nn.Sequential(
    nn.Conv2d(in_channels=2, out_channels=64, kernel_size=3),
    nn.ReLU(),
    nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
    nn.ReLU()
)

# EAGER MODE
m = copy.deepcopy(model)
m.eval()

# Fuse modules
# fuse first Conv-ReLU pair
torch.quantization.fuse_modules(m, ['0', '1'], inplace=True)
# fuse second Conv-ReLU pair
torch.quantization.fuse_modules(m, ['2', '3'], inplace=True)

# Insert stubs
m = nn.Sequential(torch.quantization.QuantStub(),
                  *m,
                  torch.quantization.DeQuantStub())

# Prepare
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare(m, inplace=True)

# Calibarate
# This example uses random data for convenience. Use representation (validation) data instead
with torch.inference_mode():
    for _ in range(10):
        x = torch.rand(1, 2, 28, 28)
        m(x)

# convert
torch.quantization.convert(m, inplace=True)

# Check
# 1 byte instead of 4 bytes for FP32
print(m[1].weight().element_size())


1




### FX GRAPH

In [11]:
from torch.quantization import quantize_fx

m = copy.deepcopy(model)
m.eval()

qconfig_dict = {"":torch.quantization.get_default_qconfig(backend)}

# Prepare
model_prepared = quantize_fx.prepare_fx(m, qconfig_dict)

# Calibrate - Use representative (validation) data
with torch.inference_mode():
    for _ in range(10):
        x = torch.rand(1, 2, 28, 28)
        model_prepared(x)

# Quantize
model_quantized = quantize_fx.convert_fx(model_prepared)

print(model_prepared[1].weight().element_size())

TypeError: prepare_fx() missing 1 required positional argument: 'example_inputs'

## Quantization-Aware Training (QAT)

<img src="fig/qat-flowchart.svg">

The PTQ approach is great for large models, but accuracy suffers in smaller models. QAT tackles this by including this quantization error in the training loss, thereby training an INT8-first model.

All weights and biases are stored in FP32, and backpropagation happens as usual. However in the forward pass, quantization is internally simulated via FakeQuantize modules. They are called fake because they quantize and immediately dequantize the data, adding quantization noise similar to what might be encountered during quantized inference. The final loss thus accounts for any expected quantization errors. Optimizing on this allows the model to identify FP32 parameters such that quantizing them to INT8 does not significantly affect accuracy.

<p align="center">
<img src="fig/qat-fake-quantization.png">
</p>

- (+) QAT yields higher accuracies than PTQ.

- (+) Qparams can be learned during model training for more fine-grained accuracy

- (-) Computational cost of retraining a model in QAT can be several hundred epochs

In [14]:
# QAT follows the same steps as PTQ, with the exception of the training loop before you actually convert the model to its quantized version
import torch
from torch import nn

backend = "fbgemm"

m = nn.Sequential(
    nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3),
    nn.ReLU(),
    nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
    nn.ReLU()
)

# Fuse
# Fuse first Conv-ReLU pair
torch.quantization.fuse_modules(m, ["0", "1"], inplace=True)
# Fuse second Conv-ReLU pair
torch.quantization.fuse_modules(m, ["2", "3"], inplace=True)

# Insert stubs
m = nn.Sequential(torch.quantization.QuantStub(),
                  *m,
                  torch.quantization.DeQuantStub())

# Prepare
m.train()
m.qconfig = torch.quantization.get_default_qconfig(backend)
torch.quantization.prepare_qat(m, inplace=True)

# Training loop
n_epochs = 10
opt = torch.optim.SGD(m.parameters(), lr=0.1)
loss_fn = lambda out, tgt: torch.pow(tgt-out, 2).mean()
for epoch in range(n_epochs):
    x = torch.rand(10, 3, 28, 28)
    out = m(x)
    loss = loss_fn(out, torch.rand_like(out))
    opt.zero_grad()
    loss.backward()
    opt.step()

# Convert
m.eval()
torch.quantization.convert(m, inplace=True)    

Sequential(
  (0): Quantize(scale=tensor([0.0081]), zero_point=tensor([0]), dtype=torch.quint8)
  (1): QuantizedConvReLU2d(3, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.012540503405034542, zero_point=0)
  (2): Identity()
  (3): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(1, 1), scale=0.0053305355831980705, zero_point=0)
  (4): Identity()
  (5): DeQuantize()
)