# 应用RoundToNearest后量化算法

[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/docs/golden_stick/docs/source_zh_cn/quantization/simqat.md)

## RoundToNearest后量化算法简介

Round to nearest本意是一种取整方式，即四舍五入。在本框架中，RoundToNearest表示一类较朴素的后量化算法，其取整方式使用了简单的四舍五入方式。

当前金箍棒中的RoundToNearest后量化（后面使用RTN来简称）主要针对LLM（大语言模型场景），使用MinMax校正器对线性层（Linear）进行量化。伪量化的网络结构示意如下：

xxx.img

表1：RTN算法规格

| 规格 | 规格说明 |
| --- | --- |
| 硬件支持 | 量化阶段运行在CPU，量化模型推理仅支持Ascend910B |
| 网络支持 | Llama2 13B/70B，具体请参见[Llama2网络](https://gitee.com/hangangqiang/mindformers/tree/dev/mindformers/models/llama)。 |
| 运行模式支持 | Graph模式和PyNative模式 |

## 示例

跟金箍棒仓所有算法一样，RTN算法的应用主要可以分为两个阶段：量化阶段和部署阶段。量化阶段是部署前提前完成的，主要的工作是：收集权重的分布、计算量化参数、量化权重数据、插入反量化节点。部署阶段通常是指用户在生产环境，使用MindSpore框架对量化后的模型进行推理的过程。本用例使用Llama2网络进行演示，主要分三个步骤：环境准备、模型量化、模型部署并评估。

### 步骤1. 环境准备

#### 1.1. Ascend910B环境

RTN算法需要运行在Ascend910B硬件上，Ascend910B的环境配置可以参考todo

#### 1.2. MindSpore环境

金箍棒依赖于MindSpore，需要提前安装合适的MindSpore。可以从MindSpore官网下载预编译好的[2.3版本安装包](https://www.mindspore.cn/versions#2.2.11)

#### 1.3. MindFormers环境

MindFormers是MindSpore的大模型套件，本用例使用MindFormers的Llama2 13B网络作为示例，需要提前安装合适的MindFormers。为了使用最新的MindFormers功能，我们使用源码编译的方式安装MindFormers：

In [None]:
!mkdir workspace
!cd workspace; git clone https://gitee.com/mindspore/mindformers.git
!cd workspace/mindformers; git checkout dev
!cd workspace/mindformers; bash build.sh

#### 1.4. 金箍棒环境

在官网下载金箍棒的预编译好的2.3版本安装包：

In [None]:
!cd workspace; git clone https://gitee.com/mindspore/golden-stick.git
!cd workspace/golden-stick; git checkout r0.4
!cd workspace/golden-stick; bash build.sh
!cd workspace/golden-stick; pip install ./output/mindspore_gs-0.4.0-py3-none-any.whl

#### 1.5. 相关文件准备

需要预先下载MindFormers Llama2网络相关的文件以及评估使用的数据集，包括：Llama2 13B网络checkpoint文件，Llama2分词器文件，Llama2模型配置文件，wikitext2 datasets：

In [None]:
!cd workspace; wget -O tokenizer.model https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/tokenizer.model
!cd workspace; wget -O llama2-13b-fp16.ckpt https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/MindFormers/llama2/llama2_7b.ckpt
!cd workspace; cp golden-stick/mindspore_gs/ptq/round_to_nearest/configs/run_llama2_7b_910b.yaml ./
!cd workspace; wget -O wikitext-2-v1.zip https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
!cd workspace; unzip wikitext-2-v1.zip

> 下载时如果遇到网络问题，可以尝试
> 1. 为wget添加 --no-check-certificate 参数
> 2. 使用浏览器手动下载相应文件，并放到相应目录下

完成所有准备后，检查目录结构：

In [None]:
!cd workspace; tree -L 2 -U

当前目录下至少应该包括下述文件：

```bash
.
├── llama2_7b.ckpt
├── tokenizer.model
├── wikitext-2-v1.zip
├── wikitext-2
│   ├── wiki.valid.tokens
│   ├── wiki.test.tokens
│   └── wiki.train.tokens
├── mindformers
│   └── xxx
└── golden-stick
    └── xxx
```

### 步骤2. 模型量化

#### 2.1. 实例化MindFormerConfig

构造MindFormers仓的Llama2网络，首先需要构造MindFormerConfig配置项，本示例中，先编写了一个构造MindFormerConfig的函数，并实例化一个MindFormerConfig对象：

In [None]:
import mindspore as ms
from mindspore import context
from mindformers import LlamaForCausalLM, MindFormerConfig, LlamaConfig, init_context, TransformerOpParallelConfig


def _set_config(config_path, device_target, device_id):
    """setup MindFormerConfig"""
    mfconfig = MindFormerConfig(config_path)
    if device_id != -1:
        mfconfig.context.device_id = device_id
    mfconfig.context.device_target = device_target
    mfconfig.model.model_config = LlamaConfig(**mfconfig.model.model_config)

    init_context(use_parallel=mfconfig.use_parallel, context_config=mfconfig.context, parallel_config=mfconfig.parallel)

    parallel_config = TransformerOpParallelConfig(**mfconfig.parallel_config)
    mfconfig.model.model_config.parallel_config = parallel_config
    mfconfig.model.model_config.checkpoint_name_or_path = mfconfig.load_checkpoint
    return mfconfig


def create_mfconfig(config_path, device_target, device_id, bs, seq_len, tokenizer_path="", ckpt_path="", model_parallel=1):
    """Create mindformers config for llama2 network for example."""
    if model_parallel > 1:
        # MS parallel not support bfloat16 now.
        compute_dtype = ms.float16
        use_parallel = True
        model_parallel = model_parallel
    else:
        compute_dtype = ms.float16
        use_parallel = False
        model_parallel = 1
    config = _set_config(config_path, device_target, device_id)
    config.model.model_config.batch_size = bs
    config.model.model_config.seq_length = seq_len
    config.model.model_config.compute_dtype = compute_dtype
    config.model.model_config.layernorm_compute_type = ms.float32
    config.model.model_config.softmax_compute_type = ms.float16
    config.model.model_config.rotary_dtype = ms.float32
    config.model.model_config.param_init_type = ms.float32
    config.processor.tokenizer.vocab_file = tokenizer_path
    config.load_checkpoint = ckpt_path
    config.model.model_config.checkpoint_name_or_path = ckpt_path
    config.use_parallel = use_parallel
    config.parallel_config.model_parallel = model_parallel
    return config

context.set_context(device_target="CPU", mode=ms.GRAPH_MODE)
llama2_config_file = "./workspace/run_llama2_7b_910b.yaml"
llama2_w16a16_ckpt_file = "./workspace/llama2_7b.ckpt"
llama2_w8a16_ckpt_file = "./workspace/llama2-7b-w8a16.ckpt"
vocab_file = "./workspace/tokenizer.model"
wikitext2_ds_path = "./workspace/wikitext-2/wiki.valid.tokens"
bs = 1
seq_len = 256
device_id = 0  # 请根据运行环境中Ascend硬件空闲情况修改

quant_network_config = create_mfconfig(llama2_config_file, "CPU", device_id, bs, seq_len, ckpt_path=llama2_w16a16_ckpt_file)

> 构造Llama2网络对象，其中的device_id可以根据运行环境中Ascend硬件空闲情况修改

#### 2.2. 实例化Llama2网络

In [None]:
import mindspore as ms
from mindformers import LlamaForCausalLM

network = LlamaForCausalLM(quant_network_config.model.model_config)
network.set_train(False)
network.phase = 'predict'

#### 2.3. 实例化RTN算法

In [None]:
from mindspore_gs.common.gs_enum import PTQMode, BackendTarget
from mindspore_gs.ptq import PTQConfig
from mindspore_gs.ptq import RoundToNearest as RTN
cfg = PTQConfig(mode=PTQMode.QUANTIZE, backend=BackendTarget.ASCEND)
ptq = RTN(config=cfg)

#### 2.4. 量化Llama2网络并保存ckpt

In [None]:
qnet = ptq.apply(network.model)
qnet = ptq.convert(qnet)
network.model = qnet
ms.save_checkpoint(network, llama2_w8a16_ckpt_file)

成功运行后，会在当前目录下生成`llama2-7b-w8a16.ckpt`文件。

### 步骤3. 模型部署

#### 3.1. 实例化MindFormerConfig和Llama2网络

In [None]:
context.set_context(device_target="Ascend")
deploy_network_config = create_mfconfig(llama2_config_file, "Ascend", device_id, bs, seq_len, ckpt_path=llama2_w8a16_ckpt_file)
deploy_network = LlamaForCausalLM(deploy_network_config.model.model_config)
deploy_network.set_train(False)
deploy_network.phase = 'predict'

#### 3.2. 加载量化后的ckpt

由于MindSpore当前不支持保存修改后的网络，所以在加载量化ckpt之前，需要先用算法恢复带量化结构的网络，然后再加载ckpt到网络。

In [None]:
deploy_cfg = PTQConfig(mode=PTQMode.DEPLOY, backend=BackendTarget.ASCEND)
deploy_ptq = RTN(config=deploy_cfg)
deploy_network.model = deploy_ptq.apply(deploy_network.model)
deploy_network.model = deploy_ptq.convert(deploy_network.model)
ms.load_checkpoint(llama2_w8a16_ckpt_file, deploy_network)

#### 3.3. 评估量化后的网络

本示例对Llama2在wikitext2数据集上评估Perplexity指标。使用步骤1中好的分词器和数据集文件分别实例化分词器对象和数据集对象，并实例化PerplexityMetric对象作为metric。

In [None]:
from mindspore_gs.datasets import create_wikitext_dataset
from mindformers import LlamaTokenizer
from mindformers.core.metric import PerplexityMetric

tokenizer = LlamaTokenizer(vocab_file=vocab_file)
deploy_ds = create_wikitext_dataset(wikitext2_ds_path, bs, seq_len, tokenizer)
deploy_metrics = {"PerplexityMetric": PerplexityMetric()}
deploy_model = ms.Model(deploy_network, metrics=deploy_metrics, eval_network=deploy_network)
quant_ppl = deploy_model.eval(deploy_ds, dataset_sink_mode=deploy_network_config.runner_config.sink_mode)
print(f"W8A16 Perplexity: {quant_ppl}")

得到结果：W8A16 Perplexity: {'PerplexityMetric': {'loss': 2.9108531734988285, 'PPL': 18.372466783981967}}

#### 3.4. 评估FP16网络的Perplexity指标

In [None]:
fp16_network_config = create_mfconfig(llama2_config_file, "Ascend", device_id, bs, seq_len, ckpt_path=llama2_w16a16_ckpt_file)
fp16_network = LlamaForCausalLM(fp16_network_config.model.model_config)
fp16_network.set_train(False)
fp16_network.phase = 'predict'
ms.load_checkpoint(llama2_w16a16_ckpt_file, fp16_network)
fp16_ds = create_wikitext_dataset(wikitext2_ds_path, bs, seq_len, tokenizer)
fp16_metrics = {"PerplexityMetric": PerplexityMetric()}
fp16_model = ms.Model(fp16_network, metrics=fp16_metrics, eval_network=fp16_network)
fp16_ppl = fp16_model.eval(fp16_ds, dataset_sink_mode=fp16_network_config.runner_config.sink_mode)
print(f"FP16 Perplexity: {fp16_ppl}")

得到结果：FP16 Perplexity: {'PerplexityMetric': {'loss': 3.0012421822877363, 'PPL': 18.581457892192436}}

#### 3.5. 比较结果

表2：Llama2 7B网络RTN算法量化前后对比

| 指标 | FP16 | W8A16 | 相对收益 |
| --- | --- | --- | --- |
| ckpt-size(GB) | 13 | 7.1 | 45.38% |
| wikitext2-PPL↓ | 18.581 | 18.372 | -0.209 |
