# Shared LSTM Performance Analysis Notebook

This notebook provides comprehensive analysis of the Shared LSTM earthquake forecasting model performance, including:
- Training history visualization
- Accuracy metrics per spatial bin
- Zero frequency bin analysis by year
- Model predictions vs actual values
- Spatial performance distribution

In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Add src to path for imports
sys.path.append(str(Path.cwd() / "src"))

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

In [None]:
# Import required modules
from src.models.shared_lstm_model import SharedLSTMModel, WeightedEarthquakeLoss
from src.models.shared_lstm_trainer import SharedLSTMTrainer
from src.models.enhanced_shared_processor import EnhancedSharedDataset
from src.preprocessing.earthquake_processor import EarthquakeProcessor
from src.binning.quadtree import extract_leaf_bounds, count_events_in_bin

print("All modules imported successfully!")

## 1. Load Trained Model and Data

First, let's load the trained model and prepare the data for analysis.

In [None]:
# Configuration
MODEL_PATH = "models/shared_lstm_best.pth"  # Update with your model path
DATA_PATH = "data/processed_earthquakes.csv"  # Update with your data path
CONFIG_PATH = "config_example.json"  # Update with your config path

# Load configuration
if os.path.exists(CONFIG_PATH):
    with open(CONFIG_PATH, 'r') as f:
        config = json.load(f)
    print("Configuration loaded:")
    print(json.dumps(config, indent=2))
else:
    print(f"Config file not found at {CONFIG_PATH}")
    config = {}

# Check if model exists
if os.path.exists(MODEL_PATH):
    print(f"\nModel found at: {MODEL_PATH}")
else:
    print(f"\nModel not found at: {MODEL_PATH}")
    print("Please update MODEL_PATH with the correct path to your trained model")

In [None]:
# Load processed earthquake data
if os.path.exists(DATA_PATH):
    df = pd.read_csv(DATA_PATH)
    print(f"Data loaded: {len(df)} records")
    print(f"Columns: {list(df.columns)}")
    print(f"Date range: {df['date'].min()} to {df['date'].max()}")
    
    # Display sample data
    print("\nSample data:")
    display(df.head())
else:
    print(f"Data file not found at {DATA_PATH}")
    print("Please update DATA_PATH with the correct path to your processed data")

## 2. Model Performance Analysis

Let's analyze the model's performance across different metrics.

In [None]:
def load_and_evaluate_model(model_path, test_loader, device='auto'):
    """Load trained model and evaluate on test data."""
    if device == 'auto':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load model checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Extract model architecture info
    input_size = checkpoint.get('input_size', 64)
    hidden_size = checkpoint.get('hidden_size', 32)
    num_layers = checkpoint.get('num_layers', 2)
    
    # Create model instance
    model = SharedLSTMModel(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers
    )
    
    # Load state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print(f"Model loaded successfully on {device}")
    print(f"Architecture: {input_size} → {hidden_size} (layers: {num_layers})")
    
    return model, checkpoint

# Note: This function requires a test_loader to be available
# You'll need to create the data loaders based on your data structure

## 3. Training History Visualization

If you have training history available, let's visualize the training progress.

In [None]:
def plot_training_history(checkpoint):
    """Plot training history from checkpoint."""
    if 'training_history' not in checkpoint:
        print("No training history found in checkpoint")
        return
    
    history = checkpoint['training_history']
    
    # Create subplots
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Total loss
    if 'train_losses' in history and 'val_losses' in history:
        axes[0, 0].plot(history['train_losses'], label='Train', linewidth=2)
        axes[0, 0].plot(history['val_losses'], label='Validation', linewidth=2)
        axes[0, 0].set_title('Total Loss', fontsize=14, fontweight='bold')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
    
    # Magnitude loss
    if 'train_magnitude_losses' in history and 'val_magnitude_losses' in history:
        axes[0, 1].plot(history['train_magnitude_losses'], label='Train', linewidth=2)
        axes[0, 1].plot(history['val_magnitude_losses'], label='Validation', linewidth=2)
        axes[0, 1].set_title('Magnitude Loss (MSE)', fontsize=14, fontweight='bold')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True, alpha=0.3)
    
    # Frequency loss
    if 'train_frequency_losses' in history and 'val_frequency_losses' in history:
        axes[0, 2].plot(history['train_frequency_losses'], label='Train', linewidth=2)
        axes[0, 2].plot(history['val_frequency_losses'], label='Validation', linewidth=2)
        axes[0, 2].set_title('Frequency Loss (Poisson NLL)', fontsize=14, fontweight='bold')
        axes[0, 2].set_xlabel('Epoch')
        axes[0, 2].set_ylabel('Loss')
        axes[0, 2].legend()
        axes[0, 2].grid(True, alpha=0.3)
    
    # Magnitude MAE
    if 'train_magnitude_mae' in history and 'val_magnitude_mae' in history:
        axes[1, 0].plot(history['train_magnitude_mae'], label='Train', linewidth=2)
        axes[1, 0].plot(history['val_magnitude_mae'], label='Validation', linewidth=2)
        axes[1, 0].set_title('Magnitude MAE', fontsize=14, fontweight='bold')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('MAE')
        axes[1, 0].legend()
        axes[1, 0].grid(True, alpha=0.3)
    
    # Frequency MAE
    if 'train_frequency_mae' in history and 'val_frequency_mae' in history:
        axes[1, 1].plot(history['train_frequency_mae'], label='Train', linewidth=2)
        axes[1, 1].plot(history['val_frequency_mae'], label='Validation', linewidth=2)
        axes[1, 1].set_title('Frequency MAE', fontsize=14, fontweight='bold')
        axes[1, 1].set_xlabel('Epoch')
        axes[1, 1].set_ylabel('MAE')
        axes[1, 1].legend()
        axes[1, 1].grid(True, alpha=0.3)
    
    # Best validation loss
    best_val_loss = checkpoint.get('best_val_loss', None)
    if best_val_loss:
        axes[1, 2].axhline(y=best_val_loss, color='red', linestyle='--', 
                           label=f'Best Val Loss: {best_val_loss:.4f}', linewidth=2)
        axes[1, 2].set_title('Best Validation Loss', fontsize=14, fontweight='bold')
        axes[1, 2].set_xlabel('Epoch')
        axes[1, 2].set_ylabel('Loss')
        axes[1, 2].legend()
        axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Example usage (uncomment when you have a checkpoint):
# plot_training_history(checkpoint)

## 4. Bin-wise Performance Analysis

This section analyzes the model's performance for each spatial bin, including accuracy metrics.

In [None]:
def analyze_bin_performance(df, model, bin_column='bin_id'):
    """Analyze model performance for each spatial bin."""
    if bin_column not in df.columns:
        print(f"Column '{bin_column}' not found in data. Available columns: {list(df.columns)}")
        return None
    
    # Group by bin and calculate metrics
    bin_metrics = []
    
    for bin_id in df[bin_column].unique():
        bin_data = df[df[bin_column] == bin_id]
        
        if len(bin_data) == 0:
            continue
        
        # Calculate basic statistics
        metrics = {
            'bin_id': bin_id,
            'total_events': len(bin_data),
            'mean_magnitude': bin_data['magnitude'].mean() if 'magnitude' in bin_data.columns else np.nan,
            'std_magnitude': bin_data['magnitude'].std() if 'magnitude' in bin_data.columns else np.nan,
            'min_magnitude': bin_data['magnitude'].min() if 'magnitude' in bin_data.columns else np.nan,
            'max_magnitude': bin_data['magnitude'].max() if 'magnitude' in bin_data.columns else np.nan,
            'mean_frequency': bin_data['frequency'].mean() if 'frequency' in bin_data.columns else np.nan,
            'zero_frequency_count': len(bin_data[bin_data['frequency'] == 0]) if 'frequency' in bin_data.columns else np.nan,
            'zero_frequency_ratio': len(bin_data[bin_data['frequency'] == 0]) / len(bin_data) if 'frequency' in bin_data.columns else np.nan
        }
        
        bin_metrics.append(metrics)
    
    return pd.DataFrame(bin_metrics)

# Analyze bin performance
if 'df' in locals():
    bin_performance = analyze_bin_performance(df, None, 'bin_id')
    if bin_performance is not None:
        print("Bin Performance Summary:")
        display(bin_performance.head(10))
        
        # Summary statistics
        print("\nSummary Statistics:")
        print(f"Total bins: {len(bin_performance)}")
        print(f"Bins with zero frequency: {len(bin_performance[bin_performance['zero_frequency_ratio'] > 0])}")
        print(f"Average events per bin: {bin_performance['total_events'].mean():.2f}")
else:
    print("Data not loaded yet. Please run the data loading cell first.")

## 5. Zero Frequency Bin Analysis by Year

This section identifies bins with zero frequency and analyzes their temporal distribution.

In [None]:
def analyze_zero_frequency_bins(df, date_column='date', bin_column='bin_id', frequency_column='frequency'):
    """Analyze bins with zero frequency and their temporal distribution."""
    required_columns = [date_column, bin_column, frequency_column]
    missing_columns = [col for col in required_columns if col not in df.columns]
    
    if missing_columns:
        print(f"Missing columns: {missing_columns}")
        print(f"Available columns: {list(df.columns)}")
        return None
    
    # Convert date column to datetime if needed
    if not pd.api.types.is_datetime64_any_dtype(df[date_column]):
        df[date_column] = pd.to_datetime(df[date_column])
    
    # Extract year
    df['year'] = df[date_column].dt.year
    
    # Find zero frequency events
    zero_freq_data = df[df[frequency_column] == 0].copy()
    
    if len(zero_freq_data) == 0:
        print("No zero frequency events found.")
        return None
    
    # Group by year and bin
    yearly_zero_freq = zero_freq_data.groupby(['year', bin_column]).size().reset_index(name='zero_freq_count')
    
    # Group by year only
    yearly_summary = yearly_zero_freq.groupby('year').agg({
        'zero_freq_count': 'sum',
        bin_column: 'nunique'
    }).rename(columns={bin_column: 'unique_bins_with_zero_freq'})
    
    return yearly_zero_freq, yearly_summary

# Analyze zero frequency bins
if 'df' in locals():
    zero_freq_analysis = analyze_zero_frequency_bins(df)
    
    if zero_freq_analysis is not None:
        yearly_zero_freq, yearly_summary = zero_freq_analysis
        
        print("Zero Frequency Analysis by Year:")
        print("=" * 50)
        
        # Display yearly summary
        print("\nYearly Summary:")
        display(yearly_summary)
        
        # Display detailed breakdown
        print("\nDetailed Breakdown by Year and Bin:")
        display(yearly_zero_freq.head(20))
        
        # Plot zero frequency trends
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
        
        # Plot 1: Total zero frequency events per year
        ax1.plot(yearly_summary.index, yearly_summary['zero_freq_count'], 
                marker='o', linewidth=2, markersize=8)
        ax1.set_title('Zero Frequency Events per Year', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Year')
        ax1.set_ylabel('Count of Zero Frequency Events')
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Unique bins with zero frequency per year
        ax2.plot(yearly_summary.index, yearly_summary['unique_bins_with_zero_freq'], 
                marker='s', linewidth=2, markersize=8, color='orange')
        ax2.set_title('Unique Bins with Zero Frequency per Year', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Year')
        ax2.set_ylabel('Number of Unique Bins')
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
else:
    print("Data not loaded yet. Please run the data loading cell first.")

## 6. Accuracy Metrics per Bin

This section calculates and visualizes accuracy metrics for each spatial bin.

In [None]:
def calculate_bin_accuracy_metrics(df, predictions_df, bin_column='bin_id'):
    """Calculate accuracy metrics for each bin using predictions vs actual values."""
    # This function requires predictions from the model
    # For now, we'll create a placeholder structure
    
    print("Note: This function requires model predictions to calculate accuracy metrics.")
    print("You'll need to run the model on your test data to get predictions.")
    
    # Placeholder for accuracy metrics structure
    accuracy_metrics = {
        'magnitude_mae': 'Mean Absolute Error for magnitude predictions',
        'magnitude_rmse': 'Root Mean Square Error for magnitude predictions',
        'magnitude_correlation': 'Correlation between predicted and actual magnitudes',
        'frequency_mae': 'Mean Absolute Error for frequency predictions',
        'frequency_rmse': 'Root Mean Square Error for frequency predictions',
        'frequency_correlation': 'Correlation between predicted and actual frequencies',
        'overall_accuracy': 'Combined accuracy score'
    }
    
    return accuracy_metrics

def plot_bin_accuracy_heatmap(accuracy_df, metric='magnitude_mae'):
    """Plot heatmap of accuracy metrics across bins."""
    if accuracy_df is None or len(accuracy_df) == 0:
        print("No accuracy data available for plotting.")
        return
    
    # Create a sample heatmap (replace with actual data when available)
    plt.figure(figsize=(12, 8))
    
    # Sample data for demonstration
    sample_bins = np.arange(20)
    sample_metrics = np.random.rand(20)  # Replace with actual metrics
    
    plt.bar(sample_bins, sample_metrics, alpha=0.7, color='skyblue')
    plt.title(f'Sample {metric.replace("_", " ").title()} Across Bins', 
              fontsize=16, fontweight='bold')
    plt.xlabel('Bin ID')
    plt.ylabel(metric.replace('_', ' ').title())
    plt.grid(True, alpha=0.3)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

print("Accuracy analysis functions defined. Run with actual predictions when available.")

## 7. Spatial Distribution Analysis

Analyze the spatial distribution of model performance across different regions.

In [None]:
def analyze_spatial_distribution(df, lat_column='latitude', lon_column='longitude', bin_column='bin_id'):
    """Analyze spatial distribution of events and bins."""
    required_columns = [lat_column, lon_column, bin_column]
    missing_columns = [col for col in required_columns if col not in df.columns]
    
    if missing_columns:
        print(f"Missing columns: {missing_columns}")
        print(f"Available columns: {list(df.columns)}")
        return None
    
    # Create spatial analysis plots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Plot 1: Event density
    axes[0, 0].scatter(df[lon_column], df[lat_column], 
                       c=df['magnitude'] if 'magnitude' in df.columns else np.random.rand(len(df)),
                       alpha=0.6, s=20, cmap='viridis')
    axes[0, 0].set_title('Earthquake Event Distribution', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Longitude')
    axes[0, 0].set_ylabel('Latitude')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Bin distribution
    if bin_column in df.columns:
        bin_centers = df.groupby(bin_column).agg({
            lon_column: 'mean',
            lat_column: 'mean',
            'magnitude': 'count' if 'magnitude' in df.columns else 'size'
        }).reset_index()
        
        scatter = axes[0, 1].scatter(bin_centers[lon_column], bin_centers[lat_column], 
                                     c=bin_centers.iloc[:, -1], s=50, alpha=0.7, cmap='plasma')
        axes[0, 1].set_title('Bin Distribution (Size = Event Count)', fontsize=14, fontweight='bold')
        axes[0, 1].set_xlabel('Longitude')
        axes[0, 1].set_ylabel('Latitude')
        axes[0, 1].grid(True, alpha=0.3)
        plt.colorbar(scatter, ax=axes[0, 1])
    
    # Plot 3: Magnitude distribution
    if 'magnitude' in df.columns:
        axes[1, 0].hist(df['magnitude'], bins=30, alpha=0.7, color='skyblue', edgecolor='black')
        axes[1, 0].set_title('Magnitude Distribution', fontsize=14, fontweight='bold')
        axes[1, 0].set_xlabel('Magnitude')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 4: Frequency distribution
    if 'frequency' in df.columns:
        axes[1, 1].hist(df['frequency'], bins=30, alpha=0.7, color='lightcoral', edgecolor='black')
        axes[1, 1].set_title('Frequency Distribution', fontsize=14, fontweight='bold')
        axes[1, 1].set_xlabel('Frequency')
        axes[1, 1].set_ylabel('Count')
        axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Analyze spatial distribution
if 'df' in locals():
    analyze_spatial_distribution(df)
else:
    print("Data not loaded yet. Please run the data loading cell first.")

## 8. Model Predictions Analysis

This section will analyze the model's predictions when available.

In [None]:
def analyze_model_predictions(actual_values, predicted_values, metadata=None):
    """Analyze model predictions vs actual values."""
    if actual_values is None or predicted_values is None:
        print("No prediction data available for analysis.")
        print("Please run the model on your test data first.")
        return None
    
    # Calculate metrics
    mae = np.mean(np.abs(predicted_values - actual_values))
    rmse = np.sqrt(np.mean((predicted_values - actual_values) ** 2))
    correlation = np.corrcoef(predicted_values, actual_values)[0, 1]
    
    # Create analysis plots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # Plot 1: Scatter plot of predictions vs actual
    axes[0, 0].scatter(actual_values, predicted_values, alpha=0.6)
    axes[0, 0].plot([actual_values.min(), actual_values.max()], 
                     [actual_values.min(), actual_values.max()], 'r--', linewidth=2)
    axes[0, 0].set_title('Predictions vs Actual Values', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Actual Values')
    axes[0, 0].set_ylabel('Predicted Values')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot 2: Residuals
    residuals = predicted_values - actual_values
    axes[0, 1].scatter(actual_values, residuals, alpha=0.6)
    axes[0, 1].axhline(y=0, color='r', linestyle='--', linewidth=2)
    axes[0, 1].set_title('Residuals vs Actual Values', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Actual Values')
    axes[0, 1].set_ylabel('Residuals')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot 3: Residuals distribution
    axes[1, 0].hist(residuals, bins=30, alpha=0.7, color='lightgreen', edgecolor='black')
    axes[1, 0].set_title('Residuals Distribution', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Residuals')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot 4: Metrics summary
    axes[1, 1].text(0.1, 0.8, f'MAE: {mae:.4f}', fontsize=12, transform=axes[1, 1].transAxes)
    axes[1, 1].text(0.1, 0.7, f'RMSE: {rmse:.4f}', fontsize=12, transform=axes[1, 1].transAxes)
    axes[1, 1].text(0.1, 0.6, f'Correlation: {correlation:.4f}', fontsize=12, transform=axes[1, 1].transAxes)
    axes[1, 1].set_title('Performance Metrics', fontsize=14, fontweight='bold')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return {
        'mae': mae,
        'rmse': rmse,
        'correlation': correlation
    }

print("Model predictions analysis function defined. Run with actual predictions when available.")

## 9. Summary and Recommendations

This section provides a summary of the analysis and recommendations for model improvement.

In [None]:
def generate_performance_summary(df, bin_performance=None, zero_freq_analysis=None):
    """Generate a comprehensive performance summary."""
    print("=" * 60)
    print("SHARED LSTM PERFORMANCE ANALYSIS SUMMARY")
    print("=" * 60)
    
    if df is not None:
        print(f"\n📊 DATA OVERVIEW:")
        print(f"   • Total records: {len(df):,}")
        print(f"   • Date range: {df['date'].min()} to {df['date'].max()}" if 'date' in df.columns else "   • Date column not found")
        print(f"   • Columns: {len(df.columns)}")
        
        if 'magnitude' in df.columns:
            print(f"   • Magnitude range: {df['magnitude'].min():.2f} to {df['magnitude'].max():.2f}")
        
        if 'frequency' in df.columns:
            print(f"   • Frequency range: {df['frequency'].min():.2f} to {df['frequency'].max():.2f}")
    
    if bin_performance is not None:
        print(f"\n🗺️  SPATIAL BIN ANALYSIS:")
        print(f"   • Total bins: {len(bin_performance):,}")
        print(f"   • Average events per bin: {bin_performance['total_events'].mean():.2f}")
        print(f"   • Bins with zero frequency: {len(bin_performance[bin_performance['zero_frequency_ratio'] > 0]):,}")
    
    if zero_freq_analysis is not None:
        yearly_zero_freq, yearly_summary = zero_freq_analysis
        print(f"\n🔍 ZERO FREQUENCY ANALYSIS:")
        print(f"   • Total zero frequency events: {yearly_zero_freq['zero_freq_count'].sum():,}")
        print(f"   • Years with zero frequency: {len(yearly_summary):,}")
        if len(yearly_summary) > 0:
            max_zero_year = yearly_summary['zero_freq_count'].idxmax()
            max_zero_count = yearly_summary['zero_freq_count'].max()
            print(f"   • Year with most zero frequency: {max_zero_year} ({max_zero_count:,} events)")
    
    print(f"\n📈 RECOMMENDATIONS:")
    print(f"   1. Run the model on test data to get actual predictions")
    print(f"   2. Calculate accuracy metrics per bin using predictions vs actual values")
    print(f"   3. Analyze spatial patterns in model performance")
    print(f"   4. Identify bins with poor performance for targeted improvement")
    print(f"   5. Consider temporal patterns in zero frequency bins")
    
    print("\n" + "=" * 60)

# Generate summary
if 'df' in locals():
    generate_performance_summary(df, 
                                bin_performance if 'bin_performance' in locals() else None,
                                zero_freq_analysis if 'zero_freq_analysis' in locals() else None)
else:
    print("Data not loaded yet. Please run the data loading cell first.")

## 10. Next Steps

To complete the performance analysis, you'll need to:

1. **Load your trained model** - Update the `MODEL_PATH` variable with the correct path
2. **Prepare test data** - Ensure you have a test dataset with the same format as training data
3. **Generate predictions** - Run the model on test data to get predictions
4. **Calculate accuracy metrics** - Use the prediction functions to analyze performance
5. **Customize visualizations** - Modify plots based on your specific needs

The notebook provides a comprehensive framework for analyzing:
- Training history and convergence
- Spatial bin performance
- Zero frequency patterns by year
- Model prediction accuracy
- Spatial distribution of performance

Run each section sequentially to build a complete understanding of your model's performance!