# Reinforcement Learning with GRPO

This Colab is a reinforcement-learning demo for math word problems using GRPO (Group Relative Policy Optimization) with Unsloth. It loads a small instruction model (Qwen2.5-1.5B-Instruct) in 4-bit for low VRAM, attaches LoRA adapters, and trains the policy with two simple rewards:

Format reward: outputs must follow ```<think>…</think><answer>…</answer>```.

Accuracy reward: the number inside ```<answer>…</answer>``` must match the GSM8K gold answer.

We’ll:

* Install & check hardware (BF16/FP16).
* Load the model (4-bit) + LoRA for efficient training.
* Prepare a small GSM8K sample into chat messages with ground-truth numbers.
* Define rewards (format + accuracy) and train with GRPO for a short run.
* Generate on a new problem, then save and reload the LoRA checkpoint.

# Install Unsloth

In [None]:
!pip install unsloth

Collecting unsloth
  Downloading unsloth-2025.11.2-py3-none-any.whl.metadata (61 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.8/61.8 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting unsloth_zoo>=2025.11.3 (from unsloth)
  Downloading unsloth_zoo-2025.11.3-py3-none-any.whl.metadata (32 kB)
Collecting tyro (from unsloth)
  Downloading tyro-0.9.35-py3-none-any.whl.metadata (12 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.1 kB)
Collecting bitsandbytes!=0.46.0,!=0.48.0,>=0.45.5 (from unsloth)
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting datasets!=4.0.*,!=4.1.0,<4.4.0,>=3.4.1 (from unsloth)
  Downloading datasets-4.3.0-py3-none-any.whl.metadata (18 kB)
Collecting trl!=0.19.0,<=0.23.0,>=0.18.2 (from 

In [None]:
import platform, sys, subprocess, os, textwrap, json

# Imports, hardware check, and runtime flags
* Imports libraries, checks for CUDA and BF16 support, prints a short summary

In [None]:
import os, re, math, random, torch
from datasets import load_dataset, Dataset
from unsloth import FastLanguageModel
try:
    from unsloth import is_bfloat16_supported
    BF16 = bool(is_bfloat16_supported())
except Exception:
    BF16 = torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8

print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))
print("Use BF16:", BF16)

CUDA available: True
Device: Tesla T4
Use BF16: False


Set a small sample size and max prompt/completion lengths.

In [None]:
sample_count = 150

MAX_PROMPT_LEN      = 256
MAX_COMPLETION_LEN  = 128   #

# Disable Weights & Biases and HF logging overhead
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

# Load base model and attach LoRA

* Loads Qwen2.5-1.5B-Instruct in 4-bit with Flash-Attention-2
* Adds LoRA adapters to attention/MLP modules  and enables gradient checkpointing; sets tokenizer padding.

In [None]:
BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"

model, tokenizer = FastLanguageModel.from_pretrained(
    BASE_MODEL,
    max_seq_length      = MAX_PROMPT_LEN + MAX_COMPLETION_LEN,
    load_in_4bit        = True,
    dtype               = torch.bfloat16 if BF16 else torch.float16,
    attn_implementation = "flash_attention_2",
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16, lora_alpha=16, lora_dropout=0.0, bias="none",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    use_gradient_checkpointing="unsloth",
)
tokenizer.padding_side = "left"
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token


==((====))==  Unsloth 2025.11.2: Fast Qwen2 patching. Transformers: 4.57.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/1.53G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/270 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Unsloth 2025.11.2 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


# Define system prompt and single-turn chat helper

* Sets a system instruction that asks the model to put reasoning in <think> and the final number in <answer>; defines chat_once(...) to build a chat prompt, generate, and return only the assistant text.


In [None]:
SYSTEM_PROMPT = (
    "You are a careful reasoning assistant. Think step by step and show your work "
    "between <think></think>, then give the final numeric answer inside <answer></answer>."
)

In [None]:
def chat_once(messages, max_new_tokens=384, temperature=0.7, top_p=1.0):
    device = model.device
    prompt = tokenizer.apply_chat_template(
        messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
    ).to(device)
    with torch.no_grad():
        out = model.generate(
            prompt,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    decoded = tokenizer.decode(out[0][prompt.shape[-1]:], skip_special_tokens=True)
    return decoded.strip()

# Load GSM8K math problems and build a training dataset

* Loads GSM8K (train), shuffles and selects up to sample_count items; extracts the gold numeric answer; converts each problem into {prompt: [system+user messages], ground_truth: <number>}.

In [None]:
raw = load_dataset("gsm8k", "main", split="train")
raw = raw.shuffle(seed=SEED).select(range(min(sample_count, len(raw))))

def extract_gsm8k_answer(a: str):
    if "####" in a:
        z = a.split("####")[-1].strip()
        # keep just the number-ish part
        m = re.search(r"-?\d+(\.\d+)?", z.replace(",", ""))
        return m.group(0) if m else None
    return None

def to_messages(example):
    # Conversational format (system + user) is fine for TRL GRPO
    question = example["question"].strip()
    gt = extract_gsm8k_answer(example["answer"])
    if gt is None:
        return None
    msgs = [
        {"role":"system", "content": SYSTEM_PROMPT},
        {"role":"user",   "content": f"Solve step by step, then return the final numeric result in <answer></answer>.\n\nProblem: {question}"},
    ]
    return {"prompt": msgs, "ground_truth": gt}

mapped = []
for ex in raw:
    row = to_messages(ex)
    if row is not None: mapped.append(row)

dataset = Dataset.from_list(mapped)
print(dataset[0])
print("Dataset size:", len(dataset))

{'prompt': [{'content': 'You are a careful reasoning assistant. Think step by step and show your work between <think></think>, then give the final numeric answer inside <answer></answer>.', 'role': 'system'}, {'content': "Solve step by step, then return the final numeric result in <answer></answer>.\n\nProblem: Ahmed is 11 years old and Fouad is 26 years old. In how many years will Fouad's age be double Ahmed's current age?", 'role': 'user'}], 'ground_truth': '4'}
Dataset size: 150


# Reward functions
* format_reward: gives 1.0 if the reply matches ```<think>…</think><answer>…</answer>```, else 0.0.
* accuracy_reward: parses the number inside ```<answer>…</answer>``` and gives 1.0 if it matches ground truth.

In [None]:
# 1) Format reward: must contain <think>...</think><answer>...</answer>
_format_pat = re.compile(r"^<think>.*?</think>\s*<answer>.*?</answer>\s*$", re.S)

def format_reward(completions, **kwargs):
    # completions: list[list[{"role":"assistant","content": "..."}]]
    contents = [c[0]["content"] if isinstance(c, list) else str(c) for c in completions]
    return [1.0 if _format_pat.match(txt.strip()) else 0.0 for txt in contents]

# 2) Accuracy reward: pull the content inside <answer>...</answer> and compare to GT
def _extract_answer_field(s: str):
    m = re.search(r"<answer>\s*(.*?)\s*</answer>", s, re.S)
    if m:
        payload = m.group(1).strip()
        n = re.search(r"-?\d+(\.\d+)?", payload.replace(",", ""))
        return n.group(0) if n else None
    if "####" in s:
        z = s.split("####")[-1]
        n = re.search(r"-?\d+(\.\d+)?", z.replace(",", ""))
        return n.group(0) if n else None
    return None

def accuracy_reward(completions, ground_truth, **kwargs):
    contents = [c[0]["content"] if isinstance(c, list) else str(c) for c in completions]
    preds = [_extract_answer_field(t) for t in contents]
    rewards = []
    for p, g in zip(preds, ground_truth):
        rewards.append(1.0 if (p is not None and g is not None and p == g) else 0.0)
    return rewards

# Build GRPO trainer and train

* Instantiates GRPOTrainer with the model, both reward functions, tokenizer, and dataset; starts training with .train().
* The trainer samples, scores with the rewards, and updates the policy.

In [None]:
from trl import GRPOConfig, GRPOTrainer

NUM_GENERATIONS = 2
BATCH_PER_DEVICE = 2
GA_STEPS = 1

training_args = GRPOConfig(
    output_dir                    = "grpo-fast",
    learning_rate                 = 5e-6,
    weight_decay                  = 0.1,
    warmup_ratio                  = 0.1,
    lr_scheduler_type             = "cosine",
    optim                         = "adamw_8bit",
    logging_steps                 = 10,
    report_to                     = "none",
    save_strategy                 = "no",
    bf16                          = BF16,
    fp16                          = (not BF16),
    per_device_train_batch_size   = BATCH_PER_DEVICE,
    gradient_accumulation_steps   = GA_STEPS,
    num_generations               = NUM_GENERATIONS,
    max_steps                     = 100,
    temperature                   = 1.0,
    top_p                         = 1.0,
    max_prompt_length             = MAX_PROMPT_LEN,
    max_completion_length         = MAX_COMPLETION_LEN,
    loss_type                     = "dapo",
    epsilon_high                  = 0.28,
    beta                          = 0.0,
    mask_truncated_completions    = True,
    use_vllm                      = False,
    dataloader_num_workers        = 2,
)


from trl import GRPOTrainer

trainer = GRPOTrainer(
    model            = model,
    reward_funcs     = [format_reward, accuracy_reward],
    args             = training_args,
    train_dataset    = dataset,
    processing_class = tokenizer,
)

In [None]:
trainer.train()

The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 150 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2
 "-____-"     Trainable parameters = 18,464,768 of 1,562,179,072 (1.18% trained)


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / format_reward / mean,rewards / format_reward / std,rewards / accuracy_reward / mean,rewards / accuracy_reward / std
10,0.0,0.0,0.0,128.0,128.0,128.0,1.0,0.0,0.0,0.0,0,0,0,0,0,0.0,0.0,0.0,0.0,0.0
20,0.0,0.0,0.0,126.95,125.9,128.0,0.95,10.7,10.7,10.7,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0,0.0,0.0
30,0.0,0.05,0.070711,124.85,121.7,128.0,0.9,19.3,19.3,19.3,No Log,No Log,No Log,No Log,No Log,0.0,0.05,0.070711,0.0,0.0
40,-0.0,0.05,0.070711,124.85,123.0,126.7,0.85,21.0,20.6,21.4,No Log,No Log,No Log,No Log,No Log,0.0,0.05,0.070711,0.0,0.0
50,0.0,0.0,0.0,125.7,123.4,128.0,0.85,33.8,33.8,33.8,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0,0.0,0.0
60,0.0,0.0,0.0,128.0,128.0,128.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0,0.0,0.0
70,0.0,0.0,0.0,126.3,124.6,128.0,0.85,35.0,35.0,35.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0,0.0,0.0
80,0.0,0.1,0.141421,126.2,125.4,127.0,0.85,23.05,23.0,23.1,No Log,No Log,No Log,No Log,No Log,0.0,0.05,0.070711,0.05,0.070711
90,0.0,0.0,0.0,126.2,124.4,128.0,0.95,9.2,9.2,9.2,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0,0.0,0.0
100,0.0,0.1,0.141421,124.35,120.7,128.0,0.9,18.3,18.3,18.3,No Log,No Log,No Log,No Log,No Log,0.0,0.05,0.070711,0.05,0.070711


TrainOutput(global_step=100, training_loss=1.6038306682730763e-10, metrics={'train_runtime': 926.6675, 'train_samples_per_second': 0.216, 'train_steps_per_second': 0.108, 'total_flos': 0.0, 'train_loss': 1.6038306682730763e-10})

# Test generation after training

* Asks a new marble-ratio problem using the same system prompt; calls chat_once(...) to generate a solution.

In [None]:
test_messages = [
    {"role":"system", "content": SYSTEM_PROMPT},
    {"role":"user",   "content": "A jar has 18 red and 12 blue marbles. If you add x red marbles so that red:blue becomes 5:3, what is x? Return final answer in <answer></answer>."},
]
print(chat_once(test_messages, max_new_tokens=384, temperature=0.7))

Let's denote the number of additional red marbles added as \( x \).

Initially:
- Red marbles = 18
- Blue marbles = 12

After adding \( x \) red marbles:
- New total red marbles = 18 + \( x \)
- Blue marbles remain unchanged at 12.

According to the problem, after this addition, the ratio of red marbles to blue marbles should be 5:3. So we can write:

\[
\frac{18 + x}{12} = \frac{5}{3}
\]

To solve for \( x \), cross-multiply:

\[
3(18 + x) = 5 \times 12
\]

Simplify both sides:

\[
54 + 3x = 60
\]

Subtract 54 from both sides:

\[
3x = 6
\]

Divide both sides by 3:

\[
x = 2
\]

So, the value of \( x \) is 2. The final answer is:

<answer>2</answer>


# Save LoRA adapter and tokenizer

* Saves the fine-tuned LoRA weights and tokenizer files to grpo_saved_lora/ so you can reuse/deploy the tuned policy.

In [None]:
OUT_DIR = "grpo_saved_lora"
model.save_pretrained(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)

lora_model, lora_tokenizer = FastLanguageModel.from_pretrained(
    OUT_DIR,
    max_seq_length      = MAX_SEQ_LEN,
    load_in_4bit        = True,
    dtype               = torch.bfloat16 if BF16 else torch.float16,
    attn_implementation = "flash_attention_2",
)
print("Loaded LoRA-only checkpoint with Unsloth.")

==((====))==  Unsloth 2025.11.2: Fast Qwen2 patching. Transformers: 4.57.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Loaded LoRA-only checkpoint with Unsloth.
