# Module 1: Training and Validation Metrics Analysis

This module provides comprehensive performance visualization for deep learning models (GRU, LSTM, RNN, MLP, Transformer) across different temporal windows (7-day, 14-day, 30-day) for seismic-geomagnetic signal recognition.  

In [None]:
"""
Model Training Performance Visualization with Corrected Legend Positioning
Supports Chinese paths using pathlib for cross-platform compatibility
Legend boxes aligned with plot borders for professional appearance
"""

import os
import re
import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any

# ================== Global Configuration ==================

# Base directory using pathlib for better path handling
BASE_DIR = Path(r"your_project/results")
   # Change to your actual results directory
   # For example, if your project is at: C:/Users/Tian/Desktop/地磁论文代码运行测试
   # Then change to: output_dir: str = r"C:\Users\Tian\Desktop\地磁论文代码运行测试\results"

# Model names for ensemble learning
MODEL_NAMES = ["GRUModel", "LSTMModel", "RNNModel", "MLPModel", "TransformerModel"]

# Mapping of model names to their respective folder names
MODEL_FOLDERS = {
    "GRUModel": "gru_models",
    "LSTMModel": "lstm_models",
    "RNNModel": "rnn_models",
    "MLPModel": "mlp_models",
    "TransformerModel": "transformer_models"
}

# Time window configurations
TIME_WINDOWS = ["7day", "14day", "30day"]
WINDOW_DISPLAY = {"7day": "7-day", "14day": "14-day", "30day": "30-day"}

# Color scheme for different models (scientific publication standard)
COLORS = {
    'GRU': '#E41A1C',         # Red
    'LSTM': '#377EB8',        # Blue
    'RNN': '#984EA3',         # Purple
    'MLP': '#4DAF4A',         # Green
    'Transformer': '#FF7F00'  # Orange
}

# Configure matplotlib for scientific journal style
plt.rcParams.update({
    'font.family': 'Arial',
    'font.weight': 'bold',
    'axes.labelsize': 32,
    'axes.titlesize': 34,
    'legend.fontsize': 21,
    'xtick.labelsize': 26,
    'ytick.labelsize': 26,
    'axes.grid': False,
    'figure.figsize': (18, 16),
    'legend.frameon': True,
    'legend.fancybox': False,
    'legend.edgecolor': 'black',
    'axes.linewidth': 1.5,
    'lines.linewidth': 4.5,
})


def get_model_key(model_name: str) -> str:
    """
    Extract the model key from the full model name.
    
    Args:
        model_name: Full model name (e.g., 'GRUModel')
    
    Returns:
        Model key (e.g., 'GRU')
    """
    return model_name.replace("Model", "").strip()


def read_ensemble_weights(model_name: str, window_name: str) -> List[float]:
    """
    Read ensemble weights from the configuration file.
    
    Args:
        model_name: Name of the model
        window_name: Time window identifier
    
    Returns:
        List of normalized ensemble weights, defaults to equal weights if not found
    """
    folder_path = BASE_DIR / MODEL_FOLDERS[model_name]
    config_filename = f"{model_name}_{window_name}_ensemble_config.json"
    config_path = folder_path / config_filename
    
    # Default equal weights for 5 folds
    default_weights = [0.2] * 5
    
    if not config_path.exists():
        print(f"Config not found: {config_path.name}, using equal weights")
        return default_weights
    
    try:
        with open(config_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
        
        weights = config.get("ensemble_weights", [])
        if not weights:
            return default_weights
        
        # Normalize weights to sum to 1
        total = sum(weights)
        if total <= 0:
            return default_weights
        
        return [w / total for w in weights]
    
    except (json.JSONDecodeError, IOError) as e:
        print(f"Error reading config {config_path.name}: {e}")
        return default_weights


def parse_training_log(log_file: Path) -> Dict[str, List[float]]:
    """
    Parse training log file to extract epoch metrics.
    
    Args:
        log_file: Path to the log file
    
    Returns:
        Dictionary containing lists of epochs, losses, and accuracies
    """
    data = {
        'epoch': [],
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    if not log_file.exists():
        print(f"Log file not found: {log_file.name}")
        return data
    
    try:
        with open(log_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
        
        # Regular expressions for extracting metrics
        patterns = {
            'epoch': r'Epoch\s+(\d+)/',
            'train_loss': r'Train Loss:\s+([\d\.]+)',
            'train_acc': r'Train Acc:\s+([\d\.]+)',
            'val_loss': r'Val Loss:\s+([\d\.]+)',
            'val_acc': r'Val Acc:\s+([\d\.]+)'
        }
        
        for line in lines:
            if line.startswith('Epoch'):
                # Extract all metrics from the line
                matches = {key: re.search(pattern, line) 
                          for key, pattern in patterns.items()}
                
                # Only add data if all metrics are found
                if all(matches.values()):
                    data['epoch'].append(int(matches['epoch'].group(1)))
                    data['train_loss'].append(float(matches['train_loss'].group(1)))
                    data['train_acc'].append(float(matches['train_acc'].group(1)))
                    data['val_loss'].append(float(matches['val_loss'].group(1)))
                    data['val_acc'].append(float(matches['val_acc'].group(1)))
    
    except Exception as e:
        print(f"Error parsing log file {log_file.name}: {e}")
    
    return data


def find_fold_logs(model_name: str, window_name: str) -> List[Path]:
    """
    Find all fold log files for a specific model and time window.
    
    Args:
        model_name: Name of the model
        window_name: Time window identifier
    
    Returns:
        Sorted list of log file paths
    """
    folder_path = BASE_DIR / MODEL_FOLDERS[model_name]
    
    if not folder_path.exists():
        print(f"Model folder does not exist: {folder_path}")
        return []
    
    # Pattern for fold log files
    pattern = f"{model_name}_{window_name}_fold_*_logs.txt"
    
    # Use pathlib's glob method for better path handling
    log_files = sorted(folder_path.glob(pattern))
    
    if not log_files:
        print(f"No log files found matching pattern: {pattern}")
        # List available files for debugging
        available_logs = list(folder_path.glob("*_logs.txt"))
        if available_logs:
            print(f"Available log files: {[f.name for f in available_logs[:3]]}")
    else:
        print(f"Found {len(log_files)} log file(s) for {model_name} {window_name}")
    
    return log_files


def calculate_weighted_statistics(
    data_array: np.ndarray, 
    weights: List[float]
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Calculate weighted mean and standard error for ensemble data.
    
    Args:
        data_array: 2D array (folds x epochs)
        weights: Weights for each fold
    
    Returns:
        Tuple of (weighted means, standard errors)
    """
    n_folds, n_epochs = data_array.shape
    means = np.zeros(n_epochs)
    errors = np.zeros(n_epochs)
    
    for epoch in range(n_epochs):
        values = data_array[:, epoch]
        # Weighted mean
        weighted_mean = np.sum(np.array(weights) * values)
        # Weighted variance
        weighted_var = np.sum(np.array(weights) * (values - weighted_mean) ** 2)
        # Standard error
        std_error = np.sqrt(weighted_var) / np.sqrt(n_folds)
        
        means[epoch] = weighted_mean
        errors[epoch] = std_error
    
    return means, errors


def combine_fold_results(model_name: str, window_name: str) -> Optional[Dict[str, np.ndarray]]:
    """
    Combine results from multiple folds using weighted averaging.
    
    Args:
        model_name: Name of the model
        window_name: Time window identifier
    
    Returns:
        Dictionary containing combined metrics or None if no data available
    """
    weights = read_ensemble_weights(model_name, window_name)
    log_files = find_fold_logs(model_name, window_name)
    
    if not log_files:
        print(f"No log files found for {model_name} {window_name}")
        return None
    
    # Parse all fold data
    folds_data = []
    for log_file in log_files:
        parsed = parse_training_log(log_file)
        if parsed['epoch']:
            folds_data.append(parsed)
    
    if not folds_data:
        print(f"No valid data found in log files for {model_name} {window_name}")
        return None
    
    # Adjust weights to match actual number of folds
    n_folds = len(folds_data)
    used_weights = weights[:n_folds]
    
    # Normalize weights
    weight_sum = sum(used_weights)
    if weight_sum <= 0:
        used_weights = [1.0 / n_folds] * n_folds
    else:
        used_weights = [w / weight_sum for w in used_weights]
    
    # Find minimum number of epochs across all folds
    min_epochs = min(len(fd['epoch']) for fd in folds_data)
    
    # Create arrays for all metrics
    metrics = {
        'train_loss': np.zeros((n_folds, min_epochs)),
        'val_loss': np.zeros((n_folds, min_epochs)),
        'train_acc': np.zeros((n_folds, min_epochs)),
        'val_acc': np.zeros((n_folds, min_epochs))
    }
    
    # Fill arrays with fold data
    for i, fold_data in enumerate(folds_data):
        for metric in metrics:
            metrics[metric][i, :] = fold_data[metric][:min_epochs]
    
    # Calculate weighted statistics for each metric
    combined = {'epoch': np.arange(1, min_epochs + 1)}
    
    for metric_name, metric_data in metrics.items():
        mean, error = calculate_weighted_statistics(metric_data, used_weights)
        combined[f'{metric_name}_mean'] = mean
        combined[f'{metric_name}_err'] = error
    
    return combined


def create_performance_plots(window_name: str, results_dict: Dict[str, Dict]) -> None:
    """
    Create 2x2 subplot of training and validation metrics with transparent legend background.
    
    Args:
        window_name: Time window identifier
        results_dict: Dictionary mapping model names to their combined results
    """
    window_display = WINDOW_DISPLAY.get(window_name, window_name)
    
    # Create figure with 2x2 subplots
    fig, axes = plt.subplots(2, 2, figsize=(25, 18))
    
    # Configure subplot properties
    subplot_config = [
        {'ax': axes[0, 0], 'title': 'Training Accuracy', 'metric': 'train_acc', 'style': '-'},
        {'ax': axes[0, 1], 'title': 'Validation Accuracy', 'metric': 'val_acc', 'style': '--'},
        {'ax': axes[1, 0], 'title': 'Training Loss', 'metric': 'train_loss', 'style': '-'},
        {'ax': axes[1, 1], 'title': 'Validation Loss', 'metric': 'val_loss', 'style': '--'}
    ]
    
    # Collect data ranges for dynamic y-axis scaling
    data_ranges = {
        'train_acc': [], 'val_acc': [],
        'train_loss': [], 'val_loss': []
    }
    
    for model_name, data in results_dict.items():
        for metric in data_ranges:
            data_ranges[metric].extend(data[f'{metric}_mean'])
    
    # Set up each subplot
    for config in subplot_config:
        ax = config['ax']
        ax.set_xlabel('Epoch', fontweight='bold')
        ax.set_ylabel(config['title'], fontweight='bold')
        ax.tick_params(direction='in', width=1.5)
        
        # Set dynamic y-axis range with 10% padding
        metric = config['metric']
        if data_ranges[metric]:
            data_min = min(data_ranges[metric])
            data_max = max(data_ranges[metric])
            data_range = data_max - data_min
            ax.set_ylim(data_min - data_range * 0.1, data_max + data_range * 0.1)
        
        # Configure spines
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(1.5)
    
    # Track final validation accuracy for legend sorting
    model_performance = {}
    
    # Plot data for each model
    for model_name, data in results_dict.items():
        model_key = get_model_key(model_name)
        color = COLORS.get(model_key, '#000000')
        
        epochs = data['epoch']
        
        # Extract final epoch values for legend labels
        final_values = {
            'train_acc': data['train_acc_mean'][-1],
            'val_acc': data['val_acc_mean'][-1],
            'train_loss': data['train_loss_mean'][-1],
            'val_loss': data['val_loss_mean'][-1]
        }
        
        model_performance[model_key] = final_values['val_acc']
        
        # Plot on each subplot
        for config in subplot_config:
            ax = config['ax']
            metric = config['metric']
            y_data = data[f'{metric}_mean']
            
            # Create label with final value
            label = f"{model_key} ({final_values[metric]:.3f})"
            
            # Plot line
            ax.plot(epochs, y_data, 
                   color=color, 
                   linestyle=config['style'],
                   linewidth=4.5,
                   label=label)
    
    # Sort legends by validation accuracy (descending) and position at plot borders
    for i, config in enumerate(subplot_config):
        ax = config['ax']
        handles, labels = ax.get_legend_handles_labels()
        
        if handles:
            # Extract model names and sort by performance
            model_names = [label.split()[0] for label in labels]
            sorted_indices = sorted(range(len(model_names)),
                                  key=lambda idx: model_performance.get(model_names[idx], 0),
                                  reverse=True)
            
            # Position legend at plot borders based on subplot position
            if i == 0:  # Training accuracy - upper right border
                loc = 'upper right'
                bbox = (1.0, 1.0)
            elif i == 1:  # Validation accuracy - upper right border
                loc = 'upper right'
                bbox = (1.0, 1.0)
            elif i == 2:  # Training loss - lower right border
                loc = 'lower right'
                bbox = (1.0, 0.0)
            else:  # Validation loss - lower right border
                loc = 'lower right'
                # Special adjustment for 30-day window
                if window_name == "30day":
                    bbox = (1.0, 0.35)
                else:
                    bbox = (1.0, 0.0)
            
            # Create sorted legend aligned with plot border
            legend = ax.legend([handles[idx] for idx in sorted_indices],
                             [labels[idx] for idx in sorted_indices],
                             loc=loc,
                             bbox_to_anchor=bbox,
                             frameon=True,
                             borderaxespad=0.0)  # No padding from border
            
            # Style legend with transparent background and visible border
            legend.get_frame().set_edgecolor('black')
            legend.get_frame().set_linewidth(1.0)
            legend.get_frame().set_facecolor('none')  # Transparent background
            legend.get_frame().set_alpha(1.0)  # Full opacity for border
    
    # Adjust layout
    plt.subplots_adjust(left=0.06, right=0.94, bottom=0.06, top=0.94, wspace=0.25, hspace=0.25)
    
    # Save figure
    save_dir = BASE_DIR / "performance_visualization" / "Performance_Analysis"
    save_dir.mkdir(parents=True, exist_ok=True)
    
    output_file = save_dir / f"Model_Performance_{window_display}.png"
    
    try:
        plt.savefig(output_file, dpi=600, bbox_inches='tight')
        print(f"Successfully saved: {output_file.name}")
    except Exception as e:
        print(f"Error saving figure: {e}")
        # Try alternative save location
        alt_dir = Path("C:/temp/results")
        alt_dir.mkdir(parents=True, exist_ok=True)
        alt_file = alt_dir / f"Model_Performance_{window_display}.png"
        plt.savefig(alt_file, dpi=600, bbox_inches='tight')
        print(f"Saved to alternative location: {alt_file}")
    
    plt.close(fig)


def main():
    """
    Main execution function for generating performance visualization plots.
    """
    print("=" * 70)
    print("Model Training Performance Visualization")
    print("=" * 70)
    
    # Verify base directory exists
    if not BASE_DIR.exists():
        print(f"ERROR: Base directory does not exist: {BASE_DIR}")
        print("Please check the path configuration.")
        return
    
    print(f"Base directory: {BASE_DIR}")
    print(f"Models to process: {', '.join(MODEL_NAMES)}")
    print(f"Time windows: {', '.join(TIME_WINDOWS)}")
    print("-" * 70)
    
    # Check available model folders
    print("\nChecking model folders:")
    for model_name, folder_name in MODEL_FOLDERS.items():
        folder_path = BASE_DIR / folder_name
        if folder_path.exists():
            log_count = len(list(folder_path.glob("*_logs.txt")))
            print(f"  ✓ {model_name:20} {log_count} log files found")
        else:
            print(f"  ✗ {model_name:20} folder not found")
    
    print("-" * 70)
    
    # Process each time window
    generated_files = []
    
    for window in TIME_WINDOWS:
        print(f"\nProcessing {WINDOW_DISPLAY[window]} window:")
        
        # Collect results for all models
        results = {}
        
        for model in MODEL_NAMES:
            print(f"  Processing {model}...", end=" ")
            combined_data = combine_fold_results(model, window)
            
            if combined_data is not None:
                results[model] = combined_data
                print("✓")
            else:
                print("✗ (no data)")
        
        # Generate plot if we have results
        if results:
            print(f"  Generating plot for {len(results)} models...")
            create_performance_plots(window, results)
            generated_files.append(f"Model_Performance_{WINDOW_DISPLAY[window]}.png")
        else:
            print(f"  No valid data for {WINDOW_DISPLAY[window]} window, skipping plot.")
    
    # Final summary
    print("\n" + "=" * 70)
    if generated_files:
        save_dir = BASE_DIR / "performance_visualization" / "Performance_Analysis"
        print(f"SUCCESS: Generated {len(generated_files)} visualization(s)")
        print(f"Output directory: {save_dir}")
        for filename in generated_files:
            print(f"  - {filename}")
    else:
        print("WARNING: No visualizations were generated.")
        print("Please check that log files exist and are properly formatted.")
    print("=" * 70)


if __name__ == "__main__":
    main()

Model Training Performance Visualization
Base directory: C:\Users\Tian\Desktop\地磁论文代码运行测试\results
Models to process: GRUModel, LSTMModel, RNNModel, MLPModel, TransformerModel
Time windows: 7day, 14day, 30day
----------------------------------------------------------------------

Checking model folders:
  ✓ GRUModel             15 log files found
  ✓ LSTMModel            15 log files found
  ✓ RNNModel             15 log files found
  ✓ MLPModel             15 log files found
  ✓ TransformerModel     15 log files found
----------------------------------------------------------------------

Processing 7-day window:
  Processing GRUModel... Found 5 log file(s) for GRUModel 7day
✓
  Processing LSTMModel... Found 5 log file(s) for LSTMModel 7day
✓
  Processing RNNModel... Found 5 log file(s) for RNNModel 7day
✓
  Processing MLPModel... Found 5 log file(s) for MLPModel 7day
✓
  Processing TransformerModel... Found 5 log file(s) for TransformerModel 7day
✓
  Generating plot for 5 models...
Su

# Model 2:Model effect comparison under different time windows

This module provides comprehensive comparative analysis and visualization of model performance across different temporal windows (7-day, 14-day, 30-day) for seismic-geomagnetic signal recognition. It generates radar charts and scatter plots to evaluate model effectiveness, stability, and trade-offs.

In [None]:
"""
Model Performance Comparison with Radar Charts and Scatter Plots
Supports Chinese paths using pathlib for cross-platform compatibility
Generates comprehensive performance visualizations for multiple models
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from scipy import stats

# Configure matplotlib for scientific publication style
plt.rcParams.update({
    'font.family': 'Arial',
    'font.weight': 'bold',
    'axes.labelweight': 'bold',
    'axes.labelsize': 26,
    'axes.titlesize': 26,
    'legend.fontsize': 16,
    'xtick.labelsize': 22,
    'ytick.labelsize': 22,
    'axes.grid': False,
    'legend.frameon': False,
})

# Color scheme for different models
COLORS = {
    'GRU': '#E41A1C',         # Red
    'LSTM': '#377EB8',        # Blue
    'MLP': '#4DAF4A',         # Green
    'RNN': '#984EA3',         # Purple
    'Transformer': '#FF7F00'  # Orange
}

# Marker styles for scatter plots
MARKERS = {
    'GRU': 'o',
    'LSTM': 's',
    'MLP': '^',
    'RNN': 'v',
    'Transformer': 'X'
}

# Metric name mappings for internal consistency
METRIC_MAPPING = {
    'F1 Score': 'f1',
    'Precision': 'precision',
    'Recall': 'recall',
    'Specificity': 'specificity',
    'Norm MCC': 'norm_mcc'
}

# Model folder structure
MODEL_FOLDERS = {
    "GRU": "gru_models",
    "LSTM": "lstm_models",
    "MLP": "mlp_models",
    "RNN": "rnn_models",
    "Transformer": "transformer_models"
}


def format_tick_label(value: float) -> str:
    """
    Format tick label to avoid -0.00 display issue.
    
    Args:
        value: The numeric value to format
    
    Returns:
        Formatted string representation
    """
    # Handle near-zero values
    if abs(value) < 0.005:
        return "0.00"
    return f"{value:.2f}"


def load_model_configurations(BASE_DIR: Path) -> Dict[str, Dict]:
    """
    Load all model configuration files from the specified directory.
    
    Args:
        BASE_DIR: Base directory containing model folders
    
    Returns:
        Dictionary containing model configurations organized by model and window
    """
    model_names = ["GRU", "LSTM", "MLP", "RNN", "Transformer"]
    window_periods = ["7day", "14day", "30day"]
    all_model_data = {}
    
    for model_name in model_names:
        folder_name = MODEL_FOLDERS.get(model_name)
        if not folder_name:
            continue
        
        model_path = BASE_DIR / folder_name
        if not model_path.exists():
            print(f"Model folder not found: {model_path}")
            continue
        
        all_model_data[model_name] = {}
        
        for window_period in window_periods:
            config_file = model_path / f"{model_name}Model_{window_period}_ensemble_config.json"
            
            if config_file.exists():
                try:
                    with open(config_file, 'r', encoding='utf-8') as f:
                        config = json.load(f)
                    all_model_data[model_name][window_period] = config
                    print(f"Loaded: {config_file.name}")
                except Exception as e:
                    print(f"Error loading {config_file.name}: {e}")
    
    return all_model_data


def extract_performance_metrics(model_data: Dict) -> Dict[str, Dict]:
    """
    Extract performance metrics from model configuration data.
    
    Args:
        model_data: Raw model configuration data
    
    Returns:
        Organized performance metrics by model and window
    """
    metrics = {}
    
    for model_name, windows in model_data.items():
        metrics[model_name] = {}
        
        for window_period, config in windows.items():
            metrics[model_name][window_period] = {}
            
            if 'average_metrics' in config:
                metrics_data = config['average_metrics']
                
                # Extract mean values
                metrics[model_name][window_period]['f1'] = metrics_data.get('avg_f1_scores', 0)
                metrics[model_name][window_period]['precision'] = metrics_data.get('avg_precisions', 0)
                metrics[model_name][window_period]['recall'] = metrics_data.get('avg_recalls', 0)
                metrics[model_name][window_period]['specificity'] = metrics_data.get('avg_specificities', 0)
                metrics[model_name][window_period]['mcc'] = metrics_data.get('avg_mccs', 0)
                metrics[model_name][window_period]['norm_mcc'] = metrics_data.get('avg_norm_mccs', 0)
                
                # Extract standard deviations
                metrics[model_name][window_period]['f1_std'] = metrics_data.get('std_f1_scores', 0.01)
                metrics[model_name][window_period]['precision_std'] = metrics_data.get('std_precisions', 0.01)
                metrics[model_name][window_period]['recall_std'] = metrics_data.get('std_recalls', 0.01)
                metrics[model_name][window_period]['specificity_std'] = metrics_data.get('std_specificities', 0.01)
                metrics[model_name][window_period]['mcc_std'] = metrics_data.get('std_mccs', 0.01)
                metrics[model_name][window_period]['norm_mcc_std'] = metrics_data.get('std_norm_mccs', 0.01)
    
    return metrics


def create_radar_chart(window_data: Dict, metric_names: List[str], output_file: Path) -> None:
    """
    Create a radar chart for comparing model performance across multiple metrics.
    
    Args:
        window_data: Performance data for all models in a specific window
        metric_names: List of metric names to display
        output_file: Path to save the output figure
    """
    fig = plt.figure(figsize=(10, 10), dpi=300)
    ax = fig.add_subplot(111, projection='polar')
    
    # Prepare angular positions
    models = list(window_data.keys())
    num_metrics = len(metric_names)
    theta = np.linspace(0, 2 * np.pi, num_metrics, endpoint=False).tolist()
    theta += theta[:1]  # Close the polygon
    
    # Configure radar chart appearance
    ax.set_theta_offset(np.pi / 2)
    ax.set_theta_direction(-1)
    ax.set_facecolor("#ffffff")
    
    # Configure grid
    ax.grid(True, linestyle='-', linewidth=0.5, color='gray', alpha=0.3)
    ax.set_ylim(0, 1)
    ax.set_rticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.tick_params(axis='y', labelsize=22)
    
    # Hide default angular labels
    ax.set_xticks(theta[:-1])
    ax.set_xticklabels([])
    
    # Add metric names with position adjustments
    for i, (angle, metric) in enumerate(zip(theta[:-1], metric_names)):
        # Fine-tune positions for overlapping labels
        if metric == 'Norm MCC':
            adjusted_angle = angle - 0.1
        elif metric == 'Precision':
            adjusted_angle = angle + 0.1
        else:
            adjusted_angle = angle
        
        ax.text(adjusted_angle, 1.15, metric, 
                ha='center', va='center', fontsize=26, fontweight='bold')
    
    # Plot each model's performance
    for model in models:
        values = []
        for metric in metric_names:
            metric_key = METRIC_MAPPING.get(metric, metric.lower().replace(' ', '_'))
            if metric_key in window_data[model]:
                values.append(window_data[model][metric_key])
            else:
                values.append(0.0)
        
        values_closed = values + values[:1]
        color = COLORS.get(model, '#000000')
        
        ax.plot(theta, values_closed,
                color=color, linewidth=2.5,
                marker='o', markersize=10,
                label=model)
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    print(f"Generated: {output_file.name}")


def create_scatter_plot(window_data: Dict, x_metric: str, y_metric: str, 
                       models: List[str], output_file: Path) -> None:
    """
    Create a scatter plot comparing two metrics across models.
    
    Args:
        window_data: Performance data for all models
        x_metric: Metric for x-axis
        y_metric: Metric for y-axis
        models: List of model names
        output_file: Path to save the output figure
    """
    fig, ax = plt.subplots(figsize=(8, 8), dpi=300)
    
    # Collect data points
    x_values = []
    y_values = []
    model_names = []
    
    for model in models:
        if x_metric in window_data[model] and y_metric in window_data[model]:
            x_values.append(window_data[model][x_metric])
            y_values.append(window_data[model][y_metric])
            model_names.append(model)
    
    # Plot scatter points
    for i, model in enumerate(model_names):
        ax.scatter(x_values[i], y_values[i],
                  s=400, c='none',
                  marker=MARKERS.get(model, 'o'),
                  edgecolors=COLORS.get(model, 'black'),
                  linewidth=2.5, label=model)
    
    # Format axis labels
    x_label = x_metric.replace('_', ' ').title().replace('F 1', 'F1')
    y_label = y_metric.replace('_', ' ').title().replace('F 1', 'F1')
    
    ax.set_xlabel(x_label, fontsize=26, fontweight='bold')
    ax.set_ylabel(y_label, fontsize=26, fontweight='bold')
    
    # Set axis ticks with proper formatting
    if len(x_values) > 0:
        x_min, x_max = min(x_values), max(x_values)
        x_margin = (x_max - x_min) * 0.05
        x_ticks = np.linspace(x_min - x_margin, x_max + x_margin, 5)
        ax.set_xticks(x_ticks)
        ax.set_xticklabels([format_tick_label(tick) for tick in x_ticks])
        
        y_min, y_max = min(y_values), max(y_values)
        y_margin = (y_max - y_min) * 0.05
        y_ticks = np.linspace(y_min - y_margin, y_max + y_margin, 5)
        ax.set_yticks(y_ticks)
        ax.set_yticklabels([format_tick_label(tick) for tick in y_ticks])
        
        ax.set_xlim(x_min - x_margin, x_max + x_margin)
        ax.set_ylim(y_min - y_margin, y_max + y_margin)
    
    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1.5)
        spine.set_edgecolor('black')
    
    # Add diagonal reference line for correlation analysis
    if x_metric != f'{x_metric}_std' and y_metric != f'{y_metric}_std':
        lims = [
            np.min([ax.get_xlim(), ax.get_ylim()]),
            np.max([ax.get_xlim(), ax.get_ylim()]),
        ]
        ax.plot(lims, lims, 'k--', alpha=0.3, linewidth=1)
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    print(f"Generated: {output_file.name}")


def create_stability_plot(window_data: Dict, metric: str, models: List[str], 
                         output_file: Path) -> None:
    """
    Create a stability scatter plot showing mean vs standard deviation.
    
    Args:
        window_data: Performance data for all models
        metric: Base metric name
        models: List of model names
        output_file: Path to save the output figure
    """
    fig, ax = plt.subplots(figsize=(8, 8), dpi=300)
    
    # Collect mean and std values
    x_values = []  # mean values
    y_values = []  # std values
    model_names = []
    
    for model in models:
        if metric in window_data[model] and f"{metric}_std" in window_data[model]:
            x_values.append(window_data[model][metric])
            y_values.append(window_data[model][f"{metric}_std"])
            model_names.append(model)
    
    # Plot scatter points
    for i, model in enumerate(model_names):
        ax.scatter(x_values[i], y_values[i],
                  s=400, c='none',
                  marker=MARKERS.get(model, 'o'),
                  edgecolors=COLORS.get(model, 'black'),
                  linewidth=2.5, label=model)
    
    # Format labels
    metric_name = metric.replace('_', ' ').title().replace('F 1', 'F1')
    ax.set_xlabel(f'Mean {metric_name}', fontsize=26, fontweight='bold')
    ax.set_ylabel(f'{metric_name} S.D.', fontsize=26, fontweight='bold')
    
    # Set axis ticks with proper formatting
    if len(x_values) > 0:
        x_min, x_max = min(x_values), max(x_values)
        x_margin = (x_max - x_min) * 0.05
        x_ticks = np.linspace(x_min - x_margin, x_max + x_margin, 5)
        ax.set_xticks(x_ticks)
        ax.set_xticklabels([format_tick_label(tick) for tick in x_ticks])
        
        y_min, y_max = min(y_values), max(y_values)
        y_margin = (y_max - y_min) * 0.05
        y_ticks = np.linspace(y_min - y_margin, y_max + y_margin, 5)
        ax.set_yticks(y_ticks)
        ax.set_yticklabels([format_tick_label(tick) for tick in y_ticks])
        
        ax.set_xlim(x_min - x_margin, x_max + x_margin)
        ax.set_ylim(y_min - y_margin, y_max + y_margin)
    
    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1.5)
        spine.set_edgecolor('black')
    
    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    print(f"Generated: {output_file.name}")


def create_legend_figures(output_dir: Path) -> None:
    """
    Create separate legend figures for radar and scatter plots.
    
    Args:
        output_dir: Directory to save the legend figures
    """
    # Legend for radar charts (line style)
    fig, ax = plt.subplots(figsize=(12, 1.5), dpi=300)
    ax.axis('off')
    
    lines = []
    for model, color in COLORS.items():
        line = plt.Line2D([0], [0], color=color, linewidth=3,
                         marker='o', markersize=12, label=model)
        lines.append(line)
    
    ax.legend(handles=lines, loc='center', ncol=5, fontsize=20,
             frameon=False, columnspacing=2)
    
    plt.tight_layout()
    radar_legend_path = output_dir / 'legend_radar.png'
    plt.savefig(radar_legend_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    print(f"Generated: {radar_legend_path.name}")
    
    # Legend for scatter plots (marker style)
    fig, ax = plt.subplots(figsize=(12, 1.5), dpi=300)
    ax.axis('off')
    
    scatter_handles = []
    for model, color in COLORS.items():
        handle = plt.Line2D([0], [0], marker=MARKERS[model], color='w',
                           markerfacecolor='none', markersize=18,
                           markeredgecolor=color, markeredgewidth=2.5,
                           linestyle='None', label=model)
        scatter_handles.append(handle)
    
    ax.legend(handles=scatter_handles, loc='center', ncol=5, fontsize=20,
             frameon=False, columnspacing=2)
    
    plt.tight_layout()
    scatter_legend_path = output_dir / 'legend_scatter.png'
    plt.savefig(scatter_legend_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    print(f"Generated: {scatter_legend_path.name}")


def main():
    """
    Main execution function to generate all visualization figures.
    """
    # Configure paths using pathlib
    BASE_DIR = Path(r"your_project/results")
    # Change to your actual results directory
    # For example, if your project is at: C:/Users/Tian/Desktop/地磁论文代码运行测试
    # Then change to: BASE_DIR = Path(r"C:\Users\Tian\Desktop\地磁论文代码运行测试\results")
    output_dir = BASE_DIR / "performance_visualization" / "Model_Comparison_Analysis"
    
    # Create output directory
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print("=" * 70)
    print("Model Performance Comparison Visualization")
    print("=" * 70)
    print(f"Base directory: {BASE_DIR}")
    print(f"Output directory: {output_dir}")
    print("-" * 70)
    
    # Load and process data
    print("\nLoading model configurations...")
    model_data = load_model_configurations(BASE_DIR)
    
    if not model_data:
        print("ERROR: No model data found. Please check the base directory.")
        return
    
    print("\nExtracting performance metrics...")
    metrics = extract_performance_metrics(model_data)
    
    # Configuration
    radar_metrics = ['F1 Score', 'Precision', 'Recall', 'Specificity', 'Norm MCC']
    window_periods = ["7day", "14day", "30day"]
    window_display = {"7day": "7-day", "14day": "14-day", "30day": "30-day"}
    
    print("\nGenerating visualizations...")
    print("-" * 70)
    
    # Generate radar charts (3 total)
    print("\nRadar Charts:")
    for idx, window in enumerate(window_periods, 1):
        window_metrics = {}
        for model, model_windows in metrics.items():
            if window in model_windows:
                window_metrics[model] = model_windows[window]
        
        if window_metrics:
            output_file = output_dir / f'radar_chart_{idx}_{window}.png'
            create_radar_chart(window_metrics, radar_metrics, output_file)
    
    # Scatter plot configurations
    scatter_configs = [
        ('f1', 'f1_std', 'stability'),        # F1 stability analysis
        ('norm_mcc', 'f1', 'correlation'),    # Norm MCC vs F1 correlation
        ('precision', 'recall', 'tradeoff')   # Precision-Recall tradeoff
    ]
    
    # Generate scatter plots (9 total: 3 types × 3 windows)
    print("\nScatter Plots:")
    plot_number = 4  # Start from 4 (radar charts are 1-3)
    
    for window in window_periods:
        window_metrics = {}
        available_models = []
        
        for model, model_windows in metrics.items():
            if window in model_windows:
                window_metrics[model] = model_windows[window]
                available_models.append(model)
        
        if window_metrics:
            for x_metric, y_metric, plot_type in scatter_configs:
                output_file = output_dir / f'scatter_{plot_number}_{window}_{plot_type}.png'
                
                if plot_type == 'stability':
                    create_stability_plot(window_metrics, x_metric, available_models, output_file)
                else:
                    create_scatter_plot(window_metrics, x_metric, y_metric, 
                                      available_models, output_file)
                
                plot_number += 1
    
    # Generate legend figures
    print("\nLegend Figures:")
    create_legend_figures(output_dir)
    
    # Summary
    print("\n" + "=" * 70)
    print(f"SUCCESS: Generated all 14 visualization figures")
    print(f"Output location: {output_dir}")
    print("=" * 70)


if __name__ == "__main__":
    main()

Model Performance Comparison Visualization
Base directory: C:\Users\Tian\Desktop\地磁论文代码运行测试\results
Output directory: C:\Users\Tian\Desktop\地磁论文代码运行测试\results\performance_visualization\Model_Comparison_Analysis
----------------------------------------------------------------------

Loading model configurations...
Loaded: GRUModel_7day_ensemble_config.json
Loaded: GRUModel_14day_ensemble_config.json
Loaded: GRUModel_30day_ensemble_config.json
Loaded: LSTMModel_7day_ensemble_config.json
Loaded: LSTMModel_14day_ensemble_config.json
Loaded: LSTMModel_30day_ensemble_config.json
Loaded: MLPModel_7day_ensemble_config.json
Loaded: MLPModel_14day_ensemble_config.json
Loaded: MLPModel_30day_ensemble_config.json
Loaded: RNNModel_7day_ensemble_config.json
Loaded: RNNModel_14day_ensemble_config.json
Loaded: RNNModel_30day_ensemble_config.json
Loaded: TransformerModel_7day_ensemble_config.json
Loaded: TransformerModel_14day_ensemble_config.json
Loaded: TransformerModel_30day_ensemble_config.json

Ex

# Module 3: ROC Curve Analysis and Visualization

This module generates comparative ROC curves for multiple deep learning models (GRU, LSTM, MLP, RNN, Transformer) across different temporal windows (7-day, 14-day, 30-day) for seismic-geomagnetic signal recognition.

In [None]:
"""
Multi-Model ROC Curve Analysis with Professional Visualization
Supports Chinese paths using pathlib for cross-platform compatibility
Legend boxes with transparent background and visible borders
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score
from scipy.interpolate import interp1d
from pathlib import Path
from typing import Dict, List, Tuple, Optional

# ================== Global Configuration ==================

# Base directory using pathlib for better path handling
BASE_DIR = Path(r"your_project/results")
   # Change to your actual results directory
   # For example, if your project is at: C:/Users/Tian/Desktop/地磁论文代码运行测试
   # Then change to: output_dir: str = r"C:\Users\Tian\Desktop\地磁论文代码运行测试\results"

# Model configurations
MODELS = ["GRU", "LSTM", "MLP", "RNN", "Transformer"]

# Time window configurations  
TIME_WINDOWS = ["7day", "14day", "30day"]
WINDOW_DISPLAY = {"7day": "7-day", "14day": "14-day", "30day": "30-day"}

# Number of cross-validation folds
NUM_FOLDS = 5

# Color scheme for different models (scientific publication standard)
COLORS = {
    'GRU': '#E41A1C',         # Red
    'LSTM': '#377EB8',        # Blue
    'MLP': '#4DAF4A',         # Green
    'RNN': '#984EA3',         # Purple
    'Transformer': '#FF7F00'  # Orange
}

# Configure matplotlib for scientific journal style
plt.rcParams.update({
    'font.family': 'Arial',
    'font.weight': 'bold',
    'axes.labelsize': 24,
    'axes.titlesize': 24,
    'axes.labelweight': 'bold',
    'axes.titleweight': 'bold',
    'legend.fontsize': 13.5,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20,
    'axes.grid': False,
    'figure.figsize': (18, 6),
    'legend.frameon': True,
    'legend.fancybox': False,
    'legend.edgecolor': 'black',
    'axes.linewidth': 1.0,
})

# ================== Data Processing Functions ==================

def load_fold_data(model: str, window: str, fold: int, base_dir: Path) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
    """
    Load test labels and probabilities for a specific fold.
    
    Args:
        model: Model name
        window: Time window identifier
        fold: Fold number (1-based)
        base_dir: Base directory path
    
    Returns:
        Tuple of (labels, probabilities) or (None, None) if files not found
    """
    model_dir = base_dir / f"{model.lower()}_models"
    
    if not model_dir.exists():
        print(f"Directory not found: {model_dir}")
        return None, None
    
    labels_file = model_dir / f"{model}Model_{window}_fold_{fold}_test_labels.npy"
    probs_file = model_dir / f"{model}Model_{window}_fold_{fold}_test_probs.npy"
    
    if not labels_file.exists() or not probs_file.exists():
        print(f"Missing data: {model} - {window} - Fold {fold}")
        return None, None
    
    try:
        labels = np.load(labels_file)
        probs = np.load(probs_file)
        
        # Extract positive class probabilities
        if probs.ndim == 2 and probs.shape[1] == 2:
            pos_probs = probs[:, 1]
        else:
            pos_probs = probs
            
        return labels, pos_probs
        
    except Exception as e:
        print(f"Error loading data for {model} - {window} - Fold {fold}: {e}")
        return None, None


def calculate_fold_roc(labels: np.ndarray, probabilities: np.ndarray) -> Tuple[np.ndarray, np.ndarray, float]:
    """
    Calculate ROC curve for a single fold.
    
    Args:
        labels: True labels
        probabilities: Predicted probabilities
    
    Returns:
        Tuple of (FPR, TPR, AUC score)
    """
    fpr, tpr, _ = roc_curve(labels, probabilities)
    
    # Ensure FPR is monotonically increasing
    unique_fpr, unique_indices = np.unique(fpr, return_index=True)
    unique_tpr = tpr[unique_indices]
    
    auc_score = roc_auc_score(labels, probabilities)
    
    return unique_fpr, unique_tpr, auc_score


def interpolate_roc_curves(fprs: List[np.ndarray], tprs: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Interpolate multiple ROC curves to common FPR points.
    
    Args:
        fprs: List of FPR arrays
        tprs: List of TPR arrays
    
    Returns:
        Tuple of (mean FPR, mean TPR, std TPR)
    """
    # Common FPR points for interpolation
    mean_fpr = np.linspace(0, 1, 100)
    interp_tprs = []
    
    for fpr, tpr in zip(fprs, tprs):
        if len(fpr) > 1:
            interp_func = interp1d(fpr, tpr, kind='linear', 
                                 bounds_error=False, fill_value=(0, 1))
            interp_tprs.append(interp_func(mean_fpr))
    
    if not interp_tprs:
        return mean_fpr, np.zeros_like(mean_fpr), np.zeros_like(mean_fpr)
    
    mean_tpr = np.mean(interp_tprs, axis=0)
    std_tpr = np.std(interp_tprs, axis=0)
    
    return mean_fpr, mean_tpr, std_tpr


def process_model_roc(model: str, window: str, base_dir: Path) -> Optional[Dict]:
    """
    Process ROC curves for all folds of a model.
    
    Args:
        model: Model name
        window: Time window identifier
        base_dir: Base directory path
    
    Returns:
        Dictionary with ROC statistics or None if insufficient data
    """
    fold_fprs = []
    fold_tprs = []
    fold_aucs = []
    
    # Process each fold
    for fold in range(1, NUM_FOLDS + 1):
        labels, probs = load_fold_data(model, window, fold, base_dir)
        
        if labels is None or probs is None:
            continue
        
        fpr, tpr, auc = calculate_fold_roc(labels, probs)
        
        if len(fpr) > 1:
            fold_fprs.append(fpr)
            fold_tprs.append(tpr)
            fold_aucs.append(auc)
    
    if not fold_fprs:
        print(f"Insufficient data for {model} in {window}")
        return None
    
    # Calculate mean ROC curve
    mean_fpr, mean_tpr, std_tpr = interpolate_roc_curves(fold_fprs, fold_tprs)
    
    return {
        "mean_fpr": mean_fpr,
        "mean_tpr": mean_tpr,
        "std_tpr": std_tpr,
        "mean_auc": np.mean(fold_aucs),
        "std_auc": np.std(fold_aucs)
    }


# ================== Visualization Functions ==================

def create_roc_subplot(ax: plt.Axes, window: str, model_data: Dict[str, Dict]) -> None:
    """
    Create ROC curve plot for a specific time window.
    
    Args:
        ax: Matplotlib axes object
        window: Time window identifier
        model_data: Dictionary containing ROC data for all models
    """
    # Configure subplot borders
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1.0)
    
    # Track AUC values for legend sorting
    model_aucs = {}
    
    # Plot ROC curve for each model
    for model in MODELS:
        if model in model_data and window in model_data[model]:
            roc_info = model_data[model][window]
            
            # Store AUC for sorting
            model_aucs[model] = roc_info["mean_auc"]
            
            # Plot ROC curve
            ax.plot(
                roc_info["mean_fpr"], 
                roc_info["mean_tpr"],
                lw=2, 
                color=COLORS[model],
                label=f'{model} ({roc_info["mean_auc"]:.2f})'
            )
    
    # Add diagonal reference line
    ax.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
    
    # Configure axis properties
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate', fontweight='bold')
    ax.set_ylabel('True Positive Rate', fontweight='bold')
    ax.tick_params(direction='in')
    
    # Get current handles and labels for sorting
    handles, labels = ax.get_legend_handles_labels()
    
    if handles:
        # Sort by AUC values (descending), keeping Random at the end
        model_indices = []
        random_idx = -1
        
        for idx, label in enumerate(labels):
            if label == 'Random':
                random_idx = idx
            else:
                model_name = label.split()[0]
                model_indices.append((idx, model_aucs.get(model_name, 0)))
        
        # Sort model indices by AUC
        model_indices.sort(key=lambda x: x[1], reverse=True)
        sorted_indices = [idx for idx, _ in model_indices]
        
        # Add Random at the end if present
        if random_idx >= 0:
            sorted_indices.append(random_idx)
        
        # Create sorted legend with transparent background
        legend = ax.legend(
            [handles[idx] for idx in sorted_indices],
            [labels[idx] for idx in sorted_indices],
            loc='lower right',
            bbox_to_anchor=(1.0, 0.0),
            frameon=True,
            borderaxespad=0.0
        )
        
        # Style legend with transparent background
        frame = legend.get_frame()
        frame.set_edgecolor('black')
        frame.set_linewidth(1.0)
        frame.set_facecolor('none')  # Transparent background


def create_multi_model_roc_figure(model_data: Dict[str, Dict]) -> plt.Figure:
    """
    Create figure with ROC curves for all time windows.
    
    Args:
        model_data: Dictionary containing ROC data for all models
    
    Returns:
        Matplotlib figure object
    """
    # Create figure with three subplots
    fig, axes = plt.subplots(1, 3, figsize=(17, 5))
    
    # Configuration for subplot labels
    label_positions = [
        {'x': 0.01, 'y': 0.95},  # Position for 'a'
        {'x': 0.35, 'y': 0.95},  # Position for 'b'  
        {'x': 0.69, 'y': 0.95}   # Position for 'c'
    ]
    label_fontsize = 20
    
    # Generate subplot for each time window
    for idx, window in enumerate(TIME_WINDOWS):
        ax = axes[idx]
        
        # Add subplot label
        fig.text(
            label_positions[idx]['x'],
            label_positions[idx]['y'],
            f'{chr(97+idx)}',
            fontsize=label_fontsize,
            fontweight='bold',
            va='bottom',
            ha='left'
        )
        
        # Create ROC plot for this window
        create_roc_subplot(ax, window, model_data)
    
    # Adjust subplot spacing
    plt.subplots_adjust(wspace=0.30, left=0.05, right=0.98, bottom=0.15, top=0.95)
    
    return fig


# ================== Main Execution Function ==================

def main():
    """
    Main execution function for generating ROC curve analysis.
    """
    print("=" * 70)
    print("Multi-Model ROC Curve Analysis")
    print("=" * 70)
    
    # Verify base directory exists
    if not BASE_DIR.exists():
        print(f"ERROR: Base directory does not exist: {BASE_DIR}")
        print("Please check the path configuration.")
        return
    
    print(f"Base directory: {BASE_DIR}")
    print(f"Models to process: {', '.join(MODELS)}")
    print(f"Time windows: {', '.join(TIME_WINDOWS)}")
    print("-" * 70)
    
    # Initialize storage for all models' ROC data
    model_roc_data = {}
    
    # Process each model
    for model in MODELS:
        print(f"\nProcessing {model} model:")
        model_roc_data[model] = {}
        
        for window in TIME_WINDOWS:
            print(f"  {WINDOW_DISPLAY[window]} window...", end=" ")
            
            roc_stats = process_model_roc(model, window, BASE_DIR)
            
            if roc_stats is not None:
                model_roc_data[model][window] = roc_stats
                print(f"✓ (AUC: {roc_stats['mean_auc']:.3f})")
            else:
                print("✗ (insufficient data)")
    
    print("\n" + "-" * 70)
    
    # Check if we have any valid data
    valid_models = sum(1 for model_data in model_roc_data.values() if model_data)
    
    if valid_models == 0:
        print("ERROR: No valid ROC data found for any model.")
        return
    
    print(f"Successfully processed data for {valid_models} model(s)")
    print("Generating ROC curve visualization...")
    
    # Create and save figure
    fig = create_multi_model_roc_figure(model_roc_data)
    
    # Create output directory and save figure
    save_dir = BASE_DIR / "performance_visualization" / "ROC_analysis"
    save_dir.mkdir(parents=True, exist_ok=True)
    
    output_file = save_dir / "multi_model_roc_curves.png"
    
    try:
        plt.savefig(output_file, dpi=600, bbox_inches='tight')
        print(f"\nSUCCESS: Figure saved to {output_file}")
    except Exception as e:
        print(f"\nERROR saving figure: {e}")
        # Try alternative save location
        alt_dir = Path("C:/temp/results")
        alt_dir.mkdir(parents=True, exist_ok=True)
        alt_file = alt_dir / "multi_model_roc_curves.png"
        plt.savefig(alt_file, dpi=600, bbox_inches='tight')
        print(f"Saved to alternative location: {alt_file}")
    
    # Close figure to free memory
    plt.close(fig)
    
    print("=" * 70)


if __name__ == "__main__":
    main()

Multi-Model ROC Curve Analysis
Base directory: C:\Users\Tian\Desktop\地磁论文代码运行测试\results
Models to process: GRU, LSTM, MLP, RNN, Transformer
Time windows: 7day, 14day, 30day
----------------------------------------------------------------------

Processing GRU model:
  7-day window... ✓ (AUC: 0.977)
  14-day window... ✓ (AUC: 0.951)
  30-day window... ✓ (AUC: 0.954)

Processing LSTM model:
  7-day window... ✓ (AUC: 0.983)
  14-day window... ✓ (AUC: 0.958)
  30-day window... ✓ (AUC: 0.958)

Processing MLP model:
  7-day window... ✓ (AUC: 0.628)
  14-day window... ✓ (AUC: 0.711)
  30-day window... ✓ (AUC: 0.769)

Processing RNN model:
  7-day window... ✓ (AUC: 0.964)
  14-day window... ✓ (AUC: 0.844)
  30-day window... ✓ (AUC: 0.806)

Processing Transformer model:
  7-day window... ✓ (AUC: 0.489)
  14-day window... ✓ (AUC: 0.530)
  30-day window... ✓ (AUC: 0.581)

----------------------------------------------------------------------
Successfully processed data for 5 model(s)
Generating R

# Module 4: Confusion Matrix Analysis with Cross-Validation Averaging

This module generates individual confusion matrices for multiple deep learning models (LSTM, GRU, RNN, MLP, Transformer) across different temporal windows (7-day, 14-day, 30-day), using averaged results from 5-fold cross-validation for robust performance evaluation of seismic-geomagnetic signal classification.

In [None]:
"""
Multi-Model Confusion Matrix Analysis with Individual Plots
Supports Chinese paths using pathlib for cross-platform compatibility
Generates averaged confusion matrices from cross-validation folds
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from sklearn.metrics import confusion_matrix
from pathlib import Path
from typing import Dict, List, Tuple, Optional

# ================== Global Configuration ==================

# Base directory using pathlib for better path handling
BASE_DIR = Path(r"your_project/results")
   # Change to your actual results directory
   # For example, if your project is at: C:/Users/Tian/Desktop/地磁论文代码运行测试
   # Then change to: output_dir: str = r"C:\Users\Tian\Desktop\地磁论文代码运行测试\results"

# Model configurations
MODELS = ["LSTM", "GRU", "RNN", "MLP", "Transformer"]

# Time window configurations
TIME_WINDOWS = ["7day", "14day", "30day"]
WINDOW_DISPLAY = {"7day": "7-day", "14day": "14-day", "30day": "30-day"}

# Number of cross-validation folds
NUM_FOLDS = 5

# Configure matplotlib for scientific journal style
plt.rcParams.update({
    'font.family': 'Arial',
    'font.weight': 'bold',
    'axes.labelsize': 50,
    'axes.titlesize': 52,
    'legend.fontsize': 50,
    'xtick.labelsize': 50,
    'ytick.labelsize': 50,
    'axes.grid': False,
    'figure.dpi': 100,
})

# Figure configuration
FIG_SIZE = (10, 10)
MARGINS = [0.16, 0.132, 0.80, 0.80]  # [left, bottom, width, height]

# Heatmap configuration
HEATMAP_CONFIG = {
    'cmap': 'coolwarm',
    'linewidths': 0.5,
    'linecolor': 'gray',
    'square': True,
    'cbar': False,
    'annot': False
}

# Text annotation configuration
TEXT_CONFIG = {
    'ha': 'center',
    'va': 'center',
    'color': 'black',
    'fontsize': 90,
    'fontweight': 'bold',
    'fontfamily': 'Arial'
}

# ================== Data Loading Functions ==================

def load_fold_confusion_matrix(model: str, window: str, fold: int, base_dir: Path) -> Optional[np.ndarray]:
    """
    Load test data and calculate confusion matrix for a specific fold.
    
    Args:
        model: Model name (e.g., 'LSTM')
        window: Time window identifier (e.g., '7day')
        fold: Fold number (1-based)
        base_dir: Base directory path
    
    Returns:
        Confusion matrix or None if files not found
    """
    model_dir = base_dir / f"{model.lower()}_models"
    
    # File paths for test data (validation set)
    probs_file = model_dir / f"{model}Model_{window}_fold_{fold}_test_probs.npy"
    labels_file = model_dir / f"{model}Model_{window}_fold_{fold}_test_labels.npy"
    
    if not probs_file.exists() or not labels_file.exists():
        return None
    
    try:
        # Load probabilities and labels
        probs = np.load(probs_file)
        labels = np.load(labels_file)
        
        # Convert probabilities to predictions
        preds = np.argmax(probs, axis=1)
        
        # Calculate confusion matrix
        cm = confusion_matrix(labels, preds)
        
        return cm
        
    except Exception as e:
        print(f"Error loading fold {fold} for {model} - {window}: {e}")
        return None


def calculate_averaged_confusion_matrix(model: str, window: str, base_dir: Path) -> Optional[np.ndarray]:
    """
    Calculate averaged confusion matrix across all folds.
    
    Args:
        model: Model name
        window: Time window identifier
        base_dir: Base directory path
    
    Returns:
        Averaged confusion matrix or None if no data
    """
    fold_cms = []
    
    # Collect confusion matrices from all folds
    for fold in range(1, NUM_FOLDS + 1):
        cm = load_fold_confusion_matrix(model, window, fold, base_dir)
        if cm is not None:
            fold_cms.append(cm)
        else:
            print(f"  Missing data: {model} - {WINDOW_DISPLAY[window]} - Fold {fold}")
    
    if not fold_cms:
        print(f"  No valid data for {model} - {WINDOW_DISPLAY[window]}")
        return None
    
    # Calculate average confusion matrix
    avg_cm = np.mean(fold_cms, axis=0)
    
    print(f"  {WINDOW_DISPLAY[window]}: {len(fold_cms)}/{NUM_FOLDS} folds averaged")
    
    return avg_cm


# ================== Global Maximum Calculation ==================

def find_global_maximum(base_dir: Path) -> float:
    """
    Find the global maximum value across all averaged confusion matrices.
    
    Args:
        base_dir: Base directory path
    
    Returns:
        Global maximum value
    """
    max_value = 0
    
    print("Calculating global maximum for color scale...")
    
    for model in MODELS:
        for window in TIME_WINDOWS:
            avg_cm = calculate_averaged_confusion_matrix(model, window, base_dir)
            if avg_cm is not None:
                max_value = max(max_value, avg_cm.max())
    
    print(f"Global maximum value: {max_value:.1f}")
    return max_value


# ================== Visualization Functions ==================

def create_individual_confusion_matrix(
    model: str, 
    window: str, 
    avg_cm: np.ndarray,
    output_dir: Path
) -> None:
    """
    Create and save an individual confusion matrix plot.
    
    Args:
        model: Model name
        window: Time window identifier
        avg_cm: Averaged confusion matrix
        output_dir: Output directory path
    """
    # Create figure with fixed size
    fig = plt.figure(figsize=FIG_SIZE, dpi=100)
    plt.subplots_adjust(top=0.88, bottom=0.07, left=0.14, right=0.92)
    
    # Add axes with specified margins
    ax = fig.add_axes(MARGINS)
    
    # Round values for display
    avg_cm_int = np.round(avg_cm).astype(int)
    
    # Determine axis labels (only show for specific positions)
    xticklabels = ["Class 0", "Class 1"] if model == "Transformer" else ["", ""]
    yticklabels = ["Class 0", "Class 1"] if window == "7day" else ["", ""]
    
    # Create heatmap
    sns.heatmap(
        avg_cm,
        ax=ax,
        xticklabels=xticklabels,
        yticklabels=yticklabels,
        **HEATMAP_CONFIG
    )
    
    # Manually add text annotations
    for i in range(avg_cm_int.shape[0]):
        for j in range(avg_cm_int.shape[1]):
            ax.text(j + 0.5, i + 0.5, str(avg_cm_int[i, j]), **TEXT_CONFIG)
    
    # Force update
    plt.draw()
    fig.canvas.draw()
    
    # Set axis labels conditionally
    if window == "7day":
        ax.set_ylabel('Actual Label', fontweight='bold', fontsize=50, labelpad=5)
    else:
        ax.set_ylabel('')
    
    if model == "Transformer":
        ax.set_xlabel('Predicted Label', fontweight='bold', fontsize=50, labelpad=5)
    else:
        ax.set_xlabel('')
    
    # Set tick parameters
    ax.tick_params(axis='both', labelsize=50)
    
    # Add title
    plt.title(f"{model} - {WINDOW_DISPLAY[window]}", fontweight='bold', fontsize=52, pad=10)
    
    # Ensure borders are visible
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1.0)
    
    # Save figure with fixed dimensions
    output_file = output_dir / f"cm_{model}_{WINDOW_DISPLAY[window]}.png"
    plt.savefig(output_file, dpi=300, bbox_inches=None)
    print(f"    Saved: {output_file.name}")
    plt.close(fig)


def create_horizontal_colorbar(max_value: float, output_dir: Path) -> None:
    """
    Create and save a horizontal colorbar.
    
    Args:
        max_value: Maximum value for color scale
        output_dir: Output directory path
    """
    # Create figure for colorbar
    fig_colorbar = plt.figure(figsize=(12, 2), dpi=100)
    
    # Define ticks
    max_val_rounded = max(350, int(np.ceil(max_value)))
    ticks = np.array([0, 50, 100, 150, 200, 250, 300, 350])
    
    # Add additional ticks if needed
    if max_val_rounded > 350:
        additional_ticks = np.arange(400, max_val_rounded + 50, 50)
        ticks = np.append(ticks, additional_ticks)
    
    # Create color mapping
    norm = mpl.colors.Normalize(vmin=0, vmax=max_val_rounded)
    cmap = plt.cm.coolwarm
    
    # Create horizontal colorbar
    ax_cbar = fig_colorbar.add_axes([0.1, 0.4, 0.8, 0.3])
    cb = mpl.colorbar.ColorbarBase(
        ax_cbar, 
        cmap=cmap, 
        norm=norm,
        orientation='horizontal', 
        ticks=ticks
    )
    
    # Set tick labels
    cb.ax.set_xticklabels([str(int(tick)) for tick in ticks], fontsize=50, fontweight='bold')
    cb.ax.tick_params(length=6, width=2)
    cb.set_label('Number of Samples', fontsize=52, fontweight='bold', labelpad=15)
    
    # Save colorbar
    colorbar_file = output_dir / "horizontal_colorbar.png"
    plt.savefig(colorbar_file, dpi=300, bbox_inches='tight')
    print(f"\nSaved colorbar: {colorbar_file.name}")
    plt.close(fig_colorbar)


# ================== Main Execution Function ==================

def main():
    """
    Main execution function for generating individual confusion matrix plots.
    """
    print("=" * 80)
    print("Individual Confusion Matrix Analysis (Averaged from Cross-Validation)")
    print("=" * 80)
    
    # Verify base directory exists
    if not BASE_DIR.exists():
        print(f"ERROR: Base directory does not exist: {BASE_DIR}")
        print("Please check the path configuration.")
        return
    
    print(f"Base directory: {BASE_DIR}")
    print(f"Models: {', '.join(MODELS)}")
    print(f"Time windows: {', '.join([WINDOW_DISPLAY[w] for w in TIME_WINDOWS])}")
    print(f"Output: Individual confusion matrix plots")
    print("-" * 80)
    
    # Create output directory
    output_dir = BASE_DIR / "performance_visualization" / "Individual_Confusion_Matrices"
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Find global maximum for consistent color scale
    max_value = find_global_maximum(BASE_DIR)
    print("-" * 80)
    
    # Process each model and time window
    successful_plots = 0
    total_plots = len(MODELS) * len(TIME_WINDOWS)
    
    for model in MODELS:
        print(f"\nProcessing {model} model:")
        
        # Check if model directory exists
        model_dir = BASE_DIR / f"{model.lower()}_models"
        if not model_dir.exists():
            print(f"  WARNING: Model directory not found: {model_dir}")
            continue
        
        for window in TIME_WINDOWS:
            # Calculate averaged confusion matrix
            avg_cm = calculate_averaged_confusion_matrix(model, window, BASE_DIR)
            
            if avg_cm is not None:
                # Create individual plot
                create_individual_confusion_matrix(model, window, avg_cm, output_dir)
                successful_plots += 1
            else:
                print(f"    Skipped: No data available")
    
    # Create horizontal colorbar
    if successful_plots > 0:
        create_horizontal_colorbar(max_value, output_dir)
    
    # Final summary
    print("\n" + "=" * 80)
    if successful_plots > 0:
        print(f"SUCCESS: Generated {successful_plots}/{total_plots} confusion matrix plots")
        print(f"Output directory: {output_dir}")
        print("\nGenerated files:")
        
        # List generated files
        for model in MODELS:
            for window in TIME_WINDOWS:
                filename = f"cm_{model}_{WINDOW_DISPLAY[window]}.png"
                if (output_dir / filename).exists():
                    print(f"  - {filename}")
        
        print("  - horizontal_colorbar.png")
    else:
        print("WARNING: No visualizations were generated.")
        print("Please check that test result files exist.")
    
    print("=" * 80)


if __name__ == "__main__":
    main()

Individual Confusion Matrix Analysis (Averaged from Cross-Validation)
Base directory: C:\Users\Tian\Desktop\地磁论文代码运行测试\results
Models: LSTM, GRU, RNN, MLP, Transformer
Time windows: 7-day, 14-day, 30-day
Output: Individual confusion matrix plots
--------------------------------------------------------------------------------
Calculating global maximum for color scale...
  7-day: 5/5 folds averaged
  14-day: 5/5 folds averaged
  30-day: 5/5 folds averaged
  7-day: 5/5 folds averaged
  14-day: 5/5 folds averaged
  30-day: 5/5 folds averaged
  7-day: 5/5 folds averaged
  14-day: 5/5 folds averaged
  30-day: 5/5 folds averaged
  7-day: 5/5 folds averaged
  14-day: 5/5 folds averaged
  30-day: 5/5 folds averaged
  7-day: 5/5 folds averaged
  14-day: 5/5 folds averaged
  30-day: 5/5 folds averaged
Global maximum value: 374.0
--------------------------------------------------------------------------------

Processing LSTM model:
  7-day: 5/5 folds averaged
    Saved: cm_LSTM_7-day.png
  14-da