# Part 4, Lab 2: Quantization Fundamentals

**Time:** ~45 minutes

Quantization maps floating-point values to lower-precision representations. This lab covers the core math: scale factors, zero points, and error bounds.

## Learning Objectives

1. Understand symmetric vs asymmetric quantization
2. Implement scale factor computation
3. Analyze quantization error
4. Understand per-tensor vs per-channel vs per-group scaling

In [None]:
import numpy as np
import torch

np.random.seed(42)
torch.manual_seed(42)

---
## 1. Symmetric Quantization

The simplest form: map [-max, +max] to [-127, +127] for INT8.

```
scale = max(|x|) / 127
quantized = round(x / scale)
dequantized = quantized * scale
```

In [None]:
def symmetric_quantize_int8(x):
    """Symmetric INT8 quantization."""
    # Compute scale factor
    abs_max = np.abs(x).max()
    scale = abs_max / 127.0
    
    # Quantize
    x_quant = np.round(x / scale).astype(np.int8)
    
    return x_quant, scale

def symmetric_dequantize(x_quant, scale):
    """Dequantize back to float."""
    return x_quant.astype(np.float32) * scale

# Test on simulated weights
weights = np.random.randn(1024, 1024).astype(np.float32)

# Quantize and dequantize
w_quant, scale = symmetric_quantize_int8(weights)
w_dequant = symmetric_dequantize(w_quant, scale)

# Measure error
mse = np.mean((weights - w_dequant) ** 2)
rel_error = np.abs(weights - w_dequant) / (np.abs(weights) + 1e-8)

print(f"Scale factor: {scale:.6f}")
print(f"MSE: {mse:.8f}")
print(f"Mean relative error: {rel_error.mean() * 100:.2f}%")
print(f"Max relative error: {rel_error.max() * 100:.2f}%")
print(f"Memory: {weights.nbytes / 1024:.1f} KB â†’ {w_quant.nbytes / 1024:.1f} KB ({weights.nbytes / w_quant.nbytes:.0f}x reduction)")

---
## 2. Asymmetric Quantization

For values that aren't centered at zero (like ReLU activations), asymmetric quantization is more efficient.

```
scale = (max - min) / 255
zero_point = round(-min / scale)
quantized = round(x / scale) + zero_point
```

In [None]:
def asymmetric_quantize_uint8(x):
    """Asymmetric UINT8 quantization with zero point."""
    x_min, x_max = x.min(), x.max()
    
    # Compute scale and zero point
    scale = (x_max - x_min) / 255.0
    zero_point = int(round(-x_min / scale))
    zero_point = max(0, min(255, zero_point))  # Clamp to uint8 range
    
    # Quantize
    x_quant = np.round(x / scale + zero_point).astype(np.uint8)
    
    return x_quant, scale, zero_point

def asymmetric_dequantize(x_quant, scale, zero_point):
    """Dequantize back to float."""
    return (x_quant.astype(np.float32) - zero_point) * scale

# Test on ReLU-like activations (non-negative)
activations = np.maximum(0, np.random.randn(1024, 1024)).astype(np.float32)

# Quantize both ways
a_sym, scale_sym = symmetric_quantize_int8(activations)
a_asym, scale_asym, zp = asymmetric_quantize_uint8(activations)

# Compare errors
a_dequant_sym = symmetric_dequantize(a_sym, scale_sym)
a_dequant_asym = asymmetric_dequantize(a_asym, scale_asym, zp)

mse_sym = np.mean((activations - a_dequant_sym) ** 2)
mse_asym = np.mean((activations - a_dequant_asym) ** 2)

print(f"Symmetric MSE:   {mse_sym:.8f}")
print(f"Asymmetric MSE:  {mse_asym:.8f}")
print(f"Asymmetric is {mse_sym / mse_asym:.1f}x better for non-negative data")

---
## 3. Per-Channel vs Per-Tensor Scaling

Per-tensor uses one scale for the entire tensor (simple, fast).
Per-channel uses one scale per output channel (better accuracy).

In [None]:
def per_channel_quantize_int8(x, axis=0):
    """Per-channel symmetric INT8 quantization."""
    # Compute scale per channel
    abs_max = np.abs(x).max(axis=axis, keepdims=True)
    scales = abs_max / 127.0
    scales = np.maximum(scales, 1e-8)  # Avoid division by zero
    
    # Quantize
    x_quant = np.round(x / scales).astype(np.int8)
    
    return x_quant, scales.squeeze()

# Simulate weights with varying magnitudes per channel
weights = np.random.randn(256, 1024).astype(np.float32)
weights *= np.random.uniform(0.1, 10, size=(256, 1))  # Different scale per row

# Compare per-tensor vs per-channel
w_pt, scale_pt = symmetric_quantize_int8(weights)
w_pc, scales_pc = per_channel_quantize_int8(weights, axis=1)

# Dequantize
w_dequant_pt = symmetric_dequantize(w_pt, scale_pt)
w_dequant_pc = w_pc.astype(np.float32) * scales_pc[:, np.newaxis]

mse_pt = np.mean((weights - w_dequant_pt) ** 2)
mse_pc = np.mean((weights - w_dequant_pc) ** 2)

print(f"Per-tensor MSE:  {mse_pt:.8f}")
print(f"Per-channel MSE: {mse_pc:.8f}")
print(f"Per-channel is {mse_pt / mse_pc:.1f}x better for varying scales")

---
## 4. Group Quantization (Block Scaling)

Modern quantization (GPTQ, AWQ, NVFP4) uses small groups (16-128 elements) for better accuracy.

In [None]:
def group_quantize_int4(x, group_size=128):
    """Group-wise INT4 quantization."""
    original_shape = x.shape
    x_flat = x.reshape(-1)
    
    # Pad to multiple of group_size
    pad_size = (group_size - len(x_flat) % group_size) % group_size
    x_padded = np.pad(x_flat, (0, pad_size), mode='constant')
    
    # Reshape into groups
    x_groups = x_padded.reshape(-1, group_size)
    
    # Compute scale per group (INT4 range: -8 to 7)
    abs_max = np.abs(x_groups).max(axis=1, keepdims=True)
    scales = abs_max / 7.0
    scales = np.maximum(scales, 1e-8)
    
    # Quantize to INT4 range
    x_quant = np.round(x_groups / scales).astype(np.int8)
    x_quant = np.clip(x_quant, -8, 7)
    
    return x_quant, scales.squeeze(), original_shape, pad_size

def group_dequantize_int4(x_quant, scales, original_shape, pad_size, group_size=128):
    """Dequantize group-wise INT4."""
    x_dequant = x_quant.astype(np.float32) * scales[:, np.newaxis]
    x_flat = x_dequant.reshape(-1)
    if pad_size > 0:
        x_flat = x_flat[:-pad_size]
    return x_flat.reshape(original_shape)

# Test group quantization
weights = np.random.randn(1024, 1024).astype(np.float32)

# Different group sizes
for group_size in [32, 64, 128]:
    w_quant, scales, shape, pad = group_quantize_int4(weights, group_size)
    w_dequant = group_dequantize_int4(w_quant, scales, shape, pad, group_size)
    mse = np.mean((weights - w_dequant) ** 2)
    num_scales = len(scales)
    overhead = num_scales * 2 / weights.size * 100  # Assuming FP16 scales
    print(f"Group size {group_size:3d}: MSE = {mse:.6f}, scale overhead = {overhead:.1f}%")

---
## Exercises

1. **Calibration**: Implement calibration that finds optimal scale factors using a calibration dataset
2. **Mixed Precision**: Implement a scheme where sensitive layers use INT8 and others use INT4
3. **Outlier Handling**: Implement SmoothQuant-style outlier migration

## Key Takeaways

- Symmetric quantization is simpler; asymmetric is better for non-symmetric distributions
- Per-channel/per-group scaling dramatically improves accuracy
- Smaller groups = better accuracy but more scale factor overhead
- INT4 with group_size=128 is a good balance (used by GPTQ, AWQ)