In [None]:
# ╭────────────────────── IMPORTS Y CONFIGURACIÓN ──────────────────╮
from __future__ import annotations

import sys
from pathlib import Path
from typing import List, Dict, Any, Tuple, Optional
import numpy as np
import pandas as pd
import xarray as xr
import tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import geopandas as gpd
import imageio.v2 as imageio
from IPython.display import display, Image as IPImage, HTML
import warnings
warnings.filterwarnings('ignore')

from tensorflow.keras.layers import (
    Input, ConvLSTM2D, GRU, Flatten, RepeatVector, Reshape,
    TimeDistributed, Dense, MultiHeadAttention, Add,
    LayerNormalization, Embedding, Concatenate, Lambda
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

# Configuración de matplotlib para mejor visualización
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 150
plt.rcParams['font.size'] = 10

# ╭─────────────────────────── PATHS ──────────────────────────╮
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
    %pip install -q xarray netCDF4 matplotlib seaborn scikit-learn cartopy geopandas imageio
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p / '.git').exists():
            BASE_PATH = p
            break

import cartopy.crs as ccrs
print(f'📁 BASE_PATH = {BASE_PATH}')

# Estructura de directorios
DATA_DIR = BASE_PATH / 'data' / 'output'
MODEL_DIR = BASE_PATH / 'models' / 'output' / 'HybridLSTMModels'
MODEL_INPUT_DIR = BASE_PATH / 'data' / 'input' / 'shapes'
IMAGE_DIR = MODEL_DIR / 'images'
GIF_DIR = MODEL_DIR / 'gifs'

# Crear directorios si no existen
for dir_path in [MODEL_DIR, IMAGE_DIR, GIF_DIR]:
    dir_path.mkdir(parents=True, exist_ok=True)

# Archivos de datos
FULL_NC_CLEAN = DATA_DIR / 'complete_dataset_with_features_with_clusters_elevation_windows_imfs_with_onehot_elevation_clean.nc'

# ╭──────────────────── HIPERPARÁMETROS ────────────────────╮
class Config:
    """Configuración centralizada del modelo"""
    INPUT_WINDOW = 60
    HORIZON = 3
    TARGET_VAR = 'total_precipitation'
    EPOCHS = 50
    BATCH_SIZE = 16
    PATIENCE = 40
    LR = 1e-3
    
    # Features
    BASE_FEATURES = [
        'year', 'month', 'month_sin', 'month_cos', 'doy_sin', 'doy_cos',
        'max_daily_precipitation', 'min_daily_precipitation', 'daily_precipitation_std',
        'elevation', 'slope', 'aspect'
    ]
    ELEV_CLUSTER = ['elev_high', 'elev_med', 'elev_low']
    KCE_FEATURES = BASE_FEATURES + ELEV_CLUSTER
    PAFC_FEATURES = KCE_FEATURES + ['total_precipitation_lag1', 'total_precipitation_lag2', 'total_precipitation_lag12']
    LAG_VARS = ['total_precipitation_lag1', 'total_precipitation_lag2', 'total_precipitation_lag12']

print("✅ Configuración cargada")
# ╰────────────────────────────────────────────────────────────╯


In [None]:
# ╭────────────────────── CARGA DE DATOS ──────────────────╮
print("📊 Cargando dataset...")

# Verificar si existe el archivo limpio, si no, crearlo
if not FULL_NC_CLEAN.exists():
    print("⚠️ Archivo limpio no encontrado. Procesando dataset original...")
    FULL_NC = DATA_DIR / 'complete_dataset_with_features_with_clusters_elevation_windows_imfs_with_onehot_elevation.nc'
    ds_raw = xr.open_dataset(FULL_NC)
    
    # Análisis de NaNs
    print("\n📊 Resumen de NaNs en variables lag:")
    print("─" * 55)
    for var in Config.LAG_VARS:
        arr = ds_raw[var].values
        total = arr.size
        n_nans = int(np.isnan(arr).sum())
        print(f"{var:<28}: {n_nans:>8,} / {total:,} ({n_nans/total:6.2%})")
    
    # Limpiar datos (remover 1981 que tiene muchos NaNs)
    ds_clean = ds_raw.sel(time=~(ds_raw['time.year'] == 1981))
    print(f"\n🔄 Timestamps: {len(ds_raw.time)} → {len(ds_clean.time)} (removido 1981)")
    
    # Guardar dataset limpio
    ds_clean.to_netcdf(FULL_NC_CLEAN, mode='w')
    print(f"💾 Dataset limpio guardado en {FULL_NC_CLEAN}")
    ds = ds_clean
else:
    ds = xr.open_dataset(FULL_NC_CLEAN)
    print(f"✅ Dataset cargado desde {FULL_NC_CLEAN}")

# Cargar shapefile
dept_gdf = gpd.read_file(MODEL_INPUT_DIR / 'MGN_Departamento.shp')

# Dimensiones
lat, lon = len(ds.latitude), len(ds.longitude)
print(f"\n📐 Dimensiones: {lat} x {lon} = {lat * lon} celdas")
print(f"📅 Período: {ds.time.values[0]} a {ds.time.values[-1]}")
# ╰────────────────────────────────────────────────────────────╯


In [None]:
# ╭────────────────────── FUNCIONES UTILITARIAS ──────────────────╮

def evaluate_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """Calcula métricas de evaluación"""
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae = mean_absolute_error(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred) / (y_true + 1e-5))) * 100
    r2 = r2_score(y_true, y_pred)
    return {'RMSE': rmse, 'MAE': mae, 'MAPE': mape, 'R2': r2}


def quick_plot(ax, data, cmap, title, date_label, vmin=None, vmax=None, show_departments=True):
    """Función unificada para plotear mapas"""
    mesh = ax.pcolormesh(ds.longitude, ds.latitude, data, 
                         cmap=cmap, shading='nearest', 
                         vmin=vmin, vmax=vmax, 
                         transform=ccrs.PlateCarree())
    ax.coastlines()
    if show_departments and dept_gdf is not None:
        ax.add_geometries(dept_gdf.geometry, ccrs.PlateCarree(), 
                         edgecolor='black', facecolor='none', linewidth=1)
    gl = ax.gridlines(draw_labels=True)
    gl.top_labels = False
    gl.right_labels = False
    ax.set_title(f"{title}\n{date_label}", pad=12)
    return mesh


def generate_and_display_gif(y_true_sample, y_pred_sample, tag, show_in_notebook=True):
    """Genera GIF y lo muestra en el notebook"""
    pcm_min, pcm_max = 0, np.max(y_pred_sample)
    frames = []
    
    for h in range(Config.HORIZON):
        pmap = y_pred_sample[h, ..., 0]
        fig, ax = plt.subplots(1, 1, figsize=(8, 6), 
                              subplot_kw={'projection': ccrs.PlateCarree()})
        
        mesh = quick_plot(ax, pmap, 'Blues', f"{tag}", f"Horizonte {h+1}", 
                         vmin=pcm_min, vmax=pcm_max)
        fig.colorbar(mesh, ax=ax, fraction=0.046, pad=0.04, label='Precipitación (mm)')
        
        # Guardar frame temporal
        tmp = GIF_DIR / f"tmp_{tag}_h{h}.png"
        fig.savefig(tmp, bbox_inches='tight', dpi=100)
        plt.close(fig)
        
        frames.append(imageio.imread(tmp))
        tmp.unlink(missing_ok=True)
    
    # Guardar GIF
    gif_path = GIF_DIR / f"{tag}.gif"
    imageio.mimsave(gif_path, frames, fps=0.5)
    print(f"💾 GIF guardado: {gif_path.name}")
    
    # Mostrar en notebook
    if show_in_notebook:
        display(HTML(f'<h4>🎬 {tag}</h4>'))
        display(IPImage(filename=str(gif_path)))
    
    return gif_path


def plot_training_history(history, tag, show_in_notebook=True):
    """Plotea y guarda el historial de entrenamiento"""
    fig, ax = plt.subplots(figsize=(10, 6))
    ax.plot(history.history['loss'], label='Train Loss', linewidth=2)
    ax.plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss (MSE)')
    ax.set_title(f'Training History - {tag}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Guardar
    img_path = IMAGE_DIR / f"{tag}_history.png"
    fig.savefig(img_path, bbox_inches='tight', dpi=150)
    
    # Mostrar en notebook
    if show_in_notebook:
        plt.show()
    else:
        plt.close()
    
    return img_path


def make_windows(mask: np.ndarray, Xarr: np.ndarray, yarr: np.ndarray, 
                allow_past_context: bool) -> Tuple[np.ndarray, np.ndarray]:
    """Genera ventanas deslizantes descartando las que contienen NaNs"""
    seq_X, seq_y = [], []
    lim = len(mask) - Config.INPUT_WINDOW - Config.HORIZON + 1
    
    for start in range(lim):
        end_w = start + Config.INPUT_WINDOW
        end_y = end_w + Config.HORIZON
        
        if allow_past_context:
            if not mask[end_w:end_y].all():
                continue
        else:
            if not mask[start:end_y].all():
                continue
        
        Xw = Xarr[start:end_w]
        yw = yarr[end_w:end_y]
        
        if np.isnan(Xw).any() or np.isnan(yw).any():
            continue
        
        seq_X.append(Xw)
        seq_y.append(yw)
    
    return np.array(seq_X), np.array(seq_y)


def impute_nans(a: np.ndarray, per_feature_mean: Optional[np.ndarray] = None, 
                is_target: bool = False) -> np.ndarray:
    """Imputa NaNs restantes (seguridad extra)"""
    if not np.isnan(a).any():
        return a
    
    if is_target:
        a[np.isnan(a)] = 0.0
        return a
    
    if per_feature_mean is None:
        raise ValueError('per_feature_mean required for imputing X')
    
    flat = a.reshape(-1, a.shape[-1])
    nan_idx = np.isnan(flat)
    for f in range(a.shape[-1]):
        flat[nan_idx[:, f], f] = per_feature_mean[f]
    
    return flat.reshape(a.shape)

print("✅ Funciones utilitarias definidas")
# ╰────────────────────────────────────────────────────────────╯


In [None]:
# ╭────────────────────── DEFINICIÓN DEL MODELO ──────────────────╮

@tf.keras.utils.register_keras_serializable()
def tile_step_emb(batch_ref, step_emb_tab):
    """Replica la tabla de embedding para el batch"""
    if isinstance(batch_ref, (tf.TensorShape, tf.TensorSpec)):
        return tf.TensorShape([batch_ref[0], step_emb_tab.shape[0], step_emb_tab.shape[1]])
    
    b = tf.shape(batch_ref)[0]
    emb = tf.expand_dims(step_emb_tab, 0)
    return tf.tile(emb, [b, 1, 1])


def build_convlstm_ed(
    *,
    input_window: int,
    output_horizon: int,
    spatial_height: int,
    spatial_width: int,
    n_features: int,
    n_filters: int = 64,
    n_heads: int = 4,
    use_attention: bool = True,
    use_positional_emb: bool = True,
    lr: float = 1e-3
) -> Model:
    """
    Encoder-Decoder ConvLSTM + GRU con positional embedding mejorado
    """
    # ──────────────── Encoder ────────────────
    enc_inputs = Input(
        shape=(input_window, spatial_height, spatial_width, n_features),
        name="enc_input"
    )
    
    x = ConvLSTM2D(n_filters, (3, 3), padding='same',
                   return_sequences=True, name="enc_lstm_1")(enc_inputs)
    x = ConvLSTM2D(n_filters // 2, (3, 3), padding='same',
                   return_sequences=False, name="enc_lstm_2")(x)
    
    # ── Flatten y repetir contexto ──
    flat = Flatten(name="flatten_spatial")(x)
    ctx = RepeatVector(output_horizon, name="context")(flat)
    
    # ── Positional embedding mejorado ──
    if use_positional_emb:
        step_ids_input = Input(shape=(output_horizon,), dtype=tf.int32, name="step_ids")
        step_emb_layer = Embedding(output_horizon, n_filters, name="step_embedding")
        step_emb = step_emb_layer(step_ids_input)
        dec_in = Concatenate(name="dec_concat")([ctx, step_emb])
        model_inputs = [enc_inputs, step_ids_input]
    else:
        dec_in = ctx
        model_inputs = enc_inputs
    
    # ─────────────── Decoder ───────────────
    dec = GRU(2 * n_filters, return_sequences=True, name="dec_gru")(dec_in)
    
    # ─────── Attention (opcional) ───────
    if use_attention:
        attn = MultiHeadAttention(num_heads=n_heads, key_dim=n_filters,
                                  dropout=0.1, name="mha")(dec, dec)
        dec = Add(name="mha_residual")([dec, attn])
        dec = LayerNormalization(name="mha_norm")(dec)
    
    # ───────────── Proyección a grilla ─────────────
    proj = TimeDistributed(
        Dense(spatial_height * spatial_width, activation='linear'),
        name="dense_proj"
    )(dec)
    
    out = Reshape(
        (output_horizon, spatial_height, spatial_width, 1),
        name="reshape_out"
    )(proj)
    
    # Nombre del modelo
    name = ("ConvLSTM_ED_Attn_PE" if use_attention else "ConvLSTM_ED_PE") \
           if use_positional_emb else \
           ("ConvLSTM_ED_Attn" if use_attention else "ConvLSTM_ED")
    
    model = Model(model_inputs, out, name=name)
    model.compile(optimizer=Adam(lr), loss='mse')
    return model


# Factories para diferentes configuraciones
def factory_no_attn(**kw):
    return build_convlstm_ed(use_attention=False, **kw)

def factory_attn(**kw):
    return build_convlstm_ed(use_attention=True, **kw)

print("✅ Arquitectura del modelo definida")
# ╰────────────────────────────────────────────────────────────╯


In [None]:
# ╭────────────────────── CONFIGURACIÓN DE EXPERIMENTOS ──────────────────╮

FOLDS = {
    'F1': {'year': 2018, 'active': True}
}

EXPERIMENTS = {
    'ConvLSTM-ED': {
        'active': True,
        'feature_list': Config.BASE_FEATURES,
        'builder': factory_attn,
        'n_filters': 64,
        'n_heads': 4
    },
    'ConvLSTM-ED-KCE': {
        'active': True,
        'feature_list': Config.KCE_FEATURES,
        'builder': factory_attn,
        'n_filters': 64,
        'n_heads': 4,
    },
    'ConvLSTM-ED-KCE-PAFC': {
        'active': True,
        'feature_list': Config.PAFC_FEATURES,
        'builder': factory_attn,
        'n_filters': 96,
        'n_heads': 6,
    },
}

print("📋 Experimentos configurados:")
for exp_name, exp_cfg in EXPERIMENTS.items():
    if exp_cfg['active']:
        print(f"  • {exp_name}: {len(exp_cfg['feature_list'])} features")
# ╰────────────────────────────────────────────────────────────╯


In [None]:
# ╭────────────────────── ENTRENAMIENTO Y EVALUACIÓN ──────────────────╮

def run_experiments(show_visualizations=True):
    """Ejecuta todos los experimentos configurados"""
    times = pd.to_datetime(ds.time.values)
    results = []
    
    # Contador de experimentos
    total_exp = sum(e['active'] for e in EXPERIMENTS.values()) * sum(f['active'] for f in FOLDS.values())
    exp_count = 0
    
    for exp_name, exp_cfg in EXPERIMENTS.items():
        if not exp_cfg['active']:
            continue
        
        # Configuración del experimento
        feature_list = exp_cfg['feature_list']
        builder = exp_cfg['builder']
        n_filters = exp_cfg.get('n_filters', 64)
        n_heads = exp_cfg.get('n_heads', 4)
        
        # Cargar features
        print(f"\n{'='*60}")
        print(f"🔬 Experimento: {exp_name}")
        print(f"{'='*60}")
        
        Xarr = ds[feature_list].to_array().transpose('time', 'latitude', 'longitude', 'variable').values.astype(np.float32)
        yarr = ds[Config.TARGET_VAR].values.astype(np.float32)
        n_features = Xarr.shape[-1]
        
        for fold_name, fold_cfg in FOLDS.items():
            if not fold_cfg['active']:
                continue
            
            exp_count += 1
            year_val = fold_cfg['year']
            
            print(f"\n▶️ [{exp_count}/{total_exp}] {exp_name} - {fold_name} (val={year_val})")
            
            # Crear máscaras temporales
            mask_val = times.year == year_val
            mask_tr = ~mask_val
            
            if mask_val.sum() < Config.HORIZON:
                print("⚠️ Año sin suficientes datos → skip")
                continue
            
            # Generar ventanas
            X_tr, y_tr = make_windows(mask_tr, Xarr, yarr, allow_past_context=False)
            X_va, y_va = make_windows(mask_val, Xarr, yarr, allow_past_context=True)
            
            print(f"📊 Ventanas - Train: {len(X_tr)}, Val: {len(X_va)}")
            
            if len(X_tr) == 0 or len(X_va) == 0:
                print("⚠️ Sin ventanas válidas → skip")
                continue
            
            # Imputación de NaNs
            feat_mean = np.nanmean(X_tr.reshape(-1, n_features), axis=0)
            X_tr = impute_nans(X_tr, feat_mean)
            X_va = impute_nans(X_va, feat_mean)
            y_tr = impute_nans(y_tr, is_target=True)
            y_va = impute_nans(y_va, is_target=True)
            
            # Escalado
            sx = StandardScaler().fit(X_tr.reshape(-1, n_features))
            sy = StandardScaler().fit(y_tr.reshape(-1, 1))
            
            X_tr_sc = sx.transform(X_tr.reshape(-1, n_features)).reshape(X_tr.shape)
            X_va_sc = sx.transform(X_va.reshape(-1, n_features)).reshape(X_va.shape)
            y_tr_sc = sy.transform(y_tr.reshape(-1, 1)).reshape(y_tr.shape)[..., None]
            y_va_sc = sy.transform(y_va.reshape(-1, 1)).reshape(y_va.shape)[..., None]
            
            # Construir modelo
            tag = f"{exp_name.replace('+', '_')}_{fold_name}"
            model_path = MODEL_DIR / f"{tag}.keras"
            
            if model_path.exists():
                print(f"⏩ Modelo {tag} ya existe → skip")
                continue
            
            model = builder(
                input_window=Config.INPUT_WINDOW,
                output_horizon=Config.HORIZON,
                spatial_height=lat,
                spatial_width=lon,
                n_features=n_features,
                n_filters=n_filters,
                n_heads=n_heads,
                lr=Config.LR
            )
            
            # Preparar inputs según el modelo
            uses_pe = len(model.inputs) > 1
            
            if uses_pe:
                step_ids_train = np.tile(np.arange(Config.HORIZON), (len(X_tr_sc), 1))
                step_ids_val = np.tile(np.arange(Config.HORIZON), (len(X_va_sc), 1))
                X_train_input = [X_tr_sc, step_ids_train]
                X_val_input = [X_va_sc, step_ids_val]
            else:
                X_train_input = X_tr_sc
                X_val_input = X_va_sc
            
            # Entrenar
            print("🏃 Entrenando modelo...")
            es = EarlyStopping(monitor='val_loss', patience=Config.PATIENCE, restore_best_weights=True)
            
            history = model.fit(
                X_train_input, y_tr_sc,
                validation_data=(X_val_input, y_va_sc),
                epochs=Config.EPOCHS,
                batch_size=Config.BATCH_SIZE,
                callbacks=[es],
                verbose=1
            )
            
            # Guardar modelo
            model.save(model_path)
            print(f"💾 Modelo guardado: {model_path.name}")
            
            # Visualizar historial
            plot_training_history(history, tag, show_in_notebook=show_visualizations)
            
            # Predicción y evaluación
            if uses_pe:
                y_hat_sc = model.predict([X_va_sc, step_ids_val], verbose=0)
            else:
                y_hat_sc = model.predict(X_va_sc, verbose=0)
            
            y_hat = sy.inverse_transform(y_hat_sc.reshape(-1, 1)).reshape(y_hat_sc.shape)
            y_true = sy.inverse_transform(y_va_sc.reshape(-1, 1)).reshape(y_va_sc.shape)
            
            # Métricas
            metrics = evaluate_metrics(y_true.ravel(), y_hat.ravel())
            metrics.update({
                'experiment': exp_name,
                'fold': fold_name,
                'epochs': len(history.history['loss'])
            })
            results.append(metrics)
            
            print(f"\n📈 Métricas: RMSE={metrics['RMSE']:.3f}, MAE={metrics['MAE']:.3f}, "
                  f"MAPE={metrics['MAPE']:.1f}%, R²={metrics['R2']:.3f}")
            
            # Verificar variación entre horizontes
            print("\n🔍 Verificación de predicciones por horizonte:")
            predictions_vary = False
            
            for h in range(Config.HORIZON):
                pred_h = y_hat[0, h, ..., 0]
                stats = {
                    'min': pred_h.min(),
                    'max': pred_h.max(),
                    'mean': pred_h.mean(),
                    'std': pred_h.std()
                }
                print(f"  H{h+1}: min={stats['min']:.3f}, max={stats['max']:.3f}, "
                      f"mean={stats['mean']:.3f}, std={stats['std']:.3f}")
                
                if h > 0:
                    diff = np.abs(y_hat[0, h] - y_hat[0, 0]).mean()
                    if diff > 0.001:
                        predictions_vary = True
            
            if not predictions_vary and Config.HORIZON > 1:
                print("⚠️ ADVERTENCIA: Las predicciones parecen idénticas entre horizontes")
            else:
                print("✅ Las predicciones varían correctamente entre horizontes")
            
            # Generar GIF
            last_idx = min(len(y_hat) - 1, 10)
            generate_and_display_gif(y_true[last_idx], y_hat[last_idx], tag, 
                                   show_in_notebook=show_visualizations)
    
    # Guardar resultados
    if results:
        df_results = pd.DataFrame(results)
        csv_path = MODEL_DIR / "metrics_experiments.csv"
        df_results.to_csv(csv_path, index=False)
        print(f"\n📊 Métricas guardadas en: {csv_path}")
        
        # Mostrar resumen
        print("\n📋 Resumen de resultados:")
        display(df_results[['experiment', 'fold', 'RMSE', 'MAE', 'MAPE', 'R2']].round(3))
    
    return results

# Ejecutar experimentos
print("🚀 Iniciando experimentos...\n")
results = run_experiments(show_visualizations=True)
# ╰────────────────────────────────────────────────────────────╯


In [None]:
# ╭────────────────────── EVALUACIÓN DE MODELOS GUARDADOS ──────────────────╮

def evaluate_saved_models(show_visualizations=True):
    """Evalúa todos los modelos guardados"""
    print("\n🔍 Evaluando modelos guardados...\n")
    
    all_metrics = []
    times = pd.to_datetime(ds.time.values)
    
    # Custom objects para cargar modelos
    custom_objects = {'tile_step_emb': tile_step_emb}
    
    for model_path in sorted(MODEL_DIR.glob("*.keras")):
        tag = model_path.stem
        parts = tag.split("_")
        fold = parts[-1]
        exp_name = "_".join(parts[:-1]).replace("_", "-")
        
        if exp_name not in EXPERIMENTS:
            print(f"⚠️ Experimento no encontrado para {tag}")
            continue
        
        print(f"\n📊 Evaluando: {tag}")
        
        # Cargar features
        feature_list = EXPERIMENTS[exp_name]['feature_list']
        Xarr = ds[feature_list].to_array().transpose('time', 'latitude', 'longitude', 'variable').values.astype(np.float32)
        yarr = ds[Config.TARGET_VAR].values.astype(np.float32)
        T, _, _, F = Xarr.shape
        
        # Ventana final para evaluación
        start = T - Config.INPUT_WINDOW - Config.HORIZON
        end_w = start + Config.INPUT_WINDOW
        end_y = end_w + Config.HORIZON
        
        X_eval = Xarr[start:end_w]
        y_eval = yarr[end_w:end_y]
        
        # Escalado
        sx = StandardScaler().fit(Xarr.reshape(-1, F))
        sy = StandardScaler().fit(yarr.reshape(-1, 1))
        
        Xe_sc = sx.transform(X_eval.reshape(-1, F)).reshape(1, Config.INPUT_WINDOW, lat, lon, F)
        
        # Cargar modelo
        model = tf.keras.models.load_model(model_path, compile=False, custom_objects=custom_objects)
        
        # Predicción
        uses_pe = len(model.inputs) > 1
        if uses_pe:
            step_ids_eval = np.tile(np.arange(Config.HORIZON), (1, 1))
            yhat_sc = model.predict([Xe_sc, step_ids_eval], verbose=0)
        else:
            yhat_sc = model.predict(Xe_sc, verbose=0)
        
        yhat = sy.inverse_transform(yhat_sc.reshape(-1, 1)).reshape(Config.HORIZON, lat, lon)
        
        # Métricas por horizonte
        for h in range(Config.HORIZON):
            yt = y_eval[h].ravel()
            yp = yhat[h].ravel()
            
            # Filtrar NaN/Inf
            mask = np.isfinite(yt) & np.isfinite(yp)
            if mask.sum() == 0:
                continue
            
            yt, yp = yt[mask], yp[mask]
            
            metrics = evaluate_metrics(yt, yp)
            metrics.update({
                'model': tag,
                'experiment': exp_name,
                'fold': fold,
                'horizon': h + 1
            })
            all_metrics.append(metrics)
        
        # Visualización comparativa
        if show_visualizations:
            fig, axes = plt.subplots(Config.HORIZON, 3, figsize=(15, 5 * Config.HORIZON),
                                   subplot_kw={'projection': ccrs.PlateCarree()})
            
            if Config.HORIZON == 1:
                axes = axes.reshape(1, -1)
            
            dates = pd.date_range(times[end_w], periods=Config.HORIZON, freq='MS')
            vmin, vmax = 0, max(yhat.max(), y_eval.max())
            
            for h in range(Config.HORIZON):
                # Real
                quick_plot(axes[h, 0], y_eval[h], 'Blues', f"Real H{h+1}",
                          dates[h].strftime('%Y-%m'), vmin, vmax)
                
                # Predicción
                quick_plot(axes[h, 1], yhat[h], 'Blues', f"Pred H{h+1}",
                          dates[h].strftime('%Y-%m'), vmin, vmax)
                
                # Error MAPE
                err = np.clip(np.abs((y_eval[h] - yhat[h]) / (y_eval[h] + 1e-5)) * 100, 0, 100)
                quick_plot(axes[h, 2], err, 'Reds', f"MAPE% H{h+1}",
                          dates[h].strftime('%Y-%m'), 0, 100)
            
            fig.suptitle(f"{tag} - Evaluación Final", fontsize=16)
            fig.tight_layout()
            
            # Guardar y mostrar
            eval_img_path = MODEL_DIR / f"eval_{tag}.png"
            fig.savefig(eval_img_path, dpi=150, bbox_inches='tight')
            plt.show()
            
            # Generar GIF de predicciones
            generate_and_display_gif(y_eval, yhat[:, :, :, np.newaxis], f"eval_{tag}",
                                   show_in_notebook=True)
    
    # Guardar métricas
    if all_metrics:
        df_metrics = pd.DataFrame(all_metrics)
        csv_path = MODEL_DIR / 'metrics_evaluation.csv'
        df_metrics.to_csv(csv_path, index=False)
        print(f"\n📊 Métricas de evaluación guardadas en: {csv_path}")
        
        # Mostrar resumen por modelo
        print("\n📋 Resumen de evaluación:")
        summary = df_metrics.groupby('model')[['RMSE', 'MAE', 'MAPE', 'R2']].mean().round(3)
        display(summary)
    
    return df_metrics

# Ejecutar evaluación
df_evaluation = evaluate_saved_models(show_visualizations=True)
# ╰────────────────────────────────────────────────────────────╯
