# PPO Canary: RL-Specific Training Validation

This notebook demonstrates how to use RLHF Canary for PPO (Proximal Policy Optimization) training validation.

**What you'll learn:**
1. How PPO training differs from DPO/SFT
2. Running PPO canaries with synthetic rewards
3. Understanding PPO-specific metrics (KL, entropy, clip fraction)
4. Memory considerations: PPO uses ~3x memory vs SFT
5. PPO failure modes and how to detect them

**Requirements:** GPU runtime (Runtime > Change runtime type > T4 GPU)

**Runtime:** ~15 minutes

## 1. Setup

In [None]:
import os
import re
import sys

print("Starting Environment Setup...")

# --- 1. Clone the repo first ---
if not os.path.exists("/content/rlhf-canary"):
    !git clone https://github.com/mmcmanus1/rlhf-canary.git /content/rlhf-canary

%cd /content/rlhf-canary

# --- 2. Force-Install the "Safe Harbor" Stack ---
!pip install "trl==0.11.4" "transformers==4.44.2" "peft==0.12.0" "accelerate==0.34.2" "tokenizers==0.19.1" --force-reinstall --no-deps --quiet
!pip install -q datasets pydantic click PyYAML bitsandbytes
print("Libraries installed (TRL 0.11.4 / Transformers 4.44.2)")

# --- 3. Patch pyproject.toml (Prevent future drift) ---
project_file = "/content/rlhf-canary/pyproject.toml"
if os.path.exists(project_file):
    with open(project_file, "r") as f:
        content = f.read()
    
    if "trl==0.11.4" not in content:
        content = re.sub(r'trl[<>=!~]+[\d\.]+', 'trl==0.11.4', content)
        with open(project_file, "w") as f:
            f.write(content)
        print("Config file patched to lock TRL 0.11.4")

# --- 4. Patch Source Code (Compatibility Fix) ---
runner_file = "/content/rlhf-canary/canary/runner/local.py"
if os.path.exists(runner_file):
    with open(runner_file, "r") as f:
        code = f.read()
    
    if "processing_class=" in code:
        code = code.replace("processing_class=", "tokenizer=")
        with open(runner_file, "w") as f:
            f.write(code)
        print("Code patched: Reverted 'processing_class' to 'tokenizer'")
    else:
        print("Code is already compatible.")

# --- 5. Install the package ---
!pip install -e . --quiet

print("Environment Ready!")

In [None]:
# Verify GPU and installation
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"Memory: {mem_gb:.1f} GB")
    
    # Warning for low memory
    if mem_gb < 12:
        print("\nNOTE: PPO uses ~3x memory vs SFT (model + ref_model + value_head)")
        print("On T4 (16GB), PPO with pythia-70m should work fine.")

import canary
print(f"Canary module loaded from: {canary.__file__}")

## 2. Understanding PPO Training

PPO (Proximal Policy Optimization) is fundamentally different from DPO/SFT:

### PPO vs DPO/SFT

| Aspect | SFT/DPO | PPO |
|--------|---------|-----|
| Training loop | Forward + backward | Generate → Reward → Update |
| Memory usage | ~1-2x model | ~3x model |
| Components | Model (+ ref for DPO) | Model + ref + value head |
| Reward signal | Implicit (preferences) | Explicit (reward model) |
| Key metrics | Loss, KL | Policy loss, value loss, KL, entropy, clip fraction |

### PPO-Specific Metrics

| Metric | What it measures | Healthy range |
|--------|-----------------|---------------|
| `objective/kl` | Policy drift from reference | < target_kl (default: 6.0) |
| `objective/entropy` | Policy randomness | Decreasing slowly, not collapsing |
| `ppo/clipfrac` | Update clipping frequency | < 0.2 is normal |
| `ppo/policy_loss` | Actor loss | Stable, no NaN |
| `ppo/value_loss` | Critic loss | Stable, no NaN |

### Memory Architecture

```
PPO Memory = Policy Model + Reference Model + Value Head
           ≈ 1x model    + 1x model        + small overhead
           ≈ 3x SFT memory
```

## 3. PPO Configuration Explained

In [None]:
# Show the PPO smoke config
!cat configs/ppo_smoke.yaml

In [None]:
# Explain key PPO parameters
print("PPO Configuration Parameters:")
print("="*60)

params = [
    ("ppo_epochs", "4", "PPO update epochs per batch (more = more aggressive)"),
    ("init_kl_coef", "0.2", "Initial KL penalty coefficient"),
    ("target_kl", "6.0", "Target KL for adaptive penalty"),
    ("cliprange", "0.2", "Policy ratio clipping range"),
    ("vf_coef", "0.1", "Value function loss weight"),
    ("max_new_tokens", "64", "Tokens to generate per prompt"),
    ("use_synthetic_reward", "true", "Use length-based rewards (for canary)"),
]

print(f"\n{'Parameter':<20} {'Default':>10} {'Description'}")
print("-"*70)
for name, default, desc in params:
    print(f"{name:<20} {default:>10} {desc}")

## 4. Run PPO Baseline Canary

In [None]:
# Run PPO smoke test (~10-15 min)
!python -m canary.cli run configs/ppo_smoke.yaml -o ./ppo_output/baseline

In [None]:
import json
from pathlib import Path

# Load and display PPO metrics
baseline_paths = list(Path('./ppo_output/baseline').rglob('metrics.json'))
if not baseline_paths:
    raise FileNotFoundError("No metrics.json found for PPO baseline. Did the training complete?")

baseline_path = baseline_paths[0]
with open(baseline_path) as f:
    ppo_metrics = json.load(f)

print("="*60)
print("PPO BASELINE METRICS")
print("="*60)

print(f"\nPerformance:")
print(f"  Step time (mean): {ppo_metrics['perf']['step_time']['mean']:.4f}s")
print(f"  Tokens/sec: {ppo_metrics['perf']['approx_tokens_per_sec']:.0f}")
print(f"  Peak memory: {ppo_metrics['perf']['max_mem_mb']:.0f}MB")

print(f"\nStability:")
print(f"  NaN steps: {ppo_metrics['stability']['nan_steps']}")
print(f"  Inf steps: {ppo_metrics['stability']['inf_steps']}")
print(f"  Loss diverged: {ppo_metrics['stability']['loss_diverged']}")

## 5. PPO vs DPO: Memory Comparison

Let's compare memory usage between PPO and DPO to see the ~3x difference.

In [None]:
# Run a DPO smoke test for comparison
!python -m canary.cli run configs/dpo_smoke.yaml -o ./ppo_output/dpo_comparison

In [None]:
# Load DPO metrics
dpo_paths = list(Path('./ppo_output/dpo_comparison').rglob('metrics.json'))
if not dpo_paths:
    raise FileNotFoundError("No metrics.json found for DPO comparison. Did the training complete?")

dpo_path = dpo_paths[0]
with open(dpo_path) as f:
    dpo_metrics = json.load(f)

print("="*60)
print("MEMORY COMPARISON: PPO vs DPO")
print("="*60)

ppo_mem = ppo_metrics['perf']['max_mem_mb']
dpo_mem = dpo_metrics['perf']['max_mem_mb']

# Sanity check for valid memory values
if dpo_mem > 100:  # Reasonable minimum memory in MB
    ratio = ppo_mem / dpo_mem
else:
    print("Warning: DPO memory metrics appear invalid")
    ratio = 0

print(f"\n{'Training Type':<15} {'Peak Memory':>15} {'Relative':>15}")
print("-"*45)
print(f"{'DPO':<15} {dpo_mem:>14.0f}MB {'1.0x':>15}")
print(f"{'PPO':<15} {ppo_mem:>14.0f}MB {f'{ratio:.1f}x':>15}")

print(f"\nPPO uses {ratio:.1f}x more memory than DPO")
print("This is expected: PPO = policy + reference + value head")

## 6. Comparing PPO Runs

Let's save the baseline and run another PPO canary for comparison.

In [None]:
# Save baseline
!mkdir -p baselines
!cp {baseline_path} baselines/ppo_baseline.json
print(f"Saved PPO baseline to baselines/ppo_baseline.json")

In [None]:
# Create a slightly different config for comparison
ppo_run2_config = """
name: ppo_run2
description: Second PPO run for comparison

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

training_type: ppo
max_steps: 50
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 1.0e-5
max_length: 256
warmup_steps: 5

# Same PPO params as baseline
ppo_epochs: 4
init_kl_coef: 0.2
target_kl: 6.0
cliprange: 0.2
vf_coef: 0.1
max_prompt_length: 64
max_new_tokens: 64
use_synthetic_reward: true

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 256
seed: 123  # Different seed

output_dir: ./ppo_output
metrics_warmup_steps: 5

profiler:
  enabled: false
"""

with open('configs/ppo_run2.yaml', 'w') as f:
    f.write(ppo_run2_config)

print("Created second PPO config with different seed")

In [None]:
# Run second PPO canary
!python -m canary.cli run configs/ppo_run2.yaml -o ./ppo_output/run2

In [None]:
# Compare to baseline
run2_paths = list(Path('./ppo_output/run2').rglob('metrics.json'))
if not run2_paths:
    raise FileNotFoundError("No metrics.json found for run2. Did the training complete?")

run2_path = run2_paths[0]
!python -m canary.cli compare {run2_path} baselines/ppo_baseline.json --threshold-tier smoke

## 7. PPO Failure Modes

PPO has specific failure modes to watch for:

### 1. KL Explosion
- **Symptom**: `objective/kl` >> target_kl
- **Cause**: Policy drifting too far from reference
- **Fix**: Increase `init_kl_coef`, reduce learning rate

### 2. Entropy Collapse
- **Symptom**: `objective/entropy` drops to near zero
- **Cause**: Policy becoming deterministic too fast
- **Fix**: Add entropy bonus, check reward signal

### 3. Value Function Divergence
- **Symptom**: `ppo/value_loss` exploding
- **Cause**: Value head not learning properly
- **Fix**: Increase `vf_coef`, check reward scale

### 4. High Clip Fraction
- **Symptom**: `ppo/clipfrac` > 0.3 consistently
- **Cause**: Policy updates too aggressive
- **Fix**: Reduce learning rate or `cliprange`

In [None]:
# Create an unstable PPO config to demonstrate failure modes
unstable_ppo_config = """
name: ppo_unstable
description: Intentionally unstable PPO settings

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

training_type: ppo
max_steps: 30
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 5.0e-4   # 5x higher LR!
max_length: 256
warmup_steps: 2

# UNSTABLE PPO settings
ppo_epochs: 8           # More aggressive updates
init_kl_coef: 0.05      # Weak KL penalty
target_kl: 20.0         # Very high target
cliprange: 0.3          # Wider clipping
vf_coef: 0.1
max_prompt_length: 64
max_new_tokens: 64
use_synthetic_reward: true

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 128
seed: 42

output_dir: ./ppo_output
metrics_warmup_steps: 2

profiler:
  enabled: false
"""

with open('configs/ppo_unstable.yaml', 'w') as f:
    f.write(unstable_ppo_config)

print("Created unstable PPO config:")
print("  - learning_rate: 5.0e-4 (5x normal)")
print("  - ppo_epochs: 8 (2x normal)")
print("  - init_kl_coef: 0.05 (weak KL penalty)")
print("  - target_kl: 20.0 (very high)")

In [None]:
# Run unstable config
!python -m canary.cli run configs/ppo_unstable.yaml -o ./ppo_output/unstable

In [None]:
# Compare unstable run to baseline
unstable_paths = list(Path('./ppo_output/unstable').rglob('metrics.json'))

if unstable_paths:
    unstable_path = unstable_paths[0]
    with open(unstable_path) as f:
        unstable_metrics = json.load(f)
    
    print("="*60)
    print("UNSTABLE PPO RUN ANALYSIS")
    print("="*60)
    
    print(f"\nStability Metrics:")
    print(f"  NaN steps: {unstable_metrics['stability']['nan_steps']}")
    print(f"  Inf steps: {unstable_metrics['stability']['inf_steps']}")
    print(f"  Loss diverged: {unstable_metrics['stability']['loss_diverged']}")
    
    print(f"\nPerformance Comparison:")
    print(f"  Baseline step time: {ppo_metrics['perf']['step_time']['mean']:.4f}s")
    print(f"  Unstable step time: {unstable_metrics['perf']['step_time']['mean']:.4f}s")
    
    # Run formal comparison
    print("\n" + "="*60)
    print("FORMAL COMPARISON")
    print("="*60)
    !python -m canary.cli compare {unstable_path} baselines/ppo_baseline.json --threshold-tier smoke
else:
    print("Unstable run failed to produce metrics (crashed during training)")

## 8. PPO Canary Best Practices

### Configuration Tips

```yaml
# Conservative settings for stable canaries
ppo_epochs: 4          # Standard value
init_kl_coef: 0.2      # Strong KL penalty
target_kl: 6.0         # Reasonable target
cliprange: 0.2         # Standard clipping
learning_rate: 1.0e-5  # Conservative LR
```

### Memory Management

1. Use PEFT/LoRA to reduce memory
2. Use 4-bit quantization for larger models
3. Keep batch_size small on limited VRAM
4. Monitor peak memory in canary reports

### Synthetic vs Real Rewards

The canary uses **synthetic rewards** (length-based) for testing:
- Pros: No need for reward model, fast, consistent
- Cons: Doesn't test reward model integration

For production, you'd use a trained reward model, but for regression testing, synthetic rewards are sufficient.

In [None]:
# Show available PPO configs
print("Available PPO Configurations:")
print("="*60)

configs = [
    ("ppo_smoke.yaml", "~10-15 min", "PR gating, quick validation"),
    ("ppo_perf.yaml", "~45-60 min", "Performance regression detection"),
    ("ppo_nightly.yaml", "~90-120 min", "Comprehensive soak test"),
]

print(f"\n{'Config':<20} {'Runtime':>15} {'Use Case'}")
print("-"*60)
for config, runtime, use_case in configs:
    print(f"{config:<20} {runtime:>15} {use_case}")

## 9. Summary

### Key Takeaways:

1. **PPO uses ~3x memory** vs SFT (model + ref + value head)
2. **Synthetic rewards** allow canary testing without a reward model
3. **PPO-specific metrics**: KL, entropy, clip fraction, policy/value loss
4. **Failure modes**: KL explosion, entropy collapse, high clip fraction
5. **Conservative settings** ensure stable canaries

### When to Use PPO Canaries:

- After changes to PPO training code
- After TRL/HuggingFace library updates
- When optimizing PPO hyperparameters
- As part of nightly regression testing

### Configuration Tiers:

| Tier | Steps | Runtime | Use Case |
|------|-------|---------|----------|
| Smoke | 50 | ~15 min | PR gating |
| Perf | 200 | ~60 min | Performance analysis |
| Nightly | 500 | ~2 hr | Soak testing |

### Related Notebooks:

- `02_profiler_deep_dive.ipynb` - Performance profiling (DPO/SFT)
- `03_stability_monitoring.ipynb` - Stability metrics and failure detection
- `04_root_cause_analysis.ipynb` - Debugging regression causes