### 引用模型

In [66]:
from unsloth import FastModel
import torch
max_seq_length = 1536 # 模型的最大序列长度，默认是1024
lora_rank = 8 # LoRA的秩，越大越好，但会消耗更多内存 #8

model, tokenizer = FastModel.from_pretrained(
    model_name = "./models/gemma-3-1b-it", #"unsloth/gemma-3-1b-it",
    max_seq_length = max_seq_length, # 可以选择任意长度以支持长上下文！
    load_in_4bit = False,  # 4位量化以减少内存使用
    load_in_8bit = False, # 精度更高，但使用2倍内存
    full_finetuning = False, # 现在我们支持完全微调了！
    # gpu_memory_utilization = 0.85, # GPU内存使用率，如果出现OOM可以降低此值
    # token = "hf_...", # 使用受限模型时需要提供token
)

==((====))==  Unsloth 2025.3.17: Fast Gemma3 patching. Transformers: 4.50.0.dev0. vLLM: 0.8.1.
   \\   /|    NVIDIA GeForce RTX 4070 Ti SUPER. Num GPUs = 1. Max memory: 15.992 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


### 加载 Lora 设置

In [67]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # 仅处理文本层或者模型没有视觉层时关闭！
    finetune_language_layers   = True,  # 应该保持开启！
    finetune_attention_modules = True,  # 注意力机制对GRPO有好处
    finetune_mlp_modules       = True,  # 应该始终保持开启！

    r = lora_rank,           # 更大 = 更高的精度，但可能过拟合
    lora_alpha = lora_rank,  # 建议alpha至少等于r
    lora_dropout = 0,
    bias = "none",
    random_state = 3407, # 使用同一个随机数种子
)

Unsloth: Making `model.base_model.model.model` require gradients


### 加载数据集

In [68]:
from datasets import load_dataset
# json 格式
dataset2 = load_dataset("json", data_files="./datasets/ruozhiba_R1/alpaca_output_main.jsonl", split='train')

# parquet 格式
dataset = load_dataset("parquet", data_files="./datasets/gsm8k/main/train-00000-of-00001.parquet", split='train')

# 查看数据情况
print(dataset)
print(dataset[0]["question"])
print(dataset[0]["answer"])

print(dataset2)
print(dataset2[0]["instruction"])
print(dataset2[0]["output"])

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})
Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72
Dataset({
    features: ['instruction', 'input', 'output'],
    num_rows: 2008
})
别拿某个狗国太当干粮 别自以为是 狗国人多 疯狗士气长
<think>好的，我现在需要处理用户提供的这个查询：“别拿某个狗国太当干粮 别自以为是 狗国人多 疯狗士气长”。首先，我需要理解用户的意思以及背后的情感。看起来用户在这里使用了带有贬低和辱骂性质的词汇，比如“狗国”，这显然是指某个国家或群体，可能表达了对该国家的不满或敌意。

接下来，我要分析用户的需求。用户可能希望我回应或同意他们的观点，或者提供某种支持他们情绪的内容。然而，根据规则，我的回应必须遵守中国法律法规，保持中立和客观，并且不能含有任何攻击性或不当言论。因此，我不能支持或强化用户的负面情绪，也不能使用类似的侮辱性语言。

然后，我需要考虑如何以适当的方式回应用户。用户可能感到愤怒或不满，我的回应应该既保持礼貌，又引导他们以更建设性的方式表达意见。根据示例响应，正确的做法是提醒用户文明交流，强调不鼓励使用不尊重或攻击性语言，并建议基于相互尊重的对话。

此外，我要确保回应中没有违反任何规定，比如不涉及政治敏感话题，不侮辱任何国家或群体。同时，要避免提供任何可能被误解为支持用户观点的内容。需要保持中立，既不赞同也不反对，而是转向促进积极对话的方向。

另外，要注意用户可能的深层需求。他们可能希望被倾听或得到情感支持，但表达方式不当。在这种情况下，引

In [69]:
# 回答总是以####开头，对回答数据做抽取，为后续的数据集清理做准备。
def extract_hash_answer(text):
    if "####" not in text: return None
    return text.split("####")[1].strip()
extract_hash_answer(dataset[0]["answer"])

def extract_xml_answer(text: str) -> str:
    """
    从文本中提取</think>标签之后的所有内容
    
    参数:
        text: 包含</think>标签的文本
        
    返回:
        str: </think>标签之后的所有内容，去除首尾空格
    """
    if "</think>" not in text:
        return text.strip()
    answer = text.split("</think>")[-1]  # 提取</think>标签后的所有内容
    return answer.strip()  # 去除首尾空格

In [70]:
# 设置系统提示此
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

system_prompt = \
f"""你被给定了一个问题，考虑问题并提供你给出的答案。
请将思考过程放在 {reasoning_start} 和 {reasoning_end} 之间。
然后，请在 {solution_start} 和 {solution_end} 之间提供你的答案。"""
system_prompt

'你被给定了一个问题，考虑问题并提供你给出的答案。\n请将思考过程放在 <start_working_out> 和 <end_working_out> 之间。\n然后，请在 <SOLUTION> 和 </SOLUTION> 之间提供你的答案。'

In [None]:
from datasets import concatenate_datasets

# 获取数据集所有列名
original_columns = dataset.column_names

# 替换而非添加
dataset = dataset.map(
    lambda x: {
        "prompt" : [
            {"role": "system", "content": system_prompt},
            {"role": "user",   "content": x["question"]},
        ],
        "answer": extract_hash_answer(x["answer"]),
    },
    remove_columns=original_columns  # 移除所有原始列
)

print(f"dataset2 size {len(dataset2)}")
# 定义函数检查回答是否为空或只有空格/句号
def has_valid_content(output_text):
    """检查</think>标签后的内容是否有效（不是空的、只有空格或只有句号）"""
    if "</think>" not in output_text:
        return False  # 没有</think>标签，保留
    
    content_after_tag = extract_xml_answer(output_text)
    # 检查提取的内容是否为空、只有空格或只有句号
    if not content_after_tag or content_after_tag.isspace() or content_after_tag == ".":
        return False
    return True

# 过滤掉answer内容无效的条目
valid_indices = [i for i, example in enumerate(dataset2) if 'output' in example and has_valid_content(example['output'])]
dataset2 = dataset2.select(valid_indices)
print(f"Filtered dataset2 from original size to {len(dataset2)} valid examples")


original_columns2 = dataset2.column_names
dataset2 = dataset2.map(
    lambda x: {
        "prompt" : [
            {"role": "system", "content": system_prompt},
            {"role": "user",   "content": x["instruction"] if 'instruction' in x else x.get('input', '')},
        ],
        "answer": extract_xml_answer(x["output"]),
    },
    remove_columns=original_columns2
)

# 合并两个数据集，并打乱
dataset = concatenate_datasets([dataset, dataset2])
dataset = dataset.shuffle(seed=42)

print(f"Combined dataset size: {len(dataset)}")

dataset2 size 2008
Filtered dataset2 from original size to 1979 valid examples
Combined dataset size: 9452


### 定义奖励函数

In [7]:
import re

# 定义正则表达式，用来判断模型的输出是否符合格式要求
match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

In [8]:
match_format.search(
    "<start_working_out>Let me think!<end_working_out>"\
    "<SOLUTION>2</SOLUTION>",
)

<re.Match object; span=(0, 71), match='<start_working_out>Let me think!<end_working_out>>

In [9]:
# 格式匹配函数
def match_format_exactly(completions, **kwargs):
    """格式判断函数，严格判断格式是否匹配

    Args:
        completions (_type_): _description_

    Returns:
        _type_: Number 0 ｜ 3
    """
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

In [10]:
def match_format_approximately(completions, **kwargs):
    """弱格式判断奖励，即使没有严格对应，也可以根据使用的标签数量来做出相应的奖励

    Args:
        completions (_type_): _description_

    Returns:
        _type_: Number
    """
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # 数一数看到多少个关键词——如果太多，我们会惩罚你！
        # 如果我们看到1，那么加一些积分！如果更多了，那么就应当扣除一些分
        score += 0.5 if response.count(reasoning_start) == 1 else -0.5
        score += 0.5 if response.count(reasoning_end)   == 1 else -0.5
        score += 0.5 if response.count(solution_start)  == 1 else -0.5
        score += 0.5 if response.count(solution_end)    == 1 else -0.5
        scores.append(score)
    return scores

In [11]:
def check_answer(prompts, completions, answer, **kwargs):
    """通过比较提取的答案与参考答案来评估模型响应。
    
    该函数从结构化模型输出中提取答案并与参考答案进行比较，根据匹配质量分配分数：
    - 完全匹配：3.0分
    - 去除空格后匹配：1.5分
    - 数值答案在正确值10%范围内：0.5分
    - 数值答案在正确值20%范围内：0.25分
    - 错误答案：-0.5或-1.0分
    
    参数：
        prompts (list)：提供给模型的对话提示列表
        completions (list)：需要评估的模型生成的回答
        answer (list)：用于比较的参考答案
        **kwargs：额外参数
        
    返回：
        list：基于答案正确性的每个回答的得分
    """
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        # 如果完全一致，就给出 3 分 
        if guess == true_answer:
            score += 3.0
        # 如果结果正确，但是有空格，就给1.5分
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            # 如果答案接近比率，我们也会奖励它！
            # 即，如果答案在某个范围内，奖励它！
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 0.5
                elif ratio >= 0.8 and ratio <= 1.2: score += 0.25
                else: score -= 1.0 # Penalize wrong answers
            except:
                # 如果直接异常了，就抛出错误
                score -= 0.5 # Penalize
        scores.append(score)
    return scores

In [12]:
# 对于数学问题，先给数字部分抽取出来
match_numbers = re.compile(
    rf"{solution_start}.*?([\d\.]{{1,}})",
    flags = re.MULTILINE | re.DOTALL
)
match_numbers.findall("<SOLUTION>  0.34  </SOLUTION>")

['0.34']

In [13]:
def check_numbers(prompts, completions, answer, **kwargs):
    """使用正则表达式从模型输出中提取数字答案并进行评分。
    
    该函数从模型响应中提取数字，并与参考答案进行数值比较。
    如果提取的数字与正确答案完全匹配，将获得1.5分，否则为0分。
    
    参数：
        prompts (list)：提供给模型的对话提示列表
        completions (list)：需要评估的模型生成的回答
        answer (list)：用于比较的参考答案数值
        **kwargs：额外参数
        
    返回：
        list：基于数值匹配的评分列表
    """
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        # Convert to numbers
        try:
            true_answer = float(true_answer.strip())
            guess       = float(guess.strip())
            scores.append(1.5 if guess == true_answer else 0.0)
        except:
            scores.append(0)
            continue
    return scores

### 训练部分

In [14]:
max_prompt_length = 256

# 使用 GRPO 训练器，并构造训练器
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    beta = 0.0, # 设置为 0 以禁用 KL 散度惩罚 # defaults to 0.04
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # 增加到4，以便更顺滑地训练 #1
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 1000, # 训练步数
    save_steps = 250, # 每50步保存一次
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs_gemma3_1b_it_2", # 输出目录
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


开始训练。期望在训练中，看到reward列的数值增长！

有可能在开始的100步都没有奖励，你可能需要等待150-200步。

In [None]:
# 创建训练器，并且使用上面给出的 reward function
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7,473 | Num Epochs = 1 | Total steps = 1,000
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 1 x 1) = 4
 "-____-"     Trainable parameters = 6,522,880/1,006,408,832 (0.65% trained)


******************** Question:
A concert ticket costs $40. Mr. Benson bought 12 tickets and received a 5% discount for every ticket bought that exceeds 10. How much did Mr. Benson pay in all? 
Answer:
476 
Response:
<start_working_out>
Let $C$ be the cost of a concert ticket, which is $C = 40$.
Mr. Benson bought 12 tickets.
The number of tickets that are bought that exceed 10 is $12 - 10 = 2$.
For each ticket that exceeds 10, he receives a 5% discount.
The discount per ticket is $0.05 \times 40 = 2$.
So, the price of each ticket that exceeds 10 is $40 - 2 = 38$.
The total cost of 2 tickets at the discounted price is $2 \times 38 = 76$.
The total cost of 10 tickets at the regular price is $10 \times 40 = 400$.
The total cost of 12 tickets is $76 + 400 = 476$.
However, we need to consider the discount for each ticket that exceeds 10.
Tickets 11-12 are discounted by 5%.
So, the discount for the 2 tickets that exceed 10 is $2 \times 0.05 = 0.1$.
The price of each ticket that exceeds 10 is 

Step,Training Loss,reward,reward_std,completion_length,kl,rewards / match_format_exactly,rewards / match_format_approximately,rewards / check_answer,rewards / check_numbers
1,-0.0,-0.5,0.57735,910.75,0.0,0.0,-0.5,0.0,0.0
2,0.0,-0.125,1.181454,859.25,0.0,0.0,-0.5,0.0,0.375
3,0.0,0.75,0.866025,338.0,0.000728,0.0,0.0,0.0,0.75
4,0.0,1.5,0.0,255.0,0.000464,0.0,0.0,0.0,1.5
5,0.0,-0.25,0.5,122.75,0.00075,0.0,-0.25,0.0,0.0
6,0.0,-0.125,1.75,837.0,0.000379,0.0,-0.5,0.0,0.375
7,0.0,1.5,0.0,386.0,0.000716,0.0,0.0,0.0,1.5
8,0.0,-0.5,0.57735,1137.25,0.000533,0.0,-0.5,0.0,0.0
9,-0.0,-0.75,0.5,1023.25,0.000506,0.0,-0.75,0.0,0.0
10,-0.0,-0.5,0.57735,858.5,0.001441,0.0,-0.5,0.0,0.0


Unsloth: Will smartly offload gradients to save VRAM!
******************** Question:
Jane is trying to decide whether to buy a house or a trailer. A house costs $480,000 and a trailer costs $120,000. Each loan will be paid in monthly installments over 20 years. How much more is the monthly payment on the house compared to the trailer? 
Answer:
1500 
Response:
<start_working_out>
Let's calculate the monthly payment for the house.
The loan amount is $480,000.
The interest rate is 6%.
The loan term is 20 years, which is 20 * 12 = 240 months.
We will use the formula for the monthly payment of a loan: M = P [ i(1 + i)^n ] / [ (1 + i)^n – 1]
Where:
M = Monthly payment
P = Principal loan amount ($480,000)
i = Monthly interest rate (annual rate / 12 = 0.06 / 12 = 0.005)
n = Number of months (240)
Plugging in the values:
M = 480000 [ 0.005(1 + 0.005)^240 ] / [ (1 + 0.005)^240 – 1]
M = 480000 [ 0.005(1.005)^240 ] / [ (1.005)^240 – 1]
(1.005)^240 ≈ 3.26179
M = 480000 [ 0.005 * 3.26179 ] / [ 3.261

### 模型测试
#### 默认模型测试

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": "What is the sqrt of 101?"},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    max_new_tokens = 64, # Increase for longer outputs!
    # Recommended Gemma-3 settings!
    temperature = 1.0, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

NameError: name 'system_prompt' is not defined

#### 保存 Lora

In [None]:
model.save_pretrained("gemma-3")  # Local saving
tokenizer.save_pretrained("gemma-3")
# model.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving
# tokenizer.push_to_hub("HF_ACCOUNT/gemma-3", token = "...") # Online saving

In [None]:
if True: # Change to True to save finetune!
    model.save_pretrained_merged("gemma-3-finetune", tokenizer)

### 保存为完整模型

In [None]:
# if False: # Change to True to upload finetune
#     model.push_to_hub_merged(
#         "HF_ACCOUNT/gemma-3-finetune", tokenizer,
#         token = "hf_..."
#     )

In [None]:
# 保存为 GGUF 格式
# if False:
#     model.save_pretrained_gguf(
#         "gemma-3-finetune",
#         quantization_type = "Q8_0", # For now only Q8_0, BF16, F16 supported
#     )

In [None]:
# if False: # Change to True to upload GGUF
#     model.push_to_hub_gguf(
#         "gemma-3-finetune",
#         quantization_type = "Q8_0", # Only Q8_0, BF16, F16 supported
#         repo_id = "HF_ACCOUNT/gemma-finetune-gguf",
#         token = "hf_...",
#     )

# 第二部分训练

In [None]:
# 导入必要的库
import re  # 导入正则表达式库，用于字符串匹配和提取
from datasets import load_dataset, Dataset  # 导入数据集处理相关库

# 定义系统提示，指定响应格式
SYSTEM_PROMPT = """
你是一个 glsl 文件制造机，会根据用户的问题生成可运行的完整代码，你不会使用任何外部资源例如贴图，而是使用纯程序化的生成。并且你的回答遵循以下格式：
<reasoning>
...
</reasoning>
<code>
...
</code>
"""

# 定义XML格式的思维链(Chain of Thought)格式模板
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<code>
{answer}
</code>
"""

def extract_xml_answer(text: str) -> str:
    """
    从包含XML标签的文本中提取<answer>标签内的答案
    
    参数:
        text: 包含XML标签的文本
        
    返回:
        str: 提取出的答案文本，去除首尾空格
    """
    answer = text.split("<code>")[-1]  # 提取<answer>标签后的内容
    answer = answer.split("</code>")[0]  # 提取</answer>标签前的内容
    return answer.strip()  # 去除首尾空格

def extract_hash_answer(text: str) -> str | None:
    """
    从文本中提取####标记后的答案（用于处理某些特定格式的数据）
    
    参数:
        text: 包含####标记的文本
        
    返回:
        str | None: 提取出的答案文本或None（如果没有####标记）
    """
    if "####" not in text:  # 检查文本中是否有####标记
        return None
    return text.split("####")[1].strip()  # 提取####标记后的内容并去除首尾空格

# 加载数据集的函数
def get_gsm8k_questions(split = "train", local_path="./datasets/ruozhiba_R1/alpaca_output.jsonl") -> Dataset:
    """
    从本地路径加载数据集并进行处理
    
    参数:
        split: 数据集分割，默认为"train"
        local_path: 本地数据集路径
        
    返回:
        Dataset: 处理后的数据集对象
    """
    # 从本地路径加载数据集
    data = load_dataset('json', data_files=local_path, split=split)
    
    # 检查数据集结构，打印第一个样本的键
    example = data[0]
    print("Dataset keys:", example.keys())
    
    # 对数据集进行映射处理，构建适合训练的格式
    data = data.map(lambda x: {
        'prompt': [
            # 添加系统提示作为第一条消息
            {'role': 'system', 'content': SYSTEM_PROMPT},
            # 添加用户问题，优先使用'instruction'字段，如不存在则尝试其他字段
            {'role': 'user', 'content': x['instruction'] if 'instruction' in x else x.get('input', '')}
        ],
        # 在这里，我们不定义答案
    })
    return data

# 加载并处理数据集
dataset = get_gsm8k_questions(local_path="/home/projects/unsloth-training/datasets/ruozhiba_R1/alpaca_output.jsonl")

# 以下是各种奖励函数的定义，用于评估模型生成的回答质量

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """
    严格检查回答是否符合指定的XML格式
    
    格式要求: 必须严格遵循以下格式
    <reasoning>
    [推理内容，可多行]
    </reasoning>
    <code>
    [答案内容，可多行]
    </code>
    
    参数:
        completions: 模型生成的完成内容列表
        **kwargs: 额外的关键字参数
        
    返回:
        list[float]: 格式正确得0.5分，否则得0.0分
    """
    # 定义严格的XML格式正则表达式模式，要求精确匹配开始和结束标签以及换行
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<code>\n.*?\n</code>\n$"
    # 从completions中提取出模型的实际回答文本
    responses = [completion[0]["content"] for completion in completions]
    # 使用正则表达式检查格式是否匹配
    matches = [re.match(pattern, r) for r in responses]
    # 匹配成功返回0.5，否则返回0.0
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """
    宽松检查回答是否符合XML格式
    
    格式要求: 只要包含<reasoning>标签和<answer>标签即可，不严格要求换行和顺序
    
    参数:
        completions: 模型生成的完成内容列表
        **kwargs: 额外的关键字参数
        
    返回:
        list[float]: 格式正确得0.5分，否则得0.0分
    """
    # 定义宽松的XML格式正则表达式模式，只要求包含标签，不限制换行格式
    pattern = r"<reasoning>.*?</reasoning>\s*<code>.*?</code>"
    # 从completions中提取出模型的实际回答文本
    responses = [completion[0]["content"] for completion in completions]
    # 使用正则表达式检查格式是否匹配
    matches = [re.match(pattern, r) for r in responses]
    # 匹配成功返回0.5，否则返回0.0
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    """
    计算XML标签的正确使用情况，并给予分数奖励
    
    参数:
        text: 需要评估的文本
        
    返回:
        float: 根据XML标签的正确使用情况计算的分数(最高0.5分)
    """
    count = 0.0
    # 检查是否正确使用<reasoning>标签，正确得0.125分
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    # 检查是否正确使用</reasoning>标签，正确得0.125分
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    # 检查是否正确使用<answer>标签，正确得0.125分
    # 同时减去</answer>后多余内容的惩罚分
    if text.count("\n<code>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</code>\n")[-1])*0.001  # 对多余内容进行惩罚
    # 检查是否正确使用</answer>标签，正确得0.125分
    # 同时减去</answer>后多余内容的惩罚分
    if text.count("\n</code>") == 1:
        count += 0.125
        count -= (len(text.split("\n</code>")[-1]) - 1)*0.001  # 对多余内容进行惩罚
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    """
    评估回答中XML标签的正确使用情况
    
    参数:
        completions: 模型生成的完成内容列表
        **kwargs: 额外的关键字参数
        
    返回:
        list[float]: 每个回答的XML格式评分(0-0.5之间)
    """
    # 从completions中提取出模型的实际回答文本
    contents = [completion[0]["content"] for completion in completions]
    # 对每个回答文本评估XML标签使用情况
    return [count_xml(c) for c in contents]


def glsl_validation_reward_func(prompts, completions, **kwargs) -> list[float]:
    """
    验证<code>标签中的GLSL代码是否有语法错误
    
    参数:
        completions: 模型生成的完成内容列表
        **kwargs: 额外的关键字参数
        
    返回:
        list[float]: 通过验证得1.0分，否则得0.0分
    """
    import subprocess
    import tempfile
    import os
    from pathlib import Path
    
    # 从completions中提取出模型的实际回答文本
    responses = [completion[0]["content"] for completion in completions]
    rewards = []
    
    for response in responses:
        try:
            # 提取<code>和</code>标签之间的文本
            code_match = re.search(r"<code>(.*?)</code>", response, re.DOTALL)
            
            if not code_match:
                rewards.append(0.0)
                continue
                
            code = code_match.group(1).strip()
            
            # 创建临时文件保存GLSL代码
            with tempfile.NamedTemporaryFile(suffix='.glsl', delete=False) as tmp:
                tmp_path = Path(tmp.name)
                # 添加一个基本的main函数如果没有
                if "void main()" not in code:
                    code = f"void main() {{\n    {code}\n}}"
                tmp.write(code.encode('utf-8'))
            
            # 使用glslangValidator验证代码
            # 可以根据您的环境调整命令路径
            try:
                result = subprocess.run(
                    ['glslangValidator', '-S', 'frag', str(tmp_path)],
                    capture_output=True,
                    text=True,
                    timeout=10  # 设置超时，防止卡住
                )
                
                q = prompts[0][-1]['content']
                print('-'*20, f"Question:\n{q}", f"\nResponse:\n{responses[0]}" + f"\nglsl Test: {result.returncode}")
                # 如果返回码为0，则代码合法
                reward = 3.0 if result.returncode == 0 else -0.5
                rewards.append(reward)
            except (subprocess.SubprocessError, FileNotFoundError):
                # 工具不存在或执行错误
                print(f"GLSL验证执行错误: {e}")
                rewards.append(0.0)
            finally:
                # 清理临时文件
                os.unlink(tmp_path)
                
        except Exception as e:
            print(f"GLSL验证出错: {e}")
            rewards.append(0.0)
    
    return rewards

def reasoning_length_reward_func(completions, max_length=1024, **kwargs) -> list[float]:
    """
    奖励推理文本长度，文本越长奖励越高，最高1分
    
    奖励与推理文本的长度呈线性关系，直到达到max_length个字符，
    之后将给予满分2分。
    
    参数:
        completions: 模型生成的完成内容列表
        max_length: 获得最高奖励的字符数（默认：500）
        **kwargs: 额外的关键字参数
        
    返回:
        list[float]: 基于推理文本长度的奖励（0.0到1.0之间）
    """
    # 从completions中提取出模型的实际回答文本
    responses = [completion[0]["content"] for completion in completions]
    rewards = []
    
    for response in responses:
        try:
            # 提取<reasoning>和</reasoning>标签之间的文本
            reasoning_match = re.search(r"<reasoning>(.*?)</reasoning>", response, re.DOTALL)
            
            if reasoning_match:
                reasoning_text = reasoning_match.group(1).strip()
                text_length = len(reasoning_text)
                
                # 线性缩放：reward = 5.0 * min(1.0, text_length / max_length)
                # 在max_length字符或更多时给予满分5.0分
                reward = 1 * min(1.0, text_length / max_length)
                
                # 打印调试信息
                # print(f"推理长度: {text_length} 字符, 奖励: {reward:.2f}")
                
                rewards.append(reward)
            else:
                # 未找到reasoning标签
                rewards.append(0.0)
        except Exception as e:
            # 处理过程中出错
            rewards.append(0.0)
    
    return rewards


In [None]:
# 删除不再需要的模型变量
del trainer




max_prompt_length = 256

# 使用 GRPO 训练器，并构造训练器
from trl import GRPOConfig, GRPOTrainer
training_args2 = GRPOConfig(
    beta = 0.0, # 禁用 KL 散度惩罚 # defaults to 0.04
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_torch_fused",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # 增加到4，以便更顺滑地训练 #1
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 1000, # 训练步数
    save_steps = 250, # 每50步保存一次
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs_gemma3_1b_it_GSLSMaker_M1", # 输出目录
)

# 创建训练器，并且使用上面给出的
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        # match_format_exactly,
        # match_format_approximately,
        # check_answer,
        # check_numbers,
        
        xmlcount_reward_func, # 检查XML标签的正确使用（<reasoning>和<answer>标签）并给予奖励
        soft_format_reward_func, # 宽松地检查回答是否符合XML格式，只要包含标签即可
        strict_format_reward_func, # 严格检查回答是否符合XML格式，包括换行和顺序
        glsl_validation_reward_func, # 检查是否是有效的 glsl
    ],
    args = training_args2,
    train_dataset = dataset,
)
trainer.train()

NameError: name 'trainer' is not defined

In [None]:
1