# Post Training an LLM for Reasoning with GRPO in TRL

In this example, we will explore the process of post-training a LLM using **Group Relative Policy Optimization (GRPO)**, a method introduced in the paper [*DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models*](https://arxiv.org/abs/2402.03300). GRPO is particularly effective for **scaling test-time compute for extended reasoning**, making it an ideal approach for solving complex tasks, such as mathematical problem solving.

**GRPO** is a **reinforcement learning (RL) post-training technique** that was integrated into the training pipeline for DeepSeek-R1. Unlike earlier technqiues that relied on search-heuristic methods, GRPO exclusively employs RL for post-training, enhancing the model's capacity to handle complex and nuanced tasks.

GRPO is available in the `trl` library and we can also check the HuggingFace's [Open-R1](https://github.com/huggingface/open-r1) repository to reproduce the full DeepSeek-R1 training process.

In this example, we will focus on **post-training with GRPO**.

## Setups

In [1]:
!pip install -qU trl peft math_verify
# Tested with transformers==4.47.1, trl==0.14.0, datasets==3.2.0, peft==0.14.0, accelerate==1.2.1, math_verify==0.3.3

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/336.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m336.4/336.4 kB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/411.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m411.1/411.1 kB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/207.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

## Load dataset

The reasoning models excel at tasks that require **complex reasoning**. A prime example is **mathematical problem-solving**, which often demands multi-step reasoning to arrive at a correct solution.

In this example, we will use the [`AI-MO/NuminaMath-TIR`](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. This is a **reasoning-focused dataset** that contains mathematical problems, their solutions, and detailed reasoning steps that explain how to transition from the problem statement to the final solution.

In [2]:
from datasets import load_dataset

dataset_id = 'AI-MO/NuminaMath-TIR'
train_dataset, test_dataset = load_dataset(
    dataset_id,
    split=['train[:5%]', 'test[:5%]']
)

README.md:   0%|          | 0.00/2.43k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/147M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/215k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/72441 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/99 [00:00<?, ? examples/s]

In [3]:
train_dataset

Dataset({
    features: ['problem', 'solution', 'messages'],
    num_rows: 3622
})

In [4]:
train_dataset[0]

{'problem': 'What is the coefficient of $x^2y^6$ in the expansion of $\\left(\\frac{3}{5}x-\\frac{y}{2}\\right)^8$?  Express your answer as a common fraction.',
 'solution': "To determine the coefficient of \\(x^2y^6\\) in the expansion of \\(\\left(\\frac{3}{5}x - \\frac{y}{2}\\right)^8\\), we can use the binomial theorem.\n\nThe binomial theorem states:\n\\[\n(a + b)^n = \\sum_{k=0}^{n} \\binom{n}{k} a^{n-k} b^k\n\\]\n\nIn this case, \\(a = \\frac{3}{5}x\\), \\(b = -\\frac{y}{2}\\), and \\(n = 8\\).\n\nWe are interested in the term that contains \\(x^2y^6\\). In the general term of the binomial expansion:\n\\[\n\\binom{8}{k} \\left(\\frac{3}{5}x\\right)^{8-k} \\left(-\\frac{y}{2}\\right)^k\n\\]\n\nTo get \\(x^2\\), we need \\(8 - k = 2\\), thus \\(k = 6\\).\n\nSubstituting \\(k = 6\\) into the expression:\n\\[\n\\binom{8}{6} \\left(\\frac{3}{5}x\\right)^{8-6} \\left(-\\frac{y}{2}\\right)^6 = \\binom{8}{6} \\left(\\frac{3}{5}x\\right)^2 \\left(-\\frac{y}{2}\\right)^6\n\\]\n\nNow, we w

In the DeepSeek-R1 training procedure, a specific system prompt was used to generate a conversational pipeline that includes reasoning steps. We will adapt our dataset to follow this approach, where the model is guided to first think through the problem and then present its answer.

In [5]:
SYSTEM_PROMPT = """
A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think><answer> answer here </answer>
"""

def make_conversation(example):
    return {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': example['problem']}
        ]
    }


train_dataset = train_dataset.map(make_conversation)
test_dataset = test_dataset.map(make_conversation)

Map:   0%|          | 0/3622 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

In [6]:
train_dataset[0]['prompt']

[{'content': '\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it.\nThe assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nThe reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,\n<think> reasoning process here </think><answer> answer here </answer>\n',
  'role': 'system'},
 {'content': 'What is the coefficient of $x^2y^6$ in the expansion of $\\left(\\frac{3}{5}x-\\frac{y}{2}\\right)^8$?  Express your answer as a common fraction.',
  'role': 'user'}]

We will remove `messages` and `problem` columns since we only need the `prompt` column and `solution` to verify the generated answer.

In [7]:
train_dataset = train_dataset.remove_columns(['messages', 'problem'])
test_dataset = test_dataset.remove_columns(['messages', 'problem'])

In [8]:
train_dataset

Dataset({
    features: ['solution', 'prompt'],
    num_rows: 3622
})

## Post-training the base model using GRPO

### Loading the baseline model

We will first load [`Qwen/Qwen2-0.5B-Instruct`](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the baseline model (Policy model in RL) in this exmaple.

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = 'Qwen/Qwen-2-0.5B-Instruct'
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype='auto',
    device_map='auto'
)

### Configure LoRA

Next, we will configure LoRA for model training.

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    task_type='CAUSAL_LM',
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=['q_proj', 'v_proj']
)

model = get_peft_model(model, lora_config)

model.print_trainable_parameters()

### Load reward functions

For the reward component of the system, we can use either pretrained reward models or reward functions defined directly in code.

For training, the DeepSeek-R1 authors used an accuracy-based reward model evaluates whether the response is correct, alongside a format-based reward that ensures the model places its reasoning process between `<think> </think>` tags.

We will implement these reward functions as generic Python functions.

1. **Format Enforcement**: ensures that the generation follows a specific format using `<think> </think> <answer> </answer>` tags for reasoning.

In [None]:
import re

def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format"""
    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    completion_contents = [completion[0]['content'] for completion in completions]

    matches = [re.match(pattern, content) for content in completion_contents]
    rewards_list = [1.0 if _match else 0.0 for _match in matches]

    return rewards_list

2. **Solution Accuracy**: verifies whether the solution to the problem is correct.

In [None]:
from math_verify import LatexExtractionConfig, parse, verify

def accuracy_reward(completions, **kwargs):
    """Reward function that checks if the completion is the same as the ground truth"""
    solutions = kwargs['solution']
    completion_contents = [completion[0]['content'] for completion in completions]

    rewards = []
    for content, solution in zip(completion_contents, solutions):
        gold_parsed = parse(
            solution,
            extraction_mode='first_match',
            extraction_config=[LatexExtractionConfig()]
        )
        answer_parsed = parse(
            content,
            extraction_mode='first_match',
            extraction_config=[LatexExtractionConfig()]
        )

        if len(gold_parsed) != 0:
            try:
                rewards.append(float(verify(answer_parsed, gold_parsed)))
            except Exception:
                rewards.append(0.0)
        else:
            rewards.append(1.0)

    return rewards

### Configure GRPO training parameters

In [None]:
from trl import GRPOConfig

training_args = GRPOConfig(
    output_dir='Qwen2-0.5B-GRPO-test',
    learning_rate=1e-5,
    remove_unused_columns=False, # to access the solution column in accuracy_reward
    gradient_accumulation_steps=16,
    num_train_epochs=1,
    bf16=True,
    # Parameters controling the data preprocessing
    max_completion_length=64, # default 256
    num_generations=4, # default 8
    max_prompt_length=128, # default 512
    # Parameters related to reporting and saving
    report_to=['tensorboard'],
    logging_steps=10,
    push_to_hub=False,
    save_strategy='steps',
    save_steps=10
)

### Train the model

In [None]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model=model,
    reward_funcs=[
        format_reward,
        accuracy_reward
    ],
    args=training_args,
    train_dataset=train_dataset
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(training_args.output_dir)
trainer.push_to_hub(dataset_name=dataset_id)

## Check the model performance

In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = 'sergiopaniego/Qwen2-0.5B-GRPO'

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype='auto',
    device_map='auto'
)

tokenizer_config.json:   0%|          | 0.00/1.35k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/80.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/367 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/719 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

In [10]:
test_dataset['prompt'][0]

[{'content': '\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it.\nThe assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nThe reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,\n<think> reasoning process here </think><answer> answer here </answer>\n',
  'role': 'system'},
 {'content': "In 1988, a person's age was equal to the sum of the digits of their birth year. How old was this person?",
  'role': 'user'}]

We will create a function to interact with the model. In addition to generating the answer, we will measure the inference duration and count the number of generation tokens. This will give us insights into how much the model has reasoned during generation.

In [13]:
import time

def generate_with_reasoning(prompt):
    # Build the prompt from the dataset
    prompt = " ".join(entry['content'] for entry in prompt)

    # Tokenize and move to the same device as the model
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)

    # Generate text without gradients
    start_time = time.time()
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_length=500)
    end_time = time.time()

    # Decode and extract model response
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    # Get inference time
    inference_duration = end_time - start_time

    # Get number of generated tokens
    num_input_tokens = inputs['input_ids'].shape[1]
    num_generated_tokens = output_ids.shape[1] - num_input_tokens

    return generated_text, inference_duration, num_generated_tokens

In [11]:
prompt = test_dataset['prompt'][0]
prompt

[{'content': '\nA conversation between User and Assistant. The user asks a question, and the Assistant solves it.\nThe assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nThe reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,\n<think> reasoning process here </think><answer> answer here </answer>\n',
  'role': 'system'},
 {'content': "In 1988, a person's age was equal to the sum of the digits of their birth year. How old was this person?",
  'role': 'user'}]

In [16]:
generated_text, inference_duration, num_generated_tokens = generate_with_reasoning(prompt)
print(generated_text)


A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
<think> reasoning process here </think><answer> answer here </answer>
 In 1988, a person's age was equal to the sum of the digits of their birth year. How old was this person?<think> 1988 + digit1 + digit2 + digit3 = digit4 </think><answer> 1988 + 0 + 8 + 8 = 1988 </answer>

Assistant: 1988 + 0 + 8 + 8 = 1988
1988 + digit1 + digit2 + digit3 = digit4

1988 + 0 + 8 + 8 = 1988
1988 + digit1 + digit2 + digit3 = digit4
1988 + digit1 + digit2 + digit3 = digit4
1988 + digit1 + digit2 + digit3 = digit4
1988 + digit1 + digit2 + digit3 = digit4
1988 + digit1 + digit2 + digit3 = digit4
1988 + digit1 + digit2 + digit3 = digit4
1988 + digit1 + digit2 + digit3 = digit4
1988 + 

In [17]:
print(f"Inference time: {inference_duration:.2f} seconds")
print(f"Generated tokens: {num_generated_tokens}")

Inference time: 97.94 seconds
Generated tokens: 386
