# Quantization Tutorial: From Theory to Practice

**A comprehensive guide to neural network quantization techniques**

---

## Table of Contents

1. [Introduction to Quantization](#introduction)
2. [Quantization Fundamentals](#fundamentals)
3. [Ternary Quantization](#ternary)
4. [INT8 Quantization](#int8)
5. [Mixed Precision](#mixed-precision)
6. [QAT vs PTQ](#qat-vs-ptq)
7. [Performance Benchmarks](#benchmarks)
8. [Practical Guidelines](#guidelines)

## 1. Introduction to Quantization {#introduction}

**Quantization** is the process of constraining neural network weights and activations from continuous values to a discrete set of values.

### Why Quantize?

| Benefit | Description | Impact |
|---------|-------------|--------|
| **Memory** | Reduced bit-width per parameter | 2-16x smaller models |
| **Speed** | Faster arithmetic operations | 2-4x faster inference |
| **Power** | Lower energy consumption | Critical for edge devices |
| **Bandwidth** | Reduced data transfer | Faster model loading |

### Quantization Levels

```
Float32 (baseline)  â†’  32 bits per weight
Float16             â†’  16 bits (2x compression)
INT8                â†’   8 bits (4x compression)
INT4                â†’   4 bits (8x compression)
Ternary {-1,0,1}    â†’   2 bits (16x compression)
Binary {-1,1}       â†’   1 bit (32x compression)
```

## Setup

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import time
from tqdm.notebook import tqdm
import pandas as pd

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (14, 6)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 2. Quantization Fundamentals {#fundamentals}

### 2.1 Uniform Quantization

The basic quantization formula:

$$
Q(x) = \text{round}\left(\frac{x - z}{s}\right)
$$

Where:
- $x$ is the floating-point value
- $s$ is the scale factor
- $z$ is the zero-point
- $Q(x)$ is the quantized integer value

Dequantization:

$$
\tilde{x} = s \cdot Q(x) + z
$$

In [None]:
def uniform_quantize(x, bits=8, symmetric=True):
    """
    Uniform quantization to n-bit integers
    
    Args:
        x: Input tensor
        bits: Number of bits (e.g., 8 for INT8)
        symmetric: Use symmetric quantization (zero-point = 0)
    """
    if symmetric:
        # Symmetric: range is [-max_val, max_val]
        max_val = torch.abs(x).max()
        qmax = 2 ** (bits - 1) - 1
        scale = max_val / qmax
        zero_point = 0
    else:
        # Asymmetric: range is [min_val, max_val]
        min_val, max_val = x.min(), x.max()
        qmax = 2 ** bits - 1
        scale = (max_val - min_val) / qmax
        zero_point = -torch.round(min_val / scale)
    
    # Quantize
    x_q = torch.round(x / scale + zero_point)
    
    # Clamp to valid range
    if symmetric:
        x_q = torch.clamp(x_q, -qmax, qmax)
    else:
        x_q = torch.clamp(x_q, 0, qmax)
    
    # Dequantize
    x_dq = scale * (x_q - zero_point)
    
    return x_dq, x_q, scale, zero_point


# Demonstrate quantization
x = torch.randn(1000) * 3

fig, axes = plt.subplots(2, 3, figsize=(16, 8))

# Original
axes[0, 0].hist(x.numpy(), bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[0, 0].set_title('Original (Float32)', fontsize=12, fontweight='bold')
axes[0, 0].set_ylabel('Frequency')

# Quantization at different bit-widths
bit_configs = [(8, 'INT8'), (4, 'INT4'), (2, 'Ternary-like')]
for idx, (bits, name) in enumerate(bit_configs):
    x_dq, x_q, scale, zp = uniform_quantize(x, bits=bits)
    
    # Dequantized values
    ax = axes[0, idx + 1] if idx < 2 else axes[1, 0]
    ax.hist(x_dq.numpy(), bins=50, alpha=0.7, color='green', edgecolor='black')
    ax.set_title(f'{name} (dequantized)', fontsize=12, fontweight='bold')
    if idx == 0:
        ax.set_ylabel('Frequency')
    
    # Quantization error
    error = (x - x_dq).abs()
    ax = axes[1, idx + 1] if idx < 2 else axes[1, 1]
    ax.hist(error.numpy(), bins=50, alpha=0.7, color='red', edgecolor='black')
    ax.set_title(f'{name} Error', fontsize=12, fontweight='bold')
    ax.set_xlabel('Absolute Error')
    if idx == 0:
        ax.set_ylabel('Frequency')

# Statistics table
axes[1, 2].axis('off')
stats_text = "Quantization Metrics\n" + "="*25 + "\n\n"
for bits, name in bit_configs:
    x_dq, _, _, _ = uniform_quantize(x, bits=bits)
    mse = F.mse_loss(x, x_dq).item()
    mae = (x - x_dq).abs().mean().item()
    stats_text += f"{name}:\n  MSE: {mse:.4f}\n  MAE: {mae:.4f}\n\n"

axes[1, 2].text(0.1, 0.5, stats_text, fontsize=11, family='monospace',
               verticalalignment='center')

plt.tight_layout()
plt.show()

## 3. Ternary Quantization {#ternary}

Ternary quantization constrains weights to **{-1, 0, 1}**, providing extreme compression with minimal accuracy loss.

### Methods:

1. **Deterministic**: Threshold-based quantization
2. **Stochastic**: Probabilistic quantization
3. **Learned**: Trainable thresholds

### 3.1 Deterministic Ternary Quantization

In [None]:
class TernaryQuantize(torch.autograd.Function):
    """Ternary quantization with STE"""
    
    @staticmethod
    def forward(ctx, input, threshold=0.5):
        output = torch.sign(input)
        output[torch.abs(input) < threshold] = 0
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, None


def stochastic_ternary_quantize(x):
    """
    Stochastic ternary quantization
    
    Quantizes with probability proportional to value magnitude:
    P(Q(x) = sign(x)) = |x|
    P(Q(x) = 0) = 1 - |x|
    """
    x_scaled = torch.clamp(torch.abs(x), 0, 1)
    prob = torch.rand_like(x)
    output = torch.where(prob < x_scaled, torch.sign(x), torch.zeros_like(x))
    return output


def learned_threshold_quantize(x, threshold):
    """
    Ternary quantization with learned threshold
    
    Args:
        x: Input weights
        threshold: Learnable threshold parameter
    """
    t = torch.sigmoid(threshold)  # Ensure threshold in [0, 1]
    x_abs = torch.abs(x)
    adaptive_t = t * x_abs.mean()  # Adaptive to weight distribution
    
    output = torch.sign(x)
    output[x_abs < adaptive_t] = 0
    return output

### 3.2 Compare Ternary Methods

In [None]:
# Generate sample weights
weights = torch.randn(10000) * 1.5
threshold_param = torch.tensor(0.0)  # For learned threshold

# Apply different methods
det_quant = TernaryQuantize.apply(weights, 0.5)
stoch_quant = stochastic_ternary_quantize(weights)
learned_quant = learned_threshold_quantize(weights, threshold_param)

fig, axes = plt.subplots(2, 4, figsize=(18, 8))

# Original
axes[0, 0].hist(weights.numpy(), bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[0, 0].set_title('Original Weights', fontsize=12, fontweight='bold')
axes[0, 0].set_ylabel('Frequency', fontsize=11)

# Quantization methods
methods = [
    (det_quant, 'Deterministic', 'green'),
    (stoch_quant, 'Stochastic', 'orange'),
    (learned_quant, 'Learned Threshold', 'purple')
]

for idx, (quant, name, color) in enumerate(methods):
    # Distribution
    counts = [(quant == -1).sum().item(), 
              (quant == 0).sum().item(), 
              (quant == 1).sum().item()]
    
    axes[0, idx + 1].bar([-1, 0, 1], counts, color=[color]*3, 
                         alpha=0.7, edgecolor='black', linewidth=2)
    axes[0, idx + 1].set_title(name, fontsize=12, fontweight='bold')
    axes[0, idx + 1].set_xticks([-1, 0, 1])
    axes[0, idx + 1].set_xlabel('Value', fontsize=11)
    
    # Scatter plot: original vs quantized
    sample_idx = np.random.choice(len(weights), 500, replace=False)
    axes[1, idx].scatter(weights[sample_idx].numpy(), quant[sample_idx].numpy(), 
                        alpha=0.3, s=10, color=color)
    axes[1, idx].axhline(0, color='gray', linestyle='--', linewidth=1)
    axes[1, idx].set_title(f'{name} Mapping', fontsize=12, fontweight='bold')
    axes[1, idx].set_xlabel('Original Weight', fontsize=11)
    axes[1, idx].set_ylabel('Quantized Weight', fontsize=11)
    axes[1, idx].set_yticks([-1, 0, 1])
    axes[1, idx].grid(True, alpha=0.3)

# Statistics
axes[1, 3].axis('off')
stats = "Sparsity Analysis\n" + "="*30 + "\n\n"
for quant, name, _ in methods:
    sparsity = (quant == 0).sum().item() / len(quant) * 100
    stats += f"{name}:\n  Sparsity: {sparsity:.1f}%\n"
    stats += f"  -1: {(quant == -1).sum().item()}\n"
    stats += f"   0: {(quant == 0).sum().item()}\n"
    stats += f"  +1: {(quant == 1).sum().item()}\n\n"

axes[1, 3].text(0.1, 0.5, stats, fontsize=10, family='monospace',
               verticalalignment='center')

plt.tight_layout()
plt.show()

## 4. INT8 Quantization {#int8}

INT8 quantization maps floating-point values to 8-bit integers, providing a good balance between compression and accuracy.

In [None]:
class INT8Quantize(torch.autograd.Function):
    """INT8 quantization with STE"""
    
    @staticmethod
    def forward(ctx, input):
        # Calculate scale
        max_val = torch.abs(input).max()
        scale = max_val / 127.0
        
        # Quantize to INT8
        input_q = torch.round(input / scale)
        input_q = torch.clamp(input_q, -127, 127)
        
        # Dequantize
        output = input_q * scale
        
        ctx.scale = scale
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class LinearINT8(nn.Module):
    """Linear layer with INT8 quantization"""
    
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        self.bias = nn.Parameter(torch.zeros(out_features))
    
    def forward(self, x):
        weight_q = INT8Quantize.apply(self.weight)
        return F.linear(x, weight_q, self.bias)


# Compare float32 vs INT8
weights = torch.randn(5000) * 2
weights_int8 = INT8Quantize.apply(weights)

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Original
axes[0].hist(weights.numpy(), bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[0].set_title('Float32 Weights', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Weight Value')
axes[0].set_ylabel('Frequency')

# INT8
axes[1].hist(weights_int8.numpy(), bins=50, alpha=0.7, color='green', edgecolor='black')
axes[1].set_title('INT8 Quantized', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Weight Value')

# Error distribution
error = (weights - weights_int8).abs()
axes[2].hist(error.numpy(), bins=50, alpha=0.7, color='red', edgecolor='black')
axes[2].set_title('Quantization Error', fontsize=12, fontweight='bold')
axes[2].set_xlabel('Absolute Error')

plt.tight_layout()
plt.show()

print(f"Mean Absolute Error: {error.mean().item():.6f}")
print(f"Max Absolute Error: {error.max().item():.6f}")
print(f"SNR: {20 * np.log10(weights.std().item() / error.std().item()):.2f} dB")

## 5. Mixed Precision {#mixed-precision}

Mixed precision uses different quantization levels for different layers:

- **First layer**: Float32 or Float16 (input sensitivity)
- **Middle layers**: Ternary or INT8 (bulk of parameters)
- **Last layer**: Float32 (classification accuracy)

In [None]:
class MixedPrecisionNet(nn.Module):
    """
    Mixed precision network:
    - Layer 1: Float32 (high precision for input)
    - Layer 2-3: Ternary (compression)
    - Layer 4: Float32 (accuracy for output)
    """
    
    def __init__(self, input_dim=784, hidden_dims=[256, 128, 64], output_dim=10):
        super().__init__()
        
        # First layer: Full precision
        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        
        # Middle layers: Ternary
        from examples.mnist_ternary import LinearTernary
        self.fc2 = LinearTernary(hidden_dims[0], hidden_dims[1])
        self.fc3 = LinearTernary(hidden_dims[1], hidden_dims[2])
        
        # Last layer: Full precision
        self.fc4 = nn.Linear(hidden_dims[2], output_dim)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    
    def get_precision_info(self):
        """Get precision information for each layer"""
        info = []
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                params = sum(p.numel() for p in module.parameters())
                bits = 32
                precision = 'Float32'
            elif hasattr(module, '__class__') and 'Ternary' in module.__class__.__name__:
                params = sum(p.numel() for p in module.parameters())
                bits = 2
                precision = 'Ternary'
            else:
                continue
            
            size_kb = params * bits / 8 / 1024
            info.append({
                'Layer': name,
                'Precision': precision,
                'Bits': bits,
                'Parameters': params,
                'Size (KB)': size_kb
            })
        return info


# Create model and analyze
try:
    model_mixed = MixedPrecisionNet()
    info = model_mixed.get_precision_info()
    
    # Create visualization
    df = pd.DataFrame(info)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))
    
    # Layer sizes
    colors = ['steelblue' if p == 'Float32' else 'green' for p in df['Precision']]
    ax1.barh(df['Layer'], df['Size (KB)'], color=colors, alpha=0.7, edgecolor='black')
    ax1.set_xlabel('Size (KB)', fontsize=12)
    ax1.set_title('Layer Size by Precision', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    # Precision distribution
    precision_sizes = df.groupby('Precision')['Size (KB)'].sum()
    ax2.pie(precision_sizes, labels=precision_sizes.index, autopct='%1.1f%%',
           colors=['green', 'steelblue'], startangle=90,
           textprops={'fontsize': 12, 'fontweight': 'bold'})
    ax2.set_title('Storage Distribution', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    print("\nMixed Precision Model Analysis:")
    print("="*60)
    print(df.to_string(index=False))
    print("="*60)
    print(f"Total Size: {df['Size (KB)'].sum():.2f} KB")
    print(f"Ternary Portion: {precision_sizes.get('Ternary', 0) / precision_sizes.sum() * 100:.1f}%")
    
except ImportError:
    print("Note: Run notebook 01_introduction.ipynb first to have LinearTernary available")

## 6. QAT vs PTQ {#qat-vs-ptq}

Two main approaches to quantization:

### Post-Training Quantization (PTQ)
- Quantize after training
- Fast and simple
- May have accuracy drop

### Quantization-Aware Training (QAT)
- Simulate quantization during training
- Better accuracy preservation
- Requires retraining

In [None]:
class PTQModel(nn.Module):
    """Standard model for Post-Training Quantization"""
    
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def quantize_ptq(self):
        """Apply post-training quantization"""
        with torch.no_grad():
            self.fc1.weight.data = TernaryQuantize.apply(self.fc1.weight.data)
            self.fc2.weight.data = TernaryQuantize.apply(self.fc2.weight.data)


class QATModel(nn.Module):
    """Model with Quantization-Aware Training"""
    
    def __init__(self):
        super().__init__()
        # Import ternary layer
        try:
            from examples.mnist_ternary import LinearTernary
            self.fc1 = LinearTernary(784, 128)
            self.fc2 = LinearTernary(128, 10)
        except:
            # Fallback
            self.fc1 = nn.Linear(784, 128)
            self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Simulate training comparison
def simulate_training_curves():
    """Simulate typical QAT vs PTQ training curves"""
    epochs = np.arange(1, 11)
    
    # Simulated accuracies (based on typical behavior)
    # Float baseline
    float_acc = 90 + 8 * (1 - np.exp(-epochs / 3))
    
    # QAT: gradual learning
    qat_acc = 85 + 10 * (1 - np.exp(-epochs / 3.5))
    
    # PTQ: immediate accuracy drop, then slight recovery
    ptq_base = float_acc[-1]  # Start from trained model
    ptq_drop = 5  # Immediate drop from quantization
    ptq_acc = np.full_like(epochs, ptq_base - ptq_drop, dtype=float)
    
    return epochs, float_acc, qat_acc, ptq_acc


epochs, float_acc, qat_acc, ptq_acc = simulate_training_curves()

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 5))

# Training curves
ax1.plot(epochs, float_acc, 'o-', linewidth=2, label='Float32 Baseline', color='blue')
ax1.plot(epochs, qat_acc, 's-', linewidth=2, label='QAT (Ternary)', color='green')
ax1.axhline(ptq_acc[0], linestyle='--', linewidth=2, label='PTQ (Ternary)', color='red')
ax1.fill_between(epochs, ptq_acc[0] - 1, ptq_acc[0] + 1, alpha=0.2, color='red')
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.set_title('QAT vs PTQ Training Curves', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)

# Comparison table
comparison_data = [
    ['Method', 'Accuracy', 'Training Time', 'Difficulty'],
    ['Float32', '98.0%', 'Baseline', 'Easy'],
    ['PTQ', '93.0%', 'None (instant)', 'Very Easy'],
    ['QAT', '95.0%', '+20% overhead', 'Moderate']
]

ax2.axis('tight')
ax2.axis('off')
table = ax2.table(cellText=comparison_data, cellLoc='left', loc='center',
                 colWidths=[0.2, 0.2, 0.3, 0.2])
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1, 3)

# Style header row
for i in range(4):
    table[(0, i)].set_facecolor('#4CAF50')
    table[(0, i)].set_text_props(weight='bold', color='white')

ax2.set_title('Method Comparison', fontsize=14, fontweight='bold', pad=20)

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("QAT vs PTQ Summary")
print("="*60)
print("\nPTQ (Post-Training Quantization):")
print("  âœ“ Very fast (no retraining)")
print("  âœ“ Simple to implement")
print("  âœ— Larger accuracy drop (3-5%)")
print("\nQAT (Quantization-Aware Training):")
print("  âœ“ Better accuracy (1-2% drop)")
print("  âœ“ Model adapts to quantization")
print("  âœ— Requires retraining")
print("  âœ— More complex implementation")
print("="*60)

## 7. Performance Benchmarks {#benchmarks}

Let's benchmark different quantization methods.

In [None]:
def benchmark_quantization(input_size=1000, output_size=1000, num_runs=100):
    """Benchmark different quantization methods"""
    
    # Create test data
    x = torch.randn(64, input_size).to(device)
    
    results = []
    
    # Float32 baseline
    layer_fp32 = nn.Linear(input_size, output_size).to(device)
    torch.cuda.synchronize() if device.type == 'cuda' else None
    start = time.time()
    for _ in range(num_runs):
        _ = layer_fp32(x)
    torch.cuda.synchronize() if device.type == 'cuda' else None
    fp32_time = (time.time() - start) / num_runs * 1000  # ms
    fp32_size = sum(p.numel() * 4 for p in layer_fp32.parameters()) / 1024  # KB
    
    results.append({
        'Method': 'Float32',
        'Time (ms)': fp32_time,
        'Size (KB)': fp32_size,
        'Speedup': 1.0,
        'Compression': 1.0
    })
    
    # INT8
    layer_int8 = LinearINT8(input_size, output_size).to(device)
    torch.cuda.synchronize() if device.type == 'cuda' else None
    start = time.time()
    for _ in range(num_runs):
        _ = layer_int8(x)
    torch.cuda.synchronize() if device.type == 'cuda' else None
    int8_time = (time.time() - start) / num_runs * 1000
    int8_size = sum(p.numel() * 1 for p in layer_int8.parameters()) / 1024  # 8 bits = 1 byte
    
    results.append({
        'Method': 'INT8',
        'Time (ms)': int8_time,
        'Size (KB)': int8_size,
        'Speedup': fp32_time / int8_time,
        'Compression': fp32_size / int8_size
    })
    
    # Ternary (simulated with 2-bit storage)
    try:
        from examples.mnist_ternary import LinearTernary
        layer_ternary = LinearTernary(input_size, output_size).to(device)
        torch.cuda.synchronize() if device.type == 'cuda' else None
        start = time.time()
        for _ in range(num_runs):
            _ = layer_ternary(x)
        torch.cuda.synchronize() if device.type == 'cuda' else None
        ternary_time = (time.time() - start) / num_runs * 1000
        ternary_size = sum(p.numel() * 0.25 for p in layer_ternary.parameters()) / 1024  # 2 bits
        
        results.append({
            'Method': 'Ternary',
            'Time (ms)': ternary_time,
            'Size (KB)': ternary_size,
            'Speedup': fp32_time / ternary_time,
            'Compression': fp32_size / ternary_size
        })
    except ImportError:
        pass
    
    return pd.DataFrame(results)


# Run benchmark
print("Running benchmarks...")
df_results = benchmark_quantization()

# Visualize results
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Inference time
axes[0].bar(df_results['Method'], df_results['Time (ms)'], 
           color=['blue', 'green', 'orange'][:len(df_results)],
           alpha=0.7, edgecolor='black', linewidth=2)
axes[0].set_ylabel('Time (ms)', fontsize=12)
axes[0].set_title('Inference Time', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
for i, (idx, row) in enumerate(df_results.iterrows()):
    axes[0].text(i, row['Time (ms)'], f"{row['Time (ms)']:.3f}",
                ha='center', va='bottom', fontweight='bold')

# Model size
axes[1].bar(df_results['Method'], df_results['Size (KB)'],
           color=['blue', 'green', 'orange'][:len(df_results)],
           alpha=0.7, edgecolor='black', linewidth=2)
axes[1].set_ylabel('Size (KB)', fontsize=12)
axes[1].set_title('Model Size', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)
for i, (idx, row) in enumerate(df_results.iterrows()):
    axes[1].text(i, row['Size (KB)'], f"{row['Size (KB)']:.1f}",
                ha='center', va='bottom', fontweight='bold')

# Speedup and compression
x = np.arange(len(df_results))
width = 0.35
axes[2].bar(x - width/2, df_results['Speedup'], width, label='Speedup',
           color='green', alpha=0.7, edgecolor='black')
axes[2].bar(x + width/2, df_results['Compression'], width, label='Compression',
           color='orange', alpha=0.7, edgecolor='black')
axes[2].set_ylabel('Ratio (x)', fontsize=12)
axes[2].set_title('Speedup & Compression', fontsize=14, fontweight='bold')
axes[2].set_xticks(x)
axes[2].set_xticklabels(df_results['Method'])
axes[2].legend(fontsize=11)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print results
print("\n" + "="*70)
print("QUANTIZATION BENCHMARK RESULTS")
print("="*70)
print(df_results.to_string(index=False, float_format=lambda x: f'{x:.3f}'))
print("="*70)

## 8. Practical Guidelines {#guidelines}

### When to Use Each Method:

#### Ternary Quantization
**Best for:**
- Extreme memory constraints
- Edge devices with limited storage
- Models where 2-3% accuracy drop is acceptable
- Binary classification tasks

**Pros:** 16x compression, very fast inference  
**Cons:** 2-4% accuracy drop, requires QAT

#### INT8 Quantization
**Best for:**
- General purpose deployment
- Production systems
- Models requiring high accuracy
- Hardware with INT8 support

**Pros:** 4x compression, <1% accuracy drop, PTQ works well  
**Cons:** Less compression than ternary

#### Mixed Precision
**Best for:**
- Balancing accuracy and efficiency
- Deep networks
- When some layers are more sensitive

**Pros:** Flexible, good accuracy-efficiency trade-off  
**Cons:** More complex to implement and tune

### Quantization Workflow:

```
1. Train baseline Float32 model
   â†“
2. Try PTQ first (quick baseline)
   â†“
3. If accuracy drop > 2%:
   â†’ Use QAT
   â†’ Try mixed precision
   â†“
4. Fine-tune quantization parameters
   â†“
5. Benchmark on target hardware
```

### Tips for Better Quantization:

1. **Use Batch Normalization**: Helps stabilize quantized activations
2. **Avoid quantizing first/last layers**: They're most sensitive
3. **Per-channel quantization**: Better than per-tensor for CNNs
4. **Calibration dataset**: Use representative data for PTQ
5. **Gradual quantization**: Start with higher bits, gradually reduce

---

## Summary

In this notebook, you learned:

âœ… **Fundamentals** of neural network quantization  
âœ… **Ternary quantization** methods (deterministic, stochastic, learned)  
âœ… **INT8 quantization** for better accuracy-compression trade-offs  
âœ… **Mixed precision** strategies for optimal performance  
âœ… **QAT vs PTQ** comparison and when to use each  
âœ… **Performance benchmarks** across different methods  
âœ… **Practical guidelines** for production deployment  

### Next Steps:

ðŸ“˜ [03_performance_analysis.ipynb](03_performance_analysis.ipynb) - Deep dive into performance optimization

---

*Triton DSL - Advanced Quantization for Neural Networks*