<a href="https://colab.research.google.com/github/ninja-marduk/ml_precipitation_prediction/blob/feature%2Fhybrid-models/models/hybrid_models_enconders_layering_w3_ST-HybridWaveStack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Meta-Modelo Convolucional: ElevClusConvPrecipMetaNet

Este modelo implementa una arquitectura convolucional avanzada para la predicción espaciotemporal de precipitación con horizonte de 12 meses.

## Arquitectura

1. **Entradas**:
   - Mapas de predicción de los modelos base (ConvBiGRU-AE y ConvLSTM-AE) para cada horizonte (12 meses)
   - Información de elevación y clusters para condicionamiento (FiLM)

2. **Reducción Temprana de Canales**:
   - Reduce los 24 canales (2 modelos × 12 horizontes) a 16 para optimizar memoria
   - Aplica Conv2D(1×1) para mezclar información sin perder resolución espacial

3. **Bloques Residuales Multiescala**:
   - Bloques depthwise-separables con distintas dilataciones (1,2,4)
   - Captura patrones a diferentes escalas espaciales sin incrementar parámetros

4. **Atención Espacial por Cluster**:
   - FiLM (Feature-wise Linear Modulation): γ_cluster ⊗ F + β_cluster
   - Adapta el comportamiento según el régimen orográfico

5. **U-Net Compacto**:
   - Arquitectura de encoder-decoder con skip connections
   - Solo 2 niveles de downsampling para preservar detalle

6. **Agrupamiento de Horizontes**:
   - Conv3D para procesar conjuntamente la dimensión temporal de horizontes
   - Permite aprender relaciones entre meses consecutivos

7. **Salida Multi-Horizonte**:
   - Genera los 12 mapas refinados de predicción

8. **Estrategias Memory-Friendly**:
   - Mixed precision (float16)
   - Gradient checkpointing
   - Acumulación de gradientes
   - Entrenamiento por etapas

In [None]:
# Predicción Espaciotemporal de Precipitación Mensual - Notebook Completo

# 0) Configuración del entorno, rutas y dependencias
import sys
import os
import logging
from pathlib import Path
import joblib  # Para persistir scalers
from datetime import datetime

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

def log_and_print(msg):
    logger.info(msg)
    print(msg)

# Detectar entorno (Colab o local)
IN_COLAB = "google.colab" in sys.modules
log_and_print(f"Ejecutando en Colab: {IN_COLAB}")

# Definir rutas base
desired_repo = 'ml_precipitation_prediction'
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_PATH = Path('/content/drive/MyDrive') / desired_repo
    if not Path(desired_repo).exists():
        log_and_print("Clonando repositorio...")
        get_ipython().system('git clone https://github.com/ninja-marduk/ml_precipitation_prediction.git')
    os.chdir(desired_repo)
    get_ipython().system('pip install -q xarray netCDF4 optuna matplotlib seaborn lightgbm xgboost scikit-learn ace_tools_open cartopy geopandas joblib')
else:
    # Instalación más simple sin PyEMD
    %pip install -q xarray netCDF4 scikit-image
    current = Path.cwd()
    for p in [current] + list(current.parents):
        if (p / '.git').is_dir() or (p / 'requirements.txt').is_file() or (p / 'README.md').is_file():
            BASE_PATH = p
            break
    else:
        BASE_PATH = current
    log_and_print(f"Ejecutando en local. Base path: {BASE_PATH}")

# Rutas de datos y modelos
DATA_OUTPUT   = BASE_PATH / 'data' / 'output'
MODELS_OUTPUT = BASE_PATH / 'models' / 'output'
PREDS_DIR     = MODELS_OUTPUT / 'base_model_predictions'
SHP_PATH      = BASE_PATH / 'data' / 'input' / 'shapes' / 'MGN_Departamento.shp'

# Crear directorios
MODELS_OUTPUT.mkdir(parents=True, exist_ok=True)
PREDS_DIR.mkdir(parents=True, exist_ok=True)

# Parámetros generales
INPUT_WINDOW   = 60
OUTPUT_HORIZON = 12  # 12 meses
BATCH_SIZE     = 16
MAX_EPOCHS     = 300
PATIENCE       = 50
LR             = 1e-3

import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log_and_print(f"Usando dispositivo: {DEVICE}")


# 1) Imports adicionales y utilidades
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature  # Añadido para características cartográficas
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.metrics import mean_absolute_percentage_error

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler # Mixed precision training
from torch.utils.checkpoint import checkpoint # Gradient checkpointing
from torch.optim.lr_scheduler import ReduceLROnPlateau  # Añadido para learning rate adaptativo
from IPython.display import clear_output  # Para actualizar gráficos durante entrenamiento
import pywt
from scipy.signal import hilbert
from skimage.restoration import denoise_wavelet

# 2) Carga y preprocesamiento de datos - Versión simplificada

def load_and_preprocess_data():
    log_and_print("Cargando datos...")
    
    # Cargar datos completos
    ds_full = xr.open_dataset(DATA_OUTPUT / 'complete_dataset_with_features_with_clusters_elevation_with_windows.nc')
    log_and_print(f"Dataset completo cargado, dims: {ds_full.dims}")
    
    # Cargar componentes directamente de archivos específicos
    ds_ceemdan = xr.open_dataset(MODELS_OUTPUT / 'features_CEEMDAN.nc')
    log_and_print(f"Dataset CEEMDAN cargado, dims: {ds_ceemdan.dims}")
    log_and_print(f"Variables disponibles en CEEMDAN: {list(ds_ceemdan.data_vars.keys())}")
    
    ds_tvfemd = xr.open_dataset(MODELS_OUTPUT / 'features_TVFEMD.nc')
    log_and_print(f"Dataset TVF-EMD cargado, dims: {ds_tvfemd.dims}")
    log_and_print(f"Variables disponibles en TVF-EMD: {list(ds_tvfemd.data_vars.keys())}")
    
    # Cargar shapefile para visualizaciones
    gdf = gpd.read_file(SHP_PATH)
    if gdf.crs is None:
        gdf = gdf.set_crs(epsg=4326)
    elif gdf.crs.to_epsg() != 4326:
        gdf = gdf.to_crs(epsg=4326)
    log_and_print("Shapefile cargado y CRS validado.")

    # Extraer información temporal
    times = ds_full.time.values.astype('datetime64[M]')
    REF = np.datetime64('2024-02','M')
    idx_ref = int(np.where(times==REF)[0][0])
    log_and_print(f"Referencia (REF) = {REF}, index={idx_ref}")
    
    return ds_full, ds_ceemdan, ds_tvfemd, gdf, times, REF, idx_ref

# Ejecutar carga/preproc con los archivos específicos
ds_full, ds_ceemdan, ds_tvfemd, gdf, times, REF, idx_ref = load_and_preprocess_data()

# Preparar datos para entrenamiento y evaluación
def prepare_train_val_data(ds_full, input_window=INPUT_WINDOW, output_horizon=OUTPUT_HORIZON):
    """
    Prepara los datos de entrenamiento y validación desde el dataset completo.
    
    Args:
        ds_full: Dataset completo con variables
        input_window: Tamaño de la ventana de entrada
        output_horizon: Horizonte de predicción
    
    Returns:
        X_train, y_train, X_val, y_val: Arrays de entrenamiento y validación
    """
    print("Preparando datos para entrenamiento y validación...")
    
    # Extraer precipitación como variable objetivo
    # Check if 'precipitacion' exists in the dataset
    if 'precipitacion' in ds_full.data_vars:
        precip = ds_full.precipitacion.values
        print("Using 'precipitacion' as target variable")
    # If not, check for 'total_precipitation'
    elif 'total_precipitation' in ds_full.data_vars:
        precip = ds_full.total_precipitation.values
        print("Using 'total_precipitation' as target variable")
    else:
        # If neither exists, raise an error
        raise ValueError("Could not find precipitation variable in dataset. Available variables: " + 
                        str(list(ds_full.data_vars.keys())))
    
    # Ensure precip has 3 dimensions (time, lat, lon)
    if len(precip.shape) != 3:
        raise ValueError(f"Expected precipitation data to have 3 dimensions (time, lat, lon), got {precip.shape}")
    
    n_times, height, width = precip.shape
    print(f"Precipitation data shape: {precip.shape}")
    
    # Extraer features, por ejemplo cluster y elevación
    features = []
    
    # Verificar si las features estáticas existen y expandirlas para todas las timesteps
    if 'cluster' in ds_full.data_vars:
        cluster_data = ds_full.cluster.values
        # Si cluster es 2D (lat, lon), expandir a 3D (time, lat, lon)
        if len(cluster_data.shape) == 2:
            cluster_data = np.repeat(cluster_data[np.newaxis, :, :], n_times, axis=0)
        features.append(cluster_data)
        print("Añadida variable 'cluster', shape:", cluster_data.shape)
    
    if 'elevation' in ds_full.data_vars:
        elev_data = ds_full.elevation.values
        # Si elevation es 2D (lat, lon), expandir a 3D (time, lat, lon)
        if len(elev_data.shape) == 2:
            elev_data = np.repeat(elev_data[np.newaxis, :, :], n_times, axis=0)
        features.append(elev_data)
        print("Añadida variable 'elevation', shape:", elev_data.shape)
    
    # Si no hay features específicas, usar precipitación histórica como feature
    if not features:
        print("No se encontraron features específicas, usando precipitación histórica")
        # Usar precipitación como característica, añadiendo una dimensión de canal
        features_array = precip.reshape(n_times, 1, height, width)
        print(f"Shape de features: {features_array.shape}")
    else:
        # Concatenar features en un solo array a lo largo de una nueva dimensión (canal)
        features_array = np.stack(features, axis=1)
        print(f"Shape de features combinadas: {features_array.shape}")
    
    # Crear ventanas deslizantes de manera segura
    X, y = [], []
    
    for i in range(n_times - input_window - output_horizon + 1):
        # Input: ventana de datos
        X.append(features_array[i:i+input_window].copy())
        # Output: horizonte de predicción 
        y.append(precip[i+input_window:i+input_window+output_horizon].copy())
    
    # Verificar formas antes de convertir
    if not X or not y:
        raise ValueError("No se pudieron crear ventanas válidas. Verifique los datos de entrada.")
        
    # Usar np.stack en lugar de np.array para garantizar arrays homogéneos
    X = np.stack(X)
    y = np.stack(y)
    
    print(f"Datos preparados - X shape: {X.shape}, y shape: {y.shape}")
    
    # División train/val (80/20)
    split_idx = int(0.8 * len(X))
    X_train, X_val = X[:split_idx], X[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]
    
    print(f"Train shapes - X: {X_train.shape}, y: {y_train.shape}")
    print(f"Val shapes - X: {X_val.shape}, y: {y_val.shape}")
    
    return X_train, y_train, X_val, y_val

# Función para entrenar y comparar modelos
def train_and_compare_models(X_train, y_train, X_val, y_val, force_retrain=False):
    """
    Entrena y compara modelos simples y híbridos usando las mismas entradas.
    
    Args:
        X_train: Datos de entrenamiento
        y_train: Etiquetas de entrenamiento
        X_val: Datos de validación
        y_val: Etiquetas de validación
        force_retrain: Si reentrenar aunque existan modelos guardados
    
    Returns:
        dict, DataLoader: Diccionario con modelos entrenados y sus métricas, y el DataLoader de validación
    """
    print("\n" + "="*70)
    print("ENTRENAMIENTO Y COMPARACIÓN DE MODELOS SIMPLES Y HÍBRIDOS")
    print("="*70)
    
    # Directorios para guardar modelos
    models_dir = MODELS_OUTPUT / 'comparison_models'
    models_dir.mkdir(exist_ok=True, parents=True)
    
    # Parámetros comunes
    input_channels = X_train.shape[1] if len(X_train.shape) > 3 else 1
    hidden_dim = 128
    output_channels = 1
    seq_length = OUTPUT_HORIZON
    
    # Obtener shape de target si está disponible
    if len(y_train.shape) >= 3:
        target_shape = y_train.shape[-2:]
    else:
        target_shape = (61, 65)  # valores por defecto
    
    print(f"Configuración común: input_channels={input_channels}, hidden_dim={hidden_dim}, output_channels={output_channels}")
    print(f"Secuencia: {seq_length}, target_shape: {target_shape}")
    
    # Definir los modelos a entrenar con sus rutas de guardado
    models = {
        'SimpleConvGRU': {
            'class': SimpleConvGRU,
            'path': models_dir / 'simple_convgru.pth',
            'params': {
                'input_channels': input_channels,
                'hidden_dim': hidden_dim,
                'output_channels': output_channels,
                'seq_length': seq_length,
                'target_shape': target_shape
            },
            'metrics': {},
            'train_losses': [],
            'val_losses': []
        },
        'SimpleConvLSTM': {
            'class': SimpleConvLSTM,
            'path': models_dir / 'simple_convlstm.pth',
            'params': {
                'input_channels': input_channels,
                'hidden_dim': hidden_dim,
                'output_channels': output_channels,
                'seq_length': seq_length,
                'target_shape': target_shape
            },
            'metrics': {},
            'train_losses': [],
            'val_losses': []
        },
        'ConvBiGRU-AE': {
            'class': ConvBiGRU_AE,
            'path': models_dir / 'convbigru_ae.pth',
            'params': {
                'input_channels': input_channels,
                'hidden_dim': hidden_dim,
                'num_layers': 3,
                'output_channels': output_channels,
                'seq_length': seq_length,
                'target_shape': target_shape
            },
            'metrics': {},
            'train_losses': [],
            'val_losses': []
        },
        'ConvBiLSTM-AE': {
            'class': ConvBiLSTM_AE,
            'path': models_dir / 'convbilstm_ae.pth',
            'params': {
                'input_channels': input_channels,
                'hidden_dim': hidden_dim,
                'num_layers': 3,
                'output_channels': output_channels,
                'seq_length': seq_length,
                'target_shape': target_shape
            },
            'metrics': {},
            'train_losses': [],
            'val_losses': []
        }
    }
    
    # Crear datasets y dataloaders
    train_dataset = PrecipitationDataset(X_train, y_train, seq_length)
    val_dataset = PrecipitationDataset(X_val, y_val, seq_length)
    
    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    # Iterar sobre cada modelo
    for model_name, config in models.items():
        print(f"\n{'-'*50}")
        print(f"PROCESANDO MODELO: {model_name}")
        print(f"{'-'*50}")
        
        # Verificar si existe modelo guardado
        if config['path'].exists() and not force_retrain:
            print(f"Modelo encontrado en {config['path']}, cargando...")
            try:
                # Cargar modelo existente
                model = config['class'](**config['params']).to(DEVICE)
                checkpoint = torch.load(config['path'], map_location=DEVICE)
                
                # Verificar contenido del checkpoint
                if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                    config['train_losses'] = checkpoint.get('train_losses', [])
                    config['val_losses'] = checkpoint.get('val_losses', [])
                else:
                    # Formato antiguo (solo state_dict)
                    model.load_state_dict(checkpoint)
                
                print(f"✅ Modelo {model_name} cargado correctamente")
            except Exception as e:
                print(f"❌ Error al cargar modelo: {str(e)}")
                print("Entrenando el modelo desde cero...")
                
                # Instanciar modelo
                model = config['class'](**config['params']).to(DEVICE)
                
                # Configurar entrenamiento
                criterion = nn.MSELoss()
                optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
                scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
                
                # Entrenar modelo
                model, train_losses, val_losses = train_hybrid_model(
                    name=model_name,
                    model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    epochs=100,
                    patience=20,
                    optimizer=optimizer,
                    criterion=criterion,
                    scheduler=scheduler
                )
                
                # Guardar historial de pérdidas
                config['train_losses'] = train_losses
                config['val_losses'] = val_losses
                
                # Guardar modelo
                save_data = {
                    'model_state_dict': model.state_dict(),
                    'train_losses': train_losses,
                    'val_losses': val_losses
                }
                torch.save(save_data, config['path'])
        else:
            # Entrenar modelo
            print(f"Entrenando {model_name} desde cero...")
            
            # Instanciar modelo
            model = config['class'](**config['params']).to(DEVICE)
            
            # Configurar entrenamiento
            criterion = nn.MSELoss()
            optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
            
            # Entrenar modelo
            model, train_losses, val_losses, learning_rates = train_hybrid_model(
                name=model_name,
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                epochs=100,
                patience=20,
                optimizer=optimizer,
                criterion=criterion,
                scheduler=scheduler,
                generate_plots=True
            )
            
            # Guardar historial de pérdidas
            config['train_losses'] = train_losses
            config['val_losses'] = val_losses
            
            # Guardar modelo
            save_data = {
                'model_state_dict': model.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'learning_rates': learning_rates
            }
            torch.save(save_data, config['path'])
        
        # Evaluar modelo con métricas adicionales
        print(f"Evaluando {model_name} con métricas adicionales...")
        model.eval()
        
        all_targets = []
        all_preds = []
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(DEVICE)
                outputs = model(inputs)
                
                # Asegurar dimensión de canal si es necesario
                if len(outputs.shape) == 5 and outputs.shape[2] == 1:
                    outputs = outputs.squeeze(2)
                
                # NUEVO: Asegurar que outputs y targets tienen la misma forma antes de compararlos
                # Aplicar las mismas transformaciones que en train_hybrid_model
                if len(targets.shape) == 5 and len(outputs.shape) == 4:
                    if outputs.shape[1] == targets.shape[1]:
                        targets = targets[:, :, 0, :, :]
                    else:
                        # Seleccionar diagonal o promediar
                        targets_reshaped = []
                        for i in range(min(targets.shape[1], outputs.shape[1])):
                            if i < targets.shape[2]:
                                targets_reshaped.append(targets[:, i, i])
                            else:
                                targets_reshaped.append(targets[:, i, -1])
                        targets = torch.stack(targets_reshaped, dim=1)
                
                # Ajustar longitud de secuencia si difiere
                if len(outputs.shape) == 4 and len(targets.shape) == 4:
                    output_seq_len = outputs.shape[1]
                    target_seq_len = targets.shape[1]
                    
                    if output_seq_len > target_seq_len:
                        outputs = outputs[:, :target_seq_len]
                    elif output_seq_len < target_seq_len:
                        targets = targets[:, :output_seq_len]
                
                # Manejar dimensiones incompatibles
                if outputs.shape != targets.shape:
                    # Encontrar dimensiones comunes
                    if len(outputs.shape) == len(targets.shape):
                        min_shape = [min(o, t) for o, t in zip(outputs.shape, targets.shape)]
                        # Crear slices para cada dimensión
                        slices_out = tuple(slice(0, m) for m in min_shape)
                        slices_tgt = tuple(slice(0, m) for m in min_shape)
                        
                        outputs = outputs[slices_out]
                        targets = targets[slices_tgt]
                
                # IMPORTANTE: Verificar que ahora tienen la misma forma
                if outputs.shape != targets.shape:
                    print(f"⚠️ Advertencia: No se pudo ajustar dimensiones: outputs={outputs.shape}, targets={targets.shape}")
                    # En caso extremo, usar formas más simples
                    if len(outputs.shape) > 2 and len(targets.shape) > 2:
                        # Usar solo el primer elemento de cada secuencia/batch
                        outputs = outputs[0:1, 0:1] if len(outputs.shape) >= 3 else outputs[0:1]
                        targets = targets[0:1, 0:1] if len(targets.shape) >= 3 else targets[0:1]
                        print(f"   Usando formas simplificadas: outputs={outputs.shape}, targets={targets.shape}")
                
                # Transformar a numpy para métricas
                preds = outputs.cpu().numpy()
                targets_np = targets.numpy()
                
                # Almacenar para métricas globales
                all_targets.append(targets_np)
                all_preds.append(preds)
        
        # Concatenar todas las predicciones y targets
        try:
            all_targets = np.concatenate(all_targets, axis=0)
            all_preds = np.concatenate(all_preds, axis=0)
        except ValueError as e:
            print(f"❌ Error al concatenar: {str(e)}")
            print(f"Formas de datos: {[t.shape for t in all_targets]}")
            print(f"Formas de predicciones: {[p.shape for p in all_preds]}")
            # Usar solo el primer batch para evitar errores de concatenación
            all_targets = all_targets[0]
            all_preds = all_preds[0]
        
        # Asegurar que las dimensiones son compatibles antes de aplanar
        print(f"Forma final: all_targets={all_targets.shape}, all_preds={all_preds.shape}")
        
        # NUEVO: Asegurar que ambos arrays tengan la misma forma antes de aplanar
        if all_targets.shape != all_preds.shape:
            # Encontrar la forma más pequeña común
            common_shape = []
            for i in range(min(len(all_targets.shape), len(all_preds.shape))):
                common_shape.append(min(all_targets.shape[i], all_preds.shape[i]))
            
            # Crear slices para recortar
            slices = tuple(slice(0, dim) for dim in common_shape)
            all_targets = all_targets[slices]
            all_preds = all_preds[slices]
            print(f"⚠️ Ajustadas dimensiones a forma común: {all_targets.shape}")
        
        # Calcular métricas
        # Aplanar para métricas generales
        flat_targets = all_targets.flatten()
        flat_preds = all_preds.flatten()
        
        # Verificar que tengan la misma longitud
        assert flat_targets.shape == flat_preds.shape, f"Error: Las dimensiones siguen siendo diferentes: {flat_targets.shape} vs {flat_preds.shape}"
        
        # Eliminar valores NaN si existen
        mask = ~np.isnan(flat_targets) & ~np.isnan(flat_preds)
        flat_targets = flat_targets[mask]
        flat_preds = flat_preds[mask]
        
        # Calcular métricas
        mae = mean_absolute_error(flat_targets, flat_preds)
        rmse = np.sqrt(mean_squared_error(flat_targets, flat_preds))
        r2 = r2_score(flat_targets, flat_preds)
        corr = np.corrcoef(flat_targets, flat_preds)[0, 1]
        
        # Calcular MAPE evitando divisiones por cero
        mask_nonzero = flat_targets != 0
        mape = np.mean(np.abs((flat_targets[mask_nonzero] - flat_preds[mask_nonzero]) / flat_targets[mask_nonzero])) * 100
        
        # Almacenar métricas
        config['metrics'] = {
            'MAE': mae,
            'RMSE': rmse,
            'MAPE (%)': mape,
            'r': corr,
            'R²': r2
        }
        
        # Almacenar modelo en el diccionario
        config['model'] = model
        
        # Mostrar métricas
        print(f"\nMétricas para {model_name}:")
        print(f"  MAE: {mae:.4f}")
        print(f"  RMSE: {rmse:.4f}")
        print(f"  MAPE: {mape:.2f}%")
        print(f"  r (correlación): {corr:.4f}")
        print(f"  R²: {r2:.4f}")
    
    # Devolver diccionario con todos los modelos y resultados, así como el val_loader para visualizaciones
    return models, val_loader

# Función para generar visualizaciones comparativas
def visualize_model_comparisons(models_dict, val_loader=None):
    """
    Genera visualizaciones comparativas para todos los modelos
    
    Args:
        models_dict: Diccionario con modelos y sus métricas
        val_loader: DataLoader de validación para generar predicciones
    
    Returns:
        Path: Ruta al directorio con las visualizaciones
    """
    print("\n" + "="*70)
    print("GENERANDO VISUALIZACIONES COMPARATIVAS")
    print("="*70)
    
    # Crear directorio para visualizaciones
    vis_dir = MODELS_OUTPUT / 'visualization' / 'comparisons'
    vis_dir.mkdir(exist_ok=True, parents=True)
    
    # 1. Tabla comparativa de métricas
    model_names = list(models_dict.keys())
    metrics = ['MAE', 'RMSE', 'MAPE (%)', 'r', 'R²']
    
    # Crear DataFrame para tabla
    metrics_data = {metric: [] for metric in metrics}
    metrics_data['Modelo'] = model_names
    
    for model_name in model_names:
        for metric in metrics:
            value = models_dict[model_name]['metrics'].get(metric, np.nan)
            metrics_data[metric].append(value)
    
    metrics_df = pd.DataFrame(metrics_data)
    
    # Mostrar tabla
    print("\n📊 TABLA COMPARATIVA DE MÉTRICAS")
    print(metrics_df.to_string(index=False, float_format=lambda x: f"{x:.4f}"))
    
    # Guardar como CSV
    metrics_csv_path = vis_dir / 'metrics_comparison.csv'
    metrics_df.to_csv(metrics_csv_path, index=False, float_format='%.4f')
    print(f"\nTabla guardada en {metrics_csv_path}")
    
    # 2. Gráfico de barras comparativo de métricas
    plt.figure(figsize=(15, 8))
    
    # Iterar sobre las métricas
    for i, metric in enumerate(metrics):
        plt.subplot(2, 3, i+1)
        values = [models_dict[model]['metrics'].get(metric, np.nan) for model in model_names]
        
        # Crear gráfico de barras
        bars = plt.bar(model_names, values)
        
        # Añadir etiquetas de valor sobre cada barra
        for bar, value in zip(bars, values):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{value:.4f}',
                    ha='center', va='bottom', rotation=0, fontsize=9)
        
        plt.title(f'Comparación de {metric}')
        plt.xticks(rotation=45, ha='right')
        plt.grid(alpha=0.3, axis='y')
        
        # Para métricas donde menor es mejor, destacar el mejor modelo
        if metric in ['MAE', 'RMSE', 'MAPE (%)']:
            best_idx = np.nanargmin(values)
            bars[best_idx].set_color('green')
        else:  # Para métricas donde mayor es mejor
            best_idx = np.nanargmax(values)
            bars[best_idx].set_color('green')
    
    plt.tight_layout()
    metrics_plot_path = vis_dir / 'metrics_comparison.png'
    plt.savefig(metrics_plot_path, dpi=300, bbox_inches='tight')
    print(f"Gráfico de métricas guardado en {metrics_plot_path}")
    plt.show()
    
    # 3. Curvas de aprendizaje comparativas (mejorado con más información)
    plt.figure(figsize=(20, 16))
    
    # 3.1 Pérdidas de entrenamiento
    plt.subplot(2, 2, 1)
    for model_name in model_names:
        train_losses = models_dict[model_name].get('train_losses', [])
        if train_losses:
            plt.plot(train_losses, label=model_name)
    
    plt.title('Pérdidas de Entrenamiento')
    plt.xlabel('Época')
    plt.ylabel('Pérdida')
    plt.grid(alpha=0.3)
    plt.legend()
    
    # 3.2 Pérdidas de validación
    plt.subplot(2, 2, 2)
    for model_name in model_names:
        val_losses = models_dict[model_name].get('val_losses', [])
        if val_losses:
            plt.plot(val_losses, label=model_name)
    
    plt.title('Pérdidas de Validación')
    plt.xlabel('Época')
    plt.ylabel('Pérdida')
    plt.grid(alpha=0.3)
    plt.legend()
    
    # 3.3 Pérdidas de ambas en escala logarítmica
    plt.subplot(2, 2, 3)
    for model_name in model_names:
        train_losses = models_dict[model_name].get('train_losses', [])
        val_losses = models_dict[model_name].get('val_losses', [])
        if train_losses:
            plt.semilogy(train_losses, linestyle='-', label=f'{model_name} (Train)')
        if val_losses:
            plt.semilogy(val_losses, linestyle='--', label=f'{model_name} (Val)')
    
    plt.title('Pérdidas de Entrenamiento y Validación (Log)')
    plt.xlabel('Época')
    plt.ylabel('Pérdida (log)')
    plt.grid(alpha=0.3)
    plt.legend()
    
    # 3.4 Comparación de convergencia (normalizada)
    plt.subplot(2, 2, 4)
    for model_name in model_names:
        val_losses = models_dict[model_name].get('val_losses', [])
        if val_losses and len(val_losses) > 1:
            # Normalizar pérdidas para comparar velocidad de convergencia
            norm_losses = [(loss - min(val_losses)) / (max(val_losses) - min(val_losses) + 1e-10) 
                          for loss in val_losses]
            plt.plot(norm_losses, label=model_name)
    
    plt.title('Convergencia Normalizada (Validación)')
    plt.xlabel('Época')
    plt.ylabel('Pérdida Normalizada')
    plt.grid(alpha=0.3)
    plt.legend()
    
    plt.tight_layout()
    learning_plot_path = vis_dir / 'learning_curves.png'
    plt.savefig(learning_plot_path, dpi=300, bbox_inches='tight')
    print(f"Gráfico de curvas de aprendizaje guardado en {learning_plot_path}")
    plt.show()
    
    # 4. Si hay un DataLoader de validación, generar ejemplos de predicciones
    if val_loader is not None and next(iter(models_dict.values())).get('model') is not None:
        try:
            # Obtener algunas muestras para visualizar
            print("\nGenerando ejemplos de predicciones...")
            
            # Tomar un batch como ejemplo
            sample_inputs, sample_targets = next(iter(val_loader))
            
            # Configurar figura para mostrar predicciones
            fig = plt.figure(figsize=(20, 5 * len(model_names)))
            
            # Generar predicciones con cada modelo
            for i, model_name in enumerate(model_names):
                model = models_dict[model_name].get('model')
                if model is None:
                    continue
                    
                # Modo evaluación
                model.eval()
                
                # Generar predicciones
                with torch.no_grad():
                    outputs = model(sample_inputs.to(next(model.parameters()).device))
                
                # Seleccionar un ejemplo (primer item del batch)
                if len(outputs.shape) == 5:  # [batch, seq, channel, H, W]
                    pred = outputs[0, 0, 0].cpu().numpy()
                elif len(outputs.shape) == 4:  # [batch, seq, H, W]
                    pred = outputs[0, 0].cpu().numpy()
                else:
                    pred = outputs[0].cpu().numpy()
                
                # Obtener target correspondiente
                if len(sample_targets.shape) == 5:  # [batch, seq, channel, H, W]
                    target = sample_targets[0, 0, 0].numpy()
                elif len(sample_targets.shape) == 4:  # [batch, seq, H, W]
                    target = sample_targets[0, 0].numpy()
                else:
                    target = sample_targets[0].numpy()
                
                # Mostrar predicción vs real
                plt.subplot(len(model_names), 3, i*3+1)
                plt.imshow(target, cmap='viridis')
                plt.colorbar(fraction=0.046, pad=0.04)
                plt.title(f'{model_name} - Valor Real')
                
                plt.subplot(len(model_names), 3, i*3+2)
                plt.imshow(pred, cmap='viridis')
                plt.colorbar(fraction=0.046, pad=0.04)
                plt.title(f'{model_name} - Predicción')
                
                plt.subplot(len(model_names), 3, i*3+3)
                diff = target - pred
                plt.imshow(diff, cmap='RdBu_r')
                plt.colorbar(fraction=0.046, pad=0.04)
                plt.title(f'{model_name} - Diferencia')
            
            plt.tight_layout()
            pred_plot_path = vis_dir / 'prediction_examples.png'
            plt.savefig(pred_plot_path, dpi=300, bbox_inches='tight')
            print(f"Ejemplos de predicción guardados en {pred_plot_path}")
            plt.show()
            
        except Exception as e:
            print(f"Error al generar ejemplos de predicción: {str(e)}")
    
    return vis_dir

# Verificación de datos
try:
    log_and_print("Verificando datos cargados...")
    
    # Verificar que los datasets contienen los datos necesarios
    if 'precipitacion' not in ds_full.data_vars and 'total_precipitation' not in ds_full.data_vars:
        raise ValueError(f"No se encontró variable de precipitación. Variables disponibles: {list(ds_full.data_vars.keys())}")
    
    # Verificar que las dimensiones son las esperadas
    if len(ds_full.dims) < 2:
        raise ValueError(f"Dataset con dimensiones insuficientes: {ds_full.dims}")
    
    # Ejemplo de cómo examinar un punto de datos para verificar que no esté corrupto
    sample_var = next(iter(ds_full.data_vars))
    log_and_print(f"Ejemplo de datos - Variable '{sample_var}', primeros valores: {ds_full[sample_var].values.flatten()[:5]}")
    
    log_and_print("Los datos superaron las verificaciones iniciales.")

    # Preparación de datos con manejo de excepciones
    try:
        # Si los datos son muy grandes, podemos usar una muestra para pruebas
        # Descomentar estas líneas para usar una muestra si hay problemas de memoria
        # use_sample = True
        # if use_sample:
        #    log_and_print("Usando muestra de datos para pruebas...")
        #    ds_full = ds_full.isel(time=slice(0, min(200, len(ds_full.time))))
        log_and_print("Preparando datos para entrenamiento...")
        X_train, y_train, X_val, y_val = prepare_train_val_data(ds_full)
        
        log_and_print(f"Datos preparados exitosamente. Formas: X_train={X_train.shape}, y_train={y_train.shape}")
        
        # Entrenar y comparar modelos con los datos preparados
        log_and_print("Iniciando entrenamiento de modelos...")
        comparison_models, val_loader = train_and_compare_models(X_train, y_train, X_val, y_val)
        
        # Generar visualizaciones
        log_and_print("Generando visualizaciones comparativas...")
        vis_dir = visualize_model_comparisons(comparison_models, val_loader)
        log_and_print(f"Visualizaciones guardadas en {vis_dir}")
        
    except Exception as e:
        import traceback
        log_and_print(f"ERROR en la preparación de datos o entrenamiento: {str(e)}")
        log_and_print(traceback.format_exc())
        
        # Intentar ejecutar una versión simplificada si falla
        log_and_print("Intentando ejecución con una muestra reducida...")
        try:
            # Usar una muestra pequeña para diagnóstico
            time_sample = min(100, len(ds_full.time))
            log_and_print(f"Usando muestra de {time_sample} puntos temporales")
            ds_sample = ds_full.isel(time=slice(0, time_sample))
            
            X_train, y_train, X_val, y_val = prepare_train_val_data(ds_sample, input_window=12, output_horizon=3)
            log_and_print("Muestra de datos preparada correctamente. Probando modelo básico...")
            
            # Probar solo un modelo pequeño
            input_channels = X_train.shape[1] if len(X_train.shape) > 3 else 1
            target_shape = y_train.shape[-2:] if len(y_train.shape) >= 3 else (61, 65)
            
            # Crear un modelo simple para pruebas
            test_model = SimpleConvGRU(
                input_channels=input_channels,
                hidden_dim=64,  # Reducido para pruebas
                output_channels=1,
                seq_length=3,  # Reducido para pruebas
                target_shape=target_shape
            ).to(DEVICE)
            
            log_and_print("Modelo creado correctamente. Verificación completa.")
            
        except Exception as nested_e:
            log_and_print(f"ERROR también en la ejecución simplificada: {str(nested_e)}")
            log_and_print(traceback.format_exc())
            log_and_print("Recomendaciones para solucionar:")
            log_and_print("1. Verificar la estructura del dataset y variables disponibles")
            log_and_print("2. Revisar espacio en disco y memoria disponible")
            log_and_print("3. Comprobar que todos los archivos necesarios están presentes")
    
except Exception as outer_e:
    import traceback
    log_and_print(f"ERROR en verificación de datos: {str(outer_e)}")
    log_and_print(traceback.format_exc())
    
# Mantenemos sólo las funciones de preprocesamiento que aún se necesitan
def calculate_afc(signal, lags=[1, 3, 6]):
    """Calcula la Función de Autocorrelación para los lags dados."""
    afc = [np.correlate(signal[lag:], signal[:-lag], mode='valid') for lag in lags]
    return afc

# Funciones de preprocesamiento adicionales
def wavelet_denoise(data, wavelet='db4', level=3):
    """Aplica denoising wavelet a los datos."""
    try:
        # API moderna de scikit-image (0.19+)
        return denoise_wavelet(
            data, 
            wavelet=wavelet, 
            mode='soft',
            method='BayesShrink',
            channel_axis=None
        )
    except Exception as e:
        print(f"Error en wavelet denoising: {str(e)}. Intentando método alternativo.")
        # Si falla, intentar solo con parámetros básicos
        return denoise_wavelet(data, wavelet=wavelet)

# 3) Definición de Datasets PyTorch
class PrecipitationDataset(Dataset):
    """
    Dataset personalizado para datos de precipitación que maneja la reducción
    de dimensionalidad y garantiza la compatibilidad dimensional.
    """
    def __init__(self, data, target, seq_length):
        self.data = torch.from_numpy(data).float()
        self.target = torch.from_numpy(target).float()
        self.seq_length = seq_length
        self.target_shape = target.shape[1:]  # Guardar la forma objetivo (height, width)
        
        print(f"Dataset inicializado - Forma de datos: {self.data.shape}")
        print(f"Dataset inicializado - Forma de targets: {self.target.shape}")
        print(f"Forma objetivo almacenada: {self.target_shape}")

    def __len__(self):
        return len(self.data) - self.seq_length + 1

    def __getitem__(self, idx):
        # Obtener secuencia de datos
        inputs = self.data[idx:idx+self.seq_length]
        labels = self.target[idx:idx+self.seq_length]
        
        # Para debugging, imprimir formas solo para el primer elemento
        if idx == 0:
            print(f"Ejemplo de entrada - forma original: {inputs.shape}")
            
        # Simplificar dimensiones si es posible
        if len(inputs.shape) > 3 and inputs.shape[1] == 1:  # Si hay una sola ventana/característica
            inputs = inputs.squeeze(1)  # Eliminar dimensión redundante 
            
        # Para debugging
        if idx == 0:
            print(f"Ejemplo de entrada - forma final: {inputs.shape}")
            print(f"Ejemplo de etiqueta: {labels.shape}")
            
        return inputs, labels

# 4) Modelos Híbridos: ConvBiGRU-AE y ConvLSTM-AE
class ConvBiGRU_AE(nn.Module):
    def __init__(self, input_channels, hidden_dim, num_layers, output_channels, seq_length, target_shape=(61, 65), kernel_size=3, padding=1):
        super(ConvBiGRU_AE, self).__init__()
        self.seq_length = seq_length
        self.hidden_dim = hidden_dim
        self.output_channels = output_channels
        self.target_shape = target_shape  # Guardar la forma objetivo
        
        # Encoder - Convoluciones 2D para cada paso de tiempo
        # En lugar de esperar múltiples canales, procesaremos cada paso de tiempo independientemente
        self.conv1 = nn.Conv2d(1, hidden_dim, kernel_size=kernel_size, padding=padding)
        self.norm1 = nn.InstanceNorm2d(hidden_dim)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=padding)
        self.norm2 = nn.InstanceNorm2d(hidden_dim)
        
        # GRU para procesar la secuencia comprimida
        self.gru = nn.GRU(hidden_dim, hidden_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
        
        # Decoder
        self.deconv1 = nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, kernel_size=kernel_size, padding=padding)
        self.dnorm1 = nn.InstanceNorm2d(hidden_dim)
        self.deconv2 = nn.ConvTranspose2d(hidden_dim, output_channels * seq_length, kernel_size=kernel_size, padding=padding)
        
    def forward(self, x, target_shape=None):
        # Usar la forma objetivo proporcionada o la predeterminada
        if target_shape is None:
            target_shape = self.target_shape
        
        # Manejar entradas con diferentes dimensiones
        batch_size = x.size(0)
        seq_len = None
        H = None
        W = None
        
        # NUEVO: Manejar caso de 6 dimensiones [batch, seq, window, channel, H, W]
        if len(x.shape) == 6:
            batch_size, seq_len, window, channels, H, W = x.shape
            # Combinar window y channels para simplificar
            x = x.permute(0, 1, 3, 2, 4, 5).contiguous()  # [batch, seq, channel, window, H, W]
            x = x.reshape(batch_size, seq_len, channels*window, H, W)
            # Ahora x tiene forma [batch, seq, channel*window, H, W]
            # Tomar solo el primer canal por simplicidad
            x = x[:, :, 0, :, :]  # [batch, seq, H, W]
            
        elif len(x.shape) > 4:  # Manejar caso con más dimensiones (por ejemplo, dimensión de canal)
            if len(x.shape) == 5:  # [batch, seq, channel, H, W]
                seq_len = x.size(1)
                # Combinar dimensión de canal con características (enfoque simplificado)
                # Asumiendo canal=1 o permitiendo solo el primer canal
                x = x[:, :, 0, :, :]  # Tomar primer canal
            else:
                raise ValueError(f"Forma de entrada no soportada: {x.shape}")
                
        elif len(x.shape) < 4:  # Manejar caso con menos dimensiones
            # Añadir dimensiones faltantes
            if len(x.shape) == 3:  # [batch, H, W]
                # Añadir dimensión de secuencia (asumiendo seq_len=1)
                x = x.unsqueeze(1)
                seq_len = 1
            else:
                raise ValueError(f"Forma de entrada no soportada: {x.shape}")
                
        else:
            # Forma esperada: [batch_size, seq_length, height, width]
            batch_size, seq_len, H, W = x.size()
        
        # Obtener dimensiones espaciales después de la reorganización
        if H is None or W is None:  # Si aún no se han establecido
            _, _, H, W = x.size()

        # Procesar cada paso de tiempo individualmente
        processed_features = []
        
        for t in range(seq_len):
            # Obtener el paso de tiempo actual y añadir dimensión de canal
            x_t = x[:, t].unsqueeze(1)  # [batch_size, 1, height, width]
            
            # Aplicar convoluciones
            x_t = F.relu(self.norm1(self.conv1(x_t)))
            x_t = F.relu(self.norm2(self.conv2(x_t)))
            
            # Promedio espacial para reducir a [batch_size, hidden_dim]
            x_t_features = x_t.mean(dim=(2, 3))
            
            # Almacenar características para este paso de tiempo
            processed_features.append(x_t_features)
        
        # Concatenar características para todos los pasos de tiempo
        sequence_features = torch.stack(processed_features, dim=1)  # [batch_size, seq_length, hidden_dim]
        
        # Aplicar GRU a la secuencia
        output, _ = self.gru(sequence_features)  # [batch_size, seq_length, hidden_dim*2]
        
        # Tomar el último estado y preparar para deconv
        output = output[:, -1, :].view(batch_size, self.hidden_dim*2, 1, 1)
        output = F.interpolate(output, size=(H, W), mode='nearest')
        
        # Decoder
        output = F.relu(self.dnorm1(self.deconv1(output)))
        output = self.deconv2(output)  # [batch_size, seq_length*output_channels, H, W]
        
        # Reorganizar para obtener [batch_size, seq_length, output_channels, H, W]
        output = output.view(batch_size, self.seq_length, self.output_channels, H, W)
        
        # Redimensionar a la forma objetivo usando interpolación bilineal
        # Primero reorganizamos para [batch_size*seq_length, output_channels, H, W]
        output_reshaped = output.view(batch_size * self.seq_length, self.output_channels, H, W)
        output_resized = F.interpolate(output_reshaped, size=target_shape, mode='bilinear', align_corners=False)
        
        # Reorganizar de nuevo a [batch_size, seq_length, output_channels, target_H, target_W]
        output = output_resized.view(batch_size, self.seq_length, self.output_channels, target_shape[0], target_shape[1])
        
        return output

class ConvBiLSTM_AE(nn.Module):
    def __init__(self, input_channels, hidden_dim, num_layers, output_channels, seq_length, target_shape=(61, 65), kernel_size=3, padding=1):
        super(ConvBiLSTM_AE, self).__init__()
        self.seq_length = seq_length
        self.hidden_dim = hidden_dim
        self.output_channels = output_channels
        self.target_shape = target_shape  # Guardar la forma objetivo
        
        # Encoder para procesar cada paso de tiempo individualmente
        self.conv1 = nn.Conv2d(1, hidden_dim, kernel_size=kernel_size, padding=padding)
        self.norm1 = nn.InstanceNorm2d(hidden_dim)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=kernel_size, padding=padding)
        self.norm2 = nn.InstanceNorm2d(hidden_dim)
        
        # LSTM para procesar la secuencia
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=num_layers, bidirectional=True, batch_first=True)
        
        # Decoder
        self.deconv1 = nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, kernel_size=kernel_size, padding=padding)
        self.dnorm1 = nn.InstanceNorm2d(hidden_dim)
        self.deconv2 = nn.ConvTranspose2d(hidden_dim, output_channels * seq_length, kernel_size=kernel_size, padding=padding)
        
    def forward(self, x, target_shape=None):
        # Usar la forma objetivo proporcionada o la predeterminada
        if target_shape is None:
            target_shape = self.target_shape
            
        # Manejar entradas con diferentes dimensiones
        batch_size = x.size(0)
        seq_len = None
        H = None
        W = None
        
        # NUEVO: Manejar caso de 6 dimensiones [batch, seq, window, channel, H, W]
        if len(x.shape) == 6:
            batch_size, seq_len, window, channels, H, W = x.shape
            # Combinar window y channels para simplificar
            x = x.permute(0, 1, 3, 2, 4, 5).contiguous()  # [batch, seq, channel, window, H, W]
            x = x.reshape(batch_size, seq_len, channels*window, H, W)
            # Ahora x tiene forma [batch, seq, channel*window, H, W]
            # Tomar solo el primer canal por simplicidad
            x = x[:, :, 0, :, :]  # [batch, seq, H, W]
            
        elif len(x.shape) > 4:  # Manejar caso con más dimensiones (por ejemplo, dimensión de canal)
            if len(x.shape) == 5:  # [batch, seq, channel, H, W]
                seq_len = x.size(1)
                # Combinar dimensión de canal con características (enfoque simplificado)
                # Asumiendo canal=1 o permitiendo solo el primer canal
                x = x[:, :, 0, :, :]  # Tomar primer canal
            else:
                raise ValueError(f"Forma de entrada no soportada: {x.shape}")
                
        elif len(x.shape) < 4:  # Manejar caso con menos dimensiones
            # Añadir dimensiones faltantes
            if len(x.shape) == 3:  # [batch, H, W]
                # Añadir dimensión de secuencia (asumiendo seq_len=1)
                x = x.unsqueeze(1)
                seq_len = 1
            else:
                raise ValueError(f"Forma de entrada no soportada: {x.shape}")
                
        else:
            # Forma esperada: [batch_size, seq_length, height, width]
            batch_size, seq_len, H, W = x.size()
        
        # Obtener dimensiones espaciales después de la reorganización
        if H is None or W is None:  # Si aún no se han establecido
            _, _, H, W = x.size()

        # Procesar cada paso de tiempo individualmente
        processed_features = []
        
        for t in range(seq_len):
            # Obtener el paso de tiempo actual y añadir dimensión de canal
            x_t = x[:, t].unsqueeze(1)  # [batch_size, 1, height, width]
            
            # Aplicar convoluciones
            x_t = F.relu(self.norm1(self.conv1(x_t)))
            x_t = F.relu(self.norm2(self.conv2(x_t)))
            
            # Extraer características globales
            # Promedio espacial para reducir a [batch_size, hidden_dim]
            x_t_features = x_t.mean(dim=(2, 3))
            
            # Almacenar características para este paso de tiempo
            processed_features.append(x_t_features)
        
        # Concatenar características para todos los pasos de tiempo
        sequence_features = torch.stack(processed_features, dim=1)  # [batch_size, seq_length, hidden_dim]
        
        # Aplicar LSTM a la secuencia
        output, _ = self.lstm(sequence_features)  # [batch_size, seq_length, hidden_dim*2]
        
        # Tomar el último estado y preparar para deconv
        output = output[:, -1, :].view(batch_size, self.hidden_dim*2, 1, 1)
        output = F.interpolate(output, size=(H, W), mode='nearest')
        
        # Decoder
        output = F.relu(self.dnorm1(self.deconv1(output)))
        output = self.deconv2(output)  # [batch_size, seq_length*output_channels, H, W]
        
        # Reorganizar para obtener [batch_size, seq_length, output_channels, H, W]
        output = output.view(batch_size, self.seq_length, self.output_channels, H, W)
        
        # Redimensionar a la forma objetivo usando interpolación bilineal
        output_reshaped = output.view(batch_size * self.seq_length, self.output_channels, H, W)
        output_resized = F.interpolate(output_reshaped, size=target_shape, mode='bilinear', align_corners=False)
        output = output_resized.view(batch_size, self.seq_length, self.output_channels, target_shape[0], target_shape[1])
        
        return output

# 4.5) Modelos ConvGRU y ConvLSTM simples para comparación
class SimpleConvGRU(nn.Module):
    """
    Modelo simple ConvGRU para comparar con los modelos híbridos.
    """
    def __init__(self, input_channels, hidden_dim, output_channels, seq_length=12, target_shape=(61, 65)):
        super(SimpleConvGRU, self).__init__()
        self.input_channels = input_channels
        self.hidden_dim = hidden_dim
        self.output_channels = output_channels
        self.seq_length = seq_length
        self.target_shape = target_shape
        
        # Capas convolucionales para procesar entrada
        self.conv1 = nn.Conv2d(input_channels, hidden_dim // 2, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(hidden_dim // 2)
        self.conv2 = nn.Conv2d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(hidden_dim)
        
        # Capa GRU
        self.gru = nn.GRU(
            input_size=hidden_dim * target_shape[0] * target_shape[1], 
            hidden_size=hidden_dim * 4, 
            num_layers=2, 
            batch_first=True
        )
        
        # Capas para generar salida
        self.upconv1 = nn.Conv2d(hidden_dim * 4, hidden_dim * 2, kernel_size=3, padding=1)
        self.upbn1 = nn.BatchNorm2d(hidden_dim * 2)
        self.upconv2 = nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=3, padding=1)
        self.upbn2 = nn.BatchNorm2d(hidden_dim)
        
        # Capa final para generar múltiples horizontes de predicción
        self.out_conv = nn.Conv2d(hidden_dim, output_channels * seq_length, kernel_size=1)
        
    def forward(self, x, target_shape=None):
        """
        Args:
            x: Tensor de entrada que puede tener varias formas:
               - [batch_size, seq_len, input_window, channels, height, width]
               - [batch_size, seq_len, channels, height, width]
               - [batch_size, channels, height, width]
            target_shape: Forma objetivo para la salida
        """
        if target_shape is None:
            target_shape = self.target_shape
        
        # NUEVO: Manejar el caso de 6 dimensiones [batch, seq, window, channel, H, W]
        if len(x.shape) == 6:  
            batch_size, seq_len, window, channels, H, W = x.shape
            
            # Reorganizar a [batch, seq, channels*window, H, W] combinando window y channels
            x = x.permute(0, 1, 3, 2, 4, 5).contiguous()  # [batch, seq, channel, window, H, W]
            x = x.view(batch_size, seq_len, channels*window, H, W)  # [batch, seq, channel*window, H, W]
            
            # Procesar cada paso de tiempo independientemente
            processed_features = []
            
            for t in range(seq_len):
                # Extraer slice temporal: [batch, channels*window, H, W]
                xt = x[:, t]
                
                # Aplicar convoluciones
                xt = F.relu(self.bn1(self.conv1(xt)))
                xt = F.relu(self.bn2(self.conv2(xt)))
                
                # Aplanar para GRU: [batch, hidden_dim*H*W]
                xt = xt.view(batch_size, -1)
                processed_features.append(xt)
            
            # Stack para secuencia: [batch, seq_len, features]
            x_sequence = torch.stack(processed_features, dim=1)
            
            # Aplicar GRU
            gru_out, _ = self.gru(x_sequence)
            
            # Tomar último estado
            last_out = gru_out[:, -1]  # [batch, hidden_dim*4]
            
        # Verificar y manejar diferentes formas de entrada
        elif len(x.shape) == 5:  # [batch, seq_len, channels, height, width]
            batch_size = x.size(0)
            seq_len = x.size(1)
            
            # Procesar cada paso de tiempo independientemente
            processed_features = []
            
            for t in range(seq_len):
                # Extraer slice temporal: [batch, channels, H, W]
                xt = x[:, t]
                
                # Aplicar convoluciones
                xt = F.relu(self.bn1(self.conv1(xt)))
                xt = F.relu(self.bn2(self.conv2(xt)))
                
                # Aplanar para GRU: [batch, hidden_dim*H*W]
                xt = xt.view(batch_size, -1)
                processed_features.append(xt)
            
            # Stack para secuencia: [batch, seq_len, features]
            x_sequence = torch.stack(processed_features, dim=1)
            
            # Aplicar GRU
            gru_out, _ = self.gru(x_sequence)
            
            # Tomar último estado
            last_out = gru_out[:, -1]  # [batch, hidden_dim*4]
            
        elif len(x.shape) == 4:  # [batch, channels, height, width] (sin dimensión de secuencia)
            batch_size = x.size(0)
            
            # Aplicar convoluciones directamente
            x = F.relu(self.bn1(self.conv1(x)))
            x = F.relu(self.bn2(self.conv2(x)))
            
            # Aplanar
            x_flat = x.view(batch_size, -1)
            
            # Añadir dimensión de secuencia ficticia
            x_sequence = x_flat.unsqueeze(1)  # [batch, 1, features]
            
            # Aplicar GRU
            gru_out, _ = self.gru(x_sequence)
            
            # Tomar salida
            last_out = gru_out[:, -1]  # [batch, hidden_dim*4]
        
        else:
            raise ValueError(f"Forma de entrada no soportada: {x.shape}")
        
        # Reformar para procesamiento convolucional
        last_out = last_out.view(batch_size, self.hidden_dim * 4, 1, 1)
        last_out = F.interpolate(last_out, size=target_shape, mode='bilinear', align_corners=False)
        
        # Capas finales
        out = F.relu(self.upbn1(self.upconv1(last_out)))
        out = F.relu(self.upbn2(self.upconv2(out)))
        out = self.out_conv(out)
        
        # Reorganizar para obtener [batch, seq_len, channels, H, W]
        out = out.view(batch_size, self.seq_length, self.output_channels, *target_shape)
        
        return out
        
class SimpleConvLSTM(nn.Module):
    """
    Modelo simple ConvLSTM para comparar con los modelos híbridos.
    """
    def __init__(self, input_channels, hidden_dim, output_channels, seq_length=12, target_shape=(61, 65)):
        super(SimpleConvLSTM, self).__init__()
        self.input_channels = input_channels
        self.hidden_dim = hidden_dim
        self.output_channels = output_channels
        self.seq_length = seq_length
        self.target_shape = target_shape
        
        # Capas convolucionales para procesar entrada
        self.conv1 = nn.Conv2d(input_channels, hidden_dim // 2, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(hidden_dim // 2)
        self.conv2 = nn.Conv2d(hidden_dim // 2, hidden_dim, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(hidden_dim)
        
        # Capa LSTM
        self.lstm = nn.LSTM(
            input_size=hidden_dim * target_shape[0] * target_shape[1], 
            hidden_size=hidden_dim * 4, 
            num_layers=2, 
            batch_first=True
        )
        
        # Capas para generar salida
        self.upconv1 = nn.Conv2d(hidden_dim * 4, hidden_dim * 2, kernel_size=3, padding=1)
        self.upbn1 = nn.BatchNorm2d(hidden_dim * 2)
        self.upconv2 = nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=3, padding=1)
        self.upbn2 = nn.BatchNorm2d(hidden_dim)
        
        # Capa final para generar múltiples horizontes de predicción
        self.out_conv = nn.Conv2d(hidden_dim, output_channels * seq_length, kernel_size=1)
        
    def forward(self, x, target_shape=None):
        """
        Args:
            x: Tensor de entrada que puede tener varias formas:
               - [batch_size, seq_len, input_window, channels, height, width]
               - [batch_size, seq_len, channels, height, width]
               - [batch_size, channels, height, width]
            target_shape: Forma objetivo para la salida
        """
        if target_shape is None:
            target_shape = self.target_shape
            
        # NUEVO: Manejar el caso de 6 dimensiones [batch, seq, window, channel, H, W]
        if len(x.shape) == 6:  
            batch_size, seq_len, window, channels, H, W = x.shape
            
            # Reorganizar a [batch, seq, channels*window, H, W] combinando window y channels
            x = x.permute(0, 1, 3, 2, 4, 5).contiguous()  # [batch, seq, channel, window, H, W]
            x = x.view(batch_size, seq_len, channels*window, H, W)  # [batch, seq, channel*window, H, W]
            
            # Procesar cada paso de tiempo independientemente
            processed_features = []
            
            for t in range(seq_len):
                # Extraer slice temporal: [batch, channels*window, H, W]
                xt = x[:, t]
                
                # Aplicar convoluciones
                xt = F.relu(self.bn1(self.conv1(xt)))
                xt = F.relu(self.bn2(self.conv2(xt)))
                
                # Aplanar para LSTM: [batch, hidden_dim*H*W]
                xt = xt.view(batch_size, -1)
                processed_features.append(xt)
            
            # Stack para secuencia: [batch, seq_len, features]
            x_sequence = torch.stack(processed_features, dim=1)
            
            # Aplicar LSTM
            lstm_out, _ = self.lstm(x_sequence)
            
            # Tomar último estado
            last_out = lstm_out[:, -1]  # [batch, hidden_dim*4]
            
        # Verificar y manejar diferentes formas de entrada
        elif len(x.shape) == 5:  # [batch, seq_len, channels, height, width]
            batch_size = x.size(0)
            seq_len = x.size(1)
            
            # Procesar cada paso de tiempo independientemente
            processed_features = []
            
            for t in range(seq_len):
                # Extraer slice temporal: [batch, channels, H, W]
                xt = x[:, t]
                
                # Aplicar convoluciones
                xt = F.relu(self.bn1(self.conv1(xt)))
                xt = F.relu(self.bn2(self.conv2(xt)))
                
                # Aplanar para LSTM: [batch, hidden_dim*H*W]
                xt = xt.view(batch_size, -1)
                processed_features.append(xt)
            
            # Stack para secuencia: [batch, seq_len, features]
            x_sequence = torch.stack(processed_features, dim=1)
            
            # Aplicar LSTM
            lstm_out, _ = self.lstm(x_sequence)
            
            # Tomar último estado
            last_out = lstm_out[:, -1]  # [batch, hidden_dim*4]
            
        elif len(x.shape) == 4:  # [batch, channels, height, width] (sin dimensión de secuencia)
            batch_size = x.size(0)
            
            # Aplicar convoluciones directamente
            x = F.relu(self.bn1(self.conv1(x)))
            x = F.relu(self.bn2(self.conv2(x)))
            
            # Aplanar
            x_flat = x.view(batch_size, -1)
            
            # Añadir dimensión de secuencia ficticia
            x_sequence = x_flat.unsqueeze(1)  # [batch, 1, features]
            
            # Aplicar LSTM
            lstm_out, _ = self.lstm(x_sequence)
            
            # Tomar salida
            last_out = lstm_out[:, -1]  # [batch, hidden_dim*4]
        
        else:
            raise ValueError(f"Forma de entrada no soportada: {x.shape}")
        
        # Reformar para procesamiento convolucional
        last_out = last_out.view(batch_size, self.hidden_dim * 4, 1, 1)
        last_out = F.interpolate(last_out, size=target_shape, mode='bilinear', align_corners=False)
        
        # Capas finales
        out = F.relu(self.upbn1(self.upconv1(last_out)))
        out = F.relu(self.upbn2(self.upconv2(out)))
        out = self.out_conv(out)
        
        # Reorganizar para obtener [batch, seq_len, channels, H, W]
        out = out.view(batch_size, self.seq_length, self.output_channels, *target_shape)
        
        return out

# Función para entrenar y evaluar todos los modelos disponibles
def train_and_compare_models(X_train, y_train, X_val, y_val, force_retrain=False):
    """
    Entrena y compara modelos simples y híbridos usando las mismas entradas.
    
    Args:
        X_train: Datos de entrenamiento
        y_train: Etiquetas de entrenamiento
        X_val: Datos de validación
        y_val: Etiquetas de validación
        force_retrain: Si reentrenar aunque existan modelos guardados
    
    Returns:
        dict, DataLoader: Diccionario con modelos entrenados y sus métricas, y el DataLoader de validación
    """
    print("\n" + "="*70)
    print("ENTRENAMIENTO Y COMPARACIÓN DE MODELOS SIMPLES Y HÍBRIDOS")
    print("="*70)
    
    # Directorios para guardar modelos
    models_dir = MODELS_OUTPUT / 'comparison_models'
    models_dir.mkdir(exist_ok=True, parents=True)
    
    # Parámetros comunes
    input_channels = X_train.shape[1] if len(X_train.shape) > 3 else 1
    hidden_dim = 128
    output_channels = 1
    seq_length = OUTPUT_HORIZON
    
    # Obtener shape de target si está disponible
    if len(y_train.shape) >= 3:
        target_shape = y_train.shape[-2:]
    else:
        target_shape = (61, 65)  # valores por defecto
    
    print(f"Configuración común: input_channels={input_channels}, hidden_dim={hidden_dim}, output_channels={output_channels}")
    print(f"Secuencia: {seq_length}, target_shape: {target_shape}")
    
    # Definir los modelos a entrenar con sus rutas de guardado
    models = {
        'SimpleConvGRU': {
            'class': SimpleConvGRU,
            'path': models_dir / 'simple_convgru.pth',
            'params': {
                'input_channels': input_channels,
                'hidden_dim': hidden_dim,
                'output_channels': output_channels,
                'seq_length': seq_length,
                'target_shape': target_shape
            },
            'metrics': {},
            'train_losses': [],
            'val_losses': []
        },
        'SimpleConvLSTM': {
            'class': SimpleConvLSTM,
            'path': models_dir / 'simple_convlstm.pth',
            'params': {
                'input_channels': input_channels,
                'hidden_dim': hidden_dim,
                'output_channels': output_channels,
                'seq_length': seq_length,
                'target_shape': target_shape
            },
            'metrics': {},
            'train_losses': [],
            'val_losses': []
        },
        'ConvBiGRU-AE': {
            'class': ConvBiGRU_AE,
            'path': models_dir / 'convbigru_ae.pth',
            'params': {
                'input_channels': input_channels,
                'hidden_dim': hidden_dim,
                'num_layers': 3,
                'output_channels': output_channels,
                'seq_length': seq_length,
                'target_shape': target_shape
            },
            'metrics': {},
            'train_losses': [],
            'val_losses': []
        },
        'ConvBiLSTM-AE': {
            'class': ConvBiLSTM_AE,
            'path': models_dir / 'convbilstm_ae.pth',
            'params': {
                'input_channels': input_channels,
                'hidden_dim': hidden_dim,
                'num_layers': 3,
                'output_channels': output_channels,
                'seq_length': seq_length,
                'target_shape': target_shape
            },
            'metrics': {},
            'train_losses': [],
            'val_losses': []
        }
    }
    
    # Crear datasets y dataloaders
    train_dataset = PrecipitationDataset(X_train, y_train, seq_length)
    val_dataset = PrecipitationDataset(X_val, y_val, seq_length)
    
    batch_size = 16
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    # Iterar sobre cada modelo
    for model_name, config in models.items():
        print(f"\n{'-'*50}")
        print(f"PROCESANDO MODELO: {model_name}")
        print(f"{'-'*50}")
        
        # Verificar si existe modelo guardado
        if config['path'].exists() and not force_retrain:
            print(f"Modelo encontrado en {config['path']}, cargando...")
            try:
                # Cargar modelo existente
                model = config['class'](**config['params']).to(DEVICE)
                checkpoint = torch.load(config['path'], map_location=DEVICE)
                
                # Verificar contenido del checkpoint
                if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                    model.load_state_dict(checkpoint['model_state_dict'])
                    config['train_losses'] = checkpoint.get('train_losses', [])
                    config['val_losses'] = checkpoint.get('val_losses', [])
                else:
                    # Formato antiguo (solo state_dict)
                    model.load_state_dict(checkpoint)
                
                print(f"✅ Modelo {model_name} cargado correctamente")
            except Exception as e:
                print(f"❌ Error al cargar modelo: {str(e)}")
                print("Entrenando el modelo desde cero...")
                
                # Instanciar modelo
                model = config['class'](**config['params']).to(DEVICE)
                
                # Configurar entrenamiento
                criterion = nn.MSELoss()
                optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
                scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
                
                # Entrenar modelo
                model, train_losses, val_losses = train_hybrid_model(
                    name=model_name,
                    model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    epochs=100,
                    patience=20,
                    optimizer=optimizer,
                    criterion=criterion,
                    scheduler=scheduler
                )
                
                # Guardar historial de pérdidas
                config['train_losses'] = train_losses
                config['val_losses'] = val_losses
                
                # Guardar modelo
                save_data = {
                    'model_state_dict': model.state_dict(),
                    'train_losses': train_losses,
                    'val_losses': val_losses
                }
                torch.save(save_data, config['path'])
        else:
            # Entrenar modelo
            print(f"Entrenando {model_name} desde cero...")
            
            # Instanciar modelo
            model = config['class'](**config['params']).to(DEVICE)
            
            # Configurar entrenamiento
            criterion = nn.MSELoss()
            optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
            
            # Entrenar modelo
            model, train_losses, val_losses, learning_rates = train_hybrid_model(
                name=model_name,
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                epochs=100,
                patience=20,
                optimizer=optimizer,
                criterion=criterion,
                scheduler=scheduler,
                generate_plots=True
            )
            
            # Guardar historial de pérdidas
            config['train_losses'] = train_losses
            config['val_losses'] = val_losses
            
            # Guardar modelo
            save_data = {
                'model_state_dict': model.state_dict(),
                'train_losses': train_losses,
                'val_losses': val_losses,
                'learning_rates': learning_rates
            }
            torch.save(save_data, config['path'])
        
        # Evaluar modelo con métricas adicionales
        print(f"Evaluando {model_name} con métricas adicionales...")
        model.eval()
        
        all_targets = []
        all_preds = []
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = inputs.to(DEVICE)
                outputs = model(inputs)
                
                # Asegurar dimensión de canal si es necesario
                if len(outputs.shape) == 5 and outputs.shape[2] == 1:
                    outputs = outputs.squeeze(2)
                
                # NUEVO: Asegurar que outputs y targets tienen la misma forma antes de compararlos
                # Aplicar las mismas transformaciones que en train_hybrid_model
                if len(targets.shape) == 5 and len(outputs.shape) == 4:
                    if outputs.shape[1] == targets.shape[1]:
                        targets = targets[:, :, 0, :, :]
                    else:
                        # Seleccionar diagonal o promediar
                        targets_reshaped = []
                        for i in range(min(targets.shape[1], outputs.shape[1])):
                            if i < targets.shape[2]:
                                targets_reshaped.append(targets[:, i, i])
                            else:
                                targets_reshaped.append(targets[:, i, -1])
                        targets = torch.stack(targets_reshaped, dim=1)
                
                # Ajustar longitud de secuencia si difiere
                if len(outputs.shape) == 4 and len(targets.shape) == 4:
                    output_seq_len = outputs.shape[1]
                    target_seq_len = targets.shape[1]
                    
                    if output_seq_len > target_seq_len:
                        outputs = outputs[:, :target_seq_len]
                    elif output_seq_len < target_seq_len:
                        targets = targets[:, :output_seq_len]
                
                # Manejar dimensiones incompatibles
                if outputs.shape != targets.shape:
                    # Encontrar dimensiones comunes
                    if len(outputs.shape) == len(targets.shape):
                        min_shape = [min(o, t) for o, t in zip(outputs.shape, targets.shape)]
                        # Crear slices para cada dimensión
                        slices_out = tuple(slice(0, m) for m in min_shape)
                        slices_tgt = tuple(slice(0, m) for m in min_shape)
                        
                        outputs = outputs[slices_out]
                        targets = targets[slices_tgt]
                
                # IMPORTANTE: Verificar que ahora tienen la misma forma
                if outputs.shape != targets.shape:
                    print(f"⚠️ Advertencia: No se pudo ajustar dimensiones: outputs={outputs.shape}, targets={targets.shape}")
                    # En caso extremo, usar formas más simples
                    if len(outputs.shape) > 2 and len(targets.shape) > 2:
                        # Usar solo el primer elemento de cada secuencia/batch
                        outputs = outputs[0:1, 0:1] if len(outputs.shape) >= 3 else outputs[0:1]
                        targets = targets[0:1, 0:1] if len(targets.shape) >= 3 else targets[0:1]
                        print(f"   Usando formas simplificadas: outputs={outputs.shape}, targets={targets.shape}")
                
                # Transformar a numpy para métricas
                preds = outputs.cpu().numpy()
                targets_np = targets.numpy()
                
                # Almacenar para métricas globales
                all_targets.append(targets_np)
                all_preds.append(preds)
        
        # Concatenar todas las predicciones y targets
        try:
            all_targets = np.concatenate(all_targets, axis=0)
            all_preds = np.concatenate(all_preds, axis=0)
        except ValueError as e:
            print(f"❌ Error al concatenar: {str(e)}")
            print(f"Formas de datos: {[t.shape for t in all_targets]}")
            print(f"Formas de predicciones: {[p.shape for p in all_preds]}")
            # Usar solo el primer batch para evitar errores de concatenación
            all_targets = all_targets[0]
            all_preds = all_preds[0]
        
        # Asegurar que las dimensiones son compatibles antes de aplanar
        print(f"Forma final: all_targets={all_targets.shape}, all_preds={all_preds.shape}")
        
        # NUEVO: Asegurar que ambos arrays tengan la misma forma antes de aplanar
        if all_targets.shape != all_preds.shape:
            # Encontrar la forma más pequeña común
            common_shape = []
            for i in range(min(len(all_targets.shape), len(all_preds.shape))):
                common_shape.append(min(all_targets.shape[i], all_preds.shape[i]))
            
            # Crear slices para recortar
            slices = tuple(slice(0, dim) for dim in common_shape)
            all_targets = all_targets[slices]
            all_preds = all_preds[slices]
            print(f"⚠️ Ajustadas dimensiones a forma común: {all_targets.shape}")
        
        # Calcular métricas
        # Aplanar para métricas generales
        flat_targets = all_targets.flatten()
        flat_preds = all_preds.flatten()
        
        # Verificar que tengan la misma longitud
        assert flat_targets.shape == flat_preds.shape, f"Error: Las dimensiones siguen siendo diferentes: {flat_targets.shape} vs {flat_preds.shape}"
        
        # Eliminar valores NaN si existen
        mask = ~np.isnan(flat_targets) & ~np.isnan(flat_preds)
        flat_targets = flat_targets[mask]
        flat_preds = flat_preds[mask]
        
        # Calcular métricas
        mae = mean_absolute_error(flat_targets, flat_preds)
        rmse = np.sqrt(mean_squared_error(flat_targets, flat_preds))
        r2 = r2_score(flat_targets, flat_preds)
        corr = np.corrcoef(flat_targets, flat_preds)[0, 1]
        
        # Calcular MAPE evitando divisiones por cero
        mask_nonzero = flat_targets != 0
        mape = np.mean(np.abs((flat_targets[mask_nonzero] - flat_preds[mask_nonzero]) / flat_targets[mask_nonzero])) * 100
        
        # Almacenar métricas
        config['metrics'] = {
            'MAE': mae,
            'RMSE': rmse,
            'MAPE (%)': mape,
            'r': corr,
            'R²': r2
        }
        
        # Almacenar modelo en el diccionario
        config['model'] = model
        
        # Mostrar métricas
        print(f"\nMétricas para {model_name}:")
        print(f"  MAE: {mae:.4f}")
        print(f"  RMSE: {rmse:.4f}")
        print(f"  MAPE: {mape:.2f}%")
        print(f"  r (correlación): {corr:.4f}")
        print(f"  R²: {r2:.4f}")
    
    # Devolver diccionario con todos los modelos y resultados, así como el val_loader para visualizaciones
    return models, val_loader

# Función para generar visualizaciones comparativas
def visualize_model_comparisons(models_dict, val_loader=None):
    """
    Genera visualizaciones comparativas para todos los modelos
    
    Args:
        models_dict: Diccionario con modelos y sus métricas
        val_loader: DataLoader de validación para generar predicciones
    
    Returns:
        Path: Ruta al directorio con las visualizaciones
    """
    print("\n" + "="*70)
    print("GENERANDO VISUALIZACIONES COMPARATIVAS")
    print("="*70)
    
    # Crear directorio para visualizaciones
    vis_dir = MODELS_OUTPUT / 'visualization' / 'comparisons'
    vis_dir.mkdir(exist_ok=True, parents=True)
    
    # 1. Tabla comparativa de métricas
    model_names = list(models_dict.keys())
    metrics = ['MAE', 'RMSE', 'MAPE (%)', 'r', 'R²']
    
    # Crear DataFrame para tabla
    metrics_data = {metric: [] for metric in metrics}
    metrics_data['Modelo'] = model_names
    
    for model_name in model_names:
        for metric in metrics:
            value = models_dict[model_name]['metrics'].get(metric, np.nan)
            metrics_data[metric].append(value)
    
    metrics_df = pd.DataFrame(metrics_data)
    
    # Mostrar tabla
    print("\n📊 TABLA COMPARATIVA DE MÉTRICAS")
    print(metrics_df.to_string(index=False, float_format=lambda x: f"{x:.4f}"))
    
    # Guardar como CSV
    metrics_csv_path = vis_dir / 'metrics_comparison.csv'
    metrics_df.to_csv(metrics_csv_path, index=False, float_format='%.4f')
    print(f"\nTabla guardada en {metrics_csv_path}")
    
    # 2. Gráfico de barras comparativo de métricas
    plt.figure(figsize=(15, 8))
    
    # Iterar sobre las métricas
    for i, metric in enumerate(metrics):
        plt.subplot(2, 3, i+1)
        values = [models_dict[model]['metrics'].get(metric, np.nan) for model in model_names]
        
        # Crear gráfico de barras
        bars = plt.bar(model_names, values)
        
        # Añadir etiquetas de valor sobre cada barra
        for bar, value in zip(bars, values):
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{value:.4f}',
                    ha='center', va='bottom', rotation=0, fontsize=9)
        
        plt.title(f'Comparación de {metric}')
        plt.xticks(rotation=45, ha='right')
        plt.grid(alpha=0.3, axis='y')
        
        # Para métricas donde menor es mejor, destacar el mejor modelo
        if metric in ['MAE', 'RMSE', 'MAPE (%)']:
            best_idx = np.nanargmin(values)
            bars[best_idx].set_color('green')
        else:  # Para métricas donde mayor es mejor
            best_idx = np.nanargmax(values)
            bars[best_idx].set_color('green')
    
    plt.tight_layout()
    metrics_plot_path = vis_dir / 'metrics_comparison.png'
    plt.savefig(metrics_plot_path, dpi=300, bbox_inches='tight')
    print(f"Gráfico de métricas guardado en {metrics_plot_path}")
    plt.show()
    
    # 3. Curvas de aprendizaje comparativas (mejorado con más información)
    plt.figure(figsize=(20, 16))
    
    # 3.1 Pérdidas de entrenamiento
    plt.subplot(2, 2, 1)
    for model_name in model_names:
        train_losses = models_dict[model_name].get('train_losses', [])
        if train_losses:
            plt.plot(train_losses, label=model_name)
    
    plt.title('Pérdidas de Entrenamiento')
    plt.xlabel('Época')
    plt.ylabel('Pérdida')
    plt.grid(alpha=0.3)
    plt.legend()
    
    # 3.2 Pérdidas de validación
    plt.subplot(2, 2, 2)
    for model_name in model_names:
        val_losses = models_dict[model_name].get('val_losses', [])
        if val_losses:
            plt.plot(val_losses, label=model_name)
    
    plt.title('Pérdidas de Validación')
    plt.xlabel('Época')
    plt.ylabel('Pérdida')
    plt.grid(alpha=0.3)
    plt.legend()
    
    # 3.3 Pérdidas de ambas en escala logarítmica
    plt.subplot(2, 2, 3)
    for model_name in model_names:
        train_losses = models_dict[model_name].get('train_losses', [])
        val_losses = models_dict[model_name].get('val_losses', [])
        if train_losses:
            plt.semilogy(train_losses, linestyle='-', label=f'{model_name} (Train)')
        if val_losses:
            plt.semilogy(val_losses, linestyle='--', label=f'{model_name} (Val)')
    
    plt.title('Pérdidas de Entrenamiento y Validación (Log)')
    plt.xlabel('Época')
    plt.ylabel('Pérdida (log)')
    plt.grid(alpha=0.3)
    plt.legend()
    
    # 3.4 Comparación de convergencia (normalizada)
    plt.subplot(2, 2, 4)
    for model_name in model_names:
        val_losses = models_dict[model_name].get('val_losses', [])
        if val_losses and len(val_losses) > 1:
            # Normalizar pérdidas para comparar velocidad de convergencia
            norm_losses = [(loss - min(val_losses)) / (max(val_losses) - min(val_losses) + 1e-10) 
                          for loss in val_losses]
            plt.plot(norm_losses, label=model_name)
    
    plt.title('Convergencia Normalizada (Validación)')
    plt.xlabel('Época')
    plt.ylabel('Pérdida Normalizada')
    plt.grid(alpha=0.3)
    plt.legend()
    
    plt.tight_layout()
    learning_plot_path = vis_dir / 'learning_curves.png'
    plt.savefig(learning_plot_path, dpi=300, bbox_inches='tight')
    print(f"Gráfico de curvas de aprendizaje guardado en {learning_plot_path}")
    plt.show()
    
    # 4. Si hay un DataLoader de validación, generar ejemplos de predicciones
    if val_loader is not None and next(iter(models_dict.values())).get('model') is not None:
        try:
            # Obtener algunas muestras para visualizar
            print("\nGenerando ejemplos de predicciones...")
            
            # Tomar un batch como ejemplo
            sample_inputs, sample_targets = next(iter(val_loader))
            
            # Configurar figura para mostrar predicciones
            fig = plt.figure(figsize=(20, 5 * len(model_names)))
            
            # Generar predicciones con cada modelo
            for i, model_name in enumerate(model_names):
                model = models_dict[model_name].get('model')
                if model is None:
                    continue
                    
                # Modo evaluación
                model.eval()
                
                # Generar predicciones
                with torch.no_grad():
                    outputs = model(sample_inputs.to(next(model.parameters()).device))
                
                # Seleccionar un ejemplo (primer item del batch)
                if len(outputs.shape) == 5:  # [batch, seq, channel, H, W]
                    pred = outputs[0, 0, 0].cpu().numpy()
                elif len(outputs.shape) == 4:  # [batch, seq, H, W]
                    pred = outputs[0, 0].cpu().numpy()
                else:
                    pred = outputs[0].cpu().numpy()
                
                # Obtener target correspondiente
                if len(sample_targets.shape) == 5:  # [batch, seq, channel, H, W]
                    target = sample_targets[0, 0, 0].numpy()
                elif len(sample_targets.shape) == 4:  # [batch, seq, H, W]
                    target = sample_targets[0, 0].numpy()
                else:
                    target = sample_targets[0].numpy()
                
                # Mostrar predicción vs real
                plt.subplot(len(model_names), 3, i*3+1)
                plt.imshow(target, cmap='viridis')
                plt.colorbar(fraction=0.046, pad=0.04)
                plt.title(f'{model_name} - Valor Real')
                
                plt.subplot(len(model_names), 3, i*3+2)
                plt.imshow(pred, cmap='viridis')
                plt.colorbar(fraction=0.046, pad=0.04)
                plt.title(f'{model_name} - Predicción')
                
                plt.subplot(len(model_names), 3, i*3+3)
                diff = target - pred
                plt.imshow(diff, cmap='RdBu_r')
                plt.colorbar(fraction=0.046, pad=0.04)
                plt.title(f'{model_name} - Diferencia')
            
            plt.tight_layout()
            pred_plot_path = vis_dir / 'prediction_examples.png'
            plt.savefig(pred_plot_path, dpi=300, bbox_inches='tight')
            print(f"Ejemplos de predicción guardados en {pred_plot_path}")
            plt.show()
            
        except Exception as e:
            print(f"Error al generar ejemplos de predicción: {str(e)}")
    
    return vis_dir
# Añadir a la parte principal de ejecución - función ya definida previamente
# Añadir a la parte principal de ejecución
# Preparar datos para entrenamiento y evaluación
def prepare_train_val_data(ds_full, input_window=INPUT_WINDOW, output_horizon=OUTPUT_HORIZON):
    """
    Prepara los datos de entrenamiento y validación desde el dataset completo.
    
    Args:
        ds_full: Dataset completo con variables
        input_window: Tamaño de la ventana de entrada
        output_horizon: Horizonte de predicción
    
    Returns:
        X_train, y_train, X_val, y_val: Arrays de entrenamiento y validación
    """
    print("Preparando datos para entrenamiento y validación...")
    
    # Extraer precipitación como variable objetivo
    # Check if 'precipitacion' exists in the dataset
    if 'precipitacion' in ds_full.data_vars:
        precip = ds_full.precipitacion.values
        print("Using 'precipitacion' as target variable")
    # If not, check for 'total_precipitation'
    elif 'total_precipitation' in ds_full.data_vars:
        precip = ds_full.total_precipitation.values
        print("Using 'total_precipitation' as target variable")
    else:
        # If neither exists, raise an error
        raise ValueError("Could not find precipitation variable in dataset. Available variables: " + 
                        str(list(ds_full.data_vars.keys())))
    
    # Ensure precip has 3 dimensions (time, lat, lon)
    if len(precip.shape) != 3:
        raise ValueError(f"Expected precipitation data to have 3 dimensions (time, lat, lon), got {precip.shape}")
    
    n_times, height, width = precip.shape
    print(f"Precipitation data shape: {precip.shape}")
    
    # Extraer features, por ejemplo cluster y elevación
    features = []
    
    # Verificar si las features estáticas existen y expandirlas para todas las timesteps
    if 'cluster' in ds_full.data_vars:
        cluster_data = ds_full.cluster.values
        # Si cluster es 2D (lat, lon), expandir a 3D (time, lat, lon)
        if len(cluster_data.shape) == 2:
            cluster_data = np.repeat(cluster_data[np.newaxis, :, :], n_times, axis=0)
        features.append(cluster_data)
        print("Añadida variable 'cluster', shape:", cluster_data.shape)
    
    if 'elevation' in ds_full.data_vars:
        elev_data = ds_full.elevation.values
        # Si elevation es 2D (lat, lon), expandir a 3D (time, lat, lon)
        if len(elev_data.shape) == 2:
            elev_data = np.repeat(elev_data[np.newaxis, :, :], n_times, axis=0)
        features.append(elev_data)
        print("Añadida variable 'elevation', shape:", elev_data.shape)
    
    # Si no hay features específicas, usar precipitación histórica como feature
    if not features:
        print("No se encontraron features específicas, usando precipitación histórica")
        # Usar precipitación como característica, añadiendo una dimensión de canal
        features_array = precip.reshape(n_times, 1, height, width)
        print(f"Shape de features: {features_array.shape}")
    else:
        # Concatenar features en un solo array a lo largo de una nueva dimensión (canal)
        features_array = np.stack(features, axis=1)
        print(f"Shape de features combinadas: {features_array.shape}")
    
    # Crear ventanas deslizantes de manera segura
    X, y = [], []
    
    for i in range(n_times - input_window - output_horizon + 1):
        # Input: ventana de datos
        X.append(features_array[i:i+input_window].copy())
        # Output: horizonte de predicción 
        y.append(precip[i+input_window:i+input_window+output_horizon].copy())
    
    # Verificar formas antes de convertir
    if not X or not y:
        raise ValueError("No se pudieron crear ventanas válidas. Verifique los datos de entrada.")
        
    # Usar np.stack en lugar de np.array para garantizar arrays homogéneos
    X = np.stack(X)
    y = np.stack(y)
    
    print(f"Datos preparados - X shape: {X.shape}, y shape: {y.shape}")
    
    # División train/val (80/20)
    split_idx = int(0.8 * len(X))
    X_train, X_val = X[:split_idx], X[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]
    
    print(f"Train shapes - X: {X_train.shape}, y: {y_train.shape}")
    print(f"Val shapes - X: {X_val.shape}, y: {y_val.shape}")
    
    return X_train, y_train, X_val, y_val

# Función para entrenar modelos híbridos
def train_hybrid_model(name, model, train_loader, val_loader, epochs=100, patience=20, 
                      optimizer=None, criterion=None, scheduler=None, generate_plots=True):
    """
    Train a hybrid model with advanced optimizations and evaluation.
    """
    print(f"\n{'='*30}")
    print(f"TRAINING {name}")
    print(f"{'='*30}")
    
    device = next(model.parameters()).device
    
    # Import tqdm if available
    try:
        from tqdm import tqdm
    except ImportError:
        # Define simple version if not installed
        def tqdm(iterable, **kwargs):
            print(kwargs.get('desc', ''))
            return iterable
    
    # Use default optimizer and loss function if not provided
    if optimizer is None:
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
        print("Using default Adam optimizer with lr=1e-3")
    
    if criterion is None:
        criterion = nn.MSELoss()
        print("Using default MSELoss criterion")
    
    if scheduler is None:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=10, verbose=True
        )
        print("Using default ReduceLROnPlateau scheduler")
    
    # Metrics tracking
    train_losses = []
    val_losses = []
    learning_rates = []  # Track learning rates
    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None
    
    # Get data dimensions info
    sample_batch = next(iter(train_loader))
    sample_inputs, sample_targets = sample_batch
    input_seq_len = sample_inputs.shape[1] if len(sample_inputs.shape) >= 2 else 1
    target_seq_len = sample_targets.shape[1] if len(sample_targets.shape) >= 2 else 1
    print(f"INPUT SEQ LEN: {input_seq_len}, TARGET SEQ LEN: {target_seq_len}")
    
    for epoch in range(epochs):
        # Track current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        learning_rates.append(current_lr)
        
        # Training phase
        model.train()
        batch_losses = []
        
        for batch_data in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train] LR: {current_lr:.1e}", leave=False):
            inputs, targets = batch_data
            
            # Debug shapes for first batch in initial epoch
            if epoch == 0 and len(batch_losses) == 0:
                print(f"Train batch shapes - inputs: {inputs.shape}, targets: {targets.shape}")
            
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            
            # Remove singleton dimension if present
            if len(outputs.shape) == 5 and outputs.shape[2] == 1:  # [batch, seq, 1, H, W]
                outputs = outputs.squeeze(2)  # [batch, seq, H, W]
            
            # Handle 5D targets with double sequence dimension
            if len(targets.shape) == 5 and len(outputs.shape) == 4 and outputs.shape[1] == targets.shape[1]:
                targets = targets[:, :, 0, :, :]  # Convert [batch, seq, seq, H, W] to [batch, seq, H, W]
                print(f"Converted targets from shape {targets.shape} by removing redundant dim")

            # Handle 5D targets with different sequence dimensions
            elif len(targets.shape) == 5 and len(outputs.shape) == 4:
                # Select diagonal or average across 3rd dimension
                targets_reshaped = []
                for i in range(min(targets.shape[1], outputs.shape[1])):
                    # For each time step, keep first element or diagonal
                    if i < targets.shape[2]:
                        targets_reshaped.append(targets[:, i, i])
                    else:
                        targets_reshaped.append(targets[:, i, -1])
                
                # Stack to form [batch, seq, H, W]
                targets = torch.stack(targets_reshaped, dim=1)
                print(f"Converted targets using diagonal selection: {targets.shape}")
            
            # Adjust sequence length when output_seq_len != target_seq_len
            if len(outputs.shape) == 4 and len(targets.shape) == 4:
                output_seq_len = outputs.shape[1]
                target_seq_len = targets.shape[1]
                
                if output_seq_len > target_seq_len:
                    # If model produces longer sequence, use only the first target_seq_len elements
                    outputs = outputs[:, :target_seq_len]
                    print(f"Trimming outputs from {output_seq_len} to {target_seq_len}")
                elif output_seq_len < target_seq_len:
                    # If model produces shorter sequence, repeat last element
                    padding = targets[:, output_seq_len:target_seq_len].shape[1]
                    last_output = outputs[:, -1:].repeat(1, padding, 1, 1)
                    outputs = torch.cat([outputs, last_output], dim=1)
                    print(f"Extending outputs from {output_seq_len} to {target_seq_len}")
            
            # Handle different target formats
            if len(targets.shape) == 3 and len(outputs.shape) == 4:  # targets: [batch, H, W], outputs: [batch, seq, H, W]
                targets = targets.unsqueeze(1).repeat(1, outputs.shape[1], 1, 1)
                
            elif len(targets.shape) == 4 and len(outputs.shape) == 4 and targets.shape[1] == 1:
                targets = targets.repeat(1, outputs.shape[1], 1, 1)
            
            # Ensure compatible dimensions before calculating loss
            if outputs.shape != targets.shape:
                print(f"⚠️ Warning: Shapes still incompatible after adjustments:")
                print(f"   outputs.shape: {outputs.shape}, targets.shape: {targets.shape}")
                
                # Handle special case: [batch, seq, H, W] vs [batch, seq, seq, H, W]
                if len(targets.shape) == 5 and len(outputs.shape) == 4:
                    B, S1, S2, H, W = targets.shape
                    targets = targets.view(B, S1*S2, H, W)[:, :outputs.shape[1]]
                
                # Generic case: adjust to common dimensions
                if outputs.shape != targets.shape:
                    try:
                        min_shape = [min(o, t) for o, t in zip(outputs.shape, targets.shape) 
                                   if o is not None and t is not None]
                        slices_out = tuple(slice(0, m) for m in min_shape)
                        slices_tgt = tuple(slice(0, m) for m in min_shape)
                        
                        outputs = outputs[slices_out]
                        targets = targets[slices_tgt]
                    except:
                        # Last resort: flatten and trim
                        flat_out = outputs.flatten()[:10000]
                        flat_tgt = targets.flatten()[:10000]
                        outputs = flat_out.unsqueeze(0)
                        targets = flat_tgt.unsqueeze(0)
            
            # Calculate loss
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            batch_losses.append(loss.item())
            
        # Calculate average training loss
        train_loss = sum(batch_losses) / len(batch_losses)
        train_losses.append(train_loss)
        
        # Validation phase
        model.eval()
        val_batch_losses = []
        
        with torch.no_grad():
            for batch_data in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]", leave=False):
                inputs, targets = batch_data
                
                if epoch == 0 and len(val_batch_losses) == 0:
                    print(f"Val batch shapes - inputs: {inputs.shape}, targets: {targets.shape}")
                
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                # Forward pass
                outputs = model(inputs)
                
                # Remove singleton dimension if present
                if len(outputs.shape) == 5 and outputs.shape[2] == 1:
                    outputs = outputs.squeeze(2)
                
                # Apply same shape adjustments as in training
                if len(targets.shape) == 5 and len(outputs.shape) == 4 and outputs.shape[1] == targets.shape[1]:
                    targets = targets[:, :, 0, :, :]
                
                elif len(targets.shape) == 5 and len(outputs.shape) == 4:
                    targets_reshaped = []
                    for i in range(min(targets.shape[1], outputs.shape[1])):
                        if i < targets.shape[2]:
                            targets_reshaped.append(targets[:, i, i])
                        else:
                            targets_reshaped.append(targets[:, i, -1])
                    targets = torch.stack(targets_reshaped, dim=1)
                
                # Adjust sequence length
                if len(outputs.shape) == 4 and len(targets.shape) == 4:
                    output_seq_len = outputs.shape[1]
                    target_seq_len = targets.shape[1]
                    
                    if output_seq_len > target_seq_len:
                        outputs = outputs[:, :target_seq_len]
                    elif output_seq_len < target_seq_len:
                        padding = targets[:, output_seq_len:target_seq_len].shape[1]
                        last_output = outputs[:, -1:].repeat(1, padding, 1, 1)
                        outputs = torch.cat([outputs, last_output], dim=1)
                
                # Handle different target formats
                if len(targets.shape) == 3 and len(outputs.shape) == 4:
                    targets = targets.unsqueeze(1).repeat(1, outputs.shape[1], 1, 1)
                
                elif len(targets.shape) == 4 and len(outputs.shape) == 4 and targets.shape[1] == 1:
                    targets = targets.repeat(1, outputs.shape[1], 1, 1)
                
                # Ensure compatible dimensions
                if outputs.shape != targets.shape:
                    if len(targets.shape) == 5 and len(outputs.shape) == 4:
                        B, S1, S2, H, W = targets.shape
                        targets = targets.view(B, S1*S2, H, W)[:, :outputs.shape[1]]
                    
                    if outputs.shape != targets.shape:
                        try:
                            min_shape = [min(o, t) for o, t in zip(outputs.shape, targets.shape) 
                                      if o is not None and t is not None]
                            slices_out = tuple(slice(0, m) for m in min_shape)
                            slices_tgt = tuple(slice(0, m) for m in min_shape)
                            
                            outputs = outputs[slices_out]
                            targets = targets[slices_tgt]
                        except:
                            flat_out = outputs.flatten()[:10000]
                            flat_tgt = targets.flatten()[:10000]
                            outputs = flat_out.unsqueeze(0)
                            targets = flat_tgt.unsqueeze(0)
                
                # Calculate validation loss
                loss = criterion(outputs, targets)
                val_batch_losses.append(loss.item())
        
        # Calculate average validation loss
        val_loss = sum(val_batch_losses) / len(val_batch_losses)
        val_losses.append(val_loss)
        
        # Update learning rate if scheduler is provided
        if scheduler is not None:
            scheduler.step(val_loss)
        
        # Print progress
        if epoch % 5 == 0 or epoch == epochs - 1:
            print(f"{name} - Epoch {epoch+1}/{epochs}: Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}, LR: {current_lr:.1e}")
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            # Save the best model state
            best_model_state = {k: v.clone() for k, v in model.state_dict().items()}
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs!")
                break
    
    # Restore best model state
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    # Generate detailed training metrics visualization if requested
    if generate_plots:
        visualize_training_metrics(name, train_losses, val_losses, learning_rates)
    
    print(f"Training completed for {name}. Best validation loss: {best_val_loss:.6f}")
    return model, train_losses, val_losses, learning_rates

# Nueva función para visualizar métricas de entrenamiento detalladas
def visualize_training_metrics(model_name, train_losses, val_losses, learning_rates):
    """
    Generate detailed visualizations for model training metrics.
    
    Args:
        model_name: Name of the model
        train_losses: List of training losses per epoch
        val_losses: List of validation losses per epoch
        learning_rates: List of learning rates per epoch
    """
    # Create figure with 2x2 subplots for comprehensive analysis
    plt.figure(figsize=(16, 14))
    
    # 1. Training and validation losses
    plt.subplot(2, 2, 1)
    epochs = range(1, len(train_losses) + 1)
    plt.plot(epochs, train_losses, 'b-', label='Training loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation loss')
    plt.title(f'{model_name} - Loss Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 2. Log scale losses - better for visualizing small changes
    plt.subplot(2, 2, 2)
    plt.semilogy(epochs, train_losses, 'b-', label='Training loss')
    plt.semilogy(epochs, val_losses, 'r-', label='Validation loss')
    plt.title(f'{model_name} - Loss Curves (Log Scale)')
    plt.xlabel('Epochs')
    plt.ylabel('Log Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 3. Loss difference (for overfitting detection)
    plt.subplot(2, 2, 3)
    loss_diff = [val - train for train, val in zip(train_losses, val_losses)]
    plt.plot(epochs, loss_diff, 'g-')
    plt.fill_between(epochs, 0, loss_diff, alpha=0.3, color='g')
    plt.title(f'{model_name} - Validation-Training Loss Gap')
    plt.xlabel('Epochs')
    plt.ylabel('Difference (Val - Train)')
    plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
    plt.grid(True, alpha=0.3)
    
    # 4. Learning rate over time
    plt.subplot(2, 2, 4)
    plt.semilogy(epochs, learning_rates, 'c-')
    plt.title(f'{model_name} - Learning Rate')
    plt.xlabel('Epochs')
    plt.ylabel('Learning Rate (log scale)')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save the plot
    save_dir = MODELS_OUTPUT / 'visualization' / 'training_metrics'
    save_dir.mkdir(exist_ok=True, parents=True)
    plt.savefig(save_dir / f'{model_name}_training_metrics.png', dpi=300, bbox_inches='tight')
    print(f"Training metrics visualization saved to {save_dir / f'{model_name}_training_metrics.png'}")
    plt.show()