# ERNIE-4.5-0.3B-PT LoRA 微调教程

本教程演示如何使用 LoRA 技术对 ERNIE-4.5-0.3B-PT 模型进行微调，构建一个能够模拟甄嬛对话风格的个性化 LLM。

## 环境要求
- Python 3.12+
- CUDA 12.4+
- PyTorch 2.5.1+
- 显存需求：约 24GB

## 1. 导入必要的库

In [7]:
import os
import json
import torch
from datasets import Dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    TrainingArguments, 
    Trainer, 
    DataCollatorForSeq2Seq
)
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
import swanlab
from swanlab.integration.transformers import SwanLabCallback
from config import Paths, LoRAConfig, TrainingConfig, SwanLabConfig, InferenceConfig

print("✅ 库导入成功")

✅ 库导入成功


## 2. 检查环境和路径

In [11]:
# 检查模型路径
print(f"模型路径: {Paths.MODEL_PATH}")
print(f"模型存在: {os.path.exists(Paths.MODEL_PATH)}")

# 检查数据集路径
print(f"数据集路径: {Paths.DATASET_PATH}")
print(f"数据集存在: {os.path.exists(Paths.DATASET_PATH)}")

# 检查 CUDA
print(f"CUDA 可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU 数量: {torch.cuda.device_count()}")
    print(f"当前 GPU: {torch.cuda.get_device_name()}")

模型路径: d:\Sources\self-llm\examples\ERNIE-4.5-0.3B-PT-Lora\models\PaddlePaddle\ERNIE-4___5-0___3B-PT
模型存在: True
数据集路径: d:\Sources\self-llm\examples\ERNIE-4.5-0.3B-PT-Lora\..\..\dataset\huanhuan.json
数据集存在: True
CUDA 可用: True
GPU 数量: 1
当前 GPU: NVIDIA GeForce RTX 2080 Ti


## 3. 加载数据集

In [9]:
def load_dataset(data_path):
    """加载甄嬛数据集"""
    print(f"加载数据集: {data_path}")
    with open(data_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    print(f"数据集大小: {len(data)}")
    return Dataset.from_list(data)

# 加载数据集
dataset = load_dataset(Paths.DATASET_PATH)

# 查看数据集样例
print("\n数据集样例:")
for i in range(min(3, len(dataset))):
    print(f"样例 {i+1}:")
    print(f"  指令: {dataset[i]['instruction']}")
    print(f"  输出: {dataset[i]['output']}")
    print()

加载数据集: d:\Sources\self-llm\examples\ERNIE-4.5-0.3B-PT-Lora\..\..\dataset\huanhuan.json
数据集大小: 3729

数据集样例:
样例 1:
  指令: 小姐，别的秀女都在求中选，唯有咱们小姐想被撂牌子，菩萨一定记得真真儿的——
  输出: 嘘——都说许愿说破是不灵的。

样例 2:
  指令: 这个温太医啊，也是古怪，谁不知太医不得皇命不能为皇族以外的人请脉诊病，他倒好，十天半月便往咱们府里跑。
  输出: 你们俩话太多了，我该和温太医要一剂药，好好治治你们。

样例 3:
  指令: 嬛妹妹，刚刚我去府上请脉，听甄伯母说你来这里进香了。
  输出: 出来走走，也是散心。



## 4. 加载模型和 Tokenizer

In [10]:
# 加载 tokenizer
print("加载 tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    Paths.MODEL_PATH, 
    use_fast=False, 
    trust_remote_code=True
)
print("✅ tokenizer 加载成功")

# 加载模型
print("\n加载模型...")
model = AutoModelForCausalLM.from_pretrained(
    Paths.MODEL_PATH, 
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)
print("✅ 模型加载成功")

# 启用梯度检查点
model.enable_input_require_grads()
print("✅ 梯度检查点已启用")

加载 tokenizer...
✅ tokenizer 加载成功

加载模型...


ValueError: The checkpoint you are trying to load has model type `ernie4_5` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

You can update Transformers with the command `pip install --upgrade transformers`. If this does not work, and the checkpoint is very new, then there may not be a release version that supports this model yet. In this case, you can get the most up-to-date code by installing Transformers from source with the command `pip install git+https://github.com/huggingface/transformers.git`

## 5. 查看模型结构

In [None]:
# 打印模型结构
print("模型结构:")
print(model)

# 查看模型参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n总参数量: {total_params:,}")
print(f"可训练参数量: {trainable_params:,}")

## 6. 数据预处理

In [None]:
def process_func(example, tokenizer):
    """数据预处理函数"""
    input_ids, attention_mask, labels = [], [], []
    
    # 适配 chat_template
    instruction = tokenizer(
        f"<|begin_of_sentence|>现在你要扮演皇帝身边的女人--甄嬛\n" 
        f"User: {example['instruction']}\n"  
        f"Assistant: ",  
        add_special_tokens=False   
    )
    response = tokenizer(f"{example['output']}<|end_of_sentence|>", add_special_tokens=False)
    
    # 拼接 input_ids
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    # 注意力掩码
    attention_mask = [1] * len(input_ids)
    # 标签，instruction 部分使用 -100 表示不计算 loss
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    
    # 截断处理
    if len(input_ids) > TrainingConfig.MAX_LENGTH:
        input_ids = input_ids[:TrainingConfig.MAX_LENGTH]
        attention_mask = attention_mask[:TrainingConfig.MAX_LENGTH]
        labels = labels[:TrainingConfig.MAX_LENGTH]
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

# 数据预处理
print("开始数据预处理...")
tokenized_dataset = dataset.map(
    lambda example: process_func(example, tokenizer),
    remove_columns=dataset.column_names
)
print("✅ 数据预处理完成")

# 查看预处理后的数据
print(f"\n预处理后数据集大小: {len(tokenized_dataset)}")
print(f"样例数据形状: {tokenized_dataset[0]}")

## 7. 配置 LoRA

In [None]:
# 配置 LoRA
print("配置 LoRA...")
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    target_modules=LoRAConfig.TARGET_MODULES,
    inference_mode=False,
    r=LoRAConfig.R,
    lora_alpha=LoRAConfig.LORA_ALPHA,
    lora_dropout=LoRAConfig.LORA_DROPOUT
)

print(f"LoRA 配置:")
print(f"  秩 (r): {LoRAConfig.R}")
print(f"  Alpha: {LoRAConfig.LORA_ALPHA}")
print(f"  Dropout: {LoRAConfig.LORA_DROPOUT}")
print(f"  目标模块: {LoRAConfig.TARGET_MODULES}")

# 应用 LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print("✅ LoRA 配置完成")

## 8. 配置训练参数

In [None]:
# 训练参数
training_args = TrainingArguments(
    output_dir=Paths.OUTPUT_DIR,
    per_device_train_batch_size=TrainingConfig.PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=TrainingConfig.GRADIENT_ACCUMULATION_STEPS,
    logging_steps=TrainingConfig.LOGGING_STEPS,
    num_train_epochs=TrainingConfig.NUM_TRAIN_EPOCHS,
    save_steps=TrainingConfig.SAVE_STEPS,
    learning_rate=TrainingConfig.LEARNING_RATE,
    save_on_each_node=True,
    gradient_checkpointing=TrainingConfig.GRADIENT_CHECKPOINTING,
    report_to="none",
)

print("训练参数:")
print(f"  批量大小: {TrainingConfig.PER_DEVICE_TRAIN_BATCH_SIZE}")
print(f"  梯度累积步数: {TrainingConfig.GRADIENT_ACCUMULATION_STEPS}")
print(f"  学习率: {TrainingConfig.LEARNING_RATE}")
print(f"  训练轮次: {TrainingConfig.NUM_TRAIN_EPOCHS}")
print(f"  输出目录: {Paths.OUTPUT_DIR}")

## 9. 配置 SwanLab（可选）

In [None]:
# SwanLab 回调（可选）
try:
    swanlab_callback = SwanLabCallback(
        project=SwanLabConfig.PROJECT_NAME, 
        experiment_name=SwanLabConfig.EXPERIMENT_NAME
    )
    callbacks = [swanlab_callback]
    print("✅ SwanLab 回调配置成功")
except Exception as e:
    print(f"⚠️  SwanLab 配置失败: {e}")
    print("将不使用 SwanLab 进行可视化")
    callbacks = []

## 10. 开始训练

In [None]:
# 创建 Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=callbacks
)

print("开始训练...")
print("注意：训练过程可能需要较长时间，请耐心等待")

# 开始训练
trainer.train()

# 保存模型
print("保存模型...")
trainer.save_model()

print("🎉 训练完成！")

## 11. 加载训练好的模型进行推理

In [None]:
# 查找最新的 checkpoint
checkpoints = []
if os.path.exists(Paths.OUTPUT_DIR):
    for item in os.listdir(Paths.OUTPUT_DIR):
        if item.startswith("checkpoint-"):
            checkpoints.append(item)

if checkpoints:
    latest_checkpoint = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))[-1]
    lora_path = os.path.join(Paths.OUTPUT_DIR, latest_checkpoint)
    print(f"使用 checkpoint: {latest_checkpoint}")
else:
    print("未找到训练好的模型")
    lora_path = None

In [None]:
if lora_path:
    # 重新加载基础模型
    base_model = AutoModelForCausalLM.from_pretrained(
        Paths.MODEL_PATH, 
        device_map="auto",
        torch_dtype=torch.bfloat16, 
        trust_remote_code=True
    )
    
    # 加载 LoRA 权重
    inference_model = PeftModel.from_pretrained(base_model, model_id=lora_path)
    print("✅ 推理模型加载成功")

## 12. 推理测试

In [None]:
def generate_response(model, tokenizer, prompt, system_message=None):
    """生成回复"""
    if system_message is None:
        system_message = InferenceConfig.SYSTEM_MESSAGE
        
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": prompt}
    ]
    
    # 应用 chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # 编码输入
    model_inputs = tokenizer([text], add_special_tokens=False, return_tensors="pt").to(model.device)
    
    # 生成回复
    with torch.no_grad():
        generated_ids = model.generate(
            model_inputs.input_ids,
            max_new_tokens=InferenceConfig.MAX_NEW_TOKENS,
            do_sample=InferenceConfig.DO_SAMPLE,
            temperature=InferenceConfig.TEMPERATURE,
            top_p=InferenceConfig.TOP_P,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # 解码输出
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    response = tokenizer.decode(output_ids, skip_special_tokens=True).strip()
    
    return response

# 测试对话
if lora_path:
    test_prompts = [
        "你是谁？",
        "你的家人都有谁？",
        "你最喜欢什么？",
        "你对皇上有什么看法？"
    ]
    
    print("开始测试对话...")
    print("=" * 50)
    
    for i, prompt in enumerate(test_prompts, 1):
        print(f"\n测试 {i}: {prompt}")
        print("-" * 30)
        
        try:
            response = generate_response(inference_model, tokenizer, prompt)
            print(f"甄嬛: {response}")
        except Exception as e:
            print(f"生成回复失败: {e}")
        
        print("-" * 30)

## 13. 交互式对话

In [None]:
# 交互式对话（在 Jupyter 中可能需要特殊处理）
if lora_path:
    print("交互式对话模式（输入 'quit' 退出）:")
    print("=" * 50)
    
    # 在 Jupyter 中，可以修改这个 cell 来测试不同的输入
    user_input = "你在宫中的生活如何？"  # 修改这里来测试不同的问题
    
    if user_input and user_input.lower() not in ['quit', 'exit', '退出']:
        try:
            response = generate_response(inference_model, tokenizer, user_input)
            print(f"你: {user_input}")
            print(f"甄嬛: {response}")
        except Exception as e:
            print(f"生成回复失败: {e}")

## 总结

恭喜！您已经成功完成了 ERNIE-4.5-0.3B-PT 模型的 LoRA 微调。

### 主要步骤回顾：
1. 环境配置和依赖安装
2. 数据集加载和预处理
3. 模型和 tokenizer 加载
4. LoRA 配置和应用
5. 训练参数设置
6. 模型训练
7. 推理测试

### 下一步可以尝试：
- 调整 LoRA 参数（r, alpha, dropout）
- 修改训练参数（学习率、批量大小等）
- 使用更多的数据进行训练
- 尝试不同的目标模块组合

### 相关资源：
- [LoRA 原理详解](https://zhuanlan.zhihu.com/p/650197598)
- [SwanLab 官网](https://swanlab.cn/)
- [ERNIE-4.5-0.3B-PT 模型](https://www.modelscope.cn/models/PaddlePaddle/ERNIE-4.5-0.3B-PT)