# Module 1: Environment Setup and Configuration

This module provides environment initialization and configuration for deep learning model training, ensuring reproducibility and optimal resource utilization.

In [20]:
"""
RNN Model Training Pipeline - Environment Setup and Configuration Module

This module provides essential functionality for initializing and configuring
the deep learning environment, including hardware detection, reproducibility
settings, and resource management.

Author: Tian Gao
Date: 2025-09-16
Version: 1.0.0
License: MIT
"""

# ==================== Standard Library Imports ====================

import copy
import json
import logging
import os
import random
import sys
import time
import warnings
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import platform


# ==================== Scientific Computing Imports ====================

import numpy as np
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    matthews_corrcoef,
    precision_score,
    recall_score
)
from sklearn.model_selection import StratifiedKFold

# ==================== PyTorch Deep Learning Imports ====================

import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import (
    DataLoader,
    Dataset, 
    Subset,
    TensorDataset,
    WeightedRandomSampler
)

# ==================== Configuration Constants ====================

class Config:
    """Centralized configuration management."""
    
    VERSION = "1.0.0"
    AUTHOR = "Tian Gao"
    LICENSE = "MIT"
    
    TESTED_CONFIG = {
        'python': '3.11.7',
        'pytorch': '2.3.0',
        'numpy': '1.26.4',
        'sklearn': '1.5.0',
        'cuda': '12.1'
    }
    
    VERSION_REQUIREMENTS = {
        'python': ((3, 9, 0), (3, 11, 99)),
        'pytorch': ('2.0.0', '2.3.1'),
        'numpy': ('1.21.0', '1.26.99'),
        'sklearn': ('1.0.0', '1.5.99')
    }
    
    DEFAULT_SEED = 42
    DEFAULT_LOG_LEVEL = 'INFO'
    DEFAULT_LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
    DEFAULT_DATE_FORMAT = '%Y-%m-%d %H:%M:%S'

# ==================== Optional Dependencies Management ====================

OPTIONAL_DEPS = {}

try:
    import psutil
    OPTIONAL_DEPS['psutil'] = True
except ImportError:
    OPTIONAL_DEPS['psutil'] = False
    warnings.warn(
        "psutil not installed. System resource monitoring unavailable.\n"
        "Install with: pip install psutil>=5.8.0",
        category=ImportWarning,
        stacklevel=2
    )

try:
    from packaging import version
    OPTIONAL_DEPS['packaging'] = True
except ImportError:
    OPTIONAL_DEPS['packaging'] = False
    warnings.warn(
        "packaging not installed. Version comparison features limited.",
        category=ImportWarning,
        stacklevel=2
    )

# Suppress specific PyTorch warnings
warnings.filterwarnings("ignore", message=".*PyTorch is not compiled with NCCL support.*")
warnings.filterwarnings("ignore", message=".*UserWarning: TypedStorage is deprecated.*")

# ==================== Logging Configuration ====================

def setup_logging(
    log_file: Optional[str] = None,
    level: str = 'INFO',
    format_string: Optional[str] = None
) -> logging.Logger:
    """
    Configure comprehensive logging system.
    
    Parameters
    ----------
    log_file : Optional[str]
        Path to log file.
    level : str
        Logging level ('DEBUG', 'INFO', 'WARNING', 'ERROR').
    format_string : Optional[str]
        Custom format string.
    
    Returns
    -------
    logging.Logger
        Configured logger instance.
    """
    if format_string is None:
        format_string = Config.DEFAULT_LOG_FORMAT
    
    logger = logging.getLogger('RNN_Pipeline')
    logger.handlers = []
    
    log_level = getattr(logging, level.upper(), logging.INFO)
    logger.setLevel(log_level)
    
    formatter = logging.Formatter(format_string, Config.DEFAULT_DATE_FORMAT)
    
    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    
    # File handler
    if log_file:
        log_path = Path(log_file)
        log_path.parent.mkdir(parents=True, exist_ok=True)
        file_handler = logging.FileHandler(log_file, encoding='utf-8')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        logger.info(f"Logging initialized: {log_file}")
    
    return logger

# Initialize default logger
logger = setup_logging(level=Config.DEFAULT_LOG_LEVEL)

# ==================== Version Compatibility Checking ====================

def check_version_compatibility(raise_on_error: bool = False) -> Dict[str, Any]:
    """
    Check version compatibility of installed packages.
    
    Parameters
    ----------
    raise_on_error : bool
        Raise exception on critical incompatibilities.
    
    Returns
    -------
    Dict[str, Any]
        Compatibility report.
    """
    report = {
        'timestamp': datetime.now().isoformat(),
        'status': 'compatible',
        'warnings': [],
        'errors': [],
        'versions': {},
        'recommendations': []
    }
    
    # Check Python
    current_python = sys.version_info[:3]
    min_python, max_python = Config.VERSION_REQUIREMENTS['python']
    
    report['versions']['python'] = {
        'current': f"{current_python[0]}.{current_python[1]}.{current_python[2]}",
        'tested': Config.TESTED_CONFIG['python']
    }
    
    if current_python < min_python or current_python > max_python:
        report['warnings'].append(
            f"Python {report['versions']['python']['current']} outside supported range"
        )
    
    # Check PyTorch
    current_pytorch = torch.__version__.split('+')[0]
    report['versions']['pytorch'] = {
        'current': current_pytorch,
        'tested': Config.TESTED_CONFIG['pytorch']
    }
    
    if OPTIONAL_DEPS.get('packaging', False):
        from packaging import version
        min_pytorch, max_pytorch = Config.VERSION_REQUIREMENTS['pytorch']
        
        if version.parse(current_pytorch) < version.parse(min_pytorch):
            report['errors'].append(f"PyTorch {current_pytorch} below minimum {min_pytorch}")
            report['status'] = 'incompatible'
        elif version.parse(current_pytorch) > version.parse(max_pytorch):
            report['warnings'].append(f"PyTorch {current_pytorch} exceeds tested version")
    
    # Check NumPy
    current_numpy = np.__version__
    report['versions']['numpy'] = {
        'current': current_numpy,
        'tested': Config.TESTED_CONFIG['numpy']
    }
    
    if current_numpy.startswith('2.'):
        report['errors'].append("NumPy 2.0+ has breaking changes")
        report['status'] = 'incompatible'
        report['recommendations'].append("pip install 'numpy<2.0.0'")
    
    # Check CUDA
    if torch.cuda.is_available():
        report['versions']['cuda'] = {
            'current': torch.version.cuda,
            'tested': Config.TESTED_CONFIG['cuda']
        }
    
    if raise_on_error and report['status'] == 'incompatible':
        raise RuntimeError("Critical version incompatibilities: " + 
                         ", ".join(report['errors']))
    
    return report

def print_compatibility_report(report: Dict[str, Any]):
    """Print formatted compatibility report."""
    print("\n" + "="*70)
    print("Version Compatibility Report")
    print("="*70)
    
    print("\nVersions:")
    for package, info in report['versions'].items():
        current = info.get('current', 'N/A')
        tested = info.get('tested', 'N/A')
        print(f"  {package:10s}: {current:15s} (tested: {tested})")
    
    if report['errors']:
        print("\nERRORS:")
        for error in report['errors']:
            print(f"  - {error}")
    
    if report['warnings']:
        print("\nWarnings:")
        for warning in report['warnings']:
            print(f"  - {warning}")
    
    if report['recommendations']:
        print("\nRecommendations:")
        for rec in report['recommendations']:
            print(f"  {rec}")
    
    print("="*70)

# ==================== Environment Detection ====================

def check_environment(detailed: bool = True) -> Dict[str, Any]:
    """
    Perform comprehensive environment verification.
    
    Parameters
    ----------
    detailed : bool
        Include detailed hardware information.
    
    Returns
    -------
    Dict[str, Any]
        Environment configuration.
    """
    env_info = {
        'python_version': f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
        'pytorch_version': torch.__version__,
        'numpy_version': np.__version__,
        'cuda_available': torch.cuda.is_available(),
        'cuda_version': torch.version.cuda if torch.cuda.is_available() else None,
        'cudnn_enabled': torch.backends.cudnn.enabled if torch.cuda.is_available() else False,
        'cudnn_version': torch.backends.cudnn.version() if torch.cuda.is_available() else None,
        'gpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
        'gpu_devices': [],
        'cpu_count': os.cpu_count(),
        'optional_features': OPTIONAL_DEPS
    }
    
    if detailed and env_info['cuda_available']:
        for i in range(env_info['gpu_count']):
            props = torch.cuda.get_device_properties(i)
            env_info['gpu_devices'].append({
                'index': i,
                'name': props.name,
                'compute_capability': (props.major, props.minor),
                'total_memory_gb': props.total_memory / (1024**3),
                'multi_processor_count': props.multi_processor_count
            })
    
    # System memory if psutil available
    if OPTIONAL_DEPS.get('psutil', False):
        import psutil
        mem = psutil.virtual_memory()
        env_info['system_memory'] = {
            'total_gb': mem.total / (1024**3),
            'available_gb': mem.available / (1024**3),
            'percent_used': mem.percent
        }
    
    return env_info

def print_environment_info(env_info: Dict[str, Any] = None):
    """Print formatted environment information."""
    if env_info is None:
        env_info = check_environment()
    
    print("\n" + "="*70)
    print("PyTorch Environment Information")
    print("="*70)
    
    print("\nSoftware:")
    print(f"  Python:    {env_info['python_version']}")
    print(f"  PyTorch:   {env_info['pytorch_version']}")
    print(f"  NumPy:     {env_info['numpy_version']}")
    
    print("\nHardware:")
    print(f"  CPUs:      {env_info['cpu_count']}")
    if env_info['cuda_available']:
        print(f"  CUDA:      {env_info['cuda_version']}")
        print(f"  cuDNN:     {env_info['cudnn_version']}")
        print(f"  GPUs:      {env_info['gpu_count']}")
        
        for gpu in env_info['gpu_devices']:
            print(f"\n  GPU {gpu['index']}: {gpu['name']}")
            print(f"    Memory:   {gpu['total_memory_gb']:.2f} GB")
            print(f"    Compute:  {gpu['compute_capability'][0]}.{gpu['compute_capability'][1]}")
    else:
        print("  CUDA:      Not Available")
    
    if 'system_memory' in env_info:
        mem = env_info['system_memory']
        print(f"\nMemory:")
        print(f"  Total:     {mem['total_gb']:.2f} GB")
        print(f"  Available: {mem['available_gb']:.2f} GB")
    
    print("="*70)

# ==================== Reproducibility Configuration ====================

def set_random_seeds(
    seed: int = 42,
    strict_determinism: bool = False,
    warn_performance: bool = True
) -> Dict[str, Any]:
    """
    Configure random seeds for reproducible experiments.
    
    Parameters
    ----------
    seed : int
        Random seed value.
    strict_determinism : bool
        Enable strict deterministic mode (may impact performance).
    warn_performance : bool
        Show performance warning for strict mode.
    
    Returns
    -------
    Dict[str, Any]
        Configuration status.
    """
    if not 0 <= seed <= 2**32 - 1:
        raise ValueError(f"Seed must be between 0 and 2^32-1")
    
    config_status = {
        'seed': seed,
        'strict_determinism': strict_determinism,
        'settings_applied': []
    }
    
    # Python random
    random.seed(seed)
    config_status['settings_applied'].append('python_random')
    
    # NumPy random
    np.random.seed(seed)
    config_status['settings_applied'].append('numpy_random')
    
    # PyTorch
    torch.manual_seed(seed)
    config_status['settings_applied'].append('torch_cpu')
    
    # Environment variable
    os.environ['PYTHONHASHSEED'] = str(seed)
    
    # CUDA
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        config_status['settings_applied'].append('torch_cuda')
        
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
        if strict_determinism:
            if hasattr(torch, 'use_deterministic_algorithms'):
                try:
                    torch.use_deterministic_algorithms(True)
                    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
                    config_status['settings_applied'].append('strict_determinism')
                    
                    if warn_performance:
                        logger.warning(
                            "Strict determinism enabled. Performance may be impacted."
                        )
                except RuntimeError as e:
                    logger.warning(f"Could not enable strict determinism: {e}")
                    config_status['strict_determinism'] = False
    
    logger.info(f"Random seeds set: seed={seed}, strict={config_status['strict_determinism']}")
    
    return config_status

def verify_reproducibility(iterations: int = 5) -> Dict[str, bool]:
    """Verify reproducibility settings."""
    results = {}
    
    # Test PyTorch
    tensors = []
    for _ in range(iterations):
        torch.manual_seed(42)
        tensors.append(torch.randn(100))
    results['torch'] = all(torch.allclose(tensors[0], t) for t in tensors[1:])
    
    # Test NumPy
    arrays = []
    for _ in range(iterations):
        np.random.seed(42)
        arrays.append(np.random.randn(100))
    results['numpy'] = all(np.allclose(arrays[0], a) for a in arrays[1:])
    
    # Test CUDA if available
    if torch.cuda.is_available():
        cuda_tensors = []
        for _ in range(iterations):
            torch.cuda.manual_seed(42)
            cuda_tensors.append(torch.randn(100, device='cuda'))
        results['cuda'] = all(torch.allclose(cuda_tensors[0], t) for t in cuda_tensors[1:])
    
    return results

# ==================== Device Management ====================

def get_device(
    gpu_id: Optional[int] = None,
    gpu_memory_fraction: Optional[float] = None,
    min_memory_gb: float = 0.0,
    fallback_to_cpu: bool = True,
    verbose: bool = True
) -> torch.device:
    """
    Intelligently select and configure computation device.
    
    Parameters
    ----------
    gpu_id : Optional[int]
        Specific GPU to use. Auto-selects if None.
    gpu_memory_fraction : Optional[float]
        Fraction of GPU memory to allocate (0.0-1.0).
    min_memory_gb : float
        Minimum required GPU memory.
    fallback_to_cpu : bool
        Use CPU if no suitable GPU found.
    verbose : bool
        Print device information.
    
    Returns
    -------
    torch.device
        Selected computation device.
    """
    if not torch.cuda.is_available():
        device = torch.device('cpu')
        if verbose:
            logger.info("CUDA not available. Using CPU.")
        return device
    
    num_gpus = torch.cuda.device_count()
    
    if gpu_id is not None:
        if gpu_id >= num_gpus or gpu_id < 0:
            if fallback_to_cpu:
                logger.warning(f"Invalid gpu_id={gpu_id}. Using CPU.")
                return torch.device('cpu')
            raise ValueError(f"Invalid gpu_id={gpu_id}. Available: 0-{num_gpus-1}")
        device = torch.device(f'cuda:{gpu_id}')
    else:
        # Auto-select GPU with most free memory
        best_gpu = 0
        max_free = 0
        
        for i in range(num_gpus):
            props = torch.cuda.get_device_properties(i)
            free = props.total_memory - torch.cuda.memory_allocated(i)
            
            if free > max_free and props.total_memory / (1024**3) >= min_memory_gb:
                max_free = free
                best_gpu = i
        
        if max_free == 0 and min_memory_gb > 0:
            if fallback_to_cpu:
                logger.warning(f"No GPU with >={min_memory_gb}GB. Using CPU.")
                return torch.device('cpu')
            raise RuntimeError(f"No GPU with sufficient memory")
        
        device = torch.device(f'cuda:{best_gpu}')
    
    # Set memory fraction
    if gpu_memory_fraction and 0.0 < gpu_memory_fraction <= 1.0:
        torch.cuda.set_per_process_memory_fraction(gpu_memory_fraction, device)
        if verbose:
            logger.info(f"GPU memory limited to {gpu_memory_fraction*100:.0f}%")
    
    if verbose:
        idx = int(str(device).split(':')[-1]) if 'cuda' in str(device) else 0
        props = torch.cuda.get_device_properties(idx)
        logger.info(
            f"Using device: {device} ({props.name}, "
            f"{props.total_memory/(1024**3):.2f} GB)"
        )
    
    # Clear cache
    if 'cuda' in str(device):
        torch.cuda.empty_cache()
    
    return device

# ==================== Memory Management ====================

def get_memory_usage(device: torch.device) -> Dict[str, float]:
    """
    Get current memory usage statistics.
    
    Parameters
    ----------
    device : torch.device
        Device to check.
    
    Returns
    -------
    Dict[str, float]
        Memory statistics in GB.
    """
    stats = {}
    
    if device.type == 'cuda':
        idx = device.index if device.index is not None else 0
        
        stats['allocated_gb'] = torch.cuda.memory_allocated(idx) / (1024**3)
        stats['reserved_gb'] = torch.cuda.memory_reserved(idx) / (1024**3)
        stats['total_gb'] = torch.cuda.get_device_properties(idx).total_memory / (1024**3)
        stats['free_gb'] = stats['total_gb'] - stats['allocated_gb']
        stats['utilization_percent'] = (stats['allocated_gb'] / stats['total_gb']) * 100
        
    elif device.type == 'cpu' and OPTIONAL_DEPS.get('psutil', False):
        import psutil
        mem = psutil.virtual_memory()
        
        stats['used_gb'] = mem.used / (1024**3)
        stats['available_gb'] = mem.available / (1024**3)
        stats['total_gb'] = mem.total / (1024**3)
        stats['utilization_percent'] = mem.percent
    
    return stats

def clear_gpu_memory():
    """Clear GPU memory cache."""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        logger.info("GPU memory cache cleared")

def estimate_batch_size(
    model: nn.Module,
    input_shape: Tuple[int, ...],
    device: torch.device,
    safety_factor: float = 0.9
) -> int:
    """
    Estimate optimal batch size based on available memory.
    
    Parameters
    ----------
    model : nn.Module
        Model to evaluate.
    input_shape : Tuple[int, ...]
        Shape of single input.
    device : torch.device
        Target device.
    safety_factor : float
        Safety margin.
    
    Returns
    -------
    int
        Recommended batch size.
    """
    if device.type != 'cuda':
        return 32  # Default for CPU
    
    # Get available memory
    stats = get_memory_usage(device)
    available_gb = stats['free_gb'] * safety_factor
    
    # Estimate memory per sample
    model = model.to(device)
    dummy_input = torch.randn(1, *input_shape, device=device)
    
    torch.cuda.reset_peak_memory_stats()
    with torch.no_grad():
        _ = model(dummy_input)
    
    memory_per_sample = torch.cuda.max_memory_allocated() / (1024**3)
    estimated = int(available_gb / memory_per_sample)
    
    # Round to power of 2
    batch_size = 2 ** int(np.log2(max(1, estimated)))
    
    return max(1, batch_size)

# ==================== Model Utilities ====================

def get_model_size(model: nn.Module) -> Dict[str, Any]:
    """
    Calculate model size and parameter count.
    
    Parameters
    ----------
    model : nn.Module
        PyTorch model.
    
    Returns
    -------
    Dict[str, Any]
        Model statistics.
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    param_size = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024**2)
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers()) / (1024**2)
    
    return {
        'total_parameters': total_params,
        'trainable_parameters': trainable_params,
        'non_trainable_parameters': total_params - trainable_params,
        'model_size_mb': param_size + buffer_size
    }

# ==================== Testing and Verification ====================

def run_diagnostic_tests() -> Dict[str, Any]:
    """Run comprehensive diagnostic tests."""
    results = {
        'timestamp': datetime.now().isoformat(),
        'tests_passed': [],
        'tests_failed': [],
        'warnings': []
    }
    
    # Version compatibility
    try:
        compat = check_version_compatibility()
        if compat['status'] == 'compatible':
            results['tests_passed'].append('version_compatibility')
        else:
            results['tests_failed'].append('version_compatibility')
    except Exception as e:
        results['tests_failed'].append('version_compatibility')
        results['warnings'].append(str(e))
    
    # Device selection
    try:
        device = get_device(verbose=False)
        results['tests_passed'].append('device_selection')
        results['device'] = str(device)
    except Exception as e:
        results['tests_failed'].append('device_selection')
        results['warnings'].append(str(e))
    
    # Reproducibility
    try:
        set_random_seeds(42)
        repro = verify_reproducibility()
        if all(repro.values()):
            results['tests_passed'].append('reproducibility')
        else:
            results['tests_failed'].append('reproducibility')
    except Exception as e:
        results['tests_failed'].append('reproducibility')
    
    # Neural network
    try:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model = nn.Sequential(
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 2)
        ).to(device)
        
        input_tensor = torch.randn(32, 10, device=device)
        output = model(input_tensor)
        
        if output.shape == (32, 2):
            results['tests_passed'].append('neural_network')
        else:
            results['tests_failed'].append('neural_network')
    except Exception as e:
        results['tests_failed'].append('neural_network')
    
    total = len(results['tests_passed']) + len(results['tests_failed'])
    results['success_rate'] = len(results['tests_passed']) / total if total > 0 else 0
    
    return results

# ==================== Main Execution ====================

if __name__ == "__main__":
    """Execute comprehensive environment verification."""
    
    print("\n" + "="*70)
    print("RNN Training Pipeline - Environment Setup")
    print(f"Version: {Config.VERSION}")
    print("="*70)
    
    # Version compatibility
    print("\nChecking version compatibility...")
    compat = check_version_compatibility()
    print_compatibility_report(compat)
    
    if compat['status'] == 'incompatible':
        print("\nCritical incompatibilities detected. Please fix before proceeding.")
        sys.exit(1)
    
    # Environment verification
    print("\nVerifying environment...")
    env_info = check_environment()
    print_environment_info(env_info)
    
    # Reproducibility
    print("\nConfiguring reproducibility...")
    config = set_random_seeds(42)
    print(f"Seeds configured: {', '.join(config['settings_applied'])}")
    
    # Device selection
    print("\nSelecting computation device...")
    device = get_device()
    
    # Run tests
    print("\nRunning diagnostic tests...")
    results = run_diagnostic_tests()
    
    print("\nTest Results:")
    print("-"*40)
    for test in results['tests_passed']:
        print(f"  [PASS] {test}")
    for test in results['tests_failed']:
        print(f"  [FAIL] {test}")
    
    print(f"\nSuccess Rate: {results['success_rate']*100:.1f}%")
    
    if results['success_rate'] == 1.0:
        print("\nEnvironment fully configured and verified.")
    else:
        print("\nEnvironment has issues. Check warnings above.")
    
    print("="*70)


RNN Training Pipeline - Environment Setup
Version: 1.0.0

Checking version compatibility...

Version Compatibility Report

Versions:
  python    : 3.11.7          (tested: 3.11.7)
  pytorch   : 2.3.0           (tested: 2.3.0)
  numpy     : 1.26.4          (tested: 1.26.4)
  cuda      : 12.1            (tested: 12.1)

Verifying environment...

PyTorch Environment Information

Software:
  Python:    3.11.7
  PyTorch:   2.3.0
  NumPy:     1.26.4

Hardware:
  CPUs:      20
  CUDA:      12.1
  cuDNN:     8801
  GPUs:      1

  GPU 0: NVIDIA GeForce RTX 4070 Ti SUPER
    Memory:   15.99 GB
    Compute:  8.9

Memory:
  Total:     47.76 GB
  Available: 18.71 GB

Configuring reproducibility...
2025-09-18 07:31:09 - RNN_Pipeline - INFO - set_random_seeds:444 - Random seeds set: seed=42, strict=False
2025-09-18 07:31:09 - RNN_Pipeline - INFO - set_random_seeds:444 - Random seeds set: seed=42, strict=False
Seeds configured: python_random, numpy_random, torch_cpu, torch_cuda

Selecting computation

# Module 2: Data Exploration and Analysis

This module provides comprehensive analysis of class distributions in binary classification datasets, with specialized support for temporal window-based data organization and imbalance detection. It automatically evaluates dataset characteristics and generates tailored recommendations for handling class imbalance.

In [None]:
"""
Dataset Distribution Analysis Module

This module provides comprehensive tools for analyzing class distributions
in binary classification datasets, with specialized support for temporal
window-based data organization and imbalance detection.

Author: Tian Gao
Date: 2025-09-16
Version: 1.0.0
License: MIT
"""

import os
import logging
import warnings
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union, Any

import numpy as np

# ==================== Configuration ====================

class Config:
    """Module configuration constants."""
    VERSION = "1.0.0"
    AUTHOR = "Tian Gao"
    REQUIRED_NUMPY = "1.21.0"
    
    # Imbalance thresholds
    IMBALANCE_THRESHOLDS = {
        'balanced': 1.5,
        'mild': 3.0,
        'moderate': 10.0,
        'severe': float('inf')
    }
    
    # Severity scores
    SEVERITY_SCORES = {
        'balanced': 0.0,
        'mild': 0.33,
        'moderate': 0.67,
        'severe': 1.0
    }

# Configure logger
logging.basicConfig(level=logging.INFO, format='%(message)s')
logger = logging.getLogger(__name__)

# ==================== Data Classes ====================

@dataclass
class ImbalanceAnalysis:
    """Data class for imbalance analysis results."""
    minority_class: int
    minority_count: int
    majority_count: int
    imbalance_ratio: float
    severity: str
    severity_score: float
    strategy: str
    suggested_weights: Optional[List[float]]
    recommendations: Dict[str, Any]

@dataclass
class WindowAnalysis:
    """Data class for window-specific analysis results."""
    period: str
    total_samples: int
    class_distribution: List[int]
    class_proportions: List[float]
    imbalance_analysis: ImbalanceAnalysis
    data_shapes: Dict[str, Tuple[int, ...]]
    validation_status: Dict[str, Any]

# ==================== Main Analysis Class ====================

class DatasetAnalyzer:
    """
    Comprehensive dataset analyzer for class distribution and imbalance detection.
    
    Attributes:
        base_path: Base directory containing dataset files
        results: Dictionary storing analysis results for each window
    """
    
    def __init__(self, base_path: Union[str, Path] = None):
        """
        Initialize the dataset analyzer.
        
        Parameters
        ----------
        base_path : Union[str, Path], optional
            Path to the base data directory. Priority order:
            1. Explicitly provided path
            2. Environment variable DATA_PATH
            3. Current directory './data'
        """
        if base_path is not None:
            self.base_path = Path(base_path)
            logger.info(f"Using provided path: {self.base_path}")
        elif 'DATA_PATH' in os.environ:
            self.base_path = Path(os.environ['DATA_PATH'])
            logger.info(f"Using environment variable DATA_PATH: {self.base_path}")
        else:
            self.base_path = Path('./data')
            logger.info(f"Using default path: {self.base_path}")
        
        self.results = {}
        
        if not self.base_path.exists():
            warning_msg = (
                f"\n{'='*60}\n"
                f"WARNING: Data directory not found\n"
                f"{'='*60}\n"
                f"Path '{self.base_path.absolute()}' does not exist.\n\n"
                f"Please ensure your data is in the correct location:\n"
                f"  Option 1: Move data to: {self.base_path.absolute()}\n"
                f"  Option 2: Specify path: DatasetAnalyzer(base_path='your/path')\n"
                f"  Option 3: Set environment variable:\n"
                f"    - Linux/Mac: export DATA_PATH='/your/data/path'\n"
                f"    - Windows: set DATA_PATH=C:\\your\\data\\path\n"
                f"{'='*60}\n"
            )
            logger.warning(warning_msg)
            self._suggest_directory_structure()
    
    def _suggest_directory_structure(self):
        """Suggest the expected directory structure."""
        suggestion = (
            "Expected directory structure:\n"
            f"{self.base_path}/\n"
            "├── window-7/\n"
            "│   ├── data_0.npy\n"
            "│   └── data_1.npy\n"
            "├── window-14/\n"
            "│   ├── data_0.npy\n"
            "│   └── data_1.npy\n"
            "└── window-30/\n"
            "    ├── data_0.npy\n"
            "    └── data_1.npy\n"
        )
        print(suggestion)
    
    def validate_data_format(self, data: np.ndarray, filename: str) -> Dict[str, Any]:
        """
        Validate data format and quality.
        
        Parameters
        ----------
        data : np.ndarray
            NumPy array to validate
        filename : str
            Name of the source file
            
        Returns
        -------
        Dict[str, Any]
            Validation results and warnings
        """
        validation = {
            'valid': True,
            'warnings': [],
            'shape': data.shape,
            'dtype': str(data.dtype),
            'has_invalid': False,
            'unique_labels': None
        }
        
        if np.any(~np.isfinite(data)):
            validation['has_invalid'] = True
            validation['warnings'].append(f"Invalid values (NaN/Inf) detected in {filename}")
            validation['valid'] = False
        
        if data.ndim != 2:
            validation['warnings'].append(
                f"Expected 2D array, got {data.ndim}D array in {filename}"
            )
            validation['valid'] = False
        
        if data.shape[1] > 0:
            labels = data[:, -1]
            unique_labels = np.unique(labels)
            validation['unique_labels'] = unique_labels.tolist()
            
            if not set(unique_labels).issubset({0, 1, 0.0, 1.0}):
                validation['warnings'].append(
                    f"Non-binary labels found: {unique_labels.tolist()}"
                )
                validation['valid'] = False
        
        return validation
    
    def analyze_window(self, period: str) -> Optional[WindowAnalysis]:
        """
        Analyze a single time window dataset.
        
        Parameters
        ----------
        period : str
            Time window period in days (e.g., '7', '14', '30')
            
        Returns
        -------
        Optional[WindowAnalysis]
            Analysis results or None if analysis fails
        """
        window_path = self.base_path / f"window-{period}"
        
        try:
            data_0, data_1, validation_status = self._load_and_validate_window_data(window_path)
            
            if data_0 is None or data_1 is None:
                return None
            
            all_data = np.concatenate([data_0, data_1], axis=0)
            all_labels = all_data[:, -1].astype(int)
            
            class_counts = np.bincount(all_labels)
            
            if len(class_counts) < 2:
                logger.warning(f"Window {period}: Only one class found")
                return None
            
            total_samples = len(all_labels)
            class_proportions = (class_counts / total_samples * 100).tolist()
            
            imbalance_analysis = self._analyze_imbalance(class_counts)
            
            analysis = WindowAnalysis(
                period=period,
                total_samples=total_samples,
                class_distribution=class_counts.tolist(),
                class_proportions=class_proportions,
                imbalance_analysis=imbalance_analysis,
                data_shapes={
                    'data_0': data_0.shape,
                    'data_1': data_1.shape
                },
                validation_status=validation_status
            )
            
            return analysis
            
        except Exception as e:
            logger.error(f"Error analyzing window {period}: {str(e)}", exc_info=True)
            return None
    
    def analyze_all_windows(self, 
                          window_periods: List[str] = ['7', '14', '30'],
                          verbose: bool = True) -> Dict[str, Optional[WindowAnalysis]]:
        """
        Analyze all specified time windows.
        
        Parameters
        ----------
        window_periods : List[str]
            List of time window periods to analyze
        verbose : bool
            If True, print detailed progress
            
        Returns
        -------
        Dict[str, Optional[WindowAnalysis]]
            Analysis results for each window
        """
        if verbose:
            print("\n" + "="*70)
            print("Dataset Distribution Analysis")
            print(f"Base Path: {self.base_path.absolute()}")
            print("="*70)
        
        for period in window_periods:
            if verbose:
                print(f"\nAnalyzing window: {period} days")
                print("-"*40)
            
            analysis = self.analyze_window(period)
            
            if analysis:
                if verbose:
                    self._print_analysis(analysis)
                self.results[period] = analysis
            else:
                logger.error(f"Failed to analyze window {period}")
                self.results[period] = None
        
        if verbose:
            self._print_summary()
            self._print_implementation_guide()
            print("\n" + "="*70)
        
        return self.results
    
    def _load_and_validate_window_data(
        self, 
        window_path: Path
    ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Dict[str, Any]]:
        """Load and validate data files for a specific window."""
        data_0_path = window_path / "data_0.npy"
        data_1_path = window_path / "data_1.npy"
        
        validation_status = {'data_0': {}, 'data_1': {}}
        
        if not window_path.exists():
            logger.error(f"Window directory not found: {window_path}")
            return None, None, validation_status
        
        if not data_0_path.exists() or not data_1_path.exists():
            logger.error(f"Required data files not found in {window_path}")
            return None, None, validation_status
        
        try:
            data_0 = np.load(str(data_0_path))
            data_1 = np.load(str(data_1_path))
            
            validation_status['data_0'] = self.validate_data_format(data_0, 'data_0.npy')
            validation_status['data_1'] = self.validate_data_format(data_1, 'data_1.npy')
            
            for file_key in ['data_0', 'data_1']:
                if validation_status[file_key]['warnings']:
                    for warning in validation_status[file_key]['warnings']:
                        logger.warning(f"  {warning}")
            
            return data_0, data_1, validation_status
            
        except Exception as e:
            logger.error(f"Error loading data: {str(e)}")
            return None, None, validation_status
    
    def _analyze_imbalance(self, class_counts: np.ndarray) -> ImbalanceAnalysis:
        """Perform comprehensive imbalance analysis."""
        minority_class = int(np.argmin(class_counts))
        majority_class = int(np.argmax(class_counts))
        minority_count = int(class_counts[minority_class])
        majority_count = int(class_counts[majority_class])
        
        imbalance_ratio = float(majority_count / minority_count if minority_count > 0 else float('inf'))
        
        severity = self._get_severity(imbalance_ratio)
        severity_score = Config.SEVERITY_SCORES[severity]
        
        strategy, weights = self._get_strategy_and_weights(
            severity, imbalance_ratio, class_counts
        )
        
        recommendations = self._generate_recommendations(
            severity, imbalance_ratio, weights
        )
        
        return ImbalanceAnalysis(
            minority_class=minority_class,
            minority_count=minority_count,
            majority_count=majority_count,
            imbalance_ratio=imbalance_ratio,
            severity=severity,
            severity_score=severity_score,
            strategy=strategy,
            suggested_weights=weights,
            recommendations=recommendations
        )
    
    def _get_severity(self, ratio: float) -> str:
        """Determine imbalance severity based on ratio."""
        for severity, threshold in Config.IMBALANCE_THRESHOLDS.items():
            if ratio < threshold:
                return severity
        return 'severe'
    
    def _get_strategy_and_weights(
        self, 
        severity: str, 
        ratio: float, 
        class_counts: np.ndarray
    ) -> Tuple[str, Optional[List[float]]]:
        """Determine strategy and calculate weights."""
        strategies = {
            'balanced': (
                "No special handling required",
                None
            ),
            'mild': (
                "Apply class weight adjustment",
                self._calculate_weights(class_counts, 'inverse')
            ),
            'moderate': (
                "Combine weights with balanced sampling",
                self._calculate_weights(class_counts, 'sqrt')
            ),
            'severe': (
                "Implement comprehensive mitigation",
                self._calculate_weights(class_counts, 'effective')
            )
        }
        
        return strategies.get(severity, ("Unknown", None))
    
    def _calculate_weights(self, class_counts: np.ndarray, method: str = 'inverse') -> List[float]:
        """Calculate class weights using various methods."""
        n_samples = np.sum(class_counts)
        n_classes = len(class_counts)
        
        if method == 'inverse':
            weights = n_samples / (n_classes * class_counts)
            
        elif method == 'sqrt':
            weights = np.sqrt(n_samples / (n_classes * class_counts))
            
        elif method == 'effective':
            beta = 0.999
            effective_num = 1.0 - np.power(beta, class_counts)
            weights = (1.0 - beta) / effective_num
            weights = weights / weights.sum() * n_classes
            
        else:
            weights = np.ones(n_classes)
        
        return weights.tolist()
    
    def _generate_recommendations(
        self, 
        severity: str, 
        ratio: float,
        weights: Optional[List[float]]
    ) -> Dict[str, Any]:
        """Generate detailed recommendations."""
        base_rec = {
            'use_class_weights': severity != 'balanced',
            'class_weights': weights,
            'use_balanced_sampling': severity in ['moderate', 'severe'],
            'use_stratified_splits': True,
            'suggested_loss_function': 'CrossEntropyLoss' if severity != 'severe' else 'FocalLoss',
            'evaluation_metrics': ['F1-score', 'MCC', 'Balanced Accuracy'],
            'min_samples_per_batch': max(2, int(100 / ratio)) if ratio > 1 else 50
        }
        
        if severity == 'mild':
            base_rec['additional_recommendations'] = [
                'Monitor per-class metrics',
                'Use stratified k-fold CV',
                'Select best model based on validation F1-score'
            ]
            
        elif severity == 'moderate':
            base_rec.update({
                'use_ensemble_methods': True,
                'sampling_strategy': 'WeightedRandomSampler',
                'additional_recommendations': [
                    'Balanced batch sampling',
                    'Learning rate scheduling',
                    'Threshold optimization'
                ]
            })
            
        elif severity == 'severe':
            base_rec.update({
                'use_data_augmentation': True,
                'consider_oversampling': 'SMOTE/ADASYN',
                'consider_undersampling': 'Tomek Links',
                'additional_recommendations': [
                    'Cost-sensitive learning',
                    'Focal loss or class-balanced loss',
                    'Ensemble methods',
                    'Anomaly detection approaches'
                ],
                'warning': f'Severe imbalance ({ratio:.1f}:1). Comprehensive strategy required.'
            })
        
        return base_rec
    
    def _print_analysis(self, analysis: WindowAnalysis):
        """Print formatted analysis results."""
        print(f"Total samples: {analysis.total_samples:,}")
        print(f"Class distribution: {analysis.class_distribution}")
        print(f"Class proportions: {', '.join([f'{p:.1f}%' for p in analysis.class_proportions])}")
        print(f"Data shapes: {analysis.data_shapes}")
        
        imb = analysis.imbalance_analysis
        print(f"\nImbalance Analysis:")
        print(f"  Ratio: {imb.imbalance_ratio:.2f}:1")
        print(f"  Severity: {imb.severity.upper()}")
        print(f"  Strategy: {imb.strategy}")
        
        if imb.suggested_weights:
            weights_str = ", ".join([f"{w:.3f}" for w in imb.suggested_weights])
            print(f"  Weights: [{weights_str}]")
    
    def _print_summary(self):
        """Print analysis summary."""
        valid_results = [r for r in self.results.values() if r is not None]
        
        if not valid_results:
            return
        
        print("\n" + "="*70)
        print("Summary")
        print("="*70)
        
        most_imbalanced = max(
            valid_results,
            key=lambda x: x.imbalance_analysis.severity_score
        )
        
        print(f"Windows analyzed: {len(valid_results)}")
        print(f"Most imbalanced: Window {most_imbalanced.period} "
              f"({most_imbalanced.imbalance_analysis.imbalance_ratio:.2f}:1)")
        
        avg_severity = np.mean([
            r.imbalance_analysis.severity_score 
            for r in valid_results
        ])
        
        if avg_severity < 0.3:
            print("Assessment: Well balanced")
        elif avg_severity < 0.6:
            print("Assessment: Moderate imbalance")
        else:
            print("Assessment: Significant imbalance")
    
    def _print_implementation_guide(self):
        """Print PyTorch implementation guide."""
        print("\n" + "="*70)
        print("Implementation Guide")
        print("="*70)
        
        for period, analysis in self.results.items():
            if analysis:
                print(f"\n# Window {period} days:")
                
                rec = analysis.imbalance_analysis.recommendations
                weights = analysis.imbalance_analysis.suggested_weights
                
                if rec['use_class_weights'] and weights:
                    print(f"weights = torch.tensor({weights})")
                    print(f"criterion = nn.{rec['suggested_loss_function']}(weight=weights)")
                
                if rec['use_balanced_sampling']:
                    print("# Use WeightedRandomSampler for balanced batches")
    
    def get_results_as_dict(self) -> Dict:
        """Get results as dictionary for easy access."""
        output = {}
        for period, analysis in self.results.items():
            if analysis:
                output[period] = {
                    'total': analysis.total_samples,
                    'distribution': analysis.class_distribution,
                    'ratio': f"{analysis.imbalance_analysis.imbalance_ratio:.2f}:1",
                    'severity': analysis.imbalance_analysis.severity,
                    'weights': analysis.imbalance_analysis.suggested_weights
                }
        return output

# ==================== Convenience Functions ====================

def quick_analyze(base_path: Union[str, Path] = None, 
                 window_periods: List[str] = ['7', '14', '30']) -> Dict[str, Optional[WindowAnalysis]]:
    """
    Quick analysis function for immediate use.
    
    Parameters
    ----------
    base_path : Union[str, Path], optional
        Path to data directory
    window_periods : List[str]
        Window periods to analyze
    
    Returns
    -------
    Dict[str, Optional[WindowAnalysis]]
        Analysis results
    
    Examples
    --------
    >>> # Quick analysis with default settings
    >>> results = quick_analyze()
    
    >>> # Specify custom path
    >>> results = quick_analyze(base_path="/path/to/data")
    
    >>> # Analyze specific windows only
    >>> results = quick_analyze(window_periods=['7', '30'])
    """
    analyzer = DatasetAnalyzer(base_path="your_project/data")
    # Change to your actual data directory path
    # For example, if your project is at: C:/Users/Tian/Desktop/地磁论文代码运行测试
    # Then change to: analyzer = DatasetAnalyzer(base_path="C:/Users/Tian/Desktop/地磁论文代码运行测试/data")
    results = analyzer.analyze_all_windows(window_periods=window_periods, verbose=True)
    return results

# ==================== Main Execution ====================

if __name__ == "__main__":
    """Module testing and demonstration."""
    
    print("\n" + "="*70)
    print("Dataset Distribution Analysis Module")
    print(f"Version: {Config.VERSION}")
    print("="*70)
    
    print("\nUsage Instructions:")
    print("-" * 40)
    print("1. Direct path specification:")
    print('   analyzer = DatasetAnalyzer(base_path="./your_data_folder")')
    print("\n2. Using environment variable:")
    print("   Linux/Mac: export DATA_PATH=/path/to/data")
    print("   Windows:   set DATA_PATH=C:\\path\\to\\data")
    print("   Then:      analyzer = DatasetAnalyzer()")
    print("\n3. Quick analysis:")
    print("   results = quick_analyze(base_path=\"./data\")")
    print("-" * 40)
    
    print("\nStarting analysis...")
    
    import sys
    if len(sys.argv) > 1:
        data_path = sys.argv[1]
        print(f"Using command line path: {data_path}")
    else:
        data_path = None
    
    results = quick_analyze(base_path=data_path)
    
    if results and any(r is not None for r in results.values()):
        print("\nAnalysis complete!")
    else:
        print("\nNo valid results. Please check your data directory and structure.")


Dataset Distribution Analysis Module
Version: 1.0.0

Usage Instructions:
----------------------------------------
1. Direct path specification:
   analyzer = DatasetAnalyzer(base_path="./your_data_folder")

2. Using environment variable:
   Linux/Mac: export DATA_PATH=/path/to/data
   Windows:   set DATA_PATH=C:\path\to\data
   Then:      analyzer = DatasetAnalyzer()

3. Quick analysis:
   results = quick_analyze(base_path="./data")
----------------------------------------

Starting analysis...
Using command line path: --f=c:\Users\Tian\AppData\Roaming\jupyter\runtime\kernel-v35240284cbf5c642c864751e8e979de0e2f7e7282.json
2025-09-18 07:31:09 - __main__ - INFO - __init__:102 - Using provided path: C:\Users\Tian\Desktop\地磁论文代码运行测试\data

Dataset Distribution Analysis
Base Path: C:\Users\Tian\Desktop\地磁论文代码运行测试\data

Analyzing window: 7 days
----------------------------------------
Total samples: 2,660
Class distribution: [1661, 999]
Class proportions: 62.4%, 37.6%
Data shapes: {'data_0':

# Module 3: Training Pipeline Utilities

This module provides essential utility functions for deep learning model training, including device configuration, data processing, model initialization, and checkpoint management. It supports multi-GPU training and implements various strategies for handling imbalanced datasets.

In [22]:
"""
Training Pipeline Utility Functions

This module provides essential utility functions for deep learning experiments,
including device configuration, data processing, model initialization, and
training utilities with comprehensive multi-GPU support.

Author: Tian Gao
Date: 2025-09-16
Version: 1.0.0
License: MIT
"""

import copy
import json
import logging
import os
import random
import sys
import time
import warnings
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import platform

import numpy as np
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    f1_score,
    matthews_corrcoef,
    precision_score,
    recall_score
)
from sklearn.model_selection import StratifiedKFold

import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import (
    DataLoader,
    Dataset, 
    Subset,
    TensorDataset,
    WeightedRandomSampler
)

# ==================== Module Configuration ====================

class Config:
    """Module configuration constants."""
    VERSION = "1.0.0"
    AUTHOR = "Tian Gao"
    
    # Default directories - can be overridden by environment variables
    DEFAULT_CHECKPOINT_DIR = os.environ.get('CHECKPOINT_DIR', './checkpoints')
    DEFAULT_LOG_DIR = os.environ.get('LOG_DIR', './logs')
    MAX_WORKERS = 16  # Maximum DataLoader workers

# ==================== Setup and Configuration Functions ====================

def setup_logging(
    log_file: str = 'training.log',
    log_level: int = logging.INFO,
    log_dir: Optional[Union[str, Path]] = None,
    include_timestamp: bool = True
) -> logging.Logger:
    """
    Configure comprehensive logging system for training pipeline.
    
    Parameters
    ----------
    log_file : str
        Name of the log file
    log_level : int
        Logging level (logging.DEBUG, INFO, WARNING, ERROR)
    log_dir : Optional[Union[str, Path]]
        Directory for log files (creates if not exists)
    include_timestamp : bool
        Whether to include timestamp in log filename
        
    Returns
    -------
    logging.Logger
        Configured logger instance
    """
    # Handle log directory
    if log_dir:
        log_path = Path(log_dir)
    else:
        log_path = Path(Config.DEFAULT_LOG_DIR)
    log_path.mkdir(parents=True, exist_ok=True)
    
    # Add timestamp to filename if requested
    if include_timestamp:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file_name = f"{Path(log_file).stem}_{timestamp}{Path(log_file).suffix}"
    else:
        log_file_name = log_file
    
    full_log_path = log_path / log_file_name
    
    # Configure logging format
    log_format = '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
    date_format = '%Y-%m-%d %H:%M:%S'
    
    # Clear existing handlers
    logger = logging.getLogger()
    logger.handlers = []
    
    # Setup basic configuration
    logging.basicConfig(
        level=log_level,
        format=log_format,
        datefmt=date_format,
        handlers=[
            logging.StreamHandler(sys.stdout),
            logging.FileHandler(full_log_path, encoding='utf-8')
        ]
    )
    
    logger = logging.getLogger(__name__)
    logger.info(f"Logging initialized. Output file: {full_log_path}")
    
    return logger

def setup_device(
    gpu_ids: Optional[List[int]] = None,
    verbose: bool = True
) -> Tuple[torch.device, Optional[List[int]]]:
    """
    Configure optimal device setup for training.
    
    Parameters
    ----------
    gpu_ids : Optional[List[int]]
        Specific GPU IDs to use. If None, uses all available GPUs
    verbose : bool
        Whether to log detailed device information
        
    Returns
    -------
    Tuple[torch.device, Optional[List[int]]]
        (primary_device, gpu_ids_for_dataparallel)
    """
    if not torch.cuda.is_available():
        if verbose:
            logging.info("CUDA not available. Using CPU.")
        return torch.device('cpu'), None
    
    # Get available GPU count
    available_gpus = torch.cuda.device_count()
    
    if gpu_ids is None:
        gpu_ids = list(range(available_gpus))
        if verbose:
            logging.info(f"Auto-detected {available_gpus} GPU(s)")
    else:
        # Validate specified GPU IDs
        invalid_ids = [id for id in gpu_ids if id >= available_gpus]
        if invalid_ids:
            raise ValueError(f"Invalid GPU IDs {invalid_ids}. Available: 0-{available_gpus-1}")
        gpu_ids = [id for id in gpu_ids if id < available_gpus]
    
    # Configure primary device
    primary_device = torch.device(f'cuda:{gpu_ids[0]}')
    torch.cuda.set_device(gpu_ids[0])
    
    if verbose:
        if len(gpu_ids) > 1:
            logging.info(f"Multi-GPU training with devices: {gpu_ids}")
        else:
            logging.info(f"Single GPU training on device {gpu_ids[0]}")
    
    return primary_device, gpu_ids if len(gpu_ids) > 1 else None

def set_seed(seed: int = 42, strict: bool = False) -> None:
    """
    Set random seeds for reproducibility.
    
    Parameters
    ----------
    seed : int
        Random seed value
    strict : bool
        If True, enables strict determinism (may reduce performance)
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    if strict:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        if hasattr(torch, 'use_deterministic_algorithms'):
            torch.use_deterministic_algorithms(True)
            os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    else:
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = True
    
    logging.info(f"Random seeds set to {seed} (strict={strict})")

# ==================== Data Processing Functions ====================

def calculate_class_weights(
    labels: Union[np.ndarray, torch.Tensor],
    method: str = 'sqrt',
    normalize: bool = True
) -> torch.FloatTensor:
    """
    Calculate class weights for imbalanced datasets.
    
    Parameters
    ----------
    labels : Union[np.ndarray, torch.Tensor]
        Array of class labels
    method : str
        Weight calculation method ('inverse', 'sqrt', 'effective', 'balanced')
    normalize : bool
        Whether to normalize weights to sum to num_classes
        
    Returns
    -------
    torch.FloatTensor
        Class weights for loss function
    """
    # Convert to numpy if needed
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    class_counts = np.bincount(labels.astype(int))
    n_samples = len(labels)
    n_classes = len(class_counts)
    
    # Calculate imbalance ratio
    imbalance_ratio = class_counts.max() / class_counts.min()
    logging.info(f"Class distribution: {class_counts.tolist()}")
    logging.info(f"Imbalance ratio: {imbalance_ratio:.2f}:1")
    
    # Calculate weights based on method
    if method == 'inverse':
        weights = n_samples / (n_classes * class_counts)
    elif method == 'sqrt':
        weights = np.sqrt(n_samples / (n_classes * class_counts))
    elif method == 'effective':
        beta = 0.999
        effective_num = 1.0 - np.power(beta, class_counts)
        weights = (1.0 - beta) / np.array(effective_num)
    elif method == 'balanced':
        # sklearn-style balanced weights
        weights = n_samples / (n_classes * class_counts)
    else:
        raise ValueError(f"Unknown method: {method}. Use 'inverse', 'sqrt', 'effective', or 'balanced'")
    
    # Normalize if requested
    if normalize:
        weights = weights / weights.sum() * n_classes
    
    weights_tensor = torch.FloatTensor(weights)
    logging.info(f"Class weights ({method}): {weights_tensor.tolist()}")
    
    return weights_tensor

def create_balanced_sampler(
    dataset: Union[TensorDataset, Subset],
    oversample: bool = True,
    replacement: bool = True
) -> WeightedRandomSampler:
    """
    Create a weighted sampler for balanced training.
    
    Parameters
    ----------
    dataset : Union[TensorDataset, Subset]
        PyTorch dataset
    oversample : bool
        If True, oversample minority class
    replacement : bool
        Whether to sample with replacement
        
    Returns
    -------
    WeightedRandomSampler
        Configured balanced sampler
    """
    # Extract labels
    if isinstance(dataset, TensorDataset):
        labels = dataset.tensors[1].cpu().numpy()
    elif isinstance(dataset, Subset):
        labels = dataset.dataset.tensors[1][dataset.indices].cpu().numpy()
    else:
        raise TypeError(f"Unsupported dataset type: {type(dataset)}")
    
    # Calculate sample weights
    unique_classes, class_counts = np.unique(labels, return_counts=True)
    class_weights = 1.0 / class_counts
    sample_weights = np.array([class_weights[label] for label in labels])
    
    # Determine number of samples
    num_samples = len(labels) if oversample else int(class_counts.min() * len(unique_classes))
    
    logging.info(f"Balanced sampler: {'Oversampling' if oversample else 'Undersampling'}")
    
    return WeightedRandomSampler(
        weights=sample_weights,
        num_samples=num_samples,
        replacement=replacement
    )

def reshape_data_for_rnn(
    X: Union[np.ndarray, torch.Tensor],
    seq_length: int,
    validate: bool = True
) -> Union[np.ndarray, torch.Tensor]:
    """
    Reshape data for RNN input format.
    
    Parameters
    ----------
    X : Union[np.ndarray, torch.Tensor]
        Input data (batch_size, total_features)
    seq_length : int
        Number of time steps
    validate : bool
        Whether to validate dimensions
        
    Returns
    -------
    Union[np.ndarray, torch.Tensor]
        Reshaped data (batch_size, seq_length, features_per_timestep)
    """
    is_tensor = isinstance(X, torch.Tensor)
    if is_tensor:
        device = X.device
        X_np = X.cpu().numpy() if X.is_cuda else X.numpy()
    else:
        X_np = X
    
    batch_size, total_features = X_np.shape
    
    if validate and total_features % seq_length != 0:
        raise ValueError(f"Total features ({total_features}) must be divisible by seq_length ({seq_length})")
    
    features_per_timestep = total_features // seq_length
    X_reshaped = X_np.reshape(batch_size, seq_length, features_per_timestep)
    
    if is_tensor:
        X_reshaped = torch.from_numpy(X_reshaped).to(device)
    
    logging.info(f"Data reshaped: ({batch_size}, {total_features}) -> ({batch_size}, {seq_length}, {features_per_timestep})")
    
    return X_reshaped

# ==================== Data Loading Functions ====================

def create_data_loaders(
    dataset: Dataset,
    batch_size: int,
    num_gpus: int = 1,
    num_workers: Optional[int] = None,
    pin_memory: bool = True,
    shuffle: bool = True,
    drop_last: bool = False
) -> DataLoader:
    """
    Create standard DataLoader with multi-GPU optimization.
    
    Parameters
    ----------
    dataset : Dataset
        PyTorch Dataset
    batch_size : int
        Base batch size per GPU
    num_gpus : int
        Number of GPUs
    num_workers : Optional[int]
        Worker processes (auto-calculated if None)
    pin_memory : bool
        Pin memory for GPU transfer
    shuffle : bool
        Shuffle data each epoch
    drop_last : bool
        Drop incomplete last batch
        
    Returns
    -------
    DataLoader
        Configured DataLoader
    """
    effective_batch_size = batch_size * max(1, num_gpus)
    
    if num_workers is None:
        num_workers = min(4 * max(1, num_gpus), Config.MAX_WORKERS)
        if len(dataset) < 1000:
            num_workers = min(num_workers, 2)
    
    logging.info(f"DataLoader: batch_size={effective_batch_size}, workers={num_workers}")
    
    return DataLoader(
        dataset,
        batch_size=effective_batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=pin_memory and torch.cuda.is_available(),
        drop_last=drop_last
    )

def create_data_loaders_balanced(
    dataset: Dataset,
    batch_size: int,
    num_gpus: int = 1,
    is_train: bool = True,
    num_workers: Optional[int] = None,
    pin_memory: bool = True,
    oversample: bool = True
) -> DataLoader:
    """
    Create DataLoader with balanced sampling.
    
    Parameters
    ----------
    dataset : Dataset
        PyTorch Dataset
    batch_size : int
        Base batch size per GPU
    num_gpus : int
        Number of GPUs
    is_train : bool
        If True, use balanced sampling
    num_workers : Optional[int]
        Worker processes
    pin_memory : bool
        Pin memory for GPU transfer
    oversample : bool
        Whether to oversample minority class
        
    Returns
    -------
    DataLoader
        DataLoader with balanced sampling
    """
    effective_batch_size = batch_size * max(1, num_gpus)
    
    if num_workers is None:
        num_workers = min(4 * max(1, num_gpus), Config.MAX_WORKERS)
        if len(dataset) < 1000:
            num_workers = min(num_workers, 2)
    
    if is_train:
        try:
            sampler = create_balanced_sampler(dataset, oversample=oversample)
            
            return DataLoader(
                dataset,
                batch_size=effective_batch_size,
                sampler=sampler,
                num_workers=num_workers,
                pin_memory=pin_memory and torch.cuda.is_available(),
                drop_last=True
            )
        except (TypeError, ValueError) as e:
            logging.warning(f"Cannot create balanced sampler: {e}")
            return create_data_loaders(
                dataset, batch_size, num_gpus, num_workers,
                pin_memory, shuffle=True, drop_last=True
            )
    else:
        return create_data_loaders(
            dataset, batch_size, num_gpus, num_workers,
            pin_memory, shuffle=False, drop_last=False
        )

def create_balanced_dataloader(
    dataset: Union[TensorDataset, Subset],
    batch_size: int,
    is_train: bool = True
) -> DataLoader:
    """
    Create dataloader with balanced sampling for training.
    
    Simple wrapper for backward compatibility.
    """
    if is_train:
        # Extract labels for weighted sampling
        if isinstance(dataset, Subset):
            all_labels = dataset.dataset.tensors[1][dataset.indices].numpy()
        else:
            all_labels = dataset.tensors[1].numpy()
        
        # Calculate sample weights
        class_counts = np.bincount(all_labels.astype(int))
        class_weights = 1.0 / class_counts
        sample_weights = class_weights[all_labels]
        
        # Create weighted sampler
        sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(sample_weights),
            replacement=True
        )
        
        return DataLoader(
            dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=0,
            pin_memory=torch.cuda.is_available()
        )
    else:
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=torch.cuda.is_available()
        )

# ==================== Model Operations ====================

def create_model(
    model_class: type,
    input_size: int,
    hidden_sizes: List[int],
    output_size: int,
    dropout_prob: float,
    device: torch.device,
    device_ids: Optional[List[int]] = None,
    **kwargs
) -> nn.Module:
    """
    Create and initialize model with multi-GPU support.
    
    Parameters
    ----------
    model_class : type
        Model class to instantiate
    input_size : int
        Number of input features
    hidden_sizes : List[int]
        Hidden layer dimensions
    output_size : int
        Number of output classes
    dropout_prob : float
        Dropout probability
    device : torch.device
        Primary device
    device_ids : Optional[List[int]]
        GPU IDs for DataParallel
    **kwargs
        Additional model arguments
        
    Returns
    -------
    nn.Module
        Initialized model
    """
    model = model_class(
        input_size=input_size,
        hidden_sizes=hidden_sizes,
        output_size=output_size,
        dropout_prob=dropout_prob,
        **kwargs
    )
    
    total_params = count_parameters(model, trainable_only=False)
    logging.info(f"Model created: {model_class.__name__}")
    logging.info(f"Parameters: {total_params:,}")
    
    if device_ids and len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)
        logging.info(f"DataParallel enabled: {device_ids}")
    
    return model.to(device)

def accuracy(
    outputs: torch.Tensor, 
    labels: torch.Tensor,
    return_correct: bool = False
) -> Union[float, Tuple[float, int]]:
    """Calculate classification accuracy."""
    with torch.no_grad():
        _, classifications = torch.max(outputs, dim=1)
        correct = (classifications == labels).float()
        acc = correct.mean().item()
        
        if return_correct:
            return acc, int(correct.sum().item())
        return acc

def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
    """Count model parameters."""
    if trainable_only:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    return sum(p.numel() for p in model.parameters())

# ==================== Checkpoint Management ====================

def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    loss: float,
    checkpoint_dir: Union[str, Path] = None,
    filename: Optional[str] = None,
    additional_info: Optional[Dict[str, Any]] = None
) -> Path:
    """
    Save complete training checkpoint.
    
    Parameters
    ----------
    model : nn.Module
        Model to save
    optimizer : torch.optim.Optimizer
        Optimizer state
    epoch : int
        Current epoch
    loss : float
        Current loss
    checkpoint_dir : Union[str, Path]
        Directory for checkpoints
    filename : Optional[str]
        Custom filename
    additional_info : Optional[Dict[str, Any]]
        Additional data to save
        
    Returns
    -------
    Path
        Path to saved checkpoint
    """
    if checkpoint_dir is None:
        checkpoint_dir = Config.DEFAULT_CHECKPOINT_DIR
    
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    if filename is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"checkpoint_epoch_{epoch:03d}_{timestamp}.pth"
    
    checkpoint_path = checkpoint_dir / filename
    
    # Handle DataParallel
    if isinstance(model, nn.DataParallel):
        model_state_dict = model.module.state_dict()
    else:
        model_state_dict = model.state_dict()
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model_state_dict,
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'timestamp': datetime.now().isoformat()
    }
    
    if additional_info:
        checkpoint.update(additional_info)
    
    torch.save(checkpoint, checkpoint_path)
    logging.info(f"Checkpoint saved: {checkpoint_path}")
    
    return checkpoint_path

def load_checkpoint(
    checkpoint_path: Union[str, Path],
    model: nn.Module,
    optimizer: Optional[torch.optim.Optimizer] = None,
    device: torch.device = torch.device('cpu'),
    strict: bool = True
) -> Tuple[nn.Module, Optional[torch.optim.Optimizer], int, float, Dict[str, Any]]:
    """
    Load model checkpoint.
    
    Parameters
    ----------
    checkpoint_path : Union[str, Path]
        Path to checkpoint
    model : nn.Module
        Model to load into
    optimizer : Optional[torch.optim.Optimizer]
        Optimizer to restore
    device : torch.device
        Device to map to
    strict : bool
        Strict state dict matching
        
    Returns
    -------
    Tuple
        (model, optimizer, epoch, loss, additional_info)
    """
    checkpoint_path = Path(checkpoint_path)
    
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Handle DataParallel
    if isinstance(model, nn.DataParallel):
        model.module.load_state_dict(checkpoint['model_state_dict'], strict=strict)
    else:
        model.load_state_dict(checkpoint['model_state_dict'], strict=strict)
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    additional_info = {
        k: v for k, v in checkpoint.items() 
        if k not in ['epoch', 'model_state_dict', 'optimizer_state_dict', 'loss']
    }
    
    logging.info(f"Checkpoint loaded: {checkpoint_path} (epoch {epoch}, loss {loss:.4f})")
    
    return model, optimizer, epoch, loss, additional_info

# ==================== Metrics Evaluation ====================

def calculate_binary_metrics(
    y_true: np.ndarray,
    y_class: np.ndarray
) -> Dict[str, float]:
    """Calculate comprehensive binary classification metrics."""
    # Core metrics
    f1 = f1_score(y_true, y_class, average='weighted', zero_division=0)
    precision = precision_score(y_true, y_class, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_class, average='weighted', zero_division=0)
    
    # Confusion matrix for specificity
    cm = confusion_matrix(y_true, y_class, labels=[0, 1])
    specificity = 0
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    
    # Matthews Correlation Coefficient
    mcc = matthews_corrcoef(y_true, y_class)
    norm_mcc = (mcc + 1) / 2  # Normalize to [0, 1]
    
    # Per-class accuracy
    class_acc = np.zeros(2)
    for i in range(2):
        mask = (y_true == i)
        if mask.sum() > 0:
            class_acc[i] = (y_class[mask] == i).mean()
    
    return {
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'mcc': mcc,
        'norm_mcc': norm_mcc,
        'class_acc': class_acc
    }

# ==================== Epoch Execution ====================

def run_epoch(
    model: nn.Module,
    data_loader: DataLoader,
    device: torch.device,
    loss_fn: nn.Module,
    optimizer: Optional[torch.optim.Optimizer] = None,
    max_grad_norm: Optional[float] = 1.0,
    is_training: bool = False
) -> Dict[str, Any]:
    """Execute single training or validation epoch."""
    if is_training and optimizer is None:
        raise ValueError("Optimizer required for training")
    
    # Set model mode
    model.train() if is_training else model.eval()
    
    # Initialize tracking
    total_loss = 0.0
    all_probs = []
    all_labels = []
    all_classes = []
    num_batches = 0
    
    # Process batches
    with torch.set_grad_enabled(is_training):
        for inputs, labels in data_loader:
            # Skip single-sample batches (BatchNorm issue)
            if inputs.size(0) <= 1:
                continue
            
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            
            # Backward pass (training only)
            if is_training:
                optimizer.zero_grad()
                loss.backward()
                
                # Gradient clipping
                if max_grad_norm:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        max_grad_norm
                    )
                
                optimizer.step()
            
            # Store classification outputs
            with torch.no_grad():
                probs = torch.softmax(outputs, dim=1)
                classes = torch.argmax(outputs, dim=1)
                
                all_probs.append(probs.cpu().numpy())
                all_classes.append(classes.cpu().numpy())
                all_labels.append(labels.cpu().numpy())
            
            total_loss += loss.item()
            num_batches += 1
    
    # Handle empty epoch
    if num_batches == 0:
        return {
            'loss': 0.0,
            'acc': 0.0,
            'metrics': {
                'f1': 0.0,
                'precision': 0.0,
                'recall': 0.0,
                'specificity': 0.0,
                'mcc': 0.0,
                'norm_mcc': 0.0,
                'class_acc': np.zeros(2)
            },
            'all_probs': np.array([]),
            'all_classes': np.array([]),
            'all_labels': np.array([])
        }
    
    # Aggregate results
    all_probs = np.concatenate(all_probs)
    all_classes = np.concatenate(all_classes)
    all_labels = np.concatenate(all_labels)
    
    # Calculate metrics
    avg_loss = total_loss / num_batches
    avg_acc = (all_classes == all_labels).mean()
    metrics = calculate_binary_metrics(all_labels, all_classes)
    
    return {
        'loss': avg_loss,
        'acc': avg_acc,
        'metrics': metrics,
        'all_probs': all_probs,
        'all_classes': all_classes,
        'all_labels': all_labels
    }

# Module 4: Core Model Classes

Neural network architectures for seismic geomagnetic signal classification, implementing RNN-based temporal models with F1-weighted ensemble methods.


In [None]:
# ==================== Core Model Classes ====================

"""
Neural Network Model Definitions for Geomagnetic Disturbance Prediction

This module implements RNN-based architectures with ensemble methods
for robust temporal sequence classification.

Author: Tian Gao
Date: 2025/09/16
Version: 1.0.0
License: MIT
"""

# ==================== Module Configuration ====================

__version__ = "1.0.0"
__author__ = "Tian Gao"

# Configure logger
logger = logging.getLogger(__name__)

# ==================== Base Model Class ====================

class BaseModel(nn.Module):
    """
    Abstract base class providing common functionality for all models.
    """
    
    def __init__(self):
        super().__init__()
        self._device = None
        self.config = {}
    
    def count_parameters(self, trainable_only: bool = True) -> int:
        """Count model parameters."""
        if trainable_only:
            return sum(p.numel() for p in self.parameters() if p.requires_grad)
        return sum(p.numel() for p in self.parameters())
    
    def get_device(self) -> torch.device:
        """Get the device of the model."""
        if self._device is None:
            self._device = next(self.parameters()).device
        return self._device
    
    def freeze_layers(self, layers_to_freeze: List[str]) -> None:
        """Freeze specific layers by name pattern."""
        frozen_count = 0
        for name, param in self.named_parameters():
            if any(layer in name for layer in layers_to_freeze):
                param.requires_grad = False
                frozen_count += 1
        logger.info(f"Frozen {frozen_count} layer groups")
    
    def unfreeze_all_layers(self) -> None:
        """Unfreeze all model layers."""
        for param in self.parameters():
            param.requires_grad = True
        logger.info("All layers unfrozen")
    
    def monitor_batch_norm_stats(self) -> Dict[str, Dict[str, float]]:
        """Monitor BatchNorm layer statistics."""
        stats = {}
        for name, module in self.named_modules():
            if isinstance(module, nn.BatchNorm1d):
                stats[name] = {
                    'running_mean': module.running_mean.mean().item(),
                    'running_var': module.running_var.mean().item(),
                    'weight_mean': module.weight.mean().item() if module.weight is not None else None,
                    'bias_mean': module.bias.mean().item() if module.bias is not None else None,
                    'momentum': module.momentum,
                    'num_features': module.num_features
                }
        return stats
    
    def save_checkpoint(
        self, 
        filepath: Union[str, Path],
        optimizer: Optional[torch.optim.Optimizer] = None,
        epoch: Optional[int] = None,
        metrics: Optional[Dict[str, float]] = None,
        model_config: Optional[Dict[str, Any]] = None
    ) -> None:
        """
        Save model checkpoint with configuration.
        
        Args:
            filepath: Save location
            optimizer: Optional optimizer state to save
            epoch: Current epoch number
            metrics: Performance metrics to save
            model_config: Model configuration to save
        """
        filepath = Path(filepath)
        filepath.parent.mkdir(parents=True, exist_ok=True)
        
        checkpoint = {
            'model_state_dict': self.state_dict(),
            'model_class': self.__class__.__name__,
            'model_config': model_config or getattr(self, 'config', {})
        }
        
        if optimizer:
            checkpoint['optimizer_state_dict'] = optimizer.state_dict()
        if epoch is not None:
            checkpoint['epoch'] = epoch
        if metrics:
            checkpoint['metrics'] = metrics
        
        torch.save(checkpoint, filepath)
        logger.info(f"Checkpoint saved to {filepath}")
    
    @classmethod
    def load_checkpoint(
        cls,
        filepath: Union[str, Path],
        map_location: Optional[torch.device] = None,
        **override_config
    ) -> 'BaseModel':
        """
        Load model from checkpoint with configuration.
        
        Args:
            filepath: Checkpoint file path
            map_location: Device mapping location
            override_config: Configuration parameters to override
            
        Returns:
            Loaded model instance
        """
        if map_location is None:
            map_location = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        checkpoint = torch.load(filepath, map_location=map_location)
        
        # Get model configuration
        model_config = checkpoint.get('model_config', {})
        model_config.update(override_config)
        
        # Initialize model with configuration
        if model_config:
            model = cls(**model_config)
        else:
            raise ValueError("Model configuration not found in checkpoint. "
                           "Please provide configuration via override_config.")
        
        model.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"Model loaded from {filepath}")
        return model

# ==================== RNN Classifier with BatchNorm ====================

class RNNClassifier(BaseModel):
    """
    Multi-Layer RNN Classifier with BatchNorm and Dropout.
    
    Architecture:
        Input -> RNN Layers -> BatchNorm -> Dropout -> Linear -> Output
    
    Args:
        input_size: Input features per timestep
        hidden_sizes: List of hidden dimensions for each RNN layer
        output_size: Number of output classes
        dropout_prob: Dropout probability (default: 0.3)
        use_batch_norm: Use batch normalization (default: True)
        bidirectional: Use bidirectional RNN (default: False)
        recurrent_dropout: Dropout for recurrent connections (default: 0.0)
        batch_norm_momentum: Momentum for BatchNorm (default: 0.1)
        batch_norm_eps: Epsilon for BatchNorm (default: 1e-5)
        batch_norm_track_stats: Track running stats in BatchNorm (default: True)
        nonlinearity: RNN activation function ('tanh' or 'relu', default: 'tanh')
    """
    
    def __init__(
        self,
        input_size: int,
        hidden_sizes: List[int],
        output_size: int,
        dropout_prob: float = 0.3,
        use_batch_norm: bool = True,
        bidirectional: bool = False,
        recurrent_dropout: float = 0.0,
        batch_norm_momentum: float = 0.1,
        batch_norm_eps: float = 1e-5,
        batch_norm_track_stats: bool = True,
        nonlinearity: str = 'tanh'
    ):
        super().__init__()
        
        # Validation
        self._validate_inputs(input_size, hidden_sizes, output_size, dropout_prob)
        
        # Store configuration for checkpoint
        self.config = {
            'input_size': input_size,
            'hidden_sizes': hidden_sizes,
            'output_size': output_size,
            'dropout_prob': dropout_prob,
            'use_batch_norm': use_batch_norm,
            'bidirectional': bidirectional,
            'recurrent_dropout': recurrent_dropout,
            'batch_norm_momentum': batch_norm_momentum,
            'batch_norm_eps': batch_norm_eps,
            'batch_norm_track_stats': batch_norm_track_stats,
            'nonlinearity': nonlinearity
        }
        
        # Store configuration attributes
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.dropout_prob = dropout_prob
        self.use_batch_norm = use_batch_norm
        self.bidirectional = bidirectional
        self.recurrent_dropout = recurrent_dropout
        self.num_layers = len(hidden_sizes)
        self.num_directions = 2 if bidirectional else 1
        self.batch_norm_momentum = batch_norm_momentum
        self.batch_norm_eps = batch_norm_eps
        self.batch_norm_track_stats = batch_norm_track_stats
        self.nonlinearity = nonlinearity
        
        # Build layers
        self._build_layers()
        
        # Initialize weights
        self._initialize_weights()
        
        # Log model info
        self._log_model_info()
    
    def _validate_inputs(
        self,
        input_size: int,
        hidden_sizes: List[int],
        output_size: int,
        dropout_prob: float
    ) -> None:
        """Validate model inputs."""
        if not hidden_sizes:
            raise ValueError("hidden_sizes cannot be empty")
        if input_size <= 0:
            raise ValueError(f"input_size must be positive, got {input_size}")
        if output_size <= 0:
            raise ValueError(f"output_size must be positive, got {output_size}")
        if not 0 <= dropout_prob < 1:
            raise ValueError(f"dropout_prob must be in [0, 1), got {dropout_prob}")
    
    def _build_layers(self) -> None:
        """Build model layers with BatchNorm."""
        # RNN layers
        self.rnn_layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList() if self.use_batch_norm else None
        self.dropout_layers = nn.ModuleList()
        
        for i in range(self.num_layers):
            input_dim = self.input_size if i == 0 else self.hidden_sizes[i-1] * self.num_directions
            
            # RNN layer without internal dropout
            rnn = nn.RNN(
                input_size=input_dim,
                hidden_size=self.hidden_sizes[i],
                num_layers=1,
                batch_first=True,
                bidirectional=self.bidirectional,
                nonlinearity=self.nonlinearity
            )
            self.rnn_layers.append(rnn)
            
            # BatchNorm layer
            if self.use_batch_norm:
                bn = nn.BatchNorm1d(
                    self.hidden_sizes[i] * self.num_directions,
                    momentum=self.batch_norm_momentum,
                    eps=self.batch_norm_eps,
                    track_running_stats=self.batch_norm_track_stats
                )
                self.batch_norms.append(bn)
            
            # Dropout layer for recurrent connections
            if i < self.num_layers - 1 and self.recurrent_dropout > 0:
                self.dropout_layers.append(nn.Dropout(self.recurrent_dropout))
            else:
                self.dropout_layers.append(nn.Identity())
        
        # Output layers
        final_hidden = self.hidden_sizes[-1] * self.num_directions
        
        # Final BatchNorm before classifier (optional)
        if self.use_batch_norm:
            self.final_batch_norm = nn.BatchNorm1d(
                final_hidden,
                momentum=self.batch_norm_momentum,
                eps=self.batch_norm_eps,
                track_running_stats=self.batch_norm_track_stats
            )
        else:
            self.final_batch_norm = nn.Identity()
        
        # Dropout
        self.dropout = nn.Dropout(self.dropout_prob)
        
        # Classifier
        self.classifier = nn.Linear(final_hidden, self.output_size)
    
    def _initialize_weights(self) -> None:
        """Initialize weights using best practices for RNNs with BatchNorm."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            elif isinstance(module, nn.RNN):
                for name, param in module.named_parameters():
                    if 'weight_ih' in name:
                        nn.init.xavier_uniform_(param)
                    elif 'weight_hh' in name:
                        nn.init.orthogonal_(param)
                    elif 'bias' in name:
                        nn.init.zeros_(param)
            elif isinstance(module, nn.BatchNorm1d):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
    
    def _log_model_info(self) -> None:
        """Log model architecture information."""
        total_params = self.count_parameters()
        logger.info(
            f"RNNClassifier initialized:\n"
            f"  Input size: {self.input_size}\n"
            f"  Hidden sizes: {self.hidden_sizes}\n"
            f"  Output size: {self.output_size}\n"
            f"  Use BatchNorm: {self.use_batch_norm}\n"
            f"  Bidirectional: {self.bidirectional}\n"
            f"  Nonlinearity: {self.nonlinearity}\n"
            f"  Total parameters: {total_params:,}\n"
            f"  Model size: {total_params * 4 / (1024 * 1024):.2f} MB"
        )
    
    def forward(
        self,
        x: torch.Tensor,
        hidden_states: Optional[List[torch.Tensor]] = None
    ) -> torch.Tensor:
        """
        Forward pass through the RNN network.
        
        Args:
            x: Input tensor (batch_size, seq_length, input_size)
            hidden_states: Optional initial hidden states
            
        Returns:
            Output logits (batch_size, output_size)
        """
        batch_size, seq_length, _ = x.size()
        
        # Initialize hidden states if not provided
        if hidden_states is None:
            hidden_states = self._init_hidden_states(batch_size, x.device)
        
        # Process through RNN layers
        for i, rnn_layer in enumerate(self.rnn_layers):
            x, h = rnn_layer(x, hidden_states[i])
            
            # Apply BatchNorm
            if self.use_batch_norm:
                # BatchNorm1d expects (batch, features) or (batch, features, length)
                # We have (batch, seq_len, features), need to reshape
                original_shape = x.shape
                x = x.contiguous().view(-1, original_shape[-1])  # (batch*seq_len, features)
                x = self.batch_norms[i](x)
                x = x.view(original_shape)  # Reshape back to (batch, seq_len, features)
            
            # Apply recurrent dropout between layers
            if i < self.num_layers - 1 and self.training:
                x = self.dropout_layers[i](x)
        
        # Extract last timestep output
        x = x[:, -1, :]  # (batch_size, final_hidden)
        
        # Final BatchNorm and dropout
        x = self.final_batch_norm(x)
        x = self.dropout(x)
        
        # Classification layer
        logits = self.classifier(x)
        
        return logits
    
    def _init_hidden_states(
        self,
        batch_size: int,
        device: torch.device
    ) -> List[torch.Tensor]:
        """Initialize hidden states for all layers."""
        hidden_states = []
        for hidden_size in self.hidden_sizes:
            h0 = torch.zeros(
                self.num_directions,
                batch_size,
                hidden_size,
                device=device
            )
            hidden_states.append(h0)
        return hidden_states
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get model configuration and statistics."""
        info = {
            'model_type': 'RNNClassifier',
            'input_size': self.input_size,
            'hidden_sizes': self.hidden_sizes,
            'output_size': self.output_size,
            'num_layers': self.num_layers,
            'dropout_prob': self.dropout_prob,
            'use_batch_norm': self.use_batch_norm,
            'bidirectional': self.bidirectional,
            'recurrent_dropout': self.recurrent_dropout,
            'nonlinearity': self.nonlinearity,
            'total_parameters': self.count_parameters(),
            'trainable_parameters': self.count_parameters(trainable_only=True),
            'model_size_mb': self.count_parameters() * 4 / (1024 * 1024)
        }
        
        if self.use_batch_norm:
            info.update({
                'batch_norm_momentum': self.batch_norm_momentum,
                'batch_norm_eps': self.batch_norm_eps,
                'batch_norm_track_stats': self.batch_norm_track_stats
            })
        
        return info

# ==================== F1-Weighted Ensemble ====================

class F1WeightedEnsemble(BaseModel):
    """
    F1-Score Weighted Ensemble Model.
    
    Combines multiple models using performance-based weighting.
    
    Args:
        models: List of base models
        weights: F1 scores or custom weights (None for equal weights)
        strategy: Aggregation strategy ('weighted_mean', 'voting', 'max')
        temperature: Temperature for softmax scaling (default: 1.0)
    """
    
    def __init__(
        self,
        models: List[nn.Module],
        weights: Optional[List[float]] = None,
        strategy: str = 'weighted_mean',
        temperature: float = 1.0
    ):
        super().__init__()
        
        # Validation
        self._validate_inputs(models, weights, strategy)
        
        # Store configuration
        self.config = {
            'num_models': len(models),
            'strategy': strategy,
            'temperature': temperature,
            'weights': weights
        }
        
        # Store models and configuration
        self.base_models = nn.ModuleList(models)
        self.num_models = len(models)
        self.strategy = strategy
        self.temperature = temperature
        
        # Setup weights
        self._setup_weights(weights)
        
        # Freeze base models
        self._freeze_base_models()
        
        logger.info(
            f"F1WeightedEnsemble initialized:\n"
            f"  Number of models: {self.num_models}\n"
            f"  Strategy: {strategy}\n"
            f"  Weights: {self.ensemble_weights.tolist()}"
        )
    
    def _validate_inputs(
        self,
        models: List[nn.Module],
        weights: Optional[List[float]],
        strategy: str
    ) -> None:
        """Validate ensemble inputs."""
        if not models:
            raise ValueError("At least one model required")
        
        if weights is not None and len(weights) != len(models):
            raise ValueError(f"Number of weights ({len(weights)}) must match number of models ({len(models)})")
        
        valid_strategies = ['weighted_mean', 'voting', 'max']
        if strategy not in valid_strategies:
            raise ValueError(f"strategy must be one of {valid_strategies}")
    
    def _setup_weights(self, weights: Optional[List[float]]) -> None:
        """Setup and normalize ensemble weights."""
        if weights is None:
            # Equal weights
            weights_tensor = torch.ones(self.num_models) / self.num_models
        else:
            # Normalize provided weights
            weights_tensor = torch.tensor(weights, dtype=torch.float32)
            weights_tensor = F.softmax(weights_tensor, dim=0)
        
        self.register_buffer('ensemble_weights', weights_tensor)
    
    def _freeze_base_models(self) -> None:
        """Freeze all base model parameters."""
        for model in self.base_models:
            for param in model.parameters():
                param.requires_grad = False
    
    def forward(
        self,
        x: torch.Tensor,
        return_all: bool = False
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass through ensemble.
        
        Important: Ensures all models are in eval mode for proper BatchNorm behavior.
        
        Args:
            x: Input tensor
            return_all: Return individual model predictions
            
        Returns:
            Ensemble predictions, optionally with individual predictions
        """
        predictions = []
        
        # Collect predictions from all models
        for model in self.base_models:
            # Critical: Set model to eval mode for BatchNorm
            model.eval()
            with torch.no_grad():
                logits = model(x)
                probs = F.softmax(logits / self.temperature, dim=-1)
                predictions.append(probs)
        
        # Stack: (num_models, batch_size, num_classes)
        all_predictions = torch.stack(predictions, dim=0)
        
        # Aggregate based on strategy
        if self.strategy == 'weighted_mean':
            weights = self.ensemble_weights.view(-1, 1, 1)
            ensemble_pred = (all_predictions * weights).sum(dim=0)
            
        elif self.strategy == 'voting':
            votes = all_predictions.argmax(dim=-1)
            # Get mode for each sample
            ensemble_pred = []
            for i in range(votes.size(1)):
                values, counts = torch.unique(votes[:, i], return_counts=True)
                mode = values[counts.argmax()]
                ensemble_pred.append(mode)
            ensemble_pred = torch.stack(ensemble_pred)
            # Convert to probabilities
            num_classes = all_predictions.size(-1)
            ensemble_pred = F.one_hot(ensemble_pred, num_classes=num_classes).float()
            
        elif self.strategy == 'max':
            ensemble_pred = all_predictions.max(dim=0).values
        
        if return_all:
            return ensemble_pred, all_predictions
        return ensemble_pred
    
    def predict_with_confidence(
        self,
        x: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Make predictions with confidence and uncertainty estimation.
        
        Args:
            x: Input tensor
            
        Returns:
            - Predicted classes (batch_size,)
            - Confidence scores (batch_size,)
            - Prediction variance (batch_size,)
        """
        with torch.no_grad():
            ensemble_pred, all_preds = self.forward(x, return_all=True)
            
            # Get predictions and confidence
            confidence, predicted = ensemble_pred.max(dim=-1)
            
            # Calculate variance across models
            variance = all_preds.var(dim=0).mean(dim=-1)
            
            return predicted, confidence, variance
    
    def get_ensemble_info(self) -> Dict[str, Any]:
        """Get ensemble configuration and statistics."""
        total_params = sum(
            model.count_parameters() if hasattr(model, 'count_parameters') 
            else sum(p.numel() for p in model.parameters())
            for model in self.base_models
        )
        
        return {
            'ensemble_type': 'F1WeightedEnsemble',
            'num_models': self.num_models,
            'strategy': self.strategy,
            'temperature': self.temperature,
            'weights': self.ensemble_weights.tolist(),
            'total_parameters': total_params,
            'ensemble_size_mb': total_params * 4 / (1024 * 1024)
        }

# ==================== Factory Functions ====================

def create_rnn_classifier(config: Dict[str, Any]) -> RNNClassifier:
    """
    Factory function to create RNN model from configuration.
    
    Args:
        config: Model configuration dictionary
        
    Returns:
        Initialized RNNClassifier
    
    Example:
        config = {
            'input_size': 10,
            'hidden_sizes': [128, 64],
            'output_size': 3,
            'use_batch_norm': True,
            'dropout_prob': 0.3,
            'nonlinearity': 'tanh'
        }
        model = create_rnn_classifier(config)
    """
    return RNNClassifier(
        input_size=config['input_size'],
        hidden_sizes=config['hidden_sizes'],
        output_size=config['output_size'],
        dropout_prob=config.get('dropout_prob', 0.3),
        use_batch_norm=config.get('use_batch_norm', True),
        bidirectional=config.get('bidirectional', False),
        recurrent_dropout=config.get('recurrent_dropout', 0.0),
        batch_norm_momentum=config.get('batch_norm_momentum', 0.1),
        batch_norm_eps=config.get('batch_norm_eps', 1e-5),
        batch_norm_track_stats=config.get('batch_norm_track_stats', True),
        nonlinearity=config.get('nonlinearity', 'tanh')
    )

def create_ensemble(
    models: List[nn.Module],
    weights: Optional[List[float]] = None,
    strategy: str = 'weighted_mean'
) -> F1WeightedEnsemble:
    """
    Factory function to create ensemble model.
    
    Args:
        models: List of base models
        weights: Optional F1 scores or weights
        strategy: Aggregation strategy
        
    Returns:
        Initialized F1WeightedEnsemble
    """
    return F1WeightedEnsemble(models, weights, strategy)

# Module 5: Training Pipeline and K-Fold Cross-Validation

This module implements the complete training pipeline with stratified K-fold cross-validation for binary classification. It includes epoch execution, metrics evaluation, and model checkpointing with support for imbalanced datasets and multi-GPU training.

In [None]:
# ==================== Core Training Functions ====================

"""
Core Training Functions for RNN Binary Classification

Implements training pipeline with F1-score optimization and stratified 
K-fold cross-validation for imbalanced datasets.

Author: Tian Gao
Date: 2025-09-16
Version: 1.0.0
License: MIT
"""

# Configure logger
logger = logging.getLogger(__name__)

# ==================== Training Pipeline ====================

def train_eval_rnn(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: lr_scheduler._LRScheduler,
    epochs: int,
    device: torch.device,
    log_epoch: Optional[List[str]] = None,
    max_grad_norm: Optional[float] = 1.0
) -> Tuple:
    """
    Train and evaluate RNN model for complete epochs.
    
    Returns:
        Training results and best model state (based on val F1).
    """
    best_model_state = None
    best_val_f1 = -float('inf')
    best_epoch = 0
    
    logger.info(f"Starting training for {epochs} epochs")
    
    # Store all epoch metrics for plotting
    history = {
        'train_loss': [], 'val_loss': [],
        'train_f1': [], 'val_f1': [],
        'train_acc': [], 'val_acc': [],
        'train_precision': [], 'val_precision': [],
        'train_recall': [], 'val_recall': [],
        'train_specificity': [], 'val_specificity': [],
        'train_mcc': [], 'val_mcc': [],
        'train_norm_mcc': [], 'val_norm_mcc': [],
        'train_class_acc': [], 'val_class_acc': []
    }
    
    for epoch in range(epochs):
        # Training phase
        train_results = run_epoch(
            model, train_loader, device, loss_fn,
            optimizer, max_grad_norm, is_training=True
        )
        
        # Validation phase
        val_results = run_epoch(
            model, val_loader, device, loss_fn,
            optimizer=None, is_training=False
        )
        
        # Store history
        history['train_loss'].append(train_results['loss'])
        history['val_loss'].append(val_results['loss'])
        history['train_f1'].append(train_results['metrics']['f1'])
        history['val_f1'].append(val_results['metrics']['f1'])
        history['train_acc'].append(train_results['acc'])
        history['val_acc'].append(val_results['acc'])
        history['train_precision'].append(train_results['metrics']['precision'])
        history['val_precision'].append(val_results['metrics']['precision'])
        history['train_recall'].append(train_results['metrics']['recall'])
        history['val_recall'].append(val_results['metrics']['recall'])
        history['train_specificity'].append(train_results['metrics']['specificity'])
        history['val_specificity'].append(val_results['metrics']['specificity'])
        history['train_mcc'].append(train_results['metrics']['mcc'])
        history['val_mcc'].append(val_results['metrics']['mcc'])
        history['train_norm_mcc'].append(train_results['metrics']['norm_mcc'])
        history['val_norm_mcc'].append(val_results['metrics']['norm_mcc'])
        history['train_class_acc'].append(train_results['metrics']['class_acc'].tolist())
        history['val_class_acc'].append(val_results['metrics']['class_acc'].tolist())
        
        # Model selection based on F1
        current_f1 = val_results['metrics']['f1']
        if current_f1 > best_val_f1:
            best_val_f1 = current_f1
            best_model_state = copy.deepcopy(model.state_dict())
            best_epoch = epoch + 1
            logger.info(f"New best model at epoch {best_epoch} (F1: {best_val_f1:.4f})")
        
        # Learning rate scheduling
        if isinstance(scheduler, lr_scheduler.ReduceLROnPlateau):
            scheduler.step(current_f1)
        else:
            scheduler.step()
        
        # Logging with all metrics
        current_lr = optimizer.param_groups[0]['lr']
        train_class_acc = train_results['metrics']['class_acc']
        val_class_acc = val_results['metrics']['class_acc']
        
        epoch_log = (
            f"Epoch {epoch+1}/{epochs}, "
            f"Train Loss: {train_results['loss']:.4f}, Train Acc: {train_results['acc']:.4f}, "
            f"Class 0/1: {train_class_acc[0]:.4f}/{train_class_acc[1]:.4f}, "
            f"Train F1: {train_results['metrics']['f1']:.4f}, "
            f"Train Precision: {train_results['metrics']['precision']:.4f}, "
            f"Train Recall: {train_results['metrics']['recall']:.4f}, "
            f"Train Specificity: {train_results['metrics']['specificity']:.4f}, "
            f"Train MCC: {train_results['metrics']['mcc']:.4f}, "
            f"Train Norm MCC: {train_results['metrics']['norm_mcc']:.4f}, "
            f"Val Loss: {val_results['loss']:.4f}, Val Acc: {val_results['acc']:.4f}, "
            f"Class 0/1: {val_class_acc[0]:.4f}/{val_class_acc[1]:.4f}, "
            f"Val F1: {val_results['metrics']['f1']:.4f}, "
            f"Val Precision: {val_results['metrics']['precision']:.4f}, "
            f"Val Recall: {val_results['metrics']['recall']:.4f}, "
            f"Val Specificity: {val_results['metrics']['specificity']:.4f}, "
            f"Val MCC: {val_results['metrics']['mcc']:.4f}, "
            f"Val Norm MCC: {val_results['metrics']['norm_mcc']:.4f}, "
            f"LR: {current_lr:.6f}"
        )
        
        logger.info(epoch_log)
        if log_epoch is not None:
            log_epoch.append(epoch_log)
    
    logger.info(f"Training completed. Best model from epoch {best_epoch}")
    
    # Load best model
    if best_model_state:
        model.load_state_dict(best_model_state)
    
    # Final evaluation
    with torch.no_grad():
        final_train = run_epoch(
            model, train_loader, device, loss_fn, is_training=False
        )
        final_val = run_epoch(
            model, val_loader, device, loss_fn, is_training=False
        )
    
    # Log final comprehensive metrics
    logger.info(f"Final training class accuracy - Class 0: {final_train['metrics']['class_acc'][0]:.4f}, "
                f"Class 1: {final_train['metrics']['class_acc'][1]:.4f}")
    logger.info(f"Final validation class accuracy - Class 0: {final_val['metrics']['class_acc'][0]:.4f}, "
                f"Class 1: {final_val['metrics']['class_acc'][1]:.4f}")
    logger.info(f"Final F1-score - Train: {final_train['metrics']['f1']:.4f}, "
                f"Validation: {final_val['metrics']['f1']:.4f}")
    logger.info(f"Final precision - Train: {final_train['metrics']['precision']:.4f}, "
                f"Validation: {final_val['metrics']['precision']:.4f}")
    logger.info(f"Final recall - Train: {final_train['metrics']['recall']:.4f}, "
                f"Validation: {final_val['metrics']['recall']:.4f}")
    logger.info(f"Final specificity - Train: {final_train['metrics']['specificity']:.4f}, "
                f"Validation: {final_val['metrics']['specificity']:.4f}")
    logger.info(f"Final MCC - Train: {final_train['metrics']['mcc']:.4f}, "
                f"Validation: {final_val['metrics']['mcc']:.4f}")
    logger.info(f"Final normalized MCC - Train: {final_train['metrics']['norm_mcc']:.4f}, "
                f"Validation: {final_val['metrics']['norm_mcc']:.4f}")
    
    return (
        final_train['loss'], final_val['loss'],
        final_train['acc'], final_val['acc'],
        final_train['all_probs'], final_train['all_labels'],
        final_val['all_probs'], final_val['all_labels'],
        best_model_state, history
    )

# ==================== K-Fold Cross-Validation ====================

def kfold_train_eval_rnn(
    model_class: Any,
    dataset: TensorDataset,
    loss_fn: nn.Module,
    optimizer_class: Any,
    optimizer_kwargs: dict,
    scheduler_class: Any,
    scheduler_kwargs: dict,
    epochs: int,
    device: torch.device,
    device_ids: Optional[List[int]] = None,
    num_folds: int = 5,
    output_dir: str = "your_project/results",
    # Change to your actual results directory path
    # For example, if your project is at: C:/Users/Tian/Desktop/地磁论文代码运行测试
    # Then change to: output_dir: str = r"C:\Users\Tian\Desktop\地磁论文代码运行测试\results"
    model_name: str = "RNNModel",
    window_name: str = "7day",
    max_grad_norm: Optional[float] = 1.0,
    input_size: int = 1,
    output_size: int = 2,
    hidden_sizes: Optional[List[int]] = None,
    dropout_prob: float = 0.3,
    batch_size: int = 32,
    random_state: int = 42,
    **kwargs
) -> Dict[str, List]:
    """
    Perform stratified K-fold cross-validation with full epoch training.
    
    Returns:
        Dictionary with cross-validation results and training history.
    """
    if hidden_sizes is None:
        hidden_sizes = [512, 256]
    
    # Setup output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Define plural forms mapping for metrics
    metric_plurals = {
        'loss': 'losses',
        'acc': 'accs', 
        'f1': 'f1s',
        'precision': 'precisions',
        'recall': 'recalls',
        'specificity': 'specificities',
        'mcc': 'mccs',
        'norm_mcc': 'norm_mccs'
    }
    
    # Initialize results tracking
    results = {
        'train_losses': [], 'val_losses': [],
        'train_accs': [], 'val_accs': [],
        'train_f1s': [], 'val_f1s': [],
        'train_precisions': [], 'val_precisions': [],
        'train_recalls': [], 'val_recalls': [],
        'train_specificities': [], 'val_specificities': [],
        'train_mccs': [], 'val_mccs': [],
        'train_norm_mccs': [], 'val_norm_mccs': [],
        'fold_results': [],
        'fold_histories': [],
        'class_accs': []
    }
    
    # Extract data
    all_data = dataset.tensors[0].numpy()
    all_labels = dataset.tensors[1].numpy()
    
    # Log dataset info
    logger.info(f"\nDataset: {len(all_labels)} samples, {all_data.shape[-1]} features")
    class_counts = np.bincount(all_labels.astype(int))
    imbalance_ratio = max(class_counts) / min(class_counts)
    for i, count in enumerate(class_counts):
        logger.info(f"Class {i}: {count} ({count/len(all_labels):.1%})")
    logger.info(f"Imbalance ratio: {imbalance_ratio:.2f}:1")
    
    # Setup stratified K-fold
    skf = StratifiedKFold(
        n_splits=num_folds,
        shuffle=True,
        random_state=random_state
    )
    
    # Calculate global class weights
    global_weights = calculate_class_weights(all_labels)
    weighted_loss_fn = nn.CrossEntropyLoss(weight=global_weights.to(device))
    
    # Multi-GPU setup
    num_gpus = len(device_ids) if device_ids else 1
    effective_batch_size = batch_size * num_gpus
    
    logger.info(f"\nTraining config: {num_folds} folds, batch_size={effective_batch_size}, "
                f"epochs={epochs} (full training)")
    logger.info(f"Using weighted loss function with class weights: {global_weights.numpy()}")
    
    # Train each fold
    start_time = time.time()
    
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(all_data, all_labels), 1):
        logger.info(f"\n{'='*60}")
        logger.info(f"FOLD {fold_idx}/{num_folds}")
        logger.info(f"{'='*60}")
        
        fold_start = time.time()
        
        try:
            # Create fold datasets
            train_subset = Subset(dataset, train_idx)
            val_subset = Subset(dataset, val_idx)
            
            # Log fold class distribution
            fold_train_labels = dataset.tensors[1][train_idx].numpy()
            fold_val_labels = dataset.tensors[1][val_idx].numpy()
            train_class_counts = np.bincount(fold_train_labels.astype(int))
            val_class_counts = np.bincount(fold_val_labels.astype(int))
            
            logger.info(f"Training set class distribution: Class 0: {train_class_counts[0]}, "
                       f"Class 1: {train_class_counts[1]}")
            logger.info(f"Validation set class distribution: Class 0: {val_class_counts[0]}, "
                       f"Class 1: {val_class_counts[1]}")
            
            # Create dataloaders
            train_loader = create_balanced_dataloader(
                train_subset, effective_batch_size, is_train=True
            )
            val_loader = create_balanced_dataloader(
                val_subset, effective_batch_size, is_train=False
            )
            
            # Initialize model
            model = model_class(
                input_size=input_size,
                hidden_sizes=hidden_sizes,
                output_size=output_size,
                dropout_prob=dropout_prob,
                **kwargs
            ).to(device)
            
            # Multi-GPU
            if device_ids and len(device_ids) > 1:
                model = nn.DataParallel(model, device_ids=device_ids)
            
            # Setup optimizer and scheduler
            optimizer = optimizer_class(model.parameters(), **optimizer_kwargs)
            
            # Handle scheduler kwargs
            sched_kwargs = scheduler_kwargs.copy()
            if scheduler_class == lr_scheduler.ReduceLROnPlateau:
                # Check PyTorch version
                import torch
                if hasattr(torch, '__version__'):
                    version = torch.__version__.split('.')
                    major_version = int(version[0])
                    if major_version < 2:  # PyTorch 1.x
                        sched_kwargs['verbose'] = False
                    else:  # PyTorch 2.x and above
                        sched_kwargs.pop('verbose', None)
            scheduler = scheduler_class(optimizer, **sched_kwargs)
            
            # Track logs
            epoch_logs = []
            
            # Train fold (full epochs)
            fold_results = train_eval_rnn(
                model, train_loader, val_loader,
                weighted_loss_fn, optimizer, scheduler,
                epochs, device, epoch_logs, max_grad_norm
            )
            
            # Unpack results
            (train_loss, val_loss, train_acc, val_acc,
             train_probs, train_labels_final,
             val_probs, val_labels_final,
             best_state, history) = fold_results
            
            # Save fold outputs
            fold_path = output_dir / f"{model_name}_{window_name}_fold_{fold_idx}"
            
            # Save model
            torch.save(best_state, f"{fold_path}.pth")
            
            # Save classification outputs
            np.save(f"{fold_path}_train_probs.npy", train_probs)
            np.save(f"{fold_path}_train_labels.npy", train_labels_final)
            np.save(f"{fold_path}_test_probs.npy", val_probs)
            np.save(f"{fold_path}_test_labels.npy", val_labels_final)
            
            # Save training history for plotting
            np.save(f"{fold_path}_history.npy", history)
            
            # Save logs
            with open(f"{fold_path}_logs.txt", 'w') as f:
                f.write('\n'.join(epoch_logs))
            
            # Calculate metrics
            train_classes = np.argmax(train_probs, axis=1)
            val_classes = np.argmax(val_probs, axis=1)
            
            train_metrics = calculate_binary_metrics(train_labels_final, train_classes)
            val_metrics = calculate_binary_metrics(val_labels_final, val_classes)
            
            # Store fold result with all metrics
            fold_result = {
                'fold': fold_idx,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'val_acc': val_acc,
                'train_f1': train_metrics['f1'],
                'val_f1': val_metrics['f1'],
                'train_precision': train_metrics['precision'],
                'val_precision': val_metrics['precision'],
                'train_recall': train_metrics['recall'],
                'val_recall': val_metrics['recall'],
                'train_specificity': train_metrics['specificity'],
                'val_specificity': val_metrics['specificity'],
                'train_mcc': train_metrics['mcc'],
                'val_mcc': val_metrics['mcc'],
                'train_norm_mcc': train_metrics['norm_mcc'],
                'val_norm_mcc': val_metrics['norm_mcc'],
                'train_class_acc': train_metrics['class_acc'].tolist(),
                'val_class_acc': val_metrics['class_acc'].tolist(),
                'model_path': str(fold_path) + '.pth',
                'time_taken': time.time() - fold_start
            }
            
            results['fold_results'].append(fold_result)
            results['fold_histories'].append(history)
            results['class_accs'].append({
                'train': train_metrics['class_acc'].tolist(),
                'val': val_metrics['class_acc'].tolist()
            })
            
            # Update aggregate metrics using plural mapping
            for metric in ['loss', 'acc', 'f1', 'precision', 'recall',
                          'specificity', 'mcc', 'norm_mcc']:
                plural_form = metric_plurals[metric]
                results[f'train_{plural_form}'].append(fold_result[f'train_{metric}'])
                results[f'val_{plural_form}'].append(fold_result[f'val_{metric}'])
            
            fold_time = time.time() - fold_start
            logger.info(f"Fold {fold_idx} completed in {fold_time:.1f}s")
            logger.info(f"Final Val F1: {val_metrics['f1']:.4f}, Val MCC: {val_metrics['mcc']:.4f}")
            logger.info(f"Val Class Acc - Class 0: {val_metrics['class_acc'][0]:.4f}, "
                       f"Class 1: {val_metrics['class_acc'][1]:.4f}")
            
        except Exception as e:
            logger.error(f"Error in fold {fold_idx}: {str(e)}")
            continue
        
        finally:
            # Clear GPU cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    
    # Calculate statistics for class-wise accuracy
    if results['class_accs']:
        class0_train_accs = [acc['train'][0] for acc in results['class_accs']]
        class1_train_accs = [acc['train'][1] for acc in results['class_accs']]
        class0_val_accs = [acc['val'][0] for acc in results['class_accs']]
        class1_val_accs = [acc['val'][1] for acc in results['class_accs']]
        
        class0_train_mean = np.mean(class0_train_accs)
        class0_train_std = np.std(class0_train_accs)
        class1_train_mean = np.mean(class1_train_accs)
        class1_train_std = np.std(class1_train_accs)
        class0_val_mean = np.mean(class0_val_accs)
        class0_val_std = np.std(class0_val_accs)
        class1_val_mean = np.mean(class1_val_accs)
        class1_val_std = np.std(class1_val_accs)
    else:
        class0_train_mean = class0_train_std = 0
        class1_train_mean = class1_train_std = 0
        class0_val_mean = class0_val_std = 0
        class1_val_mean = class1_val_std = 0
    
    # Save summary
    total_time = time.time() - start_time
    hours = int(total_time // 3600)
    minutes = int((total_time % 3600) // 60)
    seconds = int(total_time % 60)
    
    summary = {
        'model_name': model_name,
        'window_name': window_name,
        'num_folds': num_folds,
        'epochs': epochs,
        'total_time_seconds': total_time,
        'total_time_formatted': f"{hours}h {minutes}m {seconds}s",
        'config': {
            'hidden_sizes': hidden_sizes,
            'dropout_prob': dropout_prob,
            'batch_size': batch_size,
            'max_grad_norm': max_grad_norm,
            'optimizer': optimizer_class.__name__,
            'optimizer_kwargs': optimizer_kwargs,
            'scheduler': scheduler_class.__name__,
            'scheduler_kwargs': scheduler_kwargs,
            'class_weights': global_weights.tolist(),
            'balanced_sampling': True
        },
        'results': {}
    }
    
    # Calculate statistics with proper plural forms
    for metric in ['loss', 'acc', 'f1', 'precision', 'recall', 'specificity', 'mcc', 'norm_mcc']:
        for phase in ['train', 'val']:
            plural_form = metric_plurals[metric]
            key = f'{phase}_{plural_form}'
            
            if key in results and results[key]:
                values = np.array(results[key])
                summary['results'][f'{phase}_{metric}_mean'] = float(values.mean())
                summary['results'][f'{phase}_{metric}_std'] = float(values.std())
    
    # Add class-wise accuracy statistics
    summary['results']['train_class0_acc_mean'] = float(class0_train_mean)
    summary['results']['train_class0_acc_std'] = float(class0_train_std)
    summary['results']['train_class1_acc_mean'] = float(class1_train_mean)
    summary['results']['train_class1_acc_std'] = float(class1_train_std)
    summary['results']['val_class0_acc_mean'] = float(class0_val_mean)
    summary['results']['val_class0_acc_std'] = float(class0_val_std)
    summary['results']['val_class1_acc_mean'] = float(class1_val_mean)
    summary['results']['val_class1_acc_std'] = float(class1_val_std)
    
    # Add fold details
    summary['fold_results'] = results['fold_results']
    
    # Save summary
    summary_path = output_dir / f"{model_name}_{window_name}_summary.json"
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    logger.info(f"\n{'='*60}")
    logger.info("TRAINING COMPLETED")
    logger.info(f"{'='*60}")
    logger.info(f"Total training time: {hours}h {minutes}m {seconds}s")
    logger.info(f"Results saved to {output_dir}")
    
    # Log final comprehensive results
    logger.info("\nValidation Results (Mean ± Std):")
    logger.info(f"  Accuracy: {summary['results'].get('val_acc_mean', 0):.4f} ± "
                f"{summary['results'].get('val_acc_std', 0):.4f}")
    logger.info(f"  F1-Score: {summary['results'].get('val_f1_mean', 0):.4f} ± "
                f"{summary['results'].get('val_f1_std', 0):.4f}")
    logger.info(f"  Precision: {summary['results'].get('val_precision_mean', 0):.4f} ± "
                f"{summary['results'].get('val_precision_std', 0):.4f}")
    logger.info(f"  Recall: {summary['results'].get('val_recall_mean', 0):.4f} ± "
                f"{summary['results'].get('val_recall_std', 0):.4f}")
    logger.info(f"  Specificity: {summary['results'].get('val_specificity_mean', 0):.4f} ± "
                f"{summary['results'].get('val_specificity_std', 0):.4f}")
    logger.info(f"  MCC: {summary['results'].get('val_mcc_mean', 0):.4f} ± "
                f"{summary['results'].get('val_mcc_std', 0):.4f}")
    logger.info(f"  Normalized MCC: {summary['results'].get('val_norm_mcc_mean', 0):.4f} ± "
                f"{summary['results'].get('val_norm_mcc_std', 0):.4f}")
    logger.info(f"  Class 0 Accuracy: {class0_val_mean:.4f} ± {class0_val_std:.4f}")
    logger.info(f"  Class 1 Accuracy: {class1_val_mean:.4f} ± {class1_val_std:.4f}")
    
    return results

# Module 6: Training Orchestration and Ensemble Pipeline

This module implements the complete training pipeline for seismic geomagnetic signal recognition using RNN models with stratified K-fold cross-validation and F1-weighted ensemble strategy.

In [None]:
# ==================== Model Training and Ensemble Pipeline ====================
"""
Model Training and Ensemble Pipeline for Seismic Geomagnetic Signal Classification

This module orchestrates RNN model training with K-fold cross-validation
and F1-weighted ensemble creation for binary classification tasks.

Author: Tian Gao
Date: 2025-09-16
Version: 1.0.0
License: MIT
"""

# ==================== Main Training Orchestrator ====================

def main_rnn(optimize: bool = False) -> None:
    """
    Main training function with stratified K-fold cross-validation.
    
    Orchestrates the complete training pipeline across multiple time windows,
    including data loading, model training, and ensemble creation.
    
    Args:
        optimize: Whether to perform hyperparameter optimization
    """
    # Setup logging and device configuration
    setup_logging()
    device, device_ids = setup_device()
    
    # Configure paths
    base_path = "your_project"
    # Change to your actual project root directory
    # For example, if your project is at: C:/Users/Tian/Desktop/地磁论文代码运行测试
    # Then change to: base_path = r"C:\Users\Tian\Desktop\地磁论文代码运行测试"
    data_base_path = os.path.join(base_path, 'data')
    results_dir = os.path.join(base_path, 'results')
    output_dir = os.path.join(results_dir, 'rnn_models')
    os.makedirs(output_dir, exist_ok=True)
    
    logging.info(f"Starting training pipeline - Device: {device}")
    
    # Time window configuration
    window_mapping = {
        "window-7": {"name": "7day", "seq_length": 7},
        "window-14": {"name": "14day", "seq_length": 14},
        "window-30": {"name": "30day", "seq_length": 30}
    }
    
    try:
        # Training hyperparameters
        training_config = {
            'hidden_sizes': [512, 256],    # RNN hidden dimensions
            'num_layers': 2,                # Number of layers
            'dropout_prob': 0.3,            # Dropout rate
            'learning_rate': 0.001,         # Learning rate
            'batch_size': 32,               # Batch size
            'weight_decay': 0.01,           # L2 regularization
            'max_grad_norm': 1.0,           # Gradient clipping
            'epochs': 100,                  # Training epochs
            'num_folds': 5,                 # K-fold splits
            'nonlinearity': 'tanh'          # RNN activation function
        }
        
        # Optimizer configuration
        optimizer_kwargs = {
            'lr': training_config['learning_rate'],
            'weight_decay': training_config['weight_decay']
        }
        
        # Scheduler configuration
        scheduler_kwargs = {
            'mode': 'min',
            'factor': 0.1,
            'patience': 10,
            'min_lr': 1e-6,
            'verbose': True
        }
        
        num_gpus = len(device_ids) if device_ids else 1
        
        # Train models for each time window
        for window_period in ['7', '14', '30']:
            current_window_name = f"window-{window_period}"
            window_info = window_mapping[current_window_name]
            
            logging.info(f"\n{'='*60}")
            logging.info(f"Processing time window: {window_info['name']}")
            logging.info(f"{'='*60}")
            
            try:
                # Load data
                data_dir = os.path.join(data_base_path, current_window_name)
                data_0 = np.load(os.path.join(data_dir, "data_0.npy"))
                data_1 = np.load(os.path.join(data_dir, "data_1.npy"))
                
                # Combine and prepare data
                data = np.concatenate([data_0, data_1], axis=0)
                X = data[:, :-1]
                Y = data[:, -1].astype(int)
                
                # Log class distribution
                class_counts = np.bincount(Y)
                logging.info(f"Class distribution - Class 0: {class_counts[0]}, Class 1: {class_counts[1]}")
                
                # Calculate class weights
                class_weights = calculate_class_weights(Y)
                logging.info(f"Using sqrt-scaled class weights: {class_weights}")
                weighted_loss = nn.CrossEntropyLoss(weight=class_weights)
                
                # Reshape for RNN input
                X = reshape_data_for_rnn(X, window_info['seq_length'])
                
                # Convert to tensors
                X_tensor = torch.from_numpy(X).type(torch.float32)
                Y_tensor = torch.from_numpy(Y).type(torch.long)
                dataset = TensorDataset(X_tensor, Y_tensor)
                
                # Train with K-fold cross-validation
                results = kfold_train_eval_rnn(
                    model_class=RNNClassifier,
                    dataset=dataset,
                    loss_fn=weighted_loss,
                    optimizer_class=optim.AdamW,
                    optimizer_kwargs=optimizer_kwargs,
                    scheduler_class=optim.lr_scheduler.ReduceLROnPlateau,
                    scheduler_kwargs=scheduler_kwargs,
                    epochs=training_config['epochs'],
                    device=device,
                    device_ids=device_ids,
                    num_folds=training_config['num_folds'],
                    output_dir=output_dir,
                    model_name="RNNModel",
                    window_name=window_info['name'],
                    input_size=X.shape[2],
                    output_size=2,
                    hidden_sizes=training_config['hidden_sizes'],
                    dropout_prob=training_config['dropout_prob'],
                    batch_size=training_config['batch_size'],
                    max_grad_norm=training_config['max_grad_norm'],
                    nonlinearity=training_config['nonlinearity']
                )
                
                # Create F1-weighted ensemble
                perform_model_ensemble(
                    model_name="RNNModel",
                    window_name=window_info['name'],
                    output_dir=output_dir,
                    device=device,
                    device_ids=device_ids,
                    model_class=RNNClassifier,
                    input_size=X.shape[2],
                    hidden_sizes=training_config['hidden_sizes'],
                    output_size=2,
                    dropout_prob=training_config['dropout_prob'],
                    num_folds=training_config['num_folds'],
                    nonlinearity=training_config['nonlinearity']
                )
                
            except Exception as e:
                logging.error(f"Error processing {window_info['name']}: {str(e)}")
                logging.error("Detailed error:", exc_info=True)
                continue
        
        logging.info("\n" + "="*60)
        logging.info("All training completed successfully!")
        logging.info("="*60)
        
    except Exception as e:
        logging.error(f"Critical error in training pipeline: {str(e)}")
        raise


# ==================== F1-Weighted Ensemble System ====================

def perform_model_ensemble(
    model_name: str,
    window_name: str,
    output_dir: str,
    device: torch.device,
    device_ids: Optional[List[int]],
    model_class: Any,
    input_size: int,
    hidden_sizes: List[int],
    output_size: int,
    dropout_prob: float,
    num_folds: int,
    nonlinearity: str = 'tanh'
) -> None:
    """
    Create F1-weighted ensemble from K-fold models.
    
    Args:
        model_name: Model identifier
        window_name: Time window identifier
        output_dir: Output directory
        device: Training device
        device_ids: GPU device IDs
        model_class: Model class
        input_size: Input dimension
        hidden_sizes: RNN hidden sizes
        output_size: Number of classes
        dropout_prob: Dropout probability
        num_folds: Number of folds
        nonlinearity: RNN activation function
    """
    logging.info(f"Starting model ensemble - {model_name}_{window_name}")
    
    try:
        # Initialize storage
        models = []
        model_metrics = {
            'f1_scores': [],
            'precisions': [],
            'recalls': [],
            'specificities': [],
            'mccs': [],
            'norm_mccs': []
        }
        
        # Load each fold model
        for fold in range(1, num_folds + 1):
            # Create model
            model = model_class(
                input_size=input_size,
                hidden_sizes=hidden_sizes,
                output_size=output_size,
                dropout_prob=dropout_prob,
                nonlinearity=nonlinearity
            ).to(device)
            
            # Apply DataParallel
            if device_ids and len(device_ids) > 1:
                model = nn.DataParallel(model, device_ids=device_ids)
            
            # Load weights
            model_path = os.path.join(output_dir, f"{model_name}_{window_name}_fold_{fold}.pth")
            
            if os.path.exists(model_path):
                state_dict = torch.load(model_path, map_location=device)
                
                # Handle DataParallel state dict
                if device_ids and len(device_ids) > 1:
                    if not list(state_dict.keys())[0].startswith('module.'):
                        state_dict = {'module.' + k: v for k, v in state_dict.items()}
                
                model.load_state_dict(state_dict)
                model.eval()
                models.append(model)
                
                # Load validation metrics
                test_probs_path = os.path.join(output_dir, f"{model_name}_{window_name}_fold_{fold}_test_probs.npy")
                test_labels_path = os.path.join(output_dir, f"{model_name}_{window_name}_fold_{fold}_test_labels.npy")
                
                if os.path.exists(test_probs_path) and os.path.exists(test_labels_path):
                    # Calculate metrics
                    test_probs = np.load(test_probs_path)
                    test_labels = np.load(test_labels_path)
                    test_classes = np.argmax(test_probs, axis=1)
                    
                    # Compute metrics
                    fold_f1 = f1_score(test_labels, test_classes, average='weighted')
                    fold_precision = precision_score(test_labels, test_classes, average='weighted', zero_division=0)
                    fold_recall = recall_score(test_labels, test_classes, average='weighted', zero_division=0)
                    
                    # Calculate specificity
                    cm = confusion_matrix(test_labels, test_classes)
                    if cm.shape == (2, 2):
                        tn, fp = cm[0, 0], cm[0, 1]
                        fold_specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
                    else:
                        fold_specificity = 0
                    
                    # MCC metrics
                    fold_mcc = matthews_corrcoef(test_labels, test_classes)
                    fold_norm_mcc = (fold_mcc + 1) / 2
                    
                    # Ensure non-zero F1
                    fold_f1 = max(fold_f1, 0.01)
                    
                    # Store metrics
                    model_metrics['f1_scores'].append(fold_f1)
                    model_metrics['precisions'].append(fold_precision)
                    model_metrics['recalls'].append(fold_recall)
                    model_metrics['specificities'].append(fold_specificity)
                    model_metrics['mccs'].append(fold_mcc)
                    model_metrics['norm_mccs'].append(fold_norm_mcc)
                    
                    # Log metrics
                    logging.info(f"Fold {fold} metrics:")
                    logging.info(f"  F1-score: {fold_f1:.4f}")
                    logging.info(f"  Precision: {fold_precision:.4f}")
                    logging.info(f"  Recall: {fold_recall:.4f}")
                    logging.info(f"  Specificity: {fold_specificity:.4f}")
                    logging.info(f"  MCC: {fold_mcc:.4f}")
                    
                else:
                    # Load from config if available
                    _load_metrics_from_config(
                        output_dir, model_name, window_name, fold, model_metrics
                    )
                
                logging.info(f"Loaded fold {fold} model")
                
            else:
                logging.warning(f"Model not found for fold {fold}: {model_path}")
        
        # Create ensemble
        if models:
            # Calculate F1-based weights
            weights = np.array(model_metrics['f1_scores'])
            weights = weights / weights.sum()
            logging.info(f"F1-based weights: {weights}")
            
            # Create ensemble model
            ensemble_model = F1WeightedEnsemble(
                models=models,
                weights=weights 
            ).to(device)
            
            if device_ids and len(device_ids) > 1:
                ensemble_model = nn.DataParallel(ensemble_model, device_ids=device_ids)
            
            # Save ensemble
            ensemble_path = os.path.join(output_dir, f"{model_name}_{window_name}_ensemble.pth")
            torch.save(ensemble_model.state_dict(), ensemble_path)
            
            # Calculate average metrics
            avg_metrics = {}
            for metric_name, values in model_metrics.items():
                if values:
                    avg_metrics[f'avg_{metric_name}'] = float(np.mean(values))
                    avg_metrics[f'std_{metric_name}'] = float(np.std(values))
            
            # Save configuration
            ensemble_config = {
                'model_name': model_name,
                'window_name': window_name,
                'ensemble_weights': weights.tolist(),
                'ensemble_method': 'f1_weighted',
                'model_params': {
                    'input_size': input_size,
                    'hidden_sizes': hidden_sizes,
                    'output_size': output_size,
                    'dropout_prob': dropout_prob,
                    'nonlinearity': nonlinearity
                },
                'fold_metrics': model_metrics,
                'average_metrics': avg_metrics,
                'num_folds': num_folds
            }
            
            config_path = os.path.join(output_dir, f"{model_name}_{window_name}_ensemble_config.json")
            with open(config_path, 'w', encoding='utf-8') as f:
                json.dump(ensemble_config, f, indent=4, ensure_ascii=False)
            
            # Log results
            logging.info(f"Ensemble model saved: {ensemble_path}")
            logging.info("Average metrics across folds:")
            for metric_name, value in avg_metrics.items():
                if metric_name.startswith('avg_'):
                    display_name = metric_name[4:].replace('_', ' ').title()
                    logging.info(f"  {display_name}: {value:.4f}")
            
        else:
            logging.error("No models available for ensemble")
            
    except Exception as e:
        logging.error(f"Error in ensemble process: {str(e)}")
        logging.error("Detailed error:", exc_info=True)


def _load_metrics_from_config(
    output_dir: str,
    model_name: str,
    window_name: str,
    fold: int,
    model_metrics: Dict[str, List[float]]
) -> None:
    """
    Load metrics from saved configuration file.
    
    Args:
        output_dir: Output directory
        model_name: Model identifier
        window_name: Window identifier
        fold: Fold number
        model_metrics: Metrics dictionary
    """
    config_path = os.path.join(output_dir, f"{model_name}_{window_name}_summary.json")
    
    if os.path.exists(config_path):
        with open(config_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
            fold_results = config.get('fold_details', [])
            fold_result = next((r for r in fold_results if r['fold'] == fold), None)
            
            if fold_result:
                model_metrics['f1_scores'].append(fold_result.get('val_f1', 1.0))
                model_metrics['precisions'].append(fold_result.get('val_precision', 0.0))
                model_metrics['recalls'].append(fold_result.get('val_recall', 0.0))
                model_metrics['specificities'].append(fold_result.get('val_specificity', 0.0))
                model_metrics['mccs'].append(fold_result.get('val_mcc', 0.0))
                model_metrics['norm_mccs'].append(fold_result.get('val_norm_mcc', 0.5))
                logging.info(f"Fold {fold} F1-score (from config): {fold_result.get('val_f1', 1.0):.4f}")
            else:
                _add_default_metrics(model_metrics)
                logging.warning(f"Fold {fold} results not found, using defaults")
    else:
        _add_default_metrics(model_metrics)
        logging.warning(f"Config file not found for fold {fold}, using defaults")


def _add_default_metrics(model_metrics: Dict[str, List[float]]) -> None:
    """
    Add default metric values.
    
    Args:
        model_metrics: Metrics dictionary
    """
    model_metrics['f1_scores'].append(1.0)
    model_metrics['precisions'].append(0.0)
    model_metrics['recalls'].append(0.0)
    model_metrics['specificities'].append(0.0)
    model_metrics['mccs'].append(0.0)
    model_metrics['norm_mccs'].append(0.5)


# ==================== Model Loading Utilities ====================

def load_ensemble_model(
    model_path: str,
    config_path: str,
    device: torch.device,
    model_class: Any,
    device_ids: Optional[List[int]] = None
) -> F1WeightedEnsemble:
    """
    Load saved ensemble model with configuration.
    
    Args:
        model_path: Path to ensemble checkpoint
        config_path: Path to configuration JSON
        device: Device to load model
        model_class: Base model class
        device_ids: GPU device IDs
        
    Returns:
        Loaded F1WeightedEnsemble
    """
    # Load configuration
    with open(config_path, 'r', encoding='utf-8') as f:
        config = json.load(f)
    
    # Create base models
    base_models = []
    num_folds = config.get('num_folds', len(config['ensemble_weights']))
    model_params = config['model_params']
    
    for _ in range(num_folds):
        model_kwargs = {
            'input_size': model_params['input_size'],
            'hidden_sizes': model_params['hidden_sizes'],
            'output_size': model_params['output_size'],
            'dropout_prob': model_params['dropout_prob']
        }
        
        # Add RNN-specific parameters if present
        if 'nonlinearity' in model_params:
            model_kwargs['nonlinearity'] = model_params['nonlinearity']
        
        model = model_class(**model_kwargs).to(device)
        base_models.append(model)
    
    # Create ensemble
    ensemble_model = F1WeightedEnsemble(
        models=base_models,
        weights=config['ensemble_weights']
    ).to(device)
    
    # Apply DataParallel
    if device_ids and len(device_ids) > 1:
        ensemble_model = nn.DataParallel(ensemble_model, device_ids=device_ids)
    
    # Load weights
    state_dict = torch.load(model_path, map_location=device)
    ensemble_model.load_state_dict(state_dict)
    ensemble_model.eval()
    
    logging.info(f"Loaded ensemble model from {model_path}")
    logging.info(f"Ensemble method: {config.get('ensemble_method', 'unknown')}")
    logging.info(f"Number of base models: {num_folds}")
    
    return ensemble_model


# ==================== Entry Point ====================

if __name__ == "__main__":
    """
    Execute the complete training pipeline.
    """
    main_rnn(optimize=False)

2025-09-18 07:31:11 - __main__ - INFO - setup_logging:127 - Logging initialized. Output file: logs\training_20250918_073111.log
2025-09-18 07:31:11 - root - INFO - setup_device:161 - Auto-detected 1 GPU(s)
2025-09-18 07:31:11 - root - INFO - setup_device:177 - Single GPU training on device 0
2025-09-18 07:31:11 - root - INFO - main_rnn:40 - Starting training pipeline - Device: cuda:0
2025-09-18 07:31:11 - root - INFO - main_rnn:86 - 
2025-09-18 07:31:11 - root - INFO - main_rnn:87 - Processing time window: 7day
2025-09-18 07:31:11 - root - INFO - main_rnn:103 - Class distribution - Class 0: 1661, Class 1: 999
2025-09-18 07:31:11 - root - INFO - calculate_class_weights:245 - Class distribution: [1661, 999]
2025-09-18 07:31:11 - root - INFO - calculate_class_weights:246 - Imbalance ratio: 1.66:1
2025-09-18 07:31:11 - root - INFO - calculate_class_weights:268 - Class weights (sqrt): [0.8735750317573547, 1.12642502784729]
2025-09-18 07:31:11 - root - INFO - main_rnn:107 - Using sqrt-scaled



2025-09-18 07:31:12 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.6245)
2025-09-18 07:31:12 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8659, Train Acc: 0.5893, Class 0/1: 0.5117/0.6615, Train F1: 0.5869, Train Precision: 0.5888, Train Recall: 0.5893, Train Specificity: 0.5117, Train MCC: 0.1753, Train Norm MCC: 0.5876, Val Loss: 0.6975, Val Acc: 0.6184, Class 0/1: 0.6096/0.6332, Val F1: 0.6245, Val Precision: 0.6445, Val Recall: 0.6184, Val Specificity: 0.6096, Val MCC: 0.2351, Val Norm MCC: 0.6176, LR: 0.001000
2025-09-18 07:31:12 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7338, Train Acc: 0.6118, Class 0/1: 0.5046/0.7253, Train F1: 0.6074, Train Precision: 0.6215, Train Recall: 0.6118, Train Specificity: 0.5046, Train MCC: 0.2353, Train Norm MCC: 0.6176, Val Loss: 0.7773, Val Acc: 0.5996, Class 0/1: 0.6396/0.5327, Val F1: 0.6039, Val Precision: 0.6111, Val Recall: 0.5996, Val Specificity: 0.6396, Val MCC: 0.



2025-09-18 07:32:04 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.5998)
2025-09-18 07:32:04 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8822, Train Acc: 0.5949, Class 0/1: 0.5439/0.6455, Train F1: 0.5939, Train Precision: 0.5957, Train Recall: 0.5949, Train Specificity: 0.5439, Train MCC: 0.1904, Train Norm MCC: 0.5952, Val Loss: 0.7034, Val Acc: 0.6259, Class 0/1: 0.8193/0.3050, Val F1: 0.5998, Val Precision: 0.6025, Val Recall: 0.6259, Val Specificity: 0.8193, Val MCC: 0.1436, Val Norm MCC: 0.5718, LR: 0.001000
2025-09-18 07:32:05 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7500, Train Acc: 0.6283, Class 0/1: 0.5360/0.7214, Train F1: 0.6251, Train Precision: 0.6334, Train Recall: 0.6283, Train Specificity: 0.5360, Train MCC: 0.2619, Train Norm MCC: 0.6310, Val Loss: 0.7672, Val Acc: 0.5771, Class 0/1: 0.5422/0.6350, Val F1: 0.5834, Val Precision: 0.6151, Val Recall: 0.5771, Val Specificity: 0.5422, Val MCC: 0.



2025-09-18 07:32:59 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.4528)
2025-09-18 07:32:59 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8776, Train Acc: 0.5874, Class 0/1: 0.4947/0.6772, Train F1: 0.5839, Train Precision: 0.5888, Train Recall: 0.5874, Train Specificity: 0.4947, Train MCC: 0.1749, Train Norm MCC: 0.5874, Val Loss: 0.9135, Val Acc: 0.4887, Class 0/1: 0.2590/0.8700, Val F1: 0.4528, Val Precision: 0.6349, Val Recall: 0.4887, Val Specificity: 0.2590, Val MCC: 0.1533, Val Norm MCC: 0.5767, LR: 0.001000
2025-09-18 07:32:59 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 2 (F1: 0.6041)
2025-09-18 07:32:59 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7569, Train Acc: 0.6198, Class 0/1: 0.5515/0.6850, Train F1: 0.6180, Train Precision: 0.6204, Train Recall: 0.6198, Train Specificity: 0.5515, Train MCC: 0.2388, Train Norm MCC: 0.6194, Val Loss: 0.6937, Val Acc: 0.6071, Class 0/1: 0.4639/0.84



2025-09-18 07:33:58 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.3639)
2025-09-18 07:33:58 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8483, Train Acc: 0.5827, Class 0/1: 0.5168/0.6493, Train F1: 0.5809, Train Precision: 0.5846, Train Recall: 0.5827, Train Specificity: 0.5168, Train MCC: 0.1676, Train Norm MCC: 0.5838, Val Loss: 1.1595, Val Acc: 0.4417, Class 0/1: 0.1476/0.9300, Val F1: 0.3639, Val Precision: 0.6345, Val Recall: 0.4417, Val Specificity: 0.1476, Val MCC: 0.1163, Val Norm MCC: 0.5582, LR: 0.001000
2025-09-18 07:33:59 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 2 (F1: 0.6165)
2025-09-18 07:33:59 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7269, Train Acc: 0.6259, Class 0/1: 0.5484/0.7097, Train F1: 0.6239, Train Precision: 0.6334, Train Recall: 0.6259, Train Specificity: 0.5484, Train MCC: 0.2609, Train Norm MCC: 0.6305, Val Loss: 0.7404, Val Acc: 0.6165, Class 0/1: 0.6928/0.49



2025-09-18 07:35:12 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.5628)
2025-09-18 07:35:12 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8540, Train Acc: 0.5926, Class 0/1: 0.5132/0.6664, Train F1: 0.5901, Train Precision: 0.5921, Train Recall: 0.5926, Train Specificity: 0.5132, Train MCC: 0.1817, Train Norm MCC: 0.5909, Val Loss: 0.8503, Val Acc: 0.5846, Class 0/1: 0.7651/0.2850, Val F1: 0.5628, Val Precision: 0.5580, Val Recall: 0.5846, Val Specificity: 0.7651, Val MCC: 0.0557, Val Norm MCC: 0.5279, LR: 0.001000
2025-09-18 07:35:12 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 2 (F1: 0.6088)
2025-09-18 07:35:12 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7889, Train Acc: 0.5851, Class 0/1: 0.4366/0.7313, Train F1: 0.5758, Train Precision: 0.5918, Train Recall: 0.5851, Train Specificity: 0.4366, Train MCC: 0.1758, Train Norm MCC: 0.5879, Val Loss: 0.6974, Val Acc: 0.6034, Class 0/1: 0.5482/0.69



2025-09-18 07:36:33 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.6481)
2025-09-18 07:36:33 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8689, Train Acc: 0.5571, Class 0/1: 0.4807/0.6381, Train F1: 0.5546, Train Precision: 0.5615, Train Recall: 0.5571, Train Specificity: 0.4807, Train MCC: 0.1202, Train Norm MCC: 0.5601, Val Loss: 0.7370, Val Acc: 0.6414, Class 0/1: 0.7074/0.4909, Val F1: 0.6481, Val Precision: 0.6575, Val Recall: 0.6414, Val Specificity: 0.7074, Val MCC: 0.1911, Val Norm MCC: 0.5955, LR: 0.001000
2025-09-18 07:36:34 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7727, Train Acc: 0.5673, Class 0/1: 0.4033/0.7311, Train F1: 0.5553, Train Precision: 0.5753, Train Recall: 0.5673, Train Specificity: 0.4033, Train MCC: 0.1422, Train Norm MCC: 0.5711, Val Loss: 1.4256, Val Acc: 0.3364, Class 0/1: 0.0745/0.9333, Val F1: 0.2346, Val Precision: 0.5925, Val Recall: 0.3364, Val Specificity: 0.0745, Val MCC: 0.



2025-09-18 07:38:25 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.4123)
2025-09-18 07:38:25 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8211, Train Acc: 0.5834, Class 0/1: 0.4898/0.6762, Train F1: 0.5798, Train Precision: 0.5859, Train Recall: 0.5834, Train Specificity: 0.4898, Train MCC: 0.1689, Train Norm MCC: 0.5845, Val Loss: 0.9814, Val Acc: 0.4436, Class 0/1: 0.2394/0.9091, Val F1: 0.4123, Val Precision: 0.7007, Val Recall: 0.4436, Val Specificity: 0.2394, Val MCC: 0.1728, Val Norm MCC: 0.5864, LR: 0.001000
2025-09-18 07:38:26 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7792, Train Acc: 0.5696, Class 0/1: 0.4209/0.7234, Train F1: 0.5597, Train Precision: 0.5798, Train Recall: 0.5696, Train Specificity: 0.4209, Train MCC: 0.1513, Train Norm MCC: 0.5756, Val Loss: 1.3496, Val Acc: 0.3660, Class 0/1: 0.1197/0.9273, Val F1: 0.2883, Val Precision: 0.6451, Val Recall: 0.3660, Val Specificity: 0.1197, Val MCC: 0.



2025-09-18 07:40:14 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.3079)
2025-09-18 07:40:14 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8563, Train Acc: 0.5798, Class 0/1: 0.4785/0.6792, Train F1: 0.5754, Train Precision: 0.5821, Train Recall: 0.5798, Train Specificity: 0.4785, Train MCC: 0.1611, Train Norm MCC: 0.5805, Val Loss: 1.1861, Val Acc: 0.3826, Class 0/1: 0.1330/0.9515, Val F1: 0.3079, Val Precision: 0.6983, Val Recall: 0.3826, Val Specificity: 0.1330, Val MCC: 0.1257, Val Norm MCC: 0.5629, LR: 0.001000
2025-09-18 07:40:15 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 2 (F1: 0.6117)
2025-09-18 07:40:15 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.8056, Train Acc: 0.5460, Class 0/1: 0.3726/0.7291, Train F1: 0.5315, Train Precision: 0.5590, Train Recall: 0.5460, Train Specificity: 0.3726, Train MCC: 0.1087, Train Norm MCC: 0.5544, Val Loss: 0.6773, Val Acc: 0.6562, Class 0/1: 0.8697/0.16



2025-09-18 07:42:06 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.6544)
2025-09-18 07:42:06 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.9132, Train Acc: 0.5599, Class 0/1: 0.4564/0.6585, Train F1: 0.5553, Train Precision: 0.5600, Train Recall: 0.5599, Train Specificity: 0.4564, Train MCC: 0.1174, Train Norm MCC: 0.5587, Val Loss: 0.6522, Val Acc: 0.6488, Class 0/1: 0.7173/0.4940, Val F1: 0.6544, Val Precision: 0.6621, Val Recall: 0.6488, Val Specificity: 0.7173, Val MCC: 0.2047, Val Norm MCC: 0.6023, LR: 0.001000
2025-09-18 07:42:07 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7675, Train Acc: 0.5502, Class 0/1: 0.3187/0.7590, Train F1: 0.5268, Train Precision: 0.5485, Train Recall: 0.5502, Train Specificity: 0.3187, Train MCC: 0.0866, Train Norm MCC: 0.5433, Val Loss: 1.3231, Val Acc: 0.3235, Class 0/1: 0.0507/0.9398, Val F1: 0.2064, Val Precision: 0.5476, Val Recall: 0.3235, Val Specificity: 0.0507, Val MCC: -0



2025-09-18 07:43:58 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.2710)
2025-09-18 07:43:58 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8833, Train Acc: 0.5536, Class 0/1: 0.4140/0.6872, Train F1: 0.5450, Train Precision: 0.5546, Train Recall: 0.5536, Train Specificity: 0.4140, Train MCC: 0.1052, Train Norm MCC: 0.5526, Val Loss: 1.1313, Val Acc: 0.3611, Class 0/1: 0.1013/0.9515, Val F1: 0.2710, Val Precision: 0.6708, Val Recall: 0.3611, Val Specificity: 0.1013, Val MCC: 0.0872, Val Norm MCC: 0.5436, LR: 0.001000
2025-09-18 07:43:59 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 2 (F1: 0.5117)
2025-09-18 07:43:59 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7768, Train Acc: 0.5508, Class 0/1: 0.3578/0.7467, Train F1: 0.5333, Train Precision: 0.5617, Train Recall: 0.5508, Train Specificity: 0.3578, Train MCC: 0.1134, Train Norm MCC: 0.5567, Val Loss: 0.7745, Val Acc: 0.5093, Class 0/1: 0.3760/0.81



2025-09-18 07:45:14 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.2878)
2025-09-18 07:45:14 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8618, Train Acc: 0.5662, Class 0/1: 0.4442/0.6815, Train F1: 0.5599, Train Precision: 0.5666, Train Recall: 0.5662, Train Specificity: 0.4442, Train MCC: 0.1294, Train Norm MCC: 0.5647, Val Loss: 1.4354, Val Acc: 0.3895, Class 0/1: 0.0912/0.9416, Val F1: 0.2878, Val Precision: 0.6082, Val Recall: 0.3895, Val Specificity: 0.0912, Val MCC: 0.0578, Val Norm MCC: 0.5289, LR: 0.001000
2025-09-18 07:45:15 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.8385, Train Acc: 0.5428, Class 0/1: 0.3696/0.7009, Train F1: 0.5298, Train Precision: 0.5400, Train Recall: 0.5428, Train Specificity: 0.3696, Train MCC: 0.0747, Train Norm MCC: 0.5374, Val Loss: 0.9795, Val Acc: 0.3827, Class 0/1: 0.0947/0.9156, Val F1: 0.2868, Val Precision: 0.5622, Val Recall: 0.3827, Val Specificity: 0.0947, Val MCC: 0.



2025-09-18 07:46:34 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.5959)
2025-09-18 07:46:34 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8737, Train Acc: 0.5779, Class 0/1: 0.5123/0.6397, Train F1: 0.5761, Train Precision: 0.5774, Train Recall: 0.5779, Train Specificity: 0.5123, Train MCC: 0.1533, Train Norm MCC: 0.5767, Val Loss: 0.7026, Val Acc: 0.6050, Class 0/1: 0.7394/0.3571, Val F1: 0.5959, Val Precision: 0.5906, Val Recall: 0.6050, Val Specificity: 0.7394, Val MCC: 0.1012, Val Norm MCC: 0.5506, LR: 0.001000
2025-09-18 07:46:34 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7178, Train Acc: 0.6047, Class 0/1: 0.4204/0.7723, Train F1: 0.5915, Train Precision: 0.6098, Train Recall: 0.6047, Train Specificity: 0.4204, Train MCC: 0.2064, Train Norm MCC: 0.6032, Val Loss: 0.8225, Val Acc: 0.5320, Class 0/1: 0.3486/0.8701, Val F1: 0.5178, Val Precision: 0.6871, Val Recall: 0.5320, Val Specificity: 0.3486, Val MCC: 0.



2025-09-18 07:47:53 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.4851)
2025-09-18 07:47:53 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.8517, Train Acc: 0.5881, Class 0/1: 0.4604/0.7040, Train F1: 0.5817, Train Precision: 0.5877, Train Recall: 0.5881, Train Specificity: 0.4604, Train MCC: 0.1697, Train Norm MCC: 0.5849, Val Loss: 0.7255, Val Acc: 0.4977, Class 0/1: 0.3310/0.8052, Val F1: 0.4851, Val Precision: 0.6304, Val Recall: 0.4977, Val Specificity: 0.3310, Val MCC: 0.1443, Val Norm MCC: 0.5722, LR: 0.001000
2025-09-18 07:47:54 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 2 (F1: 0.5854)
2025-09-18 07:47:54 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7456, Train Acc: 0.5744, Class 0/1: 0.3728/0.7656, Train F1: 0.5570, Train Precision: 0.5815, Train Recall: 0.5744, Train Specificity: 0.3728, Train MCC: 0.1507, Train Norm MCC: 0.5753, Val Loss: 0.6883, Val Acc: 0.6507, Class 0/1: 0.9190/0.15



2025-09-18 07:48:47 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.6104)
2025-09-18 07:48:47 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.9411, Train Acc: 0.5705, Class 0/1: 0.4926/0.6487, Train F1: 0.5678, Train Precision: 0.5725, Train Recall: 0.5705, Train Specificity: 0.4926, Train MCC: 0.1431, Train Norm MCC: 0.5715, Val Loss: 0.7219, Val Acc: 0.6735, Class 0/1: 0.9401/0.1818, Val F1: 0.6104, Val Precision: 0.6593, Val Recall: 0.6735, Val Specificity: 0.9401, Val MCC: 0.1918, Val Norm MCC: 0.5959, LR: 0.001000
2025-09-18 07:48:48 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.7905, Train Acc: 0.5682, Class 0/1: 0.3979/0.7292, Train F1: 0.5557, Train Precision: 0.5712, Train Recall: 0.5682, Train Specificity: 0.3979, Train MCC: 0.1348, Train Norm MCC: 0.5674, Val Loss: 0.7028, Val Acc: 0.4795, Class 0/1: 0.3204/0.7727, Val F1: 0.4674, Val Precision: 0.6024, Val Recall: 0.4795, Val Specificity: 0.3204, Val MCC: 0.



2025-09-18 07:49:37 - __main__ - INFO - train_eval_rnn:96 - New best model at epoch 1 (F1: 0.5217)
2025-09-18 07:49:37 - __main__ - INFO - train_eval_rnn:130 - Epoch 1/100, Train Loss: 0.9139, Train Acc: 0.5499, Class 0/1: 0.4696/0.6293, Train F1: 0.5470, Train Precision: 0.5507, Train Recall: 0.5499, Train Specificity: 0.4696, Train MCC: 0.1001, Train Norm MCC: 0.5501, Val Loss: 0.7717, Val Acc: 0.5434, Class 0/1: 0.7254/0.2078, Val F1: 0.5217, Val Precision: 0.5095, Val Recall: 0.5434, Val Specificity: 0.7254, Val MCC: -0.0736, Val Norm MCC: 0.4632, LR: 0.001000
2025-09-18 07:49:37 - __main__ - INFO - train_eval_rnn:130 - Epoch 2/100, Train Loss: 0.8264, Train Acc: 0.5516, Class 0/1: 0.4025/0.7062, Train F1: 0.5412, Train Precision: 0.5602, Train Recall: 0.5516, Train Specificity: 0.4025, Train MCC: 0.1139, Train Norm MCC: 0.5570, Val Loss: 1.8062, Val Acc: 0.3881, Class 0/1: 0.0599/0.9935, Val F1: 0.2604, Val Precision: 0.7405, Val Recall: 0.3881, Val Specificity: 0.0599, Val MCC: 0