# Quantization tutorial

This tutorial shows how to do post-training static quantization, as well as illustrating two more advanced techniques - per-channel quantization and quantization-aware training - to further improve the model’s accuracy. The task is to classify MNIST digits with a simple LeNet architecture.

This is a minimalistic tutorial to show you a starting point for quantization in PyTorch. For theory and more in-depth explanations, please check out: [Quantizing deep convolutional networks for efficient inference: A whitepaper](https://arxiv.org/abs/1806.08342).

The tutorial is heavily adapted from: [Static Quantization Tutorial](https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html)

**Top-Level Explanation:** This notebook demonstrates how to prepare, calibrate, and convert a simple CNN for MNIST digit classification to a quantized model. It covers data loading, model training, post-training quantization, and quantization-aware training (QAT).

**Potential Lab Q&A:**
- *Q: What is post-training quantization?*
  *A: It is the process of converting a pre-trained floating point model to a quantized version without retraining.*
- *Q: What are QuantStub and DeQuantStub used for?*
  *A: They mark the boundaries for quantization and dequantization in the network, allowing for efficient inference with int8 arithmetic.*


### Initial Setup

Before beginning the assignment, we import the MNIST dataset and train a simple convolutional neural network (CNN) to classify it. The following code installs specific versions of torch and torchvision, then imports necessary libraries for building and quantizing the model.

**Potential Lab Q&A:**
- *Q: Why install specific versions of torch?*
  *A: To ensure compatibility with the quantization code and to replicate the tutorial results.*

In [1]:
# Install specific versions of torch and torchvision
!pip3 install torch==1.5.0 torchvision==1.6.0

# Import necessary libraries for deep learning and quantization
import torch  # PyTorch for building and training the neural network
import torchvision  # Contains vision datasets, models, and transforms
import torchvision.transforms as transforms  # For preprocessing images
import torch.nn as nn  # Neural network module
import torch.nn.functional as F  # Functional interface for NN operations
import torch.optim as optim  # Optimizers for training the network
import os  # Operating system interface
from torch.utils.data import DataLoader  # Utility for loading data in batches
import torch.quantization  # Tools for quantizing the model
from torch.quantization import QuantStub, DeQuantStub  # Stubs for quantization and dequantization

# Potential Q&A:
# Q: What is quantization?
# A: Quantization reduces the precision of weights and activations to reduce model size and increase inference speed.

[31mERROR: Could not find a version that satisfies the requirement torch==1.5.0 (from versions: 1.13.0, 1.13.1, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0, 2.4.1, 2.5.0, 2.5.1, 2.6.0)[0m[31m
[0m[31mERROR: No matching distribution found for torch==1.5.0[0m[31m
[0m

Load training and test data from the MNIST dataset and apply a normalizing transformation.

**Key Point:** MNIST images are normalized to center the data, which helps with training stability.

**Potential Q&A:**
- *Q: Why do we normalize the images?*
  *A: Normalization scales the pixel values to a standard range, improving model convergence during training.*

In [2]:
# Define a transformation pipeline for MNIST images
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5,), (0.5,))  # Normalize with mean 0.5 and std 0.5 (since MNIST is grayscale)
])

# Load the MNIST training dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=16, pin_memory=True)

# Load the MNIST test dataset
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False, num_workers=16, pin_memory=True)

# Potential Q&A:
# Q: Why use pin_memory in DataLoader?
# A: pin_memory=True speeds up data transfer to GPU by using page-locked memory.

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 11.6MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 343kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 3.23MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 10.7MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Define some helper functions and classes that help us to track the statistics and accuracy with respect to the train/test data.

**Key Points:**
- `AverageMeter` is used to compute and store the average and current values of metrics like loss and accuracy.
- `accuracy` computes the top-1 accuracy of the model.
- `print_size_of_model` prints the size of the model (which is important to check quantization effectiveness).

**Potential Q&A:**
- *Q: Why do we fuse modules in a quantized model?*
  *A: Fusing layers like Conv+ReLU reduces memory access and improves numerical accuracy after quantization.*

In [3]:
class AverageMeter(object):
    """Computes and stores the average and current value."""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def accuracy(output, target):
    """Computes the top 1 accuracy."""
    with torch.no_grad():
        batch_size = target.size(0)

        _, pred = output.topk(1, 1, True, True)  # Get top prediction for each example
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        correct_one = correct[:1].view(-1).float().sum(0, keepdim=True)
        return correct_one.mul_(100.0 / batch_size).item()

def print_size_of_model(model):
    """Prints the size of the model in MB."""
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

def load_model(quantized_model, model):
    """Loads the weights from a trained model into a quantized model object."""
    state_dict = model.state_dict()
    model = model.to('cpu')
    quantized_model.load_state_dict(state_dict)

def fuse_modules(model):
    """Fuses convolution/linear layers with ReLU for better quantization efficiency."""
    torch.quantization.fuse_modules(model, [['conv1', 'relu1'],
                                            ['conv2', 'relu2'],
                                            ['fc1', 'relu3'],
                                            ['fc2', 'relu4']], inplace=True)

Define a simple CNN that classifies MNIST images.

**Key Points:**
- The network is a minimal LeNet-like CNN.
- Quantization stubs are included if quantization is enabled (`q=True`).

**Potential Q&A:**
- *Q: What is the purpose of using `reshape` instead of `view`?*
  *A: `reshape` is more flexible and can handle non-contiguous tensors, reducing potential errors during flattening.*

In [4]:
class Net(nn.Module):
    def __init__(self, q=False):
        # If q is True, quantization is enabled
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5, bias=False)  # First convolutional layer
        self.relu1 = nn.ReLU()  # Activation
        self.pool1 = nn.MaxPool2d(2, 2)  # Pooling layer
        self.conv2 = nn.Conv2d(6, 16, 5, bias=False)  # Second convolutional layer
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256, 120, bias=False)  # First fully connected layer
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(120, 84, bias=False)  # Second fully connected layer
        self.relu4 = nn.ReLU()
        self.fc3 = nn.Linear(84, 10, bias=False)  # Output layer
        self.q = q
        if q:
            # QuantStub marks the beginning of quantization
            self.quant = QuantStub()
            # DeQuantStub marks the end of quantization
            self.dequant = DeQuantStub()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.q:
            x = self.quant(x)  # Quantize the input
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        # Flatten the tensor; use reshape instead of view for safety
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.relu4(x)
        x = self.fc3(x)
        if self.q:
            x = self.dequant(x)  # Dequantize the output
        return x

# Potential Q&A:
# Q: What are QuantStub and DeQuantStub used for?
# A: They define the boundaries where the model converts between floating point and quantized representations.

In [5]:
net = Net(q=False).cuda()  # Create an instance of the CNN and move it to the GPU
print_size_of_model(net)  # Print the size of the model (important to compare before/after quantization)

# Potential Q&A:
# Q: Why is model size important in quantization?
# A: Quantization can reduce model size, which is crucial for deploying models on resource-constrained devices.

Size (MB): 0.179057


Train this CNN on the training dataset (this may take a few moments).

In [7]:
def train(model: nn.Module, dataloader: DataLoader, cuda=False, q=False):
    criterion = nn.CrossEntropyLoss()  # Loss function
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # Optimizer
    model.train()  # Set the model to training mode
    for epoch in range(10):  # Train for 10 epochs

        running_loss = AverageMeter('loss')
        acc = AverageMeter('train_acc')
        for i, data in enumerate(dataloader, 0):
            # Get the inputs and labels
            inputs, labels = data
            if cuda:
                inputs = inputs.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()  # Zero the gradients

            if epoch >= 3 and q:
                model.apply(torch.quantization.disable_observer)  # Disable observers after 3 epochs if using QAT

            outputs = model(inputs)  # Forward pass
            loss = criterion(outputs, labels)  # Compute loss
            loss.backward()  # Backward pass
            optimizer.step()  # Update weights

            # Update running loss and accuracy statistics
            running_loss.update(loss.item(), outputs.shape[0])
            acc.update(accuracy(outputs, labels), outputs.shape[0])
            if i % 100 == 0:  # Log every 100 mini-batches
                print('[%d, %5d] ' % (epoch + 1, i + 1), running_loss, acc)
    print('Finished Training')


def test(model: nn.Module, dataloader: DataLoader, cuda=False) -> float:
    correct = 0
    total = 0
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        for data in dataloader:
            inputs, labels = data

            if cuda:
                inputs = inputs.cuda()
                labels = labels.cuda()

            outputs = model(inputs)  # Forward pass during testing
            _, predicted = torch.max(outputs.data, 1)  # Get predictions
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100 * correct / total  # Return accuracy as a percentage

# Potential Q&A:
# Q: How is accuracy computed in the test function?
# A: By comparing the predicted labels with the true labels and computing the percentage of correct predictions.

In [8]:
train(net, trainloader, cuda=True)  # Train the CNN on the MNIST training data

# Potential Q&A:
# Q: Why do we use SGD with momentum for training?
# A: Momentum helps accelerate gradients vectors in the right directions, leading to faster convergences.

[1,     1]  loss 2.301825 (2.301825) train_acc 9.375000 (9.375000)
[1,   101]  loss 2.296545 (2.299753) train_acc 14.062500 (13.180693)
[1,   201]  loss 2.294123 (2.295533) train_acc 20.312500 (16.969838)
[1,   301]  loss 2.260918 (2.289851) train_acc 31.250000 (19.684385)
[1,   401]  loss 2.229809 (2.280524) train_acc 28.125000 (22.241272)
[1,   501]  loss 2.112481 (2.260747) train_acc 34.375000 (23.924027)
[1,   601]  loss 1.691649 (2.207187) train_acc 56.250000 (26.000936)
[1,   701]  loss 1.302064 (2.094674) train_acc 57.812500 (30.683845)
[1,   801]  loss 0.603729 (1.944274) train_acc 79.687500 (35.927747)
[1,   901]  loss 0.637056 (1.799278) train_acc 79.687500 (40.822697)
[2,     1]  loss 0.421752 (0.421752) train_acc 84.375000 (84.375000)
[2,   101]  loss 0.555698 (0.466051) train_acc 76.562500 (84.916460)
[2,   201]  loss 0.367405 (0.443531) train_acc 90.625000 (85.828669)
[2,   301]  loss 0.338774 (0.414789) train_acc 87.500000 (86.908223)
[2,   401]  loss 0.418230 (0.392866)

Now that the CNN has been trained, let's test it on our test dataset.

In [9]:
score = test(net, testloader, cuda=True)  # Evaluate the trained network on the test data
print('Accuracy of the network on the test images: {}% - FP32'.format(score))

# Potential Q&A:
# Q: What does FP32 indicate?
# A: FP32 indicates that the network is using 32-bit floating point precision.

Accuracy of the network on the test images: 98.09% - FP32


### Post-training quantization

Define a new quantized network architecture, where we also define the quantization and dequantization stubs that will be important at the start and at the end.

Next, we’ll “fuse modules”; this can both make the model faster by saving on memory access while also improving numerical accuracy. While this can be used with any model, this is especially common with quantized models.

**Process Overview:**
1. **Prepare:** Insert observers into the model to collect activation statistics (min/max values).
2. **Calibration:** Run the model on representative data to calibrate these observers.
3. **Convert:** Use the collected statistics to compute quantization parameters and convert the model's operations to quantized versions.

**Potential Q&A:**
- *Q: What is the purpose of fusing modules?*
  *A: It reduces memory access and improves numerical accuracy after quantization.*

In [10]:
qnet = Net(q=True)  # Create a quantized version of the network (with quantization stubs)
load_model(qnet, net)  # Load the pretrained floating point weights into the quantized model
fuse_modules(qnet)  # Fuse layers (e.g., Conv+ReLU) to optimize the model for quantization

# Inline Q&A:
# Q: Why do we need to load weights into qnet?
# A: We need to initialize the quantized network with the trained weights from the floating point model.

In general, we have the following process (Post Training Quantization):

1. **Prepare:** Insert observers into the model to observe activation statistics.
2. **Calibration:** Run the model on representative sample data to collect statistics.
3. **Convert:** Compute quantization parameters and convert the model from floating point to quantized operators.

In [12]:
qnet.qconfig = torch.quantization.default_qconfig  # Set the default quantization configuration
print(qnet.qconfig)  # Print the quantization configuration (for verification)
torch.quantization.prepare(qnet, inplace=True)  # Prepare the model by inserting observers
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Conv1: After observer insertion \n\n', qnet.conv1)  # Display the first convolution layer after observer insertion

test(qnet, trainloader, cuda=False)  # Run calibration using training data
print('Post Training Quantization: Calibration done')
torch.quantization.convert(qnet, inplace=True)  # Convert the calibrated model to a quantized version
print('Post Training Quantization: Convert done')
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)  # Display the fused and quantized conv layer
print("Size of model after quantization")
print_size_of_model(qnet)  # Print model size to observe reduction due to quantization

# Potential Q&A:
# Q: What does calibration do in post-training quantization?
# A: It collects activation statistics from representative data to determine optimal quantization parameters.

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, quant_min=0, quant_max=127){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post Training Quantization Prepare: Inserting Observers

 Conv1: After observer insertion 

 ConvReLU2d(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False)
  (1): ReLU()
  (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
)
Post Training Quantization: Calibration done
Post Training Quantization: Convert done

 Conv1: After fusion and quantization 

 QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.05912807211279869, zero_point=0, bias=False)
Size of model after quantization
Size (MB): 0.050084


In [14]:
score = test(qnet, testloader, cuda=False)  # Evaluate the quantized model on test data
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

# Potential Q&A:
# Q: How does quantization affect accuracy?
# A: Proper quantization can maintain high accuracy while reducing model size and latency.

Accuracy of the fused and quantized network on the test images: 98.11% - INT8


We can also define a custom quantization configuration, where we replace the default observers and instead of quantizing with respect to max/min we can take an average of the observed max/min, hopefully for a better generalization performance.

In [15]:
from torch.quantization.observer import MovingAverageMinMaxObserver  # Import a custom observer for quantization

qnet = Net(q=True)  # Create a quantized network
load_model(qnet, net)  # Load pretrained weights into the quantized model
fuse_modules(qnet)  # Fuse modules (e.g., Conv+ReLU) for improved quantization

qnet.qconfig = torch.quantization.QConfig(
    activation=MovingAverageMinMaxObserver.with_args(reduce_range=True),
    weight=MovingAverageMinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric)
)
print(qnet.qconfig)  # Print the custom quantization configuration
torch.quantization.prepare(qnet, inplace=True)  # Insert observers into the model for QAT calibration
print('Post Training Quantization Prepare: Inserting Observers')
print('\n Conv1: After observer insertion \n\n', qnet.conv1)  # Inspect the conv layer after observer insertion

test(qnet, trainloader, cuda=False)  # Run calibration on training data
print('Post Training Quantization: Calibration done')
torch.quantization.convert(qnet, inplace=True)  # Convert the model to its quantized version
print('Post Training Quantization: Convert done')
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)  # Inspect the quantized conv layer
print("Size of model after quantization")
print_size_of_model(qnet)  # Print model size to verify reduction
score = test(qnet, testloader, cuda=False)  # Evaluate the quantized model on test data
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

# Potential Q&A:
# Q: What does QConfig specify?
# A: It specifies the quantization observers for activations and weights, determining how quantization parameters are computed.

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Post Training Quantization: Calibration done
Post Training Quantization: Convert done

 Conv1: After fusion and quantization 

 QuantizedConvReLU2d(1, 6, kernel_size=(5, 5), stride=(1, 1), scale=0.05865493789315224, zero_point=0, bias=False)
Size of model after quantization
Size (MB): 0.050084
Accuracy of the fused and quantized network on the test images: 98.13% - INT8


In addition, we can significantly improve on the accuracy simply by using a different quantization configuration. We repeat the same exercise with the recommended configuration for quantizing for arm64 architecture (qnnpack). This configuration does the following:

- Quantizes weights on a per-channel basis.
- Uses a histogram observer that collects a histogram of activations and then picks quantization parameters in an optimal manner.

In [16]:
qnet = Net(q=True)  # Reinitialize quantized network for custom config
load_model(qnet, net)  # Load pretrained weights
fuse_modules(qnet)  # Fuse modules before applying new quantization configuration

# Inline Q&A:
# Q: Why reinitialize qnet for a custom configuration?
# A: To apply a different quantization setup, we need to start from a fresh quantized network.

In [17]:
qnet.qconfig = torch.quantization.get_default_qconfig('qnnpack')  # Use the default QNNPACK config for arm64
print(qnet.qconfig)  # Print the new quantization configuration

torch.quantization.prepare(qnet, inplace=True)  # Prepare model with new observers
test(qnet, trainloader, cuda=False)  # Run calibration
torch.quantization.convert(qnet, inplace=True)  # Convert to quantized model
print("Size of model after quantization")
print_size_of_model(qnet)  # Print model size after applying new configuration

# Potential Q&A:
# Q: What does get_default_qconfig('qnnpack') do?
# A: It returns a quantization configuration optimized for ARM architectures using QNNPACK.

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Size of model after quantization
Size (MB): 0.050084


In [18]:
score = test(qnet, testloader, cuda=False)  # Evaluate the quantized network on the test dataset
print('Accuracy of the fused and quantized network on the test images: {}% - INT8'.format(score))

# Inline Q&A:
# Q: How is the accuracy affected after quantization?
# A: With proper calibration and conversion, accuracy remains high (here, around 98% using INT8 precision).

Accuracy of the fused and quantized network on the test images: 98.02% - INT8


### Quantization aware training

Quantization-aware training (QAT) is the quantization method that typically results in the highest accuracy. With QAT, all weights and activations are “fake quantized” during both the forward and backward passes of training: that is, float values are rounded to mimic int8 values, but all computations are still done with floating point numbers.

**Key Points:**
- QAT simulates quantization effects during training so the model learns to adapt to quantized weights and activations.
- This usually results in better accuracy compared to post-training quantization.

**Potential Q&A:**
- *Q: What is the main advantage of QAT?*
  *A: QAT typically yields higher accuracy because the network adjusts its weights during training to account for quantization effects.*

In [19]:
qnet = Net(q=True)  # Create a quantized network for QAT
fuse_modules(qnet)  # Fuse modules for optimal QAT performance
qnet.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')  # Set QAT configuration for arm64
torch.quantization.prepare_qat(qnet, inplace=True)  # Prepare the model for QAT by inserting fake quantization modules
print('\n Conv1: After fusion and quantization \n\n', qnet.conv1)
qnet = qnet.cuda()  # Move model to GPU for training
train(qnet, trainloader, cuda=True)  # Train the model with quantization-aware training enabled
qnet = qnet.cpu()  # Move the trained model back to CPU
torch.quantization.convert(qnet, inplace=True)  # Convert the QAT model to a fully quantized (INT8) model
print("Size of model after quantization")
print_size_of_model(qnet)  # Print the size of the quantized model

score = test(qnet, testloader, cuda=False)  # Evaluate the quantized model on test data
print('Accuracy of the fused and quantized network (trained quantized) on the test images: {}% - INT8'.format(score))

# Inline Q&A:
# Q: What is the purpose of QAT?
# A: QAT simulates quantization during training to help the network adjust its weights, resulting in better accuracy after conversion to INT8.


 Conv1: After fusion and quantization 

 ConvReLU2d(
  1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False
  (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
    fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_symmetric, reduce_range=False
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
  (activation_post_process): FusedMovingAvgObsFakeQuantize(
    fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=255, qscheme=torch.per_tensor_affine, reduce_range=False
    (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
  )
)
[1,     1]  loss 2.304440 (2.304440) train_acc 9.375000 (9.375000)
...
Finished Training
Size of model after quantization
Size (MB): 