# üî¢ Lecture 5: Quantization - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/05_quantization_1/demo.ipynb)

## What You'll Learn
- How quantization reduces model size (FP32 ‚Üí INT8 = 4x smaller)
- Implementing quantization from scratch
- Understanding scale and zero-point
- PyTorch dynamic quantization
- Measuring accuracy vs compression trade-off


In [None]:
# Setup
!pip install torch torchvision matplotlib -q

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

device = torch.device('cpu')  # Quantization works on CPU
print(f"üñ•Ô∏è Using device: {device}")
torch.manual_seed(42)


## Part 1: Understanding Quantization

**Quantization** = Converting floating-point numbers to lower-precision integers

| Data Type | Bits | Memory per value | Range |
|-----------|------|------------------|-------|
| FP32 | 32 | 4 bytes | ¬±3.4e38 |
| FP16 | 16 | 2 bytes | ¬±65504 |
| INT8 | 8 | 1 byte | -128 to 127 |
| INT4 | 4 | 0.5 bytes | -8 to 7 |


In [None]:
# Implement quantization from scratch
def quantize_tensor(x, num_bits=8):
    """
    Quantize a floating-point tensor to integer.
    
    Quantization formula:
        q = round(x / scale) + zero_point
    
    Dequantization formula:
        x_approx = (q - zero_point) * scale
    """
    qmin = 0
    qmax = 2**num_bits - 1
    
    # Calculate scale and zero point
    x_min, x_max = x.min().item(), x.max().item()
    scale = (x_max - x_min) / (qmax - qmin)
    zero_point = qmin - x_min / scale
    
    # Quantize
    q = torch.clamp(torch.round(x / scale + zero_point), qmin, qmax)
    
    # Dequantize for comparison
    x_dequant = (q - zero_point) * scale
    
    return q.to(torch.uint8), scale, zero_point, x_dequant

# Demo with random tensor
print("üìä Quantization Demo")
print("="*50)

x = torch.randn(3, 3) * 2  # Random values roughly in [-4, 4]
print(f"Original FP32 tensor:\n{x}\n")

q, scale, zp, x_recon = quantize_tensor(x, num_bits=8)
print(f"Quantized INT8 tensor:\n{q}\n")
print(f"Scale: {scale:.6f}")
print(f"Zero point: {zp:.1f}")

# Calculate error
error = (x - x_recon).abs()
print(f"\nReconstruction error:")
print(f"  Mean absolute error: {error.mean():.6f}")
print(f"  Max error: {error.max():.6f}")

# Memory comparison
print(f"\nüíæ Memory savings:")
print(f"  FP32: {x.numel() * 4} bytes")
print(f"  INT8: {q.numel() * 1} bytes")
print(f"  Compression: {x.numel() * 4 / (q.numel() * 1):.0f}x")


In [None]:
# Visualize quantization at different bit widths
fig, axes = plt.subplots(2, 3, figsize=(14, 8))

# Original continuous signal
x = torch.linspace(-2, 2, 1000)

for i, bits in enumerate([8, 4, 2]):
    ax = axes[0, i]
    
    # Quantize
    q, scale, zp, x_recon = quantize_tensor(x, num_bits=bits)
    
    # Plot
    ax.plot(x.numpy(), x.numpy(), 'b-', label='Original', alpha=0.5, linewidth=2)
    ax.plot(x.numpy(), x_recon.numpy(), 'r-', label='Quantized', linewidth=1)
    ax.set_title(f'{bits}-bit Quantization\n({2**bits} levels)')
    ax.set_xlabel('Original Value')
    ax.set_ylabel('Quantized Value')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Error histogram
    ax2 = axes[1, i]
    error = (x - x_recon).numpy()
    ax2.hist(error, bins=50, color='orange', alpha=0.7, edgecolor='black')
    ax2.set_title(f'{bits}-bit Error Distribution\nMean: {np.abs(error).mean():.4f}')
    ax2.set_xlabel('Quantization Error')
    ax2.set_ylabel('Count')

plt.suptitle('üîç Quantization at Different Bit Widths', fontsize=14)
plt.tight_layout()
plt.show()


## Part 2: Quantizing a Real Neural Network

Let's quantize an actual model and measure the impact on accuracy and size.


In [None]:
# Create and train a simple model
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)

# Train
model = SimpleMLP()

def train_model(model, epochs=3):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        model.train()
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
    return model

def evaluate(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for data, target in loader:
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return 100. * correct / total

print("üèãÔ∏è Training model...")
model = train_model(model, epochs=3)
original_acc = evaluate(model, test_loader)
print(f"‚úÖ Original FP32 accuracy: {original_acc:.2f}%")


In [None]:
# Get model size
def get_model_size_mb(model):
    param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
    return (param_size + buffer_size) / 1024 / 1024

# PyTorch Dynamic Quantization (Post-Training)
print("üî¢ Applying Dynamic Quantization...")
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear},  # Quantize Linear layers
    dtype=torch.qint8
)

# Evaluate quantized model
quantized_acc = evaluate(quantized_model, test_loader)

# Compare sizes
original_size = get_model_size_mb(model)
quantized_size = get_model_size_mb(quantized_model)

print(f"\n" + "="*50)
print(f"üìä QUANTIZATION RESULTS")
print(f"="*50)
print(f"\nüìê Model Size:")
print(f"   Original (FP32):  {original_size:.2f} MB")
print(f"   Quantized (INT8): {quantized_size:.2f} MB")
print(f"   Compression:      {original_size/quantized_size:.1f}x smaller")

print(f"\nüéØ Accuracy:")
print(f"   Original:  {original_acc:.2f}%")
print(f"   Quantized: {quantized_acc:.2f}%")
print(f"   Drop:      {original_acc - quantized_acc:.2f}%")


In [None]:
# Visual comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Size comparison
sizes = [original_size, quantized_size]
labels = ['FP32\n(Original)', 'INT8\n(Quantized)']
colors = ['#3b82f6', '#22c55e']
axes[0].bar(labels, sizes, color=colors, edgecolor='black')
axes[0].set_ylabel('Size (MB)')
axes[0].set_title('üìê Model Size Comparison')
for i, (s, l) in enumerate(zip(sizes, labels)):
    axes[0].text(i, s + 0.01, f'{s:.2f} MB', ha='center', fontsize=11)

# Accuracy comparison
accs = [original_acc, quantized_acc]
axes[1].bar(labels, accs, color=colors, edgecolor='black')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('üéØ Accuracy Comparison')
axes[1].set_ylim([90, 100])
for i, (a, l) in enumerate(zip(accs, labels)):
    axes[1].text(i, a + 0.2, f'{a:.1f}%', ha='center', fontsize=11)

plt.tight_layout()
plt.show()

print("\n" + "="*50)
print("üéØ KEY TAKEAWAYS")
print("="*50)
print("‚Ä¢ Quantization: FP32 ‚Üí INT8 = 4x smaller model")
print("‚Ä¢ Post-Training Quantization (PTQ) is easy: one function call")
print("‚Ä¢ Minimal accuracy loss for most models (<1%)")
print("‚Ä¢ INT8 inference is 2-4x faster on supported hardware")
print("‚Ä¢ For even smaller: INT4 quantization (used in LLMs)")
