# Goal

In this notebook, we will use GRPO to enable the reasoning capability for Llama3.1, but this can be applied to any LLM.




# SFT for Reasoning? 🤔
Can we use supervised finetuning to enable a model to reason? Yes! But it requires

- (question, good reasoning, good answer) triplets generated by a reasoning-capable model
- fine-tuning our target model with above triplets

Effectively we are distilling the knowledge and reasoning capability of a bigger and more capable model into a smaller model.

While this approach is straightforward, we are bounded by the limits of an existing reasoning model. This can be a challenge if we want reasoning for a particular domain say medical or law, which the teacher model may not be good at.

# Reasoning from scratch? 🤔
So how do we get a model to reason without relying on another model to guide it? How can we teach it to reason without feeding it tons of examples of “good” reasoning? Well, hiring annotators to write out solid reasoning for every hard question just isn’t practical or scalable. But here’s an alternative: instead of showing the model how to reason, we define what good reasoning looks like.

🧩 given → 🧑‍🎓 thinks 🤔 → answer ✅ or ❌ → improves via 🧠🔄
- Think of it like teaching someone to solve a puzzle without showing them the solution. Instead of walking them through the steps, you just tell them whether they got it right or wrong at the end. Over time, they have to figure out what kinds of moves or thought patterns lead to success. That’s the core idea behind reinforcement learning (RL): the model doesn’t get a play-by-play—it just gets a signal about whether the outcome was good or bad, and it has to learn the in-between steps on its own.

🧑‍🍳 cooks 🍝 → gets 👃👀👅 feedback → learns with 🔁🧠
- Or think of it like teaching a newbie chef to cook. Instead of giving them recipes, we let them experiment in the kitchen—and then we give feedback: “Does it look appetizing? Does it taste or smell good?” The chef has to figure out what steps lead to that tasty result.




# The Magic ✨
All we need is a simple SFT dataset with just questions and answers. 🚀

Here’s the idea: We ask the model to solve a problem after reasoning through it, and we reward it only if the final solution is correct and it provides a proper reasoning flow. That’s it—no extra complexity!

How It Works:
Let’s say we give the model a hard math problem. Instead of jumping straight to the answer, we set up the system prompt so that the model:

- Starts by reasoning through the problem and possible solutions.

- Wraps the reasoning between \<reasoning> and \</reasoning> tags.

- Provides the final solution only after reasoning is complete, placing it between \<answer> and \</answer> tags.

We reward the model if:

&nbsp;&nbsp;🏆 The answer is correct—obviously, that’s a must!


&nbsp;&nbsp;🏆 The final solution is properly tagged with \<answer>.


&nbsp;&nbsp;🏆 The reasoning is properly tagged with \<reasoning>.


&nbsp;&nbsp;🏆 The reasoning meets a certain length requirement (to avoid shortcuts).


It’s like telling a student: “I won’t just grade your final answer—you need to show your work!” 🎓✨


# RL Requires Reward 💰
RL relies on a reward signal. In the context of LLMs, we need a way to determine which completions or reasoning are good and which ones aren't as great. To generate the reward signal, one can

&nbsp;&nbsp; 👉 use a reward model, which has been trained in advance to mimic human preferences

&nbsp;&nbsp; 👉 design a set of reward functions that check out model completions align with our preferences (e.g., answer length, answer format)


In short, the goal is to teach the model what "good" looks like so it can improve its responses over time! 🚀


# The Art of Dynamic Rewards 🎯
One key aspect of RL is that rewards shouldn’t just reinforce consistency—they should also push for growth. Imagine a student who starts off getting B’s—at first, the teacher might reward them to encourage progress. But if they keep getting B’s forever, the teacher needs to raise the bar, rewarding them only when they improve to an A or A+. In other words, rewards should be dynamic and adaptive, reinforcing positive surprises rather than maintaining the status quo. This ensures that learning doesn’t plateau but instead keeps progressing.


Some popular RL-based training methods are Proximal Policy Optimization (PPO), Direct Preference Optimization (DPO), and GRPO. Let's see how each of these 3 approaches adjust the reward for model's continous improvemnt.

- PPO, a traditional RL algorithm, relies on a critic network to estimate this baseline reward for an LLM. PPO is like a traditional classroom, where the teacher (the critic network) evaluates every student not just on their latest assignment, but based on their overall performance history. The teacher has an internal benchmark for each student—if a consistent B student scores another B, it’s expected so not a large reward assigned. But if that same student suddenly earns an A+, that’s a big deal and deserves extra recognition. On the flip side, if a straight-A student suddenly drops to a B, the teacher knows something went wrong. This structured approach helps maintain steady progress, but it also makes learning complex and slow because the teacher has to track every student’s past performance, adjusting expectations individually for each one.


- DPO simplifies learning by replacing the critic with direct preference comparisons, where the model learns by evaluating pairs of outputs ranked by human feedback. This pairwise comparison helps establish a relative reward context—similar to a student competing with a peer. If both students consistently receive similar grades, neither is performing exceptionally well; true excellence is only recognized when one significantly outperforms the other. However, this approach depends on a carefully curated preference dataset and is inherently limited by its two-output comparison framework.

- GRPO simplifies learning by replacing the critic with a group-based comparison approach, where the model generates and evaluates multiple outputs at once rather than relying on pairwise comparisons. This enables the model to learn which responses are superior without needing a critic and without being limited to one-on-one assessments. An analogy would be evaluating a student’s performance relative to the entire class rather than just a single peer—providing a broader and more informative benchmark for assessing excellence. Also, unlike DPO, we don't need to create a preference dataset in advance - we let the model generate multiple answers for a given question and each answer will get a reward from the reward function and then adjust the rewards based on the mean and variance of these rewards - no need to involve human feedback!


# 🔭 GRPO at High-Level

1. Group Sampling: For a single prompt, the policy (our LLM) generates a batch of completions (instead of just one). This produces a small “group” of possible actions or answers.

2. Reward Scoring: Each output is scored by a series of reward functions, which reflect how good or desirable that output is for the task at hand.

3. Group-Based Advantage: The algorithm calculates each output’s “advantage” by comparing its reward to the average reward of the entire group (see the dynamic reward section). If the output’s reward is above average, it has a positive advantage (and vice versa).

4. Policy Update: The policy is adjusted to promote outputs with a positive advantage and discourage those with a negative advantage.
- A KL penalty term helps keep the policy from drifting too far from the original model. Without it, the model might learn to reason better, but at the cost of forgetting or degrading other important capabilities that aren't directly related to reasoning.

5. Iterative Process: The updated policy is used again to generate new groups, score them, and update—repeating until the policy converges or meets performance goals.

This group-based approach removes the need for a separate value function (critic) and helps the policy quickly learn which outputs are relatively better within each sampled group.

We will take a deeper dive into GRPO in the following section! 🔍 🌊

### Installation

As you can see, we'll only need a few dependencies thanks to the collective hard work of the community!

In [None]:
# create a conda environment: conda create -n grpo_test python=3.12.4
# activate conda environment in your colab notebook
# It took me a while to figure out the version of torch that works with bitsandbytes==0.45.4
!pip install torch==2.6.0 accelerate==1.5.2 transformers==4.50.1 diffusers==0.32.2 wandb pillow trl==0.17.0 bitsandbytes==0.45.4





In [None]:
# prompt: check cuda version

!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Jun__6_02:18:23_PDT_2024
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0


# Model
Let's set up the model and tokenizer.

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [2]:
#  load the model's tokenizer and properly set padding token
from transformers import AutoTokenizer
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct" #@param {type:"string"}

tokenizer = AutoTokenizer.from_pretrained(model_id,trust_remote_code=True)

# Setting the padding token is crucial for batch processing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token # or choose another suitable token depending on the model

In [3]:
# !pip install flash-attn==2.6.3 --upgrade


In [4]:
# import torch, flash_attn
# print(torch.__version__, flash_attn.__version__)
# from transformers.modeling_flash_attention_utils import _flash_supports_window_size
# print("window flag:", _flash_supports_window_size)


In [5]:
# load model in NF4 and properly set torch_dtype and quantization_config
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM
import torch
# Set up quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,  # Load model weights in 4-bit precision
    bnb_4bit_quant_type="nf4",  # Use NF4 quantization type
    bnb_4bit_use_double_quant=True,  # Use double quantization
    bnb_4bit_compute_dtype=torch.bfloat16  # Use bfloat16 for computation
)

# Load the model with the quantization configuration
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,  # Explicitly set torch_dtype
    device_map="auto",
    attn_implementation="sdpa",
)

# Ensure the model is in training mode
model.train()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((409

In [6]:
#  Apply LoRA to the 4-bit model
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    r=8,  # LoRA attention dimension
    lora_alpha=32,  # Alpha parameter for LoRA scaling
    lora_dropout=0.05,  # Dropout probability for LoRA layers
    bias="none",  # Do not train bias
    task_type="CAUSAL_LM",  # Task type is causal language modeling
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] # Target modules for LoRA
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)

# Print trainable parameters
model.print_trainable_parameters()

trainable params: 20,971,520 || all params: 8,051,232,768 || trainable%: 0.2605


# Dataset
For this project, we utilize the GSM8K (Grade School Math 8K) dataset, which consists of 8,500 carefully curated grade school math word problems characterized by linguistic diversity and high quality. GSM8K was specifically developed to facilitate question-answering tasks that involve basic math problems requiring multi-step reasoning. https://huggingface.co/datasets/openai/gsm8k

One fundamental difference between Supervised Fine-Tuning (SFT) and Reinforcement Learning with GRPO lies in how inputs and outputs are structured during training. In SFT, we typically concatenate the question and the expected answer, providing both to the model as input-output pairs for supervised learning. In contrast, with GRPO, only the question is provided as input, and the model is tasked with generating both the reasoning and the final answer autonomously.

This distinction is crucial because, in GRPO, the model's output—consisting of the reasoning followed by the answer—is evaluated using reward functions. These rewards assess not only the correctness of the answer but also how well the model adheres to the specified output format and instructions, such as properly delineating reasoning and answers with predefined tags.

# Preprocess Dataset
Let's download openai/gsm8k dataset and convert each example into a dictionary {"prompt": ..., "answer": ...}  where
- prompt field is a list of {"role": ..., "content": ...} where role is either "system" or "user" and content is the prompt.
- answer field is the final numeric answer for the math problem

Since one of the roles is system, it is a good time to finalize the system prompt that instructs the model how to generate its completions in a specific format. Specifically, we want the model to provide its reasoning enclosed within <thinking> and </thinking> tags, followed by the final answer enclosed within <answer> and </answer> tags. This formatting is critical, as the reward functions we design will rely on it to evaluate whether the model's completions adhere to our instructions. Of course, depending on the capabilities of the base model, it may not consistently follow the system prompt as intended. This is precisely where reinforcement learning (RL) comes in — to reinforce and instill this structured format into the model's behavior.

In [7]:
# Here is an example system prompt that we can use
SYSTEM_PROMPT = """
Respond in the following format and make sure the entire response is wrapped in <reasoning> and <answer> tags with no other text outside of these tags:
<thinking>
your reasoning goes here and do not use newlines so that the entire reasoning becomes a single paragraph
</thinking>
<answer>
your answer goes here
</answer>
"""


If you review the implementation of apply_chat_template in trl/data_utils.py, you'll notice that it expects both a "prompt" and a "completion" in the input. When provided, Hugging Face (HF) will automatically concatenate the prompt and completion to form the final input sequence for the model. However, this behavior is exactly what we want to avoid in our setup.

Unlike Supervised Fine-Tuning (SFT), where exposing the answer in the completion is necessary for next-token prediction, our goal with reinforcement learning is fundamentally different. We do not want to show the answer to the model upfront. If we included the answer in the completion field, the model would simply learn to predict the next token based on the provided answer—defeating the purpose of encouraging autonomous reasoning.

What we actually want is for the model to generate both the reasoning and the final answer on its own, starting only from the question. To achieve this, we deliberately avoid using the "completion" field. Instead, we store the ground truth answer separately in an "answer" field, which is ignored by the trainer’s apply_chat_template function. This way, the model receives only the question as input, and its output can then be evaluated using reward functions that check both the quality of reasoning and the correctness of the answer—without ever having been exposed to the answer during generation.



In [8]:
#  Preprocess openai/gsm8k and apply chat template
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from trl import apply_chat_template                                   # thin wrapper around tokenizer.apply_chat_template
import random, textwrap

def gsm8k_to_chat(example):
    """
    Convert a raw GSM8K row into:
      {
        "prompt": [ {"role":"system", "content": SYSTEM_PROMPT},
                    {"role":"user",   "content": <question>} ],
        "answer": <gold_integer_string>
      }
    The gold answer is **not** put in a 'completion' field, so the trainer will
    pass only the 'prompt' list to apply_chat_template().
    """
    return {
        "prompt": [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user",   "content": example["question"].strip()}
        ],
        "answer": example["answer"].strip()          # we keep it for rewards later
    }


In [9]:
raw_train = load_dataset("openai/gsm8k", "main", split="train")
raw_test  = load_dataset("openai/gsm8k", "main", split="test")

train_chat = raw_train.map(gsm8k_to_chat, remove_columns=raw_train.column_names)
test_chat  = raw_test.map(gsm8k_to_chat,  remove_columns=raw_test.column_names)


In [10]:
import random
from statistics import mean

def render_with_template(prompt_list, *, tokenize=False):
    """
    Wrapper that handles both old and new tokenizer signatures.
    If your tokenizer supports `add_generation_prompt`, we use it;
    otherwise we call it without that kwarg.
    """
    try:
        return tokenizer.apply_chat_template(
            prompt_list,
            tokenize=tokenize,
            add_generation_prompt=True
        )
    except TypeError:
        # older transformers version – no add_generation_prompt kwarg
        return tokenizer.apply_chat_template(
            prompt_list,
            tokenize=tokenize
        )

# --- show a single rendered example ---------------------------------------
sample = train_chat[random.randint(0, len(train_chat) - 1)]
rendered_text = render_with_template(sample["prompt"], tokenize=False)
print("---- rendered text sent to the model ----")
print(rendered_text[:800] + (" ..." if len(rendered_text) > 800 else ""))

# --- compute token lengths for the whole training split --------------------
def prompt_length(example):
    ids = render_with_template(example["prompt"], tokenize=True)
    return len(ids)

lengths = [prompt_length(ex) for ex in train_chat]

print(f"Average prompt length: {mean(lengths):.1f}")
print(f"Max prompt length: {max(lengths)}")


---- rendered text sent to the model ----
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

Respond in the following format and make sure the entire response is wrapped in <reasoning> and <answer> tags with no other text outside of these tags:
<thinking>
your reasoning goes here and do not use newlines so that the entire reasoning becomes a single paragraph
</thinking>
<answer>
your answer goes here
</answer><|eot_id|><|start_header_id|>user<|end_header_id|>

Mark is looking to buy a total of 12 pieces of fruit at the store. He has already chosen 3 apples. He has also selected a bunch of bananas containing 4 bananas. How many oranges does he need to pick out to have 12 total pieces of fruit?<|eot_id|><|start_header_id|>assistant<|end_header_id|>


Average prompt length: 161.3
Max prompt length: 313


# Reward Functions
GRPO needs a set of reward functions to determine how well model's completions adhere to our expectations. Let's add some reward functions. Note that the GRPO trainer sends the following input arguments, among others, to each reward function:
- prompts,
- completions,
- answer

You, as the designer, can choose to use them all or use whichever you need and pass the rest to **kwargs.

- Example: def reward_func(prompts, completions, answer, **kwargs) ---> uses all 3
- Example: def reward_func(completions, answer, **kwargs) ---> uses only completion

Also, you should make sure  there is a 1-to-1 correspondence between completions and returned rewards.

In [11]:
# Example: a reward function that penalizes the completion if the completion length is longer than 100 tokens
def length_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [0.0 if len(tokenizer(contents[i])) < 100 else -0.5 for i in range(len(completions))]

A reward function to extract the final answer from model completions and check if it is an integer.

In [12]:

import re
from typing import List

_ANS_TAG_RE = re.compile(r"<answer>\s*(.*?)\s*</answer>", re.IGNORECASE | re.DOTALL)

def int_reward_func(completions: List[list], **kwargs) -> List[float]:
    """
    Reward = +1.0  if the model's <answer>…</answer> tag contains
                     a legal integer (optionally signed).
             -1.0  otherwise.

    Notes
    -----
    • `completions` is a list of *chat turns*; each element is itself a
      list of messages.  We follow the same convention used in the earlier
      `length_reward_func`, extracting the first (and usually only) message
      via completions[i][0]["content"].

    • This function deliberately ignores the ground‑truth `answer`;
      it checks only *format* + *integer‑parsability* and leaves correctness
      to a separate reward.
    """
    rewards = []
    for comp in completions:
        text = comp[0]["content"] if comp and isinstance(comp[0], dict) else ""
        m = _ANS_TAG_RE.search(text)
        if m:
            candidate = m.group(1).strip()
            # integer test: optional sign followed by digits
            rewards.append( 1.0 if re.fullmatch(r"[+-]?\d+", candidate) else -1.0 )
        else:
            # no <answer> tag at all = strong penalty
            rewards.append(-1.0)
    return rewards



A reward function to ensure the completion follows this pattern: `^<thinking>\n.*?\n</thinking>\n<answer>\n.*?\n</answer>$`





In [13]:
import re
from typing import List

# Strict pattern: exactly two blocks in order, each wrapped in its XML tag,
# with *single* newline after each tag open/close (per your example pattern).
_STRICT_RE = re.compile(
    r"^<thinking>\n.*?\n</thinking>\n<answer>\n.*?\n</answer>$",
    re.DOTALL
)

def strict_format_reward_func(completions: List[list], **kwargs) -> List[float]:
    """
    +1.0  if the completion matches the exact multi‑line template
          <thinking>\n…\n</thinking>\n<answer>\n…\n</answer>
    -1.0  otherwise.
    """
    rewards = []
    for c in completions:
        txt = c[0]["content"] if c and isinstance(c[0], dict) else ""
        rewards.append(1.0 if _STRICT_RE.fullmatch(txt) else -1.0)
    return rewards


A reward function that penalizes the model if the completion has any tokens after the answer tag.

In [14]:
_END_TAG_RE = re.compile(r"</answer>\s*(.*)$", re.DOTALL)

def xmlcount_reward_func(completions: List[list], **kwargs) -> List[float]:
    """
    0.0   if *no* text remains after the closing </answer> tag.
    -0.5  if any non‑whitespace characters appear afterwards.
    """
    rewards = []
    for c in completions:
        txt = c[0]["content"] if c and isinstance(c[0], dict) else ""
        match = _END_TAG_RE.search(txt)
        if not match:                      # no </answer> tag at all → penalise
            rewards.append(-0.5)
        else:
            tail = match.group(1)
            rewards.append(0.0 if tail.strip() == "" else -0.5)
    return rewards


A reward function to extract the final answer from model completions and check if it is correct.

In [15]:
_ANS_RE   = re.compile(r"<answer>\s*([+-]?\d+)\s*</answer>", re.I | re.S)
FINAL_RE  = re.compile(r"####\s*([+-]?\d+)", re.S)

def extract_gold(ans_str):
    m = FINAL_RE.search(ans_str)
    return m.group(1).lstrip("+").lstrip("0") if m else None

def answer_correctness_func(
    prompts: List[list],
    completions: List[list],
    answer: List[str],
    **kwargs
) -> List[float]:
    """
    +1.0 if the integer inside <answer>...</answer> matches the integer
         after '####' in the GSM8K gold string.
     0.0 otherwise.
    """
    rewards = []
    for comp, gold_str in zip(completions, answer):
        # model prediction
        txt       = comp[0]["content"] if comp and isinstance(comp[0], dict) else ""
        pred_match = _ANS_RE.search(txt)
        pred_num   = pred_match.group(1).lstrip("+").lstrip("0") if pred_match else None

        # ground truth
        gold_num = extract_gold(gold_str)

        rewards.append(1.0 if pred_num is not None and pred_num == gold_num else 0.0)

    return rewards


### GRPO config



In [16]:
from trl import GRPOConfig

grpo_cfg = GRPOConfig(
    # ---------- folders / logging ----------
    output_dir                  = "ckpts/gsm8k_grpo_lora",
    logging_steps               = 10,
    evaluation_strategy         = "steps",
    eval_steps                  = 250,
    save_strategy               = "epoch",
    bf16                        = True,          # if we are using A100 / H100

    # ---------- optimiser ----------
    learning_rate               = 2e-4,
    weight_decay                = 0.01,
    lr_scheduler_type           = "cosine",
    warmup_steps                = 500,

    # ---------- batching ----------
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 8,              # ⇒ effective 32

    # ---------- GRPO specifics ----------
    num_generations             = 4,              # group size G
    beta                        = 0.02,           # KL coef
    temperature                 = 0.7,
    top_p                       = 0.9,

    # ---------- NEW length knobs ----------
    max_prompt_length           = 512,            # fits every GSM8K question + system msg
    max_completion_length       = 256,            # ample for chain‑of‑thought + answer
    label_names=[],          # suppress Trainer’s guess‑and‑warn

)
print("GRPOConfig built OK ✅")




GRPOConfig built OK ✅


# 🧠 Understanding GRPO Trainer Workflow

While setting up a GRPO config is straightforward, understanding *what each hyperparameter means* and *how the underlying training logic works* is essential. This guide walks through the GRPO training loop, highlighting the key code paths in `trl/trainer/grpo_trainer.py`.

---

## 📂 GRPO Trainer in Context

GRPO is built on top of the TRL `Trainer` class, but it customizes key parts of the training pipeline by overriding several critical methods:

- `compute_loss`
- `prepare_input`

> These functions are the heart of GRPO's training logic. Let’s look at what happens in each of them.

---

## 🧪 GRPO Workflow in `train_step`

In each `train_step`:

1. `Trainer.train_step()` in the parent class first calls `prepare_input()` in the GRPO trainer.
2. Then it calls `compute_loss()` with the output of `prepare_input()`.

---

## 🔄 `prepare_input()` Logic

This function handles **prompt generation**, **inference**, and **buffering** of completions. Here's a breakdown:

### ✅ High-Level Behavior:

- Called **once per `self.num_iterations`** to generate fresh completions.
- **Buffering** is used to reuse completions across multiple gradient updates.

> For example:
> - `num_iterations = 4`: generate completions every 4 global steps. Each global step can consists of k grad accumulation steps.
> - `gradient_accumulation_steps = 5`: buffer stores 5 sets of completions.

If `num_iterations = 1`, completions are not reused, and buffering may be unnecessary.

### 🧬 Core Logic (Simplified):

```python
@profiling_decorator
def _prepare_inputs(self, inputs: dict) -> dict:
    mode = "eval" if self.control.should_evaluate else "train"
    if mode == "train":
        if self.state.global_step % self.num_iterations == 0:
            inputs = self._generate_and_score_completions(inputs)
            self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
        else:
            inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
        self._step += 1
    else:
        inputs = self._generate_and_score_completions(inputs)
    return inputs
```

---

## ✍️ Prompt Construction

- Conversations are stored in the `"prompt"` field.
- The `apply_chat_template()` function is used to convert this into a generation-ready string.
  
```python
prompt = tokenizer.apply_chat_template(
    example["prompt"], tools=tools, tokenize=False, add_generation_prompt=True
)
```

---

## 🚀 Inference (Completions Generation)
You can check the code for vllm path which is more efficient. I didn't get a chance to look into it.

for non-vllm path:

- each rank tokenizes its prompts and truncates them by max_prompt_length
    
- each rank will send its text prompts to the on-device llm

    ```python
    prompt_completion_ids[:,-5:] => the last 5 tokens for 6 generations made on GPU0. With batched inference, we append EOS (128009) for early completed answers.
        tensor([[128009, 128009, 128009, 128009, 128009],
                [128009, 128009, 128009, 128009, 128009],
                [128009, 128009, 128009, 128009, 128009],
                [   279,   1176,    220,    605,  14741],
                [   220,    430,    568,  11021,    220],
                [128009, 128009, 128009, 128009, 128009]], device='cuda:0')
    ```


---

## 🧮 Completion Masking

Completion masks identify **valid** (non-EOS) completions.

Each rank compute completion masks based on the first occurrence of EOS. The following is the mask for completions above.

```python
completion_mask[:,-5:]
    tensor([[0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0],
            [1, 1, 1, 1, 1],
            [1, 1, 1, 1, 1],
            [0, 0, 0, 0, 0]], device='cuda:0', dtype=torch.int32)
```



---

## 🧠 Reward Computation and Normalization

Each rank:
- Decodes completions
- Computes reward for (prompt, completion)
- Gathers rewards from other ranks, because it's possible for a given prompt to have its replica across GPUs. See the sampler.
- Normalizes rewards by mean/std => This gives us advantages $A(s, a)$
- Discards completions for prompts it doesn’t own (called *alien prompts*)

Let’s walk through a concrete example to understand prompt replication and reward normalization across GPUs.

🧮 Setup:

- 3 GPUs (Ranks)
- Batch size per GPU: 8
- Number of generations per prompt: 6

📦 Total Prompts Processed per Iteration:
- Total batch size = 3 GPUs × 8 prompts = 24 prompts
- Since each prompt needs 6 completions, we can only afford to have:
 - 24 / 6 = 4 unique prompts

Each prompt is replicated 6 times across all ranks.

🔁 Prompt Distribution Example:
On GPU 0 (Rank 0):
- 8 total prompts:
 - 6 replicas of Prompt #1
 - 2 replicas of Prompt #2

> Note: Other ranks may hold the remaining replicas of Prompt #2 and additional prompts.

🎯 Reward Normalization:
- For Prompt #1:
    - All 6 completions are on GPU 0.
    - ✅ Rank 0 can compute mean and std of rewards locally.


- For Prompt #2:
    - Rank 0 has only 2 out of 6 completions.
    - ❌ It cannot compute accurate reward statistics alone.
    - ✅ Needs to gather the remaining 4 rewards from other ranks.


> That's why all gather is needed so each rank has access to the required replicas
---

## 🧾 Logprob Computation


- We generate rollouts (prompts + completions) are stored per grad accumulation step (step % G) during this first iteration.

- These buffered inputs are reused for the remaining iterations within that num_iterations window.
    - Buffering occurs only during the first num_iterations step (i.e., at global_step % num_iterations == 0).

- Logprobs for old policy and ref policy are computed during the first iteration only (i.e., when completions are freshly generated).

- Logprobs for current policy are computed in every gradient accumulation step, including when using buffered inputs.


> Note: old and current policy logprobs will be identical when `num_iterations = 1`.

---

## 🧮 `compute_loss()` Breakdown

$$
\mathcal{L}_{\text{GRPO}} = - \mathbb{E}_{(s, a)} \left[ \frac{\pi(a|s)}{\pi_{\text{old}}(a|s)} A(s, a) \right] + \beta \cdot \text{KL}[\pi(\cdot|s) \| \pi_{\text{ref}}(\cdot|s)]
$$

### Steps:

1. **Concatenate** `prompt_ids + completion_ids`.
2. Run a **forward pass** through the old policy to compute $\pi_{old}(a|s)$
    - This actually happens only once at the first iteration when we create the rollout.
3. Run a **forward pass** through the ref policy to compute  $\pi_{ref}(a|s)$
    - This actually happens only once at the first iteration when we create the rollout.
    - ref model is the original model without LoRA adapters
4. Run a **forward pass** through the current policy to compute $\pi(a|s)$
    - Needed only if number_iterations > 1; otherwise the same as old policy
5. Compute **KL loss**: between $\pi(a|s)$ and $\pi_{ref}(a|s)$.

6. Compute **advantage-weighted logprobs**: $\frac{\pi(a|s) }{ \pi_{old}(a|s)} * A(s, a)$

---

## 📋 GRPO Trainer Workflow Summary

| Component             | What It Does                                                                 | Why It Matters                                                                 |
|----------------------|------------------------------------------------------------------------------|--------------------------------------------------------------------------------|
| GRPO Trainer Class   | Extends the TRL Trainer, overrides `prepare_input` and `compute_loss`       | Customizes the loss and input preparation for rollout reuse and reward learning |
| `train_step`         | Calls `prepare_input` then `compute_loss`                                    | Controls how and when rollouts and gradients are processed                     |
| `prepare_input()`    | Generates and buffers rollouts once every `num_iterations` steps             | Allows reuse of expensive rollouts across multiple updates                     |
| Prompt Construction  | Uses `apply_chat_template` to create generation-ready input                  | Ensures the model sees correctly formatted conversational prompts             |
| Inference            | Uses on-device or vLLM backend to generate completions                       | Provides actions for which logprobs and rewards are computed                   |
| Completion Masking   | Identifies valid (non-EOS) completions                                       | Ensures reward/logprob computation only applies to meaningful tokens          |
| Reward Normalization | Aggregates and normalizes rewards across GPUs for each prompt                | Yields correct advantage estimates across distributed setup                   |
| Logprob Computation  | Computes logprobs of old, ref, and current policy                            | Needed for KL and advantage-weighted loss; reused if `num_iterations > 1`     |
| `compute_loss()`     | Combines KL divergence and advantage-weighted logprob ratio                 | Drives the optimization update direction for the policy                       |



In [17]:
#@title An easier setting on faster training
grpo_cfg.num_generations          = 2
grpo_cfg.max_completion_length    = 128
grpo_cfg.num_iterations           = 4
grpo_cfg.per_device_train_batch_size = 8
grpo_cfg.gradient_accumulation_steps = 4   # keeps effective batch = 32


In [18]:
grpo_cfg.deepspeed = "ds_zero2.json"     # or zero3


In [19]:
len(train_chat), len(test_chat)

(7473, 1319)

In [None]:
#@title Eary experimentation setting
small_train = train_chat.shuffle(seed=42).select(range(3000))  # 2k examples
grpo_cfg.num_train_epochs = 1


In [20]:
from trl import GRPOTrainer

trainer = GRPOTrainer(
    model           = model,             # your LoRA‑patched Llama‑3‑8B
    args            = grpo_cfg,          # <-- positional name is 'args'
    processing_class= tokenizer,         # tokenizer goes here (alias kept for BW compat) :contentReference[oaicite:6]{index=6}
    reward_funcs    = [
        strict_format_reward_func,
        xmlcount_reward_func,
        int_reward_func,
        answer_correctness_func
    ],
    # train_dataset   = small_train,     # uncomment if only want to do a quick experiment
    train_dataset   = train_chat,

    eval_dataset    = test_chat,
    peft_config     = lora_config,       # only if you re‑instantiate; else drop
)




In [None]:
trainer.train()
trainer.save_model("llama3_8b_gsm8k_grpo_lora")


[34m[1mwandb[0m: Currently logged in as: [33mjhuang-daniel[0m ([33mjhuang-daniel-siclarity[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
250,0.0221,0.046335
500,0.0203,0.031951


Step,Training Loss,Validation Loss
250,0.0221,0.046335
500,0.0203,0.031951
750,0.016,0.039535


In [None]:
# prompt: download the trained model to local machine

from google.colab import files

# Save the model checkpoint locally
!zip -r /content/llama3_8b_gsm8k_grpo_lora.zip /content/llama3_8b_gsm8k_grpo_lora

# Download the zip file
files.download('/content/llama3_8b_gsm8k_grpo_lora.zip')


In [None]:
# prompt: clear memory

import gc
gc.collect()
torch.cuda.empty_cache()
# del trainer

<a name="Inference"></a>
### Evaluate the model
Now let's try the model we just trained!

In [30]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write

In [None]:
#  load a LoRA checkpoint
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

base_model_id   = "meta-llama/Meta-Llama-3.1-8B-Instruct"
lora_ckpt_path  = "llama3_8b_gsm8k_grpo_lora"

tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)

# IMPORTANT: load the *full‑precision* base so merge math is exact
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,
    torch_dtype="auto",       # fp16/bf16 on GPU → fine for merge
    device_map="auto"
)

model = PeftModel.from_pretrained(base_model, lora_ckpt_path)


In [None]:
merged_model = model.merge_and_unload()       # returns a plain `AutoModel`


In [None]:
# Run this right after we reload the tokenizer / model to make Llama3 happy
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    # If the model and tokenizer sizes now differ, resize embeddings:
    merged_model.resize_token_embeddings(len(tokenizer))


In [None]:
#  load test dataset
from datasets import load_dataset

# If you already mapped the test split earlier, you can skip the next two lines
raw_test  = load_dataset("openai/gsm8k", "main", split="test")   # 1 319 rows
test_chat = raw_test.map(gsm8k_to_chat, remove_columns=raw_test.column_names)

print("Loaded", len(test_chat), "examples for evaluation")


In [None]:
import re, torch, textwrap, random
from tqdm import tqdm

answer_re = re.compile(r"<answer>\s*([+-]?\d+)\s*</answer>", re.I | re.S)

def extract_num(text):
    m = answer_re.search(text)
    return m.group(1).lstrip("+").lstrip("0") if m else None

final_re = re.compile(r"####\s*([+-]?\d+)", re.S)

def extract_gold(ans_str):
    m = final_re.search(ans_str)
    return m.group(1).lstrip("+").lstrip("0") if m else None

def evaluate(test_ds, model, tokenizer, batch_size=8):
    model.eval()
    correct = 0
    total   = 0
    samples = []

    for start in tqdm(range(0, len(test_ds), batch_size)):
        # .select() returns a Dataset;  .to_list() turns it into list‑of‑dicts
        batch_rows = test_ds.select(range(start, min(start+batch_size, len(test_ds)))).to_list()

        prompt_texts = [
            tokenizer.apply_chat_template(ex["prompt"], add_generation_prompt=True, tokenize=False)
            for ex in batch_rows
        ]
        inputs = tokenizer(prompt_texts, return_tensors="pt", padding=True).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                temperature=0.0,     # deterministic
                top_p=1.0,
                do_sample=False,
            )
        decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        for ex, tex in zip(batch_rows, decoded):
            pred = extract_num(tex)
            gold = extract_gold(ex["answer"])
            total += 1
            if pred == gold:
                correct += 1
            if len(samples) < 10:
                samples.append((ex["prompt"][1]["content"], pred, gold, tex[:160]))

    accuracy = correct / total
    return accuracy, samples

acc, samples = evaluate(test_chat, merged_model, tokenizer)
print(f"\nAccuracy on GSM8K test split: {acc:.2%}  ({int(acc*len(test_chat))}/{len(test_chat)})")

print("\n--- sample completions ---")
for q,p,g,snip in samples:
    print(textwrap.shorten(q, 70), "| pred:", p, "| gold:", g)
    print("  ↳", textwrap.shorten(snip, 120), "\n")


In [None]:
output_dir = "merged_llama3_gsm8k"
merged_model.save_pretrained(output_dir, safe_serialization=True)
tokenizer.save_pretrained(output_dir)


In [None]:
from bitsandbytes import Int8Params
from transformers import BitsAndBytesConfig

bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
quant_model = AutoModelForCausalLM.from_pretrained(output_dir, quantization_config=bnb_cfg)
quant_model.save_pretrained(output_dir + "_4bit")


In [None]:
repo_name = "lordChipotle/Llama3GRPOReasoning"
merged_model.push_to_hub(repo_name, private=False)    # or quant_model.push_to_hub(...)
tokenizer.push_to_hub(repo_name)


# Next?
Perhaps extending this notebook for something similar to [medical reasoning](https://github.com/matthewchung74/qwen_2_5_3B_GRPO_medical_thinking/blob/main/Qwen2_5_(3B)_GRPO.ipynb)?

