# INT8 Quantization from Scratch

## Problem Statement

Large language models require massive memory and compute. A 70B parameter model in FP16 needs ~140GB just for weights! **Quantization** reduces this by using lower-precision data types.

Your task is to implement **INT8 quantization from scratch** and understand the trade-offs between memory savings and accuracy.

---

## Background

### Why Quantization?

| Precision | Bits | Memory per 7B params | Typical Use |
|-----------|------|---------------------|-------------|
| FP32 | 32 | 28 GB | Training |
| FP16/BF16 | 16 | 14 GB | Training/Inference |
| INT8 | 8 | 7 GB | Inference |
| INT4 | 4 | 3.5 GB | Inference |

### Quantization Types

1. **Post-Training Quantization (PTQ)**: Quantize after training (what we'll implement)
2. **Quantization-Aware Training (QAT)**: Train with quantization in the loop

### Quantization Schemes

1. **Symmetric**: Zero-point is 0, range is `[-max, max]`
2. **Asymmetric**: Zero-point can be non-zero, range is `[min, max]`

### The Math

**Symmetric Quantization:**
```
scale = max(|x|) / 127
x_quant = round(x / scale)  # clamp to [-128, 127]
x_dequant = x_quant * scale
```

**Asymmetric Quantization:**
```
scale = (max(x) - min(x)) / 255
zero_point = round(-min(x) / scale)
x_quant = round(x / scale) + zero_point  # clamp to [0, 255]
x_dequant = (x_quant - zero_point) * scale
```

---

## Requirements

1. Implement `symmetric_quantize()` and `symmetric_dequantize()`
2. Implement `asymmetric_quantize()` and `asymmetric_dequantize()`
3. Create `QuantizedLinear` layer that stores INT8 weights
4. Compare memory usage and accuracy vs FP32

---

<details>
<summary>Hint 1: Scale Calculation</summary>

For symmetric quantization to INT8:
- Range is [-128, 127], so use 127 as the max quantized value
- Scale = max(|tensor|) / 127

</details>

<details>
<summary>Hint 2: Clamping</summary>

After dividing by scale and rounding, you must clamp:
- Symmetric: `torch.clamp(x, -128, 127).to(torch.int8)`
- Asymmetric: `torch.clamp(x, 0, 255).to(torch.uint8)`

</details>

<details>
<summary>Hint 3: Per-Channel vs Per-Tensor</summary>

Per-tensor uses one scale for entire tensor. Per-channel uses one scale per output channel, giving better accuracy for weight quantization.

</details>

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

## Part 1: Symmetric Quantization

Symmetric quantization is simpler and more common for weights:
- Zero-point is always 0
- Range is symmetric around zero: `[-max, max]`
- Uses signed INT8: `[-128, 127]`

In [None]:
def symmetric_quantize(x: torch.Tensor, num_bits: int = 8) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Symmetric quantization to signed integers.
    
    Args:
        x: Input tensor (float)
        num_bits: Number of bits (default 8 for INT8)
    
    Returns:
        x_quant: Quantized tensor (int8)
        scale: Scale factor for dequantization
    """
    # Calculate the maximum quantized value (127 for 8-bit)
    qmax = 2 ** (num_bits - 1) - 1  # 127 for 8-bit
    qmin = -qmax - 1  # -128 for 8-bit
    
    # Calculate scale: largest absolute value maps to qmax
    max_val = x.abs().max()
    scale = max_val / qmax
    
    # Avoid division by zero
    if scale == 0:
        scale = torch.tensor(1.0)
    
    # Quantize: divide by scale, round, and clamp
    x_quant = torch.round(x / scale)
    x_quant = torch.clamp(x_quant, qmin, qmax).to(torch.int8)
    
    return x_quant, scale


def symmetric_dequantize(x_quant: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
    """
    Dequantize from signed integers back to float.
    
    Args:
        x_quant: Quantized tensor (int8)
        scale: Scale factor
    
    Returns:
        x_dequant: Dequantized tensor (float)
    """
    return x_quant.float() * scale

In [None]:
# Test symmetric quantization
print("=== Testing Symmetric Quantization ===")

torch.manual_seed(42)
x = torch.randn(4, 4) * 2  # Random values roughly in [-4, 4]
print(f"Original tensor (FP32):")
print(x)
print(f"\nOriginal dtype: {x.dtype}")
print(f"Original memory: {x.numel() * 4} bytes")

# Quantize
x_quant, scale = symmetric_quantize(x)
print(f"\nQuantized tensor (INT8):")
print(x_quant)
print(f"\nScale: {scale.item():.6f}")
print(f"Quantized dtype: {x_quant.dtype}")
print(f"Quantized memory: {x_quant.numel() * 1} bytes (4x smaller!)")

# Dequantize and check error
x_dequant = symmetric_dequantize(x_quant, scale)
print(f"\nDequantized tensor:")
print(x_dequant)

error = (x - x_dequant).abs()
print(f"\nMax absolute error: {error.max().item():.6f}")
print(f"Mean absolute error: {error.mean().item():.6f}")
print(f"Relative error: {(error / x.abs().clamp(min=1e-6)).mean().item() * 100:.2f}%")

## Part 2: Asymmetric Quantization

Asymmetric quantization is better for activations (often positive):
- Zero-point can be non-zero
- Range is `[min, max]`
- Uses unsigned INT8: `[0, 255]`

In [None]:
def asymmetric_quantize(x: torch.Tensor, num_bits: int = 8) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Asymmetric quantization to unsigned integers.
    
    Args:
        x: Input tensor (float)
        num_bits: Number of bits (default 8)
    
    Returns:
        x_quant: Quantized tensor (uint8)
        scale: Scale factor
        zero_point: Zero point offset
    """
    qmax = 2 ** num_bits - 1  # 255 for 8-bit
    qmin = 0
    
    # Find min and max of input
    x_min = x.min()
    x_max = x.max()
    
    # Calculate scale and zero point
    scale = (x_max - x_min) / qmax
    
    # Avoid division by zero
    if scale == 0:
        scale = torch.tensor(1.0)
    
    # Zero point: where 0.0 maps to in quantized space
    zero_point = torch.round(-x_min / scale)
    zero_point = torch.clamp(zero_point, qmin, qmax)
    
    # Quantize
    x_quant = torch.round(x / scale) + zero_point
    x_quant = torch.clamp(x_quant, qmin, qmax).to(torch.uint8)
    
    return x_quant, scale, zero_point


def asymmetric_dequantize(x_quant: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor) -> torch.Tensor:
    """
    Dequantize from unsigned integers back to float.
    
    Args:
        x_quant: Quantized tensor (uint8)
        scale: Scale factor
        zero_point: Zero point offset
    
    Returns:
        x_dequant: Dequantized tensor (float)
    """
    return (x_quant.float() - zero_point) * scale

In [None]:
# Test asymmetric quantization - good for ReLU activations (mostly positive)
print("=== Testing Asymmetric Quantization ===")

torch.manual_seed(42)
# Simulate ReLU activations (positive values)
x = F.relu(torch.randn(4, 4) * 2)
print(f"Original tensor (post-ReLU):")
print(x)

# Quantize
x_quant, scale, zero_point = asymmetric_quantize(x)
print(f"\nQuantized tensor (UINT8):")
print(x_quant)
print(f"\nScale: {scale.item():.6f}")
print(f"Zero point: {zero_point.item()}")

# Dequantize and check error
x_dequant = asymmetric_dequantize(x_quant, scale, zero_point)
print(f"\nDequantized tensor:")
print(x_dequant)

error = (x - x_dequant).abs()
print(f"\nMax absolute error: {error.max().item():.6f}")
print(f"Mean absolute error: {error.mean().item():.6f}")

## Part 3: Per-Channel Quantization

Per-tensor quantization uses a single scale for the entire tensor. Per-channel quantization uses a different scale for each output channel, giving better accuracy.

This is especially important for weights where different channels may have very different magnitudes.

In [None]:
def per_channel_symmetric_quantize(
    weight: torch.Tensor, 
    axis: int = 0,
    num_bits: int = 8
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Per-channel symmetric quantization for weights.
    
    Args:
        weight: Weight tensor (out_features, in_features)
        axis: Channel axis (0 for output channels)
        num_bits: Number of bits
    
    Returns:
        weight_quant: Quantized weights (int8)
        scales: Per-channel scales
    """
    qmax = 2 ** (num_bits - 1) - 1
    qmin = -qmax - 1
    
    # Calculate per-channel max absolute values
    # For (out_features, in_features), reduce over in_features
    reduce_dims = [i for i in range(weight.dim()) if i != axis]
    max_vals = weight.abs().amax(dim=reduce_dims, keepdim=True)
    
    # Calculate scales
    scales = max_vals / qmax
    scales = torch.where(scales == 0, torch.ones_like(scales), scales)
    
    # Quantize
    weight_quant = torch.round(weight / scales)
    weight_quant = torch.clamp(weight_quant, qmin, qmax).to(torch.int8)
    
    return weight_quant, scales.squeeze()


def per_channel_symmetric_dequantize(
    weight_quant: torch.Tensor,
    scales: torch.Tensor,
    axis: int = 0
) -> torch.Tensor:
    """
    Dequantize per-channel quantized weights.
    """
    # Reshape scales for broadcasting
    shape = [1] * weight_quant.dim()
    shape[axis] = -1
    scales = scales.view(*shape)
    
    return weight_quant.float() * scales

In [None]:
# Compare per-tensor vs per-channel quantization
print("=== Per-Tensor vs Per-Channel Quantization ===")

torch.manual_seed(42)
# Create weights with varying magnitudes per channel
weight = torch.randn(4, 8)
weight[0] *= 10  # First channel has much larger values
weight[3] *= 0.1  # Last channel has much smaller values

print(f"Weight tensor (varying magnitudes per channel):")
print(f"Channel 0 max: {weight[0].abs().max().item():.4f}")
print(f"Channel 3 max: {weight[3].abs().max().item():.4f}")

# Per-tensor quantization
w_quant_tensor, scale_tensor = symmetric_quantize(weight)
w_dequant_tensor = symmetric_dequantize(w_quant_tensor, scale_tensor)
error_tensor = (weight - w_dequant_tensor).abs()

print(f"\n--- Per-Tensor Quantization ---")
print(f"Single scale: {scale_tensor.item():.6f}")
print(f"Mean error: {error_tensor.mean().item():.6f}")
print(f"Max error: {error_tensor.max().item():.6f}")
print(f"Channel 3 error (small values lost): {error_tensor[3].mean().item():.6f}")

# Per-channel quantization
w_quant_channel, scales_channel = per_channel_symmetric_quantize(weight)
w_dequant_channel = per_channel_symmetric_dequantize(w_quant_channel, scales_channel)
error_channel = (weight - w_dequant_channel).abs()

print(f"\n--- Per-Channel Quantization ---")
print(f"Scales: {scales_channel.tolist()}")
print(f"Mean error: {error_channel.mean().item():.6f}")
print(f"Max error: {error_channel.max().item():.6f}")
print(f"Channel 3 error (preserved!): {error_channel[3].mean().item():.6f}")

print(f"\n>>> Per-channel reduces error by {error_tensor.mean() / error_channel.mean():.1f}x!")

## Part 4: Quantized Linear Layer

Now let's create a quantized linear layer that stores INT8 weights but computes in FP32.

In [None]:
class QuantizedLinear(nn.Module):
    """
    Linear layer with INT8 quantized weights.
    
    Stores weights as INT8 to save memory.
    Dequantizes to FP32 for computation (can be optimized with INT8 matmul kernels).
    """
    
    def __init__(self, in_features: int, out_features: int, bias: bool = True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Quantized weights (INT8) - registered as buffer, not parameter
        self.register_buffer('weight_quant', torch.zeros(out_features, in_features, dtype=torch.int8))
        self.register_buffer('weight_scales', torch.zeros(out_features))
        
        if bias:
            self.register_buffer('bias', torch.zeros(out_features))
        else:
            self.register_buffer('bias', None)
    
    @classmethod
    def from_float(cls, linear: nn.Linear) -> 'QuantizedLinear':
        """
        Create a quantized linear layer from a float linear layer.
        """
        quant_linear = cls(linear.in_features, linear.out_features, bias=linear.bias is not None)
        
        # Quantize weights using per-channel quantization
        weight_quant, scales = per_channel_symmetric_quantize(linear.weight.data)
        quant_linear.weight_quant.copy_(weight_quant)
        quant_linear.weight_scales.copy_(scales)
        
        if linear.bias is not None:
            quant_linear.bias.copy_(linear.bias.data)
        
        return quant_linear
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with dequantized weights.
        
        Note: In production, you'd use INT8 matmul kernels for speed.
        Here we dequantize for simplicity.
        """
        # Dequantize weights
        weight = per_channel_symmetric_dequantize(self.weight_quant, self.weight_scales)
        
        # Compute linear
        return F.linear(x, weight, self.bias)
    
    def memory_bytes(self) -> int:
        """Calculate memory usage of quantized weights."""
        weight_bytes = self.weight_quant.numel() * 1  # INT8 = 1 byte
        scale_bytes = self.weight_scales.numel() * 4  # FP32 = 4 bytes
        bias_bytes = self.bias.numel() * 4 if self.bias is not None else 0
        return weight_bytes + scale_bytes + bias_bytes

In [None]:
# Test quantized linear layer
print("=== Testing Quantized Linear Layer ===")

torch.manual_seed(42)
in_features = 512
out_features = 256
batch_size = 4

# Create float linear layer
linear_fp32 = nn.Linear(in_features, out_features)

# Quantize it
linear_int8 = QuantizedLinear.from_float(linear_fp32)

# Compare memory
fp32_bytes = linear_fp32.weight.numel() * 4 + (linear_fp32.bias.numel() * 4 if linear_fp32.bias is not None else 0)
int8_bytes = linear_int8.memory_bytes()

print(f"FP32 linear: {fp32_bytes:,} bytes ({fp32_bytes / 1024:.1f} KB)")
print(f"INT8 linear: {int8_bytes:,} bytes ({int8_bytes / 1024:.1f} KB)")
print(f"Compression: {fp32_bytes / int8_bytes:.2f}x")

# Compare outputs
x = torch.randn(batch_size, in_features)
out_fp32 = linear_fp32(x)
out_int8 = linear_int8(x)

error = (out_fp32 - out_int8).abs()
print(f"\nOutput comparison:")
print(f"Max absolute error: {error.max().item():.6f}")
print(f"Mean absolute error: {error.mean().item():.6f}")
print(f"Relative error: {(error / out_fp32.abs().clamp(min=1e-6)).mean().item() * 100:.4f}%")

print("\n✓ Quantized linear layer works!")

## Part 5: Quantize a Full Model

Let's quantize a small MLP and see the impact on accuracy and memory.

In [None]:
class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


def quantize_model(model: nn.Module) -> nn.Module:
    """
    Quantize all linear layers in a model.
    """
    for name, module in model.named_children():
        if isinstance(module, nn.Linear):
            setattr(model, name, QuantizedLinear.from_float(module))
        else:
            quantize_model(module)
    return model


def count_parameters_bytes(model: nn.Module) -> int:
    """Count total bytes for all parameters and buffers."""
    total = 0
    for name, param in model.named_parameters():
        total += param.numel() * param.element_size()
    for name, buffer in model.named_buffers():
        total += buffer.numel() * buffer.element_size()
    return total

In [None]:
# Create and quantize model
print("=== Quantizing Full Model ===")

torch.manual_seed(42)
input_dim = 784
hidden_dim = 512
output_dim = 10

# Original model
model_fp32 = SimpleMLP(input_dim, hidden_dim, output_dim)
fp32_bytes = count_parameters_bytes(model_fp32)

# Quantized model (copy first to preserve original)
import copy
model_int8 = copy.deepcopy(model_fp32)
model_int8 = quantize_model(model_int8)
int8_bytes = count_parameters_bytes(model_int8)

print(f"FP32 model size: {fp32_bytes:,} bytes ({fp32_bytes / 1024 / 1024:.2f} MB)")
print(f"INT8 model size: {int8_bytes:,} bytes ({int8_bytes / 1024 / 1024:.2f} MB)")
print(f"Compression ratio: {fp32_bytes / int8_bytes:.2f}x")

# Test accuracy
batch_size = 32
x = torch.randn(batch_size, input_dim)

model_fp32.eval()
model_int8.eval()

with torch.no_grad():
    out_fp32 = model_fp32(x)
    out_int8 = model_int8(x)

# Check if predictions match
pred_fp32 = out_fp32.argmax(dim=1)
pred_int8 = out_int8.argmax(dim=1)
accuracy = (pred_fp32 == pred_int8).float().mean().item()

print(f"\nPrediction agreement: {accuracy * 100:.1f}%")
print(f"Output MSE: {F.mse_loss(out_fp32, out_int8).item():.6f}")

print("\n✓ Model quantization complete!")

## Summary

### Key Concepts

1. **Quantization reduces memory** by representing weights in lower precision (INT8 = 4x smaller than FP32)

2. **Symmetric quantization** is simpler, good for weights centered around zero:
   - `scale = max(|x|) / 127`
   - `x_quant = round(x / scale)`

3. **Asymmetric quantization** is better for non-centered data (like ReLU activations):
   - Uses both scale and zero_point
   - Maps `[min, max]` to `[0, 255]`

4. **Per-channel quantization** uses a separate scale per output channel, reducing quantization error significantly

5. **Trade-offs:**
   - Memory: 4x reduction with INT8
   - Accuracy: Small degradation, usually <1% for well-tuned quantization
   - Speed: Can be faster with INT8 kernels (not implemented here)

---

## Interview Tips

**Q: What is the difference between symmetric and asymmetric quantization?**
A: Symmetric uses zero as the center point and maps to signed integers [-128, 127]. Asymmetric uses a zero_point offset and maps to unsigned integers [0, 255]. Symmetric is simpler and better for weights; asymmetric is better for activations that are often positive.

**Q: Why is per-channel quantization better than per-tensor?**
A: Different channels can have very different value ranges. Per-tensor uses one scale for all, so small-magnitude channels lose precision. Per-channel gives each channel its own scale, preserving precision.

**Q: What is the memory savings of INT8 quantization?**
A: Theoretically 4x (32 bits → 8 bits). In practice ~3.5-3.8x due to scale storage overhead.

**Q: When does quantization fail?**
A: When weights have outliers (very large values) that dominate the scale. Solutions: clip outliers, use mixed precision for sensitive layers, or use techniques like SmoothQuant.

**Q: What is dynamic vs static quantization?**
A: Static: calibrate scales once using calibration data. Dynamic: compute scales on-the-fly for each input. Dynamic is simpler but slower.

**Q: What is GPTQ/AWQ?**
A: Advanced weight quantization methods that minimize reconstruction error by considering correlations between weights. GPTQ uses Hessian-based optimization; AWQ protects important weights based on activation magnitudes.

---

## References

- [Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference](https://arxiv.org/abs/1712.05877) (Jacob et al., 2018)
- [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale](https://arxiv.org/abs/2208.07339) (Dettmers et al., 2022)
- [GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers](https://arxiv.org/abs/2210.17323) (Frantar et al., 2022)
- [AWQ: Activation-aware Weight Quantization](https://arxiv.org/abs/2306.00978) (Lin et al., 2023)