# Practicing Post-Training Static Quantization

This notebook provides a practical example of **post-training static quantization**, a technique used to make neural network models smaller and faster. Quantization works by converting the model's weights and activations from 32-bit floating-point numbers to lower-precision integers, typically 8-bit.

**Benefits:**
- **Reduced Model Size**: Smaller models require less storage and are easier to deploy.
- **Faster Inference**: Integer arithmetic is much faster than floating-point arithmetic on most hardware, especially on CPUs and edge devices.
- **Lower Power Consumption**: Faster calculations lead to reduced energy usage.

We will use:
- **Model**: A pre-trained `MobileNetV2` from `torchvision`.
- **Dataset**: `CIFAR-10`.
- **Technique**: Post-Training Static Quantization in PyTorch.

## 1. Imports and Setup

First, we'll import the necessary libraries. Note that for quantization, we'll be working primarily on the CPU, as that's where the performance benefits of quantized integer models are most pronounced.

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.quantization import quantize_dynamic, prepare, convert
import os
import time
from copy import deepcopy

# Set device to CPU for quantization
device = torch.device("cpu")
print(f"Using device: {device}")

## 2. Load Data and Pre-trained Model

We will load the CIFAR-10 dataset and a pre-trained MobileNetV2 model. We'll also define a helper function to evaluate the model's accuracy.

In [None]:
# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 test dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

# Helper function for evaluation
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# Load a pre-trained MobileNetV2 model
original_model = torchvision.models.mobilenet_v2(weights=torchvision.models.MobileNet_V2_Weights.IMAGENET1K_V1)
# Modify the classifier for CIFAR-10
original_model.classifier[1] = nn.Linear(original_model.last_channel, 10)
original_model.to(device)
original_model.eval()

print("Loaded pre-trained MobileNetV2 model.")

## 3. Prepare the Model for Quantization

Post-training static quantization involves a few preparation steps:
1.  **Fuse Modules**: We merge operations like Convolution -> BatchNorm -> ReLU into a single, more efficient block. This is crucial for quantization accuracy and speed.
2.  **Specify Quantization Config**: We define how the weights and activations should be quantized. `fbgemm` is a standard backend for x86 CPUs.
3.  **Prepare the Model**: We insert "observer" modules into the model. These observers will watch the range of activation values during a calibration step.

In [None]:
quantized_model = deepcopy(original_model)
quantized_model.eval()

# Fuse Conv, BatchNorm, and ReLU modules
torch.quantization.fuse_modules(quantized_model.features, [['0', '1'], ['2', '3'], ['4', '5'], ['7', '8'], ['9', '10'], ['11', '12'], ['13', '14'], ['16', '17']], inplace=True)

# Specify quantization configuration
# 'fbgemm' is a backend for x86 CPUs
quantized_model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# Prepare the model for static quantization. This inserts observers in the model
torch.quantization.prepare(quantized_model, inplace=True)

print("Model prepared for quantization.")

## 4. Calibrate the Model

Now we need to run a small amount of representative data through the prepared model. The observers we inserted will record the distribution of the activations. This information is then used to determine the optimal quantization parameters (scale and zero-point) for each layer.

In [None]:
print("Calibrating the model with a subset of the test data...")
# Use a subset of the test loader for calibration
calibration_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True)

with torch.no_grad():
    for images, _ in calibration_loader:
        quantized_model(images)

print("Calibration complete.")

## 5. Convert to a Quantized Model

With the calibration data collected, we can now convert the model. This step replaces the floating-point operations with their integer-based, quantized equivalents.

In [None]:
print("Converting the model to a quantized version...")
torch.quantization.convert(quantized_model, inplace=True)
print("Model converted.")

## 6. Compare and Evaluate

Now for the moment of truth. Let's compare the original floating-point model with our new quantized integer model in three key areas: model size, accuracy, and inference speed.

In [None]:
# --- 1. Compare Model Size ---
def print_model_size(model, label):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / 1e6 # size in MB
    print(f"Size of {label}: {size:.2f} MB")
    os.remove("temp.p")

print_model_size(original_model, "Original Model")
print_model_size(quantized_model, "Quantized Model")

# --- 2. Compare Accuracy ---
# NOTE: The original model is not fine-tuned on CIFAR-10, so accuracy will be low.
# The key is to observe the *drop* in accuracy after quantization.
print("\nEvaluating accuracy...")
original_accuracy = evaluate_model(original_model, testloader)
quantized_accuracy = evaluate_model(quantized_model, testloader)

print(f"Accuracy of Original Model: {original_accuracy:.2f}%")
print(f"Accuracy of Quantized Model: {quantized_accuracy:.2f}%")
print(f"Accuracy Drop: {original_accuracy - quantized_accuracy:.2f}%")

# --- 3. Compare Inference Speed ---
def time_inference(model, dataloader):
    model.eval()
    latencies = []
    with torch.no_grad():
        for images, _ in dataloader:
            start = time.time()
            _ = model(images)
            end = time.time()
            latencies.append(end - start)
    return sum(latencies) / len(latencies)

print("\nComparing inference speed...")
original_latency = time_inference(original_model, testloader)
quantized_latency = time_inference(quantized_model, testloader)

print(f"Average latency of Original Model: {original_latency * 1000:.2f} ms")
print(f"Average latency of Quantized Model: {quantized_latency * 1000:.2f} ms")
print(f"Speedup: {original_latency / quantized_latency:.2f}x")

## 7. Conclusion

As you can see from the results, post-training static quantization dramatically reduces the model size (typically by about 4x) and can significantly speed up inference on the CPU. The trade-off is usually a small drop in accuracy. For many applications, especially on resource-constrained devices, this trade-off is highly favorable.