# Experiment Analysis for Report

This notebook generates figures and tables for the experiments section:
1. Encoder collapse evidence (cross-sample variance)
2. Training curves (accuracy vs steps)
3. Qualitative examples (predictions vs ground truth)

In [None]:
import sys
sys.path.insert(0, '../..')

import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import wandb
from pathlib import Path
from IPython.display import display, Markdown

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

# Output directory
OUTPUT_DIR = Path('./outputs')
OUTPUT_DIR.mkdir(exist_ok=True)

# W&B settings
ENTITY = 'bdsaglam'

## 1. Training Curves

Fetch training accuracy over steps for F1, F2, F3 from W&B.

In [None]:
# Run info for final experiments
FINAL_RUNS = {
    'F1_standard': {'project': 'etrm-final', 'run_id': 'z31hae14', 'label': 'ETRM-Deterministic', 'color': '#1f77b4'},
    'F2_hybrid_var': {'project': 'etrm-final', 'run_id': '7km7llbl', 'label': 'ETRM-Variational', 'color': '#ff7f0e'},
    'F3_etrmtrm': {'project': 'etrm-final', 'run_id': 'wj3xu8md', 'label': 'ETRM-TRM', 'color': '#2ca02c'},
}

# TRM baseline
TRM_RUN = {'project': 'Arc1concept-aug-1000-ACT-torch', 'run_id': '2jpjeuav', 'label': 'TRM (baseline)', 'color': '#d62728'}

In [None]:
def fetch_training_history(entity: str, project: str, run_id: str, metric: str = 'train/exact_accuracy') -> pd.DataFrame:
    """Fetch training history from W&B."""
    api = wandb.Api()
    run = api.run(f'{entity}/{project}/{run_id}')
    
    # Get history
    history = run.history(keys=['_step', metric], samples=1000)
    history = history[history[metric].notna()]
    
    return history


# Fetch histories for all runs
histories = {}

for name, info in FINAL_RUNS.items():
    print(f"Fetching {name}...")
    try:
        histories[name] = fetch_training_history(ENTITY, info['project'], info['run_id'])
        print(f"  Got {len(histories[name])} data points")
    except Exception as e:
        print(f"  Error: {e}")

# Also fetch TRM for comparison
print(f"Fetching TRM baseline...")
try:
    histories['TRM'] = fetch_training_history(ENTITY, TRM_RUN['project'], TRM_RUN['run_id'])
    print(f"  Got {len(histories['TRM'])} data points")
except Exception as e:
    print(f"  Error: {e}")

In [None]:
# Plot training curves
fig, ax = plt.subplots(figsize=(10, 6))

# Plot ETRM runs
for name, info in FINAL_RUNS.items():
    if name in histories and not histories[name].empty:
        df = histories[name]
        ax.plot(df['_step'] / 1000, df['train/exact_accuracy'] * 100, 
                label=info['label'], color=info['color'], linewidth=2)

# Plot TRM baseline
if 'TRM' in histories and not histories['TRM'].empty:
    df = histories['TRM']
    ax.plot(df['_step'] / 1000, df['train/exact_accuracy'] * 100,
            label=TRM_RUN['label'], color=TRM_RUN['color'], linewidth=2, linestyle='--')

ax.set_xlabel('Training Steps (k)', fontsize=14)
ax.set_ylabel('Training Accuracy (%)', fontsize=14)
ax.set_title('Training Accuracy Over Time', fontsize=16)
ax.legend(loc='lower right', fontsize=12)
ax.set_ylim(0, 105)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.savefig(OUTPUT_DIR / 'training_curves.pdf', bbox_inches='tight')
plt.show()

print(f"Saved to {OUTPUT_DIR / 'training_curves.png'}")

## 2. Encoder Collapse Evidence

To demonstrate encoder collapse, we need to:
1. Load trained ETRM models
2. Run inference on multiple different puzzles
3. Capture encoder outputs
4. Compute variance across puzzles

If encoder has collapsed, variance will be very low (outputs are similar regardless of input demos).

In [None]:
# Checkpoint paths for final experiments
CHECKPOINT_DIR = Path('../../checkpoints')

CHECKPOINTS = {
    'F1_standard': CHECKPOINT_DIR / 'etrm-final' / 'F1_standard' / 'step_174622',
    'F2_hybrid_var': CHECKPOINT_DIR / 'etrm-final' / 'F2_hybrid_var' / 'step_174240', 
    'F3_etrmtrm': CHECKPOINT_DIR / 'etrm-final' / 'F3_etrmtrm' / 'step_87310',
}

TRM_CHECKPOINT = CHECKPOINT_DIR / 'Arc1concept-aug-1000-ACT-torch' / 'pretrain_att_arc1concept_4' / 'step_518071'

# Check which checkpoints exist
for name, path in CHECKPOINTS.items():
    exists = path.exists()
    print(f"{name}: {path} - {'EXISTS' if exists else 'NOT FOUND'}")

print(f"TRM: {TRM_CHECKPOINT} - {'EXISTS' if TRM_CHECKPOINT.exists() else 'NOT FOUND'}")

In [None]:
# Find actual checkpoint paths
import glob

checkpoint_base = Path('../../checkpoints')
print("Looking for checkpoints...")
print()

# List all checkpoint directories
for item in sorted(checkpoint_base.glob('*')):
    if item.is_dir():
        print(f"{item.name}/")
        for sub in sorted(item.glob('*')):
            if sub.is_dir():
                # Check for checkpoint files
                ckpts = list(sub.glob('*.pt'))
                print(f"  {sub.name}/ ({len(ckpts)} checkpoints)")

In [None]:
# This cell will be filled in once we identify the correct checkpoint paths
# For now, let's define the analysis function

def compute_encoder_output_stats(model, dataloader, num_samples=100, device='cuda'):
    """
    Compute statistics of encoder outputs across different puzzle demos.
    
    Returns:
        dict with mean, std, and per-dimension variance of encoder outputs
    """
    model.eval()
    encoder_outputs = []
    
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= num_samples:
                break
            
            # Move batch to device
            demo_inputs = batch['demo_inputs'].to(device)
            demo_labels = batch['demo_labels'].to(device)
            demo_mask = batch['demo_mask'].to(device)
            
            # Get encoder output
            # This depends on the model architecture
            if hasattr(model, 'encoder'):
                enc_out = model.encoder(demo_inputs, demo_labels, demo_mask)
                encoder_outputs.append(enc_out.cpu())
    
    if not encoder_outputs:
        return None
    
    # Stack all outputs
    all_outputs = torch.cat(encoder_outputs, dim=0)  # [N, seq_len, hidden_dim]
    
    # Compute statistics
    mean = all_outputs.mean(dim=0)  # [seq_len, hidden_dim]
    std = all_outputs.std(dim=0)    # [seq_len, hidden_dim]
    
    # Overall variance across samples
    overall_variance = all_outputs.var(dim=0).mean().item()
    
    return {
        'mean': mean,
        'std': std,
        'overall_variance': overall_variance,
        'num_samples': len(encoder_outputs),
    }

print("Encoder analysis function defined.")
print("To run this analysis, we need to load the trained models and a dataloader.")

### Alternative: Analyze from W&B logged metrics

If we logged encoder output statistics during training, we can fetch them directly.

In [None]:
# Check what metrics are available in the runs
api = wandb.Api()

for name, info in FINAL_RUNS.items():
    print(f"\n{name}:")
    run = api.run(f"{ENTITY}/{info['project']}/{info['run_id']}")
    
    # Get summary keys
    summary_keys = [k for k in run.summary.keys() if not k.startswith('_')]
    print(f"  Summary metrics: {summary_keys[:20]}...")  # First 20
    
    # Check for encoder-related metrics
    encoder_metrics = [k for k in summary_keys if 'encoder' in k.lower() or 'emb' in k.lower() or 'variance' in k.lower()]
    if encoder_metrics:
        print(f"  Encoder-related metrics: {encoder_metrics}")

## 3. Qualitative Examples

Show side-by-side predictions: Input → TRM → ETRM → Ground Truth

In [None]:
# ARC color palette
ARC_COLORS = {
    0: '#000000',  # Black
    1: '#0074D9',  # Blue
    2: '#FF4136',  # Red
    3: '#2ECC40',  # Green
    4: '#FFDC00',  # Yellow
    5: '#AAAAAA',  # Gray
    6: '#F012BE',  # Magenta
    7: '#FF851B',  # Orange
    8: '#7FDBFF',  # Cyan
    9: '#870C25',  # Brown
}

def plot_grid(grid, ax, title=''):
    """Plot a single ARC grid."""
    from matplotlib.colors import ListedColormap
    
    cmap = ListedColormap([ARC_COLORS[i] for i in range(10)])
    
    ax.imshow(grid, cmap=cmap, vmin=0, vmax=9)
    ax.set_title(title, fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Add grid lines
    for i in range(grid.shape[0] + 1):
        ax.axhline(i - 0.5, color='white', linewidth=0.5)
    for j in range(grid.shape[1] + 1):
        ax.axvline(j - 0.5, color='white', linewidth=0.5)


def plot_example(demo_inputs, demo_outputs, test_input, test_output, 
                 trm_pred=None, etrm_pred=None, title=''):
    """Plot a complete example with demos, test, and predictions."""
    
    num_demos = len(demo_inputs)
    num_cols = max(num_demos, 4)  # At least 4 columns for predictions
    
    fig, axes = plt.subplots(3, num_cols, figsize=(3 * num_cols, 9))
    
    # Row 1: Demo inputs
    for i in range(num_cols):
        if i < num_demos:
            plot_grid(demo_inputs[i], axes[0, i], f'Demo {i+1} Input')
        else:
            axes[0, i].axis('off')
    
    # Row 2: Demo outputs
    for i in range(num_cols):
        if i < num_demos:
            plot_grid(demo_outputs[i], axes[1, i], f'Demo {i+1} Output')
        else:
            axes[1, i].axis('off')
    
    # Row 3: Test input, Ground truth, TRM pred, ETRM pred
    plot_grid(test_input, axes[2, 0], 'Test Input')
    plot_grid(test_output, axes[2, 1], 'Ground Truth')
    
    if trm_pred is not None:
        plot_grid(trm_pred, axes[2, 2], 'TRM Prediction')
    else:
        axes[2, 2].axis('off')
    
    if etrm_pred is not None:
        plot_grid(etrm_pred, axes[2, 3], 'ETRM Prediction')
    else:
        axes[2, 3].axis('off')
    
    for i in range(4, num_cols):
        axes[2, i].axis('off')
    
    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    return fig

print("Plotting functions defined.")

In [None]:
# Load test puzzles for qualitative analysis
import json

test_puzzles_path = Path('../../data/arc1concept-aug-1000/test_puzzles.json')

if test_puzzles_path.exists():
    with open(test_puzzles_path) as f:
        test_puzzles = json.load(f)
    print(f"Loaded {len(test_puzzles)} test puzzles")
    
    # Show first few puzzle names
    puzzle_names = list(test_puzzles.keys())[:10]
    print(f"First 10 puzzles: {puzzle_names}")
else:
    print(f"Test puzzles not found at {test_puzzles_path}")
    test_puzzles = None

In [None]:
# Plot a sample puzzle (without model predictions for now)
if test_puzzles:
    # Pick a puzzle
    puzzle_name = puzzle_names[0]
    puzzle = test_puzzles[puzzle_name]
    
    print(f"Puzzle: {puzzle_name}")
    print(f"  Demos: {len(puzzle['train'])}")
    print(f"  Test cases: {len(puzzle['test'])}")
    
    # Extract data
    demo_inputs = [np.array(d['input']) for d in puzzle['train']]
    demo_outputs = [np.array(d['output']) for d in puzzle['train']]
    test_input = np.array(puzzle['test'][0]['input'])
    test_output = np.array(puzzle['test'][0]['output'])
    
    # Plot
    fig = plot_example(
        demo_inputs, demo_outputs,
        test_input, test_output,
        title=f'Puzzle: {puzzle_name}'
    )
    
    plt.savefig(OUTPUT_DIR / f'example_{puzzle_name}.png', dpi=150, bbox_inches='tight')
    plt.show()

## Summary

This notebook provides:
1. **Training curves** - Saved to `outputs/training_curves.png`
2. **Encoder collapse analysis** - Requires loading trained models (paths to be configured)
3. **Qualitative examples** - Puzzle visualization ready, model predictions need checkpoint loading

### Next Steps
- [ ] Configure correct checkpoint paths
- [ ] Run encoder variance analysis
- [ ] Generate predictions for qualitative examples