# GRPO

## Setup

In [2]:
pip install trl

Collecting trl
  Downloading trl-0.27.1-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.27.1-py3-none-any.whl (532 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m532.9/532.9 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl
Successfully installed trl-0.27.1


In [3]:
import torch
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from trl import GRPOTrainer, GRPOConfig
from datasets import load_dataset, Dataset
import re
import pandas as pd
from tqdm import tqdm

## Global Params

In [None]:
SYSTEM_PROMPT = (
    "You are a helpful assistant that solves problems step-by-step. "
    "Always include the final numeric answer inside \\boxed{}."
)

In [None]:
MAX_TOKEN_SIZE = 100
MODEL_NAME = 'Qwen/Qwen2-0.5B-Instruct'
USE_ACCELERATOR = True

## Helpers

In [25]:

def generate_responses(model, tokenizer, user_message=None, system_message=None, max_new_tokens=MAX_TOKEN_SIZE, full_message=None):
    # Format chat using tokenizer's chat template
    if full_message:
        messages = full_message
    else:
        messages = []
        if system_message:
            messages.append({"role": "system", "content": system_message})
        messages.append({"role": "user", "content": user_message})

    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False,
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    input_len = inputs["input_ids"].shape[1]
    generated_ids = outputs[0][input_len:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

    return response

In [26]:
def test_model_with_questions(model, tokenizer, questions,
                              system_message=None, title="Model Output"):
    print(f"\n=== {title} ===")
    for i, question in enumerate(questions, 1):
        response = generate_responses(model, tokenizer, question,
                                      system_message)
        print(f"\nModel Input {i}:\n{question}\nModel Output {i}:\n{response}\n")


In [27]:
def load_model_and_tokenizer(model_name, use_accelerator=False):

    # Load base model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    if use_accelerator:
        if torch.cuda.is_available():
            device = "cuda"
            print(f"Using CUDA device: {device}")
        elif hasattr(xm, 'xla_device') and xm.xla_device().type == 'xla': # More robust XLA check
            device = xm.xla_device()
            print(f"Using XLA device: {device}")
        else:
            device = "cpu"
            print("No accelerator found (CUDA or XLA), falling back to CPU.")
        model.to(device)
    else:
        device = "cpu"
        model.to(device)
        print("Accelerator disabled, falling back to CPU.")

    if not tokenizer.chat_template:
        tokenizer.chat_template = """{% for message in messages %}
                {% if message['role'] == 'system' %}System: {{ message['content'] }}\n
                {% elif message['role'] == 'user' %}User: {{ message['content'] }}\n
                {% elif message['role'] == 'assistant' %}Assistant: {{ message['content'] }} <|endoftext|>

                {% endif %}
                {% endfor %}"""

    # Tokenizer config
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer

In [28]:
def display_dataset(dataset):
    # Visualize the dataset
    rows = []
    for i in range(3):
        example = dataset[i]
        user_msg = next(m['content'] for m in example['messages']
                        if m['role'] == 'user')
        assistant_msg = next(m['content'] for m in example['messages']
                             if m['role'] == 'assistant')
        rows.append({
            'User Prompt': user_msg,
            'Assistant Response': assistant_msg
        })

    # Display as table
    df = pd.DataFrame(rows)
    pd.set_option('display.max_colwidth', None)  # Avoid truncating long strings
    display(df)

# Prep

In [29]:
def reward_func(completions, ground_truth, **kwargs):
    # Regular expression to capture content inside \boxed{}
    matches = [re.search(r"\\boxed\{(.*?)\}", completion[0]['content']) for completion in completions]
    contents = [match.group(1) if match else "" for match in matches]
    # Reward 1 if the content is the same as the ground truth, 0 otherwise
    return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]

In [30]:
sample_pred = [[{"role": "assistant",
                 "content": r"...Calculating the answer. \boxed{72}"}]]
ground_truth = ["72"]
reward = reward_func(sample_pred, ground_truth)
print(f"Positive Sample Reward: {reward}")

Positive Sample Reward: [1.0]


In [31]:
sample_pred = [[{"role": "assistant",
                 "content": r"...Calculating the answer \boxed{71}"}]]
ground_truth = ["72"]
reward = reward_func(sample_pred, ground_truth)
print(f"Negative Sample Reward: {reward}")

Negative Sample Reward: [0.0]


# Load the eval dataset

In [32]:
data_num = 5
eval_dataset = load_dataset("openai/gsm8k", "main")["test"].select(range(data_num))
sample_df = eval_dataset.to_pandas()
display(sample_df)

Unnamed: 0,question,answer
0,Janet’s ducks lay 16 eggs per day. She eats th...,Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eg...
1,A robe takes 2 bolts of blue fiber and half th...,It takes 2/2=<<2/2=1>>1 bolt of white fiber\nS...
2,Josh decides to try flipping a house. He buys...,The cost of the house and repairs came out to ...
3,James decides to run 3 sprints 3 times a week....,He sprints 3*3=<<3*3=9>>9 times\nSo he runs 9*...
4,"Every day, Wendi feeds each of her chickens th...","If each chicken eats 3 cups of feed per day, t..."


In [33]:
def post_processing(example):
    match = re.search(r"####\s*(-?\d+)", example["answer"])
    example["ground_truth"] = match.group(1) if match else None
    example["prompt"] = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": example["question"]}
    ]
    return example
eval_dataset = eval_dataset.map(post_processing).remove_columns(["question", "answer"])


In [34]:
sample_df = eval_dataset.select(range(5)).to_pandas()
display(sample_df)

Unnamed: 0,ground_truth,prompt
0,18,[{'content': 'You are a helpful assistant that...
1,3,[{'content': 'You are a helpful assistant that...
2,70000,[{'content': 'You are a helpful assistant that...
3,540,[{'content': 'You are a helpful assistant that...
4,20,[{'content': 'You are a helpful assistant that...


# load base model and evaluate base model

In [35]:
model, tokenizer = load_model_and_tokenizer(MODEL_NAME, USE_ACCELERATOR)

Using CUDA device: cuda


In [37]:
# Store predictions and ground truths
all_preds = []
all_labels = []

for example in tqdm(eval_dataset):
    input_prompt = example["prompt"]
    ground_truth = example["ground_truth"]
    # Run the model to generate an answer
    with torch.no_grad():
        response = generate_responses(model, tokenizer, full_message=input_prompt)
    all_preds.append([{"role": "assistant", "content": response}])
    all_labels.append(ground_truth)
    print(response)
    print("Ground truth: ", ground_truth)

# 3. Evaluate using reward_func
rewards = reward_func(all_preds, all_labels)

# 4. Report accuracy
accuracy = sum(rewards) / len(rewards)
print(f"Evaluation Accuracy: {accuracy:.2%}")
del model, tokenizer

  0%|          | 0/5 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
 20%|██        | 1/5 [00:08<00:33,  8.26s/it]

Janet's ducks lay 16 eggs per day, so they lay 16 x 3 = 48 eggs for breakfast.
She also bakes muffins with 4 eggs each day, so she bakes 4 x 4 = 16 eggs.
So far, she has eaten 48 + 16 = 64 eggs.
The remainder of eggs is 16 - 64 = -48 eggs.
Since she can't have negative eggs
Ground truth:  18


 40%|████      | 2/5 [00:11<00:16,  5.58s/it]

If a robe requires 2 bolts of blue fiber, then the number of bolts of white fiber is half the number of bolts of blue fiber.
So, the number of bolts of white fiber is $2/2 = 1$.
To find the total number of bolts required, we add the number of bolts of blue fiber and the number of bolts of white fiber: $2 + 1 = 3$.
The answer is $\boxed{3}$.
Ground truth:  3


 60%|██████    | 3/5 [00:14<00:08,  4.37s/it]

The value of the house after putting in repairs is $80,000 + ($50,000 * 150%) = $80,000 + ($50,000 * 1.5) = $80,000 + $75,000 = $155,000.
To find the profit, we subtract the cost from the value: $155,000
Ground truth:  70000


 80%|████████  | 4/5 [00:17<00:03,  3.83s/it]

James runs 3 sprints 3 times a week, so he runs 3 x 2 = 6 sprints per day.
Each sprint is 60 meters long, so he runs 6 x 60 = 360 meters per day.
There are 7 days in a week, so he runs a total of 360 x 7 = 2520 meters per week.
The answer is: $\boxed{2520}$
Ground truth:  540


100%|██████████| 5/5 [00:21<00:00,  4.22s/it]

Wendi gives her chickens 15 + 25 = 40 cups of feed in total.
Since each chicken gets 3 cups of feed per meal, then the number of chickens is 40 / 3 = 13.
If there are 20 chickens, then she needs to give out 13 x 3 = 39 cups of feed for the last meal.
The answer is: $\boxed{39}$
Ground truth:  20
Evaluation Accuracy: 20.00%





# load training dataset

In [39]:
dataset = load_dataset("openai/gsm8k", "main")
train_dataset = dataset["train"]

# Apply to dataset
train_dataset = train_dataset.map(post_processing)
train_dataset = train_dataset.remove_columns(["question", "answer"])

train_dataset = train_dataset.select(range(10))
print(train_dataset[0])

{'ground_truth': '72', 'prompt': [{'content': 'You are a helpful assistant that solves problems step-by-step. Always include the final numeric answer inside \\boxed{}.', 'role': 'system'}, {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'role': 'user'}]}


# Train

In [41]:
config = GRPOConfig(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_generations=4, # Can set as high as 64 or 128
    num_train_epochs=1,
    learning_rate=5e-6,
    logging_steps=2,
    no_cuda= not USE_ACCELERATOR     # keeps the whole run on CPU, incl. MPS
)

In [42]:
## If this block hangs or the kernel restarts during training, please skip loading the previous 0.5B model for evaluation

model, tokenizer = load_model_and_tokenizer(MODEL_NAME, USE_ACCELERATOR)

grpo_trainer = GRPOTrainer(
    model=model,
    args=config,
    reward_funcs=reward_func,
    train_dataset=train_dataset
)

grpo_trainer.train()

Using CUDA device: cuda


The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 3


[34m[1mwandb[0m: You chose "Don't visualize my results"
[34m[1mwandb[0m: Using W&B in offline mode.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
2,-0.0331
4,0.0674




TrainOutput(global_step=5, training_loss=0.017966465651988985, metrics={'train_runtime': 355.2527, 'train_samples_per_second': 0.028, 'train_steps_per_second': 0.014, 'total_flos': 0.0, 'train_loss': 0.017966465651988985})

# Test GRPO trained model

In [43]:
model = grpo_trainer.model

# Store predictions and ground truths
all_preds = []
all_labels = []

for example in tqdm(eval_dataset):
    input_prompt = example["prompt"]
    ground_truth = example["ground_truth"]
    # Run the model to generate an answer
    with torch.no_grad():
        response = generate_responses(model, tokenizer,
                                      full_message = input_prompt)
    all_preds.append([{"role": "assistant", "content": response}])
    all_labels.append(ground_truth)
    print(response)
    print("Ground truth: ", ground_truth)

# 3. Evaluate using reward_func
rewards = reward_func(all_preds, all_labels)

# 4. Report accuracy
accuracy = sum(rewards) / len(rewards)
print(f"Evaluation Accuracy: {accuracy:.2%}")

  0%|          | 0/5 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
Caching is incompatible with gradient checkpointing in Qwen2DecoderLayer. Setting `past_key_values=None`.
 20%|██        | 1/5 [00:05<00:22,  5.62s/it]

Janet and the (::synchronized String
二相和某公司所由一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批
Ground truth:  18


 40%|████      | 2/5 [00:10<00:15,  5.16s/it]

The following two years ago, 1、某公司所由一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批
Ground truth:  3


 60%|██████    | 3/5 [00:17<00:12,  6.04s/it]

The following two years ago，某公司所由一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批
Ground truth:  70000


 80%|████████  | 4/5 [00:22<00:05,  5.79s/it]

James and the (::synchronized String
二相和某公司所由一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批
Ground truth:  540


100%|██████████| 5/5 [00:27<00:00,  5.57s/it]

W. �，某公司所由一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批一批
Ground truth:  20
Evaluation Accuracy: 0.00%



