# 🔧 02 - Quantize MobileNetV2 with PyTorch

This notebook applies **post-training static quantization** to a MobileNetV2 model trained on CIFAR-10.  
It compares accuracy and model size before and after quantization.


In [7]:
import torch
import torch.nn as nn
import torch.quantization
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import os


In [8]:
device = torch.device("cpu")  # Quantization is CPU-only
print(f"Using device: {device}")


Using device: cpu


In [9]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)


In [10]:
model_fp32 = models.mobilenet_v2(pretrained=False)
model_fp32.classifier[1] = nn.Linear(model_fp32.last_channel, 10)
model_fp32.load_state_dict(torch.load("../models/mobilenetv2_cifar10_baseline.pth", map_location=device))
model_fp32.eval()
print(" Loaded baseline FP32 model.")


 Loaded baseline FP32 model.


In [11]:
def evaluate(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in dataloader:
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total


In [12]:
acc_fp32 = evaluate(model_fp32, testloader)
print(f" Baseline FP32 Accuracy: {acc_fp32:.2f}%")


KeyboardInterrupt: 

In [None]:
model_to_quantize = models.mobilenet_v2(pretrained=False)
model_to_quantize.classifier[1] = nn.Linear(model_to_quantize.last_channel, 10)
model_to_quantize.load_state_dict(torch.load("../models/mobilenetv2_cifar10_baseline.pth", map_location=device))
model_to_quantize.eval()

model_to_quantize.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model_to_quantize, inplace=True)

# Calibration (use 5 test batches)
for i, (images, _) in enumerate(testloader):
    model_to_quantize(images)
    if i >= 5:
        break

torch.quantization.convert(model_to_quantize, inplace=True)
print(" Model quantized.")


In [None]:
acc_int8 = evaluate(model_to_quantize, testloader)
print(f" Quantized INT8 Accuracy: {acc_int8:.2f}%")


In [None]:
os.makedirs("../models", exist_ok=True)
torch.save(model_to_quantize.state_dict(), "../models/mobilenetv2_quantized.pth")
print(" Saved to '../models/mobilenetv2_quantized.pth'")


In [None]:
size_fp32 = os.path.getsize("../models/mobilenetv2_cifar10_baseline.pth") / 1e6
size_int8 = os.path.getsize("../models/mobilenetv2_quantized.pth") / 1e6

print(f" FP32 Model Size: {size_fp32:.2f} MB")
print(f" INT8 Model Size: {size_int8:.2f} MB")
print(f" Size Reduction: {(size_fp32 - size_int8) / size_fp32 * 100:.2f}%")
