# Static Quantization Notes

## Core Concept

Both weights AND activations quantized offline using calibration data. Better performance than dynamic, but needs more setup.

- Pre-compute quantization params for everything
- Integer-only inference (no FP ops at runtime)
- Ideal for CNNs in production

## Pipeline

```
FP32 model → Prune → Optimize → Calibrate → Quantize → INT8 model
```

## Step 1: Pruning (Optional but Recommended)

**What is pruning?**
Remove unimportant weights by setting them to zero. Common approach: magnitude-based pruning - if `|weight| < threshold`, set it to 0.

**Why prune before quantizing?**
- Fewer non-zero params → smaller quantization range
- Better activation distributions (less noise from tiny weights)
- Combined with quantization = massive compression (sparse + low precision)
- Model learns to compensate for pruned weights during training (if doing iterative pruning)

**Typical workflow:**
1. Train model normally
2. Prune small weights (e.g., threshold=0.01 means drop weights < 1% of max)
3. Fine-tune briefly (optional, helps accuracy)
4. Quantize

```python
import onnx, numpy as np

def prune_model(model, threshold=0.01):
    for layer in model.graph.node:
        for attr in layer.attribute:
            if attr.name == 'weights':
                w = np.frombuffer(attr.tensor.raw_data, dtype=np.float32)
                w[abs(w) < threshold] = 0  # Zero out small weights
                attr.tensor.raw_data = w.tobytes()
    return model

model = onnx.load('model.onnx')
pruned = prune_model(model, threshold=0.01)
onnx.save(pruned, 'model_pruned.onnx')
```

**Note:** This is basic magnitude pruning. More advanced: structured pruning (remove entire channels/filters), iterative pruning, lottery ticket hypothesis.

## Step 2: Optimization

Fuse ops, improve graph. Skip for models > 2GB (Protobuf limit).

```python
from onnxruntime.quantization.shape_inference import quant_pre_process
import onnxoptimizer

# Pre-process (shape inference + optimization)
quant_pre_process(
    'model_pruned.onnx',
    'model_opt.onnx',
    skip_optimization=False,
    skip_symbolic_shape=False  # Important for transformers
)

# Or use onnxoptimizer directly
passes = onnxoptimizer.get_fuse_and_elimination_passes()
opt_model = onnxoptimizer.optimize(model, passes)
```

## Calibration Data

**What is it?**
Representative subset of real data that the model will see in production. Typically 100-1000 samples.

**Why needed?**
To figure out the range of activations at each layer. Without knowing activation ranges, can't compute proper scale/zero-point values for quantization.

**How to choose:**
- Must match production data distribution
- Include edge cases (dark images, bright images, different object sizes, etc.)
- Can use subset of validation set
- DON'T just use random noise - will give terrible quantization params

**Bad calibration = broken model.**
If calibration data doesn't represent real inputs, the computed ranges will be wrong and accuracy tanks.

## Step 3: Calibration (Critical!)

Run forward passes on calibration data to collect min/max activation values at each layer. These determine scale/zero-point.

### Custom CalibrationDataReader

```python
from onnxruntime.quantization import CalibrationDataReader
import cv2, numpy as np, os
from concurrent.futures import ThreadPoolExecutor

class MyCalibrationDataReader(CalibrationDataReader):
    def __init__(self, image_folder, batch_size=1, input_size=(416, 416)):
        self.image_folder = image_folder
        self.batch_size = batch_size
        self.input_size = input_size
        self.image_files = os.listdir(image_folder)
        self.index = 0
    
    def preprocess_image(self, img_path):
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Add your preprocessing here (resize, normalize, etc.)
        img = cv2.resize(img, self.input_size)
        img = img.astype(np.float32) / 255.0
        return img
    
    def get_next(self):
        if self.index >= len(self.image_files):
            return None
        
        batch_paths = [
            os.path.join(self.image_folder, self.image_files[i])
            for i in range(self.index, min(self.index + self.batch_size, len(self.image_files)))
        ]
        
        # Parallel processing for speed
        with ThreadPoolExecutor() as executor:
            batch = list(executor.map(self.preprocess_image, batch_paths))
        
        # Shape: (batch, channels, height, width)
        batch = np.stack(batch, axis=0)
        batch = np.transpose(batch, (0, 3, 1, 2))
        
        self.index += self.batch_size
        return {"input": batch}
```

## Step 4: Quantize

```python
from onnxruntime.quantization import quantize_static, QuantType

calibration_reader = MyCalibrationDataReader(
    'path/to/calibration/images',
    batch_size=1
)

quantize_static(
    model_input='model_opt.onnx',
    model_output='model_quant.onnx',
    calibration_data_reader=calibration_reader,
    activation_type=QuantType.QUInt8,  # uint8 for CPU
    weight_type=QuantType.QUInt8,
    per_channel=False,                  # True for better accuracy
    optimize_model=True,
    use_external_data_format=False      # True if > 2GB
)
```

## Math

```
val_fp32 = scale × (val_int8 - zero_point)

scale = max(|range_max|, |range_min|) × 2 / (quant_range_max - quant_range_min)
```

Calibration determines scale/zero-point by collecting min/max activation values on representative data.

## Quantization Types

```python
QuantType.QInt8       # signed int8
QuantType.QUInt8      # unsigned int8 (best for CPU)
QuantType.QINT4       # int4 (requires opset 21+)
QuantType.QFLOAT8     # fp8 (requires opset 19+)
```

## Per-Channel vs Per-Tensor

**Per-tensor**: single scale/zero-point per layer (faster)
**Per-channel**: separate scale/zero-point per channel (more accurate)

Use per-channel for weights in conv layers for better accuracy.

## Model Requirements

- **Opset 10+** (recommend 13+)
- INT4 needs opset 21+
- FP8 needs opset 19+
- Models > 2GB need `use_external_data_format=True`

## Benchmarking

```python
import time, onnxruntime as ort, numpy as np

def benchmark(model_path, input_shape, num_runs=100):
    sess = ort.InferenceSession(model_path)
    x = np.random.rand(*input_shape).astype(np.float32)
    input_name = sess.get_inputs()[0].name
    
    sess.run(None, {input_name: x})  # warm-up
    
    start = time.time()
    for _ in range(num_runs):
        sess.run(None, {input_name: x})
    
    return (time.time() - start) / num_runs

orig_time = benchmark('model.onnx', (1, 3, 416, 416))
quant_time = benchmark('model_quant.onnx', (1, 3, 416, 416))

print(f"Speedup: {orig_time/quant_time:.2f}x")
```

## Hardware-Specific

**CPU (AVX2/AVX-512):**
- Use U8U8 or U8S8 format
- activation_type=QuantType.QUInt8

**GPU (TensorRT):**
- Need Tensor Core support (T4, A100)
- TensorRT handles quantization logic
- Pass full precision model + calibration results

## Complete Workflow

```python
# 1. Prune
model = onnx.load('model.onnx')
pruned = prune_model(model, 0.01)
onnx.save(pruned, 'model_pruned.onnx')

# 2. Optimize
quant_pre_process('model_pruned.onnx', 'model_opt.onnx')

# 3. Calibrate & Quantize
calib_reader = MyCalibrationDataReader('calib_data/')
quantize_static(
    'model_opt.onnx',
    'model_quant.onnx',
    calib_reader,
    activation_type=QuantType.QUInt8,
    weight_type=QuantType.QUInt8
)

# 4. Benchmark
# (see above)
```

## vs Dynamic Quantization

| | Static | Dynamic |
|---|---|---|
| Calibration | Required | No |
| Setup | Complex | Simple |
| Speed | Faster | Good |
| Accuracy | Better | Good |
| Hardware | GPU/TPU/CPU | CPU only |
| Production | Yes | Prototype |

## When to Use

✅ Use static when:
- Deploying to production
- Need max speed
- Have calibration data
- Targeting GPU/TPU
- Working with CNNs

❌ Don't use when:
- No calibration data
- Quick prototyping
- Input distributions vary wildly

## Considerations

- Calibration data MUST be representative of real data
- Poor calibration = bad accuracy
- Models > 2GB need external data format
- Per-channel quant helps accuracy but uses more memory
- Always validate accuracy after quantizing

## Debugging

If accuracy drops:
- Increase calibration dataset size
- Try per-channel quantization
- Check if certain layers causing issues
- Consider QAT (quantization-aware training) instead

## Quick Reference

```bash
pip install onnxruntime onnxoptimizer
```

```python
from onnxruntime.quantization import quantize_static, QuantType

quantize_static(
    'model.onnx',
    'model_quant.onnx',
    calibration_data_reader=my_reader,
    activation_type=QuantType.QUInt8,
    weight_type=QuantType.QUInt8,
    per_channel=False
)
```

Static = better performance, more work. Worth it for production deployment.