# Learned Analytical Formula Diagnostics

This notebook diagnoses why the learned analytical formula is underperforming compared to the current analytical baseline.

## Investigation Goals

1. **Load and analyze learned formulas** from notebook 8/8b results
2. **Compare predictions**: Learned vs Analytical vs Empirical
3. **Identify failure modes**: Where does the learned formula perform poorly?
4. **Test different configurations** with fewer permutations (<50)
5. **Recommend improvements**

## Key Questions

- What correlation does the learned formula achieve vs empirical?
- Are the learned parameters reasonable?
- Is regularization (λ=0.001) too strong?
- Is training data quantity insufficient?
- Does the formula structure have fundamental limitations?

In [None]:
import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.sparse as sp
from scipy.stats import pearsonr
import warnings
warnings.filterwarnings('ignore')

# Setup paths
repo_dir = Path.cwd().parent
src_dir = repo_dir / 'src'
data_dir = repo_dir / 'data'
results_dir = repo_dir / 'results'

sys.path.append(str(src_dir))

from learned_analytical import LearnedAnalyticalFormula

print(f"Repository: {repo_dir}")

# Set plot style
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100

## Configuration

In [None]:
# Test edge types
test_edge_types = ['CtD', 'AeG']  # Sparse and dense

# Test fewer permutations (<50)
test_N_values = [5, 10, 20, 30, 40]

# Test configurations
test_configs = [
    {'name': 'baseline', 'regularization': 0.001, 'n_starts': 10, 'formula': 'original'},
    {'name': 'no_reg', 'regularization': 0.0, 'n_starts': 10, 'formula': 'original'},
    {'name': 'high_reg', 'regularization': 0.01, 'n_starts': 10, 'formula': 'original'},
    {'name': 'more_starts', 'regularization': 0.001, 'n_starts': 20, 'formula': 'original'},
]

print(f"Testing {len(test_edge_types)} edge types")
print(f"Testing {len(test_N_values)} N values: {test_N_values}")
print(f"Testing {len(test_configs)} configurations")

## Helper Functions

In [None]:
def load_empirical_200(edge_type: str) -> pd.DataFrame:
    """Load 200-permutation empirical frequencies."""
    file_path = results_dir / 'empirical_edge_frequencies' / f'edge_frequency_by_degree_{edge_type}.csv'
    return pd.read_csv(file_path)

def compute_analytical_predictions(empirical_df: pd.DataFrame, m: int) -> np.ndarray:
    """Compute current analytical formula predictions."""
    predictions = []
    for _, row in empirical_df.iterrows():
        u, v = row['source_degree'], row['target_degree']
        uv = u * v
        denom = np.sqrt(uv**2 + (m - u - v + 1)**2)
        p = uv / denom if denom > 0 else 0.0
        predictions.append(p)
    return np.array(predictions)

def get_graph_stats(edge_type: str) -> dict:
    """Get graph statistics."""
    edge_file = data_dir / 'permutations' / '000.hetmat' / 'edges' / f'{edge_type}.sparse.npz'
    edge_matrix = sp.load_npz(edge_file)
    n_sources, n_targets = edge_matrix.shape
    m = edge_matrix.nnz
    density = m / (n_sources * n_targets)
    return {'m': m, 'density': density, 'n_sources': n_sources, 'n_targets': n_targets}

## Test Different Configurations

In [None]:
# Store results
all_diagnostic_results = []

for edge_type in test_edge_types:
    print(f"\n{'='*80}")
    print(f"TESTING {edge_type}")
    print(f"{'='*80}\n")
    
    # Load empirical and graph stats
    empirical_df = load_empirical_200(edge_type)
    stats = get_graph_stats(edge_type)
    m, density = stats['m'], stats['density']
    
    # Compute analytical baseline
    analytical_preds = compute_analytical_predictions(empirical_df, m)
    empirical_vals = empirical_df['frequency'].values
    
    baseline_corr = pearsonr(analytical_preds, empirical_vals)[0]
    baseline_mae = np.mean(np.abs(analytical_preds - empirical_vals))
    
    print(f"Baseline (Current Analytical):")
    print(f"  Correlation: {baseline_corr:.6f}")
    print(f"  MAE: {baseline_mae:.6f}\n")
    
    for config in test_configs:
        print(f"{'-'*60}")
        print(f"Config: {config['name']}")
        print(f"  Regularization λ={config['regularization']}, Starts={config['n_starts']}, Formula={config['formula']}")
        print(f"{'-'*60}\n")
        
        # Train learner with different N values
        for N in test_N_values:
            learner = LearnedAnalyticalFormula(
                n_random_starts=config['n_starts'],
                regularization_lambda=config['regularization'],
                formula_type=config['formula']
            )
            
            try:
                # Train on N permutations
                results = learner.find_minimum_permutations(
                    graph_name=edge_type,
                    data_dir=data_dir,
                    results_dir=results_dir,
                    N_candidates=[N],
                    convergence_threshold=0.02,
                    target_metric='correlation',
                    min_metric_value=0.95
                )
                
                learned_corr = results['final_metrics']['correlation']
                learned_mae = results['final_metrics']['mae']
                
                # Store results
                all_diagnostic_results.append({
                    'edge_type': edge_type,
                    'config': config['name'],
                    'N': N,
                    'regularization': config['regularization'],
                    'n_starts': config['n_starts'],
                    'formula': config['formula'],
                    'learned_corr': learned_corr,
                    'learned_mae': learned_mae,
                    'baseline_corr': baseline_corr,
                    'baseline_mae': baseline_mae,
                    'improvement_corr': learned_corr - baseline_corr,
                    'improvement_mae': baseline_mae - learned_mae,
                    'params': learner.params.tolist()
                })
                
                improvement = learned_corr - baseline_corr
                status = "✓" if improvement > 0 else "✗"
                print(f"  N={N:2d}: r={learned_corr:.4f} (baseline={baseline_corr:.4f}, Δ={improvement:+.4f}) {status}")
                
            except Exception as e:
                print(f"  N={N:2d}: FAILED - {e}")
                all_diagnostic_results.append({
                    'edge_type': edge_type,
                    'config': config['name'],
                    'N': N,
                    'regularization': config['regularization'],
                    'n_starts': config['n_starts'],
                    'formula': config['formula'],
                    'learned_corr': None,
                    'learned_mae': None,
                    'baseline_corr': baseline_corr,
                    'baseline_mae': baseline_mae,
                    'improvement_corr': None,
                    'improvement_mae': None,
                    'params': None
                })
        
        print()

# Convert to DataFrame
diagnostic_df = pd.DataFrame(all_diagnostic_results)
print(f"\n{'='*80}")
print("ALL DIAGNOSTIC TESTS COMPLETE")
print(f"{'='*80}")

## Results Analysis

In [None]:
# Summary statistics
print("\n" + "="*80)
print("DIAGNOSTIC RESULTS SUMMARY")
print("="*80)

for edge_type in test_edge_types:
    print(f"\n{edge_type}:")
    edge_results = diagnostic_df[diagnostic_df['edge_type'] == edge_type]
    
    # Find best configuration
    valid_results = edge_results[edge_results['learned_corr'].notna()]
    if len(valid_results) > 0:
        best_idx = valid_results['improvement_corr'].idxmax()
        best = valid_results.loc[best_idx]
        
        print(f"  Best: {best['config']} with N={best['N']}")
        print(f"    Correlation: {best['learned_corr']:.4f} (baseline: {best['baseline_corr']:.4f})")
        print(f"    Improvement: {best['improvement_corr']:+.4f} ({best['improvement_corr']/best['baseline_corr']*100:+.1f}%)")
        
        if best['improvement_corr'] > 0:
            print(f"    ✓ OUTPERFORMS BASELINE!")
        else:
            print(f"    ✗ Still underperforming")
    else:
        print(f"  No valid results")

# Display full results table
print("\n" + "="*80)
print("Full Results Table:")
display(diagnostic_df[['edge_type', 'config', 'N', 'learned_corr', 'baseline_corr', 'improvement_corr']].dropna())

## Visualization: Performance by Configuration

In [None]:
# Plot improvement vs N for each configuration
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

for ax_idx, edge_type in enumerate(test_edge_types):
    ax = axes[ax_idx]
    edge_results = diagnostic_df[diagnostic_df['edge_type'] == edge_type]
    
    for config_name in diagnostic_df['config'].unique():
        config_results = edge_results[edge_results['config'] == config_name]
        valid = config_results[config_results['learned_corr'].notna()]
        
        if len(valid) > 0:
            ax.plot(valid['N'], valid['improvement_corr'], marker='o', label=config_name, linewidth=2, markersize=8)
    
    ax.axhline(0, color='red', linestyle='--', linewidth=2, alpha=0.5, label='Baseline')
    ax.set_xlabel('Number of Training Permutations (N)', fontsize=12)
    ax.set_ylabel('Improvement over Baseline\n(Δ Correlation)', fontsize=12)
    ax.set_title(f'{edge_type} - Learned Formula Performance', fontsize=14, fontweight='bold')
    ax.legend(loc='best')
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(results_dir / 'learned_formula_diagnostics_improvement.png', dpi=300, bbox_inches='tight')
plt.show()
print("Saved improvement plot")

## Parameter Analysis

In [None]:
# Analyze learned parameters
print("\n" + "="*80)
print("LEARNED PARAMETERS ANALYSIS")
print("="*80)

param_names = ['α', 'β', 'γ', 'δ', 'ε', 'ζ', 'η', 'θ', 'κ']

for edge_type in test_edge_types:
    print(f"\n{edge_type}:")
    edge_results = diagnostic_df[(diagnostic_df['edge_type'] == edge_type) & (diagnostic_df['params'].notna())]
    
    if len(edge_results) > 0:
        # Get best performing configuration
        best_idx = edge_results['improvement_corr'].idxmax()
        best = edge_results.loc[best_idx]
        
        print(f"  Best config: {best['config']} (N={best['N']})")
        print(f"  Learned parameters:")
        for pname, pval in zip(param_names, best['params']):
            print(f"    {pname}: {pval:10.6f}")

## Recommendations

In [None]:
print("\n" + "="*80)
print("RECOMMENDATIONS")
print("="*80)

# Analyze which configuration works best
valid_results = diagnostic_df[diagnostic_df['learned_corr'].notna()]

if len(valid_results) > 0:
    # Best configuration overall
    best_config = valid_results.groupby('config')['improvement_corr'].mean().idxmax()
    best_avg_improvement = valid_results.groupby('config')['improvement_corr'].mean()[best_config]
    
    print(f"\n1. BEST CONFIGURATION: {best_config}")
    print(f"   Average improvement: {best_avg_improvement:+.4f}")
    
    # Best N value
    best_N = valid_results.groupby('N')['improvement_corr'].mean().idxmax()
    best_N_improvement = valid_results.groupby('N')['improvement_corr'].mean()[best_N]
    
    print(f"\n2. BEST N VALUE: {best_N} permutations")
    print(f"   Average improvement: {best_N_improvement:+.4f}")
    
    # Check if any configuration beats baseline
    any_beat_baseline = (valid_results['improvement_corr'] > 0).any()
    
    if any_beat_baseline:
        print(f"\n3. ✓ SOME CONFIGURATIONS BEAT BASELINE!")
        winners = valid_results[valid_results['improvement_corr'] > 0]
        print(f"   {len(winners)} out of {len(valid_results)} tests successful")
    else:
        print(f"\n3. ✗ NO CONFIGURATION BEATS BASELINE")
        print(f"   Possible reasons:")
        print(f"   - Formula structure is too simple")
        print(f"   - Training data is insufficient (<50 permutations)")
        print(f"   - Optimization is getting stuck in local minima")
        print(f"   - Regularization is too strong/weak")

# Save diagnostic results
diagnostic_df.to_csv(results_dir / 'learned_formula_diagnostics.csv', index=False)
print(f"\n{'='*80}")
print(f"Results saved to: {results_dir / 'learned_formula_diagnostics.csv'}")
print(f"{'='*80}")