In [None]:
import argparse
import os
import subprocess
import sys
from datetime import datetime
from pathlib import Path

from loguru import logger
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported

import wandb

PatchFastRL("GRPO", FastLanguageModel)

from trl import GRPOConfig, GRPOTrainer  # need to import after unsloth patched

In [2]:
from r1rl.datasets import GSM8kDataset

dataset_class = GSM8kDataset
dataset = dataset_class(
    is_chat=False,
    shuffle_seed=3407,
)

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="Qwen/Qwen2.5-3B",
    max_seq_length=1024,
    load_in_4bit=True,  # False for LoRA 16bit
    fast_inference=True,  # Enable vLLM fast inference
    max_lora_rank=64,
    gpu_memory_utilization=0.6,
)


In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r=64,  # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],  # Remove QKVO if out of memory
    lora_alpha=64,
    use_gradient_checkpointing="unsloth",  # Enable long context finetuning
    random_state=3407,
)

In [5]:
training_args = GRPOConfig(
    use_vllm=True,  # use vLLM for fast inference!
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=1,
    num_generations=4,
    max_prompt_length=250,
    max_completion_length=500,
    max_steps=500,
    eval_on_start=False,
    eval_strategy="steps",
    eval_steps=250,
    save_steps=250,
    max_grad_norm=0.1,
    report_to="none",
    output_dir=f"../dumps/outputs_tmp",
)


In [None]:
from typing import List, Dict
from r1rl.utils import extract_answer


# Reward functions
def correctness_reward_func(
    completions: List[List[Dict[str, str]] | List[str]],
    *,
    prompts: List[List[Dict[str, str]] | List[str]],
    answer: List[str],
    question: List[str],
    **kwargs,
) -> list[float]:
    def get_completion_content(completion: List[Dict[str, str]] | str) -> str:
        if isinstance(completion, str):
            return completion
        else:
            return completion[0]["content"]

    responses = [get_completion_content(completion) for completion in completions]
    extracted_responses = [extract_answer(r) for r in responses]

    infos = {
        "Prompt": prompts[0],
        "Question": question[0],
        "Response": responses[0],
        "Ground Truth": answer[0],
        "Extracted": "\n".join(
            [f"{i + 1}. {r}" for i, r in enumerate(extracted_responses)]
        ),
    }
    logger.info("\n".join([f"{k}:\n{v}" for k, v in infos.items()]))
    return [1.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]


trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[correctness_reward_func],
    args=training_args,
    train_dataset=dataset.train_dataset,
    eval_dataset=dataset.eval_dataset,
)
trainer.train()