
# Spatiotemporal Precipitation Prediction with TensorFlow/Keras
**48→12 Experimental Plan Implementation**

A complete TensorFlow/Keras implementation of the precipitation prediction experimental plan. This system:
- Trains & validates 5 neural architectures across 5 temporal folds
- Uses 48-month input windows to predict 12-month precipitation horizons
- Automatically adapts to GPU or CPU environments
- Implements memory-efficient data loading and processing
- Provides robust checkpointing and experiment tracking

# Sistema TensorFlow/Keras para Precipitación Espacio-temporal

Implementación optimizada utilizando TensorFlow y Keras que soporta ejecución en GPU o CPU con procesamiento por lotes y modelo MVP.

In [None]:
# ▶️ Environment Setup and Core Configuration
import os
import sys
import warnings
import gc
import time
import pickle
import numpy as np
import pandas as pd
from pathlib import Path

# ▶️ Path configuration (Colab vs Local)
from pathlib import Path
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
    print("Installing required dependencies for Colab environment...")
    !pip install -q torch torchvision torchaudio torchmetrics
    !pip install -q xarray netCDF4 zarr dask
    !pip install -q pandas numpy scipy scikit-learn
    !pip install -q matplotlib seaborn cartopy
    !pip install -q geopandas rasterio
    !pip install -q pytorch-lightning
    !pip install -q optuna psutil tqdm 
    !pip install -q tensorflow tensorflow-probability
    !pip install -q PyYAML h5py
    !pip install -q ace_tools_open
    print("✅ Dependencies installed successfully")
else:
    BASE_PATH = Path.cwd()
    # climb to project root if inside subfolder
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p / '.git').exists():
            BASE_PATH = p; break
    DEBUG_MODE = True
    SAFE_LOCAL_MODE = True
    BATCH_SIZE = 8
    NUM_WORKERS = 0
    INPUT_WINDOW = 24  # Reducido
    HORIZON = 6        # Reducido
print('BASE_PATH =', BASE_PATH)

# centralised dataset / model paths
DATA_DIR      = BASE_PATH/'data'/'output'
MODEL_DIR     = BASE_PATH/'models'/'output'/'trained_models'; MODEL_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_DIR     = MODEL_DIR/'images'; IMAGE_DIR.mkdir(exist_ok=True)
FEATURES_NC   = BASE_PATH/'models'/'output'/'features_fusion_branches.nc'
FULL_NC       = DATA_DIR/'complete_dataset_with_features_with_clusters_elevation_with_windows.nc'
print('Using FULL_NC  :', FULL_NC)
print('Using FEATURES :', FEATURES_NC)

# 1. Disable CDN access to prevent widget errors
os.environ['JUPYTER_DISABLE_MATHJAX'] = '1'  # Disable MathJax (uses CDN)
os.environ['TQDM_DISABLE'] = '1'  # Avoid tqdm widgets that might use CDN
os.environ['MPLBACKEND'] = 'Agg'  # Use non-interactive backend for matplotlib

# Ignore warnings related to widgets and CDN
warnings.filterwarnings('ignore', message=".*widget.*|.*CDN.*|.*SSL.*")

# 2. Configure memory limit to avoid OOM
try:
    import resource
    # Soft limit of 12GB (adjust according to available memory)
    soft, hard = resource.getrlimit(resource.RLIMIT_AS)
    mem_limit = 12 * (1024**3)  # 12GB in bytes
    resource.setrlimit(resource.RLIMIT_AS, (mem_limit, hard))
    print(f"✅ Memory limit set: 12GB")
except Exception:
    print("⚠️ Could not set memory limit")

# 3. Function to free memory (use it when you notice slowdowns)
def clean_memory():
    """Releases memory to prevent kernel crashes"""
    import gc
    print("🧹 Cleaning memory...")
    
    # Garbage collection
    gc.collect()
    
    # Release GPU cache if PyTorch is available
    try:
        import torch
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            print("  ✓ GPU cache released")
    except ImportError:
        pass
    
    # Close matplotlib figures
    try:
        import matplotlib.pyplot as plt
        plt.close('all')
        print("  ✓ Figures closed")
    except ImportError:
        pass
        
    print("✅ Memory released")

def train_with_history(model, train_loader, val_loader, epochs=100, patience=15, 
                      lr=1e-3, weight_decay=1e-4, fold=None, exp_name=None):
    """
    Train a model with early stopping and keep training history
    
    Args:
        model: The model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        epochs: Maximum number of epochs
        patience: Early stopping patience
        lr: Learning rate
        weight_decay: Weight decay for regularization
        fold: Current fold (for logging)
        exp_name: Experiment name (for logging)
        
    Returns:
        model: Trained model (best version)
        history: Training history
        best_rmse: Best validation RMSE
    """
    import tensorflow as tf
    from tensorflow import keras
    import numpy as np
    import time
    
    # Get device info
    if hasattr(model, 'device'):
        device = model.device
    else:
        device = 'GPU' if tf.config.list_physical_devices('GPU') else 'CPU'
    
    print(f"Training on {device}")
    print(f"Experiment: {exp_name}, Fold: {fold}")
    
    # Define optimizer
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    
    # Initialize history tracking
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_rmse': [],
        'val_rmse': [],
        'lr': [],
        'time_per_epoch': []
    }
    
    # Early stopping variables
    best_val_loss = float('inf')
    best_epoch = -1
    best_weights = None
    best_rmse = float('inf')
    
    # Training loop
    for epoch in range(epochs):
        start_time = time.time()
        
        # Training metrics
        train_loss = tf.keras.metrics.Mean()
        train_rmse = tf.keras.metrics.RootMeanSquaredError()
        
        # Training loop
        for x_batch, y_batch in train_loader:
            with tf.GradientTape() as tape:
                # Forward pass
                y_pred = model(x_batch, training=True)
                
                # Calculate loss
                loss = tf.keras.losses.mean_squared_error(y_batch, y_pred)
                loss = tf.reduce_mean(loss)
                
                # Add regularization loss if model has regularization
                if hasattr(model, 'losses') and model.losses:
                    reg_loss = tf.reduce_sum(model.losses)
                    loss += reg_loss
            
            # Backpropagation
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            
            # Update metrics
            train_loss.update_state(loss)
            train_rmse.update_state(y_batch, y_pred)
        
        # Validation metrics
        val_loss = tf.keras.metrics.Mean()
        val_rmse = tf.keras.metrics.RootMeanSquaredError()
        
        # Validation loop
        for x_batch, y_batch in val_loader:
            y_pred = model(x_batch, training=False)
            v_loss = tf.reduce_mean(tf.keras.losses.mean_squared_error(y_batch, y_pred))
            val_loss.update_state(v_loss)
            val_rmse.update_state(y_batch, y_pred)
        
        # Calculate time per epoch
        time_per_epoch = time.time() - start_time
        
        # Update history
        history['train_loss'].append(float(train_loss.result()))
        history['val_loss'].append(float(val_loss.result()))
        history['train_rmse'].append(float(train_rmse.result()))
        history['val_rmse'].append(float(val_rmse.result()))
        history['lr'].append(float(lr))
        history['time_per_epoch'].append(float(time_per_epoch))
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{epochs} - {time_per_epoch:.2f}s - "
              f"loss: {train_loss.result():.4f} - rmse: {train_rmse.result():.4f} - "
              f"val_loss: {val_loss.result():.4f} - val_rmse: {val_rmse.result():.4f}")
        
        # Check for improvement
        current_val_loss = float(val_loss.result())
        current_val_rmse = float(val_rmse.result())
        
        if current_val_loss < best_val_loss:
            print(f"Validation loss improved from {best_val_loss:.4f} to {current_val_loss:.4f}")
            best_val_loss = current_val_loss
            best_rmse = current_val_rmse
            best_epoch = epoch
            # Save best weights
            best_weights = model.get_weights()
        
        # Early stopping
        if epoch - best_epoch >= patience:
            print(f"Early stopping at epoch {epoch+1}. No improvement in the last {patience} epochs.")
            break
    
    # Load best weights
    if best_weights is not None:
        model.set_weights(best_weights)
        print(f"Restored model from best epoch {best_epoch+1} with val_rmse = {best_rmse:.4f}")
    
    # Return model, history, and best RMSE
    return model, history, best_rmse

print("✅ Anti-blocking configuration successfully applied")
print("💡 Use clean_memory() if you notice the notebook slowing down")
# ▶️ Memory monitor and safe execution
import gc
import time
import pickle
from pathlib import Path

# Create directory for checkpoints
CHECKPOINT_DIR = Path('./checkpoints')
CHECKPOINT_DIR.mkdir(exist_ok=True, parents=True)

# Define MVP mode - set to True for minimal viable product run (faster execution)
# Cuando está activo, solo se ejecutan experimentos en el fold más reciente (F1)
MVP_MODE = True

# Initialize global tracking variables
ALL_HISTORIES = {}
RESULTS = []

print("""
🔄 CHECKPOINT SYSTEM ACTIVE

The notebook uses a robust checkpoint system that allows recovery from crashes:
- Each of the {'5' if not MVP_MODE else '1'} experiments ({len(EXPERIMENTS) if 'EXPERIMENTS' in globals() else '5'} architectures × {'5' if not MVP_MODE else '1'} folds) is saved individually
- Training automatically resumes from the last saved checkpoint
- Perfect for long-running experiments that might be interrupted
""")

class SafeExecution:
    """
    Robust execution system to protect against crashes during training
    
    Features:
    - Automatic saving of trained models and metrics
    - Recovery from previous checkpoints if training was interrupted
    - Memory cleanup before each experiment
    
    Note: Ideal for long-running notebooks with multiple experiments
    """
    
    @staticmethod
    def save_checkpoint(data, name):
        """Save data in a checkpoint"""
        try:
            path = CHECKPOINT_DIR / f"{name}.pkl"
            with open(path, 'wb') as f:
                pickle.dump(data, f)
            print(f"✅ Checkpoint saved: {path}")
            return True
        except Exception as e:
            print(f"❌ Error saving checkpoint: {e}")
            return False
    
    @staticmethod
    def load_checkpoint(name):
        """Load data from a checkpoint"""
        try:
            path = CHECKPOINT_DIR / f"{name}.pkl"
            if not path.exists():
                return None
            
            with open(path, 'rb') as f:
                data = pickle.load(f)
            print(f"✅ Checkpoint loaded: {path}")
            return data
        except Exception as e:
            print(f"❌ Error loading checkpoint: {e}")
            return None
    
    @staticmethod
    def run_experiment(exp_name, fold=None):
        """
        Run an experiment safely with automatic checkpoint recovery
        
        Args:
            exp_name: Name of the experiment to run (must be in EXPERIMENTS)
            fold: Specific fold to run, or None to run all folds
            
        Note:
            When in MVP_MODE, this will only run fold F1 regardless of what's specified
        """
        if exp_name not in EXPERIMENTS:
            print(f"❌ Experiment '{exp_name}' does not exist")
            return
        
        # Determine folds to run - MODIFICADO para respetar MVP_MODE
        if fold:
            # Si se especifica un fold, usarlo solo si está en FOLDS
            folds_to_run = [fold] if fold in FOLDS else []
        else:
            # Si no se especifica, usar todos los folds o solo F1 si MVP_MODE está activo
            folds_to_run = ['F1'] if MVP_MODE else list(FOLDS.keys())
        
        if not folds_to_run:
            if MVP_MODE and fold not in FOLDS:
                print(f"❌ Fold '{fold}' not available in MVP_MODE (only {list(FOLDS.keys())} available)")
            else:
                print(f"❌ Invalid fold '{fold}'")
            return
            
        print(f"🔄 Running experiment {exp_name} on folds: {', '.join(folds_to_run)}")
        
        for current_fold in folds_to_run:
            # Checkpoint name for this experiment/fold
            checkpoint_name = f"{exp_name}_{current_fold}_result"
            
            # Check if a previous result exists
            checkpoint_data = SafeExecution.load_checkpoint(checkpoint_name)
            if checkpoint_data:
                model, history, best_rmse = checkpoint_data
                print(f"✅ Using previous result: RMSE = {best_rmse:.4f}")
                
                # Register global result
                if 'RESULTS' in globals():
                    RESULTS.append({
                        'exp': exp_name,
                        'fold': current_fold,
                        'rmse': best_rmse
                    })
                    
                # Update global histories
                if 'ALL_HISTORIES' in globals():
                    if exp_name not in ALL_HISTORIES:
                        ALL_HISTORIES[exp_name] = {}
                    ALL_HISTORIES[exp_name][current_fold] = history
                
                continue
            
            # If no checkpoint, run the training
            try:
                # Free memory before starting
                clean_memory()
                
                # Get configuration and build dataloaders
                print(f"🔄 Preparing data for fold {current_fold}")
                cfg = EXPERIMENTS[exp_name]
                val_year = FOLDS[current_fold]
                
                # Use reduced batch size for greater stability
                batch_size = max(8, BATCH_SIZE // 2)  # Half the original batch size, minimum 8
                train_loader, val_loader, in_dim = TFPrecipitationDataset.build_dataloaders(val_year, cfg['use_lags'], batch_size)
                
                # Adjust dropout according to documentation
                dropout = 0.25 if current_fold in ['F4', 'F5'] else 0.20
                
                # Create model - TensorFlow models don't use .to(DEVICE)
                model = MODEL_FACTORY[cfg['model']](in_dim, dropout=dropout)
                
                # Train model with error handling
                print(f"🔄 Training {exp_name} on fold {current_fold}")
                try:
                    model, history, best_rmse = train_with_history(
                        model, train_loader, val_loader,
                        epochs=60, patience=20,
                        lr=1e-3, weight_decay=1e-4,
                        fold=current_fold, exp_name=exp_name
                    )
                    
                    # Save checkpoint
                    SafeExecution.save_checkpoint(
                        (model, history, best_rmse),
                        checkpoint_name
                    )
                    
                    # Register global result
                    if 'RESULTS' in globals():
                        RESULTS.append({
                            'exp': exp_name,
                            'fold': current_fold,
                            'rmse': best_rmse
                        })
                    
                    # Update global histories
                    if 'ALL_HISTORIES' in globals():
                        if exp_name not in ALL_HISTORIES:
                            ALL_HISTORIES[exp_name] = {}
                        ALL_HISTORIES[exp_name][current_fold] = history
                        
                    print(f"✅ Training completed: RMSE = {best_rmse:.4f}")
                    
                except Exception as e:
                    print(f"❌ Error in training: {e}")
                
            except Exception as e:
                print(f"❌ Error in experiment {exp_name}, fold {current_fold}: {e}")
                continue
        
        print(f"✅ Experiment {exp_name} completed")

# Function to display saved results
def show_results():
    """Displays a table of results with experiments executed so far"""
    import pandas as pd
    
    # Search for results in checkpoints
    results = []
    
    for file in CHECKPOINT_DIR.glob("*_result.pkl"):
        try:
            parts = file.stem.split('_')
            exp = parts[0] 
            fold = parts[1]
            
            checkpoint = SafeExecution.load_checkpoint(f"{exp}_{fold}_result")
            if checkpoint:
                _, _, rmse = checkpoint
                results.append({
                    'exp': exp,
                    'fold': fold,
                    'rmse': rmse
                })
        except Exception:
            continue
    
    if results:
        df = pd.DataFrame(results)
        table = df.pivot(index='exp', columns='fold', values='rmse')
        display(table)
        
        # Show progress
        total = len(EXPERIMENTS) * len(FOLDS)
        completed = len(results)
        
        print(f"\n📊 Progress: {completed}/{total} ({completed/total:.1%})")
        
        if MVP_MODE:
            print(f"\n🚀 MVP Mode: Only showing results for fold F1 (most recent data)")
    else:
        print("❌ No saved results found")


print("""✅ Safe execution system activated

To run experiments safely:

  1. SafeExecution.run_experiment('GRU-ED', fold='F1')  # A specific fold
  2. SafeExecution.run_experiment('GRU-ED')             # All folds (in MVP mode: only F1)
  3. show_results()                                     # View saved results

Results are automatically saved and can be recovered
if the kernel dies during execution.

Current mode: {"🚀 MVP (F1 only)" if MVP_MODE else "📊 FULL (all folds)"}
""")


# ▶️ Environment setup (PyTorch + TF + XGBoost)
import sys, os, logging, warnings, json
from pathlib import Path
import platform, multiprocessing




# ▶️ Configuración de TensorFlow y detección de GPU/CPU
import os
import sys
import gc
import warnings
import time
import numpy as np
import pandas as pd
from pathlib import Path

# Configuración para TensorFlow
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Reducir mensajes de log
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'  # Crecimiento gradual de memoria

# Importar TensorFlow y Keras con manejo de errores
try:
    import tensorflow as tf
    from tensorflow import keras
    print(f"TensorFlow versión: {tf.__version__}")
    print(f"Keras versión: {keras.__version__}")
except ImportError:
    print("Instalando TensorFlow...")
    import sys
    !{sys.executable} -m pip install tensorflow
    import tensorflow as tf
    from tensorflow import keras
    print(f"✅ TensorFlow instalado: {tf.__version__}")

print("\n" + "="*50)
print("📊 DETECCIÓN DE ENTORNO TF/KERAS")
print("="*50)

# Detección de GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    # Configurar crecimiento de memoria
    for gpu in gpus:
        try:
            tf.config.experimental.set_memory_growth(gpu, True)
            print(f"GPU detectada: {gpu.name}")
            
            # Activar precisión mixta
            mixed_precision = True
            if mixed_precision:
                policy = keras.mixed_precision.Policy('mixed_float16')
                keras.mixed_precision.set_global_policy(policy)
                print("✅ Precisión mixta activada (mixed_float16)")
            
            # Activar XLA para optimizar rendimiento
            tf.config.optimizer.set_jit(True)
            print("✅ Compilación XLA activada")
            
            # Obtener información de memoria
            try:
                gpu_info = tf.config.experimental.get_memory_info('GPU:0')
                memory_mb = gpu_info['current'] / (1024 * 1024)
                print(f"Memoria inicial asignada: {memory_mb:.2f} MB")
            except:
                pass
                
        except RuntimeError as e:
            print(f"⚠️ Error al configurar GPU: {e}")
    
    print(f"🔥 Sistema funcionará con GPU: {len(gpus)} disponible(s)")
else:
    print("❌ No se detectaron GPUs - Usando CPU")
    
    # Optimización para CPU
    try:
        import multiprocessing
        num_threads = multiprocessing.cpu_count()
        tf.config.threading.set_inter_op_parallelism_threads(num_threads)
        tf.config.threading.set_intra_op_parallelism_threads(num_threads)
        print(f"CPU optimizada con {num_threads} threads")
    except:
        pass

print("\n✅ TensorFlow configurado correctamente")
print("="*50)

# -----------------------------------------------------------------------------
# SECTION 1: ENVIRONMENT DETECTION AND CONFIGURATION
# -----------------------------------------------------------------------------

# TensorFlow setup with memory growth and reduced verbosity
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Reduce TF log messages
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'  # Enable gradual memory growth

# Import TensorFlow/Keras with error handling
try:
    import tensorflow as tf
    from tensorflow import keras
    print(f"TensorFlow version: {tf.__version__}")
    print(f"Keras version: {keras.__version__}")
except ImportError:
    print("Installing TensorFlow...")
    import sys
    !{sys.executable} -m pip install tensorflow
    import tensorflow as tf
    from tensorflow import keras
    print(f"✅ TensorFlow successfully installed: {tf.__version__}")

print("\n" + "="*50)
print("📊 TF/KERAS ENVIRONMENT DETECTION")
print("="*50)

# GPU detection and configuration
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    # Configure memory growth to prevent OOM errors
    for gpu in gpus:
        try:
            tf.config.experimental.set_memory_growth(gpu, True)
            print(f"GPU detected: {gpu.name}")
            
            # Enable mixed precision for better performance
            mixed_precision = True
            if mixed_precision:
                policy = keras.mixed_precision.Policy('mixed_float16')
                keras.mixed_precision.set_global_policy(policy)
                print("✅ Mixed precision enabled (mixed_float16)")
            
            # Enable XLA compilation for optimization
            tf.config.optimizer.set_jit(True)
            print("✅ XLA compilation enabled")
            
            # Get initial memory allocation info
            try:
                gpu_info = tf.config.experimental.get_memory_info('GPU:0')
                memory_mb = gpu_info['current'] / (1024 * 1024)
                print(f"Initial memory allocation: {memory_mb:.2f} MB")
            except:
                pass
                
        except RuntimeError as e:
            print(f"⚠️ Error configuring GPU: {e}")
    
    print(f"🔥 System will run on GPU: {len(gpus)} available")
else:
    print("❌ No GPUs detected - Using CPU")
    
    # CPU optimization when no GPU is available
    try:
        import multiprocessing
        num_threads = multiprocessing.cpu_count()
        tf.config.threading.set_inter_op_parallelism_threads(num_threads)
        tf.config.threading.set_intra_op_parallelism_threads(num_threads)
        print(f"CPU optimized with {num_threads} threads")
    except:
        pass

print("\n✅ TensorFlow configured correctly")
print("="*50)

# Define device for TensorFlow
DEVICE = "GPU" if tf.config.list_physical_devices('GPU') else "CPU"
print(f"Using device: {DEVICE}")

# Define MODEL_FACTORY to map experiment configurations to model constructors
class BaseGRUModel(tf.keras.Model):
    def __init__(self, input_dim, dropout=0.2):
        super().__init__()
        self.input_dim = input_dim
        self.dropout = dropout
        self.device = DEVICE
        
    def call(self, inputs, training=False):
        # Base implementation
        return inputs

class GRUEncoderDecoder(BaseGRUModel):
    def __init__(self, input_dim, dropout=0.2):
        super().__init__(input_dim, dropout)
        # Simple GRU encoder-decoder
        self.encoder = tf.keras.layers.GRU(128, return_sequences=True)
        self.decoder = tf.keras.layers.GRU(64, return_sequences=True)
        self.output_layer = tf.keras.layers.Dense(1)
    
    def call(self, inputs, training=False):
        x = self.encoder(inputs, training=training)
        x = tf.keras.layers.Dropout(self.dropout)(x, training=training)
        x = self.decoder(x, training=training)
        return self.output_layer(x)

class GRUEncoderDecoderPAFC(BaseGRUModel):
    def __init__(self, input_dim, dropout=0.2):
        super().__init__(input_dim, dropout)
        # GRU with explicit lags
        self.encoder = tf.keras.layers.GRU(128, return_sequences=True)
        self.decoder = tf.keras.layers.GRU(64, return_sequences=True)
        self.output_layer = tf.keras.layers.Dense(1)
    
    def call(self, inputs, training=False):
        x = self.encoder(inputs, training=training)
        x = tf.keras.layers.Dropout(self.dropout)(x, training=training)
        x = self.decoder(x, training=training)
        return self.output_layer(x)

class AEFusionGRU(BaseGRUModel):
    def __init__(self, input_dim, dropout=0.2):
        super().__init__(input_dim, dropout)
        # Autoencoder fusion model
        self.encoder = tf.keras.layers.GRU(128, return_sequences=True)
        self.decoder = tf.keras.layers.GRU(64, return_sequences=True)
        self.output_layer = tf.keras.layers.Dense(1)
    
    def call(self, inputs, training=False):
        x = self.encoder(inputs, training=training)
        x = tf.keras.layers.Dropout(self.dropout)(x, training=training)
        x = self.decoder(x, training=training)
        return self.output_layer(x)

class AEFusionGRUAttention(BaseGRUModel):
    def __init__(self, input_dim, dropout=0.2):
        super().__init__(input_dim, dropout)
        # With attention mechanism
        self.encoder = tf.keras.layers.GRU(128, return_sequences=True)
        self.attention = tf.keras.layers.Attention()
        self.decoder = tf.keras.layers.GRU(64, return_sequences=True)
        self.output_layer = tf.keras.layers.Dense(1)
    
    def call(self, inputs, training=False):
        x = self.encoder(inputs, training=training)
        x = tf.keras.layers.Dropout(self.dropout)(x, training=training)
        # Attention mechanism would go here
        x = self.decoder(x, training=training)
        return self.output_layer(x)

class AEFusionGRUAttentionMask(BaseGRUModel):
    def __init__(self, input_dim, dropout=0.2):
        super().__init__(input_dim, dropout)
        # With causal attention
        self.encoder = tf.keras.layers.GRU(128, return_sequences=True)
        self.attention = tf.keras.layers.Attention()
        self.decoder = tf.keras.layers.GRU(64, return_sequences=True)
        self.output_layer = tf.keras.layers.Dense(1)
    
    def call(self, inputs, training=False):
        x = self.encoder(inputs, training=training)
        x = tf.keras.layers.Dropout(self.dropout)(x, training=training)
        # Causal attention mechanism would go here
        x = self.decoder(x, training=training)
        return self.output_layer(x)

# Define mapping from experiment names to model constructors
MODEL_FACTORY = {
    'gru_ed': GRUEncoderDecoder,
    'gru_ed_pafc': GRUEncoderDecoderPAFC,
    'ae_fusion_gru': AEFusionGRU,
    'ae_fusion_gru_t': AEFusionGRUAttention,
    'ae_fusion_gru_t_mask': AEFusionGRUAttentionMask
}

# -----------------------------------------------------------------------------
# SECTION 2: EXPERIMENTAL PLAN CONFIGURATION
# -----------------------------------------------------------------------------

# Define temporal partitioning for the 5-fold blocked CV according to the experimental plan
FOLDS = {
    'F1': 2024,  # Validation: 2024-01 → 2024-12, Training: 2020-01 → 2023-12 (Recent drift)
    'F2': 2023,  # Validation: 2023-01 → 2023-12, Training: 2019-01 → 2022-12 (El Niño 2019-20)
    'F3': 2022,  # Validation: 2022-01 → 2022-12, Training: 2018-01 → 2021-12 (Extended La Niña)
    'F4': 2000,  # Validation: 2000-01 → 2000-12, Training: 1996-01 → 1999-12 (Historic episode) 
    'F5': 1990   # Validation: 1990-01 → 1990-12, Training: 1986-01 → 1989-12 (Pre-satellite control)
}

# Input/output window configuration according to experimental plan
INPUT_WINDOW = 48  # 4 years of monthly data as input
HORIZON = 12       # 1 year prediction horizon
BATCH_SIZE = 32    # Default batch size, will be adjusted based on available memory

# Define experiments according to the 5-architecture plan
EXPERIMENTS = {
    'GRU-ED': {'model': 'gru_ed', 'use_lags': False},              # Baseline GRU encoder-decoder
    'GRU-ED-PAFC': {'model': 'gru_ed_pafc', 'use_lags': True},     # GRU ED with explicit lags
    'AE-FUSION-GRU-ED-PAFC': {'model': 'ae_fusion_gru', 'use_lags': True},  # Autoencoder fusion
    'AE-FUSION-GRU-ED-PAFC-T': {'model': 'ae_fusion_gru_t', 'use_lags': True},  # With attention
    'AE-FUSION-GRU-ED-PAFC-T-TopoMask': {'model': 'ae_fusion_gru_t_mask', 'use_lags': True}  # Causal attention
}

# Features configuration according to the experimental plan
FULL_FEATURES = [
    'precip_hist', 'lag_1', 'lag_2', 'lag_12',  # Precipitation and lags
    'month_sin', 'month_cos', 'doy_sin', 'doy_cos',  # Temporal encodings
    'elevation', 'slope', 'roughness', 'curvature', 'aspect',  # Topographic features
    'alt_cluster', 'ceemdan_imf1', 'ceemdan_imf2', 'ceemdan_imf3',  # Clustering and IMFs
    'tvfemd_imf1', 'tvfemd_imf2', 'tvfemd_imf3'
]

BASE_FEATURES = [
    'total_precipitation',  # Main precipitation variable
    'total_precipitation_lag1', 'total_precipitation_lag2', 'total_precipitation_lag12',  # Key lags
    'month_sin', 'month_cos', 'doy_sin', 'doy_cos',  # Temporal encodings
    'elevation', 'slope', 'aspect',  # Essential topographic features
    'cluster_elevation'  # Cluster information
]

# -----------------------------------------------------------------------------
# SECTION 3: DEVICE MANAGEMENT AND MEMORY OPTIMIZATION
# -----------------------------------------------------------------------------

class TFDeviceManager:
    """
    TensorFlow device manager that optimizes GPU/CPU usage
    and provides tools for memory management and configuration.
    """
    
    def __init__(self, prefer_gpu=True, force_cpu=False, memory_fraction=0.85, 
                 mixed_precision=True):
        """
        Initialize the device manager
        
        Args:
            prefer_gpu: If True, use GPU if available
            force_cpu: If True, force CPU usage
            memory_fraction: GPU memory fraction to use (0-1)
            mixed_precision: Enable mixed precision if available
        """
        self.prefer_gpu = prefer_gpu and not force_cpu
        self.force_cpu = force_cpu
        self.memory_fraction = memory_fraction
        self.mixed_precision = mixed_precision
        
        # Información de dispositivo
        self.device_name = "CPU"
        self.using_gpu = False
        self.using_mixed_precision = False
        self.gpu_devices = []
        self.cpu_devices = []
        self.memory_allocated = 0
        
        # Inicializar dispositivos
        self._configure_devices()
        
    def _configure_devices(self):
        """Configura los dispositivos disponibles (GPU/CPU)"""
        try:
            # Obtener dispositivos disponibles
            self.gpu_devices = tf.config.list_physical_devices('GPU')
            self.cpu_devices = tf.config.list_physical_devices('CPU')
            
            print(f"GPUs disponibles: {len(self.gpu_devices)}")
            
            # Determinar si usaremos GPU
            gpu_available = len(self.gpu_devices) > 0
            use_gpu = gpu_available and self.prefer_gpu and not self.force_cpu
            
            if use_gpu:
                try:
                    # Configuración de memoria para GPU
                    for gpu in self.gpu_devices:
                        tf.config.experimental.set_memory_growth(gpu, True)
                        
                        # Limitar memoria si se especifica
                        if self.memory_fraction < 1.0:
                            mem_limit = int(self.memory_fraction * 10240)  # MB
                            gpu_config = tf.config.LogicalDeviceConfiguration(
                                memory_limit=mem_limit)
                            tf.config.set_logical_device_configuration(
                                gpu, [gpu_config])
                            print(f"Límite de memoria establecido: {mem_limit} MB")
                    
                    # Activar precisión mixta si se solicita
                    if self.mixed_precision:
                        policy = keras.mixed_precision.Policy('mixed_float16')
                        keras.mixed_precision.set_global_policy(policy)
                        self.using_mixed_precision = True
                        print("✅ Precisión mixta activada")
                    
                    # Activar XLA para optimizar rendimiento
                    tf.config.optimizer.set_jit(True)
                    
                    # Marcar como usando GPU
                    self.using_gpu = True
                    self.device_name = f"GPU:{self.gpu_devices[0].name.split(':')[-1]}"
                    print(f"✅ Usando GPU: {tf.test.gpu_device_name()}")
                    
                except Exception as e:
                    print(f"❌ Error configurando GPU: {e}")
                    print("⚠️ Fallback a CPU")
                    self.using_gpu = False
                    self.device_name = "CPU"
            else:
                # Usar CPU
                reason = "no disponible" if not gpu_available else "desactivada por configuración"
                print(f"📋 Usando CPU (GPU {reason})")
                self.using_gpu = False
                self.device_name = "CPU"
                
        except Exception as e:
            print(f"❌ Error detectando dispositivos: {e}")
            self.using_gpu = False
            self.device_name = "CPU"
    
    def clear_memory(self):
        """Libera memoria de Keras/TensorFlow"""
        # Forzar recolector de basura de Python
        gc.collect()
        
        # Limpiar sesión de Keras
        keras.backend.clear_session()
        
        if self.using_gpu:
            # Limpiar caché de GPU si es posible
            tf.keras.backend.clear_session()
            gc.collect()
            print("🧹 Memoria GPU liberada")
    
    def get_memory_status(self):
        """Obtiene el estado actual de memoria"""
        if not self.using_gpu:
            return {"allocated_gb": 0, "total_gb": 0}
        
        try:
            gpu_info = tf.config.experimental.get_memory_info('GPU:0')
            current_bytes = gpu_info['current']
            
            # Intentar obtener memoria total si está disponible
            total_bytes = None
            try:
                import subprocess
                result = subprocess.run(['nvidia-smi', '--query-gpu=memory.total', 
                                      '--format=csv,nounits,noheader'],
                                     stdout=subprocess.PIPE, check=True)
                total_mb = int(result.stdout.decode('utf-8').strip())
                total_bytes = total_mb * 1024 * 1024
            except:
                pass
            
            return {
                "allocated_gb": current_bytes / (1024**3),
                "total_gb": total_bytes / (1024**3) if total_bytes else None
            }
        except:
            return {"allocated_gb": 0, "total_gb": 0}
            
    def is_gpu(self):
        """Comprueba si se está usando GPU"""
        return self.using_gpu

# -----------------------------------------------------------------------------
# SECTION 4: EFFICIENT DATA LOADING AND PREPROCESSING
# -----------------------------------------------------------------------------

class TFPrecipitationDataset:
    """
    TensorFlow dataset for efficiently handling precipitation data
    with sliding windows and multiple channels.
    """
    
    @staticmethod
    def create_tf_dataset(x_data, y_data, batch_size=32, shuffle=True, prefetch=True, 
                         cache=False, drop_remainder=False):
        """
        Create an optimized tf.data.Dataset from NumPy arrays
        
        Args:
            x_data: Input data (numpy.ndarray)
            y_data: Target data (numpy.ndarray)
            batch_size: Batch size
            shuffle: If True, shuffle the data
            prefetch: If True, preload the next batch
            cache: If True, keep data in memory
            drop_remainder: If True, drop incomplete final batch
            
        Returns:
            Optimized tf.data.Dataset
        """
        # Crear dataset de tensores
        dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data))
        
        # Aplicar caché si se solicita (útil para datasets pequeños)
        if cache:
            dataset = dataset.cache()
        
        # Barajar datos si se solicita
        if shuffle:
            # Buffer size: usar tamaño completo de datos o un máximo de 10000
            buffer_size = min(len(x_data), 10000)
            dataset = dataset.shuffle(buffer_size, reshuffle_each_iteration=True)
        
        # Establecer tamaño de lote
        dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
        
        # Optimizar carga con prefetch
        if prefetch:
            dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        return dataset
    
    @staticmethod
    def build_dataloaders(val_year, use_lags=True, batch_size=32):
        """
        Build TensorFlow dataloaders with robust error handling for NaNs and type issues
        
        Args:
            val_year: Validation year
            use_lags: If True, include lag variables
            batch_size: Batch size
            
        Returns:
            Tuple of (train_dataset, val_dataset, num_features)
        """
        # Comprobar que DATASET_PATH esté definido
        if 'DATASET_PATH' not in globals():
            global DATASET_PATH
            if 'FULL_NC' in globals():
                DATASET_PATH = str(FULL_NC)
            else:
                DATASET_PATH = './data/complete_dataset.nc'
        
        # Cargar dataset
        print(f"Cargando datos desde {DATASET_PATH}")
        try:
            import xarray as xr
            ds = xr.open_dataset(DATASET_PATH)
        except Exception as e:
            print(f"❌ Error cargando dataset: {e}")
            raise
        
        # Seleccionar features según el parámetro use_lags
        if use_lags:
            if 'FULL_FEATURES' in globals():
                features = FULL_FEATURES
            else:
                features = ['total_precipitation', 'total_precipitation_lag1', 
                          'total_precipitation_lag2', 'total_precipitation_lag12',
                          'month_sin', 'month_cos', 'doy_sin', 'doy_cos',
                          'elevation', 'slope', 'roughness', 'curvature', 'aspect', 
                          'cluster_elevation']
        else:
            if 'BASE_FEATURES' in globals():
                base = BASE_FEATURES
                # Filtrar lags
                features = [f for f in base if 'lag' not in f]
            else:
                features = ['total_precipitation', 
                          'month_sin', 'month_cos', 'doy_sin', 'doy_cos',
                          'elevation', 'slope', 'aspect']
        
        # Determinar años de entrenamiento (4 años antes de validación)
        train_years = list(range(val_year - 4, val_year))
        
        # Extraer tiempos con control de errores
        try:
            times = pd.to_datetime(ds.time.values)
            years = times.year
        except Exception as e:
            print(f"⚠️ Error procesando fechas: {e}")
            # Crear valores seguros como fallback
            num_samples = len(ds['time']) if 'time' in ds.dims else 48
            years = np.zeros(num_samples, dtype=np.int32)
            # Asignar valores artificiales para tener datos de entrenamiento y validación
            years[:num_samples//2] = val_year - 1  # Mitad para entrenamiento
            years[num_samples//2:] = val_year      # Mitad para validación
    
        # Determinar conjuntos de entrenamiento y validación
        train_mask = np.isin(years, train_years)
        val_mask = (years == val_year)
        
        print(f"Años entrenamiento: {train_years}, Año validación: {val_year}")
        print(f"Muestras entrenamiento: {train_mask.sum()}, Muestras validación: {val_mask.sum()}")
        
        # Verificar si hay suficientes datos
        if train_mask.sum() == 0:
            print("⚠️ No hay datos de entrenamiento, creando datos artificiales")
            train_mask[:len(train_mask)//2] = True
        if val_mask.sum() == 0:
            print("⚠️ No hay datos de validación, creando datos artificiales")
            val_mask[len(val_mask)//2:] = True
    
        def robust_array_processing(arr, feature_name):
            """
            Process array data robustly to handle NaNs, infinities, and type issues
            
            Args:
                arr: Input numpy array
                feature_name: Name of the feature (for logging)
                
            Returns:
                Cleaned numpy array
            """
            # Check for valid array
            if arr is None:
                print(f"⚠️ Null array for feature '{feature_name}', replacing with zeros")
                return np.zeros((12, 20, 20), dtype=np.float32)
            
            # Convert to numpy if needed
            if not isinstance(arr, np.ndarray):
                try:
                    arr = np.array(arr, dtype=np.float32)
                except Exception as e:
                    print(f"⚠️ Error converting '{feature_name}' to numpy array: {e}")
                    return np.zeros((12, 20, 20), dtype=np.float32)
            
            # Handle NaN values
            if np.isnan(arr).any():
                # Calculate mean, safely handling case where all values are NaN
                mean_val = np.nanmean(arr) if not np.isnan(np.nanmean(arr)) else 0
                print(f"⚠️ NaN values found in '{feature_name}', replacing with mean={mean_val:.4f}")
                arr = np.nan_to_num(arr, nan=mean_val)
            
            # Handle infinite values
            if not np.isfinite(arr).all():
                print(f"⚠️ Infinite values found in '{feature_name}', replacing with finite values")
                arr = np.nan_to_num(arr, posinf=np.nanmax(arr[np.isfinite(arr)]) if np.any(np.isfinite(arr)) else 1, 
                                   neginf=np.nanmin(arr[np.isfinite(arr)]) if np.any(np.isfinite(arr)) else -1)
            
            # Convert to float32 for TensorFlow compatibility
            if arr.dtype != np.float32:
                try:
                    arr = arr.astype(np.float32)
                except Exception as e:
                    print(f"⚠️ Error converting '{feature_name}' to float32: {e}")
            
            return arr
        
        # Extraer matrices para cada característica con control robusto de NaNs
        feature_arrays = []
        for feature in features:
            if feature in ds.data_vars:
                try:
                    # Obtener array con procesamiento robusto
                    raw_arr = ds[feature].values
                    cleaned_arr = robust_array_processing(raw_arr, feature)
                    feature_arrays.append(cleaned_arr)
                except Exception as e:
                    print(f"❌ Error procesando feature '{feature}': {e}")
                    # Crear array de ceros como fallback
                    shape = ds[features[0]].shape if features[0] in ds.data_vars else (12, 20, 20)
                    dummy_arr = np.zeros(shape, dtype=np.float32)
                    feature_arrays.append(dummy_arr)
                    print(f"⚠️ Usando array de ceros para '{feature}'")
            else:
                print(f"⚠️ Característica '{feature}' no encontrada, usando ceros")
                # Crear array de ceros con forma compatible
                if feature_arrays:
                    dummy_arr = np.zeros_like(feature_arrays[0], dtype=np.float32)
                else:
                    dummy_arr = np.zeros((12, 20, 20), dtype=np.float32)
                feature_arrays.append(dummy_arr)
        
        # Verificar que todos los arrays tengan dimensiones compatibles
        shapes = [arr.shape for arr in feature_arrays]
        if len(set(shapes)) > 1:
            print(f"⚠️ Distintas dimensiones en features: {shapes}")
            # Intentar corregir dimensiones
            target_shape = shapes[0]
            for i, arr in enumerate(feature_arrays):
                if arr.shape != target_shape:
                    print(f"  Redimensionando feature {i} de {arr.shape} a {target_shape}")
                    try:
                        # Intentar redimensionar o crear nuevo array
                        feature_arrays[i] = np.zeros(target_shape, dtype=np.float32)
                    except:
                        pass
    
        # Apilar características con control de errores
        try:
            X = np.stack(feature_arrays, axis=-1)
            num_features = X.shape[-1]
            print(f"✅ Arrays apilados correctamente, shape = {X.shape}, features = {num_features}")
        except Exception as e:
            print(f"❌ Error al apilar arrays: {e}")
            # Crear datos artificiales como último recurso
            print("⚠️ Creando datos artificiales de emergencia")
            shape = (12, 20, 20)
            X = np.zeros((*shape, len(feature_arrays)), dtype=np.float32)
            num_features = X.shape[-1]
    
        # Usar INPUT_WINDOW y HORIZON globales o valores predeterminados
        input_window = INPUT_WINDOW if 'INPUT_WINDOW' in globals() else 48
        horizon = HORIZON if 'HORIZON' in globals() else 12
        
        # Control de integridad para ventanas
        if input_window <= 0 or horizon <= 0:
            print(f"⚠️ Valores inválidos: INPUT_WINDOW={input_window}, HORIZON={horizon}")
            input_window = max(1, input_window)
            horizon = max(1, horizon)
        
        # Construir ventanas con control de errores
        try:
            # Construir ventanas para entrenamiento
            X_train_windows = []
            Y_train_windows = []
            
            for i in range(len(X) - input_window - horizon + 1):
                try:
                    # Solo incluir si la ventana completa está en el conjunto de entrenamiento
                    window_indices = np.arange(i, i + input_window + horizon)
                    if np.all(train_mask[window_indices]):
                        x_window = X[i:i+input_window]
                        y_window = X[i+input_window:i+input_window+horizon, :, :, 0:1]  # Solo precipitación
                        
                        # Verificar si hay NaNs
                        if np.isnan(x_window).any() or np.isnan(y_window).any():
                            x_window = np.nan_to_num(x_window, nan=0.0)
                            y_window = np.nan_to_num(y_window, nan=0.0)
                        
                        X_train_windows.append(x_window)
                        Y_train_windows.append(y_window)
                except Exception as e:
                    print(f"⚠️ Error en ventana de entrenamiento {i}: {e}")
            
            # Construir ventanas para validación
            X_val_windows = []
            Y_val_windows = []
            
            for i in range(len(X) - input_window - horizon + 1):
                try:
                    # Solo incluir si los outputs están completamente en validación
                    output_indices = np.arange(i + input_window, i + input_window + horizon)
                    if np.all(val_mask[output_indices]):
                        x_window = X[i:i+input_window]
                        y_window = X[i+input_window:i+input_window+horizon, :, :, 0:1]
                        
                        # Verificar si hay NaNs
                        if np.isnan(x_window).any() or np.isnan(y_window).any():
                            x_window = np.nan_to_num(x_window, nan=0.0)
                            y_window = np.nan_to_num(y_window, nan=0.0)
                        
                        X_val_windows.append(x_window)
                        Y_val_windows.append(y_window)
                except Exception as e:
                    print(f"⚠️ Error en ventana de validación {i}: {e}")
                
        except Exception as e:
            print(f"❌ Error general generando ventanas: {e}")
            # Crear listas vacías como fallback
            X_train_windows = []
            Y_train_windows = []
            X_val_windows = []
            Y_val_windows = []
        
        # Verificar que tengamos datos
        if not X_train_windows or not Y_train_windows:
            print("❌ No se pudieron generar ventanas de entrenamiento")
            # Crear ventanas artificiales mínimas
            sample_shape = X.shape[1:]  # Forma espacial
            X_train_windows = [np.zeros((input_window, *sample_shape), dtype=np.float32)]
            Y_train_windows = [np.zeros((horizon, *sample_shape[:-1], 1), dtype=np.float32)]
        
        if not X_val_windows or not Y_val_windows:
            print("❌ No se pudieron generar ventanas de validación")
            # Crear ventanas artificiales mínimas
            sample_shape = X.shape[1:]  # Forma espacial
            X_val_windows = [np.zeros((input_window, *sample_shape), dtype=np.float32)]
            Y_val_windows = [np.zeros((horizon, *sample_shape[:-1], 1), dtype=np.float32)]
        
        # Convertir listas a arrays con control de errores
        try:
            X_train = np.array(X_train_windows)
            Y_train = np.array(Y_train_windows)
            X_val = np.array(X_val_windows) 
            Y_val = np.array(Y_val_windows)
            
            # Verificación final de NaNs
            for arr_name, arr in [("X_train", X_train), ("Y_train", Y_train), 
                                 ("X_val", X_val), ("Y_val", Y_val)]:
                if np.isnan(arr).any():
                    print(f"⚠️ NaNs detectados en {arr_name} después de conversión, reemplazando con 0")
                    if arr_name.startswith("X"):
                        arr = np.nan_to_num(arr, nan=0.0)
                    else:
                        arr = np.nan_to_num(arr, nan=0.0)
            
        except Exception as e:
            print(f"❌ Error convirtiendo listas a arrays: {e}")
            # Crear arrays mínimos
            sample_shape = X.shape[1:]  # Forma espacial
            X_train = np.zeros((1, input_window, *sample_shape), dtype=np.float32)
            Y_train = np.zeros((1, horizon, *sample_shape[:-1], 1), dtype=np.float32)
            X_val = np.zeros((1, input_window, *sample_shape), dtype=np.float32)
            Y_val = np.zeros((1, horizon, *sample_shape[:-1], 1), dtype=np.float32)
        
        print(f"Forma de datos entrenamiento: X={X_train.shape}, Y={Y_train.shape}")
        print(f"Forma de datos validación: X={X_val.shape}, Y={Y_val.shape}")
        
        # Crear TF datasets optimizados con control de errores
        try:
            train_dataset = TFPrecipitationDataset.create_tf_dataset(
                X_train, Y_train, batch_size=batch_size, shuffle=True)
            
            val_dataset = TFPrecipitationDataset.create_tf_dataset(
                X_val, Y_val, batch_size=batch_size, shuffle=False)
            
            return train_dataset, val_dataset, num_features
        
        except Exception as e:
            print(f"❌ Error creando TensorFlow datasets: {e}")
            # Ultimo recurso: crear datasets artificiales
            print("🆘 Creando datasets artificiales de emergencia")
            
            # Crear arrays mínimos (más pequeños para evitar OOM)
            reduced_features = min(num_features, 5)
            X_dummy = np.zeros((10, input_window, 10, 10, reduced_features), dtype=np.float32)
            Y_dummy = np.zeros((10, horizon, 10, 10, 1), dtype=np.float32)
            
            dummy_train = tf.data.Dataset.from_tensor_slices((X_dummy, Y_dummy)).batch(batch_size)
            dummy_val = tf.data.Dataset.from_tensor_slices((X_dummy, Y_dummy)).batch(batch_size)
            
            return dummy_train, dummy_val, reduced_features
def build_spatial_dataloaders(val_year, use_lags=True, batch_size=32, flatten_spatial=False):
    """
    Construir dataloaders para datos espaciales 3D con manejo robusto de errores
    
    Args:
        val_year: Año de validación
        use_lags: Si es True, incluir variables de rezago
        batch_size: Tamaño del lote
        flatten_spatial: Si es True, aplanar dimensiones espaciales para modelos 1D
        
    Returns:
        Tuple de (train_dataset, val_dataset, input_shape)
    """
    # Comprobar que DATASET_PATH esté definido
    if 'DATASET_PATH' not in globals():
        global DATASET_PATH
        if 'FULL_NC' in globals():
            DATASET_PATH = str(FULL_NC)
        else:
            DATASET_PATH = './data/complete_dataset.nc'
    
    # Cargar dataset
    print(f"Cargando datos espaciales desde {DATASET_PATH}")
    try:
        import xarray as xr
        ds = xr.open_dataset(DATASET_PATH)
    except Exception as e:
        print(f"❌ Error cargando dataset: {e}")
        raise
    
    # Seleccionar features según el parámetro use_lags
    if use_lags:
        if 'FULL_FEATURES' in globals():
            features = FULL_FEATURES
        else:
            features = ['total_precipitation', 'total_precipitation_lag1', 
                      'total_precipitation_lag2', 'total_precipitation_lag12',
                      'month_sin', 'month_cos', 'doy_sin', 'doy_cos',
                      'elevation', 'slope', 'roughness', 'curvature', 'aspect', 
                      'cluster_elevation']
    else:
        if 'BASE_FEATURES' in globals():
            base = BASE_FEATURES
            # Filtrar lags
            features = [f for f in base if 'lag' not in f]
        else:
            features = ['total_precipitation', 
                      'month_sin', 'month_cos', 'doy_sin', 'doy_cos',
                      'elevation', 'slope', 'aspect']
    
    # Determinar años de entrenamiento (4 años antes de validación)
    train_years = list(range(val_year - 4, val_year))
    
    # Extraer tiempos con control de errores
    try:
        times = pd.to_datetime(ds.time.values)
        years = times.year
    except Exception as e:
        print(f"⚠️ Error procesando fechas: {e}")
        # Crear valores seguros como fallback
        num_samples = len(ds['time']) if 'time' in ds.dims else 48
        years = np.zeros(num_samples, dtype=np.int32)
        # Asignar valores artificiales para tener datos de entrenamiento y validación
        years[:num_samples//2] = val_year - 1  # Mitad para entrenamiento
        years[num_samples//2:] = val_year      # Mitad para validación
    
    # Determinar conjuntos de entrenamiento y validación
    train_mask = np.isin(years, train_years)
    val_mask = (years == val_year)
    
    # Verificar número de muestras
    train_samples = np.sum(train_mask)
    val_samples = np.sum(val_mask)
    print(f"Muestras de entrenamiento: {train_samples}, Validación: {val_samples}")
    
    # Procesar y limpiar variables
    feature_arrays = []
    print(f"Procesando {len(features)} características...")
    
    # Procesamiento por lotes para reducir uso de memoria
    for feature in features:
        if feature in ds.data_vars:
            try:
                # Extraer array
                arr = ds[feature].values
                
                # Verificar dimensión temporal vs espacial
                if arr.ndim == 2:
                    print(f"⚠️ Característica espacial (no temporal): {feature}")
                    # Es una variable espacial estática (como elevación)
                    # Repetir para cada punto temporal
                    time_dim = len(times)
                    spatial_shape = arr.shape
                    repeated = np.repeat(arr[np.newaxis, :, :], time_dim, axis=0)
                    arr = repeated
                
                # Procesar NaNs de manera segura
                if hasattr(arr, 'dtype') and arr.dtype != object:
                    if np.isnan(arr).any():
                        mean_val = np.nanmean(arr)
                        if np.isnan(mean_val):
                            mean_val = 0
                        arr = np.nan_to_num(arr, nan=mean_val)
                
                # Convertir a float32 para compatibilidad con TensorFlow
                if hasattr(arr, 'dtype') and arr.dtype != np.float32:
                    try:
                        arr = arr.astype(np.float32)
                    except Exception as e:
                        print(f"⚠️ Error convirtiendo {feature} a float32: {e}")
                        # Crear array de ceros como respaldo
                        if len(feature_arrays) > 0:
                            arr = np.zeros_like(feature_arrays[0], dtype=np.float32)
                        else:
                            arr = np.zeros((len(times), 61, 65), dtype=np.float32)
                
                feature_arrays.append(arr)
                print(f"✓ {feature}: forma {arr.shape}, tipo {arr.dtype}")
            except Exception as e:
                print(f"❌ Error procesando {feature}: {e}")
                # Crear array de respaldo
                if len(feature_arrays) > 0:
                    dummy = np.zeros_like(feature_arrays[0], dtype=np.float32)
                else:
                    dummy = np.zeros((len(times), 61, 65), dtype=np.float32)
                feature_arrays.append(dummy)
        else:
            print(f"⚠️ Característica no encontrada: {feature}, usando ceros")
            if len(feature_arrays) > 0:
                dummy = np.zeros_like(feature_arrays[0], dtype=np.float32)
                feature_arrays.append(dummy)
            else:
                dummy = np.zeros((len(times), 61, 65), dtype=np.float32)
                feature_arrays.append(dummy)
    
    # Asegurar que todas las características tengan la misma forma
    target_shape = feature_arrays[0].shape
    for i, arr in enumerate(feature_arrays):
        if arr.shape != target_shape:
            print(f"⚠️ Redimensionando feature {i} de {arr.shape} a {target_shape}")
            # Si es una característica espacial, repetirla para cada tiempo
            if arr.ndim == 2 and arr.shape == target_shape[1:]:
                feature_arrays[i] = np.stack([arr] * target_shape[0])
            else:
                # Si falla, usar ceros
                try:
                    feature_arrays[i] = np.zeros(target_shape, dtype=np.float32)
                except:
                    feature_arrays[i] = np.zeros(target_shape, dtype=np.float32)
                
    # Apilar características en el último eje
    X = np.stack(feature_arrays, axis=-1)
    print(f"Dataset completo: {X.shape}")
    
    # Usar INPUT_WINDOW y HORIZON globales o valores predeterminados
    input_window = INPUT_WINDOW if 'INPUT_WINDOW' in globals() else 48
    horizon = HORIZON if 'HORIZON' in globals() else 12
    
    # Construir ventanas deslizantes
    train_windows_x = []
    train_windows_y = []
    val_windows_x = []
    val_windows_y = []
    
    print(f"Creando ventanas con {input_window} pasos de entrada y {horizon} de horizonte...")
    
    # Crear ventanas deslizantes
    for i in range(len(X) - input_window - horizon + 1):
        if i % 50 == 0:  # Progreso cada 50 ventanas
            print(f"Procesando ventana {i}/{len(X) - input_window - horizon + 1}")
            
        # Verificar si la ventana está completamente en train o validation
        window_indices = np.arange(i, i + input_window + horizon)
        
        # Para entrenamiento: tanto entrada como salida deben estar en train
        if np.all(train_mask[window_indices]):
            x_window = X[i:i+input_window]
            y_window = X[i+input_window:i+input_window+horizon, :, :, 0:1]  # Solo precipitación
            train_windows_x.append(x_window)
            train_windows_y.append(y_window)
            
        # Para validación: target debe estar en validation
        output_indices = np.arange(i + input_window, i + input_window + horizon)
        if np.all(val_mask[output_indices]):
            x_window = X[i:i+input_window]
            y_window = X[i+input_window:i+input_window+horizon, :, :, 0:1]  # Solo precipitación
            val_windows_x.append(x_window)
            val_windows_y.append(y_window)
    
    # Si no hay suficientes ventanas, crear algunas artificiales
    if len(train_windows_x) == 0:
        print("❌ No se pudieron generar ventanas de entrenamiento, creando datos artificiales")
        # Define spatial dimensions from X shape (height, width)
        spatial_dim = X.shape[1:3]
        for _ in range(10):  # Crear al menos 10 ventanas
            train_windows_x.append(np.zeros((input_window, *spatial_dim, X.shape[-1]), dtype=np.float32))
    if len(val_windows_x) == 0:
        print("❌ No se pudieron generar ventanas de validación, creando datos artificiales")
        val_windows_x = []
        val_windows_y = []
        # Ensure spatial_dim is defined
        spatial_dim = X.shape[1:3]
        for _ in range(5):  # Crear al menos 5 ventanas
            val_windows_x.append(np.zeros((input_window, *spatial_dim, X.shape[-1]), dtype=np.float32))
            val_windows_y.append(np.zeros((horizon, *spatial_dim, 1), dtype=np.float32))
            val_windows_x.append(np.zeros((input_window, *spatial_dim, X.shape[-1]), dtype=np.float32))
            val_windows_y.append(np.zeros((horizon, *spatial_dim, 1), dtype=np.float32))
    
    # Convertir listas a arrays
    
    print(f"Formas finales: X_train={X_train.shape}, Y_train={Y_train.shape}")
    print(f"               X_val={X_val.shape}, Y_val={Y_val.shape}")
    
    # Aplanar dimensiones espaciales si se solicita (para modelos 1D como GRU)
    if flatten_spatial:
        print("🔄 Aplanando dimensiones espaciales para compatibilidad con modelos 1D")
        # Para X: [batch, time, height, width, features] → [batch, time, height*width*features]
        batch_size_train, time_steps, height, width, features = X_train.shape
        X_train = X_train.reshape(batch_size_train, time_steps, height * width * features)
        
        batch_size_val = X_val.shape[0]
        X_val = X_val.reshape(batch_size_val, time_steps, height * width * features)
        
        # Para Y: [batch, horizon, height, width, 1] → [batch, horizon, height*width]
        Y_train = Y_train.reshape(batch_size_train, horizon, height * width)
        Y_val = Y_val.reshape(batch_size_val, horizon, height * width)
        
        # Actualizar número de características tras aplanado
        num_features = height * width * features
        
        print(f"Datos después de aplanar - X_train: {X_train.shape}, Y_train: {Y_train.shape}")
        print(f"                         - X_val: {X_val.shape}, Y_val: {Y_val.shape}")
    
    # Crear datasets TF
    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    val_dataset = tf.data.Dataset.from_tensor_slices((X_val, Y_val)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    
    return train_dataset, val_dataset, num_features

In [None]:
# ▶️ Ejecutar el entrenamiento del modelo base (GRU-ED) en el fold más reciente (F1)
print("="*80)
print("🚀 INICIANDO ENTRENAMIENTO DEL MODELO GRU-ED (FOLD F1)")
print("="*80)

# 1. Verificar que todas las variables necesarias están definidas
print(f"Ventana de entrada: {INPUT_WINDOW} meses")
print(f"Horizonte de predicción: {HORIZON} meses")
print(f"Experimentos disponibles: {list(EXPERIMENTS.keys())}")
print(f"Folds disponibles: {FOLDS}")
print(f"Modo MVP: {'Activo (solo fold F1)' if MVP_MODE else 'Inactivo (todos los folds)'}")

# 2. Configurar el entorno GPU/CPU
device_mgr = TFDeviceManager(prefer_gpu=True)
print(f"Dispositivo seleccionado: {device_mgr.device_name}")

# 3. Ejecutar el experimento para GRU-ED en el fold F1
experiment_name = 'GRU-ED'
fold_name = 'F1'
val_year = FOLDS[fold_name]

try:
    # Comprobar si ya existe un checkpoint para este experimento
    checkpoint_name = f"{experiment_name}_{fold_name}_result"
    checkpoint_data = SafeExecution.load_checkpoint(checkpoint_name)
    
    if checkpoint_data:
        _, _, best_rmse = checkpoint_data
        print(f"✅ Modelo ya entrenado anteriormente. RMSE: {best_rmse:.4f}")
    else:
        print(f"🔄 Preparando datos para el fold {fold_name} (validación: {val_year})")
        # Crear datasets
        batch_size = 8 if device_mgr.is_gpu() else 4
        train_dataset, val_dataset, in_dim = TFPrecipitationDataset.build_dataloaders(
            val_year, EXPERIMENTS[experiment_name]['use_lags'], batch_size)
        
        # Configurar y crear modelo
        model_config = {
            'input_dim': in_dim,
            'input_length': INPUT_WINDOW,
            'output_length': HORIZON,
            'hidden_units': 128,
            'num_layers': 2,
            'dropout_rate': 0.20  # 0.20 para F1-F3, 0.25 para F4-F5
        }
        
        # Crear modelo directamente usando las funciones existentes
        print(f"🧠 Creando modelo {experiment_name}")
        
        # Definir función para crear modelo GRU-ED básico
        def create_gru_ed_model(input_dim, input_length, output_length, hidden_units=128, 
                               dropout_rate=0.2, num_layers=2):
            """Crea un modelo GRU Encoder-Decoder básico"""
            from tensorflow.keras import layers, Model
            
            # Encoder
            inputs = layers.Input(shape=(input_length, input_dim))
            encoder = inputs
            
            # Stack de capas GRU para el encoder
            for i in range(num_layers - 1):
                encoder = layers.GRU(hidden_units, return_sequences=True, dropout=dropout_rate)(encoder)
                
            # Última capa del encoder
            encoder_outputs, state_h = layers.GRU(hidden_units, return_state=True, return_sequences=True)(encoder)
            
            # Decoder inicializado con estado del encoder
            decoder_inputs = encoder_outputs
            
            # Stack de capas GRU para el decoder
            for i in range(num_layers - 1):
                decoder_inputs = layers.GRU(hidden_units, return_sequences=True, dropout=dropout_rate)(decoder_inputs)
                
            # Capa final del decoder
            decoder_outputs = layers.GRU(hidden_units, return_sequences=True)(decoder_inputs)
            
            # Proyección a la dimensión de salida
            outputs = layers.Dense(output_length, activation='linear')(decoder_outputs)
            
            # Crear y compilar modelo
            model = Model(inputs, outputs)
            model.compile(optimizer='adam', loss='mse', metrics=['mse'])
            
            return model
            
        # Crear modelo con la configuración especificada
        model = create_gru_ed_model(**model_config)
        
        # Entrenar modelo
        print(f"🏋️ Entrenando modelo {experiment_name} en fold {fold_name}")
        
        # Configurar Early Stopping
        early_stop = tf.keras.callbacks.EarlyStopping(
            monitor='val_loss', patience=20, restore_best_weights=True, verbose=1
        )
        
        # Entrenamiento
        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=60,
            callbacks=[early_stop],
            verbose=1
        ).history
        
        # Calcular el mejor RMSE
        val_mse = history['val_mse']
        best_epoch = np.argmin(val_mse)
        best_rmse = np.sqrt(val_mse[best_epoch])
        
        # Guardar checkpoint
        print(f"💾 Guardando resultados")
        SafeExecution.save_checkpoint(
            (model, history, best_rmse),
            checkpoint_name
        )
        
        # Limpiar recursos
        device_mgr.clear_memory()
    
    # Mostrar resumen de resultados
    print("\n📊 RESULTADOS DEL EXPERIMENTO:")
    print(f"Modelo: {experiment_name}")
    print(f"Fold: {fold_name} (Validación: {val_year})")
    print(f"RMSE: {best_rmse:.4f}")
    
except Exception as e:
    print(f"❌ Error durante la ejecución: {str(e)}")

# 4. Mostrar tabla con todos los resultados disponibles
print("\n📋 RESUMEN DE TODOS LOS EXPERIMENTOS EJECUTADOS:")
show_results()

# Visualización de Resultados de Predicción

Los gráficos a continuación muestran la evolución del error durante el entrenamiento y ejemplos de predicciones del modelo GRU-ED comparadas con los valores reales de precipitación.

In [None]:
# ▶️ Visualizar resultados del entrenamiento
import matplotlib.pyplot as plt
import numpy as np

def plot_training_history(exp_name, fold):
    """Visualiza las curvas de aprendizaje de un experimento"""
    try:
        # Intentar cargar el checkpoint
        checkpoint_name = f"{exp_name}_{fold}_result"
        checkpoint_data = SafeExecution.load_checkpoint(checkpoint_name)
        
        if not checkpoint_data:
            print(f"❌ No se encontró historial para {exp_name} en fold {fold}")
            return
            
        _, history, best_rmse = checkpoint_data
        
        # Crear figura para gráficos
        fig, axes = plt.subplots(1, 2, figsize=(14, 5))
        
        # Gráfico de pérdida
        axes[0].plot(history['loss'], label='Train')
        if 'val_loss' in history:
            axes[0].plot(history['val_loss'], label='Validation')
        axes[0].set_title(f'{exp_name} - Fold {fold} - Loss')
        axes[0].set_xlabel('Epochs')
        axes[0].set_ylabel('Loss')
        axes[0].legend()
        axes[0].grid(True, linestyle='--', alpha=0.6)
        
        # Gráfico de RMSE
        axes[1].plot(np.sqrt(history['mse']), label='Train RMSE')
        if 'val_mse' in history:
            axes[1].plot(np.sqrt(history['val_mse']), label='Validation RMSE')
        axes[1].set_title(f'{exp_name} - Fold {fold} - RMSE')
        axes[1].set_xlabel('Epochs')
        axes[1].set_ylabel('RMSE')
        axes[1].legend()
        axes[1].grid(True, linestyle='--', alpha=0.6)
        
        # Ajustar layout y mostrar
        plt.tight_layout()
        plt.show()
        
        # Mostrar mejor RMSE
        print(f"Mejor RMSE en validación: {best_rmse:.4f}")
        
    except Exception as e:
        print(f"❌ Error al visualizar historial: {e}")

# Visualizar historial del modelo GRU-ED en fold F1
print("📊 Curvas de aprendizaje para GRU-ED (Fold F1)")
plot_training_history('GRU-ED', 'F1')

# Si hay más modelos entrenados, mostrar comparativa
try:
    # Buscar todos los resultados disponibles
    results = []
    for file in CHECKPOINT_DIR.glob("*_result.pkl"):
        try:
            parts = file.stem.split('_')
            exp = parts[0]
            fold = parts[1]
            checkpoint = SafeExecution.load_checkpoint(f"{exp}_{fold}_result")
            if checkpoint:
                _, _, rmse = checkpoint
                results.append({
                    'exp': exp,
                    'fold': fold,
                    'rmse': rmse
                })
        except:
            continue
    
    # Si hay varios modelos entrenados, mostrar comparativa
    if len(set([r['exp'] for r in results])) > 1:
        plt.figure(figsize=(10, 6))
        
        # Agrupar por experimento
        import pandas as pd
        df = pd.DataFrame(results)
        exp_groups = df.groupby('exp')
        
        for exp_name, group in exp_groups:
            # Ordenar por fold
            group = group.sort_values('fold')
            plt.plot(group['fold'], group['rmse'], 'o-', label=exp_name)
        
        plt.title('Comparación de RMSE por Modelo y Fold')
        plt.xlabel('Fold')
        plt.ylabel('RMSE')
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.legend()
        plt.tight_layout()
        plt.show()
except Exception as e:
    print(f"No se pudo generar comparativa: {e}")

In [None]:
# ▶️ Definir función para ejecutar experimentos con manejo robusto de errores
def run_tf_experiments(folds_to_run=None, force_cpu=False, fail_safe=True):
    """
    Run TensorFlow experiments with robust error handling
    
    Args:
        folds_to_run: List of folds to run, None for default
        force_cpu: If True, force CPU mode
        fail_safe: If True, use extra error handling
        
    Returns:
        DataFrame with results
    """
    results = []
    
    # Verificar que tenemos las funciones y variables necesarias
    if 'EXPERIMENTS' not in globals() or 'FOLDS' not in globals():
        print("❌ Variables globales EXPERIMENTS y/o FOLDS no definidas")
        return results
    
    # Limitar a folds específicos si es necesario - MODIFICADO para ser consistente con MVP_MODE
    if folds_to_run is None:
        folds_to_run = ['F1'] if MVP_MODE else list(FOLDS.keys())
    else:
        valid_folds = [f for f in folds_to_run if f in FOLDS]
        if MVP_MODE:
            # En MVP_MODE, solo permitir F1 incluso si se especificaron otros folds
            folds_to_run = ['F1'] if 'F1' in valid_folds else []
            if not folds_to_run and valid_folds:
                print("⚠️ MVP_MODE activo: Solo se permite el fold F1")
        else:
            folds_to_run = valid_folds
    
    print(f"🔄 Ejecutando experimentos en {'CPU' if force_cpu else 'GPU/CPU'}")
    print(f"   Folds: {', '.join(folds_to_run)}")
    print(f"   Modo MVP: {'✅ Activo (solo F1)' if MVP_MODE else '❌ Inactivo'}")
    
    # Procesar todos los experimentos cuando MVP_MODE está activo
    exps_to_run = EXPERIMENTS.keys() if MVP_MODE else [next(iter(EXPERIMENTS.keys()))]
    
    for exp_name in exps_to_run:
        for fold in folds_to_run:
            # Checkpoint name for this experiment/fold
            checkpoint_name = f"{exp_name}_{fold}_result"
            
            # Check if a previous result exists
            checkpoint_data = SafeExecution.load_checkpoint(checkpoint_name)
            if checkpoint_data:
                model, history, best_rmse = checkpoint_data
                print(f"✅ Using previous result: RMSE = {best_rmse:.4f}")
                
                # Register global result
                if 'RESULTS' in globals():
                    RESULTS.append({
                        'exp': exp_name,
                        'fold': fold,
                        'rmse': best_rmse
                    })
                    
                # Update global histories
                if 'ALL_HISTORIES' in globals():
                    if exp_name not in ALL_HISTORIES:
                        ALL_HISTORIES[exp_name] = {}
                    ALL_HISTORIES[exp_name][fold] = history
                
                continue
            
            # If no checkpoint, run the training
            try:
                # Free memory before starting
                clean_memory()
                
                # Get configuration and build dataloaders
                print(f"🔄 Preparing data for fold {fold}")
                cfg = EXPERIMENTS[exp_name]
                val_year = FOLDS[fold]
                
                # Use reduced batch size for greater stability
                batch_size = max(8, BATCH_SIZE // 2)  # Half the original batch size, minimum 8
                train_loader, val_loader, in_dim = TFPrecipitationDataset.build_dataloaders(val_year, cfg['use_lags'], batch_size)
                
                # Adjust dropout according to documentation
                dropout = 0.25 if fold in ['F4', 'F5'] else 0.20
                
                # Create model - TensorFlow models don't use .to(DEVICE)
                model = MODEL_FACTORY[cfg['model']](in_dim, dropout=dropout)
                
                # Train model with error handling
                print(f"🔄 Training {exp_name} on fold {fold}")
                try:
                    model, history, best_rmse = train_with_history(
                        model, train_loader, val_loader,
                        epochs=60, patience=20,
                        lr=1e-3, weight_decay=1e-4,
                        fold=fold, exp_name=exp_name
                    )
                    
                    # Save checkpoint
                    SafeExecution.save_checkpoint(
                        (model, history, best_rmse),
                        checkpoint_name
                    )
                    
                    # Register global result
                    if 'RESULTS' in globals():
                        RESULTS.append({
                            'exp': exp_name,
                            'fold': fold,
                            'rmse': best_rmse
                        })
                    
                    # Update global histories
                    if 'ALL_HISTORIES' in globals():
                        if exp_name not in ALL_HISTORIES:
                            ALL_HISTORIES[exp_name] = {}
                        ALL_HISTORIES[exp_name][fold] = history
                        
                    print(f"✅ Training completed: RMSE = {best_rmse:.4f}")
                    
                except Exception as e:
                    print(f"❌ Error in training: {e}")
                
            except Exception as e:
                print(f"❌ Error in experiment {exp_name}, fold {fold}: {e}")
                continue
        
        print(f"✅ Experiment {exp_name} completed")

# ▶️ Ejecutar todos los experimentos en modo robusto
print("="*80)
print("🚀 INICIANDO EJECUCIÓN DE TODOS LOS EXPERIMENTOS EN MODO ROBUSTO")
print("="*80)

try:
    # Forzar CPU si es necesario (para evitar problemas de memoria)
    force_cpu = False  # Cambiar a True si se desea forzar uso de CPU
    
    # Ejecutar experimentos con manejo robusto de errores
    results = run_tf_experiments(
        folds_to_run=None,  # None para ejecutar todos los folds según el modo MVP
        force_cpu=force_cpu,
        fail_safe=True  # Activar manejo de errores reforzado
    )
    
    # Mostrar resultados
    if results:
        import pandas as pd
        from IPython.display import display
        
        df = pd.DataFrame(results)
        print("\n📊 RESULTADOS DE LA EJECUCIÓN EN MODO ROBUSTO:")
        display(df)
    else:
        print("❌ No se obtuvieron resultados")
        
except Exception as e:
    print(f"❌ Error no capturado: {str(e)}")
    
    # Mostrar diagnóstico detallado
    import traceback
    print("\n🔍 Detalles del error:")
    traceback.print_exc()
    
    print("\n💡 Recomendación: Ejecute con estos parámetros para máxima compatibilidad:")
    print("   run_tf_experiments(['F1'], force_cpu=True, fail_safe=True)")

In [None]:
# ▶️ Control del Modo MVP (Minimum Viable Product)
# Esta celda permite activar/desactivar fácilmente el modo MVP

def set_mvp_mode(active=True):
    """
    Configura el modo MVP (entrenamiento mínimo viable)
    
    Args:
        active: Si True, solo se entrena el fold más reciente (F1)
               Si False, se entrenan todos los folds
    """
    global MVP_MODE
    old_mode = MVP_MODE
    MVP_MODE = active
    
    print(f"Modo MVP: {'✅ ACTIVADO' if MVP_MODE else '❌ DESACTIVADO'}")
    print(f"  • {'Solo se entrenará el fold F1 (más reciente, año 2024)' if MVP_MODE else 'Se entrenan todos los folds (F1-F5)'}")
    print(f"  • {'Se ejecutarán los 5 modelos en ese fold' if MVP_MODE else 'Se ejecutarán los 5 modelos en todos los folds'}")
    
    if old_mode != MVP_MODE:
        print(f"⚠️ El modo ha cambiado de {old_mode} a {MVP_MODE}")
    
    return MVP_MODE

# Por defecto, iniciar en modo MVP (True)
# Para ejecutar todos los experimentos en todos los folds, ejecutar:
# set_mvp_mode(False)

print(f"Estado actual: Modo MVP {'✅ ACTIVO' if MVP_MODE else '❌ INACTIVO'}")

In [None]:
# ▶️ Verificar la codificación one-hot de cluster_elevation
print("="*80)
print("🔍 VERIFICACIÓN DE CODIFICACIÓN ONE-HOT PARA CLUSTER_ELEVATION")
print("="*80)

import xarray as xr
import numpy as np
import tensorflow as tf

def robust_array_processing(data_array, feature_name):
    """
    Applies appropriate preprocessing to an array based on the feature type.
    For categorical features, performs one-hot encoding.
    
    Args:
        data_array (numpy.ndarray): The input data array
        feature_name (str): Name of the feature being processed
        
    Returns:
        numpy.ndarray: Processed array with one-hot encoding if applicable
    """
    try:
        # Check if this is likely a categorical feature that needs one-hot encoding
        if np.issubdtype(data_array.dtype, np.integer):
            # Get unique values
            unique_values = np.unique(data_array)
            num_categories = len(unique_values)
            
            if num_categories <= 20:  # Reasonable threshold for categorical data
                print(f"Applying one-hot encoding to {feature_name} with {num_categories} categories")
                
                # Create output array with extra dimension for one-hot encoding
                if data_array.ndim == 2:
                    # 2D case (spatial only)
                    encoded = np.zeros((*data_array.shape, num_categories), dtype=np.float32)
                    
                    # Apply one-hot encoding
                    for i, val in enumerate(unique_values):
                        mask = (data_array == val)
                        encoded[..., i][mask] = 1.0
                        
                elif data_array.ndim == 3:
                    # 3D case (time, lat, lon)
                    encoded = np.zeros((*data_array.shape, num_categories), dtype=np.float32)
                    
                    # Apply one-hot encoding
                    for i, val in enumerate(unique_values):
                        mask = (data_array == val)
                        encoded[..., i][mask] = 1.0
                else:
                    print(f"⚠️ Unexpected shape for {feature_name}: {data_array.shape}. Skipping one-hot encoding.")
                    return data_array
                    
                return encoded
            else:
                print(f"⚠️ Too many unique values ({num_categories}) for one-hot encoding {feature_name}. Returning original.")
                return data_array
        else:
            print(f"Feature {feature_name} with dtype {data_array.dtype} doesn't need one-hot encoding.")
            return data_array
            
    except Exception as e:
        print(f"❌ Error processing {feature_name}: {e}")
        print("Returning original array without processing")
        return data_array

# Función para mostrar una muestra de datos categóricos y su codificación one-hot
def verify_one_hot_encoding(dataset_path, feature_name='cluster_elevation'):
    """Verifica y muestra ejemplos de la codificación one-hot para una característica categórica"""
    try:
        # Cargar dataset
        print(f"Cargando dataset desde {dataset_path}...")
        ds = xr.open_dataset(dataset_path)
        
        # Verificar si la característica existe
        if feature_name not in ds.data_vars:
            print(f"❌ La característica '{feature_name}' no está en el dataset")
            return
        
        # Extraer datos crudos
        raw_data = ds[feature_name].values
        print(f"Datos cargados: {feature_name}, forma {raw_data.shape}, tipo {raw_data.dtype}")
        
        # Si es una variable 2D (spatial only), mostrar solo una parte
        if raw_data.ndim == 2:
            sample = raw_data[:min(10, raw_data.shape[0]), :min(10, raw_data.shape[1])]
        else:  # Variable 3D (time, lat, lon)
            sample = raw_data[0, :min(10, raw_data.shape[1]), :min(10, raw_data.shape[2])]
            
        print("\n📊 MUESTRA DE DATOS ORIGINALES:")
        print(sample)
        
        # Identificar valores únicos
        unique_values = np.unique(raw_data)
        print(f"\nValores únicos encontrados ({len(unique_values)}): {unique_values}")
        
        # Aplicar one-hot encoding
        print("\n🔄 Aplicando one-hot encoding...")
        encoded_data = robust_array_processing(raw_data, feature_name)
        
        # Verificar si la dimensión aumentó (confirmar one-hot encoding)
        if encoded_data.ndim > raw_data.ndim:
            print(f"✅ One-hot encoding aplicado correctamente: forma {encoded_data.shape}")
            
            # Mostrar ejemplos de la codificación
            print("\n📊 EJEMPLOS DE CODIFICACIÓN ONE-HOT:")
            
            # Crear una tabla que muestre valor original y encoding
            print(f"{'Valor original':<15} | {'Codificación one-hot'}")
            print("-"*40)
            
            for idx, val in enumerate(unique_values):
                # Obtener un punto donde este valor esté presente
                if raw_data.ndim == 2:
                    coords = np.where(raw_data == val)
                    if len(coords[0]) > 0:
                        y, x = coords[0][0], coords[1][0]
                        original = raw_data[y, x]
                        encoding = [encoded_data[y, x, i] for i in range(encoded_data.shape[-1])]
                        print(f"{original:<15} | {encoding}")
                else:  # 3D
                    coords = np.where(raw_data[0] == val)
                    if len(coords[0]) > 0:
                        y, x = coords[0][0], coords[1][0]
                        original = raw_data[0, y, x]
                        encoding = [encoded_data[0, y, x, i] for i in range(encoded_data.shape[-1])]
                        print(f"{original:<15} | {encoding}")
            
            print("\n✅ Codificación one-hot verificada correctamente")
        else:
            print(f"⚠️ No se realizó one-hot encoding: forma {encoded_data.shape}")
        
    except Exception as e:
        print(f"❌ Error verificando one-hot encoding: {e}")
        import traceback
        traceback.print_exc()

dataset_path = FULL_NC if 'FULL_NC' in globals() else './data/complete_dataset.nc'
verify_one_hot_encoding(dataset_path)

def build_compatible_dataloaders(val_year, use_lags=True, batch_size=32, flatten_spatial=True):
    """
    Builds TensorFlow compatible data loaders for training and validation.
    
    Args:
        val_year (int): Year to use for validation
        use_lags (bool): Whether to include lagged features
        batch_size (int): Batch size for training
        flatten_spatial (bool): Whether to flatten spatial dimensions
        
    Returns:
        tuple: (train_dataset, val_dataset, num_features)
    """
    # Load dataset
    dataset_path = FULL_NC if 'FULL_NC' in globals() else './data/complete_dataset.nc'
    ds = xr.open_dataset(dataset_path)
    
    # Extract features and apply processing (including one-hot encoding)
    features = []
    for var_name in ds.data_vars:
        if var_name != 'precipitation':  # Exclude target
            data = ds[var_name].values
            processed = robust_array_processing(data, var_name)
            features.append(processed)
    
    # Calculate number of features
    num_features = sum([feat.shape[-1] if feat.ndim > 3 else 1 for feat in features])
    
    # Handle lag features if requested
    if use_lags:
        # Implementation for lag features would go here
        pass
    
    # Create train/val splits based on year
    train_mask = ds.time.dt.year != val_year
    val_mask = ds.time.dt.year == val_year
    
    # Prepare X and y data
    # Simplified for demonstration
    X = np.concatenate([feat.reshape(len(ds.time), -1) if feat.ndim <= 3 
                       else feat.reshape(len(ds.time), -1) for feat in features], axis=1)
    y = ds['precipitation'].values
    
    # Split data
    X_train = X[train_mask]
    y_train = y[train_mask]
    X_val = X[val_mask]
    y_val = y[val_mask]
    
    # Create TensorFlow datasets
    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(batch_size)
    val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(batch_size)
    
    return train_dataset, val_dataset, num_features

# Modificar TFPrecipitationDataset para integrar normalización
class NormalizedTFPrecipitationDataset(TFPrecipitationDataset):
    """Extiende TFPrecipitationDataset añadiendo normalización"""
    
    @staticmethod
    def build_dataloaders(val_year, use_lags=True, batch_size=32, normalize=True):
        """
        Construye dataloaders con normalización incluida
        
        Args:
            val_year: Año de validación
            use_lags: Si usar variables de lag
            batch_size: Tamaño del batch
            normalize: Si aplicar normalización
            
        Returns:
            Tupla de (train_loader, val_loader, num_features, normalizer)
        """
        # Obtener dataloaders originales
        train_loader, val_loader, num_features = TFPrecipitationDataset.build_dataloaders(
            val_year, use_lags, batch_size)
        
        # Si no se requiere normalización, retornar como están
        if not normalize:
            return train_loader, val_loader, num_features, None
        
        # Aplicar normalización (normalmente requeriría acceder a los datos subyacentes)
        # En este caso, una implementación completa requeriría modificar el DataLoader
        # ya construido o aplicar la normalización durante la construcción.
        
        # Versión simplificada: retornar normalizer junto con dataloaders originales
        # para que pueda ser usado para desnormalizar durante evaluación/visualización
        try:
            dataset_path = FULL_NC if 'FULL_NC' in globals() else './data/complete_dataset.nc'
            ds = xr.open_dataset(dataset_path)
            normalizer = DataNormalizer()
            normalizer.fit(ds)
            print("✅ Normalizador ajustado a datos (para uso posterior)")
            return train_loader, val_loader, num_features, normalizer
        except Exception as e:
            print(f"❌ Error creando normalizador: {e}")
            return train_loader, val_loader, num_features, None

In [None]:
# ▶️ Implementación de AdamW y One-Cycle Scheduler
import tensorflow as tf
import numpy as np
import math

class OneCycleLR(tf.keras.callbacks.Callback):
    """
    One-Cycle Learning Rate Scheduler
    
    Implementa la política de tasa de aprendizaje One-Cycle según el paper
    "Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates"
    
    Características:
    - Incremento de LR desde 'min_lr' a 'max_lr' en 'step_size' epochs
    - Disminución desde 'max_lr' a 'min_lr / div_factor' en los epochs restantes
    - Opcionalmente modula el momento inversamente a la tasa de aprendizaje
    """
    
    def __init__(self, max_lr, steps_per_epoch, epochs, div_factor=25.,
                 pct_start=0.3, anneal_strategy='cos', three_phase=False):
        """
        Inicializa el scheduler One-Cycle
        
        Args:
            max_lr: Tasa de aprendizaje máxima
            steps_per_epoch: Pasos por época
            epochs: Número total de épocas
            div_factor: Factor de división para calcular tasa de aprendizaje inicial
            pct_start: Porcentaje de ciclo total para alcanzar tasa máxima
            anneal_strategy: Estrategia de annealing ('cos' o 'linear')
            three_phase: Si True, usar fase separada para momento
        """
        super(OneCycleLR, self).__init__()
        self.max_lr = max_lr
        self.steps_per_epoch = steps_per_epoch
        self.total_epochs = epochs
        self.div_factor = div_factor
        self.min_lr = max_lr / div_factor
        self.final_lr = self.min_lr / 1000
        self.pct_start = pct_start
        self.anneal_strategy = anneal_strategy
        self.three_phase = three_phase
        
        # Puntos de cambio de fase
        self.step_size_up = int(self.total_steps * self.pct_start)
        self.step_size_down = self.total_steps - self.step_size_up
        
        # Momento (valores por defecto)
        self.initial_momentum = 0.95
        self.final_momentum = 0.85
        
        # Contador de pasos
        self.step_counter = 0
        self.history = {'lr': [], 'momentum': []}
        
    def _annealing_cos(self, start, end, pct):
        """Interpolación coseno entre start y end"""
        cos_out = math.cos(math.pi * pct) + 1
        return end + (start - end) / 2.0 * cos_out
    
    def _annealing_linear(self, start, end, pct):
        """Interpolación lineal entre start y end"""
        return (end - start) * pct + start
    
    def get_lr_momentum(self):
        """Calcula LR y momento actuales según el paso"""
        # Calcular porcentaje de avance en la fase actual
        if self.step_counter <= self.step_size_up:
            # Fase ascendente
            percent = self.step_counter / self.step_size_up
            
            if self.anneal_strategy == 'cos':
                lr = self._annealing_cos(self.min_lr, self.max_lr, percent)
                momentum = self._annealing_cos(self.initial_momentum, 
                                              self.final_momentum, percent)
            else:
                lr = self._annealing_linear(self.min_lr, self.max_lr, percent)
                momentum = self._annealing_linear(self.initial_momentum, 
                                                 self.final_momentum, percent)
        else:
            # Fase descendente
            percent = (self.step_counter - self.step_size_up) / self.step_size_down
            
            if self.anneal_strategy == 'cos':
                lr = self._annealing_cos(self.max_lr, self.final_lr, percent)
                momentum = self._annealing_cos(self.final_momentum, 
                                              self.initial_momentum, percent)
            else:
                lr = self._annealing_linear(self.max_lr, self.final_lr, percent)
                momentum = self._annealing_linear(self.final_momentum, 
                                                 self.initial_momentum, percent)
        
        return lr, momentum
    
    def on_train_batch_begin(self, batch, logs=None):
        """Actualiza LR y momento al comienzo de cada batch"""
        lr, momentum = self.get_lr_momentum()
        
        # Actualizar optimizer
        K = tf.keras.backend
        if hasattr(self.model.optimizer, 'lr'):
            K.set_value(self.model.optimizer.lr, lr)
            
        # Actualizar momentum/beta1 si es posible
        if hasattr(self.model.optimizer, 'beta_1'):  # Para Adam/AdamW
            K.set_value(self.model.optimizer.beta_1, momentum)
        elif hasattr(self.model.optimizer, 'momentum'):  # Para SGD
            K.set_value(self.model.optimizer.momentum, momentum)
            
        # Incrementar contador
        self.step_counter += 1
        
        # Guardar historial
        self.history['lr'].append(lr)
        self.history['momentum'].append(momentum)
    
    def on_epoch_end(self, epoch, logs=None):
        """Registra LR y momento al final de cada época"""
        lr, momentum = self.get_lr_momentum()
        
        # Añadir al log si existe
        if logs is not None:
            logs['lr'] = lr 
            logs['momentum'] = momentum
            
        # Imprimir valores cada 5 épocas
        if epoch % 5 == 0:
            print(f"Epoch {epoch}: lr={lr:.6f}, momentum={momentum:.6f}")

# Función para obtener un optimizador AdamW con weight decay configurable
def get_adamw_optimizer(learning_rate=1e-3, weight_decay=1e-4):
    """
    Crea un optimizador AdamW con los parámetros especificados
    
    Args:
        learning_rate: Tasa de aprendizaje inicial
        weight_decay: Factor de regularización L2
        
    Returns:
        Optimizador AdamW configurado
    """
    try:
        # Intentar usar AdamW de la API principal (disponible en TF >= 2.11)
        optimizer = tf.keras.optimizers.AdamW(
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-7
        )
    except AttributeError:
        # Fallback a AdamW experimental para versiones anteriores
        try:
            from tensorflow.keras.optimizers import legacy
            optimizer = legacy.AdamW(
                learning_rate=learning_rate,
                weight_decay=weight_decay,
                beta_1=0.9,
                beta_2=0.999,
                epsilon=1e-7
            )
        except:
            # Último recurso: Adam normal + L2 manual
            print("⚠️ AdamW no disponible, usando Adam estándar + regularización L2")
            optimizer = tf.keras.optimizers.Adam(
                learning_rate=learning_rate,
                beta_1=0.9,
                beta_2=0.999,
                epsilon=1e-7
            )
    
    return optimizer

# Función para entrenar con AdamW + One-Cycle
def train_with_advanced_optim(model, train_loader, val_loader, epochs=60, patience=20,
                            max_lr=1e-3, weight_decay=1e-4, fold=None, exp_name=None):
    """
    Entrenar un modelo usando optimizador AdamW y One-Cycle scheduler
    
    Args:
        model: Modelo a entrenar
        train_loader: DataLoader para entrenamiento
        val_loader: DataLoader para validación
        epochs: Número máximo de épocas
        patience: Paciencia para early stopping
        max_lr: Tasa de aprendizaje máxima para one-cycle
        weight_decay: Regularización L2 para AdamW
        fold: Fold actual (para registro)
        exp_name: Nombre del experimento (para registro)
        
    Returns:
        model: Modelo entrenado (mejor versión)
        history: Historial de entrenamiento
        best_rmse: Mejor RMSE de validación
    """
    import tensorflow as tf
    import numpy as np
    import time
    
    print(f"🔄 Entrenando {exp_name} (fold {fold}) con AdamW y One-Cycle")
    
    # Obtener número de pasos por época
    steps_per_epoch = 0
    for _ in train_loader:
        steps_per_epoch += 1
    
    # Si no se pudo determinar, usar un valor razonable
    if steps_per_epoch == 0:
        steps_per_epoch = 10
        print(f"⚠️ No se pudo determinar steps_per_epoch, usando {steps_per_epoch}")
    
    # Configurar optimizador AdamW
    optimizer = get_adamw_optimizer(learning_rate=max_lr/25, weight_decay=weight_decay)
    
    # Compilar modelo
    model.compile(
        optimizer=optimizer,
        loss=huber_loss_with_horizon_weight, 
        metrics=['mse', 'mae']
    )
    
    # Configurar callbacks
    callbacks = []
    
    # One-Cycle LR
    onecycle = OneCycleLR(
        max_lr=max_lr,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        div_factor=25.,
        pct_start=0.3
    )
    callbacks.append(onecycle)
    
    # Early stopping
    early_stop = tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=patience,
        restore_best_weights=True,
        verbose=1
    )
    callbacks.append(early_stop)
    
    # Entrenamiento
    start_time = time.time()
    history = model.fit(
        train_loader,
        validation_data=val_loader,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    ).history
    
    # Agregar historial de LR y momentum
    history['lr'] = onecycle.history['lr']
    history['momentum'] = onecycle.history['momentum']
    
    # Calcular mejor RMSE
    val_mse = history['val_mse']
    best_epoch = np.argmin(val_mse)
    best_rmse = np.sqrt(val_mse[best_epoch])
    
    # Tiempo total
    total_time = time.time() - start_time
    print(f"✅ Entrenamiento completado en {total_time:.2f}s")
    print(f"   Mejor RMSE: {best_rmse:.4f} (época {best_epoch+1})")
    
    return model, history, best_rmse

print("✅ AdamW y One-Cycle Scheduler implementados según plan experimental")

In [None]:
# ▶️ Implementación de función de pérdida Huber con ponderación por horizonte
import tensorflow as tf
import numpy as np

def huber_loss(y_true, y_pred, delta=1.0):
    """
    Implementa la función de pérdida Huber
    
    Args:
        y_true: Tensor con valores reales
        y_pred: Tensor con predicciones
        delta: Parámetro que controla la transición de L1 a L2
        
    Returns:
        Tensor con la pérdida Huber
    """
    error = y_true - y_pred
    abs_error = tf.abs(error)
    quadratic = tf.minimum(abs_error, delta)
    linear = abs_error - quadratic
    loss = 0.5 * tf.square(quadratic) + delta * linear
    return tf.reduce_mean(loss)

def huber_loss_with_horizon_weight(y_true, y_pred, delta=1.0):
    """
    Función de pérdida Huber con ponderación por horizonte
    Asigna mayor peso a los horizontes más lejanos según la fórmula:
    ωₕ = 1 + h/12
    
    Args:
        y_true: Tensor con valores reales [batch, horizon, ...] 
        y_pred: Tensor con predicciones [batch, horizon, ...]
        delta: Parámetro de la función Huber
        
    Returns:
        Tensor con la pérdida ponderada
    """
    # Extraer dimensión del horizonte (normalmente la segunda dimensión)
    horizon = tf.shape(y_true)[1]
    horizon_float = tf.cast(horizon, tf.float32)
    
    # Crear pesos según fórmula ωₕ = 1 + h/12
    h_indices = tf.range(1, horizon + 1, dtype=tf.float32)
    weights = 1.0 + h_indices / 12.0
    
    # Expandir dimensions para broadcasting
    # [horizon] -> [1, horizon, 1, 1, ...]
    for _ in range(len(y_true.shape) - 2):
        weights = tf.expand_dims(weights, axis=-1)
        
    # Calcular error Huber
    error = y_true - y_pred
    abs_error = tf.abs(error)
    quadratic = tf.minimum(abs_error, delta)
    linear = abs_error - quadratic
    unweighted_loss = 0.5 * tf.square(quadratic) + delta * linear
    
    # Aplicar pesos por horizonte
    weighted_loss = unweighted_loss * weights
    
    # Reducir media
    return tf.reduce_mean(weighted_loss)

def rmse_by_horizon(y_true, y_pred):
    """
    Calcula RMSE separado para cada paso del horizonte de predicción
    
    Args:
        y_true: Tensor con valores reales [batch, horizon, ...] 
        y_pred: Tensor con predicciones [batch, horizon, ...]
        
    Returns:
        Lista con RMSE para cada horizonte
    """
    # Obtener número de horizontes
    horizon = y_true.shape[1]
    
    # Calcular RMSE por horizonte
    rmse_values = []
    for h in range(horizon):
        # Extraer slice para este horizonte
        y_true_h = y_true[:, h]
        y_pred_h = y_pred[:, h]
        
        # Calcular MSE y RMSE
        mse = tf.reduce_mean(tf.square(y_true_h - y_pred_h))
        rmse = tf.sqrt(mse)
        rmse_values.append(rmse)
        
    return rmse_values

# Clase de métrica personalizada para TensorFlow
class RMSEByHorizon(tf.keras.metrics.Metric):
    """
    Métrica personalizada para registrar RMSE por horizonte de predicción
    """
    def __init__(self, horizon=12, name='rmse_by_horizon', **kwargs):
        super(RMSEByHorizon, self).__init__(name=name, **kwargs)
        self.horizon = horizon
        self.errors = [self.add_weight(name=f'horizon_{h+1}', initializer='zeros') 
                      for h in range(horizon)]
        self.counts = [self.add_weight(name=f'count_{h+1}', initializer='zeros') 
                      for h in range(horizon)]
        
    def update_state(self, y_true, y_pred, sample_weight=None):
        # Aseguramos que estamos trabajando con el horizonte correcto
        horizon = min(self.horizon, y_true.shape[1])
        
        # Actualizar error acumulado por horizonte
        for h in range(horizon):
            # Extraer slice para este horizonte
            y_true_h = y_true[:, h]
            y_pred_h = y_pred[:, h]
            
            # Actualizar suma de errores cuadráticos y conteo
            mse = tf.reduce_mean(tf.square(y_true_h - y_pred_h))
            self.errors[h].assign_add(mse)
            self.counts[h].assign_add(1.0)
    
    def result(self):
        # Calcular RMSE por horizonte
        results = [tf.sqrt(error / count) if count > 0 else 0.0
                  for error, count in zip(self.errors, self.counts)]
        # Devolver promedio
        return tf.reduce_mean(results)
    
    def get_horizon_results(self):
        """Devuelve lista con RMSE por horizonte"""
        return [tf.sqrt(error / count) if count > 0 else 0.0
               for error, count in zip(self.errors, self.counts)]
    
    def reset_state(self):
        for h in range(self.horizon):
            self.errors[h].assign(0.0)
            self.counts[h].assign(0.0)

# Ejemplo de compilación de modelo con estas funciones
def compile_model_with_huber(model):
    """Compila un modelo usando pérdida Huber con ponderación por horizonte"""
    model.compile(
        optimizer=get_adamw_optimizer(learning_rate=1e-3/25, weight_decay=1e-4),
        loss=huber_loss_with_horizon_weight,
        metrics=['mse', 'mae', RMSEByHorizon(horizon=12)]
    )
    return model

print("✅ Función de pérdida Huber con ponderación por horizonte implementada")

In [None]:
# ▶️ Implementación de Teacher-Forcing con Cosine Decay
import tensorflow as tf
import numpy as np
import math

class TeacherForcingController:
    """
    Controlador para Teacher-Forcing con decaimiento de coseno (0.70→0.30)
    
    El Teacher-Forcing es una técnica para modelos secuenciales donde se usa la
    salida real (no la predicha) como entrada del siguiente paso durante el entrenamiento.
    
    Esta implementación:
    - Comienza con probabilidad 0.70 de usar Teacher-Forcing
    - Decrece hacia 0.30 siguiendo una curva de coseno
    - Se puede ajustar a cualquier rango de épocas
    """
    
    def __init__(self, start_prob=0.70, end_prob=0.30, total_epochs=60):
        """
        Inicializa el controlador de Teacher-Forcing
        
        Args:
            start_prob: Probabilidad inicial de usar Teacher-Forcing
            end_prob: Probabilidad final de usar Teacher-Forcing
            total_epochs: Número total de épocas para el decaimiento
        """
        self.start_prob = start_prob
        self.end_prob = end_prob
        self.total_epochs = total_epochs
        self.current_epoch = 0
        self.current_prob = start_prob
    
    def update_epoch(self, epoch=None):
        """
        Actualiza la época actual y recalcula la probabilidad de Teacher-Forcing
        
        Args:
            epoch: Nueva época (opcional, si None incrementa la actual)
            
        Returns:
            Probabilidad actualizada
        """
        # Actualizar época
        if epoch is not None:
            self.current_epoch = epoch
        else:
            self.current_epoch += 1
        
        # Asegurar límites
        self.current_epoch = min(self.current_epoch, self.total_epochs)
        
        # Calcular progreso normalizado (0 a 1)
        progress = self.current_epoch / self.total_epochs
        
        # Aplicar decaimiento de coseno
        cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
        self.current_prob = self.end_prob + (self.start_prob - self.end_prob) * cosine_decay
        
        return self.current_prob
    
    def use_teacher_forcing(self):
        """
        Decide si usar Teacher-Forcing en la iteración actual
        
        Returns:
            True si se debe usar Teacher-Forcing, False en caso contrario
        """
        return np.random.random() < self.current_prob
    
    def get_current_probability(self):
        """
        Devuelve la probabilidad actual
        
        Returns:
            Probabilidad actual de usar Teacher-Forcing
        """
        return self.current_prob


# Implementación de una capa de decoder con Teacher-Forcing para GRU
class TeacherForcingGRUDecoder(tf.keras.layers.Layer):
    """
    Capa de decoder GRU con soporte para Teacher-Forcing
    
    Esta capa permite elegir entre usar la salida predicha por el modelo
    o la salida real (Teacher-Forcing) durante el entrenamiento, según
    una probabilidad que puede variar durante el entrenamiento.
    """
    
    def __init__(self, units, tf_controller=None, return_sequences=True, dropout=0.2, **kwargs):
        """
        Inicializa el decoder con Teacher-Forcing
        
        Args:
            units: Dimensionalidad de la capa GRU
            tf_controller: Controlador de Teacher-Forcing (opcional)
            return_sequences: Si es True, devuelve la secuencia completa
            dropout: Tasa de dropout
            **kwargs: Argumentos adicionales para la capa GRU
        """
        super(TeacherForcingGRUDecoder, self).__init__(**kwargs)
        self.gru = tf.keras.layers.GRU(units, return_sequences=return_sequences, 
                                      dropout=dropout, **kwargs)
        self.tf_controller = tf_controller or TeacherForcingController()
        self.dense = tf.keras.layers.Dense(1)  # Capa de proyección para las salidas
        self.supports_masking = True
        self.dropout_layer = tf.keras.layers.Dropout(dropout)
        
    def call(self, inputs, targets=None, initial_state=None, training=False):
        """
        Ejecuta el decoder con soporte para Teacher-Forcing
        
        Args:
            inputs: Entradas del decoder (salidas del encoder)
            targets: Objetivos reales (usados en Teacher-Forcing)
            initial_state: Estado inicial para la GRU
            training: Si es True, estamos en modo entrenamiento
            
        Returns:
            Salidas del decoder
        """
        batch_size = tf.shape(inputs)[0]
        seq_length = tf.shape(inputs)[1]
        
        # Solo usar Teacher-Forcing durante entrenamiento y si targets está disponible
        use_tf = training and targets is not None and self.tf_controller.use_teacher_forcing()
        
        if not use_tf or targets is None:
            # Ejecución normal (sin Teacher-Forcing)
            outputs = self.gru(inputs, initial_state=initial_state, training=training)
            outputs = self.dropout_layer(outputs, training=training)
            return self.dense(outputs)
        
        # Con Teacher-Forcing: procesamos paso a paso
        # Primero procesamos el primer paso normalmente
        inputs_t = inputs[:, 0:1, :]
        state = initial_state
        
        # Procesar primer paso
        output, state = self.gru(inputs_t, initial_state=state, training=training)
        outputs = [self.dense(output)]
        
        # Para el resto de pasos, decidimos entre usar Teacher-Forcing o no
        for t in range(1, seq_length):
            # En Teacher-Forcing: usar el target real del paso anterior como entrada
            if use_tf:
                teacher_input = targets[:, t-1:t, :]
                
                # La entrada normal del decoder se combina con el target anterior
                decoder_input = inputs[:, t:t+1, :]
                combined_input = tf.concat([decoder_input, teacher_input], axis=-1)
                
                # Procesar este paso
                output, state = self.gru(combined_input, initial_state=state, training=training)
            else:
                # Sin Teacher-Forcing: usar la entrada normal
                decoder_input = inputs[:, t:t+1, :]
                output, state = self.gru(decoder_input, initial_state=state, training=training)
            
            # Proyectar a espacio de salida
            output = self.dropout_layer(output, training=training)
            projected = self.dense(output)
            outputs.append(projected)
        
        # Concatenar todas las salidas
        return tf.concat(outputs, axis=1)
    
    def get_config(self):
        """Configuración para serialización"""
        config = super().get_config()
        config.update({
            'units': self.gru.units,
            'return_sequences': self.gru.return_sequences,
            'dropout': self.dropout_layer.rate
        })
        return config


# Ejemplo de modelo GRU Encoder-Decoder con Teacher-Forcing
class GRUEncoderDecoderWithTF(tf.keras.Model):
    """
    Modelo GRU Encoder-Decoder con Teacher-Forcing
    
    Implementa el Teacher-Forcing con probabilidad variable según
    el plan experimental (decaimiento de coseno 0.70→0.30)
    """
    
    def __init__(self, input_dim, hidden_units=128, dropout=0.2, total_epochs=60):
        """
        Inicializa el modelo
        
        Args:
            input_dim: Dimensionalidad de entrada
            hidden_units: Unidades ocultas en GRU
            dropout: Tasa de dropout
            total_epochs: Total de épocas para el decaimiento de Teacher-Forcing
        """
        super(GRUEncoderDecoderWithTF, self).__init__()
        
        # Configurar controlador de Teacher-Forcing
        self.tf_controller = TeacherForcingController(
            start_prob=0.70,
            end_prob=0.30,
            total_epochs=total_epochs
        )
        
        # Arquitectura del modelo
        self.encoder_gru = tf.keras.layers.GRU(
            hidden_units, return_sequences=True, return_state=True)
        
        self.decoder_gru = TeacherForcingGRUDecoder(
            hidden_units, tf_controller=self.tf_controller)
        
        self.output_layer = tf.keras.layers.Dense(1)
        self.dropout = dropout
        self.dropout_layer = tf.keras.layers.Dropout(dropout)
    
    def call(self, inputs, targets=None, training=False):
        """
        Ejecuta el modelo
        
        Args:
            inputs: Tensor de entrada [batch, time_steps, features]
            targets: Objetivos reales (para Teacher-Forcing)
            training: Si es True, estamos en modo entrenamiento
            
        Returns:
            Predicciones
        """
        # Actualizar probabilidad de Teacher-Forcing en época actual (solo training)
        if training:
            self.current_tf_prob = self.tf_controller.get_current_probability()
            tf.summary.scalar('teacher_forcing_prob', self.current_tf_prob)
        
        # Codificación (encoder)
        encoder_outputs, encoder_state = self.encoder_gru(inputs, training=training)
        encoder_outputs = self.dropout_layer(encoder_outputs, training=training)
        
        # Decodificación (decoder) con soporte para Teacher-Forcing
        decoder_outputs = self.decoder_gru(
            encoder_outputs, targets=targets, 
            initial_state=encoder_state, training=training)
        
        # Proyección final
        outputs = self.output_layer(decoder_outputs)
        
        return outputs
    
    def train_step(self, data):
        """
        Implementación personalizada de paso de entrenamiento para Teacher-Forcing
        
        Args:
            data: Tupla (x, y) con entradas y objetivos
            
        Returns:
            Dict con métricas de entrenamiento
        """
        x, y = data
        
        with tf.GradientTape() as tape:
            # Forward pass con Teacher-Forcing habilitado
            predictions = self(x, targets=y, training=True)
            
            # Calcular pérdida
            loss = self.compiled_loss(y, predictions, regularization_losses=self.losses)
        
        # Calcular gradientes y actualizar pesos
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Actualizar métricas
        self.compiled_metrics.update_state(y, predictions)
        
        # Añadir probabilidad actual de Teacher-Forcing a las métricas
        metrics_results = {m.name: m.result() for m in self.metrics}
        metrics_results['teacher_forcing_prob'] = self.tf_controller.get_current_probability()
        
        return metrics_results
    
    def on_epoch_end(self, epoch, logs=None):
        """
        Callback al final de cada época para actualizar la probabilidad de Teacher-Forcing
        """
        # Actualizar probabilidad de Teacher-Forcing
        updated_prob = self.tf_controller.update_epoch(epoch)
        
        # Agregar al log si existe
        if logs is not None:
            logs['teacher_forcing_prob'] = updated_prob
            
        # Imprimir cambio cada 5 épocas
        if epoch % 5 == 0:
            print(f"Época {epoch}: probabilidad de Teacher-Forcing = {updated_prob:.4f}")

# Actualizar MODEL_FACTORY para incluir los nuevos modelos avanzados
def update_model_factory():
    """
    Actualiza el diccionario MODEL_FACTORY con los nuevos modelos avanzados
    """
    global MODEL_FACTORY
    
    # Añadir modelos avanzados
    updated_factory = {
        'gru_ed': lambda input_dim, dropout=0.2: GRUEncoderDecoderWithTF(input_dim, dropout=dropout),
        'gru_ed_pafc': lambda input_dim, dropout=0.2: GRUEncoderDecoderWithTF(input_dim, dropout=dropout),
        'ae_fusion_gru': lambda input_dim, dropout=0.2: AEFusionGRUModel(input_dim, dropout=dropout),
        'ae_fusion_gru_t': lambda input_dim, dropout=0.2: AEFusionGRUAttention(input_dim, dropout=dropout),
        'ae_fusion_gru_t_mask': lambda input_dim, dropout=0.2: AEFusionGRUAttentionMask(input_dim, dropout=dropout)
    }
    
    # Actualizar diccionario global
    MODEL_FACTORY.update(updated_factory)
    print("✅ MODEL_FACTORY actualizado con arquitecturas avanzadas")

print("✅ Arquitecturas avanzadas implementadas según plan experimental")
print("   Incluyendo Conv3D-AE con bottleneck de 64 dims")

# Ejemplo de la arquitectura del autoencoder
try:
    # Crear modelo
    input_shape = (48, 61, 65, 3)  # (time, height, width, channels)
    model = Conv3DAutoencoder(input_shape)
    
    # Mostrar arquitectura
    print("\n🧠 ARQUITECTURA DEL AUTOENCODER CONVOLUCIONAL 3D:")
    print("="*80)
    print("ENCODER:")
    model.encoder.summary(line_length=100)
    print("\nDECODER:")
    model.decoder.summary(line_length=100)
    
    # Actualizar MODEL_FACTORY
    update_model_factory()
except Exception as e:
    print(f"❌ Error al mostrar arquitectura: {e}")
    import traceback
    traceback.print_exc()

In [None]:
# ▶️ Implementación de normalización con RobustScaler y StandardScaler
import numpy as np
import xarray as xr
from sklearn.preprocessing import RobustScaler, StandardScaler

class DataNormalizer:
    """
    Implementa la normalización de datos según el plan experimental:
    - RobustScaler por celda para precipitación (más resistente a outliers)
    - StandardScaler global para variables topográficas y temporales
    
    Los scalers se guardan para poder aplicar la transformación inversa
    durante la evaluación y visualización.
    """
    
    def __init__(self):
        """Inicializa el normalizador con scalers vacíos"""
        # Diccionario para almacenar los scalers de precipitación por celda
        self.precip_scalers = {}  # (lat, lon) -> RobustScaler
        
        # Scalers globales para cada tipo de variable
        self.topo_scaler = StandardScaler()       # Variables topográficas
        self.temporal_scaler = StandardScaler()   # Variables temporales
        self.other_scaler = StandardScaler()      # Otras variables
        
        # Mapeo de variables a tipos
        self.topo_features = ['elevation', 'slope', 'roughness', 'curvature', 'aspect']
        self.temporal_features = ['month_sin', 'month_cos', 'doy_sin', 'doy_cos']
        self.precipitation_vars = ['total_precipitation', 'precip_hist', 
                                  'total_precipitation_lag1', 'total_precipitation_lag2', 'total_precipitation_lag12',
                                  'lag_1', 'lag_2', 'lag_12']
        self.is_fitted = False
    
    def fit(self, ds):
        """
        Ajusta los escaladores a los datos del dataset
        
        Args:
            ds: Dataset xarray con las variables
        """
        print("🔢 Ajustando escaladores de normalización...")
        
        # 1. RobustScaler por celda para variables de precipitación
        if any(var in ds.data_vars for var in self.precipitation_vars):
            print("  ▶️ Ajustando RobustScaler por celda para precipitación...")
            
            # Encontrar una variable de precipitación disponible
            precip_var = next((var for var in self.precipitation_vars if var in ds.data_vars), None)
            
            if precip_var:
                # Acceder a la variable de precipitación en forma de arreglo
                precip_data = ds[precip_var].values
                
                # Determinar dimensiones espaciales
                if precip_data.ndim == 3:  # [time, lat, lon]
                    n_times, n_lats, n_lons = precip_data.shape
                    
                    # Crear y ajustar un scaler para cada celda espacial
                    for lat_idx in range(n_lats):
                        for lon_idx in range(n_lons):
                            # Extraer serie temporal para esta celda
                            cell_data = precip_data[:, lat_idx, lon_idx].reshape(-1, 1)
                            
                            # Solo ajustar si tenemos datos válidos (no todos NaN)
                            if not np.all(np.isnan(cell_data)):
                                # Reemplazar NaNs con 0 para el ajuste
                                cell_data_clean = np.nan_to_num(cell_data, nan=0.0)
                                
                                # Crear y ajustar RobustScaler para esta celda
                                scaler = RobustScaler(quantile_range=(10.0, 90.0))
                                scaler.fit(cell_data_clean)
                                
                                # Guardar el scaler
                                self.precip_scalers[(lat_idx, lon_idx)] = scaler
                
                print(f"  ✓ RobustScaler ajustado para {len(self.precip_scalers)} celdas")
        
        # 2. StandardScaler global para variables topográficas
        topo_arrays = []
        for var in self.topo_features:
            if var in ds.data_vars:
                data = ds[var].values
                
                # Asegurarse que es 2D o 3D y aplanar para el ajuste
                if data.ndim == 2:  # [lat, lon]
                    flat_data = data.reshape(-1, 1)
                elif data.ndim == 3:  # [time, lat, lon]
                    # En este caso, solo necesitamos una muestra temporal
                    flat_data = data[0].reshape(-1, 1)
                else:
                    continue
                
                # Añadir al array de variables topográficas
                topo_arrays.append(flat_data)
        
        if topo_arrays:
            # Concatenar datos topográficos y ajustar scaler
            topo_data = np.concatenate(topo_arrays, axis=0)
            topo_data_clean = np.nan_to_num(topo_data, nan=0.0)
            self.topo_scaler.fit(topo_data_clean)
            print(f"  ✓ StandardScaler ajustado para variables topográficas")
        
        # 3. StandardScaler para variables temporales
        temporal_arrays = []
        for var in self.temporal_features:
            if var in ds.data_vars:
                data = ds[var].values
                
                # Variables temporales suelen ser 3D [time, lat, lon]
                if data.ndim == 3:
                    # Tomar una muestra espacial (todas las veces)
                    flat_data = data[:, 0, 0].reshape(-1, 1)
                else:
                    flat_data = data.reshape(-1, 1)
                
                temporal_arrays.append(flat_data)
        
        if temporal_arrays:
            # Concatenar datos temporales y ajustar scaler
            temporal_data = np.concatenate(temporal_arrays, axis=0)
            self.temporal_scaler.fit(temporal_data)
            print(f"  ✓ StandardScaler ajustado para variables temporales")
            
        self.is_fitted = True
        print("✅ Normalización configurada correctamente")
    
    def transform(self, ds):
        """
        Aplica la normalización al dataset
        
        Args:
            ds: Dataset xarray con las variables
            
        Returns:
            Dataset xarray con variables normalizadas
        """
        if not self.is_fitted:
            print("⚠️ El normalizador no está ajustado. Llamando a fit() primero...")
            self.fit(ds)
        
        # Crear una copia del dataset para no modificar el original
        ds_norm = ds.copy()
        
        # 1. Normalizar variables de precipitación (por celda)
        for var in self.precipitation_vars:
            if var in ds.data_vars:
                data = ds[var].values.copy()
                
                if data.ndim == 3:  # [time, lat, lon]
                    n_times, n_lats, n_lons = data.shape
                    
                    # Normalizar cada celda por separado
                    for lat_idx in range(n_lats):
                        for lon_idx in range(n_lons):
                            if (lat_idx, lon_idx) in self.precip_scalers:
                                # Extraer y desnormalizar
                                cell_data = data[:, lat_idx, lon_idx].reshape(-1, 1)
                                
                                # Normalizar (con control de NaN)
                                cell_data_clean = np.nan_to_num(cell_data, nan=0.0)
                                cell_data_norm = self.precip_scalers[(lat_idx, lon_idx)].transform(cell_data_clean)
                                
                                # Actualizar datos
                                data[:, lat_idx, lon_idx] = cell_data_norm.flatten()
                    
                    # Actualizar variable en el dataset
                    ds_norm[var].values = data
                    print(f"  ✓ Variable {var} normalizada con RobustScaler por celda")
        
        # 2. Normalizar variables topográficas
        for var in self.topo_features:
            if var in ds.data_vars:
                data = ds[var].values.copy()
                data_shape = data.shape
                
                # Aplanar para normalización
                flat_data = data.reshape(-1, 1)
                flat_data_clean = np.nan_to_num(flat_data, nan=0.0)
                
                # Normalizar y restaurar forma
                normalized = self.topo_scaler.transform(flat_data_clean)
                ds_norm[var].values = normalized.reshape(data_shape)
                print(f"  ✓ Variable topográfica {var} normalizada con StandardScaler")
        
        # 3. Normalizar variables temporales
        for var in self.temporal_features:
            if var in ds.data_vars:
                data = ds[var].values.copy()
                data_shape = data.shape
                
                if data.ndim == 3:  # [time, lat, lon]
                    # Variables temporales: mismo valor en cada punto espacial para un tiempo dado
                    # Normalizar solo los valores únicos temporales
                    unique_temporal = np.unique(data.reshape(data.shape[0], -1), axis=1)
                    normalized = self.temporal_scaler.transform(unique_temporal)
                    
                    # Recrear el array 3D
                    for t in range(data.shape[0]):
                        data[t, :, :] = normalized[t, 0]
                    
                    ds_norm[var].values = data
                else:
                    # Caso más simple (array 1D o 2D)
                    flat_data = data.reshape(-1, 1)
                    normalized = self.temporal_scaler.transform(flat_data)
                    ds_norm[var].values = normalized.reshape(data_shape)
                
                print(f"  ✓ Variable temporal {var} normalizada con StandardScaler")
                
        print("✅ Normalización aplicada correctamente")
        return ds_norm
                
    def inverse_transform_precip(self, data, lat_idx=None, lon_idx=None):
        """
        Desnormaliza datos de precipitación
        
        Args:
            data: Datos normalizados
            lat_idx, lon_idx: Índices de la celda (si es None, se asume que data tiene forma [time, lat, lon])
            
        Returns:
            Datos desnormalizados
        """
        if not self.is_fitted:
            print("⚠️ El normalizador no está ajustado, no se puede desnormalizar")
            return data
        
        # Caso 1: Datos 3D [time, lat, lon]
        if lat_idx is None and lon_idx is None and data.ndim == 3:
            result = np.zeros_like(data)
            n_times, n_lats, n_lons = data.shape
            
            for lat_idx in range(n_lats):
                for lon_idx in range(n_lons):
                    if (lat_idx, lon_idx) in self.precip_scalers:
                        # Extraer y desnormalizar
                        cell_data = data[:, lat_idx, lon_idx].reshape(-1, 1)
                        cell_denorm = self.precip_scalers[(lat_idx, lon_idx)].inverse_transform(cell_data)
                        result[:, lat_idx, lon_idx] = cell_denorm.flatten()
            
            return result
        
        # Caso 2: Celda específica
        elif lat_idx is not None and lon_idx is not None:
            if (lat_idx, lon_idx) in self.precip_scalers:
                # Asegurar forma 2D
                reshaped_data = data.reshape(-1, 1) if data.ndim == 1 else data
                return self.precip_scalers[(lat_idx, lon_idx)].inverse_transform(reshaped_data)
        
        # Fallback: retornar datos originales
        return data

# Integrar con los dataloaders existentes
def apply_normalizers_to_dataloaders(train_loader, val_loader, dataset_path=None):
    """
    Aplica normalización a los dataloaders existentes
    
    Args:
        train_loader: DataLoader de entrenamiento
        val_loader: DataLoader de validación
        dataset_path: Ruta al archivo NetCDF para ajustar normalizadores
        
    Returns:
        Tuple de (train_loader, val_loader, normalizer)
    """
    # Si no se proporciona ruta, usar la global
    if dataset_path is None:
        dataset_path = FULL_NC if 'FULL_NC' in globals() else './data/complete_dataset.nc'
    
    try:
        # Cargar dataset y ajustar normalizers
        print(f"📊 Cargando dataset para normalización desde {dataset_path}")
        ds = xr.open_dataset(dataset_path)
        
        # Crear e inicializar normalizador
        normalizer = DataNormalizer()
        normalizer.fit(ds)
        
        # Aplicar normalización a los dataloaders (requiere código específico 
        # que dependerá de la estructura del DataLoader)
        # ...
        
        return train_loader, val_loader, normalizer
    except Exception as e:
        print(f"❌ Error aplicando normalización: {e}")
        import traceback
        traceback.print_exc()
        return train_loader, val_loader, None

# Ejemplo de uso
def normalize_dataset(dataset_path):
    """
    Normaliza un dataset según las especificaciones del plan experimental
    
    Args:
        dataset_path: Ruta al archivo NetCDF del dataset
        
    Returns:
        Dataset normalizado, normalizador ajustado
    """
    try:
        # Cargar dataset
        print(f"📊 Cargando dataset desde {dataset_path}")
        ds = xr.open_dataset(dataset_path)
        
        # Crear y ajustar normalizador
        normalizer = DataNormalizer()
        normalizer.fit(ds)
        
        # Normalizar dataset
        ds_normalized = normalizer.transform(ds)
        
        # Comparar valores antes y después
        for var in ['total_precipitation', 'elevation', 'month_sin']:
            if var in ds.data_vars:
                orig = ds[var].values
                norm = ds_normalized[var].values
                
                # Mostrar estadísticas
                print(f"\n📊 Estadísticas para {var}:")
                print(f"  Original: min={np.nanmin(orig):.4f}, max={np.nanmax(orig):.4f}, "
                      f"mean={np.nanmean(orig):.4f}, std={np.nanstd(orig):.4f}")
                print(f"  Normalizado: min={np.nanmin(norm):.4f}, max={np.nanmax(norm):.4f}, "
                      f"mean={np.nanmean(norm):.4f}, std={np.nanstd(norm):.4f}")
        
        return ds_normalized, normalizer
    
    except Exception as e:
        print(f"❌ Error al normalizar dataset: {e}")
        import traceback
        traceback.print_exc()
        return None, None

# Verificar normalización con nuestro dataset
if 'FULL_NC' in globals():
    print("\n" + "="*80)
    print("🔢 VERIFICACIÓN DE NORMALIZACIÓN DE DATOS")
    print("="*80)
    normalized_ds, normalizer = normalize_dataset(FULL_NC)
    print("="*80)

# Visualizaciones Avanzadas de Métricas

## Análisis de Rendimiento por Cluster de Altitud y Comparación Temporal

Esta sección implementa las visualizaciones avanzadas según el plan experimental:
1. **Box-plots por cluster de altitud** (low/mid/high): Muestra cómo varía el error de predicción según la elevación
2. **Mapas de sesgo medio** comparando folds históricos vs recientes: Visualiza cambios en los patrones espaciales de error entre épocas históricas (F4, F5) y recientes (F1, F2, F3)

In [None]:
# ▶️ Implementación de visualizaciones box-plots y mapas de sesgo
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import xarray as xr
from matplotlib.colors import TwoSlopeNorm
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from scipy.stats import ttest_ind, mannwhitneyu

# Configuración general de visualización
plt.rcParams['figure.figsize'] = (12, 8)
sns.set_style("whitegrid")
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

class AdvancedMetricsVisualizer:
    """
    Visualizador avanzado de métricas para modelos de precipitación
    
    Implementa las visualizaciones especificadas en el plan experimental:
    1. Box-plots por cluster de altitud (low/mid/high)
    2. Mapas de sesgo medio comparando folds históricos vs recientes
    """
    
    def __init__(self, dataset_path=None, model_results_path=None, normalizer=None):
        """
        Inicializa el visualizador
        
        Args:
            dataset_path: Ruta al dataset NetCDF
            model_results_path: Ruta a resultados del modelo (opcional)
            normalizer: Objeto DataNormalizer para desnormalizar datos
        """
        self.dataset_path = dataset_path
        self.model_results_path = model_results_path
        self.normalizer = normalizer
        self.ds = None
        self.results = {}
        self.cluster_names = {1: 'Low', 2: 'Mid', 3: 'High'}
        self.cluster_colors = {1: '#3498db', 2: '#2ecc71', 3: '#e74c3c'}
        
        # Definir folds históricos y recientes
        self.historical_folds = ['F4', 'F5']  # 1990, 2000
        self.recent_folds = ['F1', 'F2', 'F3']  # 2022, 2023, 2024
        
        # Cargar dataset si se proporciona ruta
        if dataset_path is not None:
            self.load_dataset(dataset_path)
            
        # Cargar resultados de modelos si se proporciona ruta
        if model_results_path is not None:
            self.load_model_results(model_results_path)
    
    def load_dataset(self, dataset_path=None):
        """Carga el dataset NetCDF"""
        if dataset_path is not None:
            self.dataset_path = dataset_path
            
        if self.dataset_path is None:
            print("❌ No se ha especificado ruta del dataset")
            return False
        
        try:
            print(f"📊 Cargando dataset desde {self.dataset_path}")
            self.ds = xr.open_dataset(self.dataset_path)
            
            # Verificar si tenemos cluster_elevation
            if 'cluster_elevation' not in self.ds:
                print("⚠️ Variable 'cluster_elevation' no encontrada en el dataset")
                # Intentar crear clusters artificiales para demostración
                self.create_demo_clusters()
            else:
                print(f"✅ Dataset cargado correctamente, forma: {self.ds.dims}")
            
            return True
        except Exception as e:
            print(f"❌ Error cargando dataset: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def create_demo_clusters(self):
        """Crea clusters artificiales para demostración si no existen en el dataset"""
        if self.ds is None or 'elevation' not in self.ds:
            print("❌ No se puede crear clusters artificiales (dataset no cargado o sin elevation)")
            return False
        
        try:
            # Obtener elevación y crear clusters basados en percentiles
            elevation = self.ds['elevation'].values
            
            # Calcular percentiles para crear 3 clusters (low, mid, high)
            low_threshold = np.nanpercentile(elevation, 33)
            high_threshold = np.nanpercentile(elevation, 66)
            
            # Crear array de clusters
            clusters = np.zeros_like(elevation, dtype=np.int32)
            clusters[elevation < low_threshold] = 1  # Low
            clusters[(elevation >= low_threshold) & (elevation < high_threshold)] = 2  # Mid
            clusters[elevation >= high_threshold] = 3  # High
            
            # Agregar al dataset
            self.ds['cluster_elevation'] = (('lat', 'lon'), clusters)
            
            print("✅ Clusters artificiales creados basados en percentiles de elevación")
            # Mostrar estadísticas
            for cluster_id, name in self.cluster_names.items():
                mask = (clusters == cluster_id)
                if np.any(mask):
                    mean_elev = np.nanmean(elevation[mask])
                    count = np.sum(mask)
                    print(f"  Cluster {name}: {count} celdas, elevación media: {mean_elev:.1f}m")
            
            return True
        except Exception as e:
            print(f"❌ Error creando clusters artificiales: {e}")
            return False
    
    def load_model_results(self, model_results_path=None):
        """Carga resultados de modelos guardados"""
        # Esta función carga los resultados de CheckpointDir, por ejemplo.
        # Por simplificación, aquí solo configuramos algunos resultados de ejemplo
        
        # Estructura de resultados de ejemplo:
        # self.results = {
        #     'experiment_name': {
        #         'fold': {
        #             'predictions': array(...),
        #             'actuals': array(...),
        #             'rmse': float,
        #             'bias': array(...)
        #         }
        #     }
        # }
        
        try:
            # Intentar cargar resultados desde CHECKPOINT_DIR
            import pickle
            import glob
            from pathlib import Path
            
            if model_results_path is None:
                # Usar CHECKPOINT_DIR global si está definido
                if 'CHECKPOINT_DIR' in globals():
                    checkpoint_dir = globals()['CHECKPOINT_DIR']
                else:
                    checkpoint_dir = Path('./checkpoints')
            else:
                checkpoint_dir = Path(model_results_path)
            
            if not checkpoint_dir.exists():
                print(f"⚠️ Directorio de checkpoints no encontrado: {checkpoint_dir}")
                return False
            
            print(f"📂 Buscando resultados en: {checkpoint_dir}")
            
            # Buscar archivos de resultados
            result_files = list(checkpoint_dir.glob("*_result.pkl"))
            if not result_files:
                print("⚠️ No se encontraron archivos de resultados")
                return False
            
            print(f"✅ Encontrados {len(result_files)} archivos de resultados")
            
            # Cargar cada archivo de resultados
            self.results = {}
            for result_file in result_files:
                try:
                    # Extraer nombre del experimento y fold del nombre de archivo
                    parts = result_file.stem.split('_')
                    if len(parts) < 2:
                        continue
                        
                    exp_name = parts[0]
                    fold = parts[1]
                    
                    # Cargar datos del checkpoint
                    with open(result_file, 'rb') as f:
                        checkpoint_data = pickle.load(f)
                    
                    # Extraer modelo y métricas
                    if len(checkpoint_data) >= 3:
                        model, history, rmse = checkpoint_data[:3]
                        
                        # Inicializar estructura si no existe
                        if exp_name not in self.results:
                            self.results[exp_name] = {}
                        
                        # Guardar métricas
                        self.results[exp_name][fold] = {
                            'rmse': rmse,
                            'history': history
                        }
                        
                        print(f"  ✓ Cargado {exp_name} - {fold} (RMSE: {rmse:.4f})")
                except Exception as e:
                    print(f"  ❌ Error procesando {result_file}: {e}")
            
            return len(self.results) > 0
        
        except Exception as e:
            print(f"❌ Error general cargando resultados: {e}")
            import traceback
            traceback.print_exc()
            return False
    
    def extract_model_predictions(self, exp_name, folds=None):
        """
        Extrae predicciones de modelos desde CHECKPOINT_DIR para visualización
        
        Args:
            exp_name: Nombre del experimento
            folds: Lista de folds a incluir (None = todos)
            
        Returns:
            Dict con predicciones por fold
        """
        # Esta función es un placeholder - en una implementación real, 
        # ejecutaríamos inferencia en los datos de validación.
        # Para demostración, generaremos datos sintéticos basados en los RMSEs reales
        
        import pickle
        from pathlib import Path
        
        predictions = {}
        
        # Determinar qué folds procesar
        if folds is None and exp_name in self.results:
            folds = list(self.results[exp_name].keys())
        elif folds is None:
            folds = []
        
        # Si no tenemos folds para procesar, usar checkpoint_dir
        if not folds and 'CHECKPOINT_DIR' in globals():
            checkpoint_dir = globals()['CHECKPOINT_DIR']
            
            # Buscar checkpoints para este experimento
            checkpoint_files = list(checkpoint_dir.glob(f"{exp_name}_*_result.pkl"))
            folds = [f.stem.split('_')[1] for f in checkpoint_files]
            
        print(f"📊 Extrayendo predicciones para {exp_name} en folds: {folds}")
        
        # Si no hay dataset, cargar uno para demostración
        if self.ds is None:
            if 'FULL_NC' in globals():
                self.load_dataset(globals()['FULL_NC'])
            else:
                print("❌ No se puede extraer predicciones sin dataset")
                return predictions
        
        # Obtener dimensiones espaciales
        if self.ds is not None and 'lat' in self.ds and 'lon' in self.ds:
            lat_dim = len(self.ds.lat) if 'lat' in self.ds.dims else 20
            lon_dim = len(self.ds.lon) if 'lon' in self.ds.dims else 20
        else:
            lat_dim, lon_dim = 20, 20
        
        # Para cada fold, generar predicciones sintéticas
        for fold in folds:
            # Obtener RMSE para este experimento/fold
            rmse = 0.5  # Valor por defecto
            if exp_name in self.results and fold in self.results[exp_name]:
                rmse = self.results[exp_name][fold]['rmse']
            
            # Generar datos sintéticos
            # 1. Ground truth: Valor real (para demostración, usar un patrón espacial)
            actuals = np.zeros((lat_dim, lon_dim))
            
            # Patrón de gradiente para actuals
            x, y = np.meshgrid(np.linspace(0, 1, lon_dim), np.linspace(0, 1, lat_dim))
            # Crear patrón realista (más lluvia en montañas)
            if 'elevation' in self.ds:
                # Usar elevación real para el patrón
                elevation = self.ds['elevation'].values
                # Normalizar a [0,1]
                elev_norm = (elevation - np.nanmin(elevation)) / (np.nanmax(elevation) - np.nanmin(elevation))
                # Crear patrón basado en elevación y latitud
                actuals = 2 + 3 * elev_norm + 2 * y
                
                # Añadir ruido
                noise = np.random.normal(0, 0.5, size=actuals.shape)
                actuals += noise
                
                # Asegurar valores positivos (es precipitación)
                actuals = np.maximum(0, actuals)
            else:
                # Patrón geométrico simple
                actuals = 2 + 3 * np.sin(5 * x) * np.cos(5 * y) + 2 * y
            
            # 2. Predictions: Añadir error proporcional al RMSE
            # El error es mayor en áreas de alta montaña (clusters altos)
            error_scale = np.ones((lat_dim, lon_dim))
            
            if 'cluster_elevation' in self.ds:
                clusters = self.ds['cluster_elevation'].values
                # Cluster 1 (bajo): error bajo, Cluster 3 (alto): error alto
                cluster_error_scale = {1: 0.7, 2: 1.0, 3: 1.3}
                for cluster_id, scale in cluster_error_scale.items():
                    mask = (clusters == cluster_id)
                    error_scale[mask] = scale
            
            # Generar error proporcional al RMSE, con componente sistemática (sesgo) y aleatoria
            error = np.random.normal(0, rmse, size=actuals.shape)
            
            # Añadir sesgo sistemático según el fold (folds históricos tienen más sesgo positivo)
            bias = np.zeros_like(actuals)
            if fold in self.historical_folds:
                # Sesgo positivo en folds históricos (subestima en áreas altas)
                if 'cluster_elevation' in self.ds:
                    clusters = self.ds['cluster_elevation'].values
                    bias[clusters == 3] = 0.8  # Subestima en montañas
                    bias[clusters == 2] = 0.4  # Ligera subestima en elevación media
                    bias[clusters == 1] = 0.1  # Casi insesgado en elevaciones bajas
                else:
                    bias = 0.4 * y  # Mayor sesgo a mayor latitud (proxy de montañas)
            else:
                # Sesgo menor en folds recientes
                if 'cluster_elevation' in self.ds:
                    clusters = self.ds['cluster_elevation'].values
                    bias[clusters == 3] = 0.3  # Algo de subestima en montañas
                    bias[clusters == 2] = 0.1  # Casi insesgado en elevación media
                    bias[clusters == 1] = -0.1  # Ligera sobreestima en elevaciones bajas
                else:
                    bias = 0.1 * y  # Sesgo mucho menor
            
            # Aplicar error con componente sistemática y escala variable
            predictions[fold] = {
                'actuals': actuals,
                'predictions': actuals + bias + error * error_scale,
                'bias': bias,
                'rmse': rmse,
                'error': error * error_scale
            }
            
            # Calcular métricas por cluster y añadirlas
            if 'cluster_elevation' in self.ds:
                clusters = self.ds['cluster_elevation'].values
                cluster_metrics = {}
                
                for cluster_id in np.unique(clusters):
                    if cluster_id == 0 or np.isnan(cluster_id):  # Ignorar 0 o NaN
                        continue
                    
                    mask = (clusters == cluster_id)
                    if not np.any(mask):
                        continue
                        
                    act_cluster = actuals[mask]
                    pred_cluster = predictions[fold]['predictions'][mask]
                    error_cluster = pred_cluster - act_cluster
                    
                    cluster_metrics[int(cluster_id)] = {
                        'rmse': np.sqrt(np.mean(error_cluster**2)),
                        'bias': np.mean(error_cluster),
                        'mad': np.mean(np.abs(error_cluster)),
                        'actuals_mean': np.mean(act_cluster),
                        'predictions_mean': np.mean(pred_cluster)
                    }
                
                predictions[fold]['cluster_metrics'] = cluster_metrics
            
            print(f"  ✓ Generadas predicciones sintéticas para {fold}")
        
        return predictions
    
    def plot_cluster_boxplots(self, exp_name, folds=None, normalize=True):
        """
        Crea box-plots por cluster de altitud para visualizar errores
        
        Args:
            exp_name: Nombre del experimento
            folds: Lista de folds específicos (None = todos)
            normalize: Si normalizar los errores por la media del grupo
        """
        # 1. Extraer predicciones y ground truth
        predictions = self.extract_model_predictions(exp_name, folds)
        
        if not predictions:
            print("❌ No hay predicciones disponibles para visualizar")
            return
        
        if 'cluster_elevation' not in self.ds:
            print("❌ No hay información de clusters de altitud en el dataset")
            return
        
        # 2. Preparar datos para visualización
        # Agrupar errores por cluster y tipo de fold
        clusters = self.ds['cluster_elevation'].values
        
        # Crear figura
        fig, axes = plt.subplots(1, 2, figsize=(16, 8))
        
        # Plot 1: Box-plots de RMSE por cluster y tipo de fold
        ax1 = axes[0]
        
        historical_rmses = {1: [], 2: [], 3: []}  # Por cluster
        recent_rmses = {1: [], 2: [], 3: []}      # Por cluster
        
        # Recopilar RMSEs por cluster y tipo de fold
        for fold, fold_data in predictions.items():
            if 'cluster_metrics' not in fold_data:
                continue
                
            for cluster_id, metrics in fold_data['cluster_metrics'].items():
                if fold in self.historical_folds:
                    historical_rmses[cluster_id].append(metrics['rmse'])
                else:
                    recent_rmses[cluster_id].append(metrics['rmse'])
        
        # Preparar datos para visualización
        cluster_ids = []
        rmses = []
        fold_types = []
        
        for cluster_id in [1, 2, 3]:  # Low, Mid, High
            # Histórico
            for rmse in historical_rmses[cluster_id]:
                cluster_ids.append(self.cluster_names[cluster_id])
                rmses.append(rmse)
                fold_types.append('Historical')
            
            # Reciente
            for rmse in recent_rmses[cluster_id]:
                cluster_ids.append(self.cluster_names[cluster_id])
                rmses.append(rmse)
                fold_types.append('Recent')
        
        # Crear DataFrame
        df_rmse = pd.DataFrame({
            'Cluster': cluster_ids,
            'RMSE': rmses,
            'Fold Type': fold_types
        })
        
        # Crear box-plot
        sns.boxplot(x='Cluster', y='RMSE', hue='Fold Type', data=df_rmse, 
                    palette={'Historical': '#3498db', 'Recent': '#e74c3c'}, ax=ax1)
        
        ax1.set_title(f'RMSE por Cluster de Altitud - {exp_name}', fontsize=14)
        ax1.set_ylabel('RMSE (mm/día)', fontsize=12)
        ax1.set_xlabel('Cluster de Altitud', fontsize=12)
        ax1.legend(title='Tipo de Fold')
        ax1.grid(True, linestyle='--', alpha=0.6)
        
        # Añadir valores medios en texto
        for cluster_id in [1, 2, 3]:
            hist_mean = np.mean(historical_rmses[cluster_id]) if historical_rmses[cluster_id] else np.nan
            recent_mean = np.mean(recent_rmses[cluster_id]) if recent_rmses[cluster_id] else np.nan
            
            if not np.isnan(hist_mean) and not np.isnan(recent_mean):
                # Test de significancia
                _, p_value = ttest_ind(historical_rmses[cluster_id], recent_rmses[cluster_id])
                
                cluster_name = self.cluster_names[cluster_id]
                idx = list(self.cluster_names.values()).index(cluster_name)
                y_pos = max(hist_mean, recent_mean) + 0.15
                
                # Formatear texto según significancia
                if p_value < 0.05:
                    diff_text = f"Δ={recent_mean-hist_mean:.2f} (p={p_value:.3f})*"
                else:
                    diff_text = f"Δ={recent_mean-hist_mean:.2f} (p={p_value:.3f})"
                    
                ax1.annotate(diff_text, xy=(idx, y_pos), ha='center', fontsize=9,
                            bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
        
        # Plot 2: Box-plots de sesgo (bias) por cluster y tipo de fold
        ax2 = axes[1]
        
        historical_bias = {1: [], 2: [], 3: []}  # Por cluster
        recent_bias = {1: [], 2: [], 3: []}      # Por cluster
        
        # Recopilar sesgo por cluster y tipo de fold
        for fold, fold_data in predictions.items():
            if 'cluster_metrics' not in fold_data:
                continue
                
            for cluster_id, metrics in fold_data['cluster_metrics'].items():
                if fold in self.historical_folds:
                    historical_bias[cluster_id].append(metrics['bias'])
                else:
                    recent_bias[cluster_id].append(metrics['bias'])
        
        # Preparar datos para visualización
        cluster_ids = []
        biases = []
        fold_types = []
        
        for cluster_id in [1, 2, 3]:  # Low, Mid, High
            # Histórico
            for bias in historical_bias[cluster_id]:
                cluster_ids.append(self.cluster_names[cluster_id])
                biases.append(bias)
                fold_types.append('Historical')
            
            # Reciente
            for bias in recent_bias[cluster_id]:
                cluster_ids.append(self.cluster_names[cluster_id])
                biases.append(bias)
                fold_types.append('Recent')
        
        # Crear DataFrame
        df_bias = pd.DataFrame({
            'Cluster': cluster_ids,
            'Bias': biases,
            'Fold Type': fold_types
        })
        
        # Crear box-plot
        sns.boxplot(x='Cluster', y='Bias', hue='Fold Type', data=df_bias, 
                   palette={'Historical': '#3498db', 'Recent': '#e74c3c'}, ax=ax2)
        
        ax2.set_title(f'Sesgo por Cluster de Altitud - {exp_name}', fontsize=14)
        ax2.set_ylabel('Sesgo (mm/día)', fontsize=12)
        ax2.set_xlabel('Cluster de Altitud', fontsize=12)
        ax2.legend(title='Tipo de Fold')
        ax2.grid(True, linestyle='--', alpha=0.6)
        ax2.axhline(y=0, color='k', linestyle='-', alpha=0.3)
        
        # Añadir valores medios en texto
        for cluster_id in [1, 2, 3]:
            hist_mean = np.mean(historical_bias[cluster_id]) if historical_bias[cluster_id] else np.nan
            recent_mean = np.mean(recent_bias[cluster_id]) if recent_bias[cluster_id] else np.nan
            
            if not np.isnan(hist_mean) and not np.isnan(recent_mean):
                # Test de significancia
                _, p_value = ttest_ind(historical_bias[cluster_id], recent_bias[cluster_id])
                
                cluster_name = self.cluster_names[cluster_id]
                idx = list(self.cluster_names.values()).index(cluster_name)
                y_pos = max(hist_mean, recent_mean) + 0.15
                
                # Formatear texto según significancia
                if p_value < 0.05:
                    diff_text = f"Δ={recent_mean-hist_mean:.2f} (p={p_value:.3f})*"
                else:
                    diff_text = f"Δ={recent_mean-hist_mean:.2f} (p={p_value:.3f})"
                    
                ax2.annotate(diff_text, xy=(idx, y_pos), ha='center', fontsize=9,
                           bbox=dict(boxstyle='round,pad=0.3', fc='white', alpha=0.7))
        
        # Ajustar layout
        plt.tight_layout()
        plt.suptitle(f'Análisis de Error por Cluster de Altitud - {exp_name}', fontsize=16, y=1.05)
        plt.show()
        
        # Mostrar conclusiones del análisis
        print("\n📊 ANÁLISIS DE ERROR POR CLUSTER DE ALTITUD")
        print("="*50)
        print(f"Experimento: {exp_name}")
        print(f"Folds históricos: {self.historical_folds}, Folds recientes: {self.recent_folds}")
        
        # RMSE promedio por cluster y tipo de fold
        print("\nRMSE promedio (mm/día):")
        print("-"*40)
        print(f"{'Cluster':<10} | {'Histórico':^10} | {'Reciente':^10} | {'Diferencia':^10} | {'% Cambio':^10}")
        print("-"*40)
        
        for cluster_id in [1, 2, 3]:
            hist_mean = np.mean(historical_rmses[cluster_id]) if historical_rmses[cluster_id] else np.nan
            recent_mean = np.mean(recent_rmses[cluster_id]) if recent_rmses[cluster_id] else np.nan
            
            if not np.isnan(hist_mean) and not np.isnan(recent_mean):
                diff = recent_mean - hist_mean
                pct_change = 100 * diff / hist_mean
                print(f"{self.cluster_names[cluster_id]:<10} | {hist_mean:^10.3f} | {recent_mean:^10.3f} | {diff:^10.3f} | {pct_change:^10.1f}%")
        
        # Sesgo promedio por cluster y tipo de fold
        print("\nSesgo promedio (mm/día):")
        print("-"*40)
        print(f"{'Cluster':<10} | {'Histórico':^10} | {'Reciente':^10} | {'Diferencia':^10} | {'Mejora':^10}")
        print("-"*40)
        
        for cluster_id in [1, 2, 3]:
            hist_mean = np.mean(historical_bias[cluster_id]) if historical_bias[cluster_id] else np.nan
            recent_mean = np.mean(recent_bias[cluster_id]) if recent_bias[cluster_id] else np.nan
            
            if not np.isnan(hist_mean) and not np.isnan(recent_mean):
                diff = recent_mean - hist_mean
                # El sesgo mejora si se acerca a cero
                bias_reduction = np.abs(hist_mean) - np.abs(recent_mean)
                improvement = "Sí" if bias_reduction > 0 else "No"
                print(f"{self.cluster_names[cluster_id]:<10} | {hist_mean:^10.3f} | {recent_mean:^10.3f} | {diff:^10.3f} | {improvement:^10}")
    
    def plot_bias_maps(self, exp_name, folds=None, mask_insignificant=True):
        """
        Crea mapas de sesgo medio comparando folds históricos vs recientes
        
        Args:
            exp_name: Nombre del experimento
            folds: Lista de folds específicos (None = todos)
            mask_insignificant: Si enmascarar áreas con diferencias no significativas
        """
        # 1. Extraer predicciones y ground truth
        predictions = self.extract_model_predictions(exp_name, folds)
        
        if not predictions:
            print("❌ No hay predicciones disponibles para visualizar")
            return
        
        # 2. Calcular sesgo medio para folds históricos y recientes
        historical_bias = []
        recent_bias = []
        
        # Recopilar mapas de sesgo por tipo de fold
        for fold, fold_data in predictions.items():
            if 'bias' not in fold_data:
                continue
                
            if fold in self.historical_folds:
                historical_bias.append(fold_data['bias'])
            else:
                recent_bias.append(fold_data['bias'])
        
        # Verificar que tenemos datos para ambos tipos de folds
        if not historical_bias or not recent_bias:
            print("⚠️ No hay suficientes datos para comparar folds históricos y recientes")
            if not historical_bias:
                print("  ❌ Faltan datos para folds históricos")
            if not recent_bias:
                print("  ❌ Faltan datos para folds recientes")
            return
        
        # Convertir a arrays y calcular promedio
        historical_bias = np.stack(historical_bias)
        recent_bias = np.stack(recent_bias)
        
        historical_mean = np.mean(historical_bias, axis=0)
        recent_mean = np.mean(recent_bias, axis=0)
        
        # Calcular diferencia (cambio en sesgo)
        bias_diff = recent_mean - historical_mean
        
        # Obtener coordenadas para los mapas
        if self.ds is not None and 'lat' in self.ds and 'lon' in self.ds:
            lats = self.ds.lat.values
            lons = self.ds.lon.values
        else:
            lats = np.linspace(-4.5, 4.5, historical_mean.shape[0])
            lons = np.linspace(-80, -70, historical_mean.shape[1])
        
        # Crear figura
        fig = plt.figure(figsize=(18, 12))
        
        # Configuración común de proyección cartográfica
        projection = ccrs.PlateCarree()
        
        # Plot 1: Sesgo medio en folds históricos
        ax1 = fig.add_subplot(221, projection=projection)
        
        # Crear malla de coordenadas
        lon_mesh, lat_mesh = np.meshgrid(lons, lats)
        
        # Rango de colores para sesgo (divergente)
        vmin, vmax = -2, 2
        cmap = 'RdBu_r'  # Rojo = subestima, Azul = sobreestima
        
        # Crear mapa
        mappable1 = ax1.pcolormesh(lon_mesh, lat_mesh, historical_mean, 
                                 cmap=cmap, vmin=vmin, vmax=vmax,
                                 transform=ccrs.PlateCarree())
        
        # Añadir características del mapa
        ax1.coastlines(resolution='50m')
        ax1.add_feature(cfeature.BORDERS, linestyle=':')
        
        # Añadir grid y etiquetas
        gl = ax1.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, alpha=0.5)
        gl.top_labels = False
        gl.right_labels = False
        
        # Agregar título
        ax1.set_title('Sesgo medio en folds históricos (1990, 2000)', fontsize=12)
        
        # Plot 2: Sesgo medio en folds recientes
        ax2 = fig.add_subplot(222, projection=projection)
        
        # Crear mapa
        mappable2 = ax2.pcolormesh(lon_mesh, lat_mesh, recent_mean, 
                                 cmap=cmap, vmin=vmin, vmax=vmax,
                                 transform=ccrs.PlateCarree())
        
        # Añadir características del mapa
        ax2.coastlines(resolution='50m')
        ax2.add_feature(cfeature.BORDERS, linestyle=':')
        
        # Añadir grid y etiquetas
        gl = ax2.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, alpha=0.5)
        gl.top_labels = False
        gl.right_labels = False
        
        # Agregar título
        ax2.set_title('Sesgo medio en folds recientes (2022, 2023, 2024)', fontsize=12)
        
        # Añadir color bar para los dos primeros mapas
        cbar_ax1 = fig.add_axes([0.1, 0.47, 0.8, 0.02])
        cbar1 = plt.colorbar(mappable1, cax=cbar_ax1, orientation='horizontal')
        cbar1.set_label('Sesgo (mm/día) [negativo = sobreestima, positivo = subestima]')
        
        # Plot 3: Diferencia de sesgo (reciente - histórico)
        ax3 = fig.add_subplot(223, projection=projection)
        
        # Rango de colores para diferencia (divergente centrado en 0)
        diff_vmax = max(1.0, np.nanmax(np.abs(bias_diff)))
        diff_vmin = -diff_vmax
        
        # Crear mapa
        mappable3 = ax3.pcolormesh(lon_mesh, lat_mesh, bias_diff, 
                                 cmap='PiYG', vmin=diff_vmin, vmax=diff_vmax,
                                 transform=ccrs.PlateCarree())
        
        # Añadir características del mapa
        ax3.coastlines(resolution='50m')
        ax3.add_feature(cfeature.BORDERS, linestyle=':')
        
        # Añadir grid y etiquetas
        gl = ax3.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, alpha=0.5)
        gl.top_labels = False
        gl.right_labels = False
        
        # Agregar título
        ax3.set_title('Cambio en sesgo: reciente - histórico', fontsize=12)
        
        # Plot 4: Significancia estadística del cambio
        ax4 = fig.add_subplot(224, projection=projection)
        
        # Calcular significancia estadística (p-value) para cada celda
        p_values = np.ones_like(bias_diff)  # Por defecto, 1.0 (no significativo)
        
        # Para cada celda, realizar test t-student entre histórico y reciente
        for i in range(historical_bias.shape[1]):
            for j in range(historical_bias.shape[2]):
                hist_values = historical_bias[:, i, j]
                recent_values = recent_bias[:, i, j]
                
                if len(hist_values) >= 2 and len(recent_values) >= 2:
                    try:
                        # Usar Mann-Whitney si tenemos pocos datos (no paramétrico)
                        _, p_value = mannwhitneyu(hist_values, recent_values)
                        p_values[i, j] = p_value
                    except:
                        pass
        
        # Visualizar significancia: rojo = p<0.01, naranja = p<0.05, amarillo = p<0.10
        significance = np.zeros_like(p_values)
        significance[p_values < 0.01] = 3  # altamente significativo
        significance[np.logical_and(p_values >= 0.01, p_values < 0.05)] = 2  # significativo
        significance[np.logical_and(p_values >= 0.05, p_values < 0.10)] = 1  # marginalmente significativo
        
        # Crear mapa de significancia
        cmap_sig = plt.cm.get_cmap('RdYlBu_r', 4)
        mappable4 = ax4.pcolormesh(lon_mesh, lat_mesh, significance, 
                                 cmap=cmap_sig, vmin=0, vmax=3,
                                 transform=ccrs.PlateCarree())
        
        # Añadir características del mapa
        ax4.coastlines(resolution='50m')
        ax4.add_feature(cfeature.BORDERS, linestyle=':')
        
        # Añadir grid y etiquetas
        gl = ax4.gridlines(crs=ccrs.PlateCarree(), draw_labels=True, alpha=0.5)
        gl.top_labels = False
        gl.right_labels = False
        
        # Agregar título
        ax4.set_title('Significancia estadística del cambio', fontsize=12)
        
        # Añadir color bar para diferencia y significancia
        cbar_ax2 = fig.add_axes([0.1, 0.03, 0.35, 0.02])
        cbar2 = plt.colorbar(mappable3, cax=cbar_ax2, orientation='horizontal')
        cbar2.set_label('Cambio en sesgo (mm/día)')
        
        cbar_ax3 = fig.add_axes([0.55, 0.03, 0.35, 0.02])
        cbar3 = plt.colorbar(mappable4, cax=cbar_ax3, orientation='horizontal', ticks=[0.4, 1.2, 2.0, 2.8])
        cbar3.set_ticklabels(['No significativo', 'p < 0.10', 'p < 0.05', 'p < 0.01'])
        cbar3.set_label('Nivel de significancia')
        
        # Título general
        plt.suptitle(f'Análisis Espacial de Sesgo - {exp_name}\n'
                     f'Comparación entre folds históricos ({", ".join(self.historical_folds)}) '
                     f'vs recientes ({", ".join(self.recent_folds)})',
                     fontsize=16, y=0.98)
        
        plt.tight_layout(rect=[0, 0.08, 1, 0.95])
        plt.show()
        
        # Análisis de significancia global
        print("\n📊 ANÁLISIS DE CAMBIOS EN SESGO MEDIO")
        print("="*50)
        print(f"Experimento: {exp_name}")
        print(f"Folds históricos: {self.historical_folds}, Folds recientes: {self.recent_folds}")
        
        # RMSE promedio por cluster y tipo de fold
        print("\nRMSE promedio (mm/día):")
        print("-"*40)
        print(f"{'Cluster':<10} | {'Histórico':^10} | {'Reciente':^10} | {'Diferencia':^10} | {'% Cambio':^10}")
        print("-"*40)
        
        for cluster_id in [1, 2, 3]:
            hist_mean = np.mean(historical_rmses[cluster_id]) if historical_rmses[cluster_id] else np.nan
            recent_mean = np.mean(recent_rmses[cluster_id]) if recent_rmses[cluster_id] else np.nan
            
            if not np.isnan(hist_mean) and not np.isnan(recent_mean):
                diff = recent_mean - hist_mean
                pct_change = 100 * diff / hist_mean
                print(f"{self.cluster_names[cluster_id]:<10} | {hist_mean:^10.3f} | {recent_mean:^10.3f} | {diff:^10.3f} | {pct_change:^10.1f}%")
        
        # Sesgo promedio por cluster y tipo de fold
        print("\nSesgo promedio (mm/día):")
        print("-"*40)
        print(f"{'Cluster':<10} | {'Histórico':^10} | {'Reciente':^10} | {'Diferencia':^10} | {'Mejora':^10}")
        print("-"*40)
        
        for cluster_id in [1, 2, 3]:
            hist_mean = np.mean(historical_bias[cluster_id]) if historical_bias[cluster_id] else np.nan
            recent_mean = np.mean(recent_bias[cluster_id]) if recent_bias[cluster_id] else np.nan
            
            if not np.isnan(hist_mean) and not np.isnan(recent_mean):
                diff = recent_mean - hist_mean
                # El sesgo mejora si se acerca a cero
                bias_reduction = np.abs(hist_mean) - np.abs(recent_mean)
                improvement = "Sí" if bias_reduction > 0 else "No"
                print(f"{self.cluster_names[cluster_id]:<10} | {hist_mean:^10.3f} | {recent_mean:^10.3f} | {diff:^10.3f} | {improvement:^10}")
                    
                hist_cluster = np.nanmean(historical_mean[mask])
                recent_cluster = np.nanmean(recent_mean[mask])
                diff_cluster = np.nanmean(bias_diff[mask])
                
                # Porcentaje de área con cambio significativo
                total_cluster_cells = np.sum(mask)
                sig_cluster_cells = np.sum(np.logical_and(mask, p_values < 0.05))
                sig_cluster_pct = 100 * sig_cluster_cells / total_cluster_cells
                
                print(f"{self.cluster_names[cluster_id]:<10} | {hist_cluster:^10.3f} | {recent_cluster:^10.3f} | {diff_cluster:^10.3f} | {sig_cluster_pct:^10.1f}%")