# 训练后量化 (Post-Training Quantization, PTQ)

**SOTA 教育标准** | 包含静态量化、动态量化、校准方法

---

## 1. PTQ 概述

### 1.1 什么是 PTQ？

**定义**: 在模型训练完成后，直接对权重和激活值进行量化。

**优势**: 无需重新训练，快速部署。

### 1.2 PTQ 分类

| 类型 | 权重 | 激活 | 速度 |
|:-----|:----:|:----:|:----:|
| 动态量化 | INT8 | FP32 | 中 |
| 静态量化 | INT8 | INT8 | 快 |

In [None]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, List, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---

## 2. 校准器实现

In [None]:
class MinMaxCalibrator:
    """最小-最大校准器。
    
    Core Idea: 收集激活值的最小/最大值来确定量化范围。
    """
    
    def __init__(self):
        self.min_val = float('inf')
        self.max_val = float('-inf')
    
    def update(self, x: Tensor) -> None:
        """更新统计值。"""
        self.min_val = min(self.min_val, x.min().item())
        self.max_val = max(self.max_val, x.max().item())
    
    def compute_scale(self, bits: int = 8, symmetric: bool = True) -> tuple:
        """计算量化参数。"""
        if symmetric:
            max_abs = max(abs(self.min_val), abs(self.max_val))
            scale = max_abs / (2 ** (bits - 1) - 1)
            zero_point = 0
        else:
            scale = (self.max_val - self.min_val) / (2 ** bits - 1)
            zero_point = round(-self.min_val / scale)
        return scale, zero_point


# 测试
calibrator = MinMaxCalibrator()
for _ in range(10):
    x = torch.randn(32, 64) * 0.5
    calibrator.update(x)

scale, zp = calibrator.compute_scale()
print(f"Range: [{calibrator.min_val:.4f}, {calibrator.max_val:.4f}]")
print(f"Scale: {scale:.6f}, Zero-point: {zp}")

---

## 3. 百分位校准器

In [None]:
class PercentileCalibrator:
    """百分位校准器。
    
    Core Idea: 使用百分位数而非极值，对异常值更鲁棒。
    """
    
    def __init__(self, percentile: float = 99.9):
        self.percentile = percentile
        self.values: List[Tensor] = []
    
    def update(self, x: Tensor) -> None:
        self.values.append(x.flatten().clone())
    
    def compute_scale(self, bits: int = 8, symmetric: bool = True) -> tuple:
        all_values = torch.cat(self.values)
        low = torch.quantile(all_values, (100 - self.percentile) / 100)
        high = torch.quantile(all_values, self.percentile / 100)
        
        if symmetric:
            max_abs = max(abs(low.item()), abs(high.item()))
            scale = max_abs / (2 ** (bits - 1) - 1)
            zero_point = 0
        else:
            scale = (high.item() - low.item()) / (2 ** bits - 1)
            zero_point = round(-low.item() / scale)
        return scale, zero_point


# 测试
calibrator = PercentileCalibrator(percentile=99.9)
for _ in range(10):
    x = torch.randn(32, 64) * 0.5
    calibrator.update(x)

scale, zp = calibrator.compute_scale()
print(f"Percentile Scale: {scale:.6f}, Zero-point: {zp}")

---

## 4. PTQ 量化流程

In [None]:
class PTQQuantizer:
    """PTQ 量化器。"""
    
    def __init__(self, model: nn.Module, bits: int = 8):
        self.model = model
        self.bits = bits
        self.calibrators = {}
        self.scales = {}
        self.zero_points = {}
    
    def calibrate(self, dataloader, num_batches: int = 100) -> None:
        """校准：收集激活值统计。"""
        self.model.eval()
        
        # 为每层创建校准器
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                self.calibrators[name] = MinMaxCalibrator()
        
        # 收集统计
        with torch.no_grad():
            for i, (x, _) in enumerate(dataloader):
                if i >= num_batches:
                    break
                _ = self.model(x)
        
        # 计算量化参数
        for name, cal in self.calibrators.items():
            scale, zp = cal.compute_scale(self.bits)
            self.scales[name] = scale
            self.zero_points[name] = zp
    
    def quantize_weights(self) -> None:
        """量化权重。"""
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                w = module.weight.data
                max_abs = w.abs().max().item()
                scale = max_abs / (2 ** (self.bits - 1) - 1)
                w_q = torch.round(w / scale) * scale
                module.weight.data = w_q


# 测试
model = nn.Sequential(nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 10))
quantizer = PTQQuantizer(model, bits=8)
quantizer.quantize_weights()
print("PTQ 权重量化完成")

---

## 5. 总结

| 方法 | 优势 | 劣势 |
|:-----|:-----|:-----|
| MinMax | 简单快速 | 对异常值敏感 |
| Percentile | 鲁棒 | 需要更多内存 |
| Histogram | 最优 | 计算复杂 |