Improve Math Reasoning by RLVR

This notebook utilizes [Tinker](https://thinkingmachines.ai/tinker/) by Thinking Machines Lab to implement **Group Relative Policy Optimization (GRPO)** for math reasoning on the GSM8K dataset, following the RL section of [Stanford CS336 Assignment 5](https://stanford-cs336.github.io/spring2025/assignments/5/). To learn the theory fundamentals of RL, we recommend checking out Stanford CS336 on Youtube, from [lecture 15 to lecture 17](https://www.youtube.com/watch?v=Dfu7vC9jo4w&list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_&index=15).

## What You'll Build

- **On-Policy GRPO**: One training step per sampling from the old policy.
- **Off-Policy GRPO**: Multiple training steps per sampling from the old policy.

We also provide tracking based on Weights & Biases so that you can visualize the training process, and check
the training result easily.

## Key Concepts

**GRPO (Group Relative Policy Optimization)**:
- Sample multiple responses per question (group_size=8)
- Normalize rewards within each group
- Advantages = (reward - group_mean) / (group_std + eps)
- Use PPO-style clipping to prevent excessive policy updates

**On-Policy vs Off-Policy**:
- **On-policy** (`epochs_per_rollout_batch=1`): One gradient step per rollout batch
- **Off-policy** (`epochs_per_rollout_batch>1`): Multiple gradient steps per rollout batch
  - Freezes old_log_probs from sampling policy πθ_old
  - Computes new_log_probs from current policy πθ at each gradient step
  - Ratio πθ/πθ_old measures policy divergence

## References

- **DeepSeekMath**: https://arxiv.org/abs/2402.03300
- **DeepSeek-R1**: https://arxiv.org/abs/2501.12948  
- **Tinker Docs**: https://tinker-docs.thinkingmachines.ai/
- **GSM8K Dataset**: https://huggingface.co/datasets/openai/gsm8k

---
## 1. Setup & Imports

### Why Tinker?

Assume we are building the GRPO training loop with PyTorch or JAX, we need to handle

- FSDP/DeepSpeed for model sharding
- Custom data loaders for distributed sampling
- Apply LoRA to the model
- Apply mixed precision to the model
- GPU cluster setup and management
- Manage distributed checkpoints
- ...

With Tinker, we only need to focus on defining the correct loss objects, and the rest of the training/inference
hassle is taken care of automatically.

In [None]:
%pip install -qU tinker tinker_cookbook

In [None]:
import random
import re
import time

import chz
import datasets
import tinker
import torch
from tinker.types.tensor_data import TensorData
from tinker_cookbook import checkpoint_utils

---
## 2. Helper Function: Group Normalization

### Why Group Normalization?
GRPO normalizes rewards **within each group** (responses to the same question):
```
advantage = (reward - mean(group_rewards)) / (std(group_rewards) + eps)
```

Benefits:
- **Relative comparison**: "This response is better than average for this question"
- **Variance reduction**: Less sensitive to reward scale
- **Credit assignment**: Rewards differences, not absolute values

In [None]:
from typing import Callable

def run_compute_group_normalized_rewards(
    reward_fn: Callable,
    rollout_responses: list[str],
    repeated_ground_truths: list[str],
    group_size: int,
    advantage_eps: float,
    normalize_by_std: bool,
) -> tuple[torch.Tensor, torch.Tensor, dict[str, float]]:
    """
    Compute group-normalized rewards for GRPO.
    
    From DeepSeekMath: https://arxiv.org/abs/2402.03300
    """
    # Compute raw rewards
    rollout_batch_size = len(rollout_responses)
    raw_rewards_list = []

    for response, ground_truth in zip(rollout_responses, repeated_ground_truths):
        reward_dict = reward_fn(response, ground_truth)
        raw_rewards_list.append(reward_dict["reward"])

    # Convert to tensor
    raw_rewards = torch.tensor(raw_rewards_list, dtype=torch.float32)

    # Reshape into groups: (n_groups, group_size)
    n_groups = rollout_batch_size // group_size
    rewards_grouped = raw_rewards.view(n_groups, group_size)

    # Compute mean for each group
    group_means = rewards_grouped.mean(dim=1, keepdim=True)

    if normalize_by_std:
        # GRPO: normalize by std
        # A^(i) = (r^(i) - mean) / (std + eps)
        group_stds = rewards_grouped.std(dim=1, keepdim=True)
        advantages_grouped = (rewards_grouped - group_means) / (group_stds + advantage_eps)
    else:
        # Dr. GRPO: don't normalize by std
        # A^(i) = r^(i) - mean
        advantages_grouped = rewards_grouped - group_means

    # Flatten back
    advantages = advantages_grouped.view(-1)

    # Metadata
    metadata = {
        "mean_reward": raw_rewards.mean().item(),
        "std_reward": raw_rewards.std().item(),
        "min_reward": raw_rewards.min().item(),
        "max_reward": raw_rewards.max().item(),
    }

    return advantages, raw_rewards, metadata

print("✓ Group normalization function defined")

---
## 2. Configuration

### Key Hyperparameters

We are mostly using the same hyperparameters as provided in the [assignment](https://github.com/stanford-cs336/assignment5-alignment) page 27, with some slight changes to accomodate the Tinker spec, e.g., using Qwen3-8B instad of
Qwen3-1.5B-Math because Tinker doesn't support the latter one.

**Model & Training:**
- `model_name`: Base LLM (Qwen3-8B)
- `lora_rank`: LoRA adapter rank (32)
- `learning_rate`: 1e-5

**GRPO Specific:**
- `rollout_batch_size`: Total rollouts per sampling step (256 = 32 inputs × 8 samples)
- `group_size`: Samples per input for group normalization (8)
- `cliprange`: PPO clip range for stability (0.2)

**On-Policy vs Off-Policy:**
- `epochs_per_rollout_batch`: 1 for on-policy GRPO, >=1 if using off-policy GRPO
- `train_batch_size`: Rollouts per gradient step (64). For on-policy GRPO, this should be equal to `rollout_batch_size`.

**Sampling:**
- `sampling_temperature`: Sampling temperature (1.0)
- `sampling_max_tokens`: Maximum tokens to sample (2048)
- `use_r1_zero_format`: Whether to use R1-Zero format for sampling (True)

**Logging & Checkpointing:**
- `log_path`: Path to save logs and checkpoints (tmp_logging/grpo-tinker-tutorial)
- `save_every`: Save checkpoint every N steps
- `eval_every`: Evaluate every N steps

### Example Configurations



In [None]:
@chz.chz
class GRPOConfig:
    """Configuration for GRPO training."""

    # Model & Service
    model_name: str = "Qwen/Qwen3-8B"
    base_url: str | None = None
    lora_rank: int = 32

    # Training hyperparameters
    n_grpo_steps: int = 60
    learning_rate: float = 1e-5
    rollout_batch_size: int = 256  # Total rollouts (questions = rollout_batch_size // group_size)
    group_size: int = 8  # Samples per question

    # Off-policy training (for efficiency)
    epochs_per_rollout_batch: int = 1  # On-policy=1, off-policy>1 (e.g., 2-4)
    train_batch_size: int = 64  # Rollouts per gradient step

    # GRPO-specific
    advantage_eps: float = 1e-6
    use_std_normalization: bool = True
    cliprange: float = 0.2  # PPO-style clipping

    # Sampling
    sampling_temperature: float = 1.0
    sampling_max_tokens: int = 2048
    use_r1_zero_format: bool = True  # <think></think><answer></answer>

    # Logging & Checkpointing
    log_path: str = "tmp_logging/grpo-tinker-tutorial"
    save_every: int = 6
    eval_every: int = 6

# Create config
config = GRPOConfig()

# Validate hyperparameters
assert config.rollout_batch_size % config.group_size == 0, "rollout_batch_size must be divisible by group_size"
assert config.train_batch_size <= config.rollout_batch_size, "train_batch_size must be <= rollout_batch_size"
assert config.rollout_batch_size % config.train_batch_size == 0, "rollout_batch_size must be divisible by train_batch_size"

n_questions = config.rollout_batch_size // config.group_size
batches_per_epoch = config.rollout_batch_size // config.train_batch_size
total_gradient_steps = config.epochs_per_rollout_batch * batches_per_epoch

print("Configuration:")
print(f"  Questions per step: {n_questions}")
print(f"  Rollouts per question: {config.group_size}")
print(f"  Total rollouts per step: {config.rollout_batch_size}")
print(f"  Train batch size: {config.train_batch_size}")
print(f"  Batches per epoch: {batches_per_epoch}")
print(f"  Epochs per rollout batch: {config.epochs_per_rollout_batch}")
print(f"  Total gradient steps per GRPO step: {total_gradient_steps}")
print(f"\n  Mode: {'On-policy' if config.epochs_per_rollout_batch == 1 else 'Off-policy'}")

---
## 3. Dataset Loading (GSM8K)

As mentioned in the assignment PDF, MATH dataset is not publically available due to copyright issue, so we are
using GSM8K dataset as the alternative training dataset.

### GSM8K Overview
GSM8K is a dataset of grade-school math word problems with numerical answers. It's ideal for testing reasoning and reinforcement learning because:
- Clear ground truth (numerical answers)
- Requires multi-step reasoning
- Automatic reward evaluation

### Format
Each example has:
- `question`: Math word problem
- `answer`: Solution with final answer marked as `#### 42`

In [None]:
def load_gsm8k_dataset():
    """Load GSM8K from HuggingFace."""
    print("Loading GSM8K dataset...")
    dataset = datasets.load_dataset("openai/gsm8k", "main")
    assert isinstance(dataset, datasets.DatasetDict)
    train_data = dataset["train"]
    test_data = dataset["test"]
    print(f"  Train: {len(train_data)} examples")
    print(f"  Test: {len(test_data)} examples")
    return train_data, test_data

train_dataset, test_dataset = load_gsm8k_dataset()

# Show example
print("\nExample question:")
print(train_dataset[0]["question"])
print("\nGround truth answer:")
print(train_dataset[0]["answer"])

---
## 4. Reward Function

The reward function evaluates how good a model's response is. For GSM8K:
- **Format reward** (20%): Is the answer properly formatted?
- **Answer reward** (80%): Is the numerical answer correct?

We try multiple methods to extract the answer from the model's response, which is a combination of
reasoning content and the final answer:
1. `<answer>...</answer>` tags (R1-Zero format)
2. `\boxed{...}` (LaTeX format)
3. Last number in response (fallback)

We provide the reward function below, you can skip reading this section, which is not directly
related to learning how GRPO works.

In [None]:
def gsm8k_reward_fn(response: str, ground_truth: str) -> dict[str, float]:
    """
    Compute reward for GSM8K responses.

    Returns:
        dict with keys: "reward", "format_reward", "answer_reward"
    """
    # Extract ground truth answer (after ####)
    gt_match = re.search(r"####\s*(.+)$", ground_truth.strip())
    if not gt_match:
        return {"reward": 0.0, "format_reward": 0.0, "answer_reward": 0.0}

    gt_answer = gt_match.group(1).strip()

    # Try multiple extraction methods
    # 1. <answer>...</answer> tags
    answer_match = re.search(r"<answer>(.*?)</answer>", response, re.DOTALL)
    if answer_match:
        model_answer = answer_match.group(1).strip()
        format_reward = 1.0
    # 2. \\boxed{}
    elif "\\boxed{" in response:
        boxed_match = re.search(r"\\boxed\{([^}]+)\}", response)
        model_answer = boxed_match.group(1).strip() if boxed_match else ""
        format_reward = 1.0 if boxed_match else 0.0
    else:
        # Fallback: last number
        numbers = re.findall(r"-?\d+\.?\d*", response)
        model_answer = numbers[-1] if numbers else ""
        format_reward = 0.5 if numbers else 0.0

    # Grade numerical answer
    try:
        model_num = float(model_answer.replace(",", ""))
        gt_num = float(gt_answer.replace(",", ""))
        answer_reward = 1.0 if abs(model_num - gt_num) < 1e-4 else 0.0
    except ValueError:
        # String comparison fallback
        answer_reward = 1.0 if model_answer == gt_answer else 0.0

    # Combined reward (format 20%, answer 80%)
    reward = format_reward * 0.2 + answer_reward * 0.8

    return {
        "reward": reward,
        "format_reward": format_reward,
        "answer_reward": answer_reward,
    }

# Test reward function
print("Testing reward function:")
print(gsm8k_reward_fn("<answer>42</answer>", "#### 42"))  # Perfect
print(gsm8k_reward_fn("The answer is 42", "#### 42"))    # Partial (format)
print(gsm8k_reward_fn("<answer>43</answer>", "#### 42")) # Wrong answer

---
## 5. Prompt Building

### Why R1-Zero Format?
We use the prompt format from DeepSeek R1 Zero, which is `<think></think><answer></answer>`. This format encourages:
- Explicit reasoning traces (thinking step)
- Separate final answer (makes extraction easier)

Inspired by DeepSeek-R1's reinforcement learning from reasoning.

In [None]:
def build_gsm8k_prompt(question: str, use_r1_zero_format: bool = True) -> str:
    """Build prompt for GSM8K question."""
    if use_r1_zero_format:
        return f"""A conversation between User and Assistant. The User asks a question, and the Assistant solves it. The Assistant first thinks about the reasoning process in the mind and then provides the User with the answer. The reasoning process is enclosed within <think> </think> and answer is enclosed within <answer> </answer> tags, respectively.

User: {question}
Assistant: <think>"""
    else:
        return f"""Solve the following math problem. Provide your final numerical answer inside \\boxed{{}}.

Question: {question}

Solution:"""

# Show example prompt
example_question = train_dataset[0]["question"]
print("Example prompt:")
print(build_gsm8k_prompt(example_question)[:300] + "...")

---
## 6. Training Setup

Now let's set up the training configs, which includes:

- Pick up our base model.
- Load the tokenizer.
- Apply LoRA to the model for efficient training purpose.
- Configure the optimizer.
- Configure the sampling params.

In [None]:
print("Setting up Tinker clients...")
service_client = tinker.ServiceClient(base_url=config.base_url)

# Check for existing checkpoint to resume
resume_info = checkpoint_utils.get_last_checkpoint(config.log_path)
if resume_info:
    training_client = service_client.create_training_client_from_state(
        resume_info["state_path"]
    )
    start_step = resume_info.get("loop_state", {}).get("grpo_step", 0)
    print(f"  Resuming from GRPO step {start_step}")
else:
    # Create fresh training client with LoRA
    training_client = service_client.create_lora_training_client(
        base_model=config.model_name,
        rank=config.lora_rank
    )
    start_step = 0
    print(f"  Starting fresh training with model: {config.model_name}")

# Setup optimizer (AdamW)
adam_params = tinker.types.AdamParams(
    learning_rate=config.learning_rate,
    beta1=0.9,
    beta2=0.95,
    eps=1e-8
)

# Sampling configuration
stop_sequences = ["</answer>", "\n\n\n"] if config.use_r1_zero_format else ["\n\n\n"]
sampling_params = tinker.types.SamplingParams(
    max_tokens=config.sampling_max_tokens,
    temperature=config.sampling_temperature,
    stop=stop_sequences,
)

# Get tokenizer for encoding/decoding
print("  Loading tokenizer...")
tokenizer = training_client.get_tokenizer()

print("✓ Tinker setup complete")

---
## 7. Main GRPO Training Loop

### High-Level Flow

```
FOR each GRPO step:
  1. Sample questions from dataset
  2. Save current policy weights → create sampling_client
  3. Generate rollouts (group_size per question) [CLOUD]
  4. Compute rewards (local)
  5. Compute group-normalized advantages (local)
  6. Build training datums with old_log_probs
  7. Off-policy training loop:
       FOR each epoch:
         FOR each batch:
           - Tinker computes NEW log_probs from current policy
           - Compares vs old_log_probs → compute ratio
           - GRPO-Clip loss with advantages
           - Gradient step
  8. Log metrics, save checkpoint, validate
```

### Key Difference: On-Policy vs Off-Policy

**On-Policy** (`epochs_per_rollout_batch=1`):
- Single pass through rollout batch
- Policy πθ ≈ πθ_old (minimal divergence)
- Sample efficient but slower training

**Off-Policy** (`epochs_per_rollout_batch>1`):
- Multiple passes through rollout batch
- Policy πθ diverges from πθ_old
- Ratio πθ/πθ_old with clipping prevents instability
- More gradient steps per sample → faster training

In [None]:
# Main training loop setup
n_questions_per_step = config.rollout_batch_size // config.group_size

print(f"Starting GRPO training for {config.n_grpo_steps} steps...")
print(f"  Rollout batch size: {config.rollout_batch_size} ({n_questions_per_step} questions × {config.group_size} samples)")
if config.epochs_per_rollout_batch == 1:
    print("  Mode: On-policy (1 gradient step per rollout batch)")
else:
    print(f"  Mode: Off-policy ({config.epochs_per_rollout_batch} epochs per rollout batch)")
print()

### 7.1 Sample Questions from Dataset

We cycle through the training dataset, shuffling when we wrap around.

In [None]:
# This will be inside the main loop - shown here as example for one step
grpo_step = start_step
t_start = time.time()

# Sample batch of questions
batch_start = (grpo_step * n_questions_per_step) % len(train_dataset)
if batch_start == 0 and grpo_step > 0:
    train_dataset = train_dataset.shuffle()
    print("  Shuffled dataset for new epoch")

batch_end = min(batch_start + n_questions_per_step, len(train_dataset))
batch_rows = train_dataset.select(range(batch_start, batch_end))
actual_n_questions = len(batch_rows)

print(f"Step {grpo_step}/{config.n_grpo_steps}: Processing {actual_n_questions} questions ({actual_n_questions * config.group_size} rollouts)")

### 7.2 Save Policy & Create Sampling Client

### Why?
For **on-policy** GRPO, we need rollouts from the current policy. By saving weights and creating a sampling client, we ensure:
- Rollouts come from πθ (not an outdated policy)
- Log probabilities are correctly captured

### How Tinker Handles This:
```python
# 1. Save current policy weights to cloud
sampling_path = training_client.save_weights_for_sampler().result().path
# Returns: tinker://uuid/weights/step_000001

# 2. Create sampling client from saved weights
sampling_client = service_client.create_sampling_client(model_path=sampling_path)
```

In [None]:
t_save_start = time.time()

# Save current policy for sampling (uploads to Tinker cloud)
sampling_path = (
    training_client.save_weights_for_sampler(name=f"step_{grpo_step:06d}")
    .result()
    .path
)
print(f"  Saved weights: {sampling_path}")

# Create sampling client from saved weights
sampling_client = service_client.create_sampling_client(model_path=sampling_path)

print(f"  Save weights time: {time.time() - t_save_start:.2f}s")

### 7.3 Generate Rollouts

We sample multiple responses per question (group_size=8) for:
- **Group normalization**: Compare responses to the same question
- **Variance reduction**: More stable advantage estimates

### How Tinker Handles This:
Tinker's `sample()` with `num_samples=8` distributes sampling across nodes:
```python
future = sampling_client.sample(
    prompt=tinker.types.ModelInput.from_ints(prompt_tokens),
    num_samples=8,  # Sample 8 times in parallel!
    sampling_params=sampling_params,
)
```

Returns `sequence.tokens` and `sequence.logprobs` for each sample.

In [None]:
t_sample_start = time.time()

# Prepare prompts and launch sampling
batch_futures = []
batch_prompts = []
batch_answers = []

for question, answer in zip(batch_rows["question"], batch_rows["answer"]):
    prompt = build_gsm8k_prompt(question, config.use_r1_zero_format)
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
    batch_prompts.append(prompt_tokens)
    batch_answers.append(answer)

    # Sample group_size responses per question (distributed across Tinker nodes)
    future = sampling_client.sample(
        prompt=tinker.types.ModelInput.from_ints(prompt_tokens),
        num_samples=config.group_size,  # 8 samples in parallel
        sampling_params=sampling_params,
    )
    batch_futures.append(future)

print(f"  Launched {len(batch_futures)} sampling jobs ({config.group_size} samples each)")

# Collect rollout results
rollout_responses = []
rollout_ground_truths = []
rollout_prompt_tokens = []
rollout_response_tokens = []
rollout_logprobs = []  # OLD log probs from sampling (frozen reference)

for prompt_tokens, answer, future in zip(batch_prompts, batch_answers, batch_futures):
    sample_result = future.result()

    # Process each sample in the group
    for sequence in sample_result.sequences:
        sampled_tokens = sequence.tokens
        sampled_logprobs = sequence.logprobs  # These are OLD log probs (πθ_old)!

        if sampled_logprobs is None:
            print("    Warning: No logprobs returned! Skipping sample.")
            continue

        # Decode response
        response_text = tokenizer.decode(sampled_tokens, skip_special_tokens=False)

        rollout_responses.append(response_text)
        rollout_ground_truths.append(answer)
        rollout_prompt_tokens.append(prompt_tokens)
        rollout_response_tokens.append(sampled_tokens)
        rollout_logprobs.append(sampled_logprobs)  # Store for off-policy training

n_rollouts = len(rollout_responses)
print(f"  Generated {n_rollouts} rollouts in {time.time() - t_sample_start:.2f}s")
print(f"  Example response: {rollout_responses[0][:200]}...")

### 7.4 Compute Rewards & Advantages

### Why Group Normalization?
GRPO normalizes rewards **within each group** (responses to the same question):
```
advantage = (reward - mean(group_rewards)) / (std(group_rewards) + eps)
```

Note: When we compute GRPO loss with Tinker, "ppo" loss + subtracting the mean of group rewards is equivalent to
GRPO loss.


Benefits:
- **Relative comparison**: "This response is better than average for this question"
- **Variance reduction**: Less sensitive to reward scale
- **Credit assignment**: Rewards differences, not absolute values

In [None]:
t_reward_start = time.time()

# Compute rewards and group-normalized advantages
# Using the group normalization function defined earlier
advantages, raw_rewards, reward_metadata = run_compute_group_normalized_rewards(
    reward_fn=gsm8k_reward_fn,
    rollout_responses=rollout_responses,
    repeated_ground_truths=rollout_ground_truths,
    group_size=config.group_size,
    advantage_eps=config.advantage_eps,
    normalize_by_std=config.use_std_normalization,
)

print(f"  Computed rewards in {time.time() - t_reward_start:.2f}s")
print("  Reward stats:")
print(f"    Mean: {reward_metadata['mean_reward']:.3f}")
print(f"    Std:  {reward_metadata['std_reward']:.3f}")
print(f"    Min:  {reward_metadata['min_reward']:.3f}")
print(f"    Max:  {reward_metadata['max_reward']:.3f}")

### 7.5 Build Training Datums

We construct `tinker.types.Datum` objects that contain:
1. **`model_input`**: Input tokens for forward pass
2. **`loss_fn_inputs`**:
   - `target_tokens`: What the model should predict
   - `advantages`: RL signal (positive=reinforce, negative=discourage)
   - `logprobs`: **OLD log probs** from sampling (frozen reference πθ_old)

### Causal LM Shift
```python
full_tokens = prompt + response  # [1, 2, 3, 4, 5]
input_tokens = full_tokens[:-1]  # [1, 2, 3, 4] - predict next token
target_tokens = full_tokens[1:]  # [2, 3, 4, 5] - what to predict
```

### Advantage Masking
Only compute loss on **response tokens** (not prompt):
```python
all_advantages = [0.0] * (prompt_len - 1) + [advantage] * len(response_tokens)
```

### Off-Policy Key: OLD Log Probs
The `logprobs` in `loss_fn_inputs` are from **sampling** (πθ_old). These are **frozen** and reused across multiple gradient steps.

In [None]:
t_build_data_start = time.time()
training_datums = []
skipped_samples = 0

for idx, (prompt_tokens, response_tokens, old_logprobs, advantage) in enumerate(
    zip(
        rollout_prompt_tokens,
        rollout_response_tokens,
        rollout_logprobs,  # OLD log probs from sampling (πθ_old)
        advantages,
    )
):
    if len(response_tokens) == 0:
        skipped_samples += 1
        continue

    # Construct full sequence
    full_tokens = prompt_tokens + response_tokens
    input_tokens = full_tokens[:-1]  # Shift for causal LM
    target_tokens = full_tokens[1:]

    prompt_len = len(prompt_tokens)

    # Advantage masking: 0 for prompt, advantage for response
    all_advantages = [0.0] * (prompt_len - 1) + [advantage.item()] * len(response_tokens)

    # OLD log probs: 0 for prompt, actual logprobs for response
    all_logprobs = [0.0] * (prompt_len - 1) + old_logprobs

    # Validate lengths
    if not (len(input_tokens) == len(target_tokens) == len(all_advantages) == len(all_logprobs)):
        print(f"    Warning: Length mismatch at sample {idx}, skipping")
        skipped_samples += 1
        continue

    # Create Datum for PPO loss
    datum = tinker.types.Datum(
        model_input=tinker.types.ModelInput.from_ints(tokens=input_tokens),
        loss_fn_inputs={
            "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
            "advantages": TensorData.from_torch(torch.tensor(all_advantages)),
            "logprobs": TensorData.from_torch(torch.tensor(all_logprobs)),  # OLD log probs!
        },
    )
    training_datums.append(datum)

n_training_samples = len(training_datums)
print(f"  Built {n_training_samples} training datums in {time.time() - t_build_data_start:.2f}s")
print(f"  Skipped {skipped_samples} samples")

if n_training_samples == 0:
    print("  ERROR: No valid training samples! Skipping this step.")

### 7.6 Off-Policy Training Loop

### The Core of Off-Policy GRPO!

**What Tinker Does Internally (Each Gradient Step)**:
1. Forward pass through CURRENT policy πθ with `model_input` tokens
2. Compute NEW log probs from current model outputs
3. Compare with OLD log probs from `loss_fn_inputs["logprobs"]`
4. Compute ratio: `ratio = exp(new_logprobs - old_logprobs) = πθ / πθ_old`
5. Apply clipping: `clipped_ratio = clip(ratio, 0.8, 1.2)`
6. GRPO-Clip loss: `-min(advantages × ratio, advantages × clipped_ratio)`
7. Backward pass

### Why This Works Off-Policy:

After the first gradient step, πθ ≠ πθ_old, but we keep using the same frozen `old_logprobs`!

The **ratio** πθ/πθ_old measures policy divergence:
- ratio ≈ 1.0: Policy hasn't changed much (safe)
- ratio >> 1.0: Policy much more likely (clipped to 1.2!)
- ratio << 1.0: Policy much less likely (clipped to 0.8!)

**Clipping** prevents excessive policy updates.

In [None]:
t_train_start = time.time()

# Configure PPO clipping thresholds
clip_low = 1.0 - config.cliprange   # 0.8
clip_high = 1.0 + config.cliprange  # 1.2

# Calculate training schedule
n_batches_per_epoch = len(training_datums) // config.train_batch_size
n_total_updates = config.epochs_per_rollout_batch * n_batches_per_epoch

print(f"  Training: {config.epochs_per_rollout_batch} epochs, {n_batches_per_epoch} batches per epoch, {n_total_updates} total gradient steps")

# Outer loop: epochs (Off-policy: multiple passes through data)
for epoch in range(config.epochs_per_rollout_batch):
    # Shuffle datums each epoch to prevent overfitting
    shuffled_datums = training_datums.copy()
    random.shuffle(shuffled_datums)

    # Inner loop: batches (Split rollouts into smaller batches)
    for batch_idx in range(n_batches_per_epoch):
        # Extract batch of datums
        start_idx = batch_idx * config.train_batch_size
        end_idx = start_idx + config.train_batch_size
        batch_datums = shuffled_datums[start_idx:end_idx]

        # Tinker's forward_backward does:
        # 1. Forward pass through CURRENT policy → NEW log probs
        # 2. Compare NEW vs OLD log probs from loss_fn_inputs["logprobs"]
        # 3. Compute ratio = exp(new_logprobs - old_logprobs)
        # 4. Apply GRPO-Clip loss with advantages
        # 5. Backward pass
        fwd_bwd_future = training_client.forward_backward(
            batch_datums,  # Contains OLD log probs from sampling!
            loss_fn="ppo",  # GRPO-Clip (PPO-style clipping)
            loss_fn_config={
                "clip_low_threshold": clip_low,
                "clip_high_threshold": clip_high,
            },
        )

        # Optimizer step (updates policy parameters)
        optim_step_future = training_client.optim_step(adam_params)

        # Wait for completion
        fwd_bwd_result = fwd_bwd_future.result()
        optim_result = optim_step_future.result()

        if (batch_idx + 1) % max(1, n_batches_per_epoch // 2) == 0:
            print(f"    Epoch {epoch+1}/{config.epochs_per_rollout_batch}, Batch {batch_idx+1}/{n_batches_per_epoch}")

print(f"  Training completed in {time.time() - t_train_start:.2f}s")
print(f"  Total gradient steps: {n_total_updates}")

### 7.7 Validation

We need to periodically evaluate on the validation set to check how the training performs because
the a converged loss may not imply a good model.

### Key: Greedy Sampling
Use `temperature=0.0` for deterministic evaluation (no randomness).

In [None]:
def validate_on_test_set(
    sampling_client,
    test_dataset,
    tokenizer,
    use_r1_zero_format: bool,
    n_samples: int = 40,
) -> dict[str, float]:
    """Evaluate model on GSM8K test set."""
    print(f"  Running validation on {n_samples} test examples...")
    test_sample = test_dataset.shuffle().select(range(min(n_samples, len(test_dataset))))

    correct = 0
    format_correct = 0

    for question, answer in zip(test_sample["question"], test_sample["answer"]):
        prompt = build_gsm8k_prompt(question, use_r1_zero_format)

        try:
            prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
            response = sampling_client.sample(
                prompt=tinker.types.ModelInput.from_ints(prompt_tokens),
                num_samples=1,
                sampling_params=tinker.types.SamplingParams(
                    max_tokens=2048,
                    temperature=0.0,  # Greedy for evaluation
                    stop=["</answer>", "\n\n\n"] if use_r1_zero_format else ["\n\n\n"],
                ),
            ).result()

            response_tokens = response.sequences[0].tokens
            response_text = tokenizer.decode(response_tokens, skip_special_tokens=False)
            reward_dict = gsm8k_reward_fn(response_text, answer)

            if reward_dict["answer_reward"] == 1.0:
                correct += 1
            if reward_dict["format_reward"] == 1.0:
                format_correct += 1

        except Exception as e:
            print(f"    Warning: Validation error: {e}")
            continue

    accuracy = correct / n_samples
    format_accuracy = format_correct / n_samples

    print(f"  Validation - Accuracy: {accuracy:.1%}, Format: {format_accuracy:.1%}")

    return {
        "accuracy": accuracy,
        "format_accuracy": format_accuracy,
    }

# Run validation (example - would be inside training loop)
if grpo_step % config.eval_every == 0:
    try:
        val_metrics = validate_on_test_set(
            sampling_client,
            test_dataset,
            tokenizer,
            config.use_r1_zero_format,
            n_samples=40,
        )
        print(f"  Validation metrics: {val_metrics}")
    except Exception as e:
        print(f"  Validation failed: {e}")

### 7.8 Checkpointing

### Why?
Save training progress to:
- Resume if interrupted
- Recover best model
- Track training history

### How Tinker Handles This:

Tinker stores checkpoints in **cloud storage** (not local disk!):

**Locally** (`log_path/checkpoints.jsonl`):
```json
{
  "name": "grpo_step_000005",
  "grpo_step": 6,
  "state_path": "tinker://uuid/weights/grpo_step_000005"
}
```

**In Tinker cloud**: Actual model weights (GBs)

In [None]:
# Save checkpoint (example - would be inside training loop)
if (grpo_step + 1) % config.save_every == 0:
    print(f"  Saving checkpoint at step {grpo_step}...")
    checkpoint_utils.save_checkpoint(
        training_client=training_client,
        name=f"grpo_step_{grpo_step:06d}",
        log_path=config.log_path,
        kind="both",  # Save both training state and sampler weights
        loop_state={"grpo_step": grpo_step + 1},
    )
    print("  ✓ Checkpoint saved to Tinker cloud")

---
## 8. Summary: On-Policy vs Off-Policy GRPO

### On-Policy GRPO
**Config**: `epochs_per_rollout_batch=1`, `train_batch_size=rollout_batch_size`

**Flow**:
1. Sample 256 rollouts from current policy πθ
2. Compute advantages
3. **One** gradient step on all 256 rollouts
4. Policy updated: πθ → πθ'
5. Repeat (sample from πθ' next time)

**Pros**: Stable, policy ≈ sampling policy  
**Cons**: Sample inefficient (need new rollouts every step)

### Off-Policy GRPO
**Config**: `epochs_per_rollout_batch=4`, `train_batch_size=64`

**Flow**:
1. Sample 256 rollouts from current policy πθ_old
2. Compute advantages, freeze old_log_probs
3. **16 gradient steps** (4 epochs × 4 batches):
   - Epoch 1, Batch 1: πθ → πθ₁ (ratio πθ₁/πθ_old ≈ 1.0)
   - Epoch 1, Batch 2: πθ₁ → πθ₂ (ratio πθ₂/πθ_old ≈ 1.05)
   - ...
   - Epoch 4, Batch 4: πθ₁₅ → πθ₁₆ (ratio πθ₁₆/πθ_old ≈ 1.15, clipped to 1.2!)
4. Policy updated: πθ_old → πθ₁₆
5. Repeat (sample from πθ₁₆ next time)

**Pros**: Sample efficient (16 gradient steps per rollout batch)  
**Cons**: Policy divergence (mitigated by clipping)

### Key Insight

Off-policy works because:
- **Frozen old_log_probs**: Reference from sampling policy πθ_old
- **Fresh new_log_probs**: Computed at each gradient step from current πθ
- **Ratio clipping**: Prevents excessive divergence (0.8 ≤ πθ/πθ_old ≤ 1.2)

This allows us to take **multiple gradient steps** per rollout batch while staying reasonably close to the sampling distribution.

---
## 9. Full Training Loop

Now let's put it all together in a complete training loop!

In [None]:
# Full training loop
n_questions_per_step = config.rollout_batch_size // config.group_size

for grpo_step in range(start_step, config.n_grpo_steps):
    t_start = time.time()
    
    print(f"\n{'='*60}")
    print(f"GRPO Step {grpo_step}/{config.n_grpo_steps}")
    print(f"{'='*60}")
    
    # 1. Sample questions from dataset
    batch_start = (grpo_step * n_questions_per_step) % len(train_dataset)
    if batch_start == 0 and grpo_step > 0:
        train_dataset = train_dataset.shuffle()
        print("  Shuffled dataset for new epoch")
    
    batch_end = min(batch_start + n_questions_per_step, len(train_dataset))
    batch_rows = train_dataset.select(range(batch_start, batch_end))
    actual_n_questions = len(batch_rows)
    
    print(f"  Processing {actual_n_questions} questions ({actual_n_questions * config.group_size} rollouts)")
    
    # 2. Save policy & create sampling client
    t_save_start = time.time()
    sampling_path = (
        training_client.save_weights_for_sampler(name=f"step_{grpo_step:06d}")
        .result()
        .path
    )
    sampling_client = service_client.create_sampling_client(model_path=sampling_path)
    print(f"  Saved weights in {time.time() - t_save_start:.2f}s")
    
    # 3. Generate rollouts
    t_sample_start = time.time()
    batch_futures = []
    batch_prompts = []
    batch_answers = []
    
    for question, answer in zip(batch_rows["question"], batch_rows["answer"]):
        prompt = build_gsm8k_prompt(question, config.use_r1_zero_format)
        prompt_tokens = tokenizer.encode(prompt, add_special_tokens=True)
        batch_prompts.append(prompt_tokens)
        batch_answers.append(answer)
        
        future = sampling_client.sample(
            prompt=tinker.types.ModelInput.from_ints(prompt_tokens),
            num_samples=config.group_size,
            sampling_params=sampling_params,
        )
        batch_futures.append(future)
    
    # Collect rollouts
    rollout_responses = []
    rollout_ground_truths = []
    rollout_prompt_tokens = []
    rollout_response_tokens = []
    rollout_logprobs = []
    
    for prompt_tokens, answer, future in zip(batch_prompts, batch_answers, batch_futures):
        sample_result = future.result()
        
        for sequence in sample_result.sequences:
            sampled_tokens = sequence.tokens
            sampled_logprobs = sequence.logprobs
            
            if sampled_logprobs is None:
                continue
            
            response_text = tokenizer.decode(sampled_tokens, skip_special_tokens=False)
            
            rollout_responses.append(response_text)
            rollout_ground_truths.append(answer)
            rollout_prompt_tokens.append(prompt_tokens)
            rollout_response_tokens.append(sampled_tokens)
            rollout_logprobs.append(sampled_logprobs)
    
    n_rollouts = len(rollout_responses)
    print(f"  Generated {n_rollouts} rollouts in {time.time() - t_sample_start:.2f}s")
    
    # 4. Compute rewards & advantages
    t_reward_start = time.time()
    advantages, raw_rewards, reward_metadata = run_compute_group_normalized_rewards(
        reward_fn=gsm8k_reward_fn,
        rollout_responses=rollout_responses,
        repeated_ground_truths=rollout_ground_truths,
        group_size=config.group_size,
        advantage_eps=config.advantage_eps,
        normalize_by_std=config.use_std_normalization,
    )
    print(f"  Rewards: mean={reward_metadata['mean_reward']:.3f}, std={reward_metadata['std_reward']:.3f}")
    
    # 5. Build training datums
    t_build_start = time.time()
    training_datums = []
    skipped = 0
    
    for prompt_tokens, response_tokens, old_logprobs, advantage in zip(
        rollout_prompt_tokens, rollout_response_tokens, rollout_logprobs, advantages
    ):
        if len(response_tokens) == 0:
            skipped += 1
            continue
        
        full_tokens = prompt_tokens + response_tokens
        input_tokens = full_tokens[:-1]
        target_tokens = full_tokens[1:]
        prompt_len = len(prompt_tokens)
        
        all_advantages = [0.0] * (prompt_len - 1) + [advantage.item()] * len(response_tokens)
        all_logprobs = [0.0] * (prompt_len - 1) + old_logprobs
        
        if not (len(input_tokens) == len(target_tokens) == len(all_advantages) == len(all_logprobs)):
            skipped += 1
            continue
        
        datum = tinker.types.Datum(
            model_input=tinker.types.ModelInput.from_ints(tokens=input_tokens),
            loss_fn_inputs={
                "target_tokens": TensorData.from_torch(torch.tensor(target_tokens)),
                "advantages": TensorData.from_torch(torch.tensor(all_advantages)),
                "logprobs": TensorData.from_torch(torch.tensor(all_logprobs)),
            },
        )
        training_datums.append(datum)
    
    print(f"  Built {len(training_datums)} datums ({skipped} skipped)")
    
    if len(training_datums) == 0:
        print("  ERROR: No valid training samples! Skipping this step.")
        continue
    
    # 6. Off-policy training loop
    t_train_start = time.time()
    clip_low = 1.0 - config.cliprange
    clip_high = 1.0 + config.cliprange
    n_batches_per_epoch = len(training_datums) // config.train_batch_size
    n_total_updates = config.epochs_per_rollout_batch * n_batches_per_epoch
    
    for epoch in range(config.epochs_per_rollout_batch):
        shuffled_datums = training_datums.copy()
        random.shuffle(shuffled_datums)
        
        for batch_idx in range(n_batches_per_epoch):
            start_idx = batch_idx * config.train_batch_size
            end_idx = start_idx + config.train_batch_size
            batch_datums = shuffled_datums[start_idx:end_idx]
            
            fwd_bwd_future = training_client.forward_backward(
                batch_datums,
                loss_fn="ppo",
                loss_fn_config={
                    "clip_low_threshold": clip_low,
                    "clip_high_threshold": clip_high,
                },
            )
            optim_step_future = training_client.optim_step(adam_params)
            
            fwd_bwd_result = fwd_bwd_future.result()
            optim_result = optim_step_future.result()
    
    print(f"  Training: {n_total_updates} gradient steps in {time.time() - t_train_start:.2f}s")
    
    # 7. Validation
    if grpo_step % config.eval_every == 0:
        try:
            val_metrics = validate_on_test_set(
                sampling_client, test_dataset, tokenizer, 
                config.use_r1_zero_format, n_samples=40
            )
            print(f"  Validation: accuracy={val_metrics['accuracy']:.1%}")
        except Exception as e:
            print(f"  Validation failed: {e}")
    
    # 8. Checkpointing
    if (grpo_step + 1) % config.save_every == 0:
        print(f"  Saving checkpoint...")
        checkpoint_utils.save_checkpoint(
            training_client=training_client,
            name=f"grpo_step_{grpo_step:06d}",
            log_path=config.log_path,
            kind="both",
            loop_state={"grpo_step": grpo_step + 1},
        )
    
    print(f"  Step {grpo_step} completed in {time.time() - t_start:.1f}s")

# Final checkpoint
print("\nTraining complete! Saving final checkpoint...")
checkpoint_utils.save_checkpoint(
    training_client=training_client,
    name="final",
    log_path=config.log_path,
    kind="both",
    loop_state={"grpo_step": config.n_grpo_steps},
)
print("✓ Training completed successfully!")