In [3]:
# ===== Colab 4: GRPO Reasoning RL with Unsloth (Tiny Math Demo, FINAL) =====

!pip install -q "unsloth>=2025.3.0" "unsloth_zoo" "trl>=0.9.6" "datasets" "accelerate" "transformers>=4.44.0"

from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
import torch
import re

# 1) Tiny math dataset
raw_data = [
    {"question": "What is 2 + 3?",  "answer": "5"},
    {"question": "What is 7 - 4?",  "answer": "3"},
    {"question": "What is 3 * 3?",  "answer": "9"},
    {"question": "What is 10 - 6?","answer": "4"},
]
dataset = Dataset.from_list(raw_data)

def to_prompt(example):
    example["prompt"] = f"Question: {example['question']}\nAnswer:"
    return example

dataset = dataset.map(to_prompt)

max_seq_length        = 256
max_prompt_length     = 64
max_completion_length = max_seq_length - max_prompt_length

# 2) Reward fn — MUST return a list/tensor of floats, not a dict
# GRPO will call it like:
#   simple_math_reward(prompts=..., completions=..., completion_ids=..., question=..., answer=..., ...)
def simple_math_reward(
    prompts,
    completions,
    completion_ids=None,
    question=None,
    answer=None,
    **kwargs,
):
    rewards = []
    for q, a, comp in zip(question, answer, completions):
        nums = re.findall(r"-?\d+", comp)
        pred = nums[-1] if nums else None
        rewards.append(1.0 if pred == a else 0.0)
    # IMPORTANT: return a list of floats, not {"rewards": ...}
    return rewards

reward_funcs = [simple_math_reward]

# 3) Load model with QLoRA
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name      = "unsloth/SmolLM2-135M-Instruct",
    max_seq_length  = max_seq_length,
    load_in_4bit    = True,
    full_finetuning = False,
    dtype           = None,
)

model = FastLanguageModel.get_peft_model(
    model,
    r                           = 16,
    lora_alpha                  = 16,
    lora_dropout                = 0,
    bias                        = "none",
    use_gradient_checkpointing  = "unsloth",
    max_seq_length              = max_seq_length,
)

# 4) GRPO config
training_args = GRPOConfig(
    output_dir                  = "smollm2_grpo_rl",
    learning_rate               = 5e-6,
    adam_beta1                  = 0.9,
    adam_beta2                  = 0.99,
    weight_decay                = 0.1,
    warmup_ratio                = 0.1,
    lr_scheduler_type           = "cosine",
    logging_steps               = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1,
    num_generations             = 2,
    max_steps                   = 10,
    max_prompt_length           = max_prompt_length,
    max_completion_length       = max_completion_length,
    remove_unused_columns       = False,
    fp16                        = not is_bfloat16_supported(),
    bf16                        = is_bfloat16_supported(),
    report_to                   = "none",    # no wandb
)

trainer = GRPOTrainer(
    model            = model,
    processing_class = tokenizer,
    reward_funcs     = reward_funcs,
    args             = training_args,
    train_dataset    = dataset,
    prompt_column    = "prompt",
)

trainer.train()

# 5) Inference demo
model.eval()

question = "What is 2 + 3?"
prompt   = f"Question: {question}\nAnswer:"
inputs   = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens = 30,
        do_sample      = True,
        temperature    = 0.7,
        use_cache      = False,
    )

print(tokenizer.decode(outputs[0], skip_special_tokens=True))


Map:   0%|          | 0/4 [00:00<?, ? examples/s]

==((====))==  Unsloth 2025.11.4: Fast Llama patching. Transformers: 4.57.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


The model is already on multiple devices. Skipping the move to device specified in `args`.


Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 2


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 4 | Num Epochs = 3 | Total steps = 10
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 1 x 1) = 2
 "-____-"     Trainable parameters = 4,884,480 of 139,400,064 (3.50% trained)


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / simple_math_reward / mean,rewards / simple_math_reward / std
1,0.0,0.0,0.0,100.5,9.0,192.0,0.5,9.0,9.0,9.0,0,0,0,0,0,0.0,0.0,0.0
2,0.0,0.0,0.0,148.0,104.0,192.0,0.5,104.0,104.0,104.0,No Log,No Log,No Log,No Log,No Log,0.0,0.0,0.0
3,-0.0,0.5,0.707107,91.0,20.0,162.0,0.0,91.0,20.0,162.0,No Log,No Log,No Log,No Log,No Log,0.0,0.5,0.707107
4,0.0,0.0,0.0,50.5,47.0,54.0,0.0,50.5,47.0,54.0,No Log,No Log,No Log,No Log,No Log,2.2e-05,0.0,0.0
5,0.0,0.0,0.0,91.0,16.0,166.0,0.0,91.0,16.0,166.0,No Log,No Log,No Log,No Log,No Log,2.1e-05,0.0,0.0
6,0.0,0.0,0.0,17.5,11.0,24.0,0.0,17.5,11.0,24.0,No Log,No Log,No Log,No Log,No Log,1.6e-05,0.0,0.0
7,0.0,0.5,0.707107,169.0,146.0,192.0,0.5,146.0,146.0,146.0,No Log,No Log,No Log,No Log,No Log,2e-05,0.5,0.707107
8,0.0,0.0,0.0,104.5,17.0,192.0,0.5,17.0,17.0,17.0,No Log,No Log,No Log,No Log,No Log,2.8e-05,0.0,0.0
9,0.0,0.0,0.0,192.0,192.0,192.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,2.6e-05,0.0,0.0
10,0.0,0.0,0.0,192.0,192.0,192.0,1.0,0.0,0.0,0.0,No Log,No Log,No Log,No Log,No Log,2.8e-05,0.0,0.0


Question: What is 2 + 3?
Answer: 2 + 3 is '2 + 3' not '2 + 3+ 1'.
