# Quantization tutorial

This tutorial shows how to do post-training static quantization, as well as illustrating two more advanced techniques – per-channel quantization and quantization-aware training – to further improve the model’s accuracy. The task is to classify MNIST digits with a simple LeNet architecture.

This is a minimalistic tutorial to show you a starting point for quantization in PyTorch. For theory and more in-depth explanations, please check out: [Quantizing deep convolutional networks for efficient inference: A whitepaper](https://arxiv.org/abs/1806.08342).

The tutorial is heavily adapted from: https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html

### Top-Level Explanation

This notebook demonstrates how to improve the inference efficiency of a neural network by quantizing a CNN trained on the MNIST dataset. It covers three key methods:
  1. **Post-Training Static Quantization:** Inserting observers, calibrating on sample data, and converting to quantized operators.
  2. **Custom Quantization Configuration:** Using alternative observers (e.g. moving average observers) to potentially improve generalization.
  3. **Quantization Aware Training (QAT):** Training the network while simulating quantization effects to achieve higher accuracy in the quantized model.

Potential Lab Test Q&A:
  - **Q:** What is the purpose of quantization in deep learning?
    **A:** Quantization reduces model size and increases inference speed by converting weights and activations from floating point to lower precision (e.g. int8) while trying to preserve accuracy.
  - **Q:** How does quantization-aware training differ from post-training quantization?
    **A:** QAT simulates quantization effects during training, which usually results in better accuracy than post-training quantization that only converts a pretrained model.
Quantization-aware training (QAT) integrates the quantization process into the training loop. This means that during training, the network simulates the effects of quantization—such as rounding errors and limited precision—in its forward and backward passes. As a result, the model learns to compensate for these inaccuracies, adjusting its weights to maintain performance under quantized conditions.

On the other hand, post-training quantization (PTQ) converts a fully-trained, full-precision model to lower precision after training is complete. Since the model wasn't exposed to quantization effects during training, it may not be as robust to the quantization errors introduced during conversion. This often leads to a larger drop in accuracy compared to QAT.

In summary, QAT usually results in better accuracy because it allows the model to learn and adapt to the quantization errors during training, while PTQ applies quantization retrospectively without any adaptive compensation.

### Initial Setup

Before beginning the assignment, we import the MNIST dataset, and train a simple convolutional neural network (CNN) to classify it. In the next cells, we install and import the required libraries and define helper functions and classes for training, evaluation, and quantization.

In [1]:
# Install specific versions of PyTorch and torchvision (for compatibility with the tutorial code)
!pip3 install torch==1.5.0 torchvision==1.6.0

# Import essential PyTorch packages and quantization tools
import torch                # Core PyTorch library
import torchvision          # Provides datasets and models for computer vision tasks
import torchvision.transforms as transforms  # For data preprocessing and augmentation
import torch.nn as nn       # For building neural network modules
import torch.nn.functional as F  # Provides functional interface for common operations
import torch.optim as optim # Optimizers for training
import os                   # For file and directory operations
from torch.utils.data import DataLoader  # For loading and batching datasets
import torch.quantization   # For post-training quantization tools
from torch.quantization import QuantStub, DeQuantStub  # For inserting quantization and dequantization steps

[31mERROR: Could not find a version that satisfies the requirement torch==1.5.0 (from versions: 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0)[0m
[31mERROR: No matching distribution found for torch==1.5.0[0m


Load training and test data from the MNIST dataset and apply a normalizing transformation.

In [2]:
# Define a transformation: convert images to tensors and normalize them
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize MNIST images (grayscale) with mean=0.5, std=0.5
])

# Load MNIST training dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=16, pin_memory=True)

# Load MNIST test dataset
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=16, pin_memory=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 11.6MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 343kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.23MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 10.7MB/s]


Define some helper functions and classes that help us track training statistics and compute accuracy.

In [3]:
class AverageMeter(object):
    """Helper class to compute and store the average and current value."""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target):
    """Computes the top-1 accuracy for the given output and target labels."""
    with torch.no_grad():
        batch_size = target.size(0)
        _, pred = output.topk(1, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        correct_one = correct[:1].view(-1).float().sum(0, keepdim=True)
        return correct_one.mul_(100.0 / batch_size).item()

def print_size_of_model(model):
    """Saves the model temporarily and prints its size in MB."""
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p") / 1e6)
    os.remove('temp.p')

def load_model(quantized_model, model):
    """Loads pretrained weights from the original model into the quantized model."""
    state_dict = model.state_dict()
    model = model.to('cpu')
    quantized_model.load_state_dict(state_dict)

def fuse_modules(model):
    """Fuses convolution/linear layers with their subsequent ReLU activation for improved performance and accuracy."""
    torch.quantization.fuse_modules(model, [['conv1', 'relu1'],
                                            ['conv2', 'relu2'],
                                            ['fc1', 'relu3'],
                                            ['fc2', 'relu4']], inplace=True)

Define a simple CNN (LeNet-style) to classify MNIST images. This network optionally includes quantization stubs to enable quantization-aware training or post-training quantization.

In [4]:
class Net(nn.Module):
    def __init__(self, q=False):
        # If q is True, quantization stubs are added for quantization-aware training or post-training quantization
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, bias=False)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5, bias=False)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256, 120, bias=False)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84, bias=False)
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10, bias=False)
        self.q = q
        if q:
            self.quant = QuantStub()    # Marks the beginning of quantization
            self.dequant = DeQuantStub()  # Marks the end of quantization

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.q:
            x = self.quant(x)  # Quantize the input
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        x = x.reshape(x.shape[0], -1)  # Flatten the tensor for the fully connected layers
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.fc3(x)
        if self.q:
            x = self.dequant(x)  # Dequantize the output
        return x

In [5]:
# Instantiate the network (without quantization for initial training)
net = Net(q=False).cuda()
print_size_of_model(net)  # Print model size to benchmark before quantization

Size (MB): 0.179057


Train this CNN on the training dataset (this may take a few moments).

In [7]:
def train(model: nn.Module, dataloader: DataLoader, cuda=False, q=False):
    criterion = nn.CrossEntropyLoss()  # Loss function for classification
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # SGD optimizer
    model.train()  # Set the model to training mode
    for epoch in range(10):  # Train for 10 epochs

        running_loss = AverageMeter('loss')
        acc = AverageMeter('train_acc')
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data  # Get inputs and labels
            if cuda:
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()  # Zero the parameter gradients

            if epoch >= 3 and q:
                model.apply(torch.quantization.disable_observer)  # Disable observers after a few epochs when quantization is enabled

            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backward pass
            optimizer.step()  # Optimize weights

            running_loss.update(loss.item(), outputs.shape[0])
            acc.update(accuracy(outputs, labels), outputs.shape[0])
            if i % 100 == 0:
                print('[%d, %5d] ' % (epoch + 1, i + 1), running_loss, acc)
    print('Finished Training')


def test(model: nn.Module, dataloader: DataLoader, cuda=False) -> float:
    correct = 0
    total = 0
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data
            if cuda:
                inputs = inputs.cuda()
                labels = labels.cuda()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total  # Return accuracy percentage

In [8]:
# Train the network on the MNIST training set
train(net, trainloader, cuda=True)

[1,     1]  loss 2.301825 (2.301825) train_acc 9.375000 (9.375000)
... (training log output) ...
Finished Training


Now that the CNN has been trained, let's test it on our test dataset.

In [9]:
# Evaluate the network on the MNIST test set
score = test(net, testloader, cuda=True)
print('Accuracy of the network on the test images: {}% - FP32'.format(score))

Accuracy of the network on the test images: 98.09% - FP32


### Post-training quantization

Define a new quantized network architecture, which includes quantization and dequantization stubs. Then, fuse modules to both speed up the model and improve numerical accuracy. This process involves:

1. **Prepare:** Inserting observers into the model to record activation statistics.
2. **Calibration:** Running the model on sample data to gather tensor statistics.
3. **Convert:** Converting floating-point operations to quantized operations using the recorded statistics.

In [10]:
# Create a quantized network instance by enabling quantization (q=True)
qnet = Net(q=True)
load_model(qnet, net)   # Load the pretrained weights into the quantized network
fuse_modules(qnet)      # Fuse layers to improve efficiency and accuracy

The following code prepares the model for post-training quantization. Observers are inserted, calibration is performed, and the model is converted to a quantized version.

Potential Lab Test Q&A:
  - **Q:** What is the role of observers in post-training quantization?
    **A:** Observers record the range (min/max values) of activations, which are later used to determine the quantization parameters.
  - **Q:** Why do we fuse modules before converting to a quantized model?
    **A:** Fusion reduces memory accesses and computation overhead by merging operations (e.g., Conv + ReLU), which improves both speed and accuracy.

In [12]:
qnet.qconfig = torch.quantization.default_qconfig  # Use the default quantization configuration
print(qnet.qconfig)  # Print the QConfig to verify settings
torch.quantization.prepare(qnet, inplace=True)  # Insert observers into the model
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Conv1: After observer insertion \n\n', qnet.conv1)

test(qnet, trainloader, cuda=False)  # Run calibration on the training set
print('Post Training Quantization: Calibration done')
torch.quantization.convert(qnet, inplace=True)  # Convert the model to use quantized operators
print('Post Training Quantization: Convert done')
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)
print("Size of model after quantization")
print_size_of_model(qnet)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post Training Quantization Prepare: Inserting Observers

 Conv1: After observer insertion 

 ConvReLU2d(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (1): ReLU()
  (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
)
Post Training Quantization: Calibration done
Post Training Quantization: Convert done

 Conv1: After fusion and quantization 

 QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.05912807211279869, zero_point=0, bias=False)
Size of model after quantization
Size (MB): 0.050084


In [14]:
# Evaluate the quantized model on the test dataset
score = test(qnet, testloader, cuda=False)
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

Accuracy of the fused and quantized network on the test images: 98.11% - INT8


We can also define a custom quantization configuration. In this configuration, we replace the default observers with ones that use a moving average to calculate min/max values, which may improve generalization.

In [15]:
from torch.quantization.observer import MovingAverageMinMaxObserver

qnet = Net(q=True)  # Create a new quantized network instance
load_model(qnet, net)  # Load pretrained weights
fuse_modules(qnet)  # Fuse layers as before

qnet.qconfig = torch.quantization.QConfig(
    activation=MovingAverageMinMaxObserver.with_args(reduce_range=True),
    weight=MovingAverageMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
)
print(qnet.qconfig)

torch.quantization.prepare(qnet, inplace=True)  # Insert the custom observers
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Conv1: After observer insertion \n\n', qnet.conv1)

test(qnet, trainloader, cuda=False)  # Calibrate using training data
print('Post Training Quantization: Calibration done')
torch.quantization.convert(qnet, inplace=True)  # Convert to quantized model
print('Post Training Quantization: Convert done')
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)
print("Size of model after quantization")
print_size_of_model(qnet)

score = test(qnet, testloader, cuda=False)  # Evaluate quantized model
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post Training Quantization Prepare: Inserting Observers

 Conv1: After observer insertion 

 ConvReLU2d(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (1): ReLU()
  (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
)
Post Training Quantization: Calibration done
Post Training Quantization: Convert done

 Conv1: After fusion and quantization 

 QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.05865493789315224, zero_point=0, bias=False)
Size of model after quantization
Size (MB): 0.050084
Accuracy of the fused and quantized network on the test images: 98.13% - INT8


In addition, we can significantly improve accuracy by using a different quantization configuration (qnnpack) optimized for arm64 architectures. This configuration quantizes weights per channel and uses a histogram observer to select optimal quantization parameters.

In [16]:
qnet = Net(q=True)
load_model(qnet, net)
fuse_modules(qnet)

In [17]:
qnet.qconfig = torch.quantization.get_default_qconfig('qnnpack')  # Use qnnpack config for ARM architectures
print(qnet.qconfig)

torch.quantization.prepare(qnet, inplace=True)  # Insert observers
test(qnet, trainloader, cuda=False)
torch.quantization.convert(qnet, inplace=True)  # Convert the model
print("Size of model after quantization")
print_size_of_model(qnet)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Size of model after quantization
Size (MB): 0.050084


In [18]:
# Evaluate the model using the qnnpack quantization configuration
score = test(qnet, testloader, cuda=False)
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

Accuracy of the fused and quantized network on the test images: 98.02% - INT8


### Quantization Aware Training

Quantization-aware training (QAT) simulates quantization effects during both forward and backward passes. This typically leads to higher accuracy in the final quantized model since the network learns to compensate for the quantization error during training.

In [19]:
qnet = Net(q=True)  # Create a new network for QAT with quantization stubs enabled
fuse_modules(qnet)  # Fuse the appropriate modules
qnet.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')  # Set the QAT configuration
torch.quantization.prepare_qat(qnet, inplace=True)  # Prepare the model for quantization-aware training by inserting fake quantization modules
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)
qnet = qnet.cuda()  # Move model to GPU for training
train(qnet, trainloader, cuda=True)  # Train the quantization-aware model
qnet = qnet.cpu()  # Move model back to CPU for conversion
torch.quantization.convert(qnet, inplace=True)  # Convert the QAT model to a fully quantized model
print("Size of model after quantization")
print_size_of_model(qnet)

score = test(qnet, testloader, cuda=False)  # Evaluate the quantized model
print('Accuracy of the fused and quantized network (trained quantized) on the test images: {}% - INT8'.format(score))


 Conv1: After fusion and quantization 

 ConvReLU2d(
  1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False
  (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
    fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_symmetric, reduce_range=False
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
  (activation_post_process): FusedMovingAvgObsFakeQuantize(
    fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=255, qscheme=torch.per_tensor_affine, reduce_range=False
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
)
[... training log output ...]
Finished Training
Size of model after quantization
Size (MB): 0.050084
Accuracy of the fused and quanti