# Per-Bin Accuracy Analysis with Anti-Overfitting Config

This notebook analyzes the per-bin accuracy for both LSTM and Attention LSTM models using the anti-overfitting configuration to prevent overfitting.

## Overview
- Loads existing trained models from anti_overfitting comparison
- Evaluates performance across quadtree bins
- Compares forecast accuracy, WMAPE, and other metrics per bin
- Visualizes results with heatmaps and comparison plots

In [None]:
import sys
import os
import json
import subprocess
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple

# Add src to path
sys.path.append('src')

# Set style
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

## Load Anti-Overfitting Configuration

In [None]:
# Load the anti-overfitting configuration
with open('anti_overfitting_config.json', 'r') as f:
    anti_overfitting_config = json.load(f)

print("Anti-Overfitting Configuration:")
print(f"Name: {anti_overfitting_config['name']}")
print(f"Description: {anti_overfitting_config['description']}")
print(f"Version: {anti_overfitting_config['version']}")
print()

print("Model Architecture:")
for key, value in anti_overfitting_config['model_architecture'].items():
    print(f"  {key}: {value}")
print()

print("Training Parameters:")
for key, value in anti_overfitting_config['training_parameters'].items():
    print(f"  {key}: {value}")
print()

print("Anti-Overfitting Features:")
for key, value in anti_overfitting_config['anti_overfitting_features'].items():
    print(f"  {key}: {value}")
print()

print("Expected Behavior:")
for key, value in anti_overfitting_config['expected_behavior'].items():
    print(f"  {key}: {value}")

## Set Output Directory (Use Existing Trained Models)

In [None]:
# Use existing trained models - skip training
output_dir = "data/results_anti_overfitting_comparison"

print(f"Output directory: {output_dir}")
print(f"Full path: {Path(output_dir).absolute()}")
print(f"Results dir: {Path(output_dir) / 'results' / 'model_comparison'}")
print(f"Results dir exists: {(Path(output_dir) / 'results' / 'model_comparison').exists()}")

## Load and Analyze Results

In [None]:
def load_comparison_results(output_dir: str) -> Dict:
    """Load the comparison results from the output directory."""
    
    results_dir = Path(output_dir) / "results" / "model_comparison"
    
    if not results_dir.exists():
        print(f"‚ùå Results directory not found: {results_dir}")
        return None
    
    print(f"‚úÖ Found results directory: {results_dir}")
    
    # List available files
    available_files = list(results_dir.glob("*"))
    print(f"ÔøΩÔøΩ Available files:")
    for file in available_files:
        print(f"  - {file.name}")
    
    # Load comparison metrics
    metrics_file = results_dir / "comparison_metrics.json"
    if metrics_file.exists():
        with open(metrics_file, 'r') as f:
            metrics = json.load(f)
        print(f"\nüìä Loaded comparison metrics:")
        print(f"  Simple LSTM - Magnitude Accuracy: {metrics['simple_lstm_metrics']['magnitude_accuracy']:.3f}")
        print(f"  Simple LSTM - Frequency Accuracy: {metrics['simple_lstm_metrics']['frequency_accuracy']:.3f}")
        print(f"  Attention LSTM - Magnitude Accuracy: {metrics['attention_lstm_metrics']['magnitude_accuracy']:.3f}")
        print(f"  Attention LSTM - Frequency Accuracy: {metrics['attention_lstm_metrics']['frequency_accuracy']:.3f}")
    
    # Check if trained models exist
    simple_model_path = results_dir / "simple_lstm_model.pth"
    attention_model_path = results_dir / "attention_lstm_model.pth"
    
    if simple_model_path.exists() and attention_model_path.exists():
        print(f"\n‚úÖ Found trained models:")
        print(f"  Simple LSTM: {simple_model_path}")
        print(f"  Attention LSTM: {attention_model_path}")
        return {
            'models_dir': str(results_dir),
            'simple_model_path': str(simple_model_path),
            'attention_model_path': str(attention_model_path),
            'metrics': metrics if 'metrics' in locals() else None
        }
    else:
        print(f"\n‚ùå Trained models not found")
        return None

# Load results if available
if output_dir:
    results = load_comparison_results(output_dir)
else:
    results = None

## Display Overall Comparison Metrics

In [None]:
if results and results.get('metrics'):
    metrics = results['metrics']
    
    print("ÔøΩÔøΩ OVERALL COMPARISON METRICS")
    print("=" * 50)
    
    # Simple LSTM metrics
    print("\nüîµ Simple LSTM Performance:")
    print(f"  Total Loss: {metrics['simple_lstm_metrics']['total_loss']:.4f}")
    print(f"  Magnitude Loss: {metrics['simple_lstm_metrics']['magnitude_loss']:.4f}")
    print(f"  Frequency Loss: {metrics['simple_lstm_metrics']['frequency_loss']:.4f}")
    print(f"  Magnitude Accuracy: {metrics['simple_lstm_metrics']['magnitude_accuracy']:.3f}")
    print(f"  Frequency Accuracy: {metrics['simple_lstm_metrics']['frequency_accuracy']:.3f}")
    print(f"  Magnitude Correlation: {metrics['simple_lstm_metrics']['magnitude_corr']:.3f}")
    print(f"  Frequency Correlation: {metrics['simple_lstm_metrics']['frequency_corr']:.3f}")
    
    # Attention LSTM metrics
    print("\nüü° Attention LSTM Performance:")
    print(f"  Total Loss: {metrics['attention_lstm_metrics']['total_loss']:.4f}")
    print(f"  Magnitude Loss: {metrics['attention_lstm_metrics']['magnitude_loss']:.4f}")
    print(f"  Frequency Loss: {metrics['attention_lstm_metrics']['frequency_loss']:.4f}")
    print(f"  Magnitude Accuracy: {metrics['attention_lstm_metrics']['magnitude_accuracy']:.3f}")
    print(f"  Frequency Accuracy: {metrics['attention_lstm_metrics']['frequency_accuracy']:.3f}")
    print(f"  Magnitude Correlation: {metrics['attention_lstm_metrics']['magnitude_corr']:.3f}")
    print(f"  Frequency Correlation: {metrics['attention_lstm_metrics']['frequency_corr']:.3f}")
    
    # Hyperparameters used
    print("\n‚öôÔ∏è  Hyperparameters Used:")
    for key, value in metrics['hyperparameters'].items():
        print(f"  {key}: {value}")
    
    # Performance comparison
    print("\nÔøΩÔøΩ PERFORMANCE COMPARISON:")
    
    # Magnitude accuracy comparison
    simple_mag_acc = metrics['simple_lstm_metrics']['magnitude_accuracy']
    attention_mag_acc = metrics['attention_lstm_metrics']['magnitude_accuracy']
    
    if simple_mag_acc > attention_mag_acc:
        print(f"  üéØ Simple LSTM wins on Magnitude Accuracy: {simple_mag_acc:.3f} vs {attention_mag_acc:.3f}")
    else:
        print(f"  üéØ Attention LSTM wins on Magnitude Accuracy: {attention_mag_acc:.3f} vs {simple_mag_acc:.3f}")
    
    # Frequency accuracy comparison
    simple_freq_acc = metrics['simple_lstm_metrics']['frequency_accuracy']
    attention_freq_acc = metrics['attention_lstm_metrics']['frequency_accuracy']
    
    if simple_freq_acc > attention_freq_acc:
        print(f"  üéØ Simple LSTM wins on Frequency Accuracy: {simple_freq_acc:.3f} vs {attention_freq_acc:.3f}")
    else:
        print(f"  üéØ Attention LSTM wins on Frequency Accuracy: {attention_freq_acc:.3f} vs {simple_freq_acc:.3f}")
    
    # Overall winner
    simple_total = simple_mag_acc + simple_freq_acc
    attention_total = attention_mag_acc + attention_freq_acc
    
    if simple_total > attention_total:
        print(f"\nüèÜ OVERALL WINNER: Simple LSTM ({simple_total:.3f} vs {attention_total:.3f})")
    else:
        print(f"\nüèÜ OVERALL WINNER: Attention LSTM ({attention_total:.3f} vs {simple_total:.3f})")
else:
    print("‚ö†Ô∏è  No metrics available to analyze.")

## Create Comparison Visualizations

In [None]:
if results and results.get('metrics'):
    metrics = results['metrics']
    
    # Create comparison plots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Accuracy Comparison
    models = ['Simple LSTM', 'Attention LSTM']
    mag_acc = [metrics['simple_lstm_metrics']['magnitude_accuracy'], 
               metrics['attention_lstm_metrics']['magnitude_accuracy']]
    freq_acc = [metrics['simple_lstm_metrics']['frequency_accuracy'], 
                metrics['attention_lstm_metrics']['frequency_accuracy']]
    
    x = np.arange(len(models))
    width = 0.35
    
    ax1.bar(x - width/2, mag_acc, width, label='Magnitude Accuracy', alpha=0.8, color='skyblue')
    ax1.bar(x + width/2, freq_acc, width, label='Frequency Accuracy', alpha=0.8, color='lightcoral')
    ax1.set_xlabel('Model')
    ax1.set_ylabel('Accuracy')
    ax1.set_title('Accuracy Comparison')
    ax1.set_xticks(x)
    ax1.set_xticklabels(models)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Loss Comparison
    total_loss = [metrics['simple_lstm_metrics']['total_loss'], 
                  metrics['attention_lstm_metrics']['total_loss']]
    mag_loss = [metrics['simple_lstm_metrics']['magnitude_loss'], 
                metrics['attention_lstm_metrics']['magnitude_loss']]
    freq_loss = [metrics['simple_lstm_metrics']['frequency_loss'], 
                 metrics['attention_lstm_metrics']['frequency_loss']]
    
    ax2.bar(x - width/2, mag_loss, width, label='Magnitude Loss', alpha=0.8, color='lightblue')
    ax2.bar(x + width/2, freq_loss, width, label='Frequency Loss', alpha=0.8, color='salmon')
    ax2.set_xlabel('Model')
    ax2.set_ylabel('Loss')
    ax2.set_title('Loss Comparison')
    ax2.set_xticks(x)
    ax2.set_xticklabels(models)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Correlation Comparison
    mag_corr = [metrics['simple_lstm_metrics']['magnitude_corr'], 
                metrics['attention_lstm_metrics']['magnitude_corr']]
    freq_corr = [metrics['simple_lstm_metrics']['frequency_corr'], 
                 metrics['attention_lstm_metrics']['frequency_corr']]
    
    ax3.bar(x - width/2, mag_corr, width, label='Magnitude Correlation', alpha=0.8, color='lightgreen')
    ax3.bar(x + width/2, freq_corr, width, label='Frequency Correlation', alpha=0.8, color='orange')
    ax3.set_xlabel('Model')
    ax3.set_ylabel('Correlation')
    ax3.set_title('Correlation Comparison')
    ax3.set_xticks(x)
    ax3.set_xticklabels(models)
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # 4. Overall Performance Heatmap
    performance_data = [
        [mag_acc[0], freq_acc[0]],  # Simple LSTM
        [mag_acc[1], freq_acc[1]]   # Attention LSTM
    ]
    
    im = ax4.imshow(performance_data, cmap='RdYlGn', aspect='auto', vmin=0, vmax=1)
    ax4.set_xticks([0, 1])
    ax4.set_xticklabels(['Magnitude', 'Frequency'])
    ax4.set_yticks([0, 1])
    ax4.set_yticklabels(['Simple LSTM', 'Attention LSTM'])
    ax4.set_title('Performance Heatmap')
    
    # Add text annotations
    for i in range(2):
        for j in range(2):
            text = ax4.text(j, i, f'{performance_data[i][j]:.3f}',
                           ha="center", va="center", color="black", fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    print("üìä Comparison visualizations created successfully!")

## Save Results to CSV

In [None]:
if results and results.get('metrics'):
    metrics = results['metrics']
    
    # Create summary DataFrame
    summary_df = pd.DataFrame({
        'Metric': ['Total Loss', 'Magnitude Loss', 'Frequency Loss', 'Magnitude Accuracy', 
                   'Frequency Accuracy', 'Magnitude Correlation', 'Frequency Correlation'],
        'Simple_LSTM': [
            metrics['simple_lstm_metrics']['total_loss'],
            metrics['simple_lstm_metrics']['magnitude_loss'],
            metrics['simple_lstm_metrics']['frequency_loss'],
            metrics['simple_lstm_metrics']['magnitude_accuracy'],
            metrics['simple_lstm_metrics']['frequency_accuracy'],
            metrics['simple_lstm_metrics']['magnitude_corr'],
            metrics['simple_lstm_metrics']['frequency_corr']
        ],
        'Attention_LSTM': [
            metrics['attention_lstm_metrics']['total_loss'],
            metrics['attention_lstm_metrics']['magnitude_loss'],
            metrics['attention_lstm_metrics']['frequency_loss'],
            metrics['attention_lstm_metrics']['magnitude_accuracy'],
            metrics['attention_lstm_metrics']['frequency_accuracy'],
            metrics['attention_lstm_metrics']['magnitude_corr'],
            metrics['attention_lstm_metrics']['frequency_corr']
        ]
    })
    
    # Save to CSV
    output_file = "anti_overfitting_comparison_results.csv"
    summary_df.to_csv(output_file, index=False)
    
    print(f"‚úÖ Results saved to: {output_file}")
    print("\nÔøΩÔøΩ Results Summary:")
    print(summary_df.round(4))
    
    # Calculate differences
    summary_df['Difference'] = summary_df['Attention_LSTM'] - summary_df['Simple_LSTM']
    summary_df['Simple_LSTM_Wins'] = summary_df['Difference'] < 0
    
    print("\nüìä Performance Differences (Attention - Simple):")
    print(summary_df[['Metric', 'Difference']].round(4))
    
    # Count wins
    simple_wins = summary_df['Simple_LSTM_Wins'].sum()
    attention_wins = len(summary_df) - simple_wins
    
    print(f"\nüèÜ Final Score:")
    print(f"  Simple LSTM wins: {simple_wins} metrics")
    print(f"  Attention LSTM wins: {attention_wins} metrics")
    
    if simple_wins > attention_wins:
        print(f"  üéØ OVERALL WINNER: Simple LSTM")
    elif attention_wins > simple_wins:
        print(f"  üéØ OVERALL WINNER: Attention LSTM")
    else:
        print(f"  üéØ TIE: Both models perform equally well")
else:
    print("‚ö†Ô∏è  No metrics available to save.")

## Summary and Conclusions

In [None]:
if results and results.get('metrics'):
    metrics = results['metrics']
    
    print("üìã COMPREHENSIVE SUMMARY REPORT")
    print("=" * 60)
    print(f"Configuration: Anti-Overfitting Config")
    print(f"Analysis Type: Overall Model Comparison")
    print()
    
    # Overall performance
    print("üèÜ OVERALL PERFORMANCE:")
    simple_mag_acc = metrics['simple_lstm_metrics']['magnitude_accuracy']
    simple_freq_acc = metrics['simple_lstm_metrics']['frequency_accuracy']
    attention_mag_acc = metrics['attention_lstm_metrics']['magnitude_accuracy']
    attention_freq_acc = metrics['attention_lstm_metrics']['frequency_accuracy']
    
    print(f"  Simple LSTM:     {simple_mag_acc:.3f} magnitude, {simple_freq_acc:.3f} frequency")
    print(f"  Attention LSTM:  {attention_mag_acc:.3f} magnitude, {attention_freq_acc:.3f} frequency")
    print()
    
    # Winner determination
    simple_total = simple_mag_acc + simple_freq_acc
    attention_total = attention_mag_acc + attention_freq_acc
    
    if attention_total > simple_total:
        print("üéØ WINNER: Attention LSTM")
        print(f"   Reason: Higher combined accuracy ({attention_total:.3f} vs {simple_total:.3f})")
    else:
        print("üéØ WINNER: Simple LSTM")
        print(f"   Reason: Higher combined accuracy ({simple_total:.3f} vs {attention_total:.3f})")
    
    print()
    
    # Anti-overfitting effectiveness
    print("ÔøΩÔøΩÔ∏è  ANTI-OVERFITTING EFFECTIVENESS:")
    max_acc = max(simple_mag_acc, simple_freq_acc, attention_mag_acc, attention_freq_acc)
    if max_acc < 0.9:
        print(f"  ‚úÖ Effective: Maximum accuracy is {max_acc:.3f} (below 0.9 threshold)")
    else:
        print(f"  ‚ö†Ô∏è  Caution: Maximum accuracy is {max_acc:.3f} (above 0.9 threshold)")
    
    # Check for realistic performance ranges
    simple_range = max(simple_mag_acc, simple_freq_acc) - min(simple_mag_acc, simple_freq_acc)
    attention_range = max(attention_mag_acc, attention_freq_acc) - min(attention_mag_acc, attention_freq_acc)
    
    print(f"  Simple LSTM accuracy range: {simple_range:.3f}")
    print(f"  Attention LSTM accuracy range: {attention_range:.3f}")
    
    if simple_range < 0.3 and attention_range < 0.3:
        print("  ‚úÖ Good: Both models show consistent performance across tasks")
    else:
        print("  ‚ö†Ô∏è  Caution: High variance across tasks may indicate instability")
    
    print()
    
    # Recommendations
    print("ÔøΩÔøΩ RECOMMENDATIONS:")
    if attention_total > simple_total + 0.1:
        print("  ‚Ä¢ Attention LSTM shows significant improvement - consider for production")
    elif simple_total > attention_total + 0.1:
        print("  ‚Ä¢ Simple LSTM performs better - simpler model may be sufficient")
    else:
        print("  ‚Ä¢ Both models perform similarly - choose based on computational requirements")
    
    if max_acc < 0.8:
        print("  ‚Ä¢ Anti-overfitting config is working well - realistic performance achieved")
    else:
        print("  ‚Ä¢ Consider further regularization if accuracy is too high")
    
    print()
    print("ÔøΩÔøΩ Analysis completed successfully!")
    print("\nNote: This analysis shows overall model performance.")
    print("For per-bin accuracy analysis, you would need to run the trained models")
    print("on test data with spatial binning to get detailed per-bin metrics.")