# üéØ Lecture 6: Quantization-Aware Training (QAT) - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/06_quantization_2/demo.ipynb)

## What You'll Learn
- Why QAT beats Post-Training Quantization
- Straight-Through Estimator (STE) for gradients
- Fake quantization during training
- Mixed-precision quantization

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
print('Ready for QAT!')

## Part 1: The Problem with PTQ

Post-Training Quantization has limitations, especially at low bit-widths.

In [None]:
# Demonstrate PTQ accuracy degradation
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

def quantize_weight(w, num_bits):
    """Simple symmetric quantization."""
    qmax = 2 ** (num_bits - 1) - 1
    scale = w.abs().max() / qmax
    q = torch.round(w / scale).clamp(-qmax-1, qmax)
    return q * scale  # Dequantized

def apply_ptq(model, num_bits):
    """Apply PTQ to all weights."""
    model_ptq = type(model)()
    model_ptq.load_state_dict(model.state_dict())
    
    with torch.no_grad():
        for name, param in model_ptq.named_parameters():
            if 'weight' in name:
                param.data = quantize_weight(param.data, num_bits)
    
    return model_ptq

# Create and "train" model
model = SimpleNet()
X_test = torch.randn(1000, 784)
y_test = torch.randint(0, 10, (1000,))

# Quick training
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
for _ in range(20):
    optimizer.zero_grad()
    loss = criterion(model(X_test), y_test)
    loss.backward()
    optimizer.step()

def evaluate(model, X, y):
    model.eval()
    with torch.no_grad():
        return (model(X).argmax(1) == y).float().mean().item() * 100

# Test PTQ at different bit widths
print('üìä PTQ ACCURACY DEGRADATION')
print('=' * 50)
baseline = evaluate(model, X_test, y_test)
print(f'FP32 Baseline: {baseline:.1f}%')

ptq_results = []
for bits in [8, 6, 4, 3, 2]:
    model_ptq = apply_ptq(model, bits)
    acc = evaluate(model_ptq, X_test, y_test)
    ptq_results.append((bits, acc))
    print(f'{bits}-bit PTQ: {acc:.1f}% (drop: {baseline-acc:.1f}%)')

print('\n‚ö†Ô∏è Notice: Accuracy drops significantly at low bit widths!')

## Part 2: Straight-Through Estimator (STE)

The key to QAT: Approximate gradients through quantization.

**Problem**: `round()` has zero gradient almost everywhere

**Solution**: STE - In forward pass: quantize. In backward pass: pass gradient through

In [None]:
class StraightThroughEstimator(torch.autograd.Function):
    """
    Straight-Through Estimator for quantization.
    
    Forward: Quantize (round)
    Backward: Pass gradient through (identity)
    """
    @staticmethod
    def forward(ctx, x, num_bits):
        qmax = 2 ** (num_bits - 1) - 1
        scale = x.abs().max() / qmax + 1e-8
        
        # Quantize
        q = torch.round(x / scale).clamp(-qmax-1, qmax)
        x_q = q * scale
        
        ctx.save_for_backward(x, torch.tensor([qmax], dtype=torch.float32))
        return x_q
    
    @staticmethod
    def backward(ctx, grad_output):
        x, qmax_tensor = ctx.saved_tensors
        qmax = qmax_tensor.item()
        scale = x.abs().max() / qmax + 1e-8
        
        # Gradient clipping (optional): zero gradient outside quantization range
        mask = (x.abs() / scale <= qmax).float()
        
        # STE: pass gradient through
        return grad_output * mask, None

fake_quantize = StraightThroughEstimator.apply

# Demonstrate STE
x = torch.randn(100, requires_grad=True)

# Forward: quantized
x_q = fake_quantize(x, 4)

# Backward: gradient flows!
loss = x_q.sum()
loss.backward()

print('üìä STRAIGHT-THROUGH ESTIMATOR DEMO')
print('=' * 50)
print(f'Input range: [{x.min():.3f}, {x.max():.3f}]')
print(f'Quantized output has {len(torch.unique(x_q))} unique values')
print(f'Gradient is non-zero: {(x.grad != 0).sum().item()} / {x.numel()}')

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original vs Quantized
axes[0].scatter(x.detach().numpy(), x_q.detach().numpy(), alpha=0.5)
axes[0].plot([-3, 3], [-3, 3], 'r--', label='y=x')
axes[0].set_xlabel('Original')
axes[0].set_ylabel('Quantized')
axes[0].set_title('Forward: Quantization')
axes[0].legend()

# Gradient
axes[1].scatter(x.detach().numpy(), x.grad.numpy(), alpha=0.5, c='green')
axes[1].axhline(y=1, color='r', linestyle='--', label='STE: grad=1')
axes[1].set_xlabel('Input Value')
axes[1].set_ylabel('Gradient')
axes[1].set_title('Backward: STE Gradient')
axes[1].legend()

# Compare with true gradient (which would be 0)
axes[2].bar(['Round (true)', 'STE'], [0, 1], color=['red', 'green'])
axes[2].set_ylabel('Gradient Flow')
axes[2].set_title('Why STE Works')

plt.tight_layout()
plt.show()

## Part 3: Implementing QAT from Scratch

In [None]:
class QATLinear(nn.Module):
    """
    Linear layer with Quantization-Aware Training.
    Uses fake quantization during training.
    """
    def __init__(self, in_features, out_features, weight_bits=8, act_bits=8):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_bits = weight_bits
        self.act_bits = act_bits
        
        # Full precision weights
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.01)
        self.bias = nn.Parameter(torch.zeros(out_features))
    
    def forward(self, x):
        # Fake quantize weights
        w_q = fake_quantize(self.weight, self.weight_bits)
        
        # Fake quantize activations
        x_q = fake_quantize(x, self.act_bits)
        
        # Compute with quantized values
        return F.linear(x_q, w_q, self.bias)

class QATNet(nn.Module):
    """Network with QAT layers."""
    def __init__(self, weight_bits=8, act_bits=8):
        super().__init__()
        self.fc1 = QATLinear(784, 256, weight_bits, act_bits)
        self.fc2 = QATLinear(256, 64, weight_bits, act_bits)
        self.fc3 = QATLinear(64, 10, weight_bits, act_bits)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

print('‚úÖ QAT Network defined!')
print('\nüìä QAT Layer Details:')
print('- Forward: Apply fake quantization to weights and activations')
print('- Backward: Use STE to pass gradients through quantization')
print('- Result: Network learns to be robust to quantization noise')

In [None]:
# Train QAT model vs PTQ model
def train_model(model, X, y, epochs=50, lr=0.001):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    losses = []
    
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    
    return losses

# Create training data
X_train = torch.randn(2000, 784)
y_train = torch.randint(0, 10, (2000,))

print('üìä QAT vs PTQ COMPARISON')
print('=' * 60)

results = {'bits': [], 'ptq': [], 'qat': []}

for bits in [8, 6, 4]:
    print(f'\n{bits}-bit Quantization:')
    
    # Train FP32 model then apply PTQ
    fp32_model = SimpleNet()
    train_model(fp32_model, X_train, y_train, epochs=30)
    ptq_model = apply_ptq(fp32_model, bits)
    ptq_acc = evaluate(ptq_model, X_test, y_test)
    print(f'  PTQ accuracy: {ptq_acc:.1f}%')
    
    # Train with QAT from scratch
    qat_model = QATNet(weight_bits=bits, act_bits=bits)
    train_model(qat_model, X_train, y_train, epochs=30)
    qat_acc = evaluate(qat_model, X_test, y_test)
    print(f'  QAT accuracy: {qat_acc:.1f}%')
    print(f'  Improvement: +{qat_acc - ptq_acc:.1f}%')
    
    results['bits'].append(bits)
    results['ptq'].append(ptq_acc)
    results['qat'].append(qat_acc)

In [None]:
# Visualize results
fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(results['bits']))
width = 0.35

bars1 = ax.bar(x - width/2, results['ptq'], width, label='PTQ', color='#ef4444')
bars2 = ax.bar(x + width/2, results['qat'], width, label='QAT', color='#22c55e')

ax.set_xlabel('Bit Width', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('üìä QAT vs PTQ Accuracy', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels([f'{b}-bit' for b in results['bits']])
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'{height:.1f}%', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

print('\nüí° QAT significantly outperforms PTQ at low bit widths!')

## Part 4: Mixed-Precision Quantization

In [None]:
class MixedPrecisionNet(nn.Module):
    """
    Different layers use different bit widths.
    First/last layers often need more precision.
    """
    def __init__(self):
        super().__init__()
        # First layer: 8-bit (sensitive to quantization)
        self.fc1 = QATLinear(784, 256, weight_bits=8, act_bits=8)
        # Middle layer: 4-bit (can handle more compression)
        self.fc2 = QATLinear(256, 64, weight_bits=4, act_bits=4)
        # Last layer: 8-bit (sensitive to quantization)
        self.fc3 = QATLinear(64, 10, weight_bits=8, act_bits=8)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

print('üìä MIXED-PRECISION QUANTIZATION')
print('=' * 50)

# Train uniform 4-bit model
uniform_4bit = QATNet(weight_bits=4, act_bits=4)
train_model(uniform_4bit, X_train, y_train, epochs=30)
uniform_acc = evaluate(uniform_4bit, X_test, y_test)

# Train mixed precision model
mixed_model = MixedPrecisionNet()
train_model(mixed_model, X_train, y_train, epochs=30)
mixed_acc = evaluate(mixed_model, X_test, y_test)

# Calculate average bits
def calc_avg_bits(model):
    total_params = 0
    total_bits = 0
    for name, module in model.named_modules():
        if isinstance(module, QATLinear):
            params = module.weight.numel()
            total_params += params
            total_bits += params * module.weight_bits
    return total_bits / total_params if total_params > 0 else 0

print(f'\nUniform 4-bit:')
print(f'  Accuracy: {uniform_acc:.1f}%')
print(f'  Avg bits: 4.0')

print(f'\nMixed Precision (8-4-8):')
print(f'  Accuracy: {mixed_acc:.1f}%')
print(f'  Avg bits: ~{calc_avg_bits(mixed_model):.1f}')

print(f'\nüí° Mixed precision: Better accuracy with similar compression!')

## Part 5: Sensitivity Analysis

In [None]:
def layer_sensitivity_analysis(model, X, y, bits_range=[8, 6, 4, 2]):
    """
    Analyze each layer's sensitivity to quantization.
    Quantize one layer at a time and measure accuracy drop.
    """
    baseline = evaluate(model, X, y)
    
    results = {}
    
    for name, param in model.named_parameters():
        if 'weight' not in name:
            continue
        
        results[name] = []
        
        for bits in bits_range:
            # Create copy and quantize only this layer
            model_copy = type(model)()
            model_copy.load_state_dict(model.state_dict())
            
            with torch.no_grad():
                for n, p in model_copy.named_parameters():
                    if n == name:
                        p.data = quantize_weight(p.data, bits)
            
            acc = evaluate(model_copy, X, y)
            results[name].append({
                'bits': bits,
                'accuracy': acc,
                'drop': baseline - acc
            })
    
    return results, baseline

# Run sensitivity analysis
sensitivity, baseline = layer_sensitivity_analysis(model, X_test, y_test)

print('üìä LAYER SENSITIVITY ANALYSIS')
print('=' * 60)
print(f'Baseline (FP32): {baseline:.1f}%')
print(f'\n{"Layer":<20} {"8-bit":<12} {"4-bit":<12} {"2-bit":<12}')
print('-' * 60)

for name, results in sensitivity.items():
    short_name = name.split('.')[0]
    row = f'{short_name:<20}'
    for r in results:
        if r['bits'] in [8, 4, 2]:
            row += f'{r["drop"]:+.1f}%{" ":>8}'
    print(row)

# Visualize
fig, ax = plt.subplots(figsize=(10, 6))

layers = list(sensitivity.keys())
x = np.arange(len(layers))
width = 0.25

for i, bits in enumerate([8, 4, 2]):
    drops = []
    for layer in layers:
        for r in sensitivity[layer]:
            if r['bits'] == bits:
                drops.append(r['drop'])
                break
    
    ax.bar(x + i*width, drops, width, label=f'{bits}-bit')

ax.set_xlabel('Layer')
ax.set_ylabel('Accuracy Drop (%)')
ax.set_title('üìä Layer Sensitivity to Quantization')
ax.set_xticks(x + width)
ax.set_xticklabels([l.split('.')[0] for l in layers])
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print('\nüí° Use this analysis to decide per-layer bit widths!')

In [None]:
print('üéØ KEY TAKEAWAYS')
print('=' * 60)
print('\n1. QAT > PTQ, especially at low bit widths (4-bit, 2-bit)')
print('\n2. STE: Forward=quantize, Backward=pass gradient through')
print('\n3. Fake quantization: Quantize‚ÜíDequantize during training')
print('\n4. Network learns to be robust to quantization noise')
print('\n5. Mixed precision: Different bits for different layers')
print('\n6. First/last layers are most sensitive')
print('\n7. Sensitivity analysis guides bit-width allocation')
print('\n' + '=' * 60)
print('\nüìö Next: Neural Architecture Search!')