# TRL Introduction: RL on Actual LLMs

Now we apply the concepts from notebook 01 to real language models.

**What we'll do:**
1. Load a small LLM (Qwen-2.5-0.5B)
2. Create a simple reward function
3. Run GRPO training using HuggingFace's `trl` library

**Why GRPO over PPO?**
- GRPO doesn't need a separate value model (simpler)
- Uses group mean as baseline (like we learned in notebook 01)
- This is what DeepSeek used for their reasoning models

**This notebook is for understanding.** The actual training runs on Modal via:
```bash
modal run modal_app.py::train_grpo_simple
```

## Part 1: The Components

Let's understand what goes into GRPO training.

In [None]:
# Component 1: MODEL
# The LLM we're training (the "policy")
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"  # Small, fast for learning

# Component 2: DATASET
# Prompts to train on - simple arithmetic to verify our setup works
TRAIN_DATA = [
    {"prompt": "What is 2 + 3?", "answer": "5"},
    {"prompt": "What is 7 - 4?", "answer": "3"},
    {"prompt": "What is 5 * 2?", "answer": "10"},
    {"prompt": "What is 8 / 2?", "answer": "4"},
    {"prompt": "What is 10 + 5?", "answer": "15"},
    {"prompt": "What is 9 - 3?", "answer": "6"},
    {"prompt": "What is 4 * 3?", "answer": "12"},
    {"prompt": "What is 12 / 4?", "answer": "3"},
]

# Component 3: REWARD FUNCTION
# Score the model's output - this is the "verifiable" in RLVR
def compute_reward(response: str, correct_answer: str) -> float:
    """
    Simple reward: 1.0 if correct answer appears in response, 0.0 otherwise.
    """
    if correct_answer in response:
        return 1.0
    return 0.0

# Test reward function
print("Reward tests:")
print(f"  'The answer is 5' for '5': {compute_reward('The answer is 5', '5')}")
print(f"  'I think it is 7' for '5': {compute_reward('I think it is 7', '5')}")
print(f"  '5' for '5': {compute_reward('5', '5')}")

## Part 2: How GRPO Works

`trl.GRPOTrainer` does this:

```
For each batch of prompts:
    1. Generate K responses per prompt (the "group", e.g., K=4)
    2. Compute reward for each response
    3. Compute baseline = mean(rewards in group)
    4. For each response:
         advantage = reward - baseline
         loss += -advantage * log_prob(response)
    5. Update model
```

This is exactly what we did in notebook 01 with the bandit, but now:
- "action" = entire generated sequence
- "policy" = the LLM
- "baseline" = group mean (GRPO's approach)

In [None]:
# The training hyperparameters

GRPO_CONFIG = {
    # How many responses to generate per prompt (the "group")
    "num_generations": 4,
    
    # Learning rate - small for LLMs!
    "learning_rate": 1e-6,
    
    # Batch size - must be divisible by num_generations
    "per_device_train_batch_size": 4,
    
    # How many update steps to run
    "max_steps": 50,
    
    # Max length of generated responses
    "max_completion_length": 64,
    
    # Temperature for sampling (higher = more random = more exploration)
    "temperature": 0.7,
}

print("GRPO Config:")
for k, v in GRPO_CONFIG.items():
    print(f"  {k}: {v}")

## Part 3: The Training Code

Here's what `modal_app.py::train_grpo_simple` does (simplified):

```python
from trl import GRPOConfig, GRPOTrainer

# Load model
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Define reward function
def reward_fn(completions, prompts, **kwargs):
    rewards = []
    for completion, prompt in zip(completions, prompts):
        correct = answer_lookup[prompt]
        rewards.append(1.0 if correct in completion else 0.0)
    return rewards

# Create trainer
trainer = GRPOTrainer(
    model=model,
    config=GRPOConfig(...),
    tokenizer=tokenizer,
    train_dataset=dataset,
    reward_funcs=reward_fn,  # This is where RLVR happens!
)

# Train
trainer.train()
```

That's it. The `GRPOTrainer` handles all the RL mechanics we learned in notebook 01.

## Part 4: Run It

Run on Modal (needs GPU):

```bash
modal run modal_app.py::train_grpo_simple
```

**What to watch for:**
1. **Reward increasing** - Model learns to produce correct answers
2. **Loss** - Will fluctuate (stochastic!), but trend should improve
3. **Test outputs** - At the end, it tests on 3 examples

**Expected:**
- Takes ~5-10 minutes
- Costs ~$0.10-0.20
- Model already knows arithmetic, so we're mostly teaching output format

## Next Steps

After running this:
1. **03_gsm8k_reward.ipynb** - Build proper answer extraction for GSM8K (real math problems)
2. **04_pseudo_labeling.ipynb** - Implement consensus-based pseudo-labeling
3. **Full noisy student loop** - Combine everything