# 量化基础理论与数学原理

**SOTA 教育标准** | 包含信息论基础、量化误差分析、均匀量化公式推导

---

## 1. 信息论基础

### 1.1 为什么量化？

**核心问题**: FP32 模型存储和计算开销大。

**量化本质**: 用有限离散值表示连续数值，这是**有损压缩**。

### 1.2 量化误差

**量化噪声功率**: $\sigma_q^2 = \Delta^2/12$

**信噪比 (SQNR)**: 每增加 1 bit，SQNR 提升 ~6 dB

In [None]:
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---

## 2. 量化方案定义

In [None]:
class QuantizationScheme(Enum):
    """量化方案枚举。"""
    INT8 = ("int8", -128, 127)
    UINT8 = ("uint8", 0, 255)
    INT4 = ("int4", -8, 7)

    def __init__(self, name: str, qmin: int, qmax: int):
        self._name = name
        self.qmin = qmin
        self.qmax = qmax

    @property
    def levels(self) -> int:
        return self.qmax - self.qmin + 1

    @property
    def bit_width(self) -> int:
        return int(np.log2(self.levels))


# 测试
scheme = QuantizationScheme.INT8
print(f"Scheme: {scheme._name}, Range: [{scheme.qmin}, {scheme.qmax}], Bits: {scheme.bit_width}")

---

## 3. 量化参数计算

In [None]:
@dataclass
class QuantizationParams:
    """量化参数: scale 和 zero_point。"""
    scale: Tensor
    zero_point: Tensor
    qmin: int
    qmax: int

    @classmethod
    def from_tensor(cls, x: Tensor, scheme: QuantizationScheme, symmetric: bool = True):
        """从张量统计计算量化参数。"""
        x_min, x_max = x.min().item(), x.max().item()
        
        if symmetric:
            max_abs = max(abs(x_min), abs(x_max), 1e-8)
            q_range = min(abs(scheme.qmin), scheme.qmax)
            scale = torch.tensor(max_abs / q_range)
            zero_point = torch.tensor(0.0)
        else:
            scale = torch.tensor((x_max - x_min) / (scheme.qmax - scheme.qmin))
            scale = torch.clamp(scale, min=1e-8)
            zero_point = torch.round(torch.tensor(scheme.qmin - x_min / scale.item()))
            zero_point = torch.clamp(zero_point, scheme.qmin, scheme.qmax)
        
        return cls(scale=scale, zero_point=zero_point, qmin=scheme.qmin, qmax=scheme.qmax)


# 测试
x = torch.randn(100) * 0.5
params = QuantizationParams.from_tensor(x, QuantizationScheme.INT8, symmetric=True)
print(f"Scale: {params.scale:.6f}, Zero-point: {params.zero_point}")

---

## 4. 均匀量化器实现

In [None]:
class UniformQuantizer:
    """均匀量化器。
    
    Core Idea: 实现标准的线性量化/反量化操作。
    """

    def __init__(self, params: QuantizationParams):
        self.params = params

    def quantize(self, x: Tensor) -> Tensor:
        """量化：FP32 -> INT8。"""
        q = x / self.params.scale + self.params.zero_point
        q = torch.round(q)
        q = torch.clamp(q, self.params.qmin, self.params.qmax)
        return q

    def dequantize(self, q: Tensor) -> Tensor:
        """反量化：INT8 -> FP32。"""
        return self.params.scale * (q - self.params.zero_point)

    def forward(self, x: Tensor) -> Tensor:
        """前向量化：量化 -> 反量化。"""
        return self.dequantize(self.quantize(x))


# 测试
x = torch.randn(4, 4) * 0.5
params = QuantizationParams.from_tensor(x, QuantizationScheme.INT8, symmetric=True)
quantizer = UniformQuantizer(params)
x_q = quantizer.forward(x)

print(f"原始张量 (部分): {x[0, :2]}")
print(f"量化后张量 (部分): {x_q[0, :2]}")
print(f"量化误差 MSE: {torch.mean((x - x_q) ** 2):.6f}")

---

## 5. 可视化分析

In [None]:
def visualize_quantization(x: Tensor, scheme: QuantizationScheme = QuantizationScheme.INT8):
    """可视化量化效果。"""
    params = QuantizationParams.from_tensor(x, scheme, symmetric=True)
    quantizer = UniformQuantizer(params)
    x_q = quantizer.forward(x)
    error = x - x_q

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # 分布对比
    axes[0].hist(x.flatten().numpy(), bins=50, alpha=0.5, label="Original", color="blue")
    axes[0].hist(x_q.flatten().numpy(), bins=50, alpha=0.5, label="Quantized", color="red")
    axes[0].set_title("Distribution Comparison")
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # 量化误差
    axes[1].hist(error.flatten().numpy(), bins=50, color="green", alpha=0.7)
    axes[1].set_title(f"Error Distribution (MSE={torch.mean(error**2):.6f})")
    axes[1].grid(True, alpha=0.3)

    # 散点图
    axes[2].scatter(x.flatten().numpy(), x_q.flatten().numpy(), alpha=0.3, s=1)
    axes[2].plot([x.min(), x.max()], [x.min(), x.max()], "r--", linewidth=1)
    axes[2].set_title("Original vs Quantized")
    axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


# 测试
test_data = torch.randn(10000)
visualize_quantization(test_data)

---

## 6. 总结

| 概念 | 公式 | 说明 |
|:-----|:-----|:-----|
| **均匀量化** | $Q(x) = \text{round}(x/s + z)$ | $s$: scale, $z$: zero-point |
| **反量化** | $\hat{x} = s \cdot (q - z)$ | 恢复浮点表示 |
| **对称量化** | $z=0, s = x_{max}/2^{b-1}$ | 简单高效 |