## GRPO + LoRA Without Regret

### Install Library

In [None]:
!pip install -qqq trl peft math-verify latex2sympy2_extended trackio

### Load Library

In [None]:
import os
from typing import Optional
import torch
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import (
    GRPOConfig,
    GRPOTrainer,
    ModelConfig,
    get_peft_config,
    get_quantization_config,
    get_kbit_device_map,
)

os.environ["TRACKIO_SPACE_ID"] = "trl-lora-without-regret"
os.environ["TRACKIO_PROJECT"] = "trl-lora-without-regret"

### Load Model Config

In [None]:
model_config = ModelConfig(
    model_name_or_path="Qwen/Qwen3-0.6B",
    torch_dtype="bfloat16",
    use_peft=True,
    lora_r=1,
    lora_alpha=32,
    lora_target_modules="all-linear",
    load_in_4bit=True,
)

training_args = GRPOConfig(
    output_dir="./grpo-lora-qwen3",
    learning_rate=1e-6,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    max_steps=100,
    gradient_checkpointing=True,
    num_generations=8,
    generation_batch_size=8,
    max_prompt_length=2048,
    max_completion_length=1024,
    logging_steps=10,
    save_steps=50,
    report_to=["trackio"],
    bf16=True,
)

### Load Data From Huggingface

In [None]:
dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train")
dataset = dataset.select(range(min(5000, len(dataset))))

In [None]:
def make_conversation(example):
    return {"prompt": [{"role": "user", "content": example["problem"]}]}


dataset = dataset.map(make_conversation)
dataset = dataset.remove_columns(
    [col for col in dataset.column_names if col not in ["prompt", "solution"]]
)

### Reward Function

In [None]:
def strip_reasoning_accuracy_reward(
    completions: list[list[dict[str, str]]], solution: list[str], **kwargs
) -> list[Optional[float]]:
    contents = [completion[0]["content"] for completion in completions]
    rewards = []

    for content, sol in zip(contents, solution):
        while "<think>" in content and "</think>" in content:
            start = content.find("<think>")
            end = content.find("</think>", start)
            if start != -1 and end != -1:
                content = content[:start] + content[end + len("</think>") :]
            else:
                break

        gold_parsed = parse(
            f"${sol}$",
            extraction_config=[
                LatexExtractionConfig(
                    boxed_match_priority=0, try_extract_without_anchor=True
                )
            ],
        )

        if len(gold_parsed) != 0:
            # We require the answer to be provided in correct latex (no malformed operators)
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        boxed_match_priority=0,
                        normalization_config=NormalizationConfig(
                            basic_latex=True,
                            units=True,
                            malformed_operators=False,
                            nits=False,
                            boxed=True,
                        ),
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )
            try:
                reward = float(verify(gold_parsed, answer_parsed))
            except:
                reward = None
        else:
            reward = None

        rewards.append(reward)
    return rewards

### Intialize Trainer

In [None]:
dtype = (
    getattr(torch, model_config.torch_dtype)
    if model_config.torch_dtype not in ["auto", None]
    else model_config.torch_dtype
)
training_args.model_init_kwargs = {
    "torch_dtype": dtype,
    "device_map": get_kbit_device_map(),
    "quantization_config": get_quantization_config(model_config),
}

peft_config = get_peft_config(model_config)

trainer = GRPOTrainer(
    model=model_config.model_name_or_path,
    args=training_args,
    reward_funcs=[strip_reasoning_accuracy_reward],
    train_dataset=dataset,
    peft_config=peft_config,
)

### Train the Model

In [None]:
trainer.train()

### Save the Model

In [None]:
trainer.save_model(training_args.output_dir)

### Test the Model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(
    model_config.model_name_or_path, torch_dtype=dtype, device_map="auto"
)
model = PeftModel.from_pretrained(base_model, training_args.output_dir)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
model.eval()

In [None]:
def test_model(problem: str, max_new_tokens: int = 512):
    messages = [{"role": "user", "content": problem}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs, max_new_tokens=max_new_tokens, temperature=0.7, do_sample=True
        )

    return tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
    )


# Test on first example
example = dataset[0]
result = test_model(example["prompt"][0]["content"])
print(f"Problem: {example['prompt'][0]['content'][:200]}...")
print(f"\nResponse: {result[:500]}...")