# Pretrain阶段的训练

### 设置训练参数

In [1]:
import torch
class pretrain_args:
    out_dir = "../out"
    epochs = 1
    batch_size = 32
    learning_rate = 5e-4
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    dtype = "bfloat16"
    use_wandb = False
    wandb_project = "MiniMind-Pretrain"
    num_workers = 1
    ddp = False
    accumulation_steps = 8
    grad_clip = 1.0
    warmup_iters = 0
    log_interval = 100
    save_interval = 100
    local_rank = -1
    embed_dim = 512
    block_num = 8
    max_seqlen = 1024
    use_moe = False
    data_path = "../data/pretrain_data.jsonl"

### 加载model

In [2]:
import sys
import os

# 获取当前 notebook 所在目录（trainer/）
current_dir = os.path.dirname(os.path.abspath("__file__"))  # 注意 Jupyter 中可能需要调整
# 或者直接写死路径
current_dir = "/data/zyp/jinbu/ZZY/minimind-v-learn/trainer"

# 上一级目录就是项目根目录，拼接 model 路径
model_dir = os.path.join(os.path.dirname(current_dir), "model")
sys.path.append(model_dir)

# 现在可以用绝对导入
from model import MinimindForCausalLM, MinimindConfig
train_args = pretrain_args()
train_args.save_dir = os.path.join(train_args.out_dir)
# 确保输出目录存在
os.makedirs(train_args.save_dir, exist_ok=True)
# 初始化模型配置
config = MinimindConfig(
    embed_dim=train_args.embed_dim,
    block_num=train_args.block_num,
    max_seqlen=train_args.max_seqlen,
)
print(f'查看工作设备 {train_args.device}')

  from .autonotebook import tqdm as notebook_tqdm


查看工作设备 cuda:0


# 单卡加载和训练（不采用DDP，wandb）

In [3]:
from transformers import AutoTokenizer
import math
from torch.utils.data import DataLoader
import sys
from pathlib import Path

# 项目根目录：/data/zyp/jinbu/ZZY/minimind-v-learn
root_dir = Path("/data/zyp/jinbu/ZZY/minimind-v-learn")

# 将根目录添加到 Python 可搜索路径
sys.path.append(str(root_dir))
from dataset.lm_dataset import PretrainDataset

def Logger(content):
    print(content)

def init_model(lm_config):
    tokenizer = AutoTokenizer.from_pretrained('../model/')
    model = MinimindForCausalLM(lm_config).to(train_args.device)
    Logger(f'LLM可训练总参数量：{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万')
    return model, tokenizer


model, tokenizer = init_model(config)
print(model)
print(tokenizer)
train_ds = PretrainDataset(
    data_path=train_args.data_path,
    tokenizer=tokenizer,
    max_seqlen=train_args.max_seqlen,
)   
train_loader = DataLoader(
    train_ds,
    batch_size=train_args.batch_size,
    shuffle=True,
    num_workers=train_args.num_workers,
    pin_memory=True,
    drop_last=False
)

loader = iter(train_loader)
print(f'打印一个 iter 的数据:\n{next(loader)}\n')
print(f'数据集大小：{len(train_ds)}, DataLoader 大小：{len(loader)}')

LLM可训练总参数量：38.075 百万
MinimindForCausalLM(
  (embed): Embed(
    (embedding): Embedding(6400, 512)
  )
  (rmsnorm): RMSNorm()
  (minimind_dense): Minimind_Dense(
    (blocks): ModuleList(
      (0-7): 8 x Minimind_Block(
        (attention): GroupQueryAttention(
          (q_proj): Linear(in_features=512, out_features=512, bias=True)
          (k_proj): Linear(in_features=512, out_features=256, bias=True)
          (v_proj): Linear(in_features=512, out_features=256, bias=True)
          (o_proj): Linear(in_features=512, out_features=512, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (res_dropout): Dropout(p=0.1, inplace=False)
        )
        (rmsnorm1): RMSNorm()
        (ffn): FeedForward(
          (gate): Linear(in_features=512, out_features=2048, bias=True)
          (up_proj): Linear(in_features=512, out_features=2048, bias=True)
          (down_proj): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace

  X=torch.tensor(input_ids[:-1],dtype=torch.long)  # 去掉最后一个token
  Y=torch.tensor(input_ids[1:],dtype=torch.long)  # 去掉第一个
  loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)


打印一个 iter 的数据:
[tensor([[  56,   56,   57,  ...,    3,    3,    3],
        [3002,  944, 1641,  ...,    3,    3,    3]]), tensor([[  56,   57, 1495,  ...,    3,    3,    3],
        [ 944, 1641,  273,  ...,    3,    3,    3]]), tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])]

数据集大小：2, DataLoader 大小：1


### 选定优化器和scaler，自动进行混合精度训练加速
[常见的optimizer](https://zhuanlan.zhihu.com/p/416979875)<br>
[最新的Muon optimizer](https://blog.csdn.net/weixin_44778145/article/details/148722786)<br>
[混合精度原理](https://www.cnblogs.com/jimchen1218/p/14315008.html)

In [4]:
# 优化器方面 选择 AdamW 优化器 并在混精度场景下创建 scaler 进行梯度缩放避免数值下溢
from torch import optim
from contextlib import nullcontext
scaler = torch.cuda.amp.GradScaler(enabled=(train_args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=train_args.learning_rate)

device_type = "cuda" if "cuda" in train_args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() # 在 cuda 上启动混精度训练，否则空白上下文

### 正式进行训练
[余弦退火学习率](https://blog.csdn.net/weixin_42392454/article/details/127766771)<br>
[CrossEntropy介绍]()<br>
[梯度裁剪-clip_grad_norm全解](https://www.hubtools.cn/2025/clip_grad_norm.html)

In [9]:
from torch import nn
import time
iter_per_epoch = len(train_loader) # 计算每个 epoch 的迭代次数
def get_lr(current_step, total_steps, lr):
    # 余弦退火学习率调度
    return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))

def train_epoch(epoch):
    loss_fct = nn.CrossEntropyLoss(reduction='none')
    start_time = time.time()
    for step, (X, Y, loss_mask) in enumerate(train_loader):
        X = X.to(train_args.device)
        Y = Y.to(train_args.device)
        loss_mask = loss_mask.to(train_args.device)

        lr = get_lr(epoch * iter_per_epoch + step, train_args.epochs * iter_per_epoch, train_args.learning_rate)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        with ctx:
            res = model(X)
            if torch.isnan(res.logits).any() or torch.isinf(res.logits).any():
                Logger(f"Warning: logits contains NaN/Inf at step {step}")
                # 打印logits的范围，辅助排查
                Logger(f"logits range: {res.logits.min().item()} ~ {res.logits.max().item()}")
            loss = loss_fct(
                res.logits.view(-1, res.logits.size(-1)),
                Y.view(-1)
            ).view(Y.size())
            print(f"loss_mask.sum(): {loss_mask.sum()}")
            loss = (loss * loss_mask).sum() / loss_mask.sum() # 这里的loss 是有效非pad的token的平均loss
            # loss += res.aux_loss
            loss = loss / train_args.accumulation_steps

        scaler.scale(loss).backward()

        if (step + 1) % train_args.accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), train_args.grad_clip)

            scaler.step(optimizer)
            scaler.update()

            optimizer.zero_grad(set_to_none=True)  # 清空梯度，为下一个iter做准备

        if step % train_args.log_interval == 0:
            spend_time = time.time() - start_time
            Logger(
                'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
                    epoch + 1,
                    train_args.epochs,
                    step,
                    iter_per_epoch,
                    loss.item() * train_args.accumulation_steps,
                    optimizer.param_groups[-1]['lr'],
                    spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))


        if (step + 1) % train_args.save_interval == 0:
            model.eval()
            moe_path = '_moe' if train_args.use_moe else ''
            ckp = f'{train_args.save_dir}/pretrain_{config.embed_dim}{moe_path}.pth'
            Logger(f'保存模型到 {ckp}')
            state_dict = model.state_dict()

            state_dict = {k: v.half() for k, v in state_dict.items()}  # 半精度保存
            torch.save(state_dict, ckp)
            model.train()

## 开始训练！ 

In [10]:
for epoch in range(train_args.epochs):
    train_epoch(epoch)

  X=torch.tensor(input_ids[:-1],dtype=torch.long)  # 去掉最后一个token
  Y=torch.tensor(input_ids[1:],dtype=torch.long)  # 去掉第一个
  loss_mask = torch.tensor(loss_mask[1:], dtype=torch.long)


logits range: nan ~ nan
loss_mask.sum(): 409
Epoch:[1/1](0/1) loss:nan lr:0.000550000000 epoch_Time:0.0min:


# !重大问题，loss 怎么是Nan????