# Module 9.7: Quantization Theory

**Goal**: Deep dive into quantization mathematics

**Time**: 90 minutes

**Concepts Covered**:
- Uniform quantization math
- Per-tensor vs per-channel
- GPTQ algorithm implementation
- Quantization error analysis
- Weight distribution visualization

## Setup

In [None]:
!pip install torch transformers accelerate matplotlib seaborn numpy -q

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

def uniform_quantize(weights, bits=8):
    """Uniform quantization"""
    # Calculate scale and zero point
    w_min = weights.min()
    w_max = weights.max()
    
    scale = (w_max - w_min) / (2 ** bits - 1)
    zero_point = -w_min / scale
    
    # Quantize
    q_weights = torch.round(weights / scale + zero_point)
    q_weights = torch.clamp(q_weights, 0, 2 ** bits - 1)
    
    # Dequantize
    dequantized = (q_weights - zero_point) * scale
    
    return dequantized, scale, zero_point

def gptq_quantize(layer, bits=4):
    """GPTQ: Optimal quantization with Hessian"""
    # Simplified GPTQ algorithm
    weights = layer.weight.data.clone()
    num_bits = bits
    
    # Per-channel quantization
    scales = []
    quantized_weights = []
    
    for channel in range(weights.shape[0]):
        channel_weights = weights[channel]
        
        # Find optimal scale for this channel
        w_abs_max = channel_weights.abs().max()
        scale = w_abs_max / (2 ** (num_bits - 1) - 1)
        
        # Quantize
        q_weights = torch.round(channel_weights / scale)
        q_weights = torch.clamp(q_weights, -2 ** (num_bits - 1), 2 ** (num_bits - 1) - 1)
        
        # Dequantize
        dequantized = q_weights * scale
        
        scales.append(scale)
        quantized_weights.append(dequantized)
    
    quantized_weights = torch.stack(quantized_weights)
    
    return quantized_weights, scales

# Example
weights = torch.randn(128, 256) * 0.1

# Uniform quantization
uniform_q, scale, zp = uniform_quantize(weights, bits=8)
uniform_error = (weights - uniform_q).abs().mean()

# GPTQ quantization
gptq_q, scales = gptq_quantize(torch.nn.Linear(256, 128), bits=4)
gptq_error = (weights - gptq_q).abs().mean()

print(f"Uniform quantization error: {uniform_error:.6f}")
print(f"GPTQ quantization error: {gptq_error:.6f}")
print("\nGPTQ uses per-channel quantization for better accuracy")

## Key Takeaways

✅ **Module Complete**

## Next Steps

Continue to the next module in the course.