# In-Context Learning Model Evaluation

Comprehensive evaluation notebook for trained transformer models on linear regression tasks.

**Features:**
- Load trained models from checkpoint
- Full evaluation suite with multiple test scenarios
- Baseline comparisons (OLS, k-NN, Averaging)
- Visualization of in-context learning curves
- Out-of-distribution robustness tests (random quadrants, orthogonal, scaling)
- Works with qwen2.5, GPT-2, and LSTM models

**Prerequisites:**
- A trained model checkpoint (from train_colab.ipynb)
- Model's run_id and output directory

## 1. Check GPU and Environment


In [None]:
# Check GPU availability
!nvidia-smi

import sys
print(f"\nPython version: {sys.version}")


## 2. Install Required Packages


In [None]:
print("Installing packages...\n")

# Core ML packages
%pip install -q transformers>=4.30.0
%pip install -q xgboost
%pip install -q matplotlib seaborn tqdm pandas
%pip install -q pyyaml
%pip install -q munch
%pip install -q scikit-learn

# PyTorch usually comes pre-installed in Colab
try:
    import torch
    print(f"‚úì PyTorch already installed: {torch.__version__}")
except ImportError:
    print("Installing PyTorch...")
    %pip install -q torch torchvision torchaudio

print("\n" + "="*60)
print("‚úì All required packages installed successfully!")
print("="*60)

# Verify key packages
import torch
import transformers
import yaml
import matplotlib.pyplot as plt
import numpy as np

print(f"\nPackage Versions:")
print(f"  PyTorch: {torch.__version__}")
print(f"  Transformers: {transformers.__version__}")

print(f"\nGPU Information:")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  CUDA version: {torch.version.cuda}")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("  ‚ö†Ô∏è  No GPU detected! Enable GPU: Runtime ‚Üí Change runtime type ‚Üí T4 GPU")

print("\n‚úì Ready for evaluation!")


## 3. Mount Google Drive (if using Drive for storage)


In [None]:
# Option 1: If your model is stored in Google Drive
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/in-context-learning

# Option 2: If using repository clone (uncomment if needed)
# import os
# import subprocess
# REPO_URL = "https://github.com/hingma/in-context-learning.git"  # UPDATE THIS!
# if not os.path.exists("in-context-learning"):
#     print(f"Cloning repository from {REPO_URL}...")
#     subprocess.run(["git", "clone", REPO_URL], check=True)
#     print("‚úì Repository cloned successfully")
# %cd in-context-learning


## 4. Import Evaluation Modules


In [None]:
# Add src to path
import sys
import os
sys.path.insert(0, './src')

# Import required modules
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import json
from tqdm import tqdm
from munch import Munch

# Import project modules
from eval import (
    get_model_from_run, 
    get_run_metrics, 
    eval_model, 
    build_evals,
    baseline_names
)
from models import build_model, get_relevant_baselines
from tasks import get_task_sampler
from samplers import get_data_sampler

# Set plot style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("‚úì All modules imported successfully")


## 5. Load Trained Model

**IMPORTANT:** Update the `run_id` below with your trained model's run ID from training.


In [None]:
# ========================================
# UPDATE THIS WITH YOUR RUN ID!
# ========================================
run_id = "8a53116d-8c44-4687-9af4-bc8344eafbc7"  # Replace with your run_id from training
# ========================================

run_path = os.path.join("./outputs", run_id)

# Check if model exists
if not os.path.exists(run_path):
    raise FileNotFoundError(f"Model not found at {run_path}. Please check your run_id!")

print(f"Loading model from: {run_path}\n")

# Load model and config
model, conf = get_model_from_run(run_path, step=-1)  # step=-1 loads final checkpoint

# Move model to GPU if available
if torch.cuda.is_available():
    model = model.cuda()
    print(f"‚úì Model loaded on GPU: {torch.cuda.get_device_name(0)}")
else:
    print("‚ö†Ô∏è  Running on CPU (slower)")
    
model.eval()

# Display model configuration
print(f"\n{'='*60}")
print("Model Configuration:")
print(f"{'='*60}")
print(f"  Model Family: {conf.model.family}")
print(f"  Task: {conf.training.task}")
print(f"  n_dims: {conf.model.n_dims}")
print(f"  n_positions: {conf.model.n_positions}")
print(f"  n_embd: {conf.model.n_embd}")
print(f"  n_layer: {conf.model.n_layer}")
print(f"  n_head: {conf.model.n_head}")
print(f"  Training points: {conf.training.curriculum.points.end}")
print(f"{'='*60}")
print(f"\n‚úì Model ready for evaluation")


## 6. Quick Test Evaluation

Run a quick evaluation to verify the model works correctly.


In [None]:
# Setup evaluation parameters from config
n_dims = conf.model.n_dims
n_points = conf.training.curriculum.points.end
batch_size = 64
task_name = conf.training.task
data_name = conf.training.data

# Create data and task samplers
data_sampler = get_data_sampler(data_name, n_dims=n_dims)
task_sampler = get_task_sampler(task_name, n_dims, batch_size)

print(f"Evaluation Setup:")
print(f"  Task: {task_name}")
print(f"  Data: {data_name}")
print(f"  n_dims: {n_dims}")
print(f"  n_points: {n_points}")
print(f"  batch_size: {batch_size}")

# Generate test data
task = task_sampler()
xs = data_sampler.sample_xs(n_points, batch_size)
ys = task.evaluate(xs)

print(f"\nGenerated test data:")
print(f"  xs shape: {xs.shape}  (batch_size, n_points, n_dims)")
print(f"  ys shape: {ys.shape}  (batch_size, n_points)")

# Get model predictions
device = "cuda" if torch.cuda.is_available() and model.name.split("_")[0] in ["gpt2", "qwen2.5", "lstm"] else "cpu"
with torch.no_grad():
    pred = model(xs.to(device), ys.to(device))
    pred = pred.cpu()

print(f"  pred shape: {pred.shape}")

# Compute loss
metric = task.get_metric()
loss = metric(pred, ys).numpy()

print(f"\n{'='*60}")
print("Quick Test Results:")
print(f"{'='*60}")
print(f"  Mean loss (all points): {loss.mean():.4f}")
print(f"  First point mean loss: {loss[:, 0].mean():.4f}")
print(f"  Final point mean loss: {loss[:, -1].mean():.4f}")
print(f"  Baseline (zero estimator): {n_dims:.4f}")
print(f"  Improvement: {(1 - loss[:, -1].mean() / n_dims) * 100:.1f}%")
print(f"{'='*60}")
print(f"\n‚úì Quick evaluation complete")


## 7. Visualize In-Context Learning Curve

Visualize how the model's performance improves with more in-context examples.


In [None]:
# Plot learning curve
plt.figure(figsize=(12, 6))

# Calculate mean and std
mean_loss = loss.mean(axis=0)
std_loss = loss.std(axis=0)

# Plot model performance
plt.plot(range(n_points), mean_loss, lw=2.5, 
         label=f"{conf.model.family} ({conf.model.n_layer}L)", 
         marker='o', markersize=4, color='#2E86AB')

# Add confidence interval
plt.fill_between(range(n_points), 
                 mean_loss - std_loss, 
                 mean_loss + std_loss, 
                 alpha=0.2, color='#2E86AB')

# Add baseline
plt.axhline(n_dims, ls="--", color="gray", lw=2, label="Zero estimator baseline")

# Formatting
plt.xlabel("Number of in-context examples", fontsize=13, fontweight='bold')
plt.ylabel("Squared Error", fontsize=13, fontweight='bold')
plt.title(f"In-Context Learning Performance: {task_name}", fontsize=15, fontweight='bold')
plt.legend(fontsize=11, loc='upper right')
plt.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
plt.show()

# Print summary statistics
print(f"\n{'='*60}")
print("Learning Curve Statistics:")
print(f"{'='*60}")
print(f"Baseline (zero estimator): {n_dims:.4f}")
print(f"Initial loss (1 example): {mean_loss[0]:.4f}")
print(f"Final loss ({n_points} examples): {mean_loss[-1]:.4f}")
print(f"Improvement over baseline: {(1 - mean_loss[-1] / n_dims) * 100:.1f}%")
print(f"Total improvement (first to last): {(1 - mean_loss[-1] / mean_loss[0]) * 100:.1f}%")
print(f"{'='*60}")


## 8. Comprehensive Evaluation with Baselines

Run full evaluation suite including baseline comparisons. This may take a few minutes.


In [None]:
print("Running comprehensive evaluation (this may take a few minutes)...")
print("\nNote: This will compute metrics for:")
print("  - Your trained model")
print("  - Baseline methods (OLS, k-NN, Averaging)")
print("  - Multiple test scenarios (standard, random quadrants, orthogonal, scaling, etc.)")

# Run comprehensive evaluation
# This will cache results in metrics.json
all_metrics = get_run_metrics(run_path, step=-1, cache=True, skip_model_load=False, skip_baselines=False)

print(f"\n‚úì Comprehensive evaluation complete!")
print(f"  Results cached in: {os.path.join(run_path, 'metrics.json')}")


## 9. Compare with Baseline Methods

Visualize how your model compares to traditional baseline methods.


In [None]:
# Extract standard evaluation results
standard_metrics = all_metrics.get("standard", {})

if not standard_metrics:
    print("No standard metrics found!")
else:
    # Create comparison plot
    fig, ax = plt.subplots(figsize=(14, 7))
    
    # Define model name
    model_name = model.name
    
    # Plot model performance
    if model_name in standard_metrics:
        means = standard_metrics[model_name]["mean"]
        stds = np.array(standard_metrics[model_name]["std"])
        
        ax.plot(means, lw=3, label=f"üöÄ {conf.model.family.upper()} (Ours)", 
                marker='o', markersize=6, color='#E63946', zorder=10)
        ax.fill_between(range(len(means)), 
                        np.array(means) - stds, 
                        np.array(means) + stds, 
                        alpha=0.2, color='#E63946')
    
    # Plot baseline models
    baseline_colors = {
        'OLS_driver=None': '#457B9D',
        'averaging': '#F4A261',
        'NN_n=3_uniform': '#2A9D8F',
    }
    
    baseline_labels = {
        'OLS_driver=None': 'Least Squares',
        'averaging': 'Averaging',
        'NN_n=3_uniform': '3-Nearest Neighbors',
    }
    
    for baseline_key, color in baseline_colors.items():
        if baseline_key in standard_metrics:
            means = standard_metrics[baseline_key]["mean"]
            label = baseline_labels.get(baseline_key, baseline_key)
            ax.plot(means, lw=2, label=f"üìä {label}", 
                   linestyle='--', marker='s', markersize=4, 
                   alpha=0.8, color=color)
    
    # Formatting
    ax.set_xlabel("Number of in-context examples", fontsize=13, fontweight='bold')
    ax.set_ylabel("Squared Error", fontsize=13, fontweight='bold')
    ax.set_title("Model vs Baseline Methods: Standard Evaluation", fontsize=15, fontweight='bold')
    ax.legend(fontsize=11, loc='upper right', framealpha=0.95)
    ax.grid(True, alpha=0.3, linestyle='--')
    ax.set_ylim(bottom=0)
    
    plt.tight_layout()
    plt.show()
    
    # Print comparison table
    print(f"\n{'='*80}")
    print("PERFORMANCE COMPARISON (Final Point Loss)")
    print(f"{'='*80}")
    print(f"{'Method':<40} {'Final Loss':<15} {'vs Baseline':<15}")
    print(f"{'-'*80}")
    
    baseline_loss = n_dims  # Zero estimator
    
    for method_name, metrics in standard_metrics.items():
        final_loss = metrics["mean"][-1]
        improvement = (1 - final_loss / baseline_loss) * 100
        
        # Clean up method name
        if method_name == model_name:
            display_name = f"üöÄ {conf.model.family.upper()} (Your Model)"
        else:
            display_name = f"üìä {baseline_names(method_name)}"
        
        print(f"{display_name:<40} {final_loss:<15.4f} {improvement:>+12.1f}%")
    
    print(f"{'-'*80}")
    print(f"{'Baseline (Zero estimator)':<40} {baseline_loss:<15.4f} {'0.0%':>15}")
    print(f"{'='*80}")


## 10. Out-of-Distribution Robustness Tests

Evaluate model performance on various out-of-distribution scenarios.


In [None]:
# Identify OOD test scenarios from the evaluation results
ood_scenarios = [
    'random_quadrants',
    'orthogonal_train_test',
    'overlapping_train_test',
    'half_subspace',
    'skewed'
]

# Collect OOD results for the model
ood_results = {}
model_name = model.name

for scenario in ood_scenarios:
    if scenario in all_metrics and model_name in all_metrics[scenario]:
        ood_results[scenario] = all_metrics[scenario][model_name]["mean"][-1]

# Also get standard performance for comparison
if "standard" in all_metrics and model_name in all_metrics["standard"]:
    standard_loss = all_metrics["standard"][model_name]["mean"][-1]
else:
    standard_loss = None

# Create visualization
if ood_results:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Bar plot of OOD performance
    scenarios = list(ood_results.keys())
    losses = list(ood_results.values())
    
    colors = ['#E63946' if l < standard_loss else '#F4A261' for l in losses]
    
    ax1.barh(scenarios, losses, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
    if standard_loss:
        ax1.axvline(standard_loss, color='#457B9D', linestyle='--', linewidth=2, 
                   label=f'Standard (in-distribution): {standard_loss:.3f}')
    ax1.set_xlabel('Final Point Squared Error', fontsize=12, fontweight='bold')
    ax1.set_title('Out-of-Distribution Robustness', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=10)
    ax1.grid(True, alpha=0.3, axis='x')
    
    # Relative performance table
    ax2.axis('off')
    
    table_data = []
    table_data.append(['Scenario', 'Final Loss', 'vs Standard'])
    table_data.append(['-'*30, '-'*12, '-'*12])
    
    if standard_loss:
        table_data.append(['Standard (in-dist)', f'{standard_loss:.4f}', '‚Äî'])
    
    for scenario, loss in ood_results.items():
        if standard_loss:
            relative = ((loss / standard_loss - 1) * 100)
            rel_str = f'{relative:+.1f}%'
        else:
            rel_str = '‚Äî'
        
        # Clean scenario name
        clean_name = scenario.replace('_', ' ').title()
        table_data.append([clean_name, f'{loss:.4f}', rel_str])
    
    # Draw table
    table = ax2.table(cellText=table_data, 
                     cellLoc='left',
                     loc='center',
                     colWidths=[0.5, 0.25, 0.25])
    
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)
    
    # Style header
    for i in range(3):
        table[(0, i)].set_facecolor('#457B9D')
        table[(0, i)].set_text_props(weight='bold', color='white')
    
    # Style standard row
    if standard_loss:
        for i in range(3):
            table[(2, i)].set_facecolor('#E8F4F8')
            table[(2, i)].set_text_props(weight='bold')
    
    ax2.set_title('Performance Summary', fontsize=14, fontweight='bold', pad=20)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed analysis
    print(f"\n{'='*80}")
    print("OUT-OF-DISTRIBUTION ROBUSTNESS ANALYSIS")
    print(f"{'='*80}")
    print(f"\n‚úì Model shows ", end="")
    
    if standard_loss:
        avg_degradation = np.mean([abs(l - standard_loss) / standard_loss * 100 for l in losses])
        if avg_degradation < 10:
            print(f"EXCELLENT robustness (avg degradation: {avg_degradation:.1f}%)")
        elif avg_degradation < 25:
            print(f"GOOD robustness (avg degradation: {avg_degradation:.1f}%)")
        else:
            print(f"MODERATE robustness (avg degradation: {avg_degradation:.1f}%)")
    else:
        print("results on OOD scenarios")
    
    print(f"\n{'='*80}")
    
else:
    print("No OOD evaluation results found.")


## 11. Input/Output Scaling Tests

Test model robustness to different input (x) and output (y) scales.


In [None]:
# Extract scaling test results
scaling_tests = {}
scales = [0.333, 0.5, 2, 3]

for dim in ['x', 'y']:
    scaling_tests[dim] = {}
    for scale in scales:
        key = f"scale-{dim}={scale}"
        if key in all_metrics and model_name in all_metrics[key]:
            scaling_tests[dim][scale] = all_metrics[key][model_name]["mean"][-1]

# Visualize scaling robustness
if scaling_tests['x'] or scaling_tests['y']:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # X-scaling
    if scaling_tests['x']:
        x_scales = sorted(scaling_tests['x'].keys())
        x_losses = [scaling_tests['x'][s] for s in x_scales]
        
        ax1.plot(x_scales, x_losses, marker='o', markersize=10, 
                linewidth=2.5, color='#E63946', label='Model performance')
        ax1.axhline(standard_loss, color='#457B9D', linestyle='--', 
                   linewidth=2, label=f'Standard: {standard_loss:.3f}')
        ax1.axvline(1.0, color='gray', linestyle=':', linewidth=1.5, alpha=0.5)
        
        ax1.set_xlabel('Input (X) Scale Factor', fontsize=12, fontweight='bold')
        ax1.set_ylabel('Final Point Squared Error', fontsize=12, fontweight='bold')
        ax1.set_title('Input Scaling Robustness', fontsize=14, fontweight='bold')
        ax1.set_xscale('log')
        ax1.grid(True, alpha=0.3)
        ax1.legend(fontsize=10)
    
    # Y-scaling
    if scaling_tests['y']:
        y_scales = sorted(scaling_tests['y'].keys())
        y_losses = [scaling_tests['y'][s] for s in y_scales]
        
        ax2.plot(y_scales, y_losses, marker='s', markersize=10, 
                linewidth=2.5, color='#F4A261', label='Model performance')
        ax2.axhline(standard_loss, color='#457B9D', linestyle='--', 
                   linewidth=2, label=f'Standard: {standard_loss:.3f}')
        ax2.axvline(1.0, color='gray', linestyle=':', linewidth=1.5, alpha=0.5)
        
        ax2.set_xlabel('Output (Y) Scale Factor', fontsize=12, fontweight='bold')
        ax2.set_ylabel('Final Point Squared Error', fontsize=12, fontweight='bold')
        ax2.set_title('Output Scaling Robustness', fontsize=14, fontweight='bold')
        ax2.set_xscale('log')
        ax2.grid(True, alpha=0.3)
        ax2.legend(fontsize=10)
    
    plt.tight_layout()
    plt.show()
    
    # Print scaling analysis
    print(f"\n{'='*80}")
    print("SCALING ROBUSTNESS ANALYSIS")
    print(f"{'='*80}")
    
    if scaling_tests['x']:
        print(f"\nInput (X) Scaling:")
        print(f"{'  Scale':<12} {'Loss':<12} {'vs Standard':<15}")
        print(f"  {'-'*40}")
        for scale in sorted(scaling_tests['x'].keys()):
            loss = scaling_tests['x'][scale]
            diff = ((loss / standard_loss - 1) * 100) if standard_loss else 0
            print(f"  {scale:<12.3f} {loss:<12.4f} {diff:>+12.1f}%")
    
    if scaling_tests['y']:
        print(f"\nOutput (Y) Scaling:")
        print(f"{'  Scale':<12} {'Loss':<12} {'vs Standard':<15}")
        print(f"  {'-'*40}")
        for scale in sorted(scaling_tests['y'].keys()):
            loss = scaling_tests['y'][scale]
            diff = ((loss / standard_loss - 1) * 100) if standard_loss else 0
            print(f"  {scale:<12.3f} {loss:<12.4f} {diff:>+12.1f}%")
    
    print(f"\n{'='*80}")
else:
    print("No scaling test results found.")


## 12. Complete Evaluation Summary

Generate a comprehensive summary of all evaluation results.


In [None]:
# Generate comprehensive summary
print(f"\n{'='*80}")
print(f"{'='*80}")
print(f"COMPLETE EVALUATION SUMMARY")
print(f"{'='*80}")
print(f"{'='*80}")

print(f"\nüìã MODEL INFORMATION:")
print(f"  Run ID: {run_id}")
print(f"  Model Family: {conf.model.family}")
print(f"  Architecture: {conf.model.n_layer} layers, {conf.model.n_head} heads, {conf.model.n_embd} embedding dim")
print(f"  Task: {conf.training.task}")
print(f"  Training dimensions: {conf.model.n_dims}")
print(f"  Context length: {conf.model.n_positions}")

print(f"\nüìä PERFORMANCE METRICS:")
if "standard" in all_metrics and model_name in all_metrics["standard"]:
    final_loss = all_metrics["standard"][model_name]["mean"][-1]
    improvement = (1 - final_loss / n_dims) * 100
    print(f"  Standard (in-distribution) loss: {final_loss:.4f}")
    print(f"  Improvement over baseline: {improvement:.1f}%")

print(f"\nüéØ BASELINE COMPARISONS:")
if "standard" in all_metrics:
    for baseline_key in ['OLS_driver=None', 'averaging', 'NN_n=3_uniform']:
        if baseline_key in all_metrics["standard"]:
            baseline_loss = all_metrics["standard"][baseline_key]["mean"][-1]
            baseline_name = baseline_names(baseline_key)
            if model_name in all_metrics["standard"]:
                model_loss = all_metrics["standard"][model_name]["mean"][-1]
                advantage = ((baseline_loss - model_loss) / baseline_loss * 100)
                symbol = "‚úÖ" if advantage > 0 else "‚ö†Ô∏è"
                print(f"  {symbol} vs {baseline_name}: {advantage:+.1f}% {'better' if advantage > 0 else 'worse'}")

print(f"\nüåê OUT-OF-DISTRIBUTION ROBUSTNESS:")
if ood_results:
    for scenario, loss in ood_results.items():
        clean_name = scenario.replace('_', ' ').title()
        if standard_loss:
            degradation = ((loss / standard_loss - 1) * 100)
            status = "‚úÖ" if abs(degradation) < 15 else "‚ö†Ô∏è" if abs(degradation) < 30 else "‚ùå"
            print(f"  {status} {clean_name}: {loss:.4f} ({degradation:+.1f}%)")
        else:
            print(f"  ‚Ä¢ {clean_name}: {loss:.4f}")

print(f"\n‚öñÔ∏è  SCALING ROBUSTNESS:")
if scaling_tests['x'] or scaling_tests['y']:
    if scaling_tests['x']:
        max_x_deg = max([abs((loss / standard_loss - 1) * 100) for loss in scaling_tests['x'].values()]) if standard_loss else 0
        print(f"  Input (X) scaling: max degradation {max_x_deg:.1f}%")
    if scaling_tests['y']:
        max_y_deg = max([abs((loss / standard_loss - 1) * 100) for loss in scaling_tests['y'].values()]) if standard_loss else 0
        print(f"  Output (Y) scaling: max degradation {max_y_deg:.1f}%")

print(f"\nüíæ SAVED RESULTS:")
print(f"  Metrics file: {os.path.join(run_path, 'metrics.json')}")
print(f"  Model checkpoint: {os.path.join(run_path, 'state.pt')}")
print(f"  Configuration: {os.path.join(run_path, 'config.yaml')}")

print(f"\n{'='*80}")
print(f"‚úÖ EVALUATION COMPLETE!")
print(f"{'='*80}")
print(f"\nAll results have been cached and can be reloaded without recomputing.")
print(f"To share results, download the metrics.json file from: {run_path}")


## 13. Export/Download Results (Optional)

Download evaluation results to your local machine.


In [None]:
# Uncomment to download files when running on Colab
# from google.colab import files

# Download metrics
# metrics_file = os.path.join(run_path, "metrics.json")
# if os.path.exists(metrics_file):
#     print(f"Downloading {metrics_file}...")
#     files.download(metrics_file)
#     print("‚úì Metrics downloaded")

# Download config
# config_file = os.path.join(run_path, "config.yaml")
# if os.path.exists(config_file):
#     print(f"Downloading {config_file}...")
#     files.download(config_file)
#     print("‚úì Config downloaded")

# Create a summary report
summary_file = os.path.join(run_path, "evaluation_summary.txt")
with open(summary_file, 'w') as f:
    f.write("="*80 + "\n")
    f.write("EVALUATION SUMMARY\n")
    f.write("="*80 + "\n\n")
    
    f.write(f"Model: {conf.model.family}\n")
    f.write(f"Task: {conf.training.task}\n")
    f.write(f"Run ID: {run_id}\n\n")
    
    if "standard" in all_metrics and model_name in all_metrics["standard"]:
        final_loss = all_metrics["standard"][model_name]["mean"][-1]
        f.write(f"Standard Loss: {final_loss:.4f}\n")
        f.write(f"Improvement: {(1 - final_loss / n_dims) * 100:.1f}%\n\n")
    
    f.write("Evaluation Scenarios:\n")
    for scenario in all_metrics.keys():
        if model_name in all_metrics[scenario]:
            loss = all_metrics[scenario][model_name]["mean"][-1]
            f.write(f"  - {scenario}: {loss:.4f}\n")

print(f"\n‚úì Summary saved to: {summary_file}")

# Uncomment to download summary
# files.download(summary_file)

print("\nüí° Tip: Run this notebook anytime to re-evaluate or visualize cached results!")
