# RL 训练v4

In [1]:
import os
# fix：https://github.com/unslothai/unsloth/issues/2299#issuecomment-2782067709
os.environ["VLLM_USE_V1"] = '0'
# 国内需要禁止统计，否则会卡在模型加载的地方（连不到外网）
os.environ["UNSLOTH_DISABLE_STATISTICS"] = "0"

from unsloth import FastLanguageModel, is_bfloat16_supported
import torch

max_seq_length = 4096 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    # 用4个元素训练好的模型去跑3个元素的题，不想再蒸馏一个模型了。
    model_name = "/data/countdown/output/models/qwen2.5-1.5b-sft-distill-merged", # change to your model path
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    local_files_only=True,
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.4, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank * 2,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 04-22 19:51:34 [__init__.py:239] Automatically detected platform cuda.
==((====))==  Unsloth 2025.3.19: Fast Qwen2 patching. Transformers: 4.51.3. vLLM: 0.8.4.
   \\   /|    NVIDIA GeForce RTX 4090. Num GPUs = 1. Max memory: 22.159 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading /data/countdown/output/models/qwen2.5-1.5b-sft-distill-merged with actual GPU utilization = 39.23%
Unsloth: Your GPU has CUDA compute capability 8.9 with VRAM = 22.16 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefi



INFO 04-22 19:51:46 [loader.py:1166] Loading weights with BitsAndBytes quantization. May take a while ...


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 04-22 19:51:47 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 04-22 19:51:47 [model_runner.py:1146] Model loading took 1.2132 GiB and 1.005270 seconds
INFO 04-22 19:51:50 [worker.py:267] Memory profiling takes 1.74 seconds
INFO 04-22 19:51:50 [worker.py:267] the current vLLM instance can use total_gpu_memory (22.16GiB) x gpu_memory_utilization (0.39) = 8.69GiB
INFO 04-22 19:51:50 [worker.py:267] model weights take 1.21GiB; non_torch_memory takes 0.08GiB; PyTorch activation peak memory takes 1.05GiB; the rest of the memory reserved for KV Cache is 6.35GiB.
INFO 04-22 19:51:50 [executor_base.py:112] # cuda blocks: 14861, # CPU blocks: 14043
INFO 04-22 19:51:50 [executor_base.py:117] Maximum concurrency for 4096 tokens per request: 58.05x
INFO 04-22 19:51:53 [model_runner.py:1456] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If

Capturing CUDA graph shapes:   0%|          | 0/27 [00:00<?, ?it/s]

INFO 04-22 19:52:21 [model_runner.py:1598] Graph capturing finished in 28 secs, took 2.87 GiB
INFO 04-22 19:52:21 [llm_engine.py:449] init engine (profile, create kv cache, warmup model) took 33.90 seconds


Sliding Window Attention is enabled but not implemented for `eager`; unexpected results may be encountered.
Unsloth 2025.3.19 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


In [2]:
from constant import SYSTEM_PROMPT, USER_PROMPT_TPL, parse_user_prompt
from datasets import load_dataset
def get_countdown_questions(data_file):
    data = load_dataset("json", data_files=data_file)["train"]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', "content": parse_user_prompt(USER_PROMPT_TPL, x["numbers"], x["target"])}
        ],
    })
    return data

dataset = get_countdown_questions("data/rl_data_simple_10k.jsonl")

dataset[0]

{'numbers': [91, 100, 44],
 'target': 147,
 'ground_truth_solution': '(91 - 44) + 100',
 'prompt': [{'content': 'You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.',
   'role': 'system'},
  {'content': 'Using the numbers 91, 100, 44, create an equation that equals 147. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once, and you must use all the numbers. Show your work in <think> </think> tags. And return the final equation in <answer> </answer> tags, for example <answer>(1 + 2) / 3</answer>. Think step by step inside <think> tags.',
   'role': 'user'}]}

In [3]:
import re
def extract_xml_answer(response):
    pattern = r"<answer>(.*?)</answer>"
    match = re.search(pattern, response, re.DOTALL)
    return match.group(1) if match else ""

# Reward function that checks if the completion follows a more relaxed format
def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r, re.DOTALL) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# 正确性奖励
def correctness_reward_func(completions, numbers, target, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]

    def correct(numbers, target, solution):
        try:
            if '=' in solution:
                # 删除掉等号以及等号后的内容
                solution = solution.split('=')[0]
            if not eval(solution) == target:
                # 如果 solution 计算结果不等于 target，则认为不正确
                return False, "Solution is not equal to target"
            # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
            allowed_pattern = r'^[\d+\-*/().\s]+$'
            if not re.match(allowed_pattern, solution):
                # 如果 solution 不符合要求，则认为不正确
                return False, "Solution is not allowed"
            # 检查 solution 是否使用了所有数字
            used_numbers = [int(n) for n in re.findall(r'\d+', solution)]
            if sorted(used_numbers) != sorted(numbers):
                # 如果 solution 没有使用所有数字没有仅使用一次，则认为不正确
                return False, "Solution is not using all numbers or each number is not used only once"
            return True, "Solution is correct"
        except Exception as e:
            # 计算失败，则认为不正确
            return False, f"Solution calculation failed: {e}"
    
    scores = []
    for i, r in enumerate(extracted_responses):
        scores.append(2 if correct(numbers[i], target[i], r)[0] else 0)
    
    return scores

In [4]:
# test reward function
from datasets import load_from_disk
eval_dataset = load_from_disk("output/test_vllm__qwen2.5-1.5b-sft-distill-lora_results")
idx = 11
eval_dataset[idx]

{'target': 656,
 'numbers': [85, 36, 16, 519],
 'prompt': 'Using the numbers 85, 36, 16, 519, create an equation that equals 656. You can use basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once, and you must use all the numbers. Show your work in <think> </think> tags. And return the final equation in <answer> </answer> tags, for example <answer>(1 + 2) / 3</answer>. Think step by step inside <think> tags.',
 'completion': "<think>Okay, let's see. I need to make an equation using 85, 36, 16, and 519, each exactly once, with basic operations, and the result should be 656. Hmm.\n\nFirst, maybe I should look at the target number, 656. Let me think if there's any way to break it down into parts that could relate to these numbers. For instance, 656 divided by something... Maybe 656 is close to 519 plus something? Wait, 519 is already given as part of the numbers. So maybe 519 plus what else would get me to 656?\n\nWait, 519 + 137 = 656. But h

In [5]:
print(f'Soft format reward: {soft_format_reward_func([[{"content": eval_dataset[idx]["completion"]}]])}')
score = correctness_reward_func(
    [[{"content": eval_dataset[idx]["completion"]}]],
    [eval_dataset[idx]["numbers"]],
    [eval_dataset[idx]["target"]])
print(f'Correctness reward: {score}')

Soft format reward: [0.5]
Correctness reward: [2]


In [6]:
import wandb
wandb.init(project="countdown-rl-simple-10k")

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mswulling[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.01,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 8, # Decrease if out of memory
    vllm_max_model_len= max_seq_length,
    max_prompt_length = 512,
    max_completion_length = 3584,
    temperature = 1.0, # set to 1.0 for more diverse responses
    #num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 1000,
    save_steps = 100,
    max_grad_norm = 0.1,
    output_dir = "output/rl4",
    beta=0.001,
    report_to = "wandb", # Can use Weights & Biases
    log_completions=True,
    logging_steps=1,
    # unsloth grpo eval bug: https://github.com/unslothai/unsloth/issues/2367
    # do_eval=True,
    # eval_strategy="steps",
    # eval_steps=1,
    # per_device_eval_batch_size = 32,
)

In [8]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        soft_format_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
    # eval_dataset = eval_dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 10,000 | Num Epochs = 1 | Total steps = 1,000
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 4 x 1) = 32
 "-____-"     Trainable parameters = 73,859,072/5,000,000,000 (1.48% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / soft_format_reward_func,rewards / correctness_reward_func
1,-0.0,2.125,0.593677,1695.03125,0.0,0.4375,1.6875


KeyboardInterrupt: 

v4: https://wandb.ai/swulling/countdown-rl-simple-10k?nw=nwuserswulling

In [None]:
model.save_pretrained("output/qwen2.5-1.5b-rl-v4-lora")  # Local saving lora weights
tokenizer.save_pretrained("output/qwen2.5-1.5b-rl-v4-lora")


```bash
vllm serve output/models/Qwen2.5-1.5B-Instruct --port 8100 --api-key NLUKKXIJDZ91rpg1z --enforce-eager  --max-model-len 4096 --enable-lora --max-lora-rank 64 --lora-modules qwen2.5-1.5b-rl-v4-lora=output/qwen2.5-1.5b-rl-v4-lora

CURATOR_VIEWER=1 python eval.py --provider vllm --data_path data/test_simple.jsonl --model_name qwen2.5-1.5b-rl-v4-lora --temperature 0.01 --max_tokens 2048

https://curator.bespokelabs.ai/datasets/fc75d8c833ba4d7984d925604337a9d5  

Accuracy: 45/100 (45.00%)
```