In [None]:
!pip install torch torchvision matplotlib pandas seaborn -q

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, Subset
import torchvision
import torchvision.transforms as transforms

In [None]:
# Quantization imports
import torch.quantization as quantization
from torch.quantization import QuantStub, DeQuantStub
from torch.quantization.quantize_fx import prepare_fx, convert_fx

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import numpy as np
import time
import os
from pathlib import Path

In [None]:
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device_train = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device_inference = torch.device('cpu')  # Quantized models run on CPU

print(f"Training device: {device_train}")
print(f"Inference device: {device_inference}")

PyTorch version: 2.8.0+cu126
CUDA available: False
Training device: cpu
Inference device: cpu


In [None]:
transform_train = transforms.Compose([
    transforms.RandomRotation(5),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train)
test_dataset_clean = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform_test)


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 9.91M/9.91M [00:00<00:00, 18.1MB/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 28.9k/28.9k [00:00<00:00, 501kB/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1.65M/1.65M [00:00<00:00, 4.54MB/s]
100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 4.54k/4.54k [00:00<00:00, 5.86MB/s]


In [None]:
# Add noise to test data
def add_noise(dataset, noise_level=0.3):
    raw = dataset.data.float() / 255.0
    noise = torch.randn_like(raw) * noise_level
    noisy = torch.clamp(raw + noise, 0., 1.)
    noisy = (noisy - 0.1307) / 0.3081
    return TensorDataset(noisy.unsqueeze(1), dataset.targets)

test_dataset_noisy = add_noise(test_dataset_clean, 0.3)


In [None]:
# Data Loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader_clean = DataLoader(test_dataset_clean, batch_size=1000, shuffle=False)
test_loader_noisy = DataLoader(test_dataset_noisy, batch_size=1000, shuffle=False)

In [None]:
class QAT_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        # FIX â†“â†“â†“
        x = x.reshape(-1, 32 * 7 * 7)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = self.dequant(x)
        return x


In [None]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()


In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    start_time = time.time()
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    end_time = time.time()
    return 100 * correct / total, end_time - start_time


In [None]:
import torch.ao.quantization.quantize_fx as quant_fx

# 1. Initialize model and set QAT config
model_fp = QAT_CNN().to(device_train)
model_fp.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm")

qconfig_dict = {"": model_fp.qconfig}
example_inputs = torch.randn(1, 1, 28, 28)

# 2. Prepare FX for QAT (UPDATED API)
model_prepared = quant_fx.prepare_fx(model_fp, qconfig_dict, example_inputs)
model_prepared.train()

# 3. Optimizer and loss
optimizer = optim.Adam(model_prepared.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 4. Train (QAT)
for epoch in range(3):
    train(model_prepared, train_loader, criterion, optimizer, device_train)
    print(f"Epoch {epoch+1}/3 completed")

# 5. Convert to quantized model (UPDATED API)
model_quantized = quant_fx.convert_fx(model_prepared.eval().to(device_inference))


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  model_prepared = quant_fx.prepare_fx(model_fp, qconfig_dict, example_inputs)
  prepared = prepare(


Epoch 1/3 completed
Epoch 2/3 completed
Epoch 3/3 completed


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  model_quantized = quant_fx.convert_fx(model_prepared.eval().to(device_inference))


In [None]:
acc_clean, t_clean = evaluate(model_quantized, test_loader_clean, device_inference)
acc_noisy, t_noisy = evaluate(model_quantized, test_loader_noisy, device_inference)

print(f"\nâœ… Clean MNIST: Accuracy = {acc_clean:.2f}%, Inference Time = {t_clean:.2f}s")
print(f"âœ… Noisy MNIST: Accuracy = {acc_noisy:.2f}%, Inference Time = {t_noisy:.2f}s")



âœ… Clean MNIST: Accuracy = 98.97%, Inference Time = 4.20s
âœ… Noisy MNIST: Accuracy = 94.45%, Inference Time = 2.64s


In [None]:
torch.save(model_quantized.state_dict(), "qat_model.pth")
size_MB = os.path.getsize("qat_model.pth") / (1024 ** 2)
print(f"ðŸ“¦ Quantized Model Size: {size_MB:.2f} MB")


ðŸ“¦ Quantized Model Size: 0.21 MB
