In [None]:
import os
import sys
sys.path.append("../..")

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10

from utils.train import evaluate

In [None]:
# these are the standard pre-computed values
cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std  = (0.2023, 0.1994, 0.2010)

t = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=cifar10_mean, std=cifar10_std)
])

ds = CIFAR10(
    root="../../assets/cifar10", 
    train=False, 
    download=True, 
    transform=t
)

dl = DataLoader(ds, batch_size=64, shuffle=False)

## Static Quantization

Static quantization converts both weights and activations to low-precision integers ahead of inference using calibration data, enabling maximum compression and speedup on supported CPU backends.

<p align="center">
    <img src="../../assets/img/deployment/static_quantization.png" width="400">
</p>

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
    def forward(self, x):
        return self.block(x)
    
class QuantizedCNN(nn.Module):
    def __init__(self, base_channels=32, num_conv_layers=3, dropout_p=0.5):
        super(QuantizedCNN, self).__init__()
        
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

        layers = []
        in_channels = 3
        channels = base_channels
        spatial_size = 32

        for _ in range(num_conv_layers):
            layers.append(ConvBlock(in_channels, channels))
            in_channels = channels
            channels *= 2
            spatial_size //= 2

        self.features = nn.Sequential(*layers)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_channels * spatial_size * spatial_size, 128),
            nn.ReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.quant(x)
        out = self.features(x)
        out = self.classifier(out)
        out = self.dequant(out)
        return out

In [None]:
model_fp32 = QuantizedCNN(num_conv_layers=3, base_channels=32, dropout_p=0.5)
model_fp32.load_state_dict(torch.load('../../assets/models/model_fp32.pth'), strict=False)
model_fp32.eval()

# fbgemm for Intel, AMD CPUS (Windows/Linux), qnnpack for ARM (mobile)
torch.backends.quantized.engine = "qnnpack"
model_fp32.qconfig = torch.quantization.get_default_qconfig("qnnpack")

torch.quantization.prepare(model_fp32, inplace=True)

# claibration
with torch.no_grad():
    for images, _ in dl:
        model_fp32(images)
        break  # a few batches is enough

model_int8 = torch.quantization.convert(model_fp32, inplace=False)

In [None]:
torch.save(model_int8.state_dict(), '../../assets/models/model_int8_static.pth')

def get_file_size_mb(path):
    return os.path.getsize(path) / (1024 ** 2)

fp32_size = get_file_size_mb("../../assets/models/model_fp32.pth")
int8_size = get_file_size_mb("../../assets/models/model_int8_static.pth")

print(f"FP32 model size: {fp32_size:.2f} MB")
print(f"INT8 model size: {int8_size:.2f} MB")
print(f"Compression ratio: {fp32_size / int8_size:.2f}×")

In [None]:
_, fp32_acc = evaluate(model_fp32, dl, nn.CrossEntropyLoss(), "cpu")
_, int8_acc = evaluate(model_int8, dl, nn.CrossEntropyLoss(), "cpu")

print(f"FP32 model accuracy: {fp32_acc:.2f}%")
print(f"INT8 model accuracy: {int8_acc:.2f}%")

## Dynamic Quantization

Dynamic quantization quantizes weights ahead of time but quantizes activations on the fly during inference, requiring no calibration data.

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
    def forward(self, x):
        return self.block(x)
    
class SimpleCNN(nn.Module):
    def __init__(self, num_conv_layers=3, base_channels=32, dropout_p=0.5):
        super().__init__()

        layers = []
        in_channels = 3
        channels = base_channels
        spatial_size = 32

        for _ in range(num_conv_layers):
            layers.append(ConvBlock(in_channels, channels))
            in_channels = channels
            channels *= 2
            spatial_size //= 2

        self.features = nn.Sequential(*layers)

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_channels * spatial_size * spatial_size, 128),
            nn.ReLU(),
            nn.Dropout(p=dropout_p),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

In [None]:
model = SimpleCNN(num_conv_layers=3, base_channels=32, dropout_p=0.5)
model_fp32.load_state_dict(torch.load("../../assets/models/model_fp32.pth"))
model_fp32.eval()

model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,
    {torch.nn.Linear},   # only quantize Linear layers
    dtype=torch.qint8
)

In [None]:
torch.save(model_int8.state_dict(), '../../assets/models/model_int8_dynamic.pth')

def get_file_size_mb(path):
    return os.path.getsize(path) / (1024 ** 2)

fp32_size = get_file_size_mb("../../assets/models/model_fp32.pth")
int8_size = get_file_size_mb("../../assets/models/model_int8_dynamic.pth")

print(f"FP32 model size: {fp32_size:.2f} MB")
print(f"INT8 model size: {int8_size:.2f} MB")
print(f"Compression ratio: {fp32_size / int8_size:.2f}×")