# GRPO on GSM8K with Qwen3-0.6B + vLLM

Train a small language model on grade-school math problems using **Group Relative Policy Optimization (GRPO)**.

- Base model: `Qwen/Qwen3-0.6B`
- Dataset: `openai/gsm8k`
- Generation backend: **vLLM**
- Framework: HuggingFace **TRL**

In [1]:
# Cell 1: Install dependencies
# Pin vLLM to a version compatible with TRL 0.28.0 (supports up to 0.12.0)
!pip install "vllm==0.12.0" --quiet
!pip install trl transformers datasets accelerate hf_transfer --quiet

# Enable hf_transfer for faster model downloads
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

In [None]:
# Cell 2: Imports
import re
import torch
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
from src.ibrl import IBRLTrainer, IBRLConfig

In [3]:
# Cell 3: Load and prepare GSM8K dataset
dataset = load_dataset("openai/gsm8k", "main")

def extract_gold_answer(answer_text: str) -> str:
    """Extract the numeric answer after ####."""
    match = re.search(r"####\s*(.+)", answer_text)
    if match:
        return match.group(1).strip().replace(",", "")
    return ""

def format_example(example):
    """Convert to chat-style prompt and attach gold answer."""
    example["prompt"] = [{"role": "user", "content": example["question"]}]
    example["gold_answer"] = extract_gold_answer(example["answer"])
    return example

train_dataset = dataset["train"].map(format_example)
test_dataset = dataset["test"].map(format_example)

print(f"Train: {len(train_dataset)}, Test: {len(test_dataset)}")
print(f"Example prompt: {train_dataset[0]['prompt']}")
print(f"Gold answer: {train_dataset[0]['gold_answer']}")

Train: 7473, Test: 1319
Example prompt: [{'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?', 'role': 'user'}]
Gold answer: 72


In [4]:
# Cell 4: Define reward functions

def extract_answer_from_completion(text: str) -> str:
    """Parse the final numeric answer from a model completion."""
    # Look for #### pattern first
    match = re.search(r"####\s*([\d,\.\-]+)", text)
    if match:
        return match.group(1).strip().replace(",", "")
    # Fallback: last number in the text
    numbers = re.findall(r"-?[\d,]+\.?\d*", text)
    if numbers:
        return numbers[-1].replace(",", "")
    return ""

def correctness_reward(completions: list[list[dict]], gold_answer: list[str], **kwargs) -> list[float]:
    """Award +1.0 if the model's final numeric answer matches the gold answer."""
    rewards = []
    for completion, gold in zip(completions, gold_answer):
        text = completion[0]["content"]
        predicted = extract_answer_from_completion(text)
        try:
            correct = float(predicted) == float(gold)
        except (ValueError, TypeError):
            correct = False
        rewards.append(1.0 if correct else 0.0)
    return rewards

def format_reward(completions: list[list[dict]], **kwargs) -> list[float]:
    """Award +0.5 if the response contains #### <number> pattern."""
    rewards = []
    for completion in completions:
        text = completion[0]["content"]
        has_format = bool(re.search(r"####\s*[\d,\.\-]+", text))
        rewards.append(0.5 if has_format else 0.0)
    return rewards

# Quick sanity check
test_comp = [[{"content": "The answer is 2+3=5. #### 5"}]]
print("Correctness:", correctness_reward(test_comp, gold_answer=["5"]))
print("Format:", format_reward(test_comp))

Correctness: [1.0]
Format: [0.5]


In [None]:
# Cell 5: Configure GRPO / IBRL with vLLM backend
config = GRPOConfig(
    output_dir="grpo_gsm8k_output",
    num_generations=8,               # group size G
    max_completion_length=512,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    learning_rate=5e-6,
    logging_steps=10,
    max_steps=20,                    # smoke test; remove for full run
    use_vllm=True,
    vllm_mode="colocate",           # run vLLM in-process (no separate server needed)
    vllm_gpu_memory_utilization=0.3,
    bf16=True,
    save_strategy="no",
    report_to="none",
)

In [None]:
# Cell 6: Initialize GRPOTrainer / IBRLTrainer
trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B",
    reward_funcs=[correctness_reward, format_reward],
    args=config,
    train_dataset=train_dataset,
)

[2026-02-24 15:21:00] INFO modeling.py:1576: Based on the current allocation process, no modules could be assigned to the following devices due to insufficient memory:
  - 0: 685253632 bytes required
These minimum requirements are specific to this allocation attempt and may vary. Consider increasing the available memory for these devices to at least the specified minimum, or adjusting the model config.


[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Capturing CUDA graphs (mixed prefill-decode, PIECEWISE): 100%|██████████| 5/5 [00:00<00:00, 28.76it/s]
Capturing CUDA graphs (decode, FULL): 100%|██████████| 4/4 [00:00<00:00, 32.13it/s]


In [None]:
# Cell 7: Train
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss


In [None]:
# Cell 8: Evaluate — sample generations on test split
import random
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = trainer.model
model.eval()

samples = random.sample(range(len(test_dataset)), 5)
for idx in samples:
    question = test_dataset[idx]["question"]
    gold = test_dataset[idx]["gold_answer"]
    
    messages = [{"role": "user", "content": question}]
    input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=512, do_sample=False)
    response = tokenizer.decode(output_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    
    predicted = extract_answer_from_completion(response)
    match = "CORRECT" if predicted == gold else "WRONG"
    
    print(f"\n{'='*60}")
    print(f"Q: {question[:120]}...")
    print(f"A: {response[:300]}")
    print(f"Predicted: {predicted} | Gold: {gold} | {match}")

In [None]:
# Cell 9: Plot training reward curve
import matplotlib.pyplot as plt

log_history = trainer.state.log_history
steps = [entry["step"] for entry in log_history if "reward" in entry]
rewards = [entry["reward"] for entry in log_history if "reward" in entry]

if steps:
    plt.figure(figsize=(8, 4))
    plt.plot(steps, rewards, marker="o")
    plt.xlabel("Step")
    plt.ylabel("Mean Reward")
    plt.title("GRPO Training — Mean Reward over Steps")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("No reward data in logs yet (try increasing max_steps or decreasing logging_steps).")