## **1. Install & import libraries**

In [None]:
%pip install unsloth vllm

In [None]:
import re
from vllm import SamplingParams
from unsloth import FastLanguageModel
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer

INFO 04-27 20:13:15 [__init__.py:239] Automatically detected platform cuda.


## **2. Load and Prepare LoRA-Enabled Model**

In [None]:
max_seq_length = 2048
lora_rank = 64

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Llama-3.2-3B-Instruct",
    max_seq_length=max_seq_length,
    load_in_4bit=False,
    fast_inference=True,
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.8,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=["q_proj", "v_proj"],
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",
    random_state=3407,
)

## **3. Load & format dataset for reasoning**

In [None]:
dataset = load_dataset(
    "5CD-AI/Vietnamese-meta-math-MetaMathQA-40K-gg-translated", split="train"
)

In [None]:
print("Dataset structure:", dataset)

Dataset structure: Dataset({
    features: ['response_vi', 'query_vi', 'response_en', 'type', 'query_en'],
    num_rows: 40000
})


In [None]:
answer_pattern = re.compile(
    r"(đáp án là:|đáp án là :|câu trả lời là:|câu trả lời là :)\s*(.*)", re.IGNORECASE
)

formatted_dataset = []
for item in dataset:
    response = item["response_vi"].strip().lower()
    match = answer_pattern.search(response)
    if match:
        answer = match.group(2).strip()
        formatted_dataset.append({"question": item["query_vi"], "answer": answer})

In [None]:
reasoning_start = "<thinking>"
reasoning_end = "</thinking>"
solution_start = "<answer>"
solution_end = "</answer>"

system_prompt = f"""You are given a problem.
Think about the problem and provide your thought process.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your final answer between {solution_start}{solution_end}"""

train_dataset = Dataset.from_list(formatted_dataset[:8000])
train_dataset = train_dataset.map(
    lambda x: {
        "prompt": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": x["question"]},
        ],
        "answer": x["answer"],
    }
)

In [None]:
train_dataset[0]

## **4. Define reward functions**



### **4.1 Match format**

In [None]:
match_format = re.compile(
    rf"^[\s]{{0,}}"
    rf"{reasoning_start}.+?{reasoning_end}.*?"
    rf"{solution_start}(.+?){solution_end}"
    rf"[\s]{{0,}}$",
    flags=re.MULTILINE | re.DOTALL,
)


# math exactly -> 3.0
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        if match_format.search(response) is not None:
            score += 3.0
        scores.append(score)
    return scores


def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        score += 0.5 if response.count(reasoning_start) == 1 else -1.0
        score += 0.5 if response.count(reasoning_end) == 1 else -1.0
        score += 0.5 if response.count(solution_start) == 1 else -1.0
        score += 0.5 if response.count(solution_end) == 1 else -1.0
        scores.append(score)
    return scores

### **4.2 Match Answer**

In [None]:
match_numbers = re.compile(
    solution_start + r".*?(-?[\d\.\,]{1,})", flags=re.MULTILINE | re.DOTALL
)


def check_answer(prompts, completions, answer, **kwargs):
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1) if (guess := match_format.search(r)) is not None else None
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue

        if guess == true_answer:
            score += 3.0
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            score -= 1.5
        scores.append(score)
    return scores


def check_numbers(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1) if (guess := match_numbers.search(r)) is not None else None
        for r in responses
    ]

    # Print every 5 steps
    count = getattr(check_numbers, "counter", 0) + 1
    check_numbers.counter = count
    if count % 5 == 0:
        print(
            "*" * 20,
            f"Question:{question}",
            f"\nResponse:\n{responses[0]}",
            f"\nExtracted: {extracted_responses[0]}",
            f"\nGT Answer: {answer[0]}",
        )

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        # Convert to numbers
        try:
            true_answer = float(true_answer.strip())
            guess = float(guess.strip().replace(",", ""))
            scores.append(1.5 if guess == true_answer else -0.5)
        except:
            scores.append(0)
    return scores

## **5. Training (GRPO)**

In [None]:
max_len = max(
    dataset.map(
        lambda x: {
            "tokens": tokenizer.apply_chat_template(
                x["prompt"], add_generation_prompt=True, tokenize=True
            )
        },
        batched=True,
    ).map(lambda x: {"length": len(x["tokens"])})["length"]
)

max_prompt_length = max_len + 1

In [None]:
training_args = GRPOConfig(
    learning_rate=5e-6,
    weight_decay=5e-4,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
    logging_steps=1,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=64,
    num_generations=8,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_seq_length - max_prompt_length,
    num_train_epochs=1,
    max_steps=-1,
    save_steps=20,
    max_grad_norm=0.1,
    report_to="wandb",
    output_dir="grpo_lora",
)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

## **6. Save model**

In [None]:
model.save_lora("saved_grpo_lora")

## **7. Inference**

### Original Model

In [None]:
idx = 0
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": train_dataset[idx]["question"]},
]
sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=1024,
)

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=False,
)

output = (
    model.fast_generate(
        [text],
        sampling_params=sampling_params,
        lora_request=None,
    )[0]
    .outputs[0]
    .text
)

print(f"Problem:\n{train_dataset[idx]['question']}")
print(f"Response:\n{output}")
print("GT Answer:", train_dataset[idx]["answer"])

### Load Lora and evaluate

In [None]:
path_lora = "saved_grpo_lora"
output = (
    model.fast_generate(
        [text],
        sampling_params=sampling_params,
        lora_request=model.load_lora(path_lora),
    )[0]
    .outputs[0]
    .text
)

print(f"Problem:\n{train_dataset[idx]['question']}")
print(f"Response:\n{output}")
print("GT Answer:", train_dataset[idx]["answer"])