# 8-Reason

参数太小的模型直接通过冷启动SFT+GRPO几乎不可能获得任何推理效果，因此，使用冷启动 SFT + GRPO 训练方法对小模型推理能力的作用有限.因此，MiniMind 项目作者使用推理数据集对 MiniMind 系列模型进行黑盒蒸馏来训练推理模型.

使用的推理数据格式:

```
{
    "conversations": [
        {"role": "user", "content": "Q1?"},
        {"role": "assistant", "content": "<think>T1</think>\n<answer>A1</answer>"},
        {"role": "user", "content": "Q2?"},
        {"role": "assistant", "content": "<think>T2</think>\n<answer>A2</answer>"}
    ]
}
```

此笔记本的完整实现见主仓库 `/minimind/train_distill_reason.py`

In [1]:
import os
import platform
import argparse
import time
import math
import warnings

import pandas as pd
import torch
import torch.nn.functional as F
import torch.distributed as dist
from contextlib import nullcontext

from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler
from transformers import AutoTokenizer, AutoModelForCausalLM
from model.model import MiniMindLM
from model.LMConfig import LMConfig
from model.dataset import SFTDataset

warnings.filterwarnings('ignore')

## 可选参数设置

首先，查看训练的可选参数，这些参数在实际使用时通过解析命令行进行导入，我们用 class 进行包装.

In [2]:
class args:
    # out_dir: str = "out" # pytorch 格式权重文件保存位置 我们只展示训练过程 所以不使用
    epochs: int = 1 # 训练轮数
    batch_size: int = 2 # pretrain 数据集仅两个样本，设置 batch 为 2
    learning_rate: float = 5e-4 # 学习率
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    dtype: str = 'bfloat16' # 16 bit 浮点数：8 bit 指数 + 7 bit 尾数
    # use_wandb: bool = False # 是否使用 wandb 我们不使用
    wandb_project: str = 'MiniMind-Notebook'
    num_workers: int = 1 # 工作进程数
    # ddp：bool = False # 单机多卡
    accumulation_steps: int = 1 # 梯度累积步数
    grad_clip: float = 1.0 # 梯度剪裁
    warmup_iters: int = 0 # 学习率热启动
    log_interval: int = 1 # 每一步打印日志 仅用于观察
    # save_interval: int = 100 # checkpoint 保存点 我们不使用
    local_rank: int = 1 # device 设备号
    dim: int = 512 # 词嵌入维度 模型超参数
    n_layers: int = 2 # MiniMind Block 数量 模型超参数
    max_seq_len: int = 512 # 序列长度阈值
    use_moe: bool = False # 是否启用混合专家
    data_path: str = './toydata/r1_data.jsonl' # 数据集路径

In [3]:
print(f'查看工作设备 {args.device}')

查看工作设备 cuda


接下来，我们对分词器、MiniMind 学生模型以及数据迭代器执行初始化.

In [4]:
def init_model(lm_config):
    tokenizer = AutoTokenizer.from_pretrained('../model/minimind_tokenizer')
    model = MiniMindLM(lm_config)
    moe_path = '_moe' if lm_config.use_moe else ''
    # 热启动
    # ckp = f'./out/rlhf_{lm_config.dim}{moe_path}.pth'
    # state_dict = torch.load(ckp, map_location=args.device)
    # model.load_state_dict(state_dict, strict=False)
    print(f'LLM总参数量：{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    model = model.to(args.device)
    return model, tokenizer

In [5]:
lm_config = LMConfig(dim=args.dim, n_layers=args.n_layers, max_seq_len=args.max_seq_len, use_moe=args.use_moe)
model, tokenizer = init_model(lm_config)

train_ds = SFTDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)

train_loader = DataLoader(
    train_ds,
    batch_size=args.batch_size,
    pin_memory=True,
    drop_last=False,
    shuffle=False,
)

print(f'模型位于设备：{model.device}, 词表长度：{tokenizer.vocab_size}, DataLoader：{train_loader}')

LLM总参数量：8.915 百万
模型位于设备：cuda:0, 词表长度：6400, DataLoader：<torch.utils.data.dataloader.DataLoader object at 0x000001DAD0FFD630>


## 启动训练

接下来，我们定义 MiniMind LoRA 微调所使用的优化器，损失函数和学习率调度，并进行一轮简单的训练.

In [6]:
# 学习率调度方面 采用余弦退火学习率
def get_lr(current_step, total_steps, lr):
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))

# 优化器方面 选择 AdamW 优化器 并在混精度场景下创建 scaler 进行梯度缩放避免数值下溢
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
# 优化学生模型参数
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

device_type = "cuda" if "cuda" in args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() # 在 cuda 上启动混精度训练，否则空白上下文

接下来，我们来看看训练函数.

蒸馏思考数据集的训练过程与 SFT 类似，区别在于模型生成序列中，思考标签位置的预测错误惩罚被放大.

In [7]:
def train_epoch(epoch):
    # 思考标签占位符
    start_of_think_ids = tokenizer('<think>').input_ids
    end_of_think_ids = tokenizer('</think>').input_ids
    start_of_answer_ids = tokenizer('<answer>').input_ids
    end_of_answer_ids = tokenizer('</answer>').input_ids
    loss_fct = nn.CrossEntropyLoss(reduction='none') # ce 损失
    start_time = time.time()
    for step, (X, Y, loss_mask) in enumerate(train_loader):
        X = X.to(args.device)
        Y = Y.to(args.device)
        loss_mask = loss_mask.to(args.device)
        lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        with ctx:
            res = model(X)
            loss = loss_fct(
                res.logits.view(-1, res.logits.size(-1)),
                Y.view(-1)
            ).view(Y.size())
            # 判断思考标签是否存在于 Y 中，并返回一个 bool 类型张量，指示思考标签在 Y 中的位置
            sp_ids = torch.isin(Y.view(-1),
                                torch.tensor(start_of_think_ids + end_of_think_ids
                                             + start_of_answer_ids + end_of_answer_ids
                                             ).to(args.device))
            # 在 sp_ids 对应的位置增加额外的惩罚
            loss_mask = loss_mask.view(-1)
            loss_mask_sum = loss_mask.sum()
            loss_mask[sp_ids] = 10
            loss_mask = loss_mask.view(Y.size())
            loss = (loss * loss_mask).sum() / loss_mask_sum # 思考标签相应位置的惩罚被放大十倍
            loss += res.aux_loss # 负载均衡损失
            loss = loss / args.accumulation_steps

        scaler.scale(loss).backward()

        if (step + 1) % args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad(set_to_none=True)

        if step % args.log_interval == 0:
            spend_time = time.time() - start_time
            print(
                'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
                    epoch + 1,
                    args.epochs,
                    step,
                    iter_per_epoch,
                    loss.item(),
                    optimizer.param_groups[-1]['lr'],
                    spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))

        # 到达指定步数后 save as torch
        # if (step + 1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
        #     model.eval()
        #     moe_path = '_moe' if lm_config.use_moe else ''
        #     ckp = f'{args.save_dir}/reason_{lm_config.dim}{moe_path}.pth'

        #     if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        #         state_dict = model.module.state_dict()
        #     else:
        #         state_dict = model.state_dict()

        #     torch.save(state_dict, ckp)
        #     model.train()

接下来，我们启动一个 Epoch 的训练进行观察.

In [8]:
iter_per_epoch = len(train_loader)
for epoch in range(args.epochs):
    train_epoch(epoch)

Epoch:[1/1](0/1) loss:11.656 lr:0.000550000000 epoch_Time:0.0min:


In [9]:
del model