# Quantization Aware Training quick implementation using Pytorch

This small and simple notebook will guide you to quantize a model using QAT with only a few more line of code.

## Imports

In [59]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub
import numpy as np

## Base class for the model

In order to achieve quantization we can use an existing model architecture and simply create an other class adding only 2 pytorch functions

In [60]:
class Model(torch.nn.Module): # Define a simple neural network
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(input_dim, hidden_dim)
        self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim//2)
        self.linear3 = torch.nn.Linear(hidden_dim//2, 10)
        self.output = torch.nn.Linear(10, output_dim)

    def forward(self, x): # Forward pass
        x = torch.nn.functional.relu(self.linear1(x))
        x = torch.nn.functional.relu(self.linear2(x))
        x = torch.nn.functional.relu(self.linear3(x))
        x = self.output(x)
        return x
    
class QuantizedModel(nn.Module): # Define a quantized model
    def __init__(self, model):
        super(QuantizedModel, self).__init__() # Define a quantized model using QuantStub, DeQuantStub
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.model = model

    def forward(self, x): # Forward pass using QuantStub, DeQuantStub to quantize and dequantize the input, output
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

## Base functions
We define the basics to generate data and to train the model, note that you need nothing more than usual here for the QAT to work.

In [61]:
def get_train_test_data(input_dim, output_dim, num_samples=100): # Generate random input and output data
    np.random.seed(0)
    inputs = np.random.rand(num_samples, input_dim).astype(np.float32) 
    targets = np.random.rand(num_samples, output_dim).astype(np.float32)
    inputs = torch.from_numpy(inputs)
    targets = torch.from_numpy(targets)
    train_loader = torch.utils.data.DataLoader(list(zip(inputs, targets)), batch_size=10)
    test_loader = torch.utils.data.DataLoader(list(zip(inputs, targets)), batch_size=10)
    return train_loader, test_loader

def training_loop(model, train_loader, test_loader, criterion, optimizer, num_epochs=10): # Training loop
    model.train() # Set the model to training mode
    for epoch in range(num_epochs):
        for inputs, targets in train_loader: # Train the model
            optimizer.zero_grad() # Zero the gradients before the backward pass in order to avoid accumulating them
            outputs = model(inputs) # Forward pass
            loss = criterion(outputs, targets) # Compute the loss
            loss.backward() # Backward pass
            optimizer.step() # Update the weights
        #print(f"Epoch: {epoch}, Loss: {loss.item()}")
        model.eval() 
        for inputs, targets in test_loader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
        #print(f"Test Loss: {loss.item()}")

## Preparation

Define the dimension of your model's layers, be aware that you need it to be large enough so that the quantization doesn't outweights the benefits of reduced computational complexity.

In [62]:
input_dim = 8 
hidden_dim = 4000 
output_dim = 1

train_loader, test_loader = get_train_test_data(input_dim, output_dim) # Get the training and test data

Training of the fp32 model

In [63]:
model_fp32 = Model(input_dim, hidden_dim, output_dim) # Create a model
optimizer = optim.Adam(model_fp32.parameters(), lr=0.01) # Create an optimizer
criterion = nn.MSELoss() # Create a loss function

training_loop(model_fp32, train_loader, test_loader, criterion, optimizer) # Train the model

## Quantization

We can now proceed to our model quantization (note that it works the same with a non pre-trained model).

#### We only need 3 more lines and a calibration :

-set the model qconfig (depending on the hardware your using, here fbgemm on Intel), which is the setup of default config for QAT (int8 etc..)

-prepare the model to QAT which converts the model to quantized, applies qconfig, enable observers etc..

-now we have to calibrate the model using sample inputs in order for it to find the quantization parameters

now we can train / fine-tune our model

-finally, we use the convert function which ends all the simulation stuff happening in training and gives us the trully quantized model

note that the inplace flags are optionnal, they simply indicates if you want your model to be impacted or if it creates a copy of it (here, the same model is being modified again and again)

In [64]:
def calibrate_model(model, train_loader): # Calibrate the model to find the quantization parameters
    model.eval() # Set the model to evaluation mode
    for inputs, _ in train_loader: # Run the model on the training data
        _ = model(inputs) # Forward pass

In [65]:
import copy
model_to_quantize = copy.deepcopy(model_fp32) # Create a copy of the model to quantize
 
model_int8_qat = QuantizedModel(model_to_quantize) # Create a quantized model
model_int8_qat.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm') # Set the quantization configuration
torch.quantization.prepare_qat(model_int8_qat, inplace=True) # Prepare the model for quantization
optimizer_int8_qat = optim.Adam(model_int8_qat.parameters(), lr=0.01) # Create an optimizer for the quantized model

calibrate_model(model_int8_qat, train_loader) # Calibrate the quantized model
training_loop(model_int8_qat, train_loader, test_loader, criterion, optimizer_int8_qat) # Train the quantized model if needed

model_int8 = torch.quantization.convert(model_int8_qat, inplace=True) # Convert the quantized model to int8



## Comparison

Let's create a small function to calculate our model's inference time, and carbon emissions using the codecarbon library.

We also create a small function to compare both model's sizes.

In [66]:
# compare both models inference time
import time
import codecarbon
from codecarbon import EmissionsTracker

def inference_time(model, data_loader, num_iterations=100, warmup_iterations = 10):

    tracker = EmissionsTracker(log_level='critical', save_to_file=False) # Initialize the co2 tracker
    tracker.start() # Start the tracker

    model.eval() # Set the model to evaluation mode
    for _ in range(warmup_iterations): # Warmup the model in order to reduce the variance of the measurements
        for inputs, targets in data_loader:
            with torch.no_grad(): # Disable gradient tracking
                model.forward(inputs)

    start = time.time()
    for _ in range(num_iterations): # Measure the inference time based on multiple iterations to reduce the variance
        for inputs, targets in data_loader: 
            with torch.no_grad(): # Disable gradient tracking
                model.forward(inputs) 
    end = time.time()
    total_emissions = tracker.stop() # Stop the tracker and get the total emissions

    return (end - start) / num_iterations, total_emissions / num_iterations

import os
def get_model_size(mdl):
    
    torch.save(mdl.state_dict(), "tmp.pt") # Save the model to disk
    size = os.path.getsize("tmp.pt") # Get the size of the model
    os.remove('tmp.pt') # Remove the temporary file

    return size

def evaluate_model(model, data_loader, criterion):
    model.eval() # Set the model to evaluation mode
    total_loss = 0 # Initialize the total loss
    with torch.no_grad(): # Disable gradient tracking
        for inputs, targets in data_loader:
            outputs = model(inputs) # Forward pass
            loss = criterion(outputs, targets) # Compute the loss
            total_loss += loss.item() # Accumulate the loss

    return total_loss / len(data_loader)

Compute everything we want to compare

In [67]:
fp32_inference_time, fp32_emissions = inference_time(model_fp32, test_loader) # Measure the inference time of the FP32 model

int8_inference_time, int8_emissions = inference_time(model_int8, test_loader) # Measure the inference time of the INT8 model

fp32_size = get_model_size(model_fp32) # Get the size of the FP32 model

int8_size = get_model_size(model_int8) # Get the size of the INT8 model

fp32_loss = evaluate_model(model_fp32, test_loader, criterion) # Evaluate the FP32 model

int8_loss = evaluate_model(model_int8, test_loader, criterion) # Evaluate the INT8 model



## Final dataframe

In [68]:
import pandas as pd

results = pd.DataFrame({    
    'Model': ['FP32', 'INT8'],
    'Inference Time (s)': [fp32_inference_time, int8_inference_time],
    'Emissions (kgCO2)': [fp32_emissions, int8_emissions],
    'Model Size (MB)': [fp32_size, int8_size],
    'Loss': [fp32_loss, int8_loss]
})

print(results)

  Model  Inference Time (s)  Emissions (kgCO2)  Model Size (MB)      Loss
0  FP32            0.023889       9.942014e-09         32234768  0.683736
1  INT8            0.020083       7.959728e-09          8179664  0.298309


Here the results are better in very aspects ! Usually loss will decrease a very small amount but the gain elsewhere is huge.

## Print the models

In [69]:
print(model_fp32)

Model(
  (linear1): Linear(in_features=8, out_features=4000, bias=True)
  (linear2): Linear(in_features=4000, out_features=2000, bias=True)
  (linear3): Linear(in_features=2000, out_features=10, bias=True)
  (output): Linear(in_features=10, out_features=1, bias=True)
)


In [70]:
print(model_int8)

QuantizedModel(
  (quant): Quantize(scale=tensor([0.0078]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (model): Model(
    (linear1): QuantizedLinear(in_features=8, out_features=4000, scale=0.022292088717222214, zero_point=76, qscheme=torch.per_channel_affine)
    (linear2): QuantizedLinear(in_features=4000, out_features=2000, scale=0.3334226906299591, zero_point=46, qscheme=torch.per_channel_affine)
    (linear3): QuantizedLinear(in_features=2000, out_features=10, scale=0.6604927778244019, zero_point=127, qscheme=torch.per_channel_affine)
    (output): QuantizedLinear(in_features=10, out_features=1, scale=0.00018686012481339276, zero_point=0, qscheme=torch.per_channel_affine)
  )
)
