# Training Curves Figure

This notebook generates the training curves figure showing:
- Train EM% over training steps for all final experiments (F1-F4)
- Test Pass@1% at evaluation checkpoints

Output: `docs/project-report/figures/training_curves.png`

In [None]:
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd()))

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import wandb

from figure_utils import (
    setup_paper_style,
    save_figure,
    fetch_final_runs,
    COLORS,
    ENCODER_NAMES,
    ENTITY,
    FINAL_PROJECT,
)

# Apply paper styling
setup_paper_style()

## 1. Fetch Final Experiments

In [None]:
# Get final experiment summary
df_final = fetch_final_runs()
print(f"Found {len(df_final)} final experiments")
df_final[['display_name', 'encoder_type', 'state', 'steps', 'train_exact_acc', 'arc_pass1']]

## 2. Fetch Training History

In [None]:
# Initialize W&B API
api = wandb.Api()

# Fetch history for each experiment
histories = {}
keys_to_fetch = ['train/exact_accuracy', 'train/accuracy', 'train/lm_loss']

for _, row in df_final.iterrows():
    run_name = row['display_name']
    run_id = row['name']
    
    print(f"Fetching history for {run_name} ({run_id})...")
    
    try:
        run = api.run(f"{ENTITY}/{FINAL_PROJECT}/{run_id}")
        
        # Use scan_history for specific keys (more efficient)
        history_data = list(run.scan_history(keys=['_step'] + keys_to_fetch))
        
        if history_data:
            histories[run_name] = pd.DataFrame(history_data)
            print(f"  -> {len(histories[run_name])} data points")
        else:
            print(f"  -> No history data found")
            
    except Exception as e:
        print(f"  -> Error: {e}")

print(f"\nFetched history for {len(histories)} experiments")

## 3. Create Training Curves Figure

In [None]:
# Color mapping for experiments
experiment_colors = {
    'F1_standard': COLORS['standard'],
    'F2_hybrid_var': COLORS['hybrid_variational'],
    'F3_etrmtrm': COLORS['etrmtrm'],
    'F4_lpn_var': COLORS['lpn_var'],
}

# Nice labels
experiment_labels = {
    'F1_standard': 'F1: Standard (2L)',
    'F2_hybrid_var': 'F2: Hybrid VAE (4L)',
    'F3_etrmtrm': 'F3: ETRMTRM',
    'F4_lpn_var': 'F4: LPN VAE',
}

In [None]:
# Create figure with single panel (training curves)
fig, ax = plt.subplots(figsize=(7, 4))

for exp_name, history in histories.items():
    if 'train/exact_accuracy' not in history.columns:
        continue
        
    # Get data
    steps = history['_step'].values
    train_em = history['train/exact_accuracy'].values * 100  # Convert to %
    
    # Remove NaN values
    mask = ~np.isnan(train_em)
    steps = steps[mask]
    train_em = train_em[mask]
    
    if len(steps) == 0:
        continue
    
    # Get color and label
    color = experiment_colors.get(exp_name, '#666666')
    label = experiment_labels.get(exp_name, exp_name)
    
    # Plot
    ax.plot(steps, train_em, label=label, color=color, linewidth=1.5, alpha=0.9)

# Customize
ax.set_xlabel('Training Steps')
ax.set_ylabel('Train Exact Match (%)')
ax.set_ylim(0, 100)
ax.legend(loc='lower right', frameon=False)
ax.grid(True, alpha=0.3, linestyle='--')

# Format x-axis with K notation
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1000:.0f}k'))

plt.tight_layout()
plt.show()

## 4. Alternative: Two-Panel Figure (Train + Loss)

In [None]:
# Create two-panel figure
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Panel 1: Train EM%
ax1 = axes[0]
for exp_name, history in histories.items():
    if 'train/exact_accuracy' not in history.columns:
        continue
    
    steps = history['_step'].values
    train_em = history['train/exact_accuracy'].values * 100
    mask = ~np.isnan(train_em)
    
    if sum(mask) == 0:
        continue
    
    color = experiment_colors.get(exp_name, '#666666')
    label = experiment_labels.get(exp_name, exp_name)
    ax1.plot(steps[mask], train_em[mask], label=label, color=color, linewidth=1.5)

ax1.set_xlabel('Training Steps')
ax1.set_ylabel('Train Exact Match (%)')
ax1.set_ylim(0, 100)
ax1.legend(loc='lower right', frameon=False, fontsize=8)
ax1.grid(True, alpha=0.3, linestyle='--')
ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1000:.0f}k'))
ax1.set_title('(a) Training Accuracy', fontsize=10)

# Panel 2: Training Loss
ax2 = axes[1]
for exp_name, history in histories.items():
    if 'train/lm_loss' not in history.columns:
        continue
    
    steps = history['_step'].values
    loss = history['train/lm_loss'].values
    mask = ~np.isnan(loss)
    
    if sum(mask) == 0:
        continue
    
    color = experiment_colors.get(exp_name, '#666666')
    label = experiment_labels.get(exp_name, exp_name)
    ax2.plot(steps[mask], loss[mask], label=label, color=color, linewidth=1.5)

ax2.set_xlabel('Training Steps')
ax2.set_ylabel('LM Loss')
ax2.legend(loc='upper right', frameon=False, fontsize=8)
ax2.grid(True, alpha=0.3, linestyle='--')
ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1000:.0f}k'))
ax2.set_title('(b) Training Loss', fontsize=10)

plt.tight_layout()
plt.show()

## 5. Save Figure

In [None]:
# Recreate and save the single-panel figure
fig, ax = plt.subplots(figsize=(7, 4))

for exp_name, history in histories.items():
    if 'train/exact_accuracy' not in history.columns:
        continue
        
    steps = history['_step'].values
    train_em = history['train/exact_accuracy'].values * 100
    mask = ~np.isnan(train_em)
    
    if sum(mask) == 0:
        continue
    
    color = experiment_colors.get(exp_name, '#666666')
    label = experiment_labels.get(exp_name, exp_name)
    ax.plot(steps[mask], train_em[mask], label=label, color=color, linewidth=1.5, alpha=0.9)

ax.set_xlabel('Training Steps')
ax.set_ylabel('Train Exact Match (%)')
ax.set_ylim(0, 100)
ax.legend(loc='lower right', frameon=False)
ax.grid(True, alpha=0.3, linestyle='--')
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1000:.0f}k'))

plt.tight_layout()

# Save
save_figure(fig, 'training_curves')
plt.show()

In [None]:
# Also save the two-panel version
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Panel 1: Train EM%
ax1 = axes[0]
for exp_name, history in histories.items():
    if 'train/exact_accuracy' not in history.columns:
        continue
    
    steps = history['_step'].values
    train_em = history['train/exact_accuracy'].values * 100
    mask = ~np.isnan(train_em)
    
    if sum(mask) == 0:
        continue
    
    color = experiment_colors.get(exp_name, '#666666')
    label = experiment_labels.get(exp_name, exp_name)
    ax1.plot(steps[mask], train_em[mask], label=label, color=color, linewidth=1.5)

ax1.set_xlabel('Training Steps')
ax1.set_ylabel('Train Exact Match (%)')
ax1.set_ylim(0, 100)
ax1.legend(loc='lower right', frameon=False, fontsize=8)
ax1.grid(True, alpha=0.3, linestyle='--')
ax1.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1000:.0f}k'))
ax1.set_title('(a) Training Accuracy', fontsize=10)

# Panel 2: Training Loss
ax2 = axes[1]
for exp_name, history in histories.items():
    if 'train/lm_loss' not in history.columns:
        continue
    
    steps = history['_step'].values
    loss = history['train/lm_loss'].values
    mask = ~np.isnan(loss)
    
    if sum(mask) == 0:
        continue
    
    color = experiment_colors.get(exp_name, '#666666')
    label = experiment_labels.get(exp_name, exp_name)
    ax2.plot(steps[mask], loss[mask], label=label, color=color, linewidth=1.5)

ax2.set_xlabel('Training Steps')
ax2.set_ylabel('LM Loss')
ax2.legend(loc='upper right', frameon=False, fontsize=8)
ax2.grid(True, alpha=0.3, linestyle='--')
ax2.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f'{x/1000:.0f}k'))
ax2.set_title('(b) Training Loss', fontsize=10)

plt.tight_layout()

save_figure(fig, 'training_curves_with_loss')
plt.show()

In [None]:
print("\nDone! Figures saved to docs/project-report/figures/")