# SFT Canary: Supervised Fine-Tuning Validation

This notebook demonstrates how to use RLHF Canary for SFT (Supervised Fine-Tuning) validation.

**What you'll learn:**
1. SFT vs DPO vs PPO: When to use each
2. Running SFT canaries
3. SFT-specific characteristics and metrics
4. Memory efficiency of SFT
5. Comparing SFT baselines

**Requirements:** GPU runtime (Runtime > Change runtime type > T4 GPU)

**Runtime:** ~8-10 minutes

## 1. Setup

In [None]:
import os
import re
import sys

print("Starting Environment Setup...")

# --- 1. Clone the repo first ---
if not os.path.exists("/content/rlhf-canary"):
    !git clone https://github.com/mmcmanus1/rlhf-canary.git /content/rlhf-canary

%cd /content/rlhf-canary

# --- 2. Force-Install the "Safe Harbor" Stack ---
!pip install "trl==0.11.4" "transformers==4.44.2" "peft==0.12.0" "accelerate==0.34.2" "tokenizers==0.19.1" --force-reinstall --no-deps --quiet
!pip install -q datasets pydantic click PyYAML bitsandbytes
print("Libraries installed (TRL 0.11.4 / Transformers 4.44.2)")

# --- 3. Patch pyproject.toml (Prevent future drift) ---
project_file = "/content/rlhf-canary/pyproject.toml"
if os.path.exists(project_file):
    with open(project_file, "r") as f:
        content = f.read()
    
    if "trl==0.11.4" not in content:
        content = re.sub(r'trl[<>=!~]+[\d\.]+', 'trl==0.11.4', content)
        with open(project_file, "w") as f:
            f.write(content)
        print("Config file patched to lock TRL 0.11.4")

# --- 4. Patch Source Code (Compatibility Fix) ---
runner_file = "/content/rlhf-canary/canary/runner/local.py"
if os.path.exists(runner_file):
    with open(runner_file, "r") as f:
        code = f.read()
    
    if "processing_class=" in code:
        code = code.replace("processing_class=", "tokenizer=")
        with open(runner_file, "w") as f:
            f.write(code)
        print("Code patched: Reverted 'processing_class' to 'tokenizer'")
    else:
        print("Code is already compatible.")

# --- 5. Install the package ---
!pip install -e . --quiet

print("Environment Ready!")

In [None]:
# Verify GPU and installation
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

import canary
print(f"Canary module loaded from: {canary.__file__}")

## 2. Understanding SFT vs DPO vs PPO

### Training Methods Comparison

| Aspect | SFT | DPO | PPO |
|--------|-----|-----|-----|
| **Goal** | Imitate good responses | Learn from preferences | Optimize reward |
| **Data** | Text completions | Preference pairs | Prompts + reward model |
| **Memory** | ~1x model | ~2x model | ~3x model |
| **Complexity** | Simple | Medium | Complex |
| **Speed** | Fastest | Medium | Slowest |

### When to Use SFT Canaries

1. **Pre-training validation**: Before running expensive DPO/PPO
2. **Baseline comparison**: Establish performance floor
3. **Infrastructure testing**: Verify training pipeline works
4. **Memory-constrained**: When DPO/PPO won't fit
5. **Quick iteration**: Fastest feedback loop

### SFT Data Format

```python
# SFT uses single text sequences
{"text": "Human: What is 2+2?\n\nAssistant: 2+2 equals 4."}

# vs DPO which needs preference pairs
{"chosen": "Good response", "rejected": "Bad response"}

# vs PPO which generates and scores
{"query": "Human: What is 2+2?\n\nAssistant:"} → generate → score
```

## 3. SFT Configuration

In [None]:
# Show the SFT smoke config
!cat configs/sft_smoke.yaml

In [None]:
# Key differences from DPO config:
print("SFT vs DPO Configuration Differences:")
print("="*60)

diffs = [
    ("training_type", "sft", "dpo"),
    ("beta", "N/A (not needed)", "0.1 (KL penalty)"),
    ("max_prompt_length", "N/A (not needed)", "64"),
    ("Data usage", "'chosen' column only", "Both 'chosen' and 'rejected'"),
    ("Reference model", "Not created", "Created internally (~2x memory)"),
]

print(f"\n{'Parameter':<20} {'SFT':>20} {'DPO':>20}")
print("-"*60)
for param, sft_val, dpo_val in diffs:
    print(f"{param:<20} {sft_val:>20} {dpo_val:>20}")

## 4. Run SFT Baseline Canary

In [None]:
# Run SFT smoke test (~5-8 min)
!python -m canary.cli run configs/sft_smoke.yaml -o ./sft_output/baseline

In [None]:
import json
from pathlib import Path

# Load and display SFT metrics
sft_baseline_paths = list(Path('./sft_output/baseline').rglob('metrics.json'))
if not sft_baseline_paths:
    raise FileNotFoundError("No metrics.json found for SFT baseline. Did the training complete?")

sft_baseline_path = sft_baseline_paths[0]
with open(sft_baseline_path) as f:
    sft_metrics = json.load(f)

print("="*60)
print("SFT BASELINE METRICS")
print("="*60)

print(f"\nPerformance:")
print(f"  Step time (mean): {sft_metrics['perf']['step_time']['mean']:.4f}s")
print(f"  Tokens/sec: {sft_metrics['perf']['approx_tokens_per_sec']:.0f}")
print(f"  Peak memory: {sft_metrics['perf']['max_mem_mb']:.0f}MB")

print(f"\nStability:")
print(f"  NaN steps: {sft_metrics['stability']['nan_steps']}")
print(f"  Inf steps: {sft_metrics['stability']['inf_steps']}")
print(f"  Loss diverged: {sft_metrics['stability']['loss_diverged']}")
print(f"  Final loss: {sft_metrics['stability']['final_loss']:.4f}")

## 5. Memory Comparison: SFT vs DPO vs PPO

Let's run DPO and compare memory usage to see SFT's efficiency.

In [None]:
# Run DPO for comparison
!python -m canary.cli run configs/dpo_smoke.yaml -o ./sft_output/dpo_comparison

In [None]:
# Load DPO metrics
dpo_paths = list(Path('./sft_output/dpo_comparison').rglob('metrics.json'))
if not dpo_paths:
    raise FileNotFoundError("No metrics.json found for DPO comparison. Did the training complete?")

dpo_path = dpo_paths[0]
with open(dpo_path) as f:
    dpo_metrics = json.load(f)

print("="*60)
print("MEMORY & PERFORMANCE COMPARISON")
print("="*60)

sft_mem = sft_metrics['perf']['max_mem_mb']
dpo_mem = dpo_metrics['perf']['max_mem_mb']

sft_step = sft_metrics['perf']['step_time']['mean']
dpo_step = dpo_metrics['perf']['step_time']['mean']

sft_tps = sft_metrics['perf']['approx_tokens_per_sec']
dpo_tps = dpo_metrics['perf']['approx_tokens_per_sec']

# Sanity check for valid values before computing ratios
if sft_mem > 0 and sft_step > 0 and dpo_tps > 0:
    print(f"\n{'Metric':<20} {'SFT':>15} {'DPO':>15} {'Ratio':>15}")
    print("-"*65)
    print(f"{'Peak Memory (MB)':<20} {sft_mem:>15.0f} {dpo_mem:>15.0f} {dpo_mem/sft_mem:>14.1f}x")
    print(f"{'Step Time (s)':<20} {sft_step:>15.4f} {dpo_step:>15.4f} {dpo_step/sft_step:>14.1f}x")
    print(f"{'Tokens/sec':<20} {sft_tps:>15.0f} {dpo_tps:>15.0f} {sft_tps/dpo_tps:>14.1f}x")

    print(f"\nKey Insight:")
    print(f"  - DPO uses {dpo_mem/sft_mem:.1f}x more memory than SFT")
    print(f"  - SFT processes {sft_tps/dpo_tps:.1f}x more tokens per second")
    print(f"  - This is because DPO creates an internal reference model")
else:
    print("Warning: Some metrics appear invalid, skipping ratio calculations")

## 6. Save SFT Baseline and Compare Runs

In [None]:
# Save SFT baseline
!mkdir -p baselines
!cp {sft_baseline_path} baselines/sft_baseline.json
print("Saved SFT baseline to baselines/sft_baseline.json")

In [None]:
# Run another SFT canary with different seed
sft_run2_config = """
name: sft_run2
description: Second SFT run for comparison

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

training_type: sft
max_steps: 100
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 5.0e-5
max_length: 256
warmup_steps: 10

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 512
seed: 123  # Different seed

output_dir: ./sft_output
metrics_warmup_steps: 10

profiler:
  enabled: false
"""

with open('configs/sft_run2.yaml', 'w') as f:
    f.write(sft_run2_config)

print("Created second SFT config with different seed")

In [None]:
# Run second SFT canary
!python -m canary.cli run configs/sft_run2.yaml -o ./sft_output/run2

In [None]:
# Compare to baseline
run2_paths = list(Path('./sft_output/run2').rglob('metrics.json'))
if not run2_paths:
    raise FileNotFoundError("No metrics.json found for run2. Did the training complete?")

run2_path = run2_paths[0]
!python -m canary.cli compare {run2_path} baselines/sft_baseline.json --threshold-tier smoke

## 7. SFT with Profiler

SFT supports profiler integration (unlike PPO's manual loop). Let's run a profiled SFT job.

In [None]:
# Create SFT config with profiler
sft_profiled_config = """
name: sft_profiled
description: SFT with profiler enabled

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

training_type: sft
max_steps: 80
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 5.0e-5
max_length: 256
warmup_steps: 10

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 512
seed: 42

output_dir: ./sft_output
metrics_warmup_steps: 10

# Enable profiler
profiler:
  enabled: true
  start_step: 50
  num_steps: 20
  output_dir: ./sft_profiler_traces
  record_shapes: true
  profile_memory: true
  with_stack: false
"""

with open('configs/sft_profiled.yaml', 'w') as f:
    f.write(sft_profiled_config)

print("Created profiled SFT config")

In [None]:
# Run profiled SFT
!python -m canary.cli run configs/sft_profiled.yaml -o ./sft_output/profiled

In [None]:
# Check profiler output
profiled_paths = list(Path('./sft_output/profiled').rglob('metrics.json'))
if not profiled_paths:
    raise FileNotFoundError("No metrics.json found for profiled run. Did the training complete?")

profiled_path = profiled_paths[0]
with open(profiled_path) as f:
    profiled_metrics = json.load(f)

if profiled_metrics.get('profiler'):
    prof = profiled_metrics['profiler']
    print("="*60)
    print("SFT PROFILER SUMMARY")
    print("="*60)
    print(f"\nTotal CUDA time: {prof.get('cuda_time_total_ms', 0):.2f} ms")
    print(f"Total CPU time: {prof.get('cpu_time_total_ms', 0):.2f} ms")
    
    if prof.get('top_cuda_ops'):
        print(f"\nTop 5 CUDA operations:")
        for op in prof['top_cuda_ops'][:5]:
            print(f"  {op['name'][:40]:<40} {op['self_cuda_time_ms']:>8.2f}ms")
else:
    print("No profiler data in metrics")

## 8. SFT Stability Testing

Let's test SFT stability detection with an intentionally unstable config.

In [None]:
# Create unstable SFT config
sft_unstable_config = """
name: sft_unstable
description: Intentionally unstable SFT - high learning rate

model_name: EleutherAI/pythia-70m
use_peft: true
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05

training_type: sft
max_steps: 50
batch_size: 2
gradient_accumulation_steps: 4
learning_rate: 1.0e-2   # 200x higher than normal!
max_length: 256
warmup_steps: 2

dataset_name: Anthropic/hh-rlhf
dataset_split: train
dataset_size: 256
seed: 42

output_dir: ./sft_output
metrics_warmup_steps: 2

profiler:
  enabled: false
"""

with open('configs/sft_unstable.yaml', 'w') as f:
    f.write(sft_unstable_config)

print("Created unstable SFT config with learning_rate=1.0e-2")

In [None]:
# Run unstable SFT
!python -m canary.cli run configs/sft_unstable.yaml -o ./sft_output/unstable

In [None]:
# Check stability metrics
unstable_paths = list(Path('./sft_output/unstable').rglob('metrics.json'))

if unstable_paths:
    with open(unstable_paths[0]) as f:
        unstable_metrics = json.load(f)
    
    print("="*60)
    print("UNSTABLE SFT ANALYSIS")
    print("="*60)
    
    stability = unstable_metrics['stability']
    print(f"\nStability Metrics:")
    print(f"  NaN steps: {stability['nan_steps']} {'FAIL!' if stability['nan_steps'] > 0 else ''}")
    print(f"  Inf steps: {stability['inf_steps']} {'FAIL!' if stability['inf_steps'] > 0 else ''}")
    print(f"  Loss diverged: {stability['loss_diverged']} {'WARNING!' if stability['loss_diverged'] else ''}")
    print(f"  Final loss: {stability['final_loss']}")
    
    # Compare to baseline
    print("\n" + "="*60)
    print("COMPARISON TO BASELINE")
    print("="*60)
    !python -m canary.cli compare {unstable_paths[0]} baselines/sft_baseline.json --threshold-tier smoke
else:
    print("Unstable run failed to produce metrics")

## 9. When to Use SFT Canaries

### SFT Canary Use Cases

| Scenario | Why SFT Canary? |
|----------|----------------|
| **New model testing** | Quick validation before DPO/PPO |
| **Infrastructure changes** | Fastest feedback on training pipeline |
| **Memory-limited GPUs** | SFT fits when DPO/PPO won't |
| **Baseline establishment** | Compare DPO/PPO improvements against SFT |
| **CI/CD gating** | Fastest canary for PR validation |

### SFT Canary Limitations

- **No preference learning**: Can't test DPO-specific issues
- **No RL dynamics**: Can't test PPO stability (KL, entropy, etc.)
- **Simpler failure modes**: Fewer things can go wrong

### Recommended Canary Strategy

```
PR Gating:     SFT smoke (5 min) → DPO smoke (10 min)
Daily:         DPO perf (45 min) → PPO perf (60 min)
Weekly:        Full suite with nightly configs
```

In [None]:
# Summary comparison of all training types
print("="*70)
print("FINAL COMPARISON: SFT vs DPO")
print("="*70)

print(f"\n{'Metric':<25} {'SFT':>20} {'DPO':>20}")
print("-"*65)

print(f"{'Training type':<25} {'sft':>20} {'dpo':>20}")
print(f"{'Memory (MB)':<25} {sft_metrics['perf']['max_mem_mb']:>20.0f} {dpo_metrics['perf']['max_mem_mb']:>20.0f}")
print(f"{'Step time (s)':<25} {sft_metrics['perf']['step_time']['mean']:>20.4f} {dpo_metrics['perf']['step_time']['mean']:>20.4f}")
print(f"{'Tokens/sec':<25} {sft_metrics['perf']['approx_tokens_per_sec']:>20.0f} {dpo_metrics['perf']['approx_tokens_per_sec']:>20.0f}")
print(f"{'Final loss':<25} {sft_metrics['stability']['final_loss']:>20.4f} {dpo_metrics['stability']['final_loss']:>20.4f}")

print("\nKey Takeaways:")
print(f"  - SFT uses {dpo_metrics['perf']['max_mem_mb']/sft_metrics['perf']['max_mem_mb']:.1f}x less memory than DPO")
print(f"  - SFT is {sft_metrics['perf']['approx_tokens_per_sec']/dpo_metrics['perf']['approx_tokens_per_sec']:.1f}x faster (tokens/sec)")
print(f"  - Use SFT for quick validation, DPO for preference learning")

## 10. Summary

### Key Takeaways:

1. **SFT is the simplest and fastest** training method
2. **Memory efficient**: ~1x model vs DPO's ~2x and PPO's ~3x
3. **Supports profiler**: Unlike PPO's manual loop
4. **Same stability metrics**: NaN/Inf detection, loss divergence
5. **Best for quick validation**: Before running expensive DPO/PPO

### When to Choose SFT Canaries:

- Infrastructure testing
- PR gating (fastest feedback)
- Memory-constrained environments
- Baseline establishment

### Related Notebooks:

- `01_quickstart.ipynb` - DPO workflow basics
- `05_ppo_canary.ipynb` - PPO with RL metrics
- `02_profiler_deep_dive.ipynb` - Performance profiling
- `03_stability_monitoring.ipynb` - Stability metrics