<a href="https://colab.research.google.com/github/chopratejas/LeetCode-solutions/blob/master/02_03_QAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torch torchvision



In [3]:
# Quantization Aware Training (QAT) Demo using PyTorch and CIFAR-10

import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import random
import os
import time

# Set random seed for reproducibility
torch.manual_seed(42)
random.seed(42)

# Quantizable CNN Model
class QuantizableCNN(nn.Module):
    def __init__(self):
        super(QuantizableCNN, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256), nn.ReLU(), nn.Linear(256, 10)
        )
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.features(x)
        x = self.classifier(x)
        x = self.dequant(x)
        return x

# Data Loaders
def get_data(batch_size=128):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
    train_subset = Subset(train_dataset, list(range(10000)))
    return DataLoader(train_subset, batch_size=batch_size, shuffle=True), DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Utility Functions
def get_model_size(model):
    torch.save(model.state_dict(), "temp.p")
    size = os.path.getsize("temp.p") / 1e6
    os.remove("temp.p")
    return size

def evaluate_model(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

def measure_inference_time(model, dataloader, device, num_runs=50):
    model.eval()
    x, _ = next(iter(dataloader))
    x = x.to(device)
    with torch.no_grad():
        for _ in range(5):
            _ = model(x)
        start = time.time()
        for _ in range(num_runs):
            _ = model(x)
        end = time.time()
    return (end - start) / num_runs

# Train QAT Model
def train_qat(model, loader, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.CrossEntropyLoss()

    # Initial FP32 training
    for i, (x, y) in enumerate(loader):
        if i > 100: break
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

    # QAT Prep and Fine-tuning
    model.cpu()
    model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
    torch.quantization.prepare_qat(model, inplace=True)
    model.train()

    for epoch in range(2):
        for i, (x, y) in enumerate(loader):
            if i > 200: break
            x, y = x.to("cpu"), y.to("cpu")
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()

    # Convert to quantized model
    model.eval()
    quantized_model = torch.quantization.convert(model.eval(), inplace=False)
    return quantized_model

# Run Demo
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader = get_data()

    # Train baseline model
    print("Training baseline FP32 model...")
    model_fp32 = QuantizableCNN().to(device)
    optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    model_fp32.train()
    for i, (x, y) in enumerate(train_loader):
        if i > 100: break
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model_fp32(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

    acc_fp32 = evaluate_model(model_fp32, test_loader, device)
    size_fp32 = get_model_size(model_fp32)
    time_fp32 = measure_inference_time(model_fp32, test_loader, device)

    # Train QAT model
    print("\nTraining and converting QAT model...")
    qat_model = QuantizableCNN().to(device)
    qat_model.load_state_dict(model_fp32.state_dict())
    qat_model = train_qat(qat_model, train_loader, device)

    acc_qat = evaluate_model(qat_model, test_loader, "cpu")
    size_qat = get_model_size(qat_model)
    time_qat = measure_inference_time(qat_model, test_loader, "cpu")

    # Results
    print("\n✅ QAT training and conversion complete.")
    print("\n--- Model Comparison ---")
    print(f"FP32 -> Accuracy: {acc_fp32:.4f}, Inference Time: {time_fp32*1000:.2f}ms, Size: {size_fp32:.2f}MB")
    print(f"QAT  -> Accuracy: {acc_qat:.4f}, Inference Time: {time_qat*1000:.2f}ms, Size: {size_qat:.2f}MB")

if __name__ == "__main__":
    main()


100%|██████████| 170M/170M [00:04<00:00, 36.7MB/s]


Training baseline FP32 model...

Training and converting QAT model...





✅ QAT training and conversion complete.

--- Model Comparison ---
FP32 -> Accuracy: 0.4496, Inference Time: 88.28ms, Size: 4.29MB
QAT  -> Accuracy: 0.4971, Inference Time: 42.61ms, Size: 1.09MB
