# nanoAhaMoment: Single File "RL for LLM" Library
Single GPU · No TRL or Verl · Efficient · 3B Base Model · Full Parameter Tuning Implementation of R1-zero training.

Inspired by [TinyZero](https://github.com/Jiayi-Pan/TinyZero) and [Mini-R1](https://www.philschmid.de/mini-deepseek-r1), but designed to be **simpler**, **cleaner**, and **faster**, with every line of code visible and understandable.

R1-Zero is arguably the more interesting contribution from the DeepSeek R1 paper. The core idea: take a freshly pre-trained LLM (straight out of the unsupervised pretraining oven) and continue its training using reinforcement learning *without* any human feedback or supervision. The result? A model that starts showing emergent behaviors like self-reflection, verification, backtracking that researchers have tried to bake into LLMs using handcrafted tricks and inductive biases, at least since O1.

In this notebook, we’ll build an R1-Zero-style training loop **from scratch**. The goal is to create a crystal-clear, hackable foundation for RL-style LLM training; one that gives you a bird’s-eye view of every moving part and how they fit together. Perfect for playing around, extending, or hacking.

---

### Why another R1-Zero implementation?

There are already great implementations like [TinyZero](https://github.com/Jiayi-Pan/TinyZero) and [Mini-R1](https://www.philschmid.de/mini-deepseek-r1). But they rely on full-fledged RL libraries (like `trl` or `verl`) to handle training.

These libraries exist for good reason; efficient RL training for LLMs sits at the crossroads of scalable training and fast inference. Making that work takes a lot of engineering. But that also means the internals are often abstracted away, hard to read, and even harder to tweak.

This notebook is different: **no abstractions, no hiding**. You’ll see everything, top to bottom. A lightweight, readable codebase that still follows best practices and runs efficiently on a single GPU.

### What is this notebook, exactly?

We'll train a base LLM using RL to solve a reasoning-heavy algorithmic task. The setup:

- **Model**: Qwen2.5 3B-Base  
- **Dataset**: Countdown-Tasks-3to4  
- **Algorithm**: GRPO (a variant of policy gradient)

Yes, the task is a bit toy-ish—but it captures the essence of R1-Zero: emergent behaviors like self-reflection, verification, backtracking, even language-switching. This setup is ideal for rapid prototyping and experimentation.

### Who is this notebook for?

- Anyone interested in RL training for LLMs  
- Researchers, especially the ones in academia, exploring reasoning in language models

### What should I know before jumping in?

- A working knowledge of the HuggingFace Transformers library  
- Some experience fine-tuning LLMs  
- Familiarity with policy gradient methods (helpful but not required)

## R1-Zero Recipe

The goal is to train a base LLM to **reason** in a way that allows it to **reevaluate** its own outputs and **improve** them, all without human supervision. The DeepSeek R1 paper proposes a surprisingly simple recipe to achieve this, and that's exactly what we'll implement in this notebook.

### The Recipe

Here's the high-level procedure:

1. **Start** with a base LLM and a dataset containing problem prompts paired only with their *final answers* (no intermediate reasoning steps).  
2. For each iteration $i = 0$ to `NUM_ITERATIONS`:
   - Sample a batch of prompts $\{x_i\}_{i=1}^N$ from the dataset.
   - For each prompt, sample $G$ responses from the model:  
     $ y_1, y_2, \cdots, y_G \sim \pi_\theta(y|x) $

     These $G$ responses form what is called a *group* in GRPO.
   - Compute a reward $R_i$ for each response and normalize them tocalculate the GRPO advantage within each group.
   - Create a list of $N \times G$ episodes, i.e., pairs of $(x_i, y_i)$ along with their corresponding advantages.
   - Estimate the policy gradient $\vec{g}_{pg}$ from these episodes.
   - Update the model parameters:  
     $\theta \leftarrow \theta + \eta \vec{g}_{pg}$

### Code Structure Overview

The code you will see is structured directly following this recipe. It boils down to three main components:

1. **Episode Generation**  
   - Generate $ (x, y) $ pairs along with their advantages for each RL iteration.
   
2. **Reward Calculation**  
   - Compute rewards for each generated response.
   
3. **Policy Gradient Estimation**  
   - Use the generated episodes to estimate the policy gradient and perform the model update.

In the end, these three components come together in a simple loop that trains the model, step by step, to develop reasoning capabilities through reinforcement learning.


## Checkpoint Playground

In the `notebooks/checkpoint_playground.ipynb`, you can load the model we already trained with this notebook and interactively test the model's reasoning capabilities. This notebook allows you to input custom prompts and observe the model's responses.

## Prerequisites

### Installing Dependencies

Before we begin, let's install the necessary Python packages. We'll be using:

- PyTorch  
- Hugging Face Transformers  
- Hugging Face Datasets  
- DeepSpeed  
- vLLM

For a detailed, step-by-step installation guide, refer to the [README](https://github.com/McGill-NLP/tiny-aha-moment.git) of this project.

In [1]:
import os
from pathlib import Path

# Set the environment variables for HuggingFace
# This is done to ensure that the cache directory for HuggingFace is set to a specific location,
# preventing the storage from being overwhelmed with model files and other data.
SCRATCH = Path.home() / "scratch"
os.environ["HF_HOME"] = str(SCRATCH / "hf_home")

### Import the required libraries

In [2]:
import gc
import re
import time
from typing import Any, Dict, List, Tuple, Union

import deepspeed
import numpy as np
import torch
from datasets import load_dataset
from deepspeed import DeepSpeedEngine
from tqdm import trange
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
from vllm import LLM, SamplingParams

import wandb
from utils import (
    compute_token_log_probs,
    dump_episodes,
    evaluate_on_test_set,
    find_free_port,
    find_last_checkpoint,
    prepare_model_inputs,
    load_model_into_vllm
)

# Needed to stop DeepSpeed from complaining
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(find_free_port())
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"

[2025-05-24 12:45:49,651] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to cpu (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


ModuleNotFoundError: No module named 'utils'

**We do have a few helper functions in `utils.py` that are used to keep the code clean.**

## Hyperparameters

Let's define the hyperparameters for the training. These are mostly taken from [Mini-R1](https://www.philschmid.de/mini-deepseek-r1) implementation.

In [3]:
# Model configuration
MODEL_NAME = "Qwen/Qwen2.5-3B"
MODEL_CHAT_NAME = MODEL_NAME + "-Instruct"

# Dataset configuration
DATASET_NAME = "Jiayi-Pan/Countdown-Tasks-3to4"

# Total number of training iterations
NUM_ITERATIONS = 1000
# Number of episodes to collect per iteration for training
EPISODES_PER_ITERATION = 64
# Number of responses to generate for each input prompt (i.e. group size in GRPO)
GENERATIONS_PER_SAMPLE = 4
# Controls how much the policy can deviate from the reference model
KL_COEFFICIENT = 0.001

# Training hyperparameters
# Batch size for each GPU device during training
PER_DEVICE_BATCH_SIZE = 4
# Learning rate for model updates
LEARNING_RATE = 1e-6

# Sampling parameters
# Maximum number of tokens to generate in each response
MAX_RESPONSE_TOKENS = 1024
# Controls randomness in generation (higher = more random)
TEMPERATURE = 1.0
# Nucleus sampling parameter (1.0 = disabled)
TOP_P = 1.0
# Top-k sampling parameter (-1 = disabled)
TOP_K = -1  # no top k

# DeepSpeed configuration
# DeepSpeed config for the policy model
deepspeed_config = {
    "bf16": {"enabled": True},
    "zero_optimization": {"stage": 2, "overlap_comm": False},
    "train_batch_size": EPISODES_PER_ITERATION,
    "train_micro_batch_size_per_gpu": PER_DEVICE_BATCH_SIZE,
    "gradient_accumulation_steps": EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE,
    "gradient_clipping": 1.0,
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": LEARNING_RATE,
            "betas": (0.9, 0.999),
            "eps": 1e-8,
            "weight_decay": 0.0,
            "torch_adam": True,
        },
    },
}
# DeepSpeed config for the reference model
ref_deepspeed_config = {
    "bf16": {"enabled": True},
    # Note that we don't train the reference model
    # These are just for compatibility with DeepSpeed.
    "train_batch_size": EPISODES_PER_ITERATION,
    "train_micro_batch_size_per_gpu": PER_DEVICE_BATCH_SIZE,
    "gradient_accumulation_steps": EPISODES_PER_ITERATION // PER_DEVICE_BATCH_SIZE,
}

RUN_NAME = "r1-zero"
EXP_DIR = SCRATCH / "deepseek_r1z_hackathon" / RUN_NAME
EXP_DIR.mkdir(parents=True, exist_ok=True)
print(f"Logs and Checkpoints will be saved to: {EXP_DIR}")

Logs and Checkpoints will be saved to: /run/determined/workdir/scratch/deepseek_r1z_hackathon/r1-zero


## Generating the training prompts

For training, we'll use the [Countdown-Tasks-3to4](https://huggingface.co/datasets/Jiayi-Pan/Countdown-Tasks-3to4) dataset, which provides problem statements paired with their final answers (but no reasoning steps).

### The Countdown Task

The Countdown game is a numerical puzzle where the player must reach a target number using a set of randomly chosen numbers and basic arithmetic operations: addition, subtraction, multiplication, and division. Each number must be used exactly once.

Example:

```yaml
Target: 622
Available Numbers: [25, 3, 6, 100]

# Not provided in the dataset
Solution: (100 × 6) + (25 − 3) = 622
```

This task is ideal for training LLMs to practice reasoning, searching, and self-verification.


Since we are using the base version of the model, which has only been pretrained on raw internet data, it has no prior understanding of system prompts or chat formatting. However, we will still use the chat format to make the resulting model compatible with downstream tools and frameworks that expect it.

In [4]:
SYSTEM_MESSAGE = (
    "You are a helpful assistant. You first think about the reasoning process in the mind "
    "and then provide the user with the answer."
)
PROMPT_TEMPLATE = (
    "Using the numbers {numbers}, create an equation that equals {target}. "
    "You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. "
    "Show your work in <think> </think> tags. And return the final equation and answer in "
    "<answer> </answer> tags, for example <answer>(1 + 2) / (3 * 5)</answer>."
)

Now that we have the system message and prompt template, we can generate the training prompts.

In [5]:
# Load and process dataset
def preprocess_example(example: Dict[str, Any]):
    numbers: List[int] = example["nums"]
    target: int = example["target"]

    prefix = [
        {"role": "system", "content": SYSTEM_MESSAGE},
        {"role": "user", "content": PROMPT_TEMPLATE.format(numbers=numbers, target=target)},
        {"role": "assistant", "content": "Let me solve this step by step.\n<think>"},
    ]
    input_ids = tokenizer.apply_chat_template(
        prefix, tokenize=True, continue_final_message=True
    )
    prompt = tokenizer.decode(
        input_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
    )
    return {"prompt": prompt, "input_ids": input_ids}

# Note that the base model and "instruct" model have different eos token. 
# Here we make sure to use the correct one.
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHAT_NAME)
EOS_TOKEN_ID = AutoTokenizer.from_pretrained(MODEL_NAME).eos_token_id
EOS_TOKEN = tokenizer.convert_ids_to_tokens(EOS_TOKEN_ID)

dataset = load_dataset(DATASET_NAME, split="train")
dataset = dataset.map(preprocess_example, num_proc=6)

# Split dataset
train_test_split = dataset.train_test_split(test_size=500, seed=42)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

len(train_dataset), len(test_dataset)

OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like Qwen/Qwen2.5-3B-Instruct is not the path to a directory containing a file named config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

Let's look at some examples from the dataset.

In [None]:
print("Target: ", train_dataset[0]["target"])
print("Available Numbers: ", train_dataset[0]["nums"])

Using the system message and prompt template, we generate the following prompt for this example:

In [None]:
print(train_dataset[0]["prompt"])

As you noticed, we also prepend the `<assistant>` tag along with the phrase *"Let me solve this step by step."* to each prompt. This helps guide the model into **answering mode**. Without this, the base model might simply continue the prompt rather than attempting to solve the task, since it has no inherent understanding of instruction-following.

Additionally, we tokenize each prompt and store the result as `input_ids`, which will be used later during training.

In [None]:
print(train_dataset[0]["input_ids"])

## Reward Function


The DeepSeek R1 paper introduced **rule-based rewards** to evaluate whether the model-generated solutions were correct. We'll adopt a similar approach by defining two custom reward functions:

- **Format Reward**: Checks if the output follows the required format:  
  `<think> [thinking] </think><answer> [answer] </answer>`

- **Equation Reward**: Extracts the equation from within the `<answer>` tag, verifies that it evaluates to the target result, and ensures that all available numbers are used exactly once.

The purpose of enforcing the format is mainly to make answer extraction easier. It isn't strictly necessary for the correctness of the answer itself but simplifies parsing during training.

The final reward assigned to an episode/trajectory (prompt+response) is simply the sum of these two components. Importantly, the reward is only computed at the **last token** of the output. From an RL perspective, this means that all intermediate actions receive zero reward. We also do not apply any discounting here (i.e., $\gamma = 1$).

In [8]:
def format_reward_func(completion: str) -> float:
    """
    Format: <think>...</think>\n</answer>...</answer>

    Also checks that the content within <answer>...</answer> conforms to a
    specified pattern (only digits, + - * / ( ) . and whitespace).

    Args:
        completion (str): Generated output

    Returns:
        float: Reward score
    """
    # Define the allowed pattern (only numbers, +, -, *, /, (, ), ., and whitespace)
    allowed_pattern = r"^[\d+\-*/().\s]+$"

    try:
        # add synthetic <think> as its already part of the prompt and prefilled 
        # for the assistant to more easily match the regex
        completion = "<think>" + completion

        # Strip EOS token if present
        if completion.endswith(EOS_TOKEN):
            completion = completion[:-len(EOS_TOKEN)]

        # Check if the format is correct
        # Pattern means:
        # 1) <think>...contents not including other <think> tags...</think>
        # 2) \n
        # 3) <answer>...anything...</answer>
        regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"
        match = re.search(regex, completion, re.DOTALL)

        if match is None or len(match.groups()) != 2:
            # Format is incorrect
            return 0.0
        else:
            # Extract the content inside <answer>...</answer>
            answer_content = match.group(2).strip()

            # Check if answer content matches the allowed pattern
            if not re.match(allowed_pattern, answer_content):
                # If it doesn't match, reward is 0.5
                return 0.5
            else:
                # If both format and pattern are correct, reward is 1
                return 1.0
    except Exception:
        # Any error leads to 0 reward
        return 0.0


def equation_reward_func(completion: str, nums: List[int], target: int) -> float:
    """
    Evaluates completion based on mathematical correctness of the answer

    Args:
        completion (str): Generated output
        target (str): Expected answer
        nums (list): Available numbers to use in the equation

    Returns:
        float: Reward score
    """
    try:
        # Check if the format is correct
        match = re.search(r"<answer>(.*?)<\/answer>", completion)
        if match is None:
            return 0.0
        # Extract the "answer" part from the completion
        equation = match.group(1).strip()
        # Extract all numbers from the equation
        used_numbers = [int(n) for n in re.findall(r"\d+", equation)]

        # Check if all numbers are used exactly once
        if sorted(used_numbers) != sorted(nums):
            return 0.0
        # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
        allowed_pattern = r"^[\d+\-*/().\s]+$"
        if not re.match(allowed_pattern, equation):
            return 0.0

        # Evaluate the equation with restricted globals and locals
        result = eval(equation, {"__builtins__": None}, {})
        # Check if the equation is correct and matches the ground truth
        if abs(float(result) - float(target)) < 1e-5:
            return 1.0
        else:
            return 0.0
    except Exception:
        # If evaluation fails, reward is 0
        return 0.0
    

def compute_reward(completion: str, sample: Dict[str, Any]) -> Tuple[float, Dict[str, float]]:
    nums = sample["nums"]
    target = sample["target"]

    format_reward = format_reward_func(completion)
    equation_reward = equation_reward_func(
        completion=completion, nums=nums, target=target
    )

    reward = format_reward + equation_reward

    metrics = {
        "format_reward": format_reward,
        "equation_reward": equation_reward,
    }   

    return reward, metrics

In [None]:
# <think> is prefilled in the prompt. So, repeating it in the completion would be incorret.
format_reward_func("<think>I think the answer is </think>\n<answer>1+2</answer>")

In [None]:
format_reward_func("I think the answer is </think>\n<answer>1+2</answer>")

In [None]:
format_reward_func("<think>I think the<think>and even more</think> answer is </think>\n<answer>1+2</answer>")

In [None]:
equation_reward_func("I think the answer is </think>\n<answer>1+2+2</answer>", [1,2], 3)

## Episode Generation

The goal of episode generation is to create a collection of query-response pairs that will be used for policy training. From the reinforcement learning (RL) perspective, the **query** serves as the initial state, and the generated tokens in the **response** represent the actions taken by the policy.

The `create_training_episodes` function takes a list of prompts (initial states) and their corresponding completions which we generate using the model.  In GRPO, we always generate multiple responses per prompt—specifically, `GENERATIONS_PER_SAMPLE` > 1. This means that, after episode generation, we end up with `batch_size × GENERATIONS_PER_SAMPLE` episodes in every RL iteration.

### Advantage Computation

In addition to generating episodes, `create_training_episodes` is also responsible for computing the **advantage** for every response token. 

In RL terms, the advantage of a token represents how much better or worse that token's action is compared to the average generate token at that specific state (prompt + prefix). Ideally, we would compute an advantage for every token individually to capture how each step contributes to the overall reward.

However, in GRPO, there's no per-token advantage computation. Instead, we compute a single advantage value per response. This value reflects how good the entire response is relative to other responses generated for the same prompt. We then assign this single advantage value uniformly to all tokens within that response.

GRPO uses a simple formula for this:

1. For each prompt $x$ with a group of generated responses $y_1, y_2, \ldots, y_G \sim \pi(\cdot|x)$, compute their rewards $R_1, R_2, \ldots, R_G$.
2. Compute the group's mean and standard deviation:  
   $ \mu = \text{mean}(R_1, R_2, \ldots, R_G) $  
   $ \sigma = \text{std}(R_1, R_2, \ldots, R_G) $
3. Compute a **relative score** for each response:  
   $ R^*_i = \frac{R_i - \mu}{\sigma} $
4. Assign this relative score $R^*_i$ as the advantage to all tokens of the $i$-th response:  
   $ A_t^{(i)} = R^*_i $

This **per-group normalization** encourages responses that are better than average and penalizes those that are worse.

### Example: Advantage in Action

Consider a binary reward scenario where each response is either correct (1) or incorrect (0):

```python
>>> rewards = np.array([1, 1, 0, 0, 0])
>>> (rewards - rewards.mean()) / (rewards.std())
array([ 1.22474487,  1.22474487, -0.81649658, -0.81649658, -0.81649658])
```

Here, the correct responses receive higher advantage scores, promoting them in future updates.


If only one response is correct:

```python
>>> rewards = np.array([1, 0, 0, 0, 0])
>>> (rewards - rewards.mean()) / (rewards.std())
array([ 2. , -0.5, -0.5, -0.5, -0.5])
```

This resembles the case where the question in the prompt is too hard and the model is not able to generate a correct response on average.
However, if one of the responses is correct, it will be assigned a higher advantage score, and all incorrect responses will be assigned a negative relative score.

If all responses are incorrect:

```python
>>> rewards = np.array([0, 0, 0, 0, 0])
>>> (rewards - rewards.mean()) / (rewards.std() + 1e-6)
array([0., 0., 0., 0., 0.])
```

Since there is no one is better than the average, the model receives no learning signal.

If all responses are correct:

```python
>>> rewards = np.array([1, 1, 1, 1, 1])
>>> (rewards - rewards.mean()) / (rewards.std() + 1e-6)
array([0., 0., 0., 0., 0.])
```

Again, no learning signal is provided because there is nothing to improve upon.

In a more mixed case:

```python
>>> rewards = np.array([1, 1, 1, 1, 0])
>>> (rewards - rewards.mean()) / (rewards.std() + 1e-6)
array([0.5, 0.5, 0.5, 0.5, -2.])
```

This represents an easier question for the model. Most responses are correct, but occasional incorrect ones are heavily penalized.

In [11]:
def create_training_episodes(
    samples: List[Dict[str, Any]],
    all_generations: List[List[int]],
    all_finish_reasons: List[str],
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Process model generations and calculate rewards for training episodes.

    This function processes generated responses and calculates rewards for training episodes by:
    1. Grouping generations by sample (GENERATIONS_PER_SAMPLE responses per input)
    2. Computing rewards and advantages for each response
    3. Processing response tokens

    Args:
        samples: List of input samples, each containing:
            - input_ids: List[int], tokenized input prompt
            - nums: List[int], numbers to use in equation
            - target: int, target value for equation
        all_generations: List of token ID sequences for each generated response
        all_finish_reasons: List of finish reasons for each generation ("stop" or other)

    Returns:
        Tuple containing:
        1. Dictionary with processed data for training:
            - all_query_token_ids: List[List[int]], input token IDs repeated for each generation
            - all_response_token_ids: List[List[int]], response token IDs with EOS tokens added
            - all_advantages: List[List[float]], advantage values repeated for each token
        2. Dictionary with generation statistics:
            - response_lengths: List[int], lengths of generated responses
            - rewards: List[float], raw reward values
            - non_stop_rate: List[bool], whether each generation ended naturally
            - reward_metrics/*: Various reward component metrics

    Example:
        >>> samples = [{"input_ids": [1,2,3], "nums": [1,2,3], "target": 6}]
        >>> generations = [[4,5, EOS_TOKEN_ID], [6,7], [8,9, EOS_TOKEN_ID]]  # 3 generations per sample
        >>> finish_reasons = ["stop", "length", "stop"]
        >>> episodes, stats = create_training_episodes(samples, generations, finish_reasons)
        >>> episodes
        {
            'all_query_token_ids': [[1,2,3], [1,2,3], [1,2,3]],
            'all_response_token_ids': [[4,5,EOS_TOKEN_ID], [6,7], [8,9,EOS_TOKEN_ID]],
            'all_advantages': [[0.5,0.5,0.5], [-1.0,-1.0], [0.5,0.5,0.5]]
        }
    """
    assert len(all_generations) == len(all_finish_reasons)
    assert len(all_generations) == len(samples) * GENERATIONS_PER_SAMPLE

    # Process responses and calculate rewards
    groups = [
        list(range(i, i + GENERATIONS_PER_SAMPLE))
        for i in range(0, len(all_generations), GENERATIONS_PER_SAMPLE)
    ]  # example: [[0, 1, 2], [3, 4, 5], [6, 7, 8]]

    all_query_token_ids, all_responses_token_ids, all_advantages = [], [], []

    stats = {
        "response_lengths": [],
        "rewards": [],
        "non_stop_rate": [],
    }

    for sample, group_indices in zip(samples, groups):
        finish_reasons = [all_finish_reasons[i] for i in group_indices]
        response_token_ids = [all_generations[i] for i in group_indices]
        responses = tokenizer.batch_decode(response_token_ids, skip_special_tokens=False)

        rewards_and_metrics = [compute_reward(resp, sample) for resp in responses]
        rewards, reward_metrics = zip(*rewards_and_metrics)

        rewards = np.array(rewards) # [group_size]
        response_advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-4)
        
        advantages = [
            [resp_adv] * len(resp) 
            for resp_adv, resp in zip(response_advantages, response_token_ids)
        ]

        all_query_token_ids.extend([sample["input_ids"]] * GENERATIONS_PER_SAMPLE)
        all_responses_token_ids.extend(response_token_ids)
        all_advantages.extend(advantages)

        stats["rewards"].extend(rewards)
        stats["non_stop_rate"].extend([fr != "stop" for fr in finish_reasons])
        stats["response_lengths"].extend([len(ids) for ids in response_token_ids])
        for rm in reward_metrics:
            for k, v in rm.items():
                stats.setdefault(f"reward_metrics/{k}", []).append(v)

    episodes = {
        "all_query_token_ids": all_query_token_ids,
        "all_response_token_ids": all_responses_token_ids,
        "all_advantages": all_advantages,
    }

    return episodes, stats

In [None]:
case_0 = {
    "sample": {"input_ids": [1,2,3], "nums": [1,2,3], "target": 6},
    "generations": [[4,5, 22, 33], [6,7], [8,9, 11], [10,11]],
    "finish_reasons": ["stop", "length", "stop", "stop"]
}

case = case_0
episodes, stats = create_training_episodes([case["sample"]], case["generations"], case["finish_reasons"])
episodes

In [None]:
case_1 = {
    "sample": {"input_ids": [33, 44], "nums": [11, 7, 8], "target": 26},
    "generations": [[1,2], [3,4], [5,6], [7,8]],
    "finish_reasons": ["stop", "stop", "length", "stop"]
}
case = case_1
episodes, stats = create_training_episodes([case["sample"]], case["generations"], case["finish_reasons"])
episodes

In [None]:
case_2 = {
    "sample": {"input_ids": [9, 8, 7, 6, 5, 4], "nums": [1,2,3,4], "target": 10},
    "generations": [[9,10], [11,12], [13,14], [15,16]],
    "finish_reasons": ["length", "length", "stop", "stop"]
}
case = case_2
episodes, stats = create_training_episodes([case["sample"]], case["generations"], case["finish_reasons"])
episodes

As you can see, the `input_ids` of this single exmaple is repeated in all of generated episodes

## Policy Gradient


Now that we have a batch of episodes with corresponding advantages, we can compute the **policy gradient loss** to update the model.

GRPO uses the same loss formulation as PPO, but the key difference lies in how advantages are computed. To understand the implementation in `compute_pg_loss`, let’s first recall the original PPO objective:

$$
\mathcal{L}_{\text{PPO}} = \mathbb{E}\left[\min\left( 
\frac{\pi_\theta(y_t \mid y_{<t}, x)}{\pi_{\theta_{\text{old}}}(y_t \mid y_{<t}, x)} A_t, \;
\text{clip}\left(
\frac{\pi_\theta(y_t \mid y_{<t}, x)}{\pi_{\theta_{\text{old}}}(y_t \mid y_{<t}, x)}, \;
1 - \epsilon, \; 1 + \epsilon
\right) A_t \right)\right]
$$

where:
- $ \pi_{\theta} $ is the current policy,
- $ \pi_{\theta_{\text{old}}} $ is the policy from the previous iteration (the policy we sampled episodes from),
- $ A_t $ is the advantage.

This objective tries to increase or decrease the probability of tokens based on the advantage $A_t$ only when the ratio between the new and old policy probabilities stays within a small range, controlled by the clipping threshold $\epsilon$. This clipping mechanism prevents large, destabilizing updates during training.

### Fully Online Setting: Simplifying the Objective

In general PPO, multiple gradient steps might be taken using the same batch of episodes. However, in our case, we apply only **one gradient step per iteration** using freshly sampled episodes. That means:

- $ \pi_{\theta} = \pi_{\theta_{\text{old}}} $
- Consequently,  
  $$
  \frac{\pi_\theta(y_t \mid y_{<t}, x)}{\pi_{\theta_{\text{old}}}(y_t \mid y_{<t}, x)} = 1
  $$
  
Since the ratio is exactly 1:
- The clipping function becomes inactive.
- The $\min(\cdot,\cdot)$ operator simply returns the unclipped term.

So, the objective simplifies **to**:

$$
\mathcal{L}_{\text{PPO}} = \mathbb{E}\left[ \frac{\pi_\theta(y_t \mid y_{<t}, x)}{\pi_{\theta_{\text{old}}}(y_t \mid y_{<t}, x)} A_t \right]
$$


Taking the gradient of this loss with respect to $\theta$, we get:

$$
\vec{g}_{\text{PPO}} = \nabla_\theta \mathcal{L}_{\text{PPO}} = 2 \underbrace{\mathbb{E}\left[ \nabla_\theta \log \pi_\theta(y_t \mid y_{<t}, x) \cdot A_t \right]}_{\text{vanilla policy gradient with advantage}}
$$

This is the **standard policy gradient** formula, where the log-probabilities are weighted by the advantage. In effect, we recover vanilla REINFORCE-style learning.

> Note: The a constant multiplier (like 2) does not affect the direction of the gradient and can be safely ignored.

In fact, this behavior is not unique to GRPO. In all methods such as PPO, TRPO the very first gradient step after collecting new data will always reduce to this same form. Only after the optimization step the clipping or trust region constraint start to take effect.

### KL Penalty

The final loss also has a **KL penalty** term to ensure the new policy doesn't drift too far from a reference policy:

$$
\mathcal{L} = \mathcal{L}_{\text{PPO}} - \beta \cdot \text{KL}(\pi_\theta \parallel \pi_{\theta_{\text{ref}}})
$$

We estimate the KL divergence using the **k3 estimator** from [this blog post by Schulman](http://joschu.net/blog/kl-approx.html):

$$
\text{KL}(\pi_\theta \parallel \pi_{\theta_{\text{ref}}}) = \mathbb{E}\left[\frac{\pi_{\theta_{\text{ref}}}(y_t \mid y_{<t}, x)}{\pi_\theta(y_t \mid y_{<t}, x)} - \log\left(\frac{\pi_{\theta_{\text{ref}}}(y_t \mid y_{<t}, x)}{\pi_\theta(y_t \mid y_{<t}, x)}\right) - 1\right]
$$

This regularization term softly constrains the updated model to remain close to the reference.


### GRPO vs PPO/VinePPO: Key Difference

The main difference between **GRPO** and methods like **PPO/VinePPO** lies in **how the advantage is computed and applied**:

- In **PPO/VinePPO**, each token/step's advantage is computed individually. This allows for fine-grained credit assignment across the sequence.
- In **GRPO**, a **single scalar advantage** is computed for the entire response and is applied **uniformly to all tokens** in that response.

This distinction is illustrated below:

#### A successful response in GRPO:
<img src="https://github.com/McGill-NLP/nano-aha-moment/blob/main/assets/grpo_successful.png?raw=true" alt="GRPO vs PPO/VinePPO: successful response" width="500">

#### A failed response in GRPO:
<img src="https://github.com/McGill-NLP/nano-aha-moment/blob/main/assets/grpo_unsuccessful.png?raw=true" alt="GRPO vs PPO/VinePPO: failed response" width="500">

In GRPO, all tokens in a response are updated with the same magnitude. In contrast, PPO/VinePPO updates each token/step with a different advantage value:

<img src="https://github.com/McGill-NLP/nano-aha-moment/blob/main/assets/ppo_and_vineppo.png?raw=true" alt="GRPO vs PPO/VinePPO: PPO and VinePPO" width="500">


In [12]:
def compute_pg_loss(
    policy_model: Union[DeepSpeedEngine, PreTrainedModel],
    reference_model: Union[DeepSpeedEngine, PreTrainedModel],
    batch: Dict[str, torch.Tensor],
    total_response_len: int,
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Compute the policy gradient loss with KL penalty between policy and reference models.

    This function:
    1. Computes log probabilities for both policy and reference models
    2. Calculates KL divergence penalty between the models
    3. Computes policy gradient loss using advantages
    4. Combines the losses with KL coefficient

    Args:
        policy_model: The model being trained
        reference_model: The reference model for KL penalty calculation
        batch: Dictionary containing:
            - input_ids: Tensor of shape [batch_size, seq_len]
            - attention_mask: Tensor of shape [batch_size, seq_len]
            - labels: Tensor of shape [batch_size, seq_len] with -100 for ignored positions
            - advantages: Tensor of shape [batch_size, seq_len]

    Returns:
        Tuple containing:
            - loss: Combined policy gradient and KL penalty loss (scalar tensor)
            - metrics: Dictionary with detailed loss components:
                - policy_loss: Pure policy gradient loss
                - kl_penalty: KL divergence penalty
                - entropy: Policy entropy
    """
    input_ids = batch["input_ids"]  # [batch_size, seq_len]
    attention_mask = batch["attention_mask"]  # [batch_size, seq_len]
    labels = batch["labels"]  # [batch_size, seq_len]
    advantages = batch["advantages"]  # [batch_size, seq_len]

    model_inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

    labels_mask = (labels[..., 1:] != -100).float()  # [batch_size, seq_len-1]

    with torch.no_grad():
        ref_logps = compute_token_log_probs(
            reference_model, model_inputs, TEMPERATURE
        )  # [batch_size, seq_len-1]

    logps = compute_token_log_probs(policy_model, model_inputs, TEMPERATURE)  # [batch_size, seq_len-1]

    kl_penalty = torch.exp(ref_logps - logps) - (ref_logps - logps) - 1  # [batch_size, seq_len-1]
    kl_penalty = kl_penalty * labels_mask  # [batch_size, seq_len-1]

    entropy = -logps.sum() / labels_mask.sum()  # scalar

    policy_loss = -logps * advantages[..., 1:]  # [batch_size, seq_len-1]
    policy_loss = policy_loss * labels_mask  # [batch_size, seq_len-1]

    loss = (policy_loss + KL_COEFFICIENT * kl_penalty).sum() / total_response_len  # scalar

    metrics = {
        "policy_loss": policy_loss.sum().item() / total_response_len,
        "kl_penalty": kl_penalty.sum().item() / total_response_len,
        "entropy": entropy.item() / total_response_len,
    }

    return loss, metrics

## Training

Before starting the RL loop, we need to set up all necessary components:

- **Policy Model**: The main model that will be trained using policy gradients.
- **Reference Model**: A frozen copy of the base model used for KL regularization.
- **DeepSpeed**: Both models are initialized with DeepSpeed.
- **vLLM Inference Engine**: Used for fast, batched inference during episode generation.
- **WandB Logging**: We initialize WandB to track training metrics, hyperparameters, and checkpoints.

Finally, if an existing checkpoint is detected, we automatically resume training from where it left off. 

Couple of remarks:
- We move the reference to CPU and only take back to GPU during policy gradient computation. Because of the relatievely small size of the model, this moving back and forth from GPU to CPU is super fast.
- Despite the entire training being run on a single GPU, we still use DeepSeed Zero stage 2. This is because the stage 2 comes with some optimization that avoid memory fragmentations, allowing to fully utilize GPU memory.
- Flash Attention is required in our setup as it reduces the memory requirement of transformers from $\mathcal{O}(n^2)$ to $\mathcal{O}(n)$ where $n$ the sequence length.

In [None]:
# Initialize main and reference models
policy_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map=0,
)
reference_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map=0,
)
policy_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})


# Initialize DeepSpeed engines
policy_model, *_ = deepspeed.initialize(
    model=policy_model,
    config=deepspeed_config,
    model_parameters=policy_model.parameters(),
)
reference_model, *_ = deepspeed.initialize(
    model=reference_model,
    config=ref_deepspeed_config,
)

reference_model.module.cpu()

############################################
# Initialize vLLM (Inference) engine
############################################

inference_engine = LLM(
    model=MODEL_NAME,
    skip_tokenizer_init=False,
    gpu_memory_utilization=0.2,
    enable_prefix_caching=True,
    swap_space=1,
    scheduling_policy="fcfs",
    dtype=torch.bfloat16,
    max_model_len=2048,
    enable_sleep_mode=True,
)

# Wandb for logging
wandb.init(
    project="r1-aha-moment",
    name=RUN_NAME,
    config={
        "model_name": MODEL_NAME,
        "learning_rate": LEARNING_RATE,
        "num_iterations": NUM_ITERATIONS,
        "episodes_per_iteration": EPISODES_PER_ITERATION,
        "rollouts_per_episode": GENERATIONS_PER_SAMPLE,
        "kl_coefficient": KL_COEFFICIENT,
        "temperature": TEMPERATURE,
    },
)

# Load checkpoint if it exists
begin_iter = 0
ckpt_path, ckpt_iter = find_last_checkpoint(EXP_DIR)
if ckpt_path is not None:
    print(f"Resuming from checkpoint {ckpt_path} at iteration {ckpt_iter}")
    out = policy_model.load_checkpoint(ckpt_path / "deepspeed")
    if out is None:
        raise RuntimeError(f"Failed to load checkpoint {ckpt_path}")
    begin_iter = ckpt_iter + 1
    load_model_into_vllm(policy_model, inference_engine)

### Training loop

With everything set up, we are ready to start the main training loop. Each iteration of the loop performs the following steps:

1. **Evaluation** (optional): 
Every few iterations, the model is evaluated on a test set to monitor progress.
2. **Episode Generation**
A batch of prompts is sampled, and multiple responses are generated for each prompt using the inference engine. Then we put the inference engine to sleep.
3. **Reward Computation**
Rewards and advantages for each generated episode are computed.
4. **Policy Gradient Training**
Using the computed advantages, we calculate the policy gradient loss and update the model parameters. The training is done using gradient accumulation to handle large batches. Note that we apply single gradient update per iteration.
5. **Inference Engine Update**
The inference engine is woken up and updated with the latest model weights.
6. **Logging**
Training and evaluation metrics are logged using WandB.
7. **Checkpointing**
Every 50 iterations, the model and optimizer states are saved.

This loop continues until the specified number of iterations is completed.

**Sleeping of vLLM**
Before training begins, we put vLLM into sleep mode to free up its KV cache and model weights, ensuring enough GPU memory is available for policy training. After the training step is complete, vLLM is woken up, reinitializing its KV cache and preparing for the next round of sampling using the updated model parameters.

In [None]:
for iteration in trange(NUM_ITERATIONS):
    print(f"Iteration {iteration}/{NUM_ITERATIONS}")

    metrics = {}

    #########################################################
    # Evaluation
    #########################################################

    eval_stats = None
    if iteration % 25 == 0:
        print("Evaluating on eval set...")
        eval_episodes, eval_stats = evaluate_on_test_set(
            inference_engine=inference_engine,
            test_dataset=test_dataset,
            tokenizer=tokenizer,
            eos_token=EOS_TOKEN,
            eval_sampling_params=SamplingParams(
                temperature=0.3,
                max_tokens=1024,
                n=1,
                detokenize=False,
                stop_token_ids=[EOS_TOKEN_ID],
            ),
            reward_func=lambda completion, sample: compute_reward(
                completion, sample
            ),
        )
        eval_episode_table = dump_episodes(
            episodes=eval_episodes,
            episodes_stats=eval_stats,
            exp_dir=EXP_DIR,
            tokenizer=tokenizer,
            iteration=iteration,
            is_eval=True,
        )
        wandb.log({"eval/episodes": eval_episode_table, "iteration": iteration})


    #########################################################
    # Generate Episodes
    #########################################################

    # Sample training batch
    num_samples = EPISODES_PER_ITERATION // GENERATIONS_PER_SAMPLE
    indices = np.random.choice(
        len(train_dataset), size=num_samples, replace=False
    )
    samples = train_dataset.select(indices)

    # Sample responses
    outputs = inference_engine.generate(
        prompt_token_ids=samples["input_ids"],
        sampling_params=SamplingParams(
            n=GENERATIONS_PER_SAMPLE,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            top_k=TOP_K,
            max_tokens=MAX_RESPONSE_TOKENS,
            detokenize=False,
            stop_token_ids=[EOS_TOKEN_ID],
        )
    )
    all_generations = [list(g.token_ids) for out in outputs for g in out.outputs]
    all_finish_reasons = [g.finish_reason for out in outputs for g in out.outputs]
    inference_engine.sleep(1)

    print(f"Generated {len(all_generations)} responses")
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(1)

    # Process responses and calculate rewards
    episodes, episodes_stats = create_training_episodes(
        samples,
        all_generations,
        all_finish_reasons,
    )
    for k, v in episodes_stats.items():
        metrics.setdefault(k, []).extend(v)

    episode_table = dump_episodes(
        episodes=episodes,
        episodes_stats=episodes_stats,
        exp_dir=EXP_DIR,
        tokenizer=tokenizer,
        iteration=iteration,
    )

    #########################################################
    # Training
    #########################################################

    # Prepare training batch
    model_inputs = prepare_model_inputs(
        query_token_ids=episodes["all_query_token_ids"],
        response_token_ids=episodes["all_response_token_ids"],
        advantages=episodes["all_advantages"],
        device="cuda"
    )

    # Calculate losses and update model
    policy_model.train()
    reference_model.module.cuda()
    reference_model.eval()

    total_response_len = (model_inputs["labels"] != -100).sum().item()

    for i in trange(0, EPISODES_PER_ITERATION, PER_DEVICE_BATCH_SIZE, desc="Gradient Accumulation"):
        batch = {
            k: v[i : i + PER_DEVICE_BATCH_SIZE]
            for k, v in model_inputs.items()
        }

        # Compute policy gradient loss
        loss, loss_metrics = compute_pg_loss(
            policy_model=policy_model,
            reference_model=reference_model,
            batch=batch,
            total_response_len=total_response_len,
        )

        # Track metrics
        metrics.setdefault("loss", []).append(loss.item())
        grad_norm = policy_model.get_global_grad_norm()
        if grad_norm is not None:
            grad_norm = grad_norm.item()
        metrics.setdefault("grad_norm", []).append(grad_norm)
        for k, v in loss_metrics.items():
            metrics.setdefault(k, []).append(v.item() if isinstance(v, torch.Tensor) else v)

        # Backpropagation and optimization step
        policy_model.backward(loss, scale_wrt_gas=False)
        
        # Free memory
        del loss, loss_metrics
        if policy_model.is_gradient_accumulation_boundary():
            reference_model.module.cpu()

        policy_model.step()

    #########################################################
    # Update inference engine weights
    #########################################################
    
    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(1)

    inference_engine.wake_up()
    load_model_into_vllm(policy_model, inference_engine)

    gc.collect()
    torch.cuda.empty_cache()
    time.sleep(1)


    #########################################################
    # Log metrics
    #########################################################

    train_metrics = {
        k: np.mean(v) for k, v in metrics.items() if None not in v
    }
    train_metrics["learning_rate"] = policy_model.get_lr()[0]
    logs = {
        "iteration": iteration,
        f"episodes/iter_{iteration:06d}": episode_table,
        **{f"train/{k}": v for k, v in train_metrics.items()},
    }
    if eval_stats is not None:
        eval_metrics = {k: np.mean(v) for k, v in eval_stats.items() if None not in v}
        logs.update({f"eval/{k}": v for k, v in eval_metrics.items()})
    wandb.log(logs)

    selected_keys = [
        "train/kl_penalty",
        "train/rewards",
        "train/reward_metrics/format_reward",
        "train/reward_metrics/equation_reward",
        "eval/rewards",
        "eval/reward_metrics/format_reward",
        "eval/reward_metrics/equation_reward",
    ]
    selected_metrics = {k: logs[k] for k in selected_keys if k in logs}
    print(f"KEY METRICS: {selected_metrics}")

    if iteration % 50 == 0 and iteration != 0:
        policy_model.module.save_pretrained(
            str(EXP_DIR / "checkpoints" / f"ckpt_{iteration:06d}" / "hf_model")
        )
        policy_model.save_checkpoint(
            str(EXP_DIR / "checkpoints" / f"ckpt_{iteration:06d}" / "deepspeed")
        )

## Citation

If you use this codebase in your research, please cite us using:

```bibtex
@misc{Kazemnejad2025:NanoAhaMoment,
  author       = {Amirhossein Kazemnejad and Milad Aghajohari and Alessandro Sordoni and Aaron Courville and Siva Reddy},
  title        = {Nano Aha! Moment: Lunch Break Reproduction of DeepSeek R1-Zero from Scratch},
  year         = {2025},
  howpublished = {\url{https://github.com/McGill-NLP/nano-aha-moment}},
  note         = {GitHub repository}
}
```