<a href="https://colab.research.google.com/github/ninja-marduk/ml_precipitation_prediction/blob/feature%2Fhybrid-models/models/hybrid_models_enconders_layering_w3_ST-HybridWaveStack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""
Entrenamiento Multi‐rama con GRU encoder–decoder y Transformer para low,
validación y forecast parametrizables, meta‐modelo XGBoost (stacking all H=1–3),
paralelización, trazabilidad y límites del departamento de Boyacá.
"""

import sys
from pathlib import Path
import warnings
import logging

# Función para verificar disponibilidad de datos para lags de precipitación
def verify_precipitation_lags(ds, required_lags=None, min_valid_ratio=0.90):
    """
    Verifica si hay suficientes datos disponibles para procesar los lags de precipitación.

    Args:
        ds: Dataset xarray que contiene las variables
        required_lags: Lista de lags requeridos (si None, verifica todos los disponibles)
        min_valid_ratio: Proporción mínima de datos válidos para considerar aceptable

    Raises:
        ValueError: Si no hay lags disponibles o si la proporción de datos válidos es insuficiente
    """
    # Lista de posibles lags de precipitación
    all_possible_lags = [
        "total_precipitation_lag1", "total_precipitation_lag2",
        "total_precipitation_lag3", "total_precipitation_lag4",
        "total_precipitation_lag12", "total_precipitation_lag24",
        "total_precipitation_lag36"
    ]

    # Determinar qué lags verificar
    lags_to_check = required_lags if required_lags else [lag for lag in all_possible_lags if lag in ds.data_vars]

    if not lags_to_check:
        raise ValueError("No se encontraron variables de lag de precipitación en el dataset")

    logger.info(f"Verificando disponibilidad de datos para {len(lags_to_check)} lags de precipitación")

    # Verificar cada lag
    for lag in lags_to_check:
        if lag not in ds.data_vars:
            raise ValueError(f"El lag requerido {lag} no está disponible en el dataset")

        # Calcular proporción de datos válidos
        lag_data = ds[lag].values
        total_elements = lag_data.size
        valid_elements = total_elements - np.isnan(lag_data).sum()
        valid_ratio = valid_elements / total_elements

        logger.info(f"Lag {lag}: {valid_ratio:.2%} de datos válidos ({valid_elements}/{total_elements})")

        # Verificar si hay suficientes datos válidos
        if valid_ratio < min_valid_ratio:
            raise ValueError(
                f"Insuficientes datos válidos para {lag}. "
                f"Disponible: {valid_ratio:.2%}, Requerido: {min_valid_ratio:.2%}"
            )

    logger.info("✅ Verificación de lags de precipitación completada con éxito")
    return True

# 0) Detectar entorno (Local / Colab)
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=True)
    BASE_PATH = Path("/content/drive/MyDrive/ml_precipitation_prediction")
    !pip install -q xarray netCDF4 optuna seaborn cartopy xgboost ace_tools_open cartopy
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p/".git").exists():
            BASE_PATH = p
            break
print(f"▶️ Base path: {BASE_PATH}")

# 1) Suprimir warnings irrelevantes
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)
from cartopy.io import DownloadWarning
warnings.filterwarnings("ignore", category=DownloadWarning)
import tensorflow as tf
tf.get_logger().setLevel("ERROR")

# 2) Parámetros configurables
INPUT_WINDOW    = 48          # meses de entrada según análisis ACF/PACF
OUTPUT_HORIZON  = 3           # meses de validación y forecast
REF_DATE        = "2025-03"   # fecha de referencia yyyy-mm
MAX_EPOCHS      = 300
PATIENCE_ES     = 30
LR_FACTOR       = 0.5
LR_PATIENCE     = 10
DROPOUT         = 0.1

# 3) Rutas y logger
MODEL_DIR    = BASE_PATH/"models"/"output"/"trained_models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
FEATURES_NC  = BASE_PATH/"models"/"output"/"features_fusion_branches.nc"
FULL_NC      = BASE_PATH/"data"/"output"/"complete_dataset_with_features_with_clusters_elevation_with_windows.nc"
SHP_USER     = Path("/mnt/data/MGN_Departamento.shp")
BOYACA_SHP   = SHP_USER if SHP_USER.exists() else BASE_PATH/"data"/"input"/"shapes"/"MGN_Departamento.shp"
RESULTS_CSV  = MODEL_DIR/f"metrics_w{OUTPUT_HORIZON}_ref{REF_DATE}.csv"
IMAGE_DIR    = MODEL_DIR/"images"
IMAGE_DIR.mkdir(parents=True, exist_ok=True)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

# 4) Imports principales
import numpy            as np
import pandas           as pd
import xarray           as xr
import geopandas        as gpd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
import psutil
from joblib import cpu_count
from scipy.stats import skew

def print_progress(message, level=0, is_start=False, is_end=False):
    """
    Print a formatted progress message.

    Args:
        message: The message to print
        level: Indentation level (0, 1, 2)
        is_start: Whether this is the start of a section
        is_end: Whether this is the end of a section
    """
    prefix = ""
    if level == 0:
        if is_start:
            prefix = "🔵 "
        elif is_end:
            prefix = "✅ "
        else:
            prefix = "➡️ "
    elif level == 1:
        prefix = "   ⚪ "
    else:
        prefix = "     • "

    print(f"{prefix}{message}")
from tensorflow.keras.layers import (
    Input, GRU, RepeatVector, TimeDistributed, Dense,
    MultiHeadAttention, Add, LayerNormalization, Flatten
)
from tensorflow.keras.models import Model
from tensorflow.keras import callbacks

# 5) Recursos hardware
CORES     = cpu_count()
AVAIL_RAM = psutil.virtual_memory().available / (1024**3)
gpus      = tf.config.list_physical_devices("GPU")
USE_GPU   = bool(gpus)
if USE_GPU:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    logger.info(f"🖥 GPU disponible: {gpus[0].name}")
else:
    tf.config.threading.set_inter_op_parallelism_threads(CORES)
    tf.config.threading.set_intra_op_parallelism_threads(CORES)
    logger.info(f"⚙ CPU cores: {CORES}, RAM libre: {AVAIL_RAM:.1f} GB")

# 6) Modelos y utilitarios
def evaluate_metrics(y_true, y_pred):
    # Filtrar NaNs para robustez
    mask = ~(np.isnan(y_true) | np.isnan(y_pred))
    y_true, y_pred = y_true[mask], y_pred[mask]

    # Verificar que hay suficientes datos válidos
    if len(y_true) < 10:
        logger.warning(f"Insuficientes datos válidos para calcular métricas: {len(y_true)} < 10")
        return np.nan, np.nan, np.nan, np.nan

    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    mae = np.mean(np.abs(y_true - y_pred))

    # Evitar división por cero en MAPE
    nonzero_mask = y_true != 0
    if np.sum(nonzero_mask) > 10:
        mape = np.mean(np.abs((y_true[nonzero_mask] - y_pred[nonzero_mask])/(y_true[nonzero_mask] + 1e-5))) * 100
    else:
        mape = np.nan

    # Cálculo de R2 solo si hay suficiente varianza
    var = np.var(y_true)
    if var > 1e-10:
        r2 = 1 - np.sum((y_true - y_pred)**2) / np.sum((y_true - np.mean(y_true))**2)
    else:
        r2 = np.nan

    return rmse, mae, mape, r2

def check_nans(arr, name="array"):
    """Verifica si hay NaNs en un array y retorna un resumen"""
    # Verificar primero si el array es de tipo numérico
    if not np.issubdtype(arr.dtype, np.number):
        return {
            "name": name,
            "nan_count": 0,  # No aplicable para tipos no numéricos
            "total_elements": arr.size,
            "nan_percentage": 0.0,
            "has_nans": False  # Consideramos que no hay NaNs en datos no numéricos
        }
    
    # Procesamiento normal para arrays numéricos
    nan_count = np.isnan(arr).sum()
    total_count = arr.size
    nan_percentage = (nan_count / total_count) * 100 if total_count > 0 else 0

    return {
        "name": name,
        "nan_count": nan_count,
        "total_elements": total_count,
        "nan_percentage": nan_percentage,
        "has_nans": nan_count > 0
    }

def replace_nans(arr, strategy="mean", fill_value=None):
    """Reemplaza valores NaN en un array usando diferentes estrategias"""
    if not np.isnan(arr).any():
        return arr

    # Crear copia para no modificar el original
    arr_copy = arr.copy()

    if strategy == "mean":
        fill = np.nanmean(arr)
    elif strategy == "median":
        fill = np.nanmedian(arr)
    elif strategy == "zero":
        fill = 0.0
    elif strategy == "constant":
        fill = 0.0 if fill_value is None else fill_value
    elif strategy == "interpolate":
        # Para series temporales - simple interpolación lineal
        if arr_copy.ndim == 1:
            mask = np.isnan(arr_copy)
            if np.all(mask):  # Si todo es NaN
                return np.zeros_like(arr_copy)
            if not np.any(~mask):  # Si no hay valores válidos
                return np.zeros_like(arr_copy)
            arr_copy[mask] = np.interp(
                np.flatnonzero(mask),
                np.flatnonzero(~mask),
                arr_copy[~mask]
            )
            return arr_copy
        else:
            # Aplanar, interpolar y restaurar forma
            original_shape = arr_copy.shape
            arr_flat = arr_copy.reshape(-1)
            mask = np.isnan(arr_flat)
            if np.all(mask) or not np.any(~mask):  # Si todo es NaN o no hay valores válidos
                return np.zeros_like(arr_copy)
            arr_flat[mask] = np.interp(
                np.flatnonzero(mask),
                np.flatnonzero(~mask),
                arr_flat[~mask]
            )
            return arr_flat.reshape(original_shape)
    else:
        raise ValueError(f"Estrategia '{strategy}' no reconocida")

    # Reemplazar NaNs
    arr_copy[np.isnan(arr_copy)] = fill
    return arr_copy

class ScalerNaN:
    """StandardScaler que maneja NaNs de forma segura"""
    def __init__(self):
        self.mean_ = None
        self.scale_ = None

    def fit(self, X):
        self.mean_ = np.nanmean(X, axis=0)
        # Usar nanvar con ddof=0 para consistencia con StandardScaler
        self.var_ = np.nanvar(X, axis=0, ddof=0)
        # Evitar división por cero
        self.var_[self.var_ < 1e-10] = 1.0
        self.scale_ = np.sqrt(self.var_)
        return self

    def transform(self, X):
        X_transformed = X.copy()
        # Mantener la estructura dimensional para el broadcasting correcto
        # Iterar sobre cada fila para mantener la compatibilidad dimensional
        for i in range(X.shape[0]):
            row_mask = ~np.isnan(X[i, :])
            if np.any(row_mask):
                X_transformed[i, row_mask] = (X[i, row_mask] - self.mean_[row_mask]) / self.scale_[row_mask]
        return X_transformed

    def fit_transform(self, X):
        return self.fit(X).transform(X)

    def inverse_transform(self, X):
        X_inv = X.copy()
        # Usar la misma lógica de iteración para inversa
        for i in range(X.shape[0]):
            row_mask = ~np.isnan(X[i, :])
            if np.any(row_mask):
                X_inv[i, row_mask] = X[i, row_mask] * self.scale_[row_mask] + self.mean_[row_mask]
        return X_inv

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, X, Y, batch_size=32, **kwargs):
        super().__init__(**kwargs)
        self.X, self.Y = X.astype(np.float32), Y.astype(np.float32)
        self.batch_size = batch_size
    def __len__(self):
        return int(np.ceil(len(self.X) / self.batch_size))
    def __getitem__(self, idx):
        sl = slice(idx*self.batch_size, (idx+1)*self.batch_size)
        return self.X[sl], self.Y[sl]

class TrainingProgressCallback(tf.keras.callbacks.Callback):
    def __init__(self, model_name, total_epochs):
        super().__init__()
        self.model_name = model_name
        self.total_epochs = total_epochs
        self.current_epoch = 0

    def on_epoch_begin(self, epoch, logs=None):
        self.current_epoch = epoch

    def on_epoch_end(self, epoch, logs=None):
        # Update progress after each epoch
        loss = logs.get('loss', 0.0)
        val_loss = logs.get('val_loss', 0.0)
        progress = (epoch + 1) / self.total_epochs * 100

        # Print progress information
        print_progress(
            f"Entrenamiento {self.model_name}: Época {epoch+1}/{self.total_epochs} " +
            f"({progress:.1f}%) - loss: {loss:.4f} - val_loss: {val_loss:.4f}",
            level=2
        )

# GRU-ED y Transformer-ED builders
from tensorflow.keras import backend as K

def build_gru_ed(input_shape, horizon, n_cells, latent=128, dropout=DROPOUT):
    inp = Input(shape=input_shape)
    x = GRU(latent, dropout=dropout)(inp)
    x = RepeatVector(horizon)(x)
    x = GRU(latent, dropout=dropout, return_sequences=True)(x)
    out = TimeDistributed(Dense(n_cells))(x)
    m = Model(inp, out)
    m.compile(optimizer="adam", loss="mse")
    return m


def build_transformer_ed(input_shape, horizon, n_cells,
                         head_size=64, num_heads=4, ff_dim=256, dropout=0.1):
    inp = Input(shape=input_shape)
    attn = MultiHeadAttention(num_heads=num_heads, key_dim=head_size)(inp, inp)
    x = Add()([inp, attn])
    x = LayerNormalization(epsilon=1e-6)(x)
    ff = Dense(ff_dim, activation="relu")(x)
    ff = Dense(input_shape[-1])(ff)
    x = Add()([x, ff])
    x = LayerNormalization(epsilon=1e-6)(x)
    x = Flatten()(x)
    x = Dense(horizon * n_cells)(x)
    out = K.reshape(x, (-1, horizon, n_cells))
    m = Model(inp, out)
    m.compile(optimizer="adam", loss="mse")
    return m


def build_gru_ed_low(input_shape, horizon, n_cells,
                     latent=256, dropout=0.1, use_transformer=True):
    if use_transformer:
        try:
            return build_transformer_ed(input_shape, horizon, n_cells,
                                        head_size=64, num_heads=4,
                                        ff_dim=512, dropout=dropout)
        except tf.errors.ResourceExhaustedError:
            logger.warning("OOM Transformer → usando GRU‐ED para low-branch")
    return build_gru_ed(input_shape, horizon, n_cells,
                        latent=latent, dropout=dropout)


def build_gru_ed_medium_high(input_shape, horizon, n_cells, latent=128, dropout=0.1, use_transformer=True):
    try:
        if use_transformer:
            return build_transformer_ed(input_shape, horizon, n_cells,
                                        head_size=64, num_heads=4,
                                        ff_dim=512, dropout=dropout)
        else:
            return build_gru_ed(input_shape, horizon, n_cells,
                                latent=latent, dropout=dropout)
    except tf.errors.ResourceExhaustedError:
        logger.warning("OOM Transformer → usando GRU‐ED para medium/high")
        return build_gru_ed(input_shape, horizon, n_cells,
                            latent=latent, dropout=dropout)

# 7) Carga datos y shapefile
logger.info("📂 Cargando datasets…")
ds_full = xr.open_dataset(FULL_NC)
ds_feat = xr.open_dataset(FEATURES_NC)
boyaca_gdf = gpd.read_file(BOYACA_SHP)
if boyaca_gdf.crs is None:
    boyaca_gdf.set_crs(epsg=4326, inplace=True)
else:
    boyaca_gdf = boyaca_gdf.to_crs(epsg=4326)

# Extraer información de elevación y calcular límites para los clusters
elev_flat = ds_full['elevation'].values.ravel()
elev_min = np.nanmin(elev_flat)
elev_max = np.nanmax(elev_flat)

# Definir los clusters de elevación usando el mínimo/máximo del dataset y los valores de la imagen proporcionada
ELEVATION_CLUSTERS = {
    'nivel_1': (elev_min, 956.0),
    'nivel_2': (959.0, 2263.0),
    'nivel_3': (2264.0, elev_max)
}

print_progress(f"Clusters de elevación definidos:", level=1)
print_progress(f"nivel_1: {ELEVATION_CLUSTERS['nivel_1'][0]:.1f} - {ELEVATION_CLUSTERS['nivel_1'][1]:.1f} m", level=2)
print_progress(f"nivel_2: {ELEVATION_CLUSTERS['nivel_2'][0]:.1f} - {ELEVATION_CLUSTERS['nivel_2'][1]:.1f} m", level=2)
print_progress(f"nivel_3: {ELEVATION_CLUSTERS['nivel_3'][0]:.1f} - {ELEVATION_CLUSTERS['nivel_3'][1]:.1f} m", level=2)

# Crear máscaras según los clusters de elevación definidos
mask_nivel_1 = (elev_flat >= ELEVATION_CLUSTERS['nivel_1'][0]) & (elev_flat <= ELEVATION_CLUSTERS['nivel_1'][1])
mask_nivel_2 = (elev_flat >= ELEVATION_CLUSTERS['nivel_2'][0]) & (elev_flat <= ELEVATION_CLUSTERS['nivel_2'][1])
mask_nivel_3 = (elev_flat >= ELEVATION_CLUSTERS['nivel_3'][0]) & (elev_flat <= ELEVATION_CLUSTERS['nivel_3'][1])

print_progress(f"Puntos por cluster: nivel_1: {np.sum(mask_nivel_1)}, nivel_2: {np.sum(mask_nivel_2)}, nivel_3: {np.sum(mask_nivel_3)}", level=2)

times      = ds_full.time.values.astype("datetime64[M]")
user_ref   = np.datetime64(REF_DATE, "M")
last_avail = times[-1]
ref = last_avail if user_ref>last_avail else user_ref

val_dates = [
    str(ref),
    str((ref - np.timedelta64(1,'M')).astype("datetime64[M]")),
    str((ref - np.timedelta64(2,'M')).astype("datetime64[M]"))
]
fc_dates  = [str((ref + np.timedelta64(i+1,'M')).astype("datetime64[M]")) for i in range(OUTPUT_HORIZON)]

idx_ref = int(np.where(times == ref)[0][0])
lat     = ds_full.latitude.values
lon     = ds_full.longitude.values
METHODS = ["CEEMDAN","TVFEMD","FUSION"]
BRANCHES= ["high","medium","low"]

# Verificar qué lags de precipitación están disponibles
LAG_FEATURES = [
    "total_precipitation_lag1",
    "total_precipitation_lag2",
    "total_precipitation_lag3",
    "total_precipitation_lag4",
    "total_precipitation_lag12",
    "total_precipitation_lag24",
    "total_precipitation_lag36"
]
available_lags = [lag for lag in LAG_FEATURES if lag in ds_full.data_vars]
logger.info(f"📊 Lags de precipitación disponibles: {available_lags}")

all_metrics = []
preds_store = {}
true_store  = {}
histories   = {}

# callbacks
es_cb = callbacks.EarlyStopping("val_loss", patience=PATIENCE_ES, restore_best_weights=True)
lr_cb = callbacks.ReduceLROnPlateau("val_loss", factor=LR_FACTOR, patience=LR_PATIENCE, min_lr=1e-6)

# 8) Bucle principal
print_progress(f"Iniciando procesamiento de {len(METHODS)} métodos × {len(BRANCHES)} branches con manejo robusto de NaNs", is_start=True)
total_combinations = len(METHODS) * len(BRANCHES)
processed = 0

for method in METHODS:
    for branch in BRANCHES:
        processed += 1
        name = f"{method}_{branch}"
        if name not in ds_feat.data_vars:
            print_progress(f"({processed}/{total_combinations}) ⚠️ {name} no existe, salteando...")
            continue

        print_progress(f"({processed}/{total_combinations}) Procesando {name}", is_start=True)
        try:
            # extraer y aplanar
            Xarr = ds_feat[name].values            # (T, ny, nx)
            Yarr = ds_full["total_precipitation"].values  # (T, ny, nx)

            # Verificar NaNs iniciales
            x_summary = check_nans(Xarr, f"Entrada {name}")
            y_summary = check_nans(Yarr, f"Objetivo {name}")

            if x_summary["has_nans"]:
                print_progress(f"⚠️ Detectados {x_summary['nan_count']} NaNs en entrada {name} ({x_summary['nan_percentage']:.2f}%)", level=1)
                Xarr = replace_nans(Xarr, strategy="interpolate")
                print_progress(f"NaNs reemplazados usando interpolación", level=2)

            if y_summary["has_nans"]:
                print_progress(f"⚠️ Detectados {y_summary['nan_count']} NaNs en objetivo {name} ({y_summary['nan_percentage']:.2f}%)", level=1)
                Yarr = replace_nans(Yarr, strategy="interpolate")
                print_progress(f"NaNs reemplazados usando interpolación", level=2)

            T, ny, nx = Xarr.shape
            n_cells   = ny * nx

            Xfull = Xarr.reshape(T, n_cells)
            yfull = Yarr.reshape(T, n_cells)

            # ventanas
            Nw = T - INPUT_WINDOW - OUTPUT_HORIZON + 1
            if Nw <= 0:
                print_progress(f"❌ Ventanas insuficientes para {name}, continuando con el siguiente", level=1)
                continue

            print_progress(f"Generando {Nw} ventanas para {name}", level=1)

            Xs = np.stack([Xfull[i : i+INPUT_WINDOW] for i in range(Nw)], axis=0)
            ys = np.stack([yfull[i+INPUT_WINDOW : i+INPUT_WINDOW+OUTPUT_HORIZON]
                           for i in range(Nw)], axis=0)

            # ========================================================================
            # IMPLEMENTACIÓN ROBUSTA DE LAGS DE PRECIPITACIÓN
            # ========================================================================
            print_progress(f"Procesando lags de precipitación para {name} de forma robusta", level=1)

            # Preparar lista para características adicionales
            features_to_add = []

            # sin/cos para low
            if branch == "low":
                months = pd.to_datetime(ds_full.time.values).month.values
                s = np.sin(2 * np.pi * months/12)
                c = np.cos(2 * np.pi * months/12)
                Ss = np.stack([s[i : i+INPUT_WINDOW] for i in range(Nw)], axis=0)
                Cs = np.stack([c[i : i+INPUT_WINDOW] for i in range(Nw)], axis=0)
                Ss = np.repeat(Ss[:,:,None], n_cells, axis=2)
                Cs = np.repeat(Cs[:,:,None], n_cells, axis=2)
                features_to_add.extend([Ss, Cs])
                logger.info(f"✓ Agregadas características estacionales sin/cos para branch {branch}")

            # Agregar lags de precipitación como features adicionales (manejo robusto)
            if available_lags:
                logger.info(f"🔄 Agregando {len(available_lags)} lags de precipitación al branch {branch}")
                for lag_var in available_lags:
                    # Obtener datos y verificar NaNs
                    lag_data = ds_full[lag_var].values
                    lag_summary = check_nans(lag_data, f"Lag {lag_var}")

                    # Manejar NaNs según el porcentaje
                    if lag_summary["has_nans"]:
                        print_progress(f"⚠️ {lag_var}: {lag_summary['nan_count']} NaNs ({lag_summary['nan_percentage']:.2f}%)", level=2)
                        if lag_summary["nan_percentage"] < 5:
                            lag_data = replace_nans(lag_data, strategy="interpolate")
                            print_progress(f"NaNs interpolados en {lag_var}", level=2)
                        elif lag_summary["nan_percentage"] < 20:
                            lag_data = replace_nans(lag_data, strategy="mean")
                            print_progress(f"NaNs reemplazados con media en {lag_var}", level=2)
                        else:
                            lag_data = replace_nans(lag_data, strategy="zero")
                            print_progress(f"⚠️ Demasiados NaNs en {lag_var}, reemplazando con ceros", level=2)

                    lag_full = lag_data.reshape(T, n_cells)
                    lag_windows = np.stack([lag_full[i : i+INPUT_WINDOW] for i in range(Nw)], axis=0)
                    features_to_add.append(lag_windows)
                logger.info(f"✓ Lags procesados robustamente: {available_lags}")

            # Concatenar todas las características
            if features_to_add:
                Xs = np.concatenate([Xs] + features_to_add, axis=2)
                n_feats = Xs.shape[2]
                print_progress(f"Estructura de features: {Xs.shape} ({n_feats} features totales)", level=1)
            else:
                n_feats = n_cells
                print_progress(f"Sin features adicionales: {Xs.shape}", level=1)

            # Añadir variables adicionales disponibles en ds_full
            if branch == "low":  # Para el branch "low" añadimos todas las variables
                add_vars_features = []
                
                # Añadir year
                if 'year' in ds_full.data_vars:
                    print_progress(f"Añadiendo 'year' para {name}", level=2)
                    year_data = ds_full['year'].values
                    # Convertir a float si es necesario
                    if not np.issubdtype(year_data.dtype, np.number):
                        year_data = year_data.astype(float)
                    year_full = year_data.reshape(T, n_cells) if year_data.ndim > 1 else np.repeat(year_data[:, np.newaxis], n_cells, axis=1)
                    year_windows = np.stack([year_full[i : i+INPUT_WINDOW] for i in range(Nw)], axis=0)
                    add_vars_features.append(year_windows)
                
                # Añadir doy_sin, doy_cos (día del año transformado a coordenadas cíclicas)
                for doy_var in ['doy_sin', 'doy_cos']:
                    if doy_var in ds_full.data_vars:
                        print_progress(f"Añadiendo '{doy_var}' para {name}", level=2)
                        doy_data = ds_full[doy_var].values
                        if not np.issubdtype(doy_data.dtype, np.number):
                            doy_data = doy_data.astype(float)
                        doy_full = doy_data.reshape(T, n_cells) if doy_data.ndim > 1 else np.repeat(doy_data[:, np.newaxis], n_cells, axis=1)
                        doy_windows = np.stack([doy_full[i : i+INPUT_WINDOW] for i in range(Nw)], axis=0)
                        add_vars_features.append(doy_windows)
                
                # Añadir variables topográficas (elevation, slope, aspect, cluster_elevation)
                for topo_var in ['elevation', 'slope', 'aspect', 'cluster_elevation']:
                    if topo_var in ds_full.data_vars:
                        print_progress(f"Añadiendo '{topo_var}' para {name}", level=2)
                        try:
                            topo_data = ds_full[topo_var].values
                            
                            # Convertir a tipo numérico si no lo es (excepto cluster_elevation que es categórico)
                            if topo_var != 'cluster_elevation' and not np.issubdtype(topo_data.dtype, np.number):
                                topo_data = topo_data.astype(float)
                            
                            # Manejar posibles NaNs en variables topográficas numéricas
                            if topo_var != 'cluster_elevation':
                                topo_summary = check_nans(topo_data, topo_var)
                                if topo_summary["has_nans"]:
                                    print_progress(f"⚠️ Reemplazando {topo_summary['nan_count']} NaNs en {topo_var}", level=2)
                                    topo_data = replace_nans(topo_data, strategy="mean")
                            
                            # Si es cluster_elevation (categórico), convertirlo a one-hot encoding
                            if topo_var == 'cluster_elevation':
                                print_progress(f"Procesando {topo_var} como variable categórica", level=2)
                                continue  # Saltar por ahora ya que necesitaríamos implementar one-hot encoding
                            
                            # Convertir a formato compatible
                            topo_full = topo_data.reshape(n_cells) if topo_data.ndim == 2 else topo_data.ravel()
                            topo_full = np.tile(topo_full, (T, 1))
                            topo_windows = np.stack([topo_full[i : i+INPUT_WINDOW] for i in range(Nw)], axis=0)
                            add_vars_features.append(topo_windows)
                        except Exception as e:
                            print_progress(f"⚠️ Error al procesar {topo_var}: {str(e)}", level=2)
                            continue
            
            # Verificar NaNs después del procesamiento
            xs_processed_summary = check_nans(Xs, "Features procesados")
            if xs_processed_summary["has_nans"]:
                print_progress(f"⚠️ Aún hay {xs_processed_summary['nan_count']} NaNs después del procesamiento, reemplazando", level=1)
                Xs = replace_nans(Xs, strategy="mean")

            # escalado robusto
            print_progress("Aplicando escalado robusto de datos", level=1)
            # Usar ScalerNaN para manejar valores NaN correctamente
            scX = ScalerNaN().fit(Xs.reshape(-1, n_feats))
            scY = ScalerNaN().fit(ys.reshape(-1, n_cells))

            Xs_s = scX.transform(Xs.reshape(-1, n_feats)).reshape(Xs.shape)
            ys_s = scY.transform(ys.reshape(-1, n_cells)).reshape(ys.shape)

            # Verificar NaNs después del escalado
            xs_scaled_summary = check_nans(Xs_s, "Features escalados")
            ys_scaled_summary = check_nans(ys_s, "Objetivos escalados")

            if xs_scaled_summary["has_nans"] or ys_scaled_summary["has_nans"]:
                print_progress("⚠️ Hay NaNs después del escalado, reemplazando con ceros", level=1)
                # Reemplazar NaNs restantes con ceros
                Xs_s = np.nan_to_num(Xs_s, nan=0.0)
                ys_s = np.nan_to_num(ys_s, nan=0.0)

            # partición centrada en REF_DATE
            k_ref = np.clip(idx_ref - INPUT_WINDOW + 1, 0, Nw-1)
            i0    = np.clip(k_ref - (OUTPUT_HORIZON-1), 0, Nw-OUTPUT_HORIZON)

            X_tr, y_tr = Xs_s[:i0], ys_s[:i0]
            X_va, y_va = Xs_s[i0 : i0+OUTPUT_HORIZON], ys_s[i0 : i0+OUTPUT_HORIZON]

            # cargar/entrenar
            model_path = MODEL_DIR/f"{name}_w{OUTPUT_HORIZON}_ref{ref}.keras"
            if model_path.exists():
                print_progress(f"Cargando modelo existente: {model_path.name}", level=1)
                model = tf.keras.models.load_model(str(model_path), compile=False)
                model.compile(optimizer="adam", loss="mse")
            else:
                print_progress(f"Creando nuevo modelo para {name}", level=1)
                if branch == "low":
                    model = build_gru_ed_low((INPUT_WINDOW, n_feats), OUTPUT_HORIZON, n_cells)
                    print_progress(f"Modelo low-branch creado: {model.__class__.__name__}", level=2)
                else:
                    model = build_gru_ed_medium_high((INPUT_WINDOW, n_feats), OUTPUT_HORIZON, n_cells)
                    print_progress(f"Modelo {branch}-branch creado: {model.__class__.__name__}", level=2)

                # Crear callback de progreso personalizado
                progress_cb = TrainingProgressCallback(name, MAX_EPOCHS)

                # Mostrar resumen de datos de entrenamiento
                print_progress(f"Entrenando con {len(X_tr)} muestras, validando con {len(X_va)} muestras", level=1)
                print_progress(f"X_train: {X_tr.shape}, y_train: {y_tr.shape}", level=2)
                print_progress(f"X_val: {X_va.shape}, y_val: {y_va.shape}", level=2)

                # Entrenamiento con barra de progreso
                print_progress(f"Iniciando entrenamiento para {name}", level=1)
                hist = model.fit(
                    DataGenerator(X_tr, y_tr),
                    validation_data=DataGenerator(X_va, y_va),
                    epochs=MAX_EPOCHS,
                    callbacks=[es_cb, lr_cb, progress_cb],
                    verbose=0  # Desactivamos verbose integrado ya que tenemos progress_cb
                )

                print_progress(f"Guardando modelo en {model_path.name}", level=1)
                model.save(str(model_path))
                histories[name] = hist.history

                # Mostrar información del entrenamiento
                print_progress(f"Entrenamiento completado en {len(hist.history['loss'])} épocas", level=1)
                print_progress(f"Loss inicial: {hist.history['loss'][0]:.4f}, Loss final: {hist.history['loss'][-1]:.4f}", level=2)
                print_progress(f"Val-loss inicial: {hist.history['val_loss'][0]:.4f}, Val-loss final: {hist.history['val_loss'][-1]:.4f}", level=2)

            # validación H=1..H
            print_progress(f"Generando predicciones de validación para {name}", level=1)
            preds = model.predict(X_va, verbose=0).reshape(OUTPUT_HORIZON, OUTPUT_HORIZON, n_cells)
            for h in range(OUTPUT_HORIZON):
                date_val = val_dates[h]
                pm_flat  = preds[h,0]
                tm_flat  = y_va[h,0]
                pm = scY.inverse_transform(pm_flat.reshape(1,-1))[0].reshape(ny,nx)
                tm = scY.inverse_transform(tm_flat.reshape(1,-1))[0].reshape(ny,nx)
                rmse, mae, mape, r2 = evaluate_metrics(tm.ravel(), pm.ravel())
                all_metrics.append({
                    "model": name, "branch": branch, "horizon": h+1,
                    "type":"validation", "date": date_val,
                    "RMSE": rmse, "MAE": mae, "MAPE": mape, "R2": r2
                })
                preds_store[(name,date_val)] = pm
                true_store[(name,date_val)]  = tm

            # forecast
            print_progress(f"Generando predicciones de forecast para {name}", level=1)
            X_fc = Xs_s[k_ref : k_ref+1]
            fc_s = model.predict(X_fc, verbose=0)[0]
            FC   = scY.inverse_transform(fc_s)
            for h in range(OUTPUT_HORIZON):
                date_fc = fc_dates[h]
                all_metrics.append({
                    "model": name, "branch": branch, "horizon": h+1,
                    "type":"forecast", "date": date_fc,
                    "RMSE": np.nan, "MAE": np.nan, "MAPE": np.nan, "R2": np.nan
                })
                preds_store[(name,date_fc)] = FC[h].reshape(ny,nx)

        except Exception as e:
            print_progress(f"‼️ Error en {name}: {str(e)}", level=1)
            logger.exception(f"Error en {name}, continuo…")
            continue

print_progress("Procesamiento de todos los modelos completado", is_end=True)

# 9) Guardar métricas y mostrar tabla
print_progress("Guardando métricas y generando tablas", is_start=True)
dfm = pd.DataFrame(all_metrics)
dfm.to_csv(RESULTS_CSV, index=False)
import ace_tools_open as tools
import cartopy.crs as ccrs
tools.display_dataframe_to_user(name=f"Metrics_w{OUTPUT_HORIZON}_ref{ref}", dataframe=dfm)
print_progress(f"Métricas guardadas en {RESULTS_CSV}", is_end=True)

# 10) Curvas de entrenamiento
print_progress("Generando curvas de entrenamiento", is_start=True)
for name, hist in histories.items():
    plt.figure(figsize=(6,4))
    plt.plot(hist["loss"],  label="train")
    plt.plot(hist["val_loss"],label="val")
    plt.title(f"Loss curve: {name}")
    plt.xlabel("Epoch"); plt.ylabel("MSE")
    plt.legend(); plt.tight_layout(); plt.show()
print_progress("Visualizaciones de curvas de entrenamiento completadas", is_end=True)

# 10bis) True vs Predicted por rama y horizonte
for branch in BRANCHES:
    for h in range(1, OUTPUT_HORIZON+1):
        plt.figure(figsize=(5,5))
        valid_data_found = False
        for method in METHODS:
            key = f"{method}_{branch}"
            date_val = val_dates[h-1]
            if (key, date_val) in preds_store and (key, date_val) in true_store:
                y_true = true_store[(key, date_val)].ravel()
                y_pred = preds_store[(key, date_val)].ravel()
                plt.scatter(y_true, y_pred, alpha=0.3, s=2, label=method)
                valid_data_found = True
        
        if valid_data_found:
            lims = [0, max(plt.xlim()[1], plt.ylim()[1])]
            plt.plot(lims, lims, 'k--')
            plt.xlabel("True"); plt.ylabel("Predicted")
            plt.title(f"True vs Pred — {branch}, H={h}")
            plt.legend()  # Solo añadir leyenda si hay datos válidos
        else:
            plt.text(0.5, 0.5, "No hay datos suficientes", ha='center', va='center')
            plt.xlabel("True"); plt.ylabel("Predicted")
            plt.title(f"True vs Pred — {branch}, H={h} (Sin datos)")
        
        plt.tight_layout(); plt.show()

# 11) Mapas 3×3 validación H=1
xmin, ymin, xmax, ymax = boyaca_gdf.total_bounds
for date_val in val_dates:
    arrs = [preds_store[(f"{m}_{b}",date_val)].ravel()
            for m in METHODS for b in BRANCHES
            if (f"{m}_{b}",date_val) in preds_store]
    if not arrs:
        logger.warning(f"No hay predicciones para {date_val}, salto plot.")
        continue
    vmin, vmax = np.min(arrs), np.max(arrs)
    fig, axs = plt.subplots(3,3, figsize=(12,12), subplot_kw={"projection":ccrs.PlateCarree()})
    fig.suptitle(f"Validación H=1 — {date_val}", fontsize=16)
    for i, b in enumerate(BRANCHES):
        for j, m in enumerate(METHODS):
            ax = axs[i,j]
            ax.set_extent([xmin, xmax, ymin, ymax], ccrs.PlateCarree())
            ax.add_geometries(boyaca_gdf.geometry, ccrs.PlateCarree(),
                              edgecolor="black", facecolor="none", linewidth=1)
            key = (f"{m}_{b}", date_val)
            if key in preds_store:
                pcm = ax.pcolormesh(lon, lat, preds_store[key],
                                    vmin=vmin, vmax=vmax,
                                    transform=ccrs.PlateCarree(), cmap="Blues")
            ax.set_title(f"{m}_{b}")
    fig.colorbar(pcm, ax=axs, orientation="horizontal",
                 fraction=0.05, pad=0.04, label="Precipitación (mm)")
    fig.savefig(IMAGE_DIR/f"val_H1_{date_val}.png", dpi=150); plt.show()

    arrs_mape = [
        np.clip(np.abs((true_store[k] - preds_store[k])/(true_store[k]+1e-5))*100,0,200).ravel()
        for k in preds_store if k[1]==date_val and k in true_store
    ]
    if not arrs_mape: continue
    vmin2, vmax2 = 0, np.max(arrs_mape)
    fig, axs = plt.subplots(3,3, figsize=(12,12), subplot_kw={"projection":ccrs.PlateCarree()})
    fig.suptitle(f"MAPE H=1 — {date_val}", fontsize=16)
    for i, b in enumerate(BRANCHES):
        for j, m in enumerate(METHODS):
            ax = axs[i,j]
            ax.set_extent([xmin, xmax, ymin, ymax], ccrs.PlateCarree())
            ax.add_geometries(boyaca_gdf.geometry, ccrs.PlateCarree(),
                              edgecolor="black", facecolor="none", linewidth=1)
            key = (f"{m}_{b}", date_val)
            if key in preds_store and key in true_store:
                mmap = np.clip(np.abs((true_store[key] - preds_store[key])/(true_store[key]+1e-5))*100,0,200)
                pcm2 = ax.pcolormesh(lon, lat, mmap,
                                     vmin=vmin2, vmax=vmax2,
                                     transform=ccrs.PlateCarree(), cmap="Reds")
            ax.set_title(f"{m}_{b}")
    fig.colorbar(pcm2, ax=axs, orientation="horizontal",
                 fraction=0.05, pad=0.04, label="MAPE (%)")
    fig.savefig(IMAGE_DIR/f"mape_H1_{date_val}.png", dpi=150); plt.show()

# 13) META‐MODELOS XGB stacking H=1-3 (retraining con 9 features)
from xgboost import XGBRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# 13.0) Preparar X_meta completo para cada horizonte y retrain modelos
print_progress("Iniciando meta-modelos XGB de stacking H=1-3", is_start=True)
for h in range(1, OUTPUT_HORIZON+1):
    date = val_dates[h-1]
    print_progress(f"Entrenando meta-modelo XGB para horizonte {h}, fecha {date}", level=1)

    # Extraer features (3 preds + elev stats + slope + aspect)
    print_progress(f"Preparando datos para H={h}", level=2)
    
    # Verificar que todas las predicciones existen antes de proceder
    missing_keys = []
    for b in ['low', 'medium', 'high']:
        key = (f"FUSION_{b}", date)
        if key not in preds_store:
            missing_keys.append(key)
    
    if missing_keys:
        print_progress(f"⚠️ No se encontraron predicciones para: {missing_keys}", level=1)
        print_progress(f"Omitiendo entrenamiento para horizonte {h}", level=1)
        continue
    
    # Ahora es seguro obtener las predicciones
    preds = [preds_store[(f"FUSION_{b}", date)].ravel() for b in ['low','medium','high']]
    elev_flat   = ds_full['elevation'].values.ravel()
    slope_flat  = ds_full['slope'].values.ravel()
    aspect_flat = ds_full['aspect'].values.ravel()
    # Estadísticos de elevación
    mean_e = elev_flat.mean(); std_e = elev_flat.std(); skew_e = skew(elev_flat)
    elev_stats = np.vstack([
        np.full_like(elev_flat, mean_e),
        np.full_like(elev_flat, std_e),
        np.full_like(elev_flat, skew_e)
    ]).T
    # Construir X_meta y y_true
    X_meta = np.column_stack(preds + [elev_stats, slope_flat, aspect_flat])
    y_true = true_store[("FUSION_low", date)].ravel()

    # Mostrar dimensiones
    print_progress(f"X_meta shape: {X_meta.shape}, y_true shape: {y_true.shape}", level=2)

    # Train/test split
    X_tr, X_te, y_tr, y_te = train_test_split(X_meta, y_true, test_size=0.2, random_state=42)
    print_progress(f"Split: train={X_tr.shape[0]} muestras, test={X_te.shape[0]} muestras", level=2)

    # Ajustar modelo con todas las features
    print_progress(f"Entrenando XGBoost para H={h}", level=2, is_start=True)
    xgb = XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=5)
    xgb.fit(X_tr, y_tr)

    # Evaluar en conjunto de prueba
    y_pred = xgb.predict(X_te)
    test_rmse = np.sqrt(mean_squared_error(y_te, y_pred))
    print_progress(f"XGB H={h} entrenado. Test RMSE: {test_rmse:.4f}", level=2, is_end=True)

    # Guardar modelo retrained
    model_path = str(MODEL_DIR/f"xgb_all_H{h}_{ref}_9features.json")
    print_progress(f"Guardando modelo en {model_path}", level=2)
    xgb.save_model(model_path)

print_progress("Meta-modelos XGB entrenados y guardados", is_end=True)

# 13.2) Scatter, mapas y métrica final
print_progress("Generando visualizaciones y métricas finales", is_start=True)
fig_sc, axs_sc = plt.subplots(1, OUTPUT_HORIZON, figsize=(6*OUTPUT_HORIZON,5))
for idx_h, h in enumerate(range(1, OUTPUT_HORIZON+1)):
    date = val_dates[h-1]
    mdl_path = MODEL_DIR/f"xgb_all_H{h}_{ref}_9features.json"
    if mdl_path.exists():
        xgb = XGBRegressor(); xgb.load_model(str(mdl_path))
        print_progress(f"Cargado modelo XGB para H={h}", level=1)

        # Recolectar X_meta completo
        preds = [preds_store[(f"FUSION_{b}", date)].ravel() for b in ['low','medium','high']]
        elev_flat   = ds_full['elevation'].values.ravel()
        slope_flat  = ds_full['slope'].values.ravel()
        aspect_flat = ds_full['aspect'].values.ravel()
        mean_e = elev_flat.mean(); std_e = elev_flat.std(); skew_e = skew(elev_flat)

def xgb_predict_full(model, X):
    """
    Make predictions with an XGBoost model, handling memory constraints and NaNs.

    Args:
        model: The XGBoost model
        X: Input features

    Returns:
        Predictions for all samples
    """
    # Handle NaNs in input
    has_nans = np.isnan(X).any()
    if has_nans:
        print_progress(f"⚠️ Detectados NaNs en entrada de XGB, reemplazando con valores medios", level=2)
        # Replace NaNs with column means
        X = np.copy(X)  # Create a copy to avoid modifying the original
        for col in range(X.shape[1]):
            col_data = X[:, col]
            if np.isnan(col_data).any():
                col_mean = np.nanmean(col_data)
                X[np.isnan(X[:, col]), col] = col_mean

    # Check if we need to batch the predictions due to memory constraints
    batch_size = 100000  # Adjust based on available memory
    if X.shape[0] > batch_size:
        # Batch predictions to avoid memory issues
        n_batches = int(np.ceil(X.shape[0] / batch_size))
        preds = []
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, X.shape[0])
            try:
                batch_preds = model.predict(X[start_idx:end_idx])
                preds.append(batch_preds)
            except Exception as e:
                print_progress(f"Error en predicción batch {i}: {str(e)}", level=1)
                # Intentar con DMatrix como fallback
                try:
                    import xgboost as xgb
                    dmatrix = xgb.DMatrix(X[start_idx:end_idx])
                    batch_preds = model.predict(dmatrix)
                    preds.append(batch_preds)
                except Exception as e2:
                    print_progress(f"Error crítico en predicción: {str(e2)}", level=1)
                    # Retornar arrays de cero en caso de error irrecuperable
                    preds.append(np.zeros(end_idx - start_idx))
        return np.concatenate(preds)
    else:
        # Make predictions in one go
        try:
            return model.predict(X)
        except Exception as e:
            print_progress(f"Error en predicción: {str(e)}", level=1)
            # Intentar con DMatrix como fallback
            try:
                import xgboost as xgb
                dmatrix = xgb.DMatrix(X)
                return model.predict(dmatrix)
            except Exception as e2:
                print_progress(f"Error crítico en predicción: {str(e2)}", level=1)
                return np.zeros(X.shape[0])

# Meta metrics list
meta_metrics_all = []

# 13.2.1) Generate scatter plots and calculate metrics
for idx_h, h in enumerate(range(1, OUTPUT_HORIZON+1)):
    date = val_dates[h-1]
    mdl_path = MODEL_DIR/f"xgb_all_H{h}_{ref}_9features.json"
    if mdl_path.exists():
        xgb = XGBRegressor(); xgb.load_model(str(mdl_path))
        print_progress(f"Cargado modelo XGB para H={h}", level=1)

        # Recolectar X_meta completo con manejo robusto de NaNs
        preds = [preds_store[(f"FUSION_{b}", date)].ravel() for b in ['low','medium','high']]

        # Verificar NaNs en predicciones base
        for i, b in enumerate(['low','medium','high']):
            pred_summary = check_nans(preds[i], f"Predicción FUSION_{b}")
            if pred_summary["has_nans"]:
                print_progress(f"Reemplazando {pred_summary['nan_count']} NaNs en predicciones de FUSION_{b}", level=2)
                preds[i] = replace_nans(preds[i], strategy="mean")

        elev_flat = ds_full['elevation'].values.ravel()
        slope_flat = ds_full['slope'].values.ravel()
        aspect_flat = ds_full['aspect'].values.ravel()
        mean_e = elev_flat.mean(); std_e = elev_flat.std(); skew_e = skew(elev_flat)

        # Verificar NaNs en características topográficas
        for arr, name in zip([elev_flat, slope_flat, aspect_flat], ['elevation', 'slope', 'aspect']):
            topo_summary = check_nans(arr, name)
            if topo_summary["has_nans"]:
                print_progress(f"Reemplazando {topo_summary['nan_count']} NaNs en {name}", level=2)
                if name == 'elevation':
                    elev_flat = replace_nans(elev_flat, strategy="mean")
                elif name == 'slope':
                    slope_flat = replace_nans(slope_flat, strategy="mean")
                elif name == 'aspect':
                    aspect_flat = replace_nans(aspect_flat, strategy="mean")

        elev_stats = np.vstack([
            np.full_like(elev_flat, mean_e),
            np.full_like(elev_flat, std_e),
            np.full_like(elev_flat, skew_e)
        ]).T

        X_meta = np.column_stack(preds + [elev_stats, slope_flat, aspect_flat])
        ytrue = true_store[("FUSION_low", date)].ravel()

        # Verificar y manejar NaNs en ytrue
        ytrue_summary = check_nans(ytrue, "Objetivo verdadero")
        if ytrue_summary["has_nans"]:
            print_progress(f"Reemplazando {ytrue_summary['nan_count']} NaNs en objetivo verdadero", level=2)
            ytrue = replace_nans(ytrue, strategy="mean")

        # Predicción robusta
        ypred = xgb_predict_full(xgb, X_meta)

        # Scatter
        ax = axs_sc[idx_h]
        ax.scatter(ytrue, ypred, alpha=0.3, s=2)
        lims = [min(ytrue.min(), ypred.min()), max(ytrue.max(), ypred.max())]
        ax.plot(lims, lims, 'k--')
        ax.set_title(f"XGB H={h} — {date}")
        ax.set_xlabel("True"); ax.set_ylabel("Predicted")

        # Métricas robustas
        rm, ma, maP, r2 = evaluate_metrics(ytrue, ypred)
        meta_metrics_all.append({
            'horizon':h, 'date':date,
            'RMSE':rm, 'MAE':ma, 'MAPE':maP, 'R2':r2,
            'valid_data_pct': 100 - (np.isnan(ytrue).sum() / len(ytrue) * 100)
        })
    else:
        axs_sc[idx_h].text(0.5,0.5,f"No model H={h}",ha='center',va='center')
plt.tight_layout(); plt.show()

# con modelo retrained (con manejo robusto de NaNs)
for h in range(1, OUTPUT_HORIZON+1):
    date = val_dates[h-1]
    mdl_path = MODEL_DIR/f"xgb_all_H{h}_{ref}_9features.json"
    if not mdl_path.exists():
        continue

    try:
        xgb = XGBRegressor(); xgb.load_model(str(mdl_path))

        # Reconstruir X_meta con manejo robusto
        preds = [preds_store[(f"FUSION_{b}", date)].ravel() for b in ['low','medium','high']]
        # Manejar NaNs en predicciones
        for i, b in enumerate(['low','medium','high']):
            if np.isnan(preds[i]).any():
                preds[i] = replace_nans(preds[i], strategy="mean")

        elev_flat = ds_full['elevation'].values.ravel()
        slope_flat = ds_full['slope'].values.ravel()
        aspect_flat = ds_full['aspect'].values.ravel()

        # Manejar NaNs en características topográficas
        if np.isnan(elev_flat).any():
            elev_flat = replace_nans(elev_flat, strategy="mean")
        if np.isnan(slope_flat).any():
            slope_flat = replace_nans(slope_flat, strategy="mean")
        if np.isnan(aspect_flat).any():
            aspect_flat = replace_nans(aspect_flat, strategy="mean")

        mean_e = elev_flat.mean(); std_e = elev_flat.std(); skew_e = skew(elev_flat)
        elev_stats = np.vstack([
            np.full_like(elev_flat, mean_e), np.full_like(elev_flat, std_e), np.full_like(elev_flat, skew_e)
        ]).T

        X_meta = np.column_stack(preds + [elev_stats, slope_flat, aspect_flat])

        # Predicción robusta
        P = xgb_predict_full(xgb, X_meta).reshape(len(lat), len(lon))
        T = true_store[("FUSION_low", date)]

        # Calcular MAPE evitando NaNs
        mask_valid = ~(np.isnan(T) | np.isnan(P))
        M = np.full_like(T, np.nan)  # Inicializar con NaN
        M[mask_valid] = np.abs((T[mask_valid] - P[mask_valid])/(T[mask_valid] + 1e-5))*100

        # Reemplazar NaNs en mapa MAPE para visualización
        if np.isnan(M).any():
            print_progress(f"Reemplazando NaNs en mapa MAPE para visualización", level=2)
            M = np.nan_to_num(M, nan=0.0)

        # Prepare grids for plotting before use
        grid_lon, grid_lat = np.meshgrid(lon, lat)

        fig, axs = plt.subplots(1,2, figsize=(12,5), subplot_kw={'projection':ccrs.PlateCarree()})
        axs[0].set_title(f"Predicción XGB H={h}")
        pcm = axs[0].pcolormesh(grid_lon, grid_lat, P, transform=ccrs.PlateCarree(), cmap='Blues')
        boyaca_gdf.boundary.plot(ax=axs[0], edgecolor='black', transform=ccrs.PlateCarree())
        fig.colorbar(pcm, ax=axs[0], orientation='vertical', label='mm')
        axs[1].set_title(f"MAPE% XGB H={h}")
        pcm2 = axs[1].pcolormesh(grid_lon, grid_lat, M, transform=ccrs.PlateCarree(), cmap='Reds', vmin=0, vmax=np.nanpercentile(M,99))
        boyaca_gdf.boundary.plot(ax=axs[1], edgecolor='black', transform=ccrs.PlateCarree())
        fig.colorbar(pcm2, ax=axs[1], orientation='vertical', label='%')
        plt.tight_layout(); plt.show()
    except Exception as e:
        print_progress(f"❌ Error generando mapa para H={h}: {str(e)}", level=1)
        logger.exception(f"Error en visualización de mapa para H={h}")

 # 13.3) Mapas con modelo retrained (con manejo robusto de NaNs)
for h in range(1, OUTPUT_HORIZON+1):
    date = val_dates[h-1]
    mdl_path = MODEL_DIR/f"xgb_all_H{h}_{ref}_9features.json"
    if not mdl_path.exists():
        continue

    try:
        xgb = XGBRegressor(); xgb.load_model(str(mdl_path))

        # Reconstruir X_meta con manejo robusto
        preds = [preds_store[(f"FUSION_{b}", date)].ravel() for b in ['low','medium','high']]
        # Manejar NaNs en predicciones
        for i, b in enumerate(['low','medium','high']):
            if np.isnan(preds[i]).any():
                preds[i] = replace_nans(preds[i], strategy="mean")

        elev_flat = ds_full['elevation'].values.ravel()
        slope_flat = ds_full['slope'].values.ravel()
        aspect_flat = ds_full['aspect'].values.ravel()

        # Manejar NaNs en características topográficas
        if np.isnan(elev_flat).any():
            elev_flat = replace_nans(elev_flat, strategy="mean")
        if np.isnan(slope_flat).any():
            slope_flat = replace_nans(slope_flat, strategy="mean")
        if np.isnan(aspect_flat).any():
            aspect_flat = replace_nans(aspect_flat, strategy="mean")

        mean_e = elev_flat.mean(); std_e = elev_flat.std(); skew_e = skew(elev_flat)
        elev_stats = np.vstack([
            np.full_like(elev_flat, mean_e), np.full_like(elev_flat, std_e), np.full_like(elev_flat, skew_e)
        ]).T

        X_meta = np.column_stack(preds + [elev_stats, slope_flat, aspect_flat])

        # Predicción robusta
        P = xgb_predict_full(xgb, X_meta).reshape(len(lat), len(lon))
        T = true_store[("FUSION_low", date)]

        # Calcular MAPE evitando NaNs
        mask_valid = ~(np.isnan(T) | np.isnan(P))
        M = np.full_like(T, np.nan)  # Inicializar con NaN
        M[mask_valid] = np.abs((T[mask_valid] - P[mask_valid])/(T[mask_valid] + 1e-5))*100

        # Reemplazar NaNs en mapa MAPE para visualización
        if np.isnan(M).any():
            print_progress(f"Reemplazando NaNs en mapa MAPE para visualización", level=2)
            M = np.nan_to_num(M, nan=0.0)

        # Prepare grids for plotting before use
        grid_lon, grid_lat = np.meshgrid(lon, lat)

        fig, axs = plt.subplots(1,2, figsize=(12,5), subplot_kw={'projection':ccrs.PlateCarree()})
        axs[0].set_title(f"Predicción XGB H={h}")
        pcm = axs[0].pcolormesh(grid_lon, grid_lat, P, transform=ccrs.PlateCarree(), cmap='Blues')
        boyaca_gdf.boundary.plot(ax=axs[0], edgecolor='black', transform=ccrs.PlateCarree())
        fig.colorbar(pcm, ax=axs[0], orientation='vertical', label='mm')
        axs[1].set_title(f"MAPE% XGB H={h}")
        pcm2 = axs[1].pcolormesh(grid_lon, grid_lat, M, transform=ccrs.PlateCarree(), cmap='Reds', vmin=0, vmax=np.nanpercentile(M,99))
        boyaca_gdf.boundary.plot(ax=axs[1], edgecolor='black', transform=ccrs.PlateCarree())
        fig.colorbar(pcm2, ax=axs[1], orientation='vertical', label='%')
        plt.tight_layout(); plt.show()
    except Exception as e:
        print_progress(f"❌ Error generando mapa para H={h}: {str(e)}", level=1)
        logger.exception(f"Error en visualización de mapa para H={h}")


In [None]:
# 13) Meta-modelo neuronal completo con métricas, mapas y tablas
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from scipy.stats import skew
import pandas as pd
import cartopy.crs as ccrs
from sklearn.metrics import r2_score

# Definición de un modelo neuronal más robusto para secuencias
class TemporalMetaModel(nn.Module):
    """
    Modelo híbrido avanzado con componentes de atención temporal y GRU
    para mantener mejor la coherencia entre horizontes de predicción.
    """
    def __init__(self, in_dim, hidden_dim=128, num_horizons=3):
        super().__init__()
        self.num_horizons = num_horizons

        # Encoder de características
        self.feature_encoder = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )

        # Capa GRU para capturar dependencias temporales
        self.gru = nn.GRU(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            dropout=0.1,
            batch_first=True
        )

        # Mecanismo de atención para diferentes horizontes
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=4,
            dropout=0.1
        )

        # Proyección específica para cada horizonte
        self.horizon_projections = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim*2, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            ) for _ in range(num_horizons)
        ])

        # Módulo de calibración para mejorar la estimación de incertidumbre
        self.calibration = nn.Linear(hidden_dim*2, 1)

    def forward(self, x, horizon_idx=None):
        batch_size = x.shape[0]

        # Codificar características de entrada
        encoded = self.feature_encoder(x)

        # Preparar secuencia de horizontes para GRU
        # Expandir el tensor codificado para tener formato de secuencia [batch, seq_len, features]
        seq_input = encoded.unsqueeze(1).expand(-1, self.num_horizons, -1)

        # Procesar con GRU para mantener coherencia temporal
        gru_out, _ = self.gru(seq_input)

        # Aplicar mecanismo de atención para refinar representaciones
        # Preparar datos para la atención: [seq_len, batch, embed_dim]
        attn_input = gru_out.transpose(0, 1)
        attn_output, _ = self.attention(attn_input, attn_input, attn_input)

        # Volver al formato original [batch, seq_len, embed_dim]
        attn_output = attn_output.transpose(0, 1)

        # Concatenar información original con atención para preservar detalles
        combined = torch.cat([gru_out, attn_output], dim=2)

        # Si se solicita un horizonte específico, devolver solo ese
        if horizon_idx is not None:
            return self.horizon_projections[horizon_idx](combined[:, horizon_idx])

        # Caso contrario, generar predicciones para todos los horizontes
        outputs = []
        for h in range(self.num_horizons):
            h_pred = self.horizon_projections[h](combined[:, h])
            outputs.append(h_pred)

        # Retornar predicciones para todos los horizontes [batch, num_horizons]
        return torch.cat(outputs, dim=1)

# Función de pérdida personalizada que considera la coherencia temporal
class TemporalCoherenceLoss(nn.Module):
    """
    Función de pérdida que combina MSE con un término de regularización
    para mantener coherencia temporal entre horizontes consecutivos.
    """
    def __init__(self, lambda_coherence=0.3):
        super().__init__()
        self.mse = nn.MSELoss()
        self.lambda_coherence = lambda_coherence

    def forward(self, pred, target):
        # Calcular MSE estándar
        mse_loss = self.mse(pred, target)

        # Si tenemos multi-horizonte (más de una columna)
        if pred.shape[1] > 1:
            # Penalizar cambios abruptos entre horizontes consecutivos
            coherence_loss = torch.mean(torch.abs(
                torch.diff(pred, dim=1) - torch.diff(target, dim=1)
            )**2)

            # Combinar las pérdidas
            return mse_loss + self.lambda_coherence * coherence_loss
        else:
            return mse_loss

# Función para entrenar el modelo con aprendizaje por curriculum
def train_with_curriculum(model, train_loader, val_loader, epochs=50, lr=0.001, device='cpu'):
    """
    Entrena el modelo usando aprendizaje por curriculum:
    primero en H=1, luego H=1,2, finalmente todos los horizontes.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.5, patience=5, verbose=True
    )
    criterion = TemporalCoherenceLoss()

    history = {'train_loss': [], 'val_loss': []}
    best_val_loss = float('inf')
    best_state_dict = None

    # Fase 1: Entrenar solo en H=1
    print_progress("Fase 1: Entrenando en horizonte H=1", level=1)
    for epoch in range(epochs // 3):
        model.train()
        train_loss = 0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch_h1 = x_batch.to(device), y_batch[:, 0:1].to(device)

            optimizer.zero_grad()
            output = model(x_batch, horizon_idx=0)
            loss = criterion(output, y_batch_h1)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(x_batch)

        train_loss /= len(train_loader.dataset)
        history['train_loss'].append(train_loss)

        # Validación
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                x_batch, y_batch_h1 = x_batch.to(device), y_batch[:, 0:1].to(device)
                output = model(x_batch, horizon_idx=0)
                loss = criterion(output, y_batch_h1)
                val_loss += loss.item() * len(x_batch)

        val_loss /= len(val_loader.dataset)
        history['val_loss'].append(val_loss)
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_state_dict = model.state_dict().copy()

        if (epoch + 1) % 5 == 0:
            print_progress(f"Epoch {epoch+1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", level=2)

    # Cargar los mejores pesos obtenidos hasta ahora
    model.load_state_dict(best_state_dict)

    # Fase 2: Entrenar en H=1,2
    print_progress("Fase 2: Entrenando en horizontes H=1,2", level=1)
    best_val_loss = float('inf')
    for epoch in range(epochs // 3):
        model.train()
        train_loss = 0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch_h12 = x_batch.to(device), y_batch[:, 0:2].to(device)

            optimizer.zero_grad()
            output = model(x_batch)[:, 0:2]  # Solo los primeros 2 horizontes
            loss = criterion(output, y_batch_h12)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(x_batch)

        train_loss /= len(train_loader.dataset)
        history['train_loss'].append(train_loss)

        # Validación
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                x_batch, y_batch_h12 = x_batch.to(device), y_batch[:, 0:2].to(device)
                output = model(x_batch)[:, 0:2]
                loss = criterion(output, y_batch_h12)
                val_loss += loss.item() * len(x_batch)

        val_loss /= len(val_loader.dataset)
        history['val_loss'].append(val_loss)
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_state_dict = model.state_dict().copy()

        if (epoch + 1) % 5 == 0:
            print_progress(f"Epoch {epoch+1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", level=2)

    # Cargar los mejores pesos obtenidos hasta ahora
    model.load_state_dict(best_state_dict)

    # Fase 3: Entrenar en todos los horizontes
    print_progress("Fase 3: Entrenando en todos los horizontes", level=1)
    best_val_loss = float('inf')
    for epoch in range(epochs // 3):
        model.train()
        train_loss = 0
        for x_batch, y_batch in train_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            output = model(x_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(x_batch)

        train_loss /= len(train_loader.dataset)
        history['train_loss'].append(train_loss)

        # Validación
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x_batch, y_batch in val_loader:
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                output = model(x_batch)
                loss = criterion(output, y_batch)
                val_loss += loss.item() * len(x_batch)

        val_loss /= len(val_loader.dataset)
        history['val_loss'].append(val_loss)
        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_state_dict = model.state_dict().copy()

        if (epoch + 1) % 5 == 0:
            print_progress(f"Epoch {epoch+1} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}", level=2)

    # Cargar los mejores pesos obtenidos
    model.load_state_dict(best_state_dict)

    return model, history

# Parámetros de entrenamiento
device   = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_sz = 64
lr       = 5e-4  # Tasa de aprendizaje más baja para entrenamiento más estable
epochs     = 60    # Más épocas para el curriculum learning

# Funciones de métrica
def rmse(a, b): return np.sqrt(np.mean((a - b)**2))
def mae(a, b):  return np.mean(np.abs(a - b))
def mape(a, b): return np.mean(np.abs((a - b) / (b + 1e-5))) * 100
def r2(a, b):
    # a: predicciones, b: valores reales
    ss_res = np.sum((b - a)**2)
    ss_tot = np.sum((b - np.mean(b))**2)
    return 1 - ss_res/ss_tot if ss_tot != 0 else np.nan

def evaluate(a, b):
    return {
        'RMSE': rmse(a, b),
        'MAE': mae(a, b),
        'MAPE': mape(a, b),
        'R2': r2(a, b)
    }

# Preparar grillas para mapas
grid_lon, grid_lat = np.meshgrid(lon, lat)

# Máscaras según elevación
elev_flat = ds_full['elevation'].values.ravel()
mask_low    = elev_flat < 200
mask_mid    = (elev_flat >= 200) & (elev_flat <= 1000)
mask_high   = elev_flat > 1000

# Colecciones para métricas
global_metrics = []
elev_metrics = []
pct_metrics_list = []
trained_models = {}  # Almacenar modelos entrenados

# El resto del código sigue igual, solo cambiamos la inicialización y entrenamiento del modelo
# Dividiremos el bucle en partes para mostrar solo las secciones que cambian

# Bucle por cada horizonte
print_progress("Procesando datos para la construcción del meta-modelo avanzado", level=1)
all_X_meta = []
all_y_horizons = []

for h in range(1, OUTPUT_HORIZON+1):
    date = val_dates[h-1]
    print_progress(f"Recopilando datos para horizonte {h}, fecha {date}", level=1)

    # Obtener predicciones de stacking y verificar NaNs
    preds = [preds_store[(f"FUSION_{b}", date)].ravel() for b in ['low','medium','high']]

    # Verificar y manejar NaNs en predicciones de cada rama
    for i, branch in enumerate(['low', 'medium', 'high']):
        pred_summary = check_nans(preds[i], f"Predicción FUSION_{branch}")
        if pred_summary["has_nans"]:
            print_progress(f"⚠️ {pred_summary['nan_count']} NaNs en predicciones de {branch} ({pred_summary['nan_percentage']:.2f}%)", level=2)
            preds[i] = replace_nans(preds[i], strategy="interpolate")

    # Estadísticos globales de elevación con manejo de NaNs
    elev_flat = ds_full['elevation'].values.ravel()

    # Verificar NaNs en elevación
    elev_summary = check_nans(elev_flat, "Elevación")
    if elev_summary["has_nans"]:
        print_progress(f"Reemplazando {elev_summary['nan_count']} NaNs en elevación", level=2)
        elev_flat = replace_nans(elev_flat, strategy="mean")

    mean_e = elev_flat.mean()
    std_e = elev_flat.std()
    skew_e = skew(elev_flat)
    elev_stats = np.vstack([
        np.full_like(elev_flat, mean_e),
        np.full_like(elev_flat, std_e),
        np.full_like(elev_flat, skew_e)
    ]).T

    # Verificar NaNs en slope y aspect
    slope_flat = ds_full['slope'].values.ravel()
    aspect_flat = ds_full['aspect'].values.ravel()

    for arr, name in zip([slope_flat, aspect_flat], ['Slope', 'Aspect']):
        arr_summary = check_nans(arr, name)
        if arr_summary["has_nans"]:
            print_progress(f"Reemplazando {arr_summary['nan_count']} NaNs en {name}", level=2)
            if name == 'Slope':
                slope_flat = replace_nans(slope_flat, strategy="mean")
            else:
                aspect_flat = replace_nans(aspect_flat, strategy="mean")

    # Construir X_meta y y_true
    X_meta = np.column_stack(preds + [elev_stats, slope_flat, aspect_flat])
    y_true = true_store[("FUSION_low", date)].ravel()

    # Verificar NaNs en y_true
    y_true_summary = check_nans(y_true, "Objetivo")
    if y_true_summary["has_nans"]:
        print_progress(f"Reemplazando {y_true_summary['nan_count']} NaNs en objetivo", level=2)
        y_true = replace_nans(y_true, strategy="mean")

    # Verificar NaNs en X_meta final
    X_meta_summary = check_nans(X_meta, "X_meta final")
    if X_meta_summary["has_nans"]:
        print_progress(f"⚠️ Aún hay {X_meta_summary['nan_count']} NaNs en X_meta, reemplazando", level=2)
        X_meta = np.nan_to_num(X_meta, nan=0.0)

    # Almacenamos datos para todos los horizontes
    all_X_meta.append(X_meta)
    all_y_horizons.append(y_true)

# Unificar todos los datos para el enfoque multi-horizonte
# Necesitamos asegurar que tenemos la misma cantidad de muestras para cada horizonte
n_samples = min([X.shape[0] for X in all_X_meta])
print_progress(f"Construyendo dataset multi-horizonte con {n_samples} muestras por horizonte", level=1)

# Tomar las primeras n_samples de cada horizonte
X_meta_unified = all_X_meta[0][:n_samples]  # Usamos X del primer horizonte como base
y_horizons = np.column_stack([y[:n_samples] for y in all_y_horizons])

# Train/Val split
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(
    X_meta_unified, y_horizons, test_size=0.15, random_state=42
)

print_progress(f"Split de datos: Train={X_train.shape}, Val={X_val.shape}", level=2)

# Preparar DataLoader
train_dataset = TensorDataset(
    torch.from_numpy(X_train).float(),
    torch.from_numpy(y_train).float()
)
val_dataset = TensorDataset(
    torch.from_numpy(X_val).float(),
    torch.from_numpy(y_val).float()
)

train_loader = DataLoader(train_dataset, batch_size=batch_sz, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_sz)

# Define model checkpoint path
model_path = MODEL_DIR/f"temporal_meta_model_{ref}.pt"

# Inicializar y entrenar modelo avanzado
print_progress("Entrenando meta-modelo temporal avanzado con curriculum learning", is_start=True)

# Check if model already exists
if model_path.exists():
    print_progress(f"Cargando modelo existente de {model_path}", level=1)
    model = TemporalMetaModel(in_dim=X_meta_unified.shape[1], num_horizons=OUTPUT_HORIZON).to(device)
    model.load_state_dict(torch.load(str(model_path)))

    # Simular historial para gráficas
    history = {
        'train_loss': [0.1] * 10,  # Placeholder
        'val_loss': [0.1] * 10      # Placeholder
    }
else:
    # Entrenar nuevo modelo
    print_progress(f"Iniciando entrenamiento de nuevo modelo temporal", level=1)
    model = TemporalMetaModel(in_dim=X_meta_unified.shape[1], num_horizons=OUTPUT_HORIZON).to(device)
    model, history = train_with_curriculum(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=epochs,
        lr=lr,
        device=device
    )

    # Guardar modelo
    print_progress(f"Guardando modelo en {model_path}", level=1)
    torch.save(model.state_dict(), str(model_path))

    # Visualizar curva de entrenamiento
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.axvline(x=epochs//3, color='red', linestyle='--', alpha=0.7, label='Fase 1 → Fase 2')
    plt.axvline(x=2*epochs//3, color='red', linestyle='--', alpha=0.7, label='Fase 2 → Fase 3')
    plt.xlabel('Época')
    plt.ylabel('Pérdida')
    plt.title('Entrenamiento con Curriculum Learning')
    plt.legend()
    plt.grid(alpha=0.3)
    plt.savefig(IMAGE_DIR/f"temporal_meta_training_{ref}.png", dpi=150)
    plt.show()

print_progress("Evaluando modelo para todos los horizontes", level=1)

# Evaluación agregada
model.eval()
with torch.no_grad():
    predictions = model(torch.from_numpy(X_meta_unified).float().to(device))
    predictions = predictions.cpu().numpy()

# Evaluar para cada horizonte y calcular métricas
print_progress("Calculando métricas por horizonte", level=1)

global_metrics = []
for h in range(OUTPUT_HORIZON):
    preds_h = predictions[:, h]
    true_h = y_horizons[:, h]

    # Métricas globales
    metrics = evaluate(preds_h, true_h)
    global_metrics.append({
        'horizon': h+1,
        'date': val_dates[h],
        'RMSE': metrics['RMSE'],
        'MAE': metrics['MAE'],
        'MAPE': metrics['MAPE'],
        'R2': metrics['R2'],
        'valid_data_pct': 100 - (np.isnan(true_h).sum() / len(true_h) * 100)
    })

    # Generar visualizaciones
    plt.figure(figsize=(6, 6))
    plt.scatter(true_h, preds_h, alpha=0.4, s=3)
    lims = [
        min(np.nanmin(true_h), np.nanmin(preds_h)),
        max(np.nanmax(true_h), np.nanmax(preds_h))
    ]
    plt.plot(lims, lims, 'k--', alpha=0.75)
    plt.xlabel('Verdadero')
    plt.ylabel('Predicción')
    plt.title(f'Modelo Temporal: Horizonte {h+1} (R²={metrics["R2"]:.3f})')
    plt.grid(alpha=0.3)
    plt.savefig(IMAGE_DIR/f"temporal_meta_scatter_h{h+1}_{ref}.png", dpi=150)
    plt.show()

    print_progress(f"Horizonte {h+1}: RMSE={metrics['RMSE']:.2f}, R²={metrics['R2']:.3f}", level=2)

# Generar tabla de métricas
df_global = pd.DataFrame(global_metrics)
df_global.to_csv(MODEL_DIR/f"temporal_meta_metrics_ref{ref}.csv", index=False)
tools.display_dataframe_to_user(
    name=f"TemporalMeta_Global_metrics_ref{ref}",
    dataframe=df_global
)

# Comparativa entre modelos
print_progress("Comparando rendimiento entre meta-modelos", level=1)

models_to_compare = []

# DeepMeta (modelo original)
try:
    deep_meta_path = MODEL_DIR/f"deepmeta_global_metrics_ref{ref}.csv"
    if deep_meta_path.exists():
        df_deep = pd.read_csv(deep_meta_path)
        df_deep['model'] = 'DeepMeta'
        models_to_compare.append(df_deep)
except Exception as e:
    print_progress(f"Error al cargar métricas DeepMeta: {str(e)}", level=2)

# XGB
try:
    xgb_path = MODEL_DIR/f"xgb_meta_metrics_ref{ref}.csv"
    if xgb_path.exists():
        df_xgb = pd.read_csv(xgb_path)
        df_xgb['model'] = 'XGBoost'
        models_to_compare.append(df_xgb)
except Exception as e:
    print_progress(f"Error al cargar métricas XGBoost: {str(e)}", level=2)

# Nuevo modelo temporal
df_temporal = df_global.copy()
df_temporal['model'] = 'TemporalMeta'
models_to_compare.append(df_temporal)

# Unificar y mostrar comparación
if len(models_to_compare) > 1:
    df_comparison = pd.concat(models_to_compare, ignore_index=True)
    pivot_rmse = df_comparison.pivot_table(
        index='horizon', columns='model', values='RMSE', aggfunc='mean'
    ).reset_index()
    pivot_r2 = df_comparison.pivot_table(
        index='horizon', columns='model', values='R2', aggfunc='mean'
    ).reset_index()

    # Guardar y mostrar tablas comparativas
    pivot_rmse.to_csv(MODEL_DIR/f"metamodels_rmse_comparison_{ref}.csv", index=False)
    pivot_r2.to_csv(MODEL_DIR/f"metamodels_r2_comparison_{ref}.csv", index=False)

    tools.display_dataframe_to_user(
        name=f"MetaModels_RMSE_Comparison_{ref}",
        dataframe=pivot_rmse
    )
    tools.display_dataframe_to_user(
        name=f"MetaModels_R2_Comparison_{ref}",
        dataframe=pivot_r2
    )

    # Visualización de comparación
    plt.figure(figsize=(10, 6))

    models = [col for col in pivot_r2.columns if col != 'horizon']
    x = np.arange(len(pivot_r2))
    width = 0.8 / len(models)

    for i, model in enumerate(models):
        plt.bar(
            x + width * (i - len(models)/2 + 0.5),
            pivot_r2[model],
            width=width,
            label=model
        )

    plt.xlabel('Horizonte')
    plt.ylabel('R²')
    plt.title('Comparación de R² entre Meta-modelos')
    plt.xticks(x, pivot_r2['horizon'])
    plt.legend()
    plt.grid(alpha=0.3)
    plt.savefig(IMAGE_DIR/f"metamodels_comparison_{ref}.png", dpi=150)
    plt.show()

print_progress("Análisis de meta-modelos completado", is_end=True)

# Para la parte de mapas, usamos el mismo código base pero con el nuevo modelo
# ...resto del código continúa igual...

# Variables Utilizadas por los Modelos

Se utilizan las variables definidas en `ALL_FEATURES`: