# 量化

## 需求
在模型部署后做推理的过程中，有下面两个问题
1. 算力需求过大
2. 显存开销太大
## 收益
以float32 -> int8为例，暂时忽略量化相关的参数， 因为相对于矩阵的存储， 量化参数非常少。  
存储降低的情况： 32bit/number -> 8bit/number , 4倍存储下降。  
计算需求的降低： 32bit float GEMM -> 8bit float GEMM + elementwise multiply + quant + dequant. GEMM: general matrix multiply [收益示例](#gemm-accerlate)
## 核心问题
非MOE结构下（MOE存在路由变动问题），指标的变化受量化导致的数值变动影响。因此，如何降低量化导致的数值变动是核心问题。

量化分类
|量化方案分类|主要做法|优点|缺点|
|-|-|-|-|
|QAT()|
0. QAT：量化感知训练，让模型在训练（调优）时，感知到量化误差。核心是伪量化，使用梯度直通让数学上不可微的量化操作可以传递梯度。后文有具体[实现](#qat-ste-code)
1. PTQ：在训练完成后做量化。本次分享的重点介绍方案。

计算示例（视频）

In [None]:
# ste_fake_quant_demo.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------- Fake Quant with STE (autograd.Function) --------
class FakeQuantizeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits, symmetric, per_channel, ch_axis, eps):
        # 计算 scale/zero_point（简单版：逐次前向用min/max估计）
        if per_channel:
            # 按通道统计 min/max
            reduce_dims = [d for d in range(x.dim()) if d != ch_axis]
            x_min = x.amin(dim=reduce_dims, keepdim=True)
            x_max = x.amax(dim=reduce_dims, keepdim=True)
        else:
            x_min = x.min()
            x_max = x.max()

        if symmetric:
            # 对称量化：零点为0
            qmax = 2 ** (num_bits - 1) - 1
            m = torch.maximum(x_max.abs(), x_min.abs())
            scale = torch.maximum(m / qmax, torch.tensor(eps, device=x.device, dtype=x.dtype))
            zero_point = 0.0
            x_clamp_min = -qmax * scale
            x_clamp_max =  qmax * scale
        else:
            # 非对称：映射到 [0, 2^b-1]
            qmin, qmax = 0, 2 ** num_bits - 1
            scale = torch.maximum((x_max - x_min) / max(qmax - qmin, 1), torch.tensor(eps, device=x.device, dtype=x.dtype))
            zero_point = torch.clamp(torch.round(qmin - x_min / scale), qmin, qmax)
            x_clamp_min = (0.0 - zero_point) * scale
            x_clamp_max = (qmax - zero_point) * scale

        # 保存范围用于反向时做简单截断（可选）
        ctx.save_for_backward(x, x_clamp_min, x_clamp_max)
        ctx.num_bits = num_bits
        ctx.symmetric = symmetric
        ctx.per_channel = per_channel
        ctx.ch_axis = ch_axis
        ctx.eps = eps

        # 量化 -> 反量化
        if symmetric:
            q = torch.round(x / scale)
            q = torch.clamp(q, -(2**(num_bits-1)), 2**(num_bits-1)-1)
            x_hat = q * scale
        else:
            q = torch.round(x / scale + zero_point)
            q = torch.clamp(q, 0, 2**num_bits - 1)
            x_hat = (q - zero_point) * scale

        return x_hat

    @staticmethod
    def backward(ctx, grad_output):
        # STE：忽略round的梯度，近似 d/dx (x) = 1
        x, x_clamp_min, x_clamp_max = ctx.saved_tensors
        # 在可表示范围内直通，范围外可选择截断为0（常见简化）
        pass_through = (x >= x_clamp_min) & (x <= x_clamp_max)
        grad_input = grad_output * pass_through.to(grad_output.dtype)
        # 非张量参数的梯度为 None
        return grad_input, None, None, None, None, None

# -------- 便捷模块封装 --------
class FakeQuantize(nn.Module):
    def __init__(self, num_bits=8, symmetric=True, per_channel=False, ch_axis=1, eps=1e-8):
        super().__init__()
        self.num_bits = num_bits
        self.symmetric = symmetric
        self.per_channel = per_channel
        self.ch_axis = ch_axis
        self.eps = eps

    def forward(self, x):
        return FakeQuantizeFunction.apply(
            x, self.num_bits, self.symmetric, self.per_channel, self.ch_axis, self.eps
        )

# 将权重也做伪量化（常见做法：前向时对权重做伪量化）
class QuantLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True,
                 w_num_bits=8, a_num_bits=8, symmetric=True):
        super().__init__(in_features, out_features, bias=bias)
        self.w_fake = FakeQuantize(num_bits=w_num_bits, symmetric=symmetric, per_channel=True, ch_axis=0)
        self.a_fake = FakeQuantize(num_bits=a_num_bits, symmetric=symmetric, per_channel=False)

    def forward(self, x):
        # 激活伪量化（前一层输出）
        x_q = self.a_fake(x)
        # 权重伪量化（对 weight 做 per-channel，对应 out_features 维度）
        w_q = self.w_fake(self.weight)
        b = self.bias
        return F.linear(x_q, w_q, b)

# -------- 一个极简可跑的训练 demo --------
def synthetic_regression(n=1024, in_dim=16, out_dim=1, noise=0.1, seed=0, device="cpu"):
    g = torch.Generator().manual_seed(seed)
    X = torch.randn(n, in_dim, generator=g)
    w_true = torch.randn(in_dim, out_dim, generator=g)
    y = X @ w_true + noise * torch.randn(n, out_dim, generator=g)
    return X.to(device), y.to(device)

class TinyQATNet(nn.Module):
    def __init__(self, in_dim=16, hidden=32, out_dim=1):
        super().__init__()
        self.fc1 = QuantLinear(in_dim, hidden, w_num_bits=8, a_num_bits=8, symmetric=True)
        self.act = nn.ReLU()
        self.fc2 = QuantLinear(hidden, out_dim, w_num_bits=8, a_num_bits=8, symmetric=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(42)

    X, y = synthetic_regression(n=2048, in_dim=16, out_dim=1, noise=0.1, seed=123, device=device)
    model = TinyQATNet(in_dim=16, hidden=32, out_dim=1).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    model.train()
    for step in range(1, 301):
        # 小批量
        idx = torch.randint(0, X.size(0), (128,), device=device)
        xb, yb = X[idx], y[idx]

        pred = model(xb)
        loss = loss_fn(pred, yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        if step % 50 == 0:
            with torch.no_grad():
                mse_all = loss_fn(model(X), y).item()
            print(f"step {step:4d} | batch_loss={loss.item():.4f} | full_MSE={mse_all:.4f}")

    # 简单对比：关闭伪量化看一下（推理阶段常用做法）
    model.eval()
    print("\nDisable fake-quant (eval) and re-evaluate:")
    def disable_fake(m):
        if isinstance(m, FakeQuantize):
            # 用恒等映射替换（最简单方式：注册一个不做事的forward）
            m.forward = lambda x: x
    model.apply(disable_fake)
    with torch.no_grad():
        mse_no_fake = loss_fn(model(X), y).item()
    print(f"Full-data MSE without fake-quant: {mse_no_fake:.4f}")

# if __name__ == "__main__":
#     main()
main*()

举例（以float32到int8的量化为例）: 
1. 把[float32](https://www.h-schmidt.net/FloatConverter/IEEE754.html)的范围直接变换到int8
   |数据类型|数据点|数据范围|
   |-|-|-|
   |float32|2^32|[−3.4028235×1038,−1.4013×10−45]∪[1.4013×10−45,3.4028235×1038]|
   |int8|2^8|[−128,127]|

<center class='img'>
<img src="assets/img/most_general_quant.png" style="zoom: 40%;">
</center>

论文及分类    
综述: [A White Paper on Neural Network Quantization](https://arxiv.org/pdf/2106.08295)  
训练前分布约束技术：   
    [PACT: Parameterized Clipping Activation for Quantized Neural Networks](https://arxiv.org/pdf/1805.06085)   
    [R2 Loss: Range Restriction Loss for Model Compression and Quantization](https://arxiv.org/pdf/2303.08253)  
    [Robust Quantization: One Model to Rule Them All](https://arxiv.org/pdf/2002.07686)  
PTQ技术(目前主要面向Transformer结构, 实际的计算是Linear, 但是在语音中conformer中是有conv结构的, 也需要关注面向conv的量化方案)：  
    &nbsp; 直觉类：   
        &ensp; 最直观降低误差的手段：   
            &emsp; [LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale](https://arxiv.org/pdf/2208.07339)     
        &ensp; 基于数据的观察：   
            &emsp; [SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models](https://arxiv.org/pdf/2211.10438)    
        &emsp; [ZeroQuant: Efficient and Affordable Post-Training Quantization for Large-Scale Transformers](https://arxiv.org/pdf/2206.01861)    
        &emsp; [AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration](https://arxiv.org/pdf/2306.00978)  
    &nbsp;灵光一闪类：     
        &ensp; [Up or Down? Adaptive Rounding for Post-Training Quantization](https://arxiv.org/pdf/2004.10568)  
    &nbsp;数学分析类：   
        &ensp; [Data-Free Quantization Through Weight Equalization and Bias Correction](https://arxiv.org/pdf/1906.04721)  


<a id="qat-ste-code">QAT梯度直通示例</a>  
是否需要增加一个更加简单的示例，以及对pytorch的反向传播做下说明？（需要的, 做一个演示）

In [1]:
# ste_fake_quant_demo.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# -------- Fake Quant with STE (autograd.Function) --------
class FakeQuantizeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, num_bits, symmetric, per_channel, ch_axis, eps):
        # 计算 scale/zero_point（简单版：逐次前向用min/max估计）
        if per_channel:
            # 按通道统计 min/max
            reduce_dims = [d for d in range(x.dim()) if d != ch_axis]
            x_min = x.amin(dim=reduce_dims, keepdim=True)
            x_max = x.amax(dim=reduce_dims, keepdim=True)
        else:
            x_min = x.min()
            x_max = x.max()

        if symmetric:
            # 对称量化：零点为0
            qmax = 2 ** (num_bits - 1) - 1
            m = torch.maximum(x_max.abs(), x_min.abs())
            scale = torch.maximum(m / qmax, torch.tensor(eps, device=x.device, dtype=x.dtype))
            zero_point = 0.0
            x_clamp_min = -qmax * scale
            x_clamp_max =  qmax * scale
        else:
            # 非对称：映射到 [0, 2^b-1]
            qmin, qmax = 0, 2 ** num_bits - 1
            scale = torch.maximum((x_max - x_min) / max(qmax - qmin, 1), torch.tensor(eps, device=x.device, dtype=x.dtype))
            zero_point = torch.clamp(torch.round(qmin - x_min / scale), qmin, qmax)
            x_clamp_min = (0.0 - zero_point) * scale
            x_clamp_max = (qmax - zero_point) * scale

        # 保存范围用于反向时做简单截断（可选）
        ctx.save_for_backward(x, x_clamp_min, x_clamp_max)
        ctx.num_bits = num_bits
        ctx.symmetric = symmetric
        ctx.per_channel = per_channel
        ctx.ch_axis = ch_axis
        ctx.eps = eps

        # 量化 -> 反量化
        if symmetric:
            q = torch.round(x / scale)
            q = torch.clamp(q, -(2**(num_bits-1)), 2**(num_bits-1)-1)
            x_hat = q * scale
        else:
            q = torch.round(x / scale + zero_point)
            q = torch.clamp(q, 0, 2**num_bits - 1)
            x_hat = (q - zero_point) * scale

        return x_hat

    @staticmethod
    def backward(ctx, grad_output):
        # STE：忽略round的梯度，近似 d/dx (x) = 1
        x, x_clamp_min, x_clamp_max = ctx.saved_tensors
        # 在可表示范围内直通，范围外可选择截断为0（常见简化）
        pass_through = (x >= x_clamp_min) & (x <= x_clamp_max)
        grad_input = grad_output * pass_through.to(grad_output.dtype)
        # 非张量参数的梯度为 None
        return grad_input, None, None, None, None, None

# -------- 便捷模块封装 --------
class FakeQuantize(nn.Module):
    def __init__(self, num_bits=8, symmetric=True, per_channel=False, ch_axis=1, eps=1e-8):
        super().__init__()
        self.num_bits = num_bits
        self.symmetric = symmetric
        self.per_channel = per_channel
        self.ch_axis = ch_axis
        self.eps = eps

    def forward(self, x):
        return FakeQuantizeFunction.apply(
            x, self.num_bits, self.symmetric, self.per_channel, self.ch_axis, self.eps
        )

# 将权重也做伪量化（常见做法：前向时对权重做伪量化）
class QuantLinear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True,
                 w_num_bits=8, a_num_bits=8, symmetric=True):
        super().__init__(in_features, out_features, bias=bias)
        self.w_fake = FakeQuantize(num_bits=w_num_bits, symmetric=symmetric, per_channel=True, ch_axis=0)
        self.a_fake = FakeQuantize(num_bits=a_num_bits, symmetric=symmetric, per_channel=False)

    def forward(self, x):
        # 激活伪量化（前一层输出）
        x_q = self.a_fake(x)
        # 权重伪量化（对 weight 做 per-channel，对应 out_features 维度）
        w_q = self.w_fake(self.weight)
        b = self.bias
        return F.linear(x_q, w_q, b)

# -------- 一个极简可跑的训练 demo --------
def synthetic_regression(n=1024, in_dim=16, out_dim=1, noise=0.1, seed=0, device="cpu"):
    g = torch.Generator().manual_seed(seed)
    X = torch.randn(n, in_dim, generator=g)
    w_true = torch.randn(in_dim, out_dim, generator=g)
    y = X @ w_true + noise * torch.randn(n, out_dim, generator=g)
    return X.to(device), y.to(device)

class TinyQATNet(nn.Module):
    def __init__(self, in_dim=16, hidden=32, out_dim=1):
        super().__init__()
        self.fc1 = QuantLinear(in_dim, hidden, w_num_bits=8, a_num_bits=8, symmetric=True)
        self.act = nn.ReLU()
        self.fc2 = QuantLinear(hidden, out_dim, w_num_bits=8, a_num_bits=8, symmetric=True)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.manual_seed(42)

    X, y = synthetic_regression(n=2048, in_dim=16, out_dim=1, noise=0.1, seed=123, device=device)
    model = TinyQATNet(in_dim=16, hidden=32, out_dim=1).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    model.train()
    for step in range(1, 301):
        # 小批量
        idx = torch.randint(0, X.size(0), (128,), device=device)
        xb, yb = X[idx], y[idx]

        pred = model(xb)
        loss = loss_fn(pred, yb)

        opt.zero_grad()
        loss.backward()
        opt.step()

        if step % 50 == 0:
            with torch.no_grad():
                mse_all = loss_fn(model(X), y).item()
            print(f"step {step:4d} | batch_loss={loss.item():.4f} | full_MSE={mse_all:.4f}")

    # 简单对比：关闭伪量化看一下（推理阶段常用做法）
    model.eval()
    print("\nDisable fake-quant (eval) and re-evaluate:")
    def disable_fake(m):
        if isinstance(m, FakeQuantize):
            # 用恒等映射替换（最简单方式：注册一个不做事的forward）
            m.forward = lambda x: x
    model.apply(disable_fake)
    with torch.no_grad():
        mse_no_fake = loss_fn(model(X), y).item()
    print(f"Full-data MSE without fake-quant: {mse_no_fake:.4f}")

# if __name__ == "__main__":
#     main()
main()

step   50 | batch_loss=9.6628 | full_MSE=9.3308
step  100 | batch_loss=5.5163 | full_MSE=5.6729
step  150 | batch_loss=2.5716 | full_MSE=2.1575
step  200 | batch_loss=0.7124 | full_MSE=0.6377
step  250 | batch_loss=0.2451 | full_MSE=0.2592
step  300 | batch_loss=0.1310 | full_MSE=0.1408

Disable fake-quant (eval) and re-evaluate:
Full-data MSE without fake-quant: 0.1390


<a id="GEMM-accerlate">计算加速效果示意</a>