[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mmcmanus1/rlhf-canary/blob/main/notebooks/03_stability_monitoring.ipynb)

# Stability Monitoring: Detecting Training Instabilities

Catch NaN explosions, loss divergence, and gradient problems before they waste your training run. Learn to recognize early warning signs and predict failures before they happen.

**What you'll learn:**
1. What stability metrics are tracked (NaN/Inf, loss divergence, gradient norms)
2. PPO-specific stability metrics (KL divergence, entropy, clip fraction)
3. How to inject artificial instabilities for testing
4. Interpreting stability regression reports
5. Early warning patterns - predict run failure before it happens

**Requirements:** GPU runtime (Runtime > Change runtime type > T4 GPU)

**Runtime:** ~12-15 minutes

## 1. Setup

In [None]:
import os
import re
import sys

print("Starting Environment Setup...")

# --- 1. Clone or update the repo ---
if not os.path.exists("/content/rlhf-canary"):
    !git clone https://github.com/mmcmanus1/rlhf-canary.git /content/rlhf-canary
else:
    !cd /content/rlhf-canary && git pull --ff-only

%cd /content/rlhf-canary

# --- 2. Force-Install the "Safe Harbor" Stack ---
!pip install "trl==0.11.4" "transformers==4.44.2" "peft==0.12.0" "accelerate==0.34.2" "tokenizers==0.19.1" --force-reinstall --no-deps --quiet
!pip install -q datasets pydantic click PyYAML bitsandbytes
print("Libraries installed (TRL 0.11.4 / Transformers 4.44.2)")

# --- 3. Patch pyproject.toml (Prevent future drift) ---
project_file = "/content/rlhf-canary/pyproject.toml"
if os.path.exists(project_file):
    with open(project_file, "r") as f:
        content = f.read()
    
    if "trl==0.11.4" not in content:
        content = re.sub(r'trl[<>=!~]+[\d\.]+', 'trl==0.11.4', content)
        with open(project_file, "w") as f:
            f.write(content)
        print("Config file patched to lock TRL 0.11.4")

# --- 4. Patch Source Code (Compatibility Fix) ---
runner_file = "/content/rlhf-canary/canary/runner/local.py"
if os.path.exists(runner_file):
    with open(runner_file, "r") as f:
        code = f.read()
    
    if "processing_class=" in code:
        code = code.replace("processing_class=", "tokenizer=")
        with open(runner_file, "w") as f:
            f.write(code)
        print("Code patched: Reverted 'processing_class' to 'tokenizer'")
    else:
        print("Code is already compatible.")

# --- 5. Install the package ---
!pip install -e . --quiet

print("Environment Ready!")

In [None]:
# Verify GPU and installation
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

import canary
print(f"Canary module loaded from: {canary.__file__}")

## 2. Understanding Stability Metrics

RLHF Canary tracks several stability indicators:

### Core Stability Metrics

| Metric | What it detects | Failure threshold |
|--------|-----------------|-------------------|
| `nan_steps` | NaN in loss/gradients | Any NaN = fail |
| `inf_steps` | Inf in loss/gradients | Any Inf = fail |
| `loss_diverged` | Loss increasing over time | late_loss > early_loss * 1.5 |
| `grad_norm_values` | Gradient explosion | Tracked for analysis |

### PPO-Specific Metrics

| Metric | What it indicates | Warning sign |
|--------|-------------------|-------------|
| `objective/kl` | Policy drift from reference | KL spike = instability |
| `objective/entropy` | Exploration diversity | Entropy collapse = stuck |
| `ppo/clipfrac` | Update clipping frequency | High clip = too aggressive |
| `ppo/policy_loss` | Actor loss | NaN/Inf = fail |
| `ppo/value_loss` | Critic loss | NaN/Inf = fail |

### Stability Keys (only these trigger NaN/Inf detection)

```python
STABILITY_KEYS = {
    "loss", "train_loss", "policy_loss", "value_loss",
    "grad_norm", "rewards/chosen", "rewards/rejected", "kl",
    # PPO-specific
    "objective/kl", "objective/entropy", "ppo/policy_loss",
    "ppo/value_loss", "ppo/clipfrac", "ppo/mean_non_score_reward",
}
```

## 3. Run a Stable Baseline (DPO)

First, let's run a stable training job to establish a baseline.

In [None]:
# Run a stable DPO canary
!python -m canary.cli run configs/dpo_smoke.yaml -o ./stability_output/stable_dpo

In [None]:
import json
from pathlib import Path

# Load and display stability metrics
stable_paths = list(Path('./stability_output/stable_dpo').rglob('metrics.json'))
if not stable_paths:
    raise FileNotFoundError("No metrics.json found for stable DPO run. Did the training complete?")

stable_path = stable_paths[0]
with open(stable_path) as f:
    stable_metrics = json.load(f)

print("="*60)
print("STABLE DPO RUN - STABILITY METRICS")
print("="*60)

stability = stable_metrics['stability']
print(f"\nNaN steps:      {stability['nan_steps']}")
print(f"Inf steps:      {stability['inf_steps']}")
print(f"Loss diverged:  {stability['loss_diverged']}")
print(f"Final loss:     {stability['final_loss']:.4f}")

# Show loss trajectory
loss_values = stability['loss_values']
if len(loss_values) > 10:
    print(f"\nLoss trajectory (first 5): {[f'{v:.4f}' for v in loss_values[:5]]}")
    print(f"Loss trajectory (last 5):  {[f'{v:.4f}' for v in loss_values[-5:]]}")

print("\nThis is a healthy run - no NaNs, no Infs, loss decreasing.")

## 4. Inject Instability: High Learning Rate

A common cause of training instability is a learning rate that's too high. Let's create a config that will cause gradient explosion.

In [None]:
# Create an unstable config with high learning rate
unstable_config = """
name: dpo_unstable_lr
description: Intentionally unstable config - high learning rate

# Model configuration
model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

# Training configuration - DANGEROUSLY HIGH LR
training_type: dpo
max_steps: 50
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 1.0e-2   # 200x higher than normal!
max_length: 256
warmup_steps: 5

# DPO-specific
beta: 0.1
max_prompt_length: 64

# Dataset configuration
dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 256
seed: 42

# Output configuration
output_dir: ./stability_output
metrics_warmup_steps: 5

profiler:
  enabled: false
"""

with open('configs/dpo_unstable_lr.yaml', 'w') as f:
    f.write(unstable_config)

print("Created unstable config with learning_rate=1.0e-2 (200x normal)")
print("This will likely cause gradient explosion or NaNs...")

In [None]:
# Run the unstable config
!python -m canary.cli run configs/dpo_unstable_lr.yaml -o ./stability_output/unstable_lr

In [None]:
# Check the stability metrics for the unstable run
unstable_paths = list(Path('./stability_output/unstable_lr').rglob('metrics.json'))

if unstable_paths:
    with open(unstable_paths[0]) as f:
        unstable_metrics = json.load(f)
    
    print("="*60)
    print("UNSTABLE RUN - STABILITY METRICS")
    print("="*60)
    
    stability = unstable_metrics['stability']
    print(f"\nNaN steps:      {stability['nan_steps']} {'FAIL!' if stability['nan_steps'] > 0 else ''}")
    print(f"Inf steps:      {stability['inf_steps']} {'FAIL!' if stability['inf_steps'] > 0 else ''}")
    print(f"Loss diverged:  {stability['loss_diverged']} {'WARNING!' if stability['loss_diverged'] else ''}")
    print(f"Final loss:     {stability['final_loss']}")
    
    # Show loss trajectory
    loss_values = stability['loss_values']
    if len(loss_values) > 10:
        print(f"\nLoss trajectory (first 5): {[f'{v:.4f}' for v in loss_values[:5]]}")
        print(f"Loss trajectory (last 5):  {[f'{v:.4f}' for v in loss_values[-5:]]}")
    elif loss_values:
        print(f"\nLoss values: {[f'{v:.4f}' for v in loss_values]}")
    
    # Check gradient norms if available
    grad_norms = stability.get('grad_norm_values', [])
    if grad_norms:
        print(f"\nGradient norms (last 5): {[f'{v:.2f}' for v in grad_norms[-5:]]}")
        max_grad = max(grad_norms)
        if max_grad > 100:
            print(f"Max gradient norm: {max_grad:.2f} - GRADIENT EXPLOSION!")
else:
    print("Run failed before producing metrics (training crashed early)")

## 5. Compare Stable vs Unstable with Canary

Let's use the canary comparison tool to see how it detects stability regressions.

In [None]:
# Compare unstable run to stable baseline
if unstable_paths:
    !python -m canary.cli compare {unstable_paths[0]} {stable_path} --threshold-tier smoke
else:
    print("No unstable metrics to compare (run crashed)")

## 6. PPO Stability Monitoring

PPO training has additional stability concerns:
- **KL divergence**: Policy drifting too far from reference
- **Entropy collapse**: Policy becoming too deterministic
- **Clip fraction**: Too much clipping = unstable updates

Let's run a PPO canary and examine these metrics.

In [None]:
# Run PPO smoke test
!python -m canary.cli run configs/ppo_smoke.yaml -o ./stability_output/ppo_stable

In [None]:
# Examine PPO stability metrics
ppo_paths = list(Path('./stability_output/ppo_stable').rglob('metrics.json'))

if ppo_paths:
    with open(ppo_paths[0]) as f:
        ppo_metrics = json.load(f)
    
    print("="*60)
    print("PPO RUN - STABILITY METRICS")
    print("="*60)
    
    stability = ppo_metrics['stability']
    print(f"\nCore Stability:")
    print(f"  NaN steps:      {stability['nan_steps']}")
    print(f"  Inf steps:      {stability['inf_steps']}")
    print(f"  Loss diverged:  {stability['loss_diverged']}")
    print(f"  Final loss:     {stability['final_loss']}")
    
    print(f"\nPPO-Specific (tracked in loss_values):")
    loss_values = stability.get('loss_values', [])
    if loss_values:
        print(f"  Loss trajectory: {len(loss_values)} values recorded")
        print(f"  First 3 losses: {[f'{v:.4f}' for v in loss_values[:3]]}")
        print(f"  Last 3 losses:  {[f'{v:.4f}' for v in loss_values[-3:]]}")
    
    print("\n" + "="*60)
    print("PPO HEALTH INDICATORS")
    print("="*60)
    print("\nNote: PPO-specific metrics (KL, entropy, clipfrac) are logged")
    print("during training and checked for NaN/Inf. Look at training logs")
    print("above for 'objective/kl', 'objective/entropy', 'ppo/clipfrac'.")
    
    print("\nHealthy PPO indicators:")
    print("  - KL divergence: Should stay below target_kl (6.0)")
    print("  - Entropy: Should decrease slowly, not collapse")
    print("  - Clip fraction: < 0.2 is normal, > 0.3 is concerning")
else:
    print("PPO run failed to produce metrics")

## 7. Inject PPO Instability: Extreme KL Coefficient

Let's create a PPO config that will cause KL explosion by setting an extreme KL coefficient.

In [None]:
# Create unstable PPO config
unstable_ppo_config = """
name: ppo_unstable_kl
description: Intentionally unstable PPO - extreme init_kl_coef

# Model configuration
model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

# Training configuration
training_type: ppo
max_steps: 30
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 1.0e-4   # Higher LR
max_length: 256
warmup_steps: 2

# PPO-specific - UNSTABLE SETTINGS
ppo_epochs: 8         # More epochs per step
init_kl_coef: 0.01    # Very low KL penalty - allows policy to drift
target_kl: 100.0      # Very high target - won't adapt
cliprange: 0.4        # Wider clipping - less stable
vf_coef: 0.1
max_prompt_length: 64
max_new_tokens: 64
use_synthetic_reward: true

# Dataset configuration
dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 128
seed: 42

# Output configuration
output_dir: ./stability_output
metrics_warmup_steps: 2

profiler:
  enabled: false
"""

with open('configs/ppo_unstable_kl.yaml', 'w') as f:
    f.write(unstable_ppo_config)

print("Created unstable PPO config:")
print("  - init_kl_coef: 0.01 (normal: 0.2) - weak KL penalty")
print("  - target_kl: 100.0 (normal: 6.0) - won't adapt")
print("  - cliprange: 0.4 (normal: 0.2) - less stable updates")
print("  - ppo_epochs: 8 (normal: 4) - more aggressive updates")

In [None]:
# Run the unstable PPO config
!python -m canary.cli run configs/ppo_unstable_kl.yaml -o ./stability_output/ppo_unstable

In [None]:
# Compare PPO runs
ppo_unstable_paths = list(Path('./stability_output/ppo_unstable').rglob('metrics.json'))

if ppo_unstable_paths and ppo_paths:
    with open(ppo_unstable_paths[0]) as f:
        ppo_unstable_metrics = json.load(f)
    
    print("="*60)
    print("PPO COMPARISON: Stable vs Unstable")
    print("="*60)
    
    stable_stab = ppo_metrics['stability']
    unstable_stab = ppo_unstable_metrics['stability']
    
    print(f"\n{'Metric':<20} {'Stable':>15} {'Unstable':>15}")
    print("-"*50)
    print(f"{'NaN steps':<20} {stable_stab['nan_steps']:>15} {unstable_stab['nan_steps']:>15}")
    print(f"{'Inf steps':<20} {stable_stab['inf_steps']:>15} {unstable_stab['inf_steps']:>15}")
    print(f"{'Loss diverged':<20} {str(stable_stab['loss_diverged']):>15} {str(unstable_stab['loss_diverged']):>15}")
    
    stable_final = stable_stab.get('final_loss')
    unstable_final = unstable_stab.get('final_loss')
    stable_str = f"{stable_final:.4f}" if stable_final else "N/A"
    unstable_str = f"{unstable_final:.4f}" if unstable_final else "N/A"
    print(f"{'Final loss':<20} {stable_str:>15} {unstable_str:>15}")

## 8. Early Warning Patterns

The goal of stability monitoring is to predict run failure **before** it happens. Here are patterns to watch for:

### Warning Signs (30+ minutes before failure)

| Pattern | What it means | Action |
|---------|--------------|--------|
| Loss increasing for 5+ steps | Training diverging | Reduce LR, check data |
| Gradient norm > 10x normal | Gradient explosion starting | Add gradient clipping |
| KL spike (PPO) | Policy drifting from reference | Increase KL penalty |
| Entropy dropping rapidly | Policy collapsing | Add entropy bonus |
| Clip fraction > 0.3 (PPO) | Updates too aggressive | Reduce LR or cliprange |

### The "Death Spiral" Pattern

```
Step N:   Loss slightly up      <- Early warning
Step N+5: Loss up more, grad norm rising
Step N+10: Grad norm spiking    <- Last chance to save
Step N+15: NaN detected         <- Too late
```

In [None]:
# Demonstrate loss trajectory analysis
def analyze_loss_trajectory(loss_values):
    """Analyze loss trajectory for early warning signs."""
    if len(loss_values) < 10:
        return "Insufficient data for trajectory analysis"
    
    # Split into early and late phases
    early = loss_values[:len(loss_values)//3]
    late = loss_values[-len(loss_values)//3:]
    
    import statistics
    early_avg = statistics.mean(early)
    late_avg = statistics.mean(late)
    
    # Check for concerning patterns
    warnings = []
    
    if late_avg > early_avg * 1.2:
        warnings.append(f"WARN: Loss increased {(late_avg/early_avg - 1)*100:.1f}% from early to late")
    
    if late_avg > early_avg * 1.5:
        warnings.append("CRITICAL: Loss divergence detected!")
    
    # Check for increasing trend in recent steps
    recent = loss_values[-5:]
    if all(recent[i] < recent[i+1] for i in range(len(recent)-1)):
        warnings.append("WARN: Loss monotonically increasing in last 5 steps")
    
    # Check for high variance (instability)
    if len(late) > 1:
        late_std = statistics.stdev(late)
        if late_std > abs(late_avg) * 0.5:
            warnings.append(f"WARN: High loss variance in late phase (std={late_std:.4f})")
    
    if not warnings:
        return "OK: Loss trajectory looks healthy"
    
    return "\n".join(warnings)

# Analyze our runs
print("="*60)
print("LOSS TRAJECTORY ANALYSIS")
print("="*60)

print("\n--- Stable DPO ---")
print(analyze_loss_trajectory(stable_metrics['stability']['loss_values']))

if unstable_paths:
    print("\n--- Unstable DPO (high LR) ---")
    print(analyze_loss_trajectory(unstable_metrics['stability']['loss_values']))

## 9. Using Stability Gates in CI/CD

The canary system uses stability checks as hard gates. Here's how they work:

```yaml
# Stability thresholds (from canary/compare/thresholds.py)
nan_steps_allowed: 0      # Any NaN = FAIL
inf_steps_allowed: 0      # Any Inf = FAIL
```

These are **non-negotiable** gates - even one NaN or Inf will fail the canary.

In [None]:
# Programmatic access to stability comparison
from canary.compare.stats import compare_to_baseline, load_metrics
from canary.compare.thresholds import SMOKE_THRESHOLDS

# Load metrics
stable = load_metrics(str(stable_path))

print("="*60)
print("STABILITY THRESHOLDS (Smoke Tier)")
print("="*60)

print(f"\nnan_steps_allowed: {SMOKE_THRESHOLDS.nan_steps_allowed}")
print(f"inf_steps_allowed: {SMOKE_THRESHOLDS.inf_steps_allowed}")
print(f"max_step_time_increase_pct: {SMOKE_THRESHOLDS.max_step_time_increase_pct}%")
print(f"max_tps_drop_pct: {SMOKE_THRESHOLDS.max_tps_drop_pct}%")
print(f"max_mem_increase_mb: {SMOKE_THRESHOLDS.max_mem_increase_mb}MB")
print(f"max_mem_increase_pct: {SMOKE_THRESHOLDS.max_mem_increase_pct}%")

print("\nNote: Stability checks (NaN/Inf) are HARD gates - zero tolerance.")

## 10. Summary

### Key Takeaways:

1. **Stability metrics** are tracked automatically during training
2. **NaN/Inf detection** only checks stability-relevant keys (loss, gradients, KL, etc.)
3. **Loss divergence** is detected by comparing early vs late training phases
4. **PPO has additional concerns**: KL divergence, entropy collapse, clip fraction
5. **Early warning patterns** can predict failure 30+ minutes before NaN

### When Stability Checks Fail:

1. **NaN detected**: Check learning rate, gradient clipping, data preprocessing
2. **Loss diverging**: Reduce learning rate, check data quality
3. **KL explosion (PPO)**: Increase KL penalty, reduce number of PPO epochs
4. **Entropy collapse (PPO)**: Add entropy bonus, check reward signal

### Next Steps:

- See `04_root_cause_analysis.ipynb` for debugging regression causes
- See `05_ppo_canary.ipynb` for detailed PPO canary workflows