# Implementation in PyTorch

### Static Quantization Example

This is a simple Pytorch model with static quantization using FBGEMM (Facebook GEneral Matrix Multiplication). I have to define a Quantization-Ready Model.

The model needs **QuantStub** and **DeQuantStub** layers to mark the points where tensors are quantized and dequantized.

The **QuantStub** and **DeQuantStub** are crucial for defining the quantization boundaries, and the calibration step is essential for collecting the necessary statistics for accurate quantization.

#### FBGEMM (Facebook GEneral Matrix Multiplication) 

It is a high-performance, low-precision library used as a backend for quantized operators on x86 machines.


## Load Modules

In [19]:
import torch
import torch.nn as nn
import torch.quantization

import warnings
warnings.filterwarnings('ignore')

## Define a simple neural network model

In [20]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(10, 20)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(20, 5)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dequant(x)
        return x

## Instantiate and prepare the model for quantization

#### `torch.quantization.get_default_qconfig('fbgemm')` 

Returns the default quantization configuration **(QConfig)** for the **FBGEMM** backend in PyTorch.

This default QConfig is suitable for **post-training** static quantization and quantization-aware training when targeting x86 CPUs. It provides a common and effective configuration for achieving performance benefits with reduced precision while maintaining model accuracy.


#### FBGEMM (Facebook GEneral Matrix Multiplication) 

It is a high-performance, low-precision library used as a backend for quantized operators on x86 machines.

## Prepare the Model for Quantization.
Set the backend to "fbgemm" and get the default quantization configuration. Then, prepare the model, which inserts observers to collect statistics for quantization.


### model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
Returns the default quantization configuration **(QConfig)** for the **FBGEMM** backend in PyTorch.

This default QConfig is suitable for **post-training** static quantization and quantization-aware training when targeting x86 CPUs. It provides a common and effective configuration for achieving performance benefits with reduced precision while maintaining model accuracy.

In [21]:
model_fp32 = SimpleModel()
model_fp32.eval() # Set model to evaluation mode for quantization

# Specify FBGEMM backend for quantization
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Prepare the model for static quantization
# This inserts observers to record activation statistics
model_prepared = torch.quantization.prepare(model_fp32, inplace=False)

## Calibrate the Model.

Run the prepared model on a representative dataset. The observers collect statistics (min/max values) of activations, which are used to determine quantization `scale factors` and `zero-points`.

In [22]:
# Create dummy data for calibration
dummy_input = torch.randn(1, 10)

# Calibrate the model by running it with dummy data
with torch.inference_mode():
    model_prepared(dummy_input)

## Convert to Quantized Model.
Convert the prepared model to its quantized version using the collected statistics. This replaces floating-point operations with their quantized integer equivalents.

In [23]:
# Convert the prepared model to a quantized model
model_quantized = torch.quantization.convert(model_prepared, inplace=False)

In [24]:
model_quantized

SimpleModel(
  (quant): Quantize(scale=tensor([0.0224]), zero_point=tensor([75]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=10, out_features=20, scale=0.020005200058221817, zero_point=71, qscheme=torch.per_channel_affine)
  (relu): ReLU()
  (linear2): QuantizedLinear(in_features=20, out_features=5, scale=0.003096354193985462, zero_point=122, qscheme=torch.per_channel_affine)
  (dequant): DeQuantize()
)