# Test Model on Multiple Datasets

This notebook evaluates a trained model on three test datasets:
- **Natural**: Uniformly random samples from [1, 10^13]
- **Cheat**: Numbers with prime factors only within the first 100 primes
- **Non-cheat**: Numbers with at least one prime factor outside the first 100 primes

It reports per-class performance for each dataset.

In [None]:
import sys
import os
import subprocess
import json
import re
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 8)

## Configuration

Specify the model checkpoint and encoding to test:

In [None]:
# Model configuration
ENCODING = "CRT100_with_stats"  # Options: interCRT100, CRT100, interCRT100_with_n, CRT100_with_stats
TASK = "mu"  # Options: mu, musq
MODEL_DIR = "models/train_CRT100stat"  # Path to the directory containing the trained model
CHECKPOINT_NAME = "best-valid_arithmetic_acc"  # Name of checkpoint file (without .pth)

# Test datasets directory (will be created if needed)
INPUT_BASE_DIR = Path("input")

# Results output directory
RESULTS_DIR = Path("test_results") / f"{ENCODING}_{TASK}"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# Test dataset types
DATASET_TYPES = ["natural", "cheat", "non_cheat"]

# Int2Int directory
INT2INT_DIR = Path("Int2Int")

print(f"Testing model: {MODEL_DIR}/{CHECKPOINT_NAME}.pth")
print(f"Encoding: {ENCODING}")
print(f"Task: {TASK}")
print(f"Results will be saved to: {RESULTS_DIR}")

## Helper Functions

In [None]:
def get_encoding_params(encoding):
    """
    Get data_types parameter for each encoding.
    """
    encoding_map = {
        "interCRT100": "int[200]:range(-1,2)",
        "CRT100": "int[100]:range(-1,2)",
        "interCRT100_with_n": "int[201]:range(-1,2)",
        "CRT100_with_stats": "int[103]:range(-1,2)"
    }
    return encoding_map.get(encoding, "int[200]:range(-1,2)")


def get_test_data_path(encoding, dataset_type, task):
    """
    Construct the path to test data file.
    """
    if task == "mu":
        base_name = f"mu_{encoding}"
    elif task == "musq":
        base_name = f"musq_{encoding}"
    else:
        base_name = f"{task}_{encoding}"
    
    # Construct directory name
    input_dir = INPUT_BASE_DIR / f"input_dir_{encoding}_{dataset_type}"
    test_file = input_dir / f"{base_name}_{dataset_type}.txt.test"
    
    return test_file


def parse_log_output(log_text):
    """
    Parse the evaluation output to extract metrics.
    """
    # Find the __log__ JSON line
    log_match = re.search(r'__log__:({.+})', log_text)
    if log_match:
        try:
            return json.loads(log_match.group(1))
        except json.JSONDecodeError as e:
            print(f"Error parsing JSON: {e}")
            return None
    return None


def interpret_class_id(class_id, task):
    """
    Interpret class ID based on task type.
    """
    if task == 'mu':
        if class_id == '0':
            return 'μ(n) = 0'
        elif class_id == '1':
            return 'μ(n) = 1'
        elif class_id == '100':
            return 'μ(n) = -1'
        else:
            return f'Class {class_id}'
    elif task == 'musq':
        if class_id == '0':
            return 'μ²(n) = 0'
        elif class_id == '1':
            return 'μ²(n) = 1'
        else:
            return f'Class {class_id}'
    else:
        return f'Class {class_id}'


print("Helper functions loaded!")

## Check if Model Checkpoint Exists

In [None]:
checkpoint_path = Path(MODEL_DIR) / f"{CHECKPOINT_NAME}.pth"

if checkpoint_path.exists():
    print(f"✓ Checkpoint found: {checkpoint_path}")
else:
    print(f"✗ Checkpoint NOT found: {checkpoint_path}")
    print(f"\nAvailable files in {MODEL_DIR}:")
    model_dir_path = Path(MODEL_DIR)
    if model_dir_path.exists():
        for f in model_dir_path.iterdir():
            print(f"  - {f.name}")
    else:
        print(f"  Directory {MODEL_DIR} does not exist!")
    
    print("\n" + "="*80)
    print("NOTE: If no checkpoint exists, you need to:")
    print("1. Train a model with --save_periodic flag, OR")
    print("2. The model will auto-save 'best-valid_arithmetic_acc.pth' during training")
    print("3. Check the train.log to see if 'Saving best' messages appear")
    print("="*80)

## Check Test Data Files

In [None]:
print("Checking test data files:\n")
test_files = {}

for dataset_type in DATASET_TYPES:
    test_file = get_test_data_path(ENCODING, dataset_type, TASK)
    test_files[dataset_type] = test_file
    
    if test_file.exists():
        file_size = test_file.stat().st_size / (1024 * 1024)  # MB
        print(f"✓ {dataset_type:12s}: {test_file} ({file_size:.2f} MB)")
    else:
        print(f"✗ {dataset_type:12s}: {test_file} NOT FOUND")

print("\n" + "="*80)
print("NOTE: If test files don't exist, generate them using:")
print(f"  make -C src/run_int2int_scripts data_all ENCODING={ENCODING}")
print("="*80)

## Run Evaluation on Each Test Dataset

In [None]:
def run_evaluation(model_dir, checkpoint_name, test_data_path, encoding, task, output_file):
    """
    Run evaluation using the Int2Int training script with --eval_only flag.
    """
    data_types = get_encoding_params(encoding)
    checkpoint_path = Path(model_dir) / f"{checkpoint_name}.pth"
    
    # Construct the command
    cmd = [
        "python", str(INT2INT_DIR / "train.py"),
        "--eval_only", "True",
        "--eval_from_exp", str(Path(model_dir).absolute()),
        "--reload_checkpoint", str(checkpoint_path.absolute()),
        "--eval_data", str(test_data_path.absolute()),
        "--eval_size", "10000",
        "--data_types", data_types,
        "--operation", "data",
        "--cpu", "True",  # Run on CPU for notebook compatibility
        "--num_workers", "0",
    ]
    
    print(f"\nRunning command:")
    print(" ".join(cmd))
    print("\n" + "="*80)
    
    # Run the command
    try:
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            timeout=600  # 10 minute timeout
        )
        
        # Save full output to file
        with open(output_file, 'w') as f:
            f.write("STDOUT:\n")
            f.write(result.stdout)
            f.write("\n\nSTDERR:\n")
            f.write(result.stderr)
        
        print(f"Full output saved to: {output_file}")
        
        # Parse metrics from output
        metrics = parse_log_output(result.stdout)
        if metrics is None:
            metrics = parse_log_output(result.stderr)
        
        if result.returncode != 0:
            print(f"\n⚠ Warning: Command exited with code {result.returncode}")
            print("Last 50 lines of stderr:")
            print("\n".join(result.stderr.split("\n")[-50:]))
        
        return metrics, result
        
    except subprocess.TimeoutExpired:
        print("\n✗ Evaluation timed out after 10 minutes")
        return None, None
    except Exception as e:
        print(f"\n✗ Error running evaluation: {e}")
        return None, None


print("Evaluation function ready!")

## Evaluate on Natural Dataset

In [None]:
dataset_type = "natural"
test_file = test_files[dataset_type]
output_file = RESULTS_DIR / f"eval_{dataset_type}.log"

if not test_file.exists():
    print(f"✗ Test file not found: {test_file}")
    natural_metrics = None
elif not checkpoint_path.exists():
    print(f"✗ Checkpoint not found: {checkpoint_path}")
    natural_metrics = None
else:
    print(f"\n{'='*80}")
    print(f"EVALUATING ON NATURAL DATASET")
    print(f"{'='*80}")
    natural_metrics, result = run_evaluation(
        MODEL_DIR, CHECKPOINT_NAME, test_file, ENCODING, TASK, output_file
    )
    
    if natural_metrics:
        print("\n✓ Evaluation completed successfully!")
        print(f"\nOverall Accuracy: {natural_metrics.get('valid_arithmetic_acc', 'N/A'):.2f}%")
    else:
        print("\n✗ Failed to parse metrics from evaluation output")

## Evaluate on Cheat Dataset

In [None]:
dataset_type = "cheat"
test_file = test_files[dataset_type]
output_file = RESULTS_DIR / f"eval_{dataset_type}.log"

if not test_file.exists():
    print(f"✗ Test file not found: {test_file}")
    cheat_metrics = None
elif not checkpoint_path.exists():
    print(f"✗ Checkpoint not found: {checkpoint_path}")
    cheat_metrics = None
else:
    print(f"\n{'='*80}")
    print(f"EVALUATING ON CHEAT DATASET")
    print(f"{'='*80}")
    cheat_metrics, result = run_evaluation(
        MODEL_DIR, CHECKPOINT_NAME, test_file, ENCODING, TASK, output_file
    )
    
    if cheat_metrics:
        print("\n✓ Evaluation completed successfully!")
        print(f"\nOverall Accuracy: {cheat_metrics.get('valid_arithmetic_acc', 'N/A'):.2f}%")
    else:
        print("\n✗ Failed to parse metrics from evaluation output")

## Evaluate on Non-Cheat Dataset

In [None]:
dataset_type = "non_cheat"
test_file = test_files[dataset_type]
output_file = RESULTS_DIR / f"eval_{dataset_type}.log"

if not test_file.exists():
    print(f"✗ Test file not found: {test_file}")
    non_cheat_metrics = None
elif not checkpoint_path.exists():
    print(f"✗ Checkpoint not found: {checkpoint_path}")
    non_cheat_metrics = None
else:
    print(f"\n{'='*80}")
    print(f"EVALUATING ON NON-CHEAT DATASET")
    print(f"{'='*80}")
    non_cheat_metrics, result = run_evaluation(
        MODEL_DIR, CHECKPOINT_NAME, test_file, ENCODING, TASK, output_file
    )
    
    if non_cheat_metrics:
        print("\n✓ Evaluation completed successfully!")
        print(f"\nOverall Accuracy: {non_cheat_metrics.get('valid_arithmetic_acc', 'N/A'):.2f}%")
    else:
        print("\n✗ Failed to parse metrics from evaluation output")

## Compile Results

In [None]:
# Collect all metrics
all_metrics = {
    'natural': natural_metrics,
    'cheat': cheat_metrics,
    'non_cheat': non_cheat_metrics
}

# Filter out None values
all_metrics = {k: v for k, v in all_metrics.items() if v is not None}

if not all_metrics:
    print("\n✗ No evaluation results available. Please check:")
    print("  1. Model checkpoint exists")
    print("  2. Test data files exist")
    print("  3. Evaluation ran without errors")
else:
    print(f"\n✓ Successfully evaluated on {len(all_metrics)} dataset(s)")
    
    # Save combined results
    results_json = RESULTS_DIR / "all_results.json"
    with open(results_json, 'w') as f:
        json.dump(all_metrics, f, indent=2)
    print(f"\nResults saved to: {results_json}")

## Overall Performance Comparison

In [None]:
if all_metrics:
    print("="*80)
    print("OVERALL PERFORMANCE SUMMARY")
    print("="*80)
    
    # Create summary table
    summary_data = []
    for dataset_type, metrics in all_metrics.items():
        summary_data.append({
            'Dataset': dataset_type.replace('_', '-').title(),
            'Accuracy (%)': metrics.get('valid_arithmetic_acc', 0),
            'Perfect (%)': metrics.get('valid_arithmetic_perfect', 0),
            'Correct (%)': metrics.get('valid_arithmetic_correct', 0),
            'XE Loss': metrics.get('valid_arithmetic_xe_loss', 0)
        })
    
    summary_df = pd.DataFrame(summary_data)
    display(summary_df)
    
    # Save to CSV
    summary_csv = RESULTS_DIR / "overall_summary.csv"
    summary_df.to_csv(summary_csv, index=False)
    print(f"\nSummary saved to: {summary_csv}")
else:
    print("No metrics to display")

## Per-Class Performance Analysis

In [None]:
def extract_class_metrics(metrics, task):
    """
    Extract per-class accuracy metrics.
    """
    class_data = []
    
    for key, value in metrics.items():
        if key.startswith('valid_arithmetic_acc_'):
            class_id = key.split('_')[-1]
            if class_id not in ['d1', 'd2', 'd3']:
                class_data.append({
                    'Class ID': class_id,
                    'Class': interpret_class_id(class_id, task),
                    'Accuracy (%)': value
                })
    
    # Sort by class ID
    class_data.sort(key=lambda x: int(x['Class ID']) if x['Class ID'].isdigit() else float('inf'))
    return class_data


if all_metrics:
    print("="*80)
    print("PER-CLASS PERFORMANCE")
    print("="*80)
    
    for dataset_type, metrics in all_metrics.items():
        print(f"\n{dataset_type.upper().replace('_', '-')} DATASET:")
        print("-" * 80)
        
        class_data = extract_class_metrics(metrics, TASK)
        
        if class_data:
            class_df = pd.DataFrame(class_data)
            display(class_df)
            
            # Save to CSV
            class_csv = RESULTS_DIR / f"class_performance_{dataset_type}.csv"
            class_df.to_csv(class_csv, index=False)
            print(f"Saved to: {class_csv}")
        else:
            print("No per-class metrics found")
else:
    print("No metrics to analyze")

## Visualization: Overall Performance Comparison

In [None]:
if all_metrics and len(all_metrics) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot 1: Accuracy comparison
    ax = axes[0]
    datasets = [k.replace('_', '-').title() for k in all_metrics.keys()]
    accuracies = [v.get('valid_arithmetic_acc', 0) for v in all_metrics.values()]
    colors = ['#2ecc71', '#e74c3c', '#3498db']
    
    bars = ax.bar(datasets, accuracies, color=colors[:len(datasets)], alpha=0.8, edgecolor='black', linewidth=2)
    
    # Add value labels
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{acc:.2f}%',
               ha='center', va='bottom', fontweight='bold', fontsize=12)
    
    ax.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
    ax.set_title(f'Overall Accuracy Comparison\n{ENCODING}, {TASK.upper()} Task', 
                fontsize=14, fontweight='bold')
    ax.set_ylim([0, 105])
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 2: XE Loss comparison
    ax = axes[1]
    losses = [v.get('valid_arithmetic_xe_loss', 0) for v in all_metrics.values()]
    
    bars = ax.bar(datasets, losses, color=colors[:len(datasets)], alpha=0.8, edgecolor='black', linewidth=2)
    
    # Add value labels
    for bar, loss in zip(bars, losses):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{loss:.4f}',
               ha='center', va='bottom', fontweight='bold', fontsize=12)
    
    ax.set_ylabel('Cross-Entropy Loss', fontsize=13, fontweight='bold')
    ax.set_title(f'Cross-Entropy Loss Comparison\n{ENCODING}, {TASK.upper()} Task', 
                fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    
    # Save figure
    fig_path = RESULTS_DIR / "overall_comparison.png"
    plt.savefig(fig_path, dpi=300, bbox_inches='tight')
    print(f"Figure saved to: {fig_path}")
    
    plt.show()
else:
    print("No metrics to visualize")

## Visualization: Per-Class Performance Comparison

In [None]:
if all_metrics:
    # Collect per-class data for all datasets
    class_comparison = {}
    
    for dataset_type, metrics in all_metrics.items():
        for key, value in metrics.items():
            if key.startswith('valid_arithmetic_acc_'):
                class_id = key.split('_')[-1]
                if class_id not in ['d1', 'd2', 'd3']:
                    if class_id not in class_comparison:
                        class_comparison[class_id] = {}
                    class_comparison[class_id][dataset_type] = value
    
    if class_comparison:
        # Create DataFrame for plotting
        plot_data = []
        for class_id, datasets in class_comparison.items():
            for dataset_type, acc in datasets.items():
                plot_data.append({
                    'Class': interpret_class_id(class_id, TASK),
                    'Dataset': dataset_type.replace('_', '-').title(),
                    'Accuracy': acc
                })
        
        plot_df = pd.DataFrame(plot_data)
        
        # Create grouped bar chart
        fig, ax = plt.subplots(figsize=(14, 8))
        
        # Get unique classes and datasets
        classes = plot_df['Class'].unique()
        datasets = plot_df['Dataset'].unique()
        
        x = np.arange(len(classes))
        width = 0.25
        colors = ['#2ecc71', '#e74c3c', '#3498db']
        
        # Plot bars for each dataset
        for i, dataset in enumerate(datasets):
            dataset_data = plot_df[plot_df['Dataset'] == dataset]
            values = [dataset_data[dataset_data['Class'] == c]['Accuracy'].values[0] 
                     if len(dataset_data[dataset_data['Class'] == c]) > 0 else 0 
                     for c in classes]
            
            offset = width * (i - len(datasets)/2 + 0.5)
            bars = ax.bar(x + offset, values, width, label=dataset, 
                         color=colors[i % len(colors)], alpha=0.8, edgecolor='black')
            
            # Add value labels
            for bar, val in zip(bars, values):
                if val > 0:
                    height = bar.get_height()
                    ax.text(bar.get_x() + bar.get_width()/2., height,
                           f'{val:.1f}',
                           ha='center', va='bottom', fontsize=9, fontweight='bold')
        
        ax.set_xlabel('Class', fontsize=13, fontweight='bold')
        ax.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
        ax.set_title(f'Per-Class Accuracy Comparison Across Datasets\n{ENCODING}, {TASK.upper()} Task',
                    fontsize=14, fontweight='bold')
        ax.set_xticks(x)
        ax.set_xticklabels(classes, rotation=15, ha='right')
        ax.legend(fontsize=11)
        ax.set_ylim([0, 105])
        ax.grid(True, alpha=0.3, axis='y')
        
        plt.tight_layout()
        
        # Save figure
        fig_path = RESULTS_DIR / "per_class_comparison.png"
        plt.savefig(fig_path, dpi=300, bbox_inches='tight')
        print(f"Figure saved to: {fig_path}")
        
        plt.show()
    else:
        print("No per-class metrics to visualize")
else:
    print("No metrics to visualize")

## Final Summary Report

In [None]:
if all_metrics:
    print("="*80)
    print("FINAL TEST PERFORMANCE SUMMARY")
    print("="*80)
    print(f"\nModel: {MODEL_DIR}")
    print(f"Checkpoint: {CHECKPOINT_NAME}")
    print(f"Encoding: {ENCODING}")
    print(f"Task: {TASK.upper()}")
    print(f"\nResults saved to: {RESULTS_DIR}")
    print("\n" + "="*80)
    print("\nFiles generated:")
    for f in RESULTS_DIR.iterdir():
        print(f"  - {f.name}")
    print("\n" + "="*80)
else:
    print("\nNo results to summarize. Please check the error messages above.")