# Quantization Exercises

**From [rlbook.ai](https://rlbook.ai/chapters/quantization/summary-exercises)**

Practice implementing quantization from scratch. Each exercise has:
- A problem description
- Starter code with TODOs
- A solution (hidden by default)

Try to solve each exercise before looking at the solution!

## Setup

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# For Exercise 4
import torch
import torch.nn as nn

print(f"NumPy version: {np.__version__}")
print(f"PyTorch version: {torch.__version__}")

---
## Exercise 1: Implement Symmetric Quantization

**Goal:** Implement functions that quantize float values to integers and recover them.

**Background:**
- Symmetric quantization maps the range `[-max, +max]` to `[-127, +127]` for int8
- The scale factor is: `scale = max(|x|) / 127`
- To quantize: `q = round(x / scale)`
- To dequantize: `x_approx = q * scale`

In [None]:
# Exercise 1: Your code here

def symmetric_quantize(x: np.ndarray, bits: int = 8) -> tuple[np.ndarray, float]:
    """
    Quantize array x to signed integers with given bit width.

    Args:
        x: Input array of float values
        bits: Number of bits (default 8)

    Returns:
        Tuple of (quantized values as integers, scale factor)
    """
    # TODO: Calculate qmax (maximum quantized value for signed int)
    # Hint: For 8 bits, signed range is -128 to 127, so qmax = 127
    qmax = None  # Your code here

    # TODO: Calculate scale factor
    # Hint: scale = max(|x|) / qmax
    scale = None  # Your code here

    # TODO: Quantize values
    # Hint: round(x / scale), then clip to [-qmax, qmax], then cast to int8
    q = None  # Your code here

    return q, scale


def symmetric_dequantize(q: np.ndarray, scale: float) -> np.ndarray:
    """Recover approximate float values from quantized integers."""
    # TODO: Multiply by scale to recover approximate values
    # Hint: Cast q to float32 first
    return None  # Your code here


# Test your implementation
weights = np.array([0.247, -0.103, 0.089, -0.156, 0.312, -0.278])
q, scale = symmetric_quantize(weights, bits=8)
recovered = symmetric_dequantize(q, scale)

print(f"Original:  {weights}")
print(f"Quantized: {q}")
print(f"Scale:     {scale:.6f}")
print(f"Recovered: {recovered}")
print(f"Max error: {np.abs(weights - recovered).max():.6f}")

### Exercise 1: Solution

Click to expand the solution after you've tried it yourself.

In [None]:
# @title Exercise 1: Solution (double-click to show)

def symmetric_quantize_solution(x: np.ndarray, bits: int = 8) -> tuple[np.ndarray, float]:
    """Symmetric quantization to signed integers."""
    qmax = 2**(bits - 1) - 1  # 127 for int8
    scale = np.max(np.abs(x)) / qmax
    q = np.round(x / scale).clip(-qmax, qmax).astype(np.int8)
    return q, scale


def symmetric_dequantize_solution(q: np.ndarray, scale: float) -> np.ndarray:
    """Recover approximate float values from quantized integers."""
    return q.astype(np.float32) * scale


# Test
weights = np.array([0.247, -0.103, 0.089, -0.156, 0.312, -0.278])
q, scale = symmetric_quantize_solution(weights, bits=8)
recovered = symmetric_dequantize_solution(q, scale)

print(f"Original:  {weights}")
print(f"Quantized: {q}")  # Expected: [100, -42, 36, -63, 127, -113]
print(f"Scale:     {scale:.6f}")  # Expected: ~0.002457
print(f"Recovered: {recovered}")
print(f"Max error: {np.abs(weights - recovered).max():.6f}")  # Expected: ~0.001

---
## Exercise 2: Compare Per-Tensor vs Per-Channel

**Goal:** Implement per-channel quantization and compare the error to per-tensor.

**Background:**
- Per-tensor: One scale for the entire weight matrix
- Per-channel: One scale per output channel (row)
- Per-channel is better when channels have different weight magnitudes

In [None]:
# Exercise 2: Your code here

def per_tensor_quantize(weights: np.ndarray, bits: int = 8) -> tuple[np.ndarray, float]:
    """Quantize entire tensor with single scale (provided for reference)."""
    qmax = 2**(bits - 1) - 1
    scale = np.max(np.abs(weights)) / qmax
    q = np.round(weights / scale).clip(-qmax, qmax).astype(np.int8)
    return q, scale


def per_channel_quantize(weights: np.ndarray, bits: int = 8) -> tuple[np.ndarray, np.ndarray]:
    """
    Quantize each row (output channel) with its own scale.

    Args:
        weights: 2D array of shape (out_channels, in_channels)
        bits: Number of bits

    Returns:
        Tuple of (quantized values, array of scales per channel)
    """
    qmax = 2**(bits - 1) - 1

    # TODO: Calculate per-row scales (one scale per output channel)
    # Hint: Use np.max with axis=1 to get max per row
    scales = None  # Your code here

    # TODO: Quantize each row with its scale
    # Hint: Use broadcasting with scales[:, None] to divide each row by its scale
    q = None  # Your code here

    return q, scales


# Create weights with very different magnitudes per channel
np.random.seed(42)
weights = np.random.randn(4, 64)
weights[0] *= 0.01   # Tiny weights
weights[1] *= 0.1    # Small weights
weights[2] *= 1.0    # Normal weights
weights[3] *= 10.0   # Large weights

print("Weight magnitudes per channel:")
for i in range(4):
    print(f"  Channel {i}: max = {np.abs(weights[i]).max():.4f}")

# Quantize both ways
q_tensor, scale_tensor = per_tensor_quantize(weights)
q_channel, scales_channel = per_channel_quantize(weights)

# TODO: Calculate reconstruction errors
# Hint: For per-tensor, multiply all by scale_tensor
# Hint: For per-channel, multiply each row by its scale (use scales_channel[:, None])
recovered_tensor = None  # Your code here
recovered_channel = None  # Your code here

mse_tensor = np.mean((weights - recovered_tensor)**2)
mse_channel = np.mean((weights - recovered_channel)**2)

print(f"\nPer-tensor MSE:  {mse_tensor:.8f}")
print(f"Per-channel MSE: {mse_channel:.8f}")
print(f"Improvement:     {mse_tensor / mse_channel:.1f}x better")

### Exercise 2: Solution

In [None]:
# @title Exercise 2: Solution (double-click to show)

def per_tensor_quantize_solution(weights: np.ndarray, bits: int = 8) -> tuple[np.ndarray, float]:
    qmax = 2**(bits - 1) - 1
    scale = np.max(np.abs(weights)) / qmax
    q = np.round(weights / scale).clip(-qmax, qmax).astype(np.int8)
    return q, scale


def per_channel_quantize_solution(weights: np.ndarray, bits: int = 8) -> tuple[np.ndarray, np.ndarray]:
    qmax = 2**(bits - 1) - 1
    # One scale per row
    scales = np.max(np.abs(weights), axis=1) / qmax
    # Quantize each row with its own scale
    q = np.round(weights / scales[:, None]).clip(-qmax, qmax).astype(np.int8)
    return q, scales


# Create weights with very different magnitudes per channel
np.random.seed(42)
weights = np.random.randn(4, 64)
weights[0] *= 0.01
weights[1] *= 0.1
weights[2] *= 1.0
weights[3] *= 10.0

# Quantize both ways
q_tensor, scale_tensor = per_tensor_quantize_solution(weights)
q_channel, scales_channel = per_channel_quantize_solution(weights)

# Reconstruct
recovered_tensor = q_tensor.astype(np.float32) * scale_tensor
recovered_channel = q_channel.astype(np.float32) * scales_channel[:, None]

# Compare errors
mse_tensor = np.mean((weights - recovered_tensor)**2)
mse_channel = np.mean((weights - recovered_channel)**2)

print(f"Per-tensor MSE:  {mse_tensor:.8f}")
print(f"Per-channel MSE: {mse_channel:.8f}")
print(f"Improvement:     {mse_tensor / mse_channel:.1f}x better")

# Per-channel is ~100x better when channel scales vary widely!

---
## Exercise 3: Visualize Quantization Error vs Bit Width

**Goal:** Create a plot showing how quantization error decreases as you use more bits.

**Background:**
- More bits = more discrete levels = less error
- Each additional bit roughly doubles the number of levels
- Error decreases exponentially (roughly 4x per bit)

In [None]:
# Exercise 3: Your code here

def quantize_and_measure_error(x: np.ndarray, bits: int) -> float:
    """Quantize x and return the mean squared error."""
    qmax = 2**(bits - 1) - 1
    scale = np.max(np.abs(x)) / qmax
    q = np.round(x / scale).clip(-qmax, qmax)
    recovered = q * scale
    return np.mean((x - recovered)**2)


# Generate random weights
np.random.seed(42)
weights = np.random.randn(1000)

# TODO: Test different bit widths (2, 3, 4, 5, 6, 7, 8)
bit_widths = [2, 3, 4, 5, 6, 7, 8]
errors = None  # Your code here: list of errors for each bit width

# TODO: Create the plot
# Hint: Use plt.semilogy() for log y-axis
# Hint: Add labels, title, grid

# Your plotting code here

### Exercise 3: Solution

In [None]:
# @title Exercise 3: Solution (double-click to show)

def quantize_and_measure_error_solution(x: np.ndarray, bits: int) -> float:
    qmax = 2**(bits - 1) - 1
    scale = np.max(np.abs(x)) / qmax
    q = np.round(x / scale).clip(-qmax, qmax)
    recovered = q * scale
    return np.mean((x - recovered)**2)


np.random.seed(42)
weights = np.random.randn(1000)

bit_widths = [2, 3, 4, 5, 6, 7, 8]
errors = [quantize_and_measure_error_solution(weights, b) for b in bit_widths]

plt.figure(figsize=(10, 6))
plt.semilogy(bit_widths, errors, 'o-', markersize=10, linewidth=2, color='#06b6d4')
plt.xlabel('Bits', fontsize=12)
plt.ylabel('Mean Squared Error (log scale)', fontsize=12)
plt.title('Quantization Error vs Bit Width', fontsize=14)
plt.grid(True, alpha=0.3)
plt.xticks(bit_widths)

# Add annotations
for b, e in zip(bit_widths, errors):
    plt.annotate(f'{e:.2e}', (b, e), textcoords="offset points",
                 xytext=(0, 10), ha='center', fontsize=9)

plt.tight_layout()
plt.show()

# Print the reduction factor
print("\nError reduction per additional bit:")
for i in range(1, len(errors)):
    print(f"  {bit_widths[i-1]} â†’ {bit_widths[i]} bits: {errors[i-1]/errors[i]:.1f}x reduction")

---
## Exercise 4: Quantize a PyTorch Model

**Goal:** Use PyTorch's built-in quantization and measure the benefits.

**Background:**
- `torch.quantization.quantize_dynamic` does post-training quantization
- It quantizes Linear layers to int8
- Benefits: smaller model size, faster inference

In [None]:
# Exercise 4: Your code here
import time
from io import BytesIO

# Simple model
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)


# Create model
model = SimpleNet()
model.eval()

# TODO: Use torch.quantization.quantize_dynamic to quantize the model
# Hint: quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)
quantized_model = None  # Your code here


# TODO: Compare model sizes
# Hint: Save to BytesIO buffer, check buffer.tell() for size
def get_model_size(m):
    buffer = BytesIO()
    torch.save(m.state_dict(), buffer)
    return buffer.tell()

original_size = get_model_size(model)
quantized_size = None  # Your code here

print(f"Original size:  {original_size / 1024:.1f} KB")
print(f"Quantized size: {quantized_size / 1024:.1f} KB")
print(f"Compression:    {original_size / quantized_size:.1f}x")


# TODO: Compare inference speed
# Hint: Time 1000 forward passes with random input
x = torch.randn(1, 784)
n_runs = 1000

# Warmup
for _ in range(100):
    _ = model(x)
    _ = quantized_model(x)

# Your timing code here

### Exercise 4: Solution

In [None]:
# @title Exercise 4: Solution (double-click to show)
import time
from io import BytesIO


class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        return self.fc2(x)


model = SimpleNet()
model.eval()

# Dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

# Compare sizes
def get_model_size(m):
    buffer = BytesIO()
    torch.save(m.state_dict(), buffer)
    return buffer.tell()

original_size = get_model_size(model)
quantized_size = get_model_size(quantized_model)

print(f"Original size:  {original_size / 1024:.1f} KB")
print(f"Quantized size: {quantized_size / 1024:.1f} KB")
print(f"Compression:    {original_size / quantized_size:.1f}x")

# Compare speed
x = torch.randn(1, 784)
n_runs = 1000

# Warmup
for _ in range(100):
    _ = model(x)
    _ = quantized_model(x)

# Time original
start = time.time()
for _ in range(n_runs):
    _ = model(x)
original_time = time.time() - start

# Time quantized
start = time.time()
for _ in range(n_runs):
    _ = quantized_model(x)
quantized_time = time.time() - start

print(f"\nOriginal time:  {original_time*1000:.1f} ms for {n_runs} runs")
print(f"Quantized time: {quantized_time*1000:.1f} ms for {n_runs} runs")
print(f"Speedup:        {original_time/quantized_time:.2f}x")

---
## Summary

In these exercises, you learned:

1. **Symmetric quantization**: Map floats to integers using a scale factor
2. **Per-channel vs per-tensor**: Per-channel is much better when weights vary across channels
3. **Error vs bits**: Each additional bit roughly halves the quantization error
4. **PyTorch quantization**: `quantize_dynamic` gives ~3-4x size reduction with minimal code

**Next steps:**
- Try quantizing a larger model (ResNet, BERT)
- Experiment with static quantization (requires calibration data)
- Explore quantization-aware training for better accuracy