# 量化感知训练 (Quantization-Aware Training, QAT)

**SOTA 教育标准** | 包含 QAT 原理、伪量化、STE 梯度估计、完整训练流程

---

## 1. QAT 核心原理

### 1.1 为什么需要 QAT？

**PTQ 的局限性**:
- 低比特量化（INT4/INT2）精度损失大
- 对异常值敏感
- 无法恢复量化引入的误差

**QAT 的优势**:
- 训练时模拟量化效果
- 网络学习适应量化噪声
- 可达到接近 FP32 的精度

### 1.2 伪量化 (Fake Quantization)

**核心思想**: 前向传播模拟量化，反向传播使用 STE。

$$\text{FakeQuant}(x) = s \cdot \text{clip}(\text{round}(x/s), q_{min}, q_{max})$$

**前向**: 量化 -> 反量化（模拟量化误差）
**反向**: 直通估计器 (Straight-Through Estimator)

### 1.3 STE 梯度估计

**问题**: `round()` 函数梯度为零，无法反向传播。

**解决方案**: Straight-Through Estimator

$$\frac{\partial \text{FakeQuant}(x)}{\partial x} \approx \mathbf{1}_{x \in [q_{min} \cdot s, q_{max} \cdot s]}$$

即：在量化范围内，梯度直通；超出范围，梯度为零。

In [None]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---

## 2. STE 实现

In [None]:
class StraightThroughEstimator(torch.autograd.Function):
    """直通估计器 (STE)。

    Core Idea:
        前向传播执行量化，反向传播直接传递梯度。

    Mathematical Theory:
        Forward: y = round(x)
        Backward: dx = dy (直通)

    Summary:
        这是 QAT 能够工作的关键技术。
    """

    @staticmethod
    def forward(ctx, x: Tensor, scale: Tensor, zero_point: Tensor, qmin: int, qmax: int) -> Tensor:
        # 量化
        x_int = torch.round(x / scale + zero_point)
        x_int = torch.clamp(x_int, qmin, qmax)
        # 反量化
        x_q = (x_int - zero_point) * scale
        # 保存用于反向传播
        ctx.save_for_backward(x, scale)
        ctx.qmin = qmin
        ctx.qmax = qmax
        return x_q

    @staticmethod
    def backward(ctx, grad_output: Tensor) -> Tuple[Tensor, None, None, None, None]:
        x, scale = ctx.saved_tensors
        # STE: 在量化范围内直通梯度
        x_normalized = x / scale
        mask = (x_normalized >= ctx.qmin) & (x_normalized <= ctx.qmax)
        grad_input = grad_output * mask.float()
        return grad_input, None, None, None, None


# 测试 STE
x = torch.randn(4, requires_grad=True)
scale = torch.tensor(0.1)
zero_point = torch.tensor(0.0)

x_q = StraightThroughEstimator.apply(x, scale, zero_point, -128, 127)
loss = x_q.sum()
loss.backward()

print(f"Input: {x.data}")
print(f"Quantized: {x_q.data}")
print(f"Gradient: {x.grad}")

---

## 3. QAT 模块实现

In [None]:
class FakeQuantize(nn.Module):
    """伪量化模块。

    Core Idea:
        在训练时模拟量化效果，使网络适应量化噪声。

    Mathematical Theory:
        y = scale * clip(round(x/scale + zp), qmin, qmax) - zp * scale

    Complexity:
        Time: O(n)
        Space: O(1) 额外空间
    """

    def __init__(self, bits: int = 8, symmetric: bool = True):
        super().__init__()
        self.bits = bits
        self.symmetric = symmetric

        if symmetric:
            self.qmin = -(2 ** (bits - 1))
            self.qmax = 2 ** (bits - 1) - 1
        else:
            self.qmin = 0
            self.qmax = 2 ** bits - 1

        # 可学习的量化参数
        self.register_buffer("scale", torch.tensor(1.0))
        self.register_buffer("zero_point", torch.tensor(0.0))
        self.register_buffer("initialized", torch.tensor(False))

    def update_params(self, x: Tensor) -> None:
        """更新量化参数。"""
        with torch.no_grad():
            x_min, x_max = x.min(), x.max()
            if self.symmetric:
                max_abs = max(abs(x_min), abs(x_max))
                self.scale.fill_(max_abs / self.qmax if max_abs > 0 else 1.0)
                self.zero_point.fill_(0.0)
            else:
                self.scale.fill_((x_max - x_min) / (self.qmax - self.qmin))
                self.zero_point.fill_(self.qmin - x_min / self.scale)
            self.initialized.fill_(True)

    def forward(self, x: Tensor) -> Tensor:
        if self.training and not self.initialized:
            self.update_params(x)
        return StraightThroughEstimator.apply(
            x, self.scale, self.zero_point, self.qmin, self.qmax
        )

In [None]:
class QATLinear(nn.Module):
    """QAT 线性层。

    Core Idea:
        在训练时对权重和激活值进行伪量化。

    Summary:
        这是 QAT 的核心组件，训练后可直接导出为 INT8 模型。
    """

    def __init__(self, in_features: int, out_features: int, bias: bool = True, bits: int = 8):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.weight_fake_quant = FakeQuantize(bits=bits, symmetric=True)
        self.activation_fake_quant = FakeQuantize(bits=bits, symmetric=False)

    def forward(self, x: Tensor) -> Tensor:
        # 量化权重
        w_q = self.weight_fake_quant(self.linear.weight)
        # 量化激活
        x_q = self.activation_fake_quant(x)
        # 线性运算
        return F.linear(x_q, w_q, self.linear.bias)


# 测试
qat_layer = QATLinear(64, 32)
x = torch.randn(8, 64)
y = qat_layer(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Weight scale: {qat_layer.weight_fake_quant.scale.item():.6f}")

---

## 4. QAT 训练流程

In [None]:
class SimpleQATModel(nn.Module):
    """简单的 QAT 模型示例。"""

    def __init__(self, input_dim: int = 784, hidden_dim: int = 256, output_dim: int = 10):
        super().__init__()
        self.fc1 = QATLinear(input_dim, hidden_dim)
        self.fc2 = QATLinear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x: Tensor) -> Tensor:
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        return self.fc2(x)


def train_qat_model(model: nn.Module, epochs: int = 5) -> list:
    """QAT 训练函数。"""
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    losses = []

    # 模拟训练数据
    for epoch in range(epochs):
        x = torch.randn(32, 784)
        y = torch.randint(0, 10, (32,))

        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

    return losses


# 训练
model = SimpleQATModel()
losses = train_qat_model(model, epochs=10)

# 可视化
plt.figure(figsize=(8, 4))
plt.plot(losses, 'b-o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('QAT Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

---

## 5. 总结

| 概念 | 公式/方法 | 说明 |
|:-----|:---------|:-----|
| **伪量化** | $\text{FQ}(x) = s \cdot \text{clip}(\text{round}(x/s), q_{min}, q_{max})$ | 模拟量化效果 |
| **STE** | $\frac{\partial y}{\partial x} = 1$ (范围内) | 直通梯度估计 |
| **QAT 优势** | 训练时适应量化噪声 | 精度接近 FP32 |

**关键点**:
1. QAT 在训练时模拟量化效果
2. STE 解决了 round() 不可导的问题
3. QAT 适合低比特量化（INT4/INT2）
4. 训练后可直接导出为量化模型