# ðŸ”¢ Lecture 5: Quantization Basics - Complete Demo

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/gaurav-redhat/efficientml_course/blob/main/05_quantization_1/demo.ipynb)

## What You'll Learn
- Data types: FP32, FP16, INT8, INT4
- Quantization math: scale and zero-point
- Symmetric vs asymmetric quantization
- Post-Training Quantization (PTQ)

In [None]:
!pip install torch matplotlib numpy -q
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

torch.manual_seed(42)
print('Ready for quantization!')

## Part 1: Understanding Data Types

Different precision formats trade off range and accuracy for memory.

In [None]:
# Data type comparison
dtypes = {
    'FP32': {'bits': 32, 'range': 'Â±3.4e38', 'precision': '~7 decimal digits'},
    'FP16': {'bits': 16, 'range': 'Â±65504', 'precision': '~3 decimal digits'},
    'BF16': {'bits': 16, 'range': 'Â±3.4e38', 'precision': '~2 decimal digits'},
    'INT8': {'bits': 8, 'range': '-128 to 127', 'precision': 'integer only'},
    'INT4': {'bits': 4, 'range': '-8 to 7', 'precision': 'integer only'},
}

print('ðŸ“Š DATA TYPE COMPARISON')
print('=' * 70)
print(f'{"Type":<10} {"Bits":<8} {"Range":<20} {"Precision":<25}')
print('-' * 70)
for name, info in dtypes.items():
    print(f'{name:<10} {info["bits"]:<8} {info["range"]:<20} {info["precision"]:<25}')

# Memory savings visualization
fig, ax = plt.subplots(figsize=(10, 5))
types = list(dtypes.keys())
bits = [dtypes[t]['bits'] for t in types]
colors = ['#ef4444', '#f97316', '#eab308', '#22c55e', '#10b981']

bars = ax.barh(types, bits, color=colors)
ax.set_xlabel('Bits per Value', fontsize=12)
ax.set_title('ðŸ“Š Memory per Value by Data Type', fontsize=14)

for bar, b in zip(bars, bits):
    compression = 32 / b
    ax.text(bar.get_width() + 0.5, bar.get_y() + bar.get_height()/2, 
            f'{compression:.1f}x compression vs FP32', va='center')

plt.tight_layout()
plt.show()

## Part 2: The Math of Quantization

**Quantization Formula:**
$$Q(x) = \text{round}\left(\frac{x}{s}\right) + z$$

**Dequantization:**
$$\hat{x} = s \cdot (Q(x) - z)$$

Where:
- $s$ = scale factor
- $z$ = zero-point

In [None]:
def quantize_tensor(x, num_bits=8, symmetric=True):
    """
    Quantize a tensor to specified bit width.
    
    Args:
        x: Input tensor (FP32)
        num_bits: Target bit width
        symmetric: Use symmetric or asymmetric quantization
    
    Returns:
        Quantized tensor, scale, zero_point
    """
    if symmetric:
        # Symmetric: zero-point = 0, range = [-max, max]
        qmin = -(2 ** (num_bits - 1))
        qmax = 2 ** (num_bits - 1) - 1
        
        max_val = x.abs().max()
        scale = max_val / qmax
        zero_point = 0
    else:
        # Asymmetric: full range utilization
        qmin = 0
        qmax = 2 ** num_bits - 1
        
        min_val, max_val = x.min(), x.max()
        scale = (max_val - min_val) / (qmax - qmin)
        zero_point = qmin - torch.round(min_val / scale)
    
    # Quantize
    q = torch.clamp(torch.round(x / scale) + zero_point, qmin, qmax)
    
    return q.to(torch.int8 if num_bits == 8 else torch.int32), scale, zero_point

def dequantize_tensor(q, scale, zero_point):
    """Dequantize back to FP32."""
    return scale * (q.float() - zero_point)

# Example
x = torch.randn(1000) * 2 + 0.5  # Non-zero mean tensor

print('ðŸ“Š QUANTIZATION EXAMPLE')
print('=' * 50)
print(f'Original tensor: min={x.min():.4f}, max={x.max():.4f}')

# Symmetric quantization
q_sym, scale_sym, zp_sym = quantize_tensor(x, num_bits=8, symmetric=True)
x_deq_sym = dequantize_tensor(q_sym, scale_sym, zp_sym)

print(f'\nðŸ”· Symmetric INT8:')
print(f'   Scale: {scale_sym:.6f}')
print(f'   Zero-point: {zp_sym}')
print(f'   Reconstruction error: {torch.mean((x - x_deq_sym) ** 2):.6f}')

# Asymmetric quantization
q_asym, scale_asym, zp_asym = quantize_tensor(x, num_bits=8, symmetric=False)
x_deq_asym = dequantize_tensor(q_asym, scale_asym, zp_asym)

print(f'\nðŸ”¶ Asymmetric INT8:')
print(f'   Scale: {scale_asym:.6f}')
print(f'   Zero-point: {zp_asym:.0f}')
print(f'   Reconstruction error: {torch.mean((x - x_deq_asym) ** 2):.6f}')

In [None]:
# Visualize quantization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Original distribution
axes[0, 0].hist(x.numpy(), bins=50, color='#3b82f6', alpha=0.7, edgecolor='black')
axes[0, 0].set_title('Original FP32 Values', fontsize=12)
axes[0, 0].set_xlabel('Value')

# Quantized values (symmetric)
axes[0, 1].hist(q_sym.numpy(), bins=50, color='#22c55e', alpha=0.7, edgecolor='black')
axes[0, 1].set_title('Quantized INT8 Values (Symmetric)', fontsize=12)
axes[0, 1].set_xlabel('Quantized Value')

# Reconstruction comparison
indices = torch.argsort(x)[:100]
axes[1, 0].plot(x[indices].numpy(), 'b-', label='Original', linewidth=2)
axes[1, 0].plot(x_deq_sym[indices].numpy(), 'g--', label='Symmetric', linewidth=2)
axes[1, 0].plot(x_deq_asym[indices].numpy(), 'r:', label='Asymmetric', linewidth=2)
axes[1, 0].set_title('Original vs Reconstructed', fontsize=12)
axes[1, 0].legend()
axes[1, 0].set_xlabel('Sample Index')

# Quantization error
error_sym = (x - x_deq_sym).numpy()
error_asym = (x - x_deq_asym).numpy()
axes[1, 1].hist(error_sym, bins=50, alpha=0.5, label=f'Symmetric (MSE={np.mean(error_sym**2):.6f})', color='green')
axes[1, 1].hist(error_asym, bins=50, alpha=0.5, label=f'Asymmetric (MSE={np.mean(error_asym**2):.6f})', color='red')
axes[1, 1].set_title('Quantization Error Distribution', fontsize=12)
axes[1, 1].legend()
axes[1, 1].set_xlabel('Error')

plt.tight_layout()
plt.show()

## Part 3: Different Bit Widths

In [None]:
# Compare different bit widths
bit_widths = [8, 4, 2]
test_tensor = torch.randn(10000)

print('ðŸ“Š BIT WIDTH COMPARISON')
print('=' * 60)
print(f'{"Bits":<8} {"Levels":<12} {"MSE":<15} {"Max Error":<15}')
print('-' * 60)

results = []
for bits in bit_widths:
    q, s, z = quantize_tensor(test_tensor, num_bits=bits, symmetric=True)
    deq = dequantize_tensor(q, s, z)
    
    mse = torch.mean((test_tensor - deq) ** 2).item()
    max_err = torch.max(torch.abs(test_tensor - deq)).item()
    levels = 2 ** bits
    
    results.append({'bits': bits, 'levels': levels, 'mse': mse, 'max_err': max_err})
    print(f'{bits:<8} {levels:<12} {mse:<15.6f} {max_err:<15.4f}')

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, bits in enumerate(bit_widths):
    q, s, z = quantize_tensor(test_tensor[:100], num_bits=bits, symmetric=True)
    deq = dequantize_tensor(q, s, z)
    
    axes[idx].plot(test_tensor[:100].numpy(), 'b-', alpha=0.7, label='Original')
    axes[idx].plot(deq.numpy(), 'r--', alpha=0.7, label='Quantized')
    axes[idx].set_title(f'{bits}-bit Quantization\n({2**bits} levels)', fontsize=12)
    axes[idx].legend()
    axes[idx].set_xlabel('Sample')

plt.tight_layout()
plt.show()

print('\nðŸ’¡ Lower bits = more error, but much smaller model!')

## Part 4: Post-Training Quantization (PTQ)

In [None]:
# Create a simple model
class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

# Simulate trained model
model = SimpleClassifier()

# Create test data
X_test = torch.randn(1000, 784)
y_test = torch.randint(0, 10, (1000,))

def get_model_size(model):
    """Get model size in MB."""
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    return param_size / 1024 / 1024

def evaluate_model(model, X, y):
    """Evaluate model accuracy."""
    model.eval()
    with torch.no_grad():
        outputs = model(X)
        _, predicted = outputs.max(1)
        accuracy = (predicted == y).float().mean().item() * 100
    return accuracy

print('ðŸ“Š ORIGINAL MODEL')
print('=' * 40)
print(f'Size: {get_model_size(model):.2f} MB')
print(f'Accuracy: {evaluate_model(model, X_test, y_test):.1f}%')

In [None]:
# PyTorch Dynamic Quantization
print('\nðŸ”§ PYTORCH DYNAMIC QUANTIZATION')
print('=' * 50)

# Apply dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear},  # Layers to quantize
    dtype=torch.qint8
)

print('\nQuantized model structure:')
print(quantized_model)

# Compare
orig_size = get_model_size(model)
quant_acc = evaluate_model(quantized_model, X_test, y_test)

print(f'\nðŸ“Š COMPARISON')
print(f'Original accuracy: {evaluate_model(model, X_test, y_test):.1f}%')
print(f'Quantized accuracy: {quant_acc:.1f}%')
print(f'\nðŸ’¾ Model size comparison requires saving to disk...')

## Part 5: Calibration for Static Quantization

In [None]:
def collect_activation_stats(model, calibration_data):
    """
    Collect activation statistics for calibration.
    This helps determine optimal scale and zero-point for activations.
    """
    stats = {}
    hooks = []
    
    def hook_fn(name):
        def hook(module, input, output):
            if name not in stats:
                stats[name] = {'min': float('inf'), 'max': float('-inf'), 'values': []}
            
            out = output.detach()
            stats[name]['min'] = min(stats[name]['min'], out.min().item())
            stats[name]['max'] = max(stats[name]['max'], out.max().item())
            stats[name]['values'].append(out.flatten()[:1000])  # Sample
        return hook
    
    # Register hooks
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.ReLU)):
            hooks.append(module.register_forward_hook(hook_fn(name)))
    
    # Run calibration data
    model.eval()
    with torch.no_grad():
        for batch in calibration_data:
            _ = model(batch)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return stats

# Run calibration
calibration_data = [torch.randn(32, 784) for _ in range(10)]
activation_stats = collect_activation_stats(model, calibration_data)

print('ðŸ“Š ACTIVATION STATISTICS FOR CALIBRATION')
print('=' * 60)
print(f'{"Layer":<20} {"Min":<15} {"Max":<15} {"Range":<15}')
print('-' * 60)

for name, stat in activation_stats.items():
    range_val = stat['max'] - stat['min']
    print(f'{name:<20} {stat["min"]:<15.4f} {stat["max"]:<15.4f} {range_val:<15.4f}')

# Visualize activation distributions
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()

for idx, (name, stat) in enumerate(list(activation_stats.items())[:4]):
    values = torch.cat(stat['values']).numpy()
    axes[idx].hist(values, bins=50, color='#3b82f6', alpha=0.7, edgecolor='black')
    axes[idx].axvline(x=0, color='red', linestyle='--')
    axes[idx].set_title(f'{name}\nRange: [{stat["min"]:.2f}, {stat["max"]:.2f}]')

plt.suptitle('ðŸ“Š Activation Distributions (for Calibration)', fontsize=14)
plt.tight_layout()
plt.show()

print('\nðŸ’¡ Calibration data helps find optimal quantization parameters!')

## Part 6: Quantization Error Analysis

In [None]:
def analyze_quantization_error(model, num_bits=8):
    """
    Analyze quantization error per layer.
    """
    errors = {}
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            # Quantize and dequantize
            q, s, z = quantize_tensor(param.data, num_bits=num_bits)
            deq = dequantize_tensor(q, s, z)
            
            # Calculate errors
            mse = torch.mean((param.data - deq) ** 2).item()
            relative_error = (torch.abs(param.data - deq) / (torch.abs(param.data) + 1e-10)).mean().item()
            
            errors[name] = {
                'mse': mse,
                'relative_error': relative_error * 100,
                'weight_range': (param.min().item(), param.max().item())
            }
    
    return errors

# Analyze errors at different bit widths
print('ðŸ“Š PER-LAYER QUANTIZATION ERROR')
print('=' * 70)

for bits in [8, 4]:
    print(f'\n{bits}-bit Quantization:')
    print(f'{"Layer":<15} {"MSE":<15} {"Rel. Error (%)":<15} {"Weight Range":<25}')
    print('-' * 70)
    
    errors = analyze_quantization_error(model, num_bits=bits)
    for name, err in errors.items():
        short_name = name.split('.')[0]
        w_range = f'[{err["weight_range"][0]:.3f}, {err["weight_range"][1]:.3f}]'
        print(f'{short_name:<15} {err["mse"]:<15.6f} {err["relative_error"]:<15.2f} {w_range:<25}')

In [None]:
print('ðŸŽ¯ KEY TAKEAWAYS')
print('=' * 60)
print('\n1. Quantization: FP32 â†’ INT8/INT4 (4-8x memory reduction)')
print('\n2. Scale and zero-point map floating point to integer range')
print('\n3. Symmetric: simpler, works well for weights')
print('\n4. Asymmetric: better range utilization for activations')
print('\n5. PTQ: Quick, no retraining, slight accuracy loss')
print('\n6. Calibration: Run sample data to find optimal parameters')
print('\n7. Lower bits = more compression but more error')
print('\n' + '=' * 60)
print('\nðŸ“š Next: Quantization-Aware Training for better accuracy!')