# Model Performance Analysis

This notebook provides a comprehensive analysis of trained Möbius function models, including:
- Overall accuracy metrics
- Per-class accuracy (for μ(n) = -1, 0, 1)
- Training configuration (data generation, encoding format, model architecture)
- Training curves and convergence analysis

**Note:** Update the `path` and `xp_env` variables below to point to your trained models.

In [None]:
# Configuration
path = "../models/model_CRT100_with_stats"  # Path to trained models
xp_env = ["mu"]  # Experiment names: basic=μ(n), musq=μ²(n)
indicator = "valid_arithmetic"  # Metric prefix

In [None]:
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import ast
from datetime import datetime
from tabulate import tabulate
import pandas as pd
from collections import defaultdict

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11

## 1. Load Experiments

In [None]:
# Find all experiments
xps = [(env, xp) for env in xp_env for xp in os.listdir(path + '/' + env) 
       if os.path.isdir(path + env + '/' + xp)]
names = [path + env + '/' + xp for (env, xp) in xps]

print(f"Found {len(names)} experiments:")
for env, xp in xps:
    print(f"  - {env}/{xp}")

## 2. Parse Training Logs

In [None]:
def parse_training_log(log_path, indicator="valid_arithmetic"):
    """
    Parse training log to extract metrics including per-class accuracy.
    """
    if not os.path.exists(log_path):
        return None
    
    epochs = []
    overall_acc = []
    overall_loss = []
    perfect_match = []
    
    # Per-class accuracy tracking
    class_acc = defaultdict(list)
    
    with open(log_path, 'r') as f:
        for line in f:
            if '__log__:' in line:
                try:
                    # Extract JSON log
                    pos = line.find('__log__:')
                    json_str = line[pos + 8:]
                    data = ast.literal_eval(json_str)
                    
                    epoch = data.get('epoch', -1)
                    if epoch < 0:
                        continue
                    
                    epochs.append(epoch)
                    overall_acc.append(data.get(f'{indicator}_acc', 0.0))
                    overall_loss.append(data.get(f'{indicator}_xe_loss', 0.0))
                    perfect_match.append(data.get(f'{indicator}_perfect', 0.0))
                    
                    # Extract per-class accuracy
                    for key, value in data.items():
                        if key.startswith(f'{indicator}_acc_') and key != f'{indicator}_acc':
                            # Extract class name (e.g., "0", "1", "100" for -1)
                            class_name = key.replace(f'{indicator}_acc_', '')
                            class_acc[class_name].append(value)
                    
                except Exception as e:
                    print(f"Error parsing line: {e}")
                    continue
    
    # Find best epoch
    if overall_acc:
        best_epoch_idx = np.argmax(overall_acc)
        best_epoch = epochs[best_epoch_idx]
        best_acc = overall_acc[best_epoch_idx]
    else:
        best_epoch = -1
        best_acc = 0.0
    
    return {
        'epochs': epochs,
        'overall_acc': overall_acc,
        'overall_loss': overall_loss,
        'perfect_match': perfect_match,
        'class_acc': dict(class_acc),
        'best_epoch': best_epoch,
        'best_acc': best_acc,
        'final_epoch': epochs[-1] if epochs else -1,
        'final_acc': overall_acc[-1] if overall_acc else 0.0
    }


def load_params(params_path):
    """
    Load experiment parameters from pickle file.
    """
    if not os.path.exists(params_path):
        return None
    
    with open(params_path, 'rb') as f:
        params = pickle.load(f)
    
    return params


print("Functions defined successfully.")

In [None]:
# Load all experiment data
experiments = []

for env, xp in xps:
    exp_path = path + env + '/' + xp
    log_path = exp_path + '/train.log'
    params_path = exp_path + '/params.pkl'
    
    # Parse training log
    metrics = parse_training_log(log_path, indicator)
    
    # Load parameters
    params = load_params(params_path)
    
    if metrics and params:
        experiments.append({
            'name': f"{env}/{xp}",
            'env': env,
            'xp': xp,
            'metrics': metrics,
            'params': params
        })

print(f"Successfully loaded {len(experiments)} experiments with complete data.")

## 3. Experiment Summary

In [None]:
# Create summary table
summary_data = []

for exp in experiments:
    m = exp['metrics']
    p = exp['params']
    
    # Get per-class accuracy at best epoch
    best_idx = exp['metrics']['epochs'].index(m['best_epoch']) if m['best_epoch'] >= 0 else -1
    
    class_acc_str = ""
    if best_idx >= 0:
        class_accs = []
        for class_name in sorted(m['class_acc'].keys()):
            if len(m['class_acc'][class_name]) > best_idx:
                acc = m['class_acc'][class_name][best_idx]
                # Map class names: 100 -> -1 for Möbius
                display_name = "-1" if class_name == "100" else class_name
                class_accs.append(f"{display_name}:{acc:.1f}%")
        class_acc_str = ", ".join(class_accs)
    
    summary_data.append({
        'Experiment': exp['name'],
        'Task': 'μ(n)' if exp['env'] == 'basic' else 'μ²(n)',
        'Best Epoch': m['best_epoch'],
        'Best Acc (%)': f"{m['best_acc']:.2f}",
        'Final Epoch': m['final_epoch'],
        'Final Acc (%)': f"{m['final_acc']:.2f}",
        'Per-Class Acc (%)': class_acc_str
    })

df_summary = pd.DataFrame(summary_data)
print("\n" + "="*80)
print("EXPERIMENT SUMMARY")
print("="*80)
print(df_summary.to_string(index=False))
print("="*80)

## 4. Experiment Configuration Details

In [None]:
for exp in experiments:
    p = exp['params']
    
    print("\n" + "="*80)
    print(f"CONFIGURATION: {exp['name']}")
    print("="*80)
    
    print("\n### Data Configuration ###")
    print(f"  Operation:          {p.operation if hasattr(p, 'operation') else 'N/A'}")
    print(f"  Data Types:         {p.data_types if hasattr(p, 'data_types') else 'N/A'}")
    print(f"  Training Data:      {p.train_data if hasattr(p, 'train_data') else 'N/A'}")
    print(f"  Eval Data:          {p.eval_data if hasattr(p, 'eval_data') else 'N/A'}")
    print(f"  Encoding Base:      {p.base if hasattr(p, 'base') else 'N/A'}")
    
    print("\n### Data Encoding Format ###")
    print(f"  Format: Interleaved CRT representation")
    print(f"  Structure: [n mod p₁, p₁, n mod p₂, p₂, ..., n mod p₁₀₀, p₁₀₀]")
    print(f"  Number of primes: 100 (primes ≤ 542)")
    print(f"  Vector length: 200 (2 × 100 primes)")
    print(f"  Integer encoding: Base-{p.base if hasattr(p, 'base') else 1000} positional notation")
    
    print("\n### Model Architecture ###")
    print(f"  Architecture:       {p.architecture if hasattr(p, 'architecture') else 'N/A'}")
    print(f"  Model Type:         {'LSTM' if (hasattr(p, 'lstm') and p.lstm) else 'Transformer'}")
    print(f"  Encoder Layers:     {p.n_enc_layers if hasattr(p, 'n_enc_layers') else 'N/A'}")
    print(f"  Decoder Layers:     {p.n_dec_layers if hasattr(p, 'n_dec_layers') else 'N/A'}")
    print(f"  Encoder Embed Dim:  {p.enc_emb_dim if hasattr(p, 'enc_emb_dim') else 'N/A'}")
    print(f"  Decoder Embed Dim:  {p.dec_emb_dim if hasattr(p, 'dec_emb_dim') else 'N/A'}")
    print(f"  Attention Heads:    {p.n_enc_heads if hasattr(p, 'n_enc_heads') else 'N/A'} (encoder), {p.n_dec_heads if hasattr(p, 'n_dec_heads') else 'N/A'} (decoder)")
    print(f"  Dropout:            {p.dropout if hasattr(p, 'dropout') else 'N/A'}")
    print(f"  Attention Dropout:  {p.attention_dropout if hasattr(p, 'attention_dropout') else 'N/A'}")
    
    print("\n### Training Configuration ###")
    print(f"  Optimizer:          {p.optimizer if hasattr(p, 'optimizer') else 'N/A'}")
    print(f"  Batch Size:         {p.batch_size if hasattr(p, 'batch_size') else 'N/A'}")
    print(f"  Epoch Size:         {p.epoch_size if hasattr(p, 'epoch_size') else 'N/A'}")
    print(f"  Max Epochs:         {p.max_epoch if hasattr(p, 'max_epoch') else 'N/A'}")
    print(f"  Eval Size:          {p.eval_size if hasattr(p, 'eval_size') else 'N/A'}")
    print(f"  Gradient Clipping:  {p.clip_grad_norm if hasattr(p, 'clip_grad_norm') else 'N/A'}")
    print(f"  Env Base Seed:      {p.env_base_seed if hasattr(p, 'env_base_seed') else 'N/A'}")
    print(f"  Random Seed:        {p.seed if hasattr(p, 'seed') else 'Not set (non-reproducible)'}")
    
    print("\n### Dataset Generation ###")
    print(f"  Total samples:      1,000,000")
    print(f"  Training samples:   900,000")
    print(f"  Test samples:       100,000")
    print(f"  Integer range:      [2, 10¹³]")
    print(f"  Generation method:  Random sampling with Möbius function computed via C++ library")
    
    print("="*80)

## 5. Training Curves

In [None]:
# Plot overall accuracy
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

for exp in experiments:
    m = exp['metrics']
    label = f"{exp['env']} (best: {m['best_acc']:.2f}% @ epoch {m['best_epoch']})"
    ax1.plot(m['epochs'], m['overall_acc'], linewidth=2, label=label, alpha=0.8)
    ax2.plot(m['epochs'], m['overall_loss'], linewidth=2, label=exp['env'], alpha=0.8)

ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.set_title('Overall Accuracy', fontsize=14, fontweight='bold')
ax1.legend(loc='lower right')
ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Cross-Entropy Loss', fontsize=12)
ax2.set_title('Training Loss', fontsize=14, fontweight='bold')
ax2.legend(loc='upper right')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../notebooks/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("Training curves saved to: ../notebooks/training_curves.png")

## 6. Per-Class Accuracy Analysis

In [None]:
# Map class names for display
def map_class_name(class_name):
    """Map internal class names to human-readable Möbius values."""
    if class_name == "100":
        return "μ=-1"
    elif class_name == "0":
        return "μ=0"
    elif class_name == "1":
        return "μ=1"
    else:
        return f"class {class_name}"


# Plot per-class accuracy for each experiment
for exp in experiments:
    m = exp['metrics']
    
    if not m['class_acc']:
        print(f"No per-class data for {exp['name']}")
        continue
    
    fig, ax = plt.subplots(figsize=(14, 6))
    
    # Plot each class
    for class_name in sorted(m['class_acc'].keys()):
        if not m['class_acc'][class_name]:
            continue
        
        display_name = map_class_name(class_name)
        class_epochs = m['epochs'][:len(m['class_acc'][class_name])]
        
        ax.plot(class_epochs, m['class_acc'][class_name], 
                linewidth=2, label=display_name, alpha=0.8, marker='o', markersize=2)
    
    # Plot overall accuracy for comparison
    ax.plot(m['epochs'], m['overall_acc'], 
            linewidth=2.5, label='Overall', alpha=0.9, linestyle='--', color='black')
    
    ax.set_xlabel('Epoch', fontsize=12)
    ax.set_ylabel('Accuracy (%)', fontsize=12)
    ax.set_title(f'Per-Class Accuracy: {exp["name"]}', fontsize=14, fontweight='bold')
    ax.legend(loc='best', fontsize=11)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 105)
    
    plt.tight_layout()
    filename = f'../notebooks/per_class_accuracy_{exp["env"]}.png'
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Per-class accuracy plot saved to: {filename}")

## 7. Final Per-Class Accuracy Comparison

In [None]:
# Create bar chart comparing per-class accuracy at best epoch
fig, axes = plt.subplots(1, len(experiments), figsize=(8*len(experiments), 6))
if len(experiments) == 1:
    axes = [axes]

for idx, exp in enumerate(experiments):
    m = exp['metrics']
    ax = axes[idx]
    
    # Get accuracy at best epoch
    best_idx = m['epochs'].index(m['best_epoch']) if m['best_epoch'] >= 0 else -1
    
    if best_idx < 0:
        continue
    
    class_names = []
    class_values = []
    
    for class_name in sorted(m['class_acc'].keys()):
        if len(m['class_acc'][class_name]) > best_idx:
            class_names.append(map_class_name(class_name))
            class_values.append(m['class_acc'][class_name][best_idx])
    
    # Add overall accuracy
    class_names.append('Overall')
    class_values.append(m['best_acc'])
    
    # Create bar chart
    colors = ['#ff7f0e', '#2ca02c', '#d62728', '#1f77b4']  # Colors for each bar
    bars = ax.bar(class_names, class_values, color=colors[:len(class_names)], alpha=0.8)
    
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.1f}%',
                ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    ax.set_ylabel('Accuracy (%)', fontsize=12)
    ax.set_title(f'{exp["name"]}\n(Best Epoch: {m["best_epoch"]})', 
                 fontsize=13, fontweight='bold')
    ax.set_ylim(0, 105)
    ax.grid(True, axis='y', alpha=0.3)

plt.tight_layout()
plt.savefig('../notebooks/per_class_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("Per-class comparison saved to: ../notebooks/per_class_comparison.png")

## 8. Detailed Statistics

In [None]:
for exp in experiments:
    m = exp['metrics']
    
    print("\n" + "="*80)
    print(f"DETAILED STATISTICS: {exp['name']}")
    print("="*80)
    
    print(f"\n### Overall Performance ###")
    print(f"  Best Accuracy:      {m['best_acc']:.2f}% (epoch {m['best_epoch']})")
    print(f"  Final Accuracy:     {m['final_acc']:.2f}% (epoch {m['final_epoch']})")
    print(f"  Total Epochs:       {m['final_epoch'] + 1}")
    
    if m['class_acc']:
        print(f"\n### Per-Class Performance (at best epoch {m['best_epoch']}) ###")
        best_idx = m['epochs'].index(m['best_epoch']) if m['best_epoch'] >= 0 else -1
        
        if best_idx >= 0:
            for class_name in sorted(m['class_acc'].keys()):
                if len(m['class_acc'][class_name]) > best_idx:
                    display_name = map_class_name(class_name)
                    acc = m['class_acc'][class_name][best_idx]
                    print(f"  {display_name:>8}:  {acc:.2f}%")
        
        print(f"\n### Per-Class Performance (final epoch {m['final_epoch']}) ###")
        for class_name in sorted(m['class_acc'].keys()):
            if m['class_acc'][class_name]:
                display_name = map_class_name(class_name)
                acc = m['class_acc'][class_name][-1]
                print(f"  {display_name:>8}:  {acc:.2f}%")
    
    print("="*80)