# In-Context Learning Model Evaluation

Comprehensive evaluation notebook for trained transformer models on linear regression tasks.

**Features:**
- Load trained models from checkpoint
- Full evaluation suite with multiple test scenarios
- Baseline comparisons (OLS, k-NN, Averaging)
- Visualization of in-context learning curves
- Out-of-distribution robustness tests (random quadrants, orthogonal, scaling)
- Works with qwen2.5, GPT-2, and LSTM models

**Prerequisites:**
- A trained model checkpoint (from train_colab.ipynb)
- Model's run_id and output directory

## 1. Check GPU and Environment


In [None]:
# Check GPU availability
!nvidia-smi

import sys
print(f"\nPython version: {sys.version}")


## 2. Install Required Packages


In [None]:
print("Installing packages...\n")

# Core ML packages
%pip install -q transformers>=4.30.0
%pip install -q xgboost
%pip install -q matplotlib seaborn tqdm pandas
%pip install -q pyyaml
%pip install -q munch
%pip install -q scikit-learn

# PyTorch usually comes pre-installed in Colab
try:
    import torch
    print(f"‚úì PyTorch already installed: {torch.__version__}")
except ImportError:
    print("Installing PyTorch...")
    %pip install -q torch torchvision torchaudio

print("\n" + "="*60)
print("‚úì All required packages installed successfully!")
print("="*60)

# Verify key packages
import torch
import transformers
import yaml
import matplotlib.pyplot as plt
import numpy as np

print(f"\nPackage Versions:")
print(f"  PyTorch: {torch.__version__}")
print(f"  Transformers: {transformers.__version__}")

print(f"\nGPU Information:")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  CUDA version: {torch.version.cuda}")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("  ‚ö†Ô∏è  No GPU detected! Enable GPU: Runtime ‚Üí Change runtime type ‚Üí T4 GPU")

print("\n‚úì Ready for evaluation!")


## 3. Mount Google Drive (if using Drive for storage)


In [None]:
# Option 1: If your model is stored in Google Drive
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/in-context-learning

# Option 2: If using repository clone (uncomment if needed)
# import os
# import subprocess
# REPO_URL = "https://github.com/hingma/in-context-learning.git"  # UPDATE THIS!
# if not os.path.exists("in-context-learning"):
#     print(f"Cloning repository from {REPO_URL}...")
#     subprocess.run(["git", "clone", REPO_URL], check=True)
#     print("‚úì Repository cloned successfully")
# %cd in-context-learning


## 4. Import Evaluation Modules


In [None]:
# Add src to path
import sys
import os
sys.path.insert(0, './src')

# Import required modules
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
from tqdm import tqdm
from munch import Munch

# Import project modules
from eval import (
    get_model_from_run, 
    get_run_metrics, 
    eval_model, 
    build_evals,
    baseline_names
)
from models import build_model, get_relevant_baselines
from tasks import get_task_sampler
from samplers import get_data_sampler

# Set plot style (matching eval.ipynb)
sns.set_theme('notebook', 'darkgrid')
palette = sns.color_palette('colorblind')

print("‚úì All modules imported successfully")


## 5. Load Trained Model

**IMPORTANT:** Update the `run_id` below with your trained model's run ID from training.


In [None]:
# ========================================
# UPDATE THIS WITH YOUR RUN ID!
# ========================================
run_id = "8a53116d-8c44-4687-9af4-bc8344eafbc7"  # Replace with your run_id from training
# ========================================

run_path = os.path.join("./outputs", run_id)

# Check if model exists
if not os.path.exists(run_path):
    raise FileNotFoundError(f"Model not found at {run_path}. Please check your run_id!")

print(f"Loading model from: {run_path}\n")

# Load model and config
model, conf = get_model_from_run(run_path, step=-1)  # step=-1 loads final checkpoint

# Move model to GPU if available
if torch.cuda.is_available():
    model = model.cuda()
    print(f"‚úì Model loaded on GPU: {torch.cuda.get_device_name(0)}")
else:
    print("‚ö†Ô∏è  Running on CPU (slower)")
    
model.eval()

# Display model configuration
print(f"\n{'='*60}")
print("Model Configuration:")
print(f"{'='*60}")
print(f"  Model Family: {conf.model.family}")
print(f"  Task: {conf.training.task}")
print(f"  n_dims: {conf.model.n_dims}")
print(f"  n_positions: {conf.model.n_positions}")
print(f"  n_embd: {conf.model.n_embd}")
print(f"  n_layer: {conf.model.n_layer}")
print(f"  n_head: {conf.model.n_head}")
print(f"  Training points: {conf.training.curriculum.points.end}")
print(f"{'='*60}")
print(f"\n‚úì Model ready for evaluation")


## 6. Quick Test Evaluation

Run a quick evaluation to verify the model works correctly.


In [None]:
# Setup evaluation parameters from config
n_dims = conf.model.n_dims
n_points = conf.training.curriculum.points.end
batch_size = 64
task_name = conf.training.task
data_name = conf.training.data

# Create data and task samplers
data_sampler = get_data_sampler(data_name, n_dims=n_dims)
task_sampler = get_task_sampler(task_name, n_dims, batch_size)

print(f"Evaluation Setup:")
print(f"  Task: {task_name}")
print(f"  Data: {data_name}")
print(f"  n_dims: {n_dims}")
print(f"  n_points: {n_points}")
print(f"  batch_size: {batch_size}")

# Generate test data
task = task_sampler()
xs = data_sampler.sample_xs(n_points, batch_size)
ys = task.evaluate(xs)

print(f"\nGenerated test data:")
print(f"  xs shape: {xs.shape}  (batch_size, n_points, n_dims)")
print(f"  ys shape: {ys.shape}  (batch_size, n_points)")

# Get model predictions
device = "cuda" if torch.cuda.is_available() and model.name.split("_")[0] in ["gpt2", "qwen2.5", "lstm"] else "cpu"
with torch.no_grad():
    pred = model(xs.to(device), ys.to(device))
    pred = pred.cpu()

print(f"  pred shape: {pred.shape}")

# Compute loss
metric = task.get_metric()
loss = metric(pred, ys).numpy()

print(f"\n{'='*60}")
print("Quick Test Results:")
print(f"{'='*60}")
print(f"  Mean loss (all points): {loss.mean():.4f}")
print(f"  First point mean loss: {loss[:, 0].mean():.4f}")
print(f"  Final point mean loss: {loss[:, -1].mean():.4f}")
print(f"  Baseline (zero estimator): {n_dims:.4f}")
print(f"  Improvement: {(1 - loss[:, -1].mean() / n_dims) * 100:.1f}%")
print(f"{'='*60}")
print(f"\n‚úì Quick evaluation complete")


## 7. Visualize In-Context Learning Curve

Visualize how the model's performance improves with more in-context examples.


In [None]:
# Plot learning curve (matching eval.ipynb style)
plt.figure(figsize=(10, 6))

# Calculate mean
mean_loss = loss.mean(axis=0)

# Plot model performance
plt.plot(mean_loss, lw=2, label=f"{conf.model.family} ({conf.model.n_layer}L)")

# Add baseline
plt.axhline(n_dims, ls="--", color="gray", label="zero estimator")

# Formatting
plt.xlabel("# in-context examples")
plt.ylabel("squared error")
plt.legend()
plt.show()

# Print summary statistics
print(f"\n{'='*60}")
print("Learning Curve Statistics:")
print(f"{'='*60}")
print(f"Baseline (zero estimator): {n_dims:.4f}")
print(f"Initial loss (1 example): {mean_loss[0]:.4f}")
print(f"Final loss ({n_points} examples): {mean_loss[-1]:.4f}")
print(f"Improvement over baseline: {(1 - mean_loss[-1] / n_dims) * 100:.1f}%")
print(f"Total improvement (first to last): {(1 - mean_loss[-1] / mean_loss[0]) * 100:.1f}%")
print(f"{'='*60}")


## 8. Comprehensive Evaluation with Baselines

Run full evaluation suite including baseline comparisons. This may take a few minutes.


In [None]:
print("Running comprehensive evaluation (this may take a few minutes)...")
print("\nNote: This will compute metrics for:")
print("  - Your trained model")
print("  - Baseline methods (OLS, k-NN, Averaging)")
print("  - Multiple test scenarios (standard, random quadrants, orthogonal, scaling, etc.)")

# Run comprehensive evaluation
# This will cache results in metrics.json
all_metrics = get_run_metrics(run_path, step=-1, cache=True, skip_model_load=False, skip_baselines=False)

print(f"\n‚úì Comprehensive evaluation complete!")
print(f"  Results cached in: {os.path.join(run_path, 'metrics.json')}")


## 9. Compare with Baseline Methods

Visualize how your model compares to traditional baseline methods.


In [None]:
# Extract standard evaluation results
standard_metrics = all_metrics.get("standard", {})

if not standard_metrics:
    print("No standard metrics found!")
else:
    # Create comparison plot (matching eval.ipynb style)
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Define model name
    model_name = model.name
    
    # Plot model performance with colorblind-friendly palette
    colors = sns.color_palette('colorblind')
    
    # Plot model performance
    if model_name in standard_metrics:
        means = standard_metrics[model_name]["mean"]
        ax.plot(means, lw=2, label=f"{conf.model.family} ({conf.model.n_layer}L)", 
                color=colors[0])
    
    # Plot baseline models
    baseline_keys = ['OLS_driver=None', 'averaging', 'NN_n=3_uniform']
    baseline_labels_map = {
        'OLS_driver=None': 'Least Squares',
        'averaging': 'Averaging',
        'NN_n=3_uniform': '3-NN',
    }
    
    for idx, baseline_key in enumerate(baseline_keys, 1):
        if baseline_key in standard_metrics:
            means = standard_metrics[baseline_key]["mean"]
            label = baseline_labels_map.get(baseline_key, baseline_key)
            ax.plot(means, lw=2, label=label, linestyle='--', color=colors[idx])
    
    # Formatting
    ax.set_xlabel("# in-context examples")
    ax.set_ylabel("squared error")
    ax.legend()
    ax.set_ylim(bottom=0)
    
    plt.show()
    
    # Print comparison table
    print(f"\n{'='*80}")
    print("PERFORMANCE COMPARISON (Final Point Loss)")
    print(f"{'='*80}")
    print(f"{'Method':<40} {'Final Loss':<15} {'vs Baseline':<15}")
    print(f"{'-'*80}")
    
    baseline_loss = n_dims  # Zero estimator
    
    for method_name, metrics in standard_metrics.items():
        final_loss = metrics["mean"][-1]
        improvement = (1 - final_loss / baseline_loss) * 100
        
        # Clean up method name
        if method_name == model_name:
            display_name = f"{conf.model.family.upper()} (Your Model)"
        else:
            display_name = baseline_names(method_name)
        
        print(f"{display_name:<40} {final_loss:<15.4f} {improvement:>+12.1f}%")
    
    print(f"{'-'*80}")
    print(f"{'Baseline (Zero estimator)':<40} {baseline_loss:<15.4f} {'0.0%':>15}")
    print(f"{'='*80}")


## 10. Out-of-Distribution Robustness Tests

Evaluate model performance on various out-of-distribution scenarios.


In [None]:
# Plot OOD scenarios individually with data-driven axis limits
# This creates separate plots for each scenario with appropriate scaling

ood_scenarios = [
    'random_quadrants',
    'orthogonal_train_test',
    'overlapping_train_test',
    'half_subspace',
    'skewed'
]

model_name = model.name
colors = sns.color_palette('colorblind')

# Baseline methods to include
baseline_keys = ['OLS_driver=None', 'averaging', 'NN_n=3_uniform']
baseline_labels_map = {
    'OLS_driver=None': 'Least Squares',
    'averaging': 'Averaging',
    'NN_n=3_uniform': '3-Nearest Neighbors',
}

print("Plotting OOD scenarios (each with baseline comparisons):\n")

for scenario in ood_scenarios:
    if scenario not in all_metrics:
        continue
    
    metric = all_metrics[scenario]
    
    # Determine scale factor for axis limits
    if "scale" in scenario:
        scale = float(scenario.split("=")[-1])**2
    else:
        scale = 1.0
    
    # Create plot with adaptive size based on data range
    fig, ax = plt.subplots(figsize=(7, 5))
    
    # Collect all values to determine proper axis limits
    all_values = []
    
    # Plot model performance
    if model_name in metric:
        means = metric[model_name]["mean"]
        ax.plot(means, lw=2, label=f"{conf.model.family}", color=colors[0])
        all_values.extend(means)
    
    # Plot baselines
    for idx, baseline_key in enumerate(baseline_keys, 1):
        if baseline_key in metric:
            means = metric[baseline_key]["mean"]
            label = baseline_labels_map.get(baseline_key, baseline_key)
            ax.plot(means, lw=2, label=label, linestyle='--', color=colors[idx])
            all_values.extend(means)
    
    # Add trivial baseline
    trivial = 1.0 * scale
    ax.axhline(trivial, ls="--", color="gray", alpha=0.5)
    
    # Set title
    ax.set_title(scenario, fontsize=12)
    
    # Set axis limits based on actual data
    if all_values:
        y_min = min(all_values)
        y_max = max(all_values)
        y_range = y_max - y_min
        
        # Add padding (10% on bottom, 15% on top for legend space)
        ax.set_ylim(max(-0.1 * scale, y_min - 0.1 * y_range), 
                    y_max + 0.15 * y_range)
    
    # X-axis: zoom in for orthogonal tests (most interesting in first n_dims examples)
    if "ortho" in scenario:
        ax.set_xlim(-1, min(n_dims, len(means)) - 1)
    else:
        ax.set_xlim(-1, len(means))
    
    # Labels
    ax.set_xlabel("in-context examples", fontsize=11)
    ax.set_ylabel("squared error", fontsize=11)
    ax.legend(loc="best", fontsize=9)
    
    plt.tight_layout()
    plt.show()
    print(f"  ‚úì {scenario}: y-range [{y_min:.2f}, {y_max:.2f}]")

print("\nNote: Axis limits are automatically adjusted based on actual data ranges.")


## 11. Input/Output Scaling Tests

Test model robustness to different input (x) and output (y) scales.


In [None]:
# Plot scaling tests individually for better visibility
scales = [0.333, 0.5, 2, 3]

# Process and plot X-scaling
print("\nX-Scaling Robustness Tests:")
x_scaling_data = {}
for scale in scales:
    key = f"scale-x={scale}"
    if key in all_metrics and model_name in all_metrics[key]:
        x_scaling_data[scale] = {
            'final': all_metrics[key][model_name]["mean"][-1],
            'full': all_metrics[key][model_name]["mean"]
        }

# Process and plot Y-scaling  
print("\nY-Scaling Robustness Tests:")
y_scaling_data = {}
for scale in scales:
    key = f"scale-y={scale}"
    if key in all_metrics and model_name in all_metrics[key]:
        y_scaling_data[scale] = {
            'final': all_metrics[key][model_name]["mean"][-1],
            'full': all_metrics[key][model_name]["mean"]
        }

# Plot each scaling scenario separately with proper axis limits
for dim, scaling_data in [('x', x_scaling_data), ('y', y_scaling_data)]:
    if not scaling_data:
        continue
    
    print(f"\n{dim.upper()}-Scaling scenarios:")
    
    for scale, data in sorted(scaling_data.items()):
        key = f"scale-{dim}={scale}"
        if key not in all_metrics:
            continue
        
        metric = all_metrics[key]
        
        # Create individual plot for each scale
        fig, ax = plt.subplots(figsize=(7, 5))
        
        colors = sns.color_palette('colorblind')
        all_values = []
        
        # Plot model
        if model_name in metric:
            means = metric[model_name]["mean"]
            ax.plot(means, lw=2, label=f"{conf.model.family}", color=colors[0])
            all_values.extend(means)
        
        # Plot baselines
        for idx, baseline_key in enumerate(['OLS_driver=None', 'averaging', 'NN_n=3_uniform'], 1):
            if baseline_key in metric:
                means = metric[baseline_key]["mean"]
                label = baseline_labels_map.get(baseline_key, baseline_key)
                ax.plot(means, lw=2, label=label, linestyle='--', color=colors[idx])
                all_values.extend(means)
        
        # Add trivial baseline (scaled)
        trivial = scale**2 if dim == 'x' else scale**2
        ax.axhline(trivial, ls="--", color="gray", alpha=0.5)
        
        # Set title
        ax.set_title(f"{key} (scale factor = {scale})", fontsize=12)
        
        # Set axis limits based on data
        if all_values:
            y_min = min(all_values)
            y_max = max(all_values)
            y_range = y_max - y_min
            ax.set_ylim(y_min - 0.1 * y_range, y_max + 0.15 * y_range)
        
        ax.set_xlim(-1, len(metric[model_name]["mean"]))
        ax.set_xlabel("in-context examples", fontsize=11)
        ax.set_ylabel("squared error", fontsize=11)
        ax.legend(loc="best", fontsize=9)
        
        plt.tight_layout()
        plt.show()
        print(f"  ‚úì {key}: final loss = {data['final']:.2f}, y-range [{y_min:.2f}, {y_max:.2f}]")

# Summary plot: final losses vs scale factor
if x_scaling_data or y_scaling_data:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # X-scaling summary
    if x_scaling_data:
        x_scales = sorted(x_scaling_data.keys())
        x_losses = [x_scaling_data[s]['final'] for s in x_scales]
        
        ax1.plot(x_scales, x_losses, lw=2, marker='o', markersize=8, label='model')
        if standard_loss:
            ax1.axhline(standard_loss, color='gray', linestyle='--', 
                       label=f'standard: {standard_loss:.3f}')
        
        ax1.set_xlabel('input (x) scale factor', fontsize=12)
        ax1.set_ylabel('final point squared error', fontsize=12)
        ax1.set_title('X-Scaling: Final Point Loss', fontsize=13)
        ax1.set_xscale('log')
        ax1.legend(fontsize=10)
        ax1.grid(True, alpha=0.3)
    
    # Y-scaling summary
    if y_scaling_data:
        y_scales = sorted(y_scaling_data.keys())
        y_losses = [y_scaling_data[s]['final'] for s in y_scales]
        
        ax2.plot(y_scales, y_losses, lw=2, marker='o', markersize=8, label='model')
        if standard_loss:
            ax2.axhline(standard_loss, color='gray', linestyle='--', 
                       label=f'standard: {standard_loss:.3f}')
        
        ax2.set_xlabel('output (y) scale factor', fontsize=12)
        ax2.set_ylabel('final point squared error', fontsize=12)
        ax2.set_title('Y-Scaling: Final Point Loss', fontsize=13)
        ax2.set_xscale('log')
        ax2.legend(fontsize=10)
        ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    print("\n‚úì Scaling summary plots completed")
    
    # Print scaling analysis
    print(f"\n{'='*80}")
    print("SCALING ROBUSTNESS ANALYSIS")
    print(f"{'='*80}")
    
    if x_scaling_data:
        print(f"\nInput (X) Scaling:")
        print(f"{'  Scale':<12} {'Final Loss':<15} {'vs Standard':<15}")
        print(f"  {'-'*42}")
        for scale in sorted(x_scaling_data.keys()):
            loss = x_scaling_data[scale]['final']
            diff = ((loss / standard_loss - 1) * 100) if standard_loss else 0
            print(f"  {scale:<12.3f} {loss:<15.4f} {diff:>+12.1f}%")
    
    if y_scaling_data:
        print(f"\nOutput (Y) Scaling:")
        print(f"{'  Scale':<12} {'Final Loss':<15} {'vs Standard':<15}")
        print(f"  {'-'*42}")
        for scale in sorted(y_scaling_data.keys()):
            loss = y_scaling_data[scale]['final']
            diff = ((loss / standard_loss - 1) * 100) if standard_loss else 0
            print(f"  {scale:<12.3f} {loss:<15.4f} {diff:>+12.1f}%")
    
    print(f"\n{'='*80}")
else:
    print("\n‚ö†Ô∏è No scaling test results found.")


## 12. Complete Evaluation Summary

Generate a comprehensive summary of all evaluation results.


In [None]:
# Generate comprehensive summary
print(f"\n{'='*80}")
print(f"{'='*80}")
print(f"COMPLETE EVALUATION SUMMARY")
print(f"{'='*80}")
print(f"{'='*80}")

print(f"\nüìã MODEL INFORMATION:")
print(f"  Run ID: {run_id}")
print(f"  Model Family: {conf.model.family}")
print(f"  Architecture: {conf.model.n_layer} layers, {conf.model.n_head} heads, {conf.model.n_embd} embedding dim")
print(f"  Task: {conf.training.task}")
print(f"  Training dimensions: {conf.model.n_dims}")
print(f"  Context length: {conf.model.n_positions}")

print(f"\nüìä PERFORMANCE METRICS:")
if "standard" in all_metrics and model_name in all_metrics["standard"]:
    final_loss = all_metrics["standard"][model_name]["mean"][-1]
    standard_loss = final_loss
    improvement = (1 - final_loss / n_dims) * 100
    print(f"  Standard (in-distribution) loss: {final_loss:.4f}")
    print(f"  Improvement over baseline: {improvement:.1f}%")
else:
    standard_loss = None

print(f"\nüéØ BASELINE COMPARISONS:")
if "standard" in all_metrics:
    for baseline_key in ['OLS_driver=None', 'averaging', 'NN_n=3_uniform']:
        if baseline_key in all_metrics["standard"]:
            baseline_loss = all_metrics["standard"][baseline_key]["mean"][-1]
            baseline_name = baseline_names(baseline_key)
            if model_name in all_metrics["standard"]:
                model_loss = all_metrics["standard"][model_name]["mean"][-1]
                advantage = ((baseline_loss - model_loss) / baseline_loss * 100)
                symbol = "‚úÖ" if advantage > 0 else "‚ö†Ô∏è"
                print(f"  {symbol} vs {baseline_name}: {advantage:+.1f}% {'better' if advantage > 0 else 'worse'}")

print(f"\nüåê OUT-OF-DISTRIBUTION ROBUSTNESS:")
# Collect OOD results for summary
ood_scenarios_summary = ['random_quadrants', 'orthogonal_train_test', 'overlapping_train_test', 'half_subspace', 'skewed']
ood_results = {}
for scenario in ood_scenarios_summary:
    if scenario in all_metrics and model_name in all_metrics[scenario]:
        ood_results[scenario] = all_metrics[scenario][model_name]["mean"][-1]

if ood_results:
    for scenario, loss in ood_results.items():
        clean_name = scenario.replace('_', ' ').title()
        if standard_loss:
            degradation = ((loss / standard_loss - 1) * 100)
            status = "‚úÖ" if abs(degradation) < 15 else "‚ö†Ô∏è" if abs(degradation) < 30 else "‚ùå"
            print(f"  {status} {clean_name}: {loss:.4f} ({degradation:+.1f}%)")
        else:
            print(f"  ‚Ä¢ {clean_name}: {loss:.4f}")

print(f"\n‚öñÔ∏è  SCALING ROBUSTNESS:")
# Collect scaling data for summary
x_scaling_summary = {}
y_scaling_summary = {}
for scale in [0.333, 0.5, 2, 3]:
    key_x = f"scale-x={scale}"
    key_y = f"scale-y={scale}"
    if key_x in all_metrics and model_name in all_metrics[key_x]:
        x_scaling_summary[scale] = all_metrics[key_x][model_name]["mean"][-1]
    if key_y in all_metrics and model_name in all_metrics[key_y]:
        y_scaling_summary[scale] = all_metrics[key_y][model_name]["mean"][-1]

if x_scaling_summary or y_scaling_summary:
    if x_scaling_summary:
        max_x_deg = max([abs((loss / standard_loss - 1) * 100) for loss in x_scaling_summary.values()]) if standard_loss else 0
        print(f"  Input (X) scaling: max degradation {max_x_deg:.1f}%")
    if y_scaling_summary:
        max_y_deg = max([abs((loss / standard_loss - 1) * 100) for loss in y_scaling_summary.values()]) if standard_loss else 0
        print(f"  Output (Y) scaling: max degradation {max_y_deg:.1f}%")

print(f"\nüíæ SAVED RESULTS:")
print(f"  Metrics file: {os.path.join(run_path, 'metrics.json')}")
print(f"  Model checkpoint: {os.path.join(run_path, 'state.pt')}")
print(f"  Configuration: {os.path.join(run_path, 'config.yaml')}")

print(f"\n{'='*80}")
print(f"‚úÖ EVALUATION COMPLETE!")
print(f"{'='*80}")
print(f"\nAll results have been cached and can be reloaded without recomputing.")
print(f"To share results, download the metrics.json file from: {run_path}")


## 13. Export/Download Results (Optional)

Download evaluation results to your local machine.


In [None]:
# Uncomment to download files when running on Colab
# from google.colab import files

# Download metrics
# metrics_file = os.path.join(run_path, "metrics.json")
# if os.path.exists(metrics_file):
#     print(f"Downloading {metrics_file}...")
#     files.download(metrics_file)
#     print("‚úì Metrics downloaded")

# Download config
# config_file = os.path.join(run_path, "config.yaml")
# if os.path.exists(config_file):
#     print(f"Downloading {config_file}...")
#     files.download(config_file)
#     print("‚úì Config downloaded")

# Create a summary report
summary_file = os.path.join(run_path, "evaluation_summary.txt")
with open(summary_file, 'w') as f:
    f.write("="*80 + "\n")
    f.write("EVALUATION SUMMARY\n")
    f.write("="*80 + "\n\n")
    
    f.write(f"Model: {conf.model.family}\n")
    f.write(f"Task: {conf.training.task}\n")
    f.write(f"Run ID: {run_id}\n\n")
    
    if "standard" in all_metrics and model_name in all_metrics["standard"]:
        final_loss = all_metrics["standard"][model_name]["mean"][-1]
        f.write(f"Standard Loss: {final_loss:.4f}\n")
        f.write(f"Improvement: {(1 - final_loss / n_dims) * 100:.1f}%\n\n")
    
    f.write("Evaluation Scenarios:\n")
    for scenario in all_metrics.keys():
        if model_name in all_metrics[scenario]:
            loss = all_metrics[scenario][model_name]["mean"][-1]
            f.write(f"  - {scenario}: {loss:.4f}\n")

print(f"\n‚úì Summary saved to: {summary_file}")

# Uncomment to download summary
# files.download(summary_file)

print("\nüí° Tip: Run this notebook anytime to re-evaluate or visualize cached results!")


## 14. Custom Robustness Test Example

Example: Test model robustness to input scaling (similar to eval.ipynb).

As an exploration, you can test how robust the model is to scaling inputs:


In [None]:
# Test with doubled inputs
xs2 = 2 * xs
ys2 = task.evaluate(xs2)

with torch.no_grad():
    pred2 = model(xs2.to(device), ys2.to(device))
    pred2 = pred2.cpu()

loss2 = metric(pred2, ys2).numpy()

# Plot comparison
plt.figure(figsize=(10, 6))
plt.plot(loss.mean(axis=0), lw=2, label="standard inputs")
plt.plot(loss2.mean(axis=0) / 4, lw=2, label="doubled inputs (scaled)")
plt.axhline(n_dims, ls="--", color="gray", label="zero estimator")
plt.xlabel("# in-context examples")
plt.ylabel("squared error")
plt.legend()
plt.show()

print("\nThe error may increase with doubled inputs, especially when the number")
print("of in-context examples exceeds the dimension, but the model should")
print("still remain relatively accurate.")


## Additional Notes

**Advanced Plotting with Data-Driven Axis Limits**: This notebook features intelligent visualization:
- **Seaborn 'notebook' theme** with 'darkgrid' style and colorblind-friendly palette
- **Data-driven axis limits**: Each plot automatically adjusts its y-axis based on actual value ranges
- **Individual scenario plots**: Each test scenario (OOD, scaling) gets its own optimized plot
- **Smart zooming**: 
  - Orthogonal tests focus on the first `n_dims` examples (most informative region)
  - Scaling tests show full learning curves with proper scale adjustments
  - All plots ensure curves are fully visible without excessive whitespace

**Key Features**:
1. **Cached Results**: All metrics are saved in `metrics.json` and can be reloaded without recomputation
2. **Comprehensive Tests**: Standard, OOD (with adaptive views), and scaling robustness evaluations  
3. **Baseline Comparisons**: Every plot includes OLS, k-NN, and averaging baselines
4. **Adaptive Visualization**: Handles diverse value ranges (from 0.2 to 80+) automatically
5. **Custom Tests**: Easy to add your own robustness tests (see doubled inputs example)

**Understanding the Plots**:
- **Standard evaluation**: Shows in-context learning across all examples
- **OOD scenarios**: Individual plots with data-driven axis limits, printed y-ranges for reference
- **Scaling tests**: Both individual learning curves AND summary plots comparing final losses
- **Gray dashed line**: Shows the trivial baseline (adjusted for scaling scenarios)
- **Value ranges**: Automatically printed for each scenario to understand data scale

**Plot Optimization Examples**:
- `random_quadrants`: y-range [~9, ~20] - model doesn't converge as well
- `half_subspace`: y-range [~0.5, ~20] - excellent convergence
- `scale-x=2`: y-range [~4, ~83] - scaled values, still shows convergence pattern
- Each plot adapts to show the full story without clipping

**Next Steps**:
- Compare different model architectures (qwen2.5 vs GPT-2)
- Test on different tasks (sparse linear regression, decision trees, etc.)
- Experiment with different training curricula
- Analyze learned representations or attention patterns
- Use printed y-ranges to quickly identify problematic scenarios