<a href="https://colab.research.google.com/github/ninja-marduk/ml_precipitation_prediction/blob/feature%2Fhybrid-models/models/hybrid_models_enconders_layering_w3_ST-HybridWaveStack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# -*- coding: utf-8 -*-
"""
Entrenamiento Multi‐rama con GRU encoder–decoder y Transformer para low,
validación y forecast parametrizables, meta‐modelo U-Net + ConvLSTM (stacking low, medium, high),
paralelización, trazabilidad y límites del departamento de Boyacá.
"""

import sys
from pathlib import Path
import warnings
import logging

# 0) Detectar entorno (Local / Colab)
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=True)
    BASE_PATH = Path("/content/drive/MyDrive/ml_precipitation_prediction")
    !pip install -q xarray netCDF4 optuna seaborn cartopy ace_tools_open
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p/".git").exists():
            BASE_PATH = p
            break
print(f"▶️ Base path: {BASE_PATH}")

# 1) Suprimir warnings irrelevantes
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)
from cartopy.io import DownloadWarning
warnings.filterwarnings("ignore", category=DownloadWarning)
import tensorflow as tf
tf.get_logger().setLevel("ERROR")

# 2) Parámetros configurables
INPUT_WINDOW    = 60          # meses de entrada
OUTPUT_HORIZON  = 3           # meses de validación y forecast
REF_DATE        = "2025-03"   # fecha de referencia yyyy-mm
MAX_EPOCHS      = 300
PATIENCE_ES     = 30
LR_FACTOR       = 0.5
LR_PATIENCE     = 10
DROPOUT         = 0.1
BATCH_SIZE      = 16

# 3) Rutas y logger
MODEL_DIR    = BASE_PATH/"models"/"output"/"trained_models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
FEATURES_NC  = BASE_PATH/"models"/"output"/"features_fusion_branches.nc"
FULL_NC      = BASE_PATH/"data"/"output"/"complete_dataset_with_features_with_clusters_elevation_with_windows.nc"
SHP_USER     = Path("/mnt/data/MGN_Departamento.shp")
BOYACA_SHP   = SHP_USER if SHP_USER.exists() else BASE_PATH/"data"/"input"/"shapes"/"MGN_Departamento.shp"
RESULTS_CSV  = MODEL_DIR/f"metrics_w{OUTPUT_HORIZON}_ref{REF_DATE}.csv"
IMAGE_DIR    = MODEL_DIR/"images"
IMAGE_DIR.mkdir(exist_ok=True)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

# 4) Imports principales
import numpy            as np
import pandas           as pd
import xarray           as xr
import geopandas        as gpd
import matplotlib.pyplot as plt
import cartopy.crs      as ccrs
from sklearn.preprocessing import StandardScaler
import psutil
from joblib import cpu_count
from tensorflow.keras.layers import (
    Input, GRU, RepeatVector, TimeDistributed, Dense,
    MultiHeadAttention, Add, LayerNormalization, Flatten,
    Conv2D, ConvLSTM2D, MaxPooling2D, UpSampling2D, Concatenate, BatchNormalization
)
from tensorflow.keras.models import Model
from tensorflow.keras import callbacks

# 5) Recursos hardware
CORES     = cpu_count()
AVAIL_RAM = psutil.virtual_memory().available / (1024**3)
gpus      = tf.config.list_physical_devices("GPU")
USE_GPU   = bool(gpus)
if USE_GPU:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    logger.info(f"🖥 GPU disponible: {gpus[0].name}")
else:
    tf.config.threading.set_inter_op_parallelism_threads(CORES)
    tf.config.threading.set_intra_op_parallelism_threads(CORES)
    logger.info(f"⚙ CPU cores: {CORES}, RAM libre: {AVAIL_RAM:.1f} GB")

# 6) Modelos y utilitarios
def evaluate_metrics(y_true, y_pred):
    rmse = np.sqrt(np.mean((y_true - y_pred)**2))
    mae  = np.mean(np.abs(y_true - y_pred))
    mape = np.mean(np.abs((y_true - y_pred)/(y_true + 1e-5))) * 100
    r2   = 1 - np.sum((y_true - y_pred)**2) / np.sum((y_true - np.mean(y_true))**2)
    return rmse, mae, mape, r2

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, X, Y, batch_size=BATCH_SIZE, **kwargs):
        super().__init__(**kwargs)
        self.X, self.Y = X.astype(np.float32), Y.astype(np.float32)
        self.batch_size = batch_size
    def __len__(self):
        return int(np.ceil(len(self.X) / self.batch_size))
    def __getitem__(self, idx):
        sl = slice(idx*self.batch_size, (idx+1)*self.batch_size)
        return self.X[sl], self.Y[sl]

def build_gru_ed(input_shape, horizon, n_cells, latent=128, dropout=DROPOUT):
    inp = Input(shape=input_shape)
    x = GRU(latent, dropout=dropout)(inp)
    x = RepeatVector(horizon)(x)
    x = GRU(latent, dropout=dropout, return_sequences=True)(x)
    out = TimeDistributed(Dense(n_cells))(x)
    m = Model(inp, out)
    m.compile(optimizer="adam", loss="mse")
    return m

def build_transformer_ed(input_shape, horizon, n_cells,
                         head_size=64, num_heads=4, ff_dim=256, dropout=0.1):
    inp = Input(shape=input_shape)
    attn = MultiHeadAttention(num_heads=num_heads, key_dim=head_size)(inp, inp)
    x = Add()([inp, attn])
    x = LayerNormalization(epsilon=1e-6)(x)
    ff = Dense(ff_dim, activation="relu")(x)
    ff = Dense(input_shape[-1])(ff)
    x = Add()([x, ff])
    x = LayerNormalization(epsilon=1e-6)(x)
    x = Flatten()(x)
    x = Dense(horizon * n_cells)(x)
    out = tf.reshape(x, (-1, horizon, n_cells))
    m = Model(inp, out)
    m.compile(optimizer="adam", loss="mse")
    return m

def build_gru_ed_low(input_shape, horizon, n_cells,
                     latent=256, dropout=0.1, use_transformer=True):
    if use_transformer:
        try:
            return build_transformer_ed(input_shape, horizon, n_cells,
                                        head_size=64, num_heads=4,
                                        ff_dim=512, dropout=dropout)
        except tf.errors.ResourceExhaustedError:
            logger.warning("OOM Transformer → usando GRU‐ED para low-branch")
    return build_gru_ed(input_shape, horizon, n_cells,
                        latent=latent, dropout=dropout)

# NUEVO: modelo Meta U-Net + ConvLSTM para predicción espaciotemporal
def build_unet_convlstm_meta(input_shape, output_horizon, ny, nx, channels=3):
    """
    input_shape: (num_models, height, width, channels) — ej. (9, ny, nx, 1)
    output_horizon: número de meses a predecir (e.g., 3)
    ny, nx: dimensiones espaciales
    channels: canales de entrada por modelo (normalmente 1)

    Retorna modelo que predice secuencia (output_horizon, ny, nx, 1)
    """
    inputs = Input(shape=input_shape)  # e.g. (9, ny, nx, 1)

    # Fusionar los 9 mapas concatenando canales
    # reshape para concatenar todos mapas en canales: (ny, nx, 9*channels)
    x = tf.reshape(inputs, (-1, ny, nx, input_shape[0]*channels))

    # Encoder convolucional (simples bloques conv + maxpool)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(x)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(64, 3, activation='relu', padding='same')(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(2)(conv1)

    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(2)(conv2)

    # Bottleneck ConvLSTM
    # Reshape para ConvLSTM: (batch_size, time_steps=output_horizon, height, width, channels=128)
    bottleneck = tf.reshape(pool2, (-1, output_horizon, pool2.shape[1], pool2.shape[2], 128))
    convlstm = ConvLSTM2D(64, kernel_size=3, padding='same', return_sequences=True)(bottleneck)

    # Decoder convolucional con upsampling + skip connections
    up1 = UpSampling2D(2)(convlstm[:, -1])  # último frame de secuencia
    concat1 = Concatenate()([up1, conv2])
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(concat1)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(128, 3, activation='relu', padding='same')(conv3)
    conv3 = BatchNormalization()(conv3)

    up2 = UpSampling2D(2)(conv3)
    concat2 = Concatenate()([up2, conv1])
    conv4 = Conv2D(64, 3, activation='relu', padding='same')(concat2)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(64, 3, activation='relu', padding='same')(conv4)
    conv4 = BatchNormalization()(conv4)

    # Capa final para predicción (secuencia de meses)
    outputs = TimeDistributed(Conv2D(1, 1, activation='linear'))(tf.expand_dims(convlstm, axis=2))
    # outputs.shape = (batch, time_steps, 1, ny, nx, 1)
    # Reorganizar a (batch, time_steps, ny, nx, 1)
    outputs = tf.squeeze(outputs, axis=2)

    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss='mse')
    return model


# 7) Carga datos y shapefile
logger.info("📂 Cargando datasets…")
ds_full = xr.open_dataset(FULL_NC)
ds_feat = xr.open_dataset(FEATURES_NC)
boyaca_gdf = gpd.read_file(BOYACA_SHP)
if boyaca_gdf.crs is None:
    boyaca_gdf.set_crs(epsg=4326, inplace=True)
else:
    boyaca_gdf = boyaca_gdf.to_crs(epsg=4326)

times      = ds_full.time.values.astype("datetime64[M]")
user_ref   = np.datetime64(REF_DATE, "M")
last_avail = times[-1]
if user_ref > last_avail:
    ref = last_avail
    logger.info(f"REF_DATE fuera de rango; usando último mes: {ref}")
else:
    ref = user_ref

# fechas explícitas para validación y forecast
val_dates = [
    str(ref),
    str((ref - np.timedelta64(1,'M')).astype("datetime64[M]")),
    str((ref - np.timedelta64(2,'M')).astype("datetime64[M]"))
]
fc_dates  = [
    str((ref + np.timedelta64(i+1,'M')).astype("datetime64[M]"))
    for i in range(OUTPUT_HORIZON)
]

idx_ref = int(np.where(times == ref)[0][0])
lat     = ds_full.latitude.values
lon     = ds_full.longitude.values
METHODS   = ["CEEMDAN","TVFEMD","FUSION"]
BRANCHES  = ["high","medium","low"]

all_metrics = []
preds_store = {}
true_store  = {}
histories   = {}

# callbacks
es_cb = callbacks.EarlyStopping("val_loss", patience=PATIENCE_ES, restore_best_weights=True)
lr_cb = callbacks.ReduceLROnPlateau("val_loss", factor=LR_FACTOR, patience=LR_PATIENCE, min_lr=1e-6)

# 8) Entrenamiento modelos base (low, medium, high)
for method in METHODS:
    for branch in BRANCHES:
        name = f"{method}_{branch}"
        if name not in ds_feat.data_vars:
            logger.warning(f"⚠ {name} no existe, salto.")
            continue
        logger.info(f"▶ Procesando {name}")
        try:
            Xarr = ds_feat[name].values
            Yarr = ds_full["total_precipitation"].values
            T, ny, nx = Xarr.shape
            n_cells = ny * nx

            Xfull = Xarr.reshape(T, n_cells)
            yfull = Yarr.reshape(T, n_cells)

            Nw = T - INPUT_WINDOW - OUTPUT_HORIZON + 1
            if Nw <= 0:
                logger.warning("❌ Ventanas insuficientes.")
                continue

            Xs = np.stack([Xfull[i : i + INPUT_WINDOW] for i in range(Nw)], axis=0)
            ys = np.stack([yfull[i + INPUT_WINDOW : i + INPUT_WINDOW + OUTPUT_HORIZON] for i in range(Nw)], axis=0)

            if branch == "low":
                months = pd.to_datetime(ds_full.time.values).month.values
                s = np.sin(2 * np.pi * months / 12)
                c = np.cos(2 * np.pi * months / 12)
                Ss = np.stack([s[i : i + INPUT_WINDOW] for i in range(Nw)], axis=0)
                Cs = np.stack([c[i : i + INPUT_WINDOW] for i in range(Nw)], axis=0)
                Ss = np.repeat(Ss[:, :, None], n_cells, axis=2)
                Cs = np.repeat(Cs[:, :, None], n_cells, axis=2)
                Xs = np.concatenate([Xs, Ss, Cs], axis=2)
                n_feats = Xs.shape[2]
            else:
                n_feats = n_cells

            scX = StandardScaler().fit(Xs.reshape(-1, n_feats))
            scY = StandardScaler().fit(ys.reshape(-1, n_cells))
            Xs_s = scX.transform(Xs.reshape(-1, n_feats)).reshape(Xs.shape)
            ys_s = scY.transform(ys.reshape(-1, n_cells)).reshape(ys.shape)

            k_ref = np.clip(idx_ref - INPUT_WINDOW + 1, 0, Nw - 1)
            i0 = np.clip(k_ref - (OUTPUT_HORIZON - 1), 0, Nw - OUTPUT_HORIZON)

            X_tr, y_tr = Xs_s[:i0], ys_s[:i0]
            X_va, y_va = Xs_s[i0 : i0 + OUTPUT_HORIZON], ys_s[i0 : i0 + OUTPUT_HORIZON]

            model_path = MODEL_DIR / f"{name}_w{OUTPUT_HORIZON}_ref{ref}.keras"
            if model_path.exists():
                model = tf.keras.models.load_model(str(model_path), compile=False)
                model.compile(optimizer="adam", loss="mse")
                logger.info(f"⏩ Cargado modelo: {model_path.name}")
            else:
                if branch == "low":
                    model = build_gru_ed_low((INPUT_WINDOW, n_feats), OUTPUT_HORIZON, n_cells)
                else:
                    model = build_gru_ed((INPUT_WINDOW, n_feats), OUTPUT_HORIZON, n_cells)
                hist = model.fit(
                    DataGenerator(X_tr, y_tr),
                    validation_data=DataGenerator(X_va, y_va),
                    epochs=MAX_EPOCHS,
                    callbacks=[es_cb, lr_cb],
                    verbose=1,
                )
                model.save(str(model_path))
                histories[name] = hist.history

            preds = model.predict(X_va, verbose=0).reshape(OUTPUT_HORIZON, OUTPUT_HORIZON, n_cells)
            for h in range(OUTPUT_HORIZON):
                date_val = val_dates[h]
                pm_flat = preds[h, 0]
                tm_flat = y_va[h, 0]
                pm = scY.inverse_transform(pm_flat.reshape(1, -1))[0].reshape(ny, nx)
                tm = scY.inverse_transform(tm_flat.reshape(1, -1))[0].reshape(ny, nx)
                rmse, mae, mape, r2 = evaluate_metrics(tm.ravel(), pm.ravel())
                all_metrics.append(
                    {
                        "model": name,
                        "branch": branch,
                        "horizon": h + 1,
                        "type": "validation",
                        "date": date_val,
                        "RMSE": rmse,
                        "MAE": mae,
                        "MAPE": mape,
                        "R2": r2,
                    }
                )
                preds_store[(name, date_val)] = pm
                true_store[(name, date_val)] = tm

            X_fc = Xs_s[k_ref : k_ref + 1]
            fc_s = model.predict(X_fc, verbose=0)[0]
            FC = scY.inverse_transform(fc_s)
            for h in range(OUTPUT_HORIZON):
                date_fc = fc_dates[h]
                all_metrics.append(
                    {
                        "model": name,
                        "branch": branch,
                        "horizon": h + 1,
                        "type": "forecast",
                        "date": date_fc,
                        "RMSE": np.nan,
                        "MAE": np.nan,
                        "MAPE": np.nan,
                        "R2": np.nan,
                    }
                )
                preds_store[(name, date_fc)] = FC[h].reshape(ny, nx)

        except Exception:
            logger.exception(f"‼ Error en {name}, continuo…")
            continue

# 9) Guardar métricas y mostrar tabla
dfm = pd.DataFrame(all_metrics)
dfm.to_csv(RESULTS_CSV, index=False)
import ace_tools_open as tools

tools.display_dataframe_to_user(name=f"Metrics_w{OUTPUT_HORIZON}_ref{ref}", dataframe=dfm)

# 10) Curvas de entrenamiento
for name, hist in histories.items():
    plt.figure(figsize=(6, 4))
    plt.plot(hist["loss"], label="train")
    plt.plot(hist["val_loss"], label="val")
    plt.title(f"Loss curve: {name}")
    plt.xlabel("Epoch")
    plt.ylabel("MSE")
    plt.legend()
    plt.tight_layout()
    plt.show()

# 10bis) True vs Predicted por rama y horizonte
for branch in BRANCHES:
    for h in range(1, OUTPUT_HORIZON + 1):
        plt.figure(figsize=(5, 5))
        for method in METHODS:
            key = f"{method}_{branch}"
            date_val = val_dates[h - 1]
            if (key, date_val) in preds_store and (key, date_val) in true_store:
                y_true = true_store[(key, date_val)].ravel()
                y_pred = preds_store[(key, date_val)].ravel()
                plt.scatter(y_true, y_pred, alpha=0.3, s=2, label=method)
        lims = [0, max(plt.xlim()[1], plt.ylim()[1])]
        plt.plot(lims, lims, "k--")
        plt.xlabel("True")
        plt.ylabel("Predicted")
        plt.title(f"True vs Pred — {branch}, H={h}")
        plt.legend()
        plt.tight_layout()
        plt.show()

# 11) Mapas 3×3 validación H=1
xmin, ymin, xmax, ymax = boyaca_gdf.total_bounds
for date_val in val_dates:
    arrs = [
        preds_store[(f"{m}_{b}", date_val)].ravel()
        for m in METHODS
        for b in BRANCHES
        if (f"{m}_{b}", date_val) in preds_store
    ]
    if not arrs:
        logger.warning(f"No hay predicciones para {date_val}, salto plot.")
        continue
    vmin, vmax = np.min(arrs), np.max(arrs)
    fig, axs = plt.subplots(3, 3, figsize=(12, 12), subplot_kw={"projection": ccrs.PlateCarree()})
    fig.suptitle(f"Validación H=1 — {date_val}", fontsize=16)
    for i, b in enumerate(BRANCHES):
        for j, m in enumerate(METHODS):
            ax = axs[i, j]
            ax.set_extent([xmin, xmax, ymin, ymax], ccrs.PlateCarree())
            ax.add_geometries(boyaca_gdf.geometry, ccrs.PlateCarree(), edgecolor="black", facecolor="none", linewidth=1)
            key = (f"{m}_{b}", date_val)
            if key in preds_store:
                pcm = ax.pcolormesh(
                    lon, lat, preds_store[key], vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree(), cmap="Blues"
                )
            ax.set_title(f"{m}_{b}")
    fig.colorbar(pcm, ax=axs, orientation="horizontal", fraction=0.05, pad=0.04, label="Precipitación (mm)")
    fig.savefig(IMAGE_DIR / f"val_H1_{date_val}.png", dpi=150)
    plt.show()

    arrs_mape = [
        np.clip(np.abs((true_store[k] - preds_store[k]) / (true_store[k] + 1e-5)) * 100, 0, 200).ravel()
        for k in preds_store
        if k[1] == date_val and k in true_store
    ]
    if not arrs_mape:
        continue
    vmin2, vmax2 = 0, np.max(arrs_mape)
    fig, axs = plt.subplots(3, 3, figsize=(12, 12), subplot_kw={"projection": ccrs.PlateCarree()})
    fig.suptitle(f"MAPE H=1 — {date_val}", fontsize=16)
    for i, b in enumerate(BRANCHES):
        for j, m in enumerate(METHODS):
            ax = axs[i, j]
            ax.set_extent([xmin, xmax, ymin, ymax], ccrs.PlateCarree())
            ax.add_geometries(boyaca_gdf.geometry, ccrs.PlateCarree(), edgecolor="black", facecolor="none", linewidth=1)
            key = (f"{m}_{b}", date_val)
            if key in preds_store and key in true_store:
                mmap = np.clip(np.abs((true_store[key] - preds_store[key]) / (true_store[key] + 1e-5)) * 100, 0, 200)
                pcm2 = ax.pcolormesh(lon, lat, mmap, vmin=vmin2, vmax=vmax2, transform=ccrs.PlateCarree(), cmap="Reds")
            ax.set_title(f"{m}_{b}")
    fig.colorbar(pcm2, ax=axs, orientation="horizontal", fraction=0.05, pad=0.04, label="MAPE (%)")
    fig.savefig(IMAGE_DIR / f"mape_H1_{date_val}.png", dpi=150)
    plt.show()

# --- Meta-modelo U-Net + ConvLSTM ---

import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split

def prepare_meta_dataset(preds_store, true_store, val_dates, METHODS, BRANCHES, ny, nx):
    """
    Prepara X, y para meta-modelo como imágenes con shape:
    X: (samples, num_models=9, ny, nx, 1)
    y: (samples, output_horizon=3, ny, nx, 1)
    """
    X_meta, y_meta = [], []
    for date_val in val_dates:
        # recoger predicciones para cada modelo (9 en total)
        X_imgs = []
        for branch in BRANCHES:
            for method in METHODS:
                key = (f"{method}_{branch}", date_val)
                if key in preds_store:
                    X_imgs.append(preds_store[key][None, ..., None])  # (1, ny, nx, 1)
                else:
                    logger.warning(f"⚠ Falta {key} para meta dataset.")
                    break
            else:
                continue
            break
        else:
            X_imgs_stacked = np.vstack(X_imgs)  # (9, ny, nx, 1)
            y_img = true_store[(f"{METHODS[0]}_{BRANCHES[0]}", date_val)][None, ..., None]  # (1, ny, nx, 1)
            X_meta.append(X_imgs_stacked)
            y_meta.append(y_img)
    if not X_meta:
        logger.warning("⚠ No hay muestras para meta-modelo.")
        return None, None
    return np.stack(X_meta), np.stack(y_meta)


def build_unet_convlstm_meta(input_shape, output_horizon, ny, nx, channels=1):
    """
    Modelo híbrido U-Net + ConvLSTM para meta-modelo que recibe
    secuencia de mapas (num_models, ny, nx, channels)
    y predice secuencia (output_horizon, ny, nx, 1)
    """

    inputs = Input(shape=input_shape)  # (num_models, ny, nx, channels)

    # reshape para concatenar modelos en canales
    x = tf.reshape(inputs, (-1, ny, nx, input_shape[0]*channels))

    # Encoder convolucional
    conv1 = Conv2D(64, 3, activation="relu", padding="same")(x)
    conv1 = BatchNormalization()(conv1)
    conv1 = Conv2D(64, 3, activation="relu", padding="same")(conv1)
    conv1 = BatchNormalization()(conv1)
    pool1 = MaxPooling2D(2)(conv1)

    conv2 = Conv2D(128, 3, activation="relu", padding="same")(pool1)
    conv2 = BatchNormalization()(conv2)
    conv2 = Conv2D(128, 3, activation="relu", padding="same")(conv2)
    conv2 = BatchNormalization()(conv2)
    pool2 = MaxPooling2D(2)(conv2)

    # Bottleneck ConvLSTM: reshape para secuencia temporal
    bottleneck = tf.reshape(pool2, (-1, output_horizon, pool2.shape[1], pool2.shape[2], 128))
    convlstm = ConvLSTM2D(64, kernel_size=3, padding="same", return_sequences=True)(bottleneck)

    # Decoder convolucional con upsampling y skip connections
    up1 = UpSampling2D(2)(convlstm[:, -1])
    concat1 = Concatenate()([up1, conv2])
    conv3 = Conv2D(128, 3, activation="relu", padding="same")(concat1)
    conv3 = BatchNormalization()(conv3)
    conv3 = Conv2D(128, 3, activation="relu", padding="same")(conv3)
    conv3 = BatchNormalization()(conv3)

    up2 = UpSampling2D(2)(conv3)
    concat2 = Concatenate()([up2, conv1])
    conv4 = Conv2D(64, 3, activation="relu", padding="same")(concat2)
    conv4 = BatchNormalization()(conv4)
    conv4 = Conv2D(64, 3, activation="relu", padding="same")(conv4)
    conv4 = BatchNormalization()(conv4)

    outputs = TimeDistributed(Conv2D(1, 1, activation="linear"))(tf.expand_dims(convlstm, axis=2))
    outputs = tf.squeeze(outputs, axis=2)  # (batch, time_steps, ny, nx, 1)

    model = Model(inputs, outputs)
    model.compile(optimizer="adam", loss="mse")
    return model

# Preparar datos meta-modelo
ny, nx = len(lat), len(lon)
X_meta, y_meta = prepare_meta_dataset(preds_store, true_store, val_dates, METHODS, BRANCHES, ny, nx)

if X_meta is not None and y_meta is not None:
    # Separar train/test meta
    Xtr, Xte, ytr, yte = train_test_split(X_meta, y_meta, test_size=0.2, random_state=42)

    meta_model = build_unet_convlstm_meta(
        input_shape=X_meta.shape[1:], output_horizon=OUTPUT_HORIZON, ny=ny, nx=nx, channels=1
    )

    hist_meta = meta_model.fit(
        Xtr,
        ytr,
        validation_data=(Xte, yte),
        epochs=MAX_EPOCHS,
        batch_size=BATCH_SIZE,
        callbacks=[es_cb, lr_cb],
        verbose=1,
    )

    # Guardar meta-modelo
    meta_model_path = MODEL_DIR / f"meta_model_unet_convlstm_{ref}.keras"
    meta_model.save(str(meta_model_path))

    # Evaluar meta-modelo
    yhat_meta = meta_model.predict(Xte, verbose=0)
    yte_flat = yte.reshape(-1, ny * nx)
    yhat_flat = yhat_meta.reshape(-1, ny * nx)

    rmse_meta, mae_meta, mape_meta, r2_meta = evaluate_metrics(yte_flat, yhat_flat)
    logger.info(f"Meta-modelo U-Net + ConvLSTM RMSE: {rmse_meta:.3f}")

    # Gráficos True vs Predicted para meta-modelo en cada horizonte
    for h in range(OUTPUT_HORIZON):
        plt.figure(figsize=(5, 5))
        y_true_h = yte[:, h, :, :, 0].ravel()
        y_pred_h = yhat_meta[:, h, :, :, 0].ravel()
        plt.scatter(y_true_h, y_pred_h, alpha=0.3, s=2)
        lims = [min(min(y_true_h), min(y_pred_h)), max(max(y_true_h), max(y_pred_h))]
        plt.plot(lims, lims, "k--")
        plt.xlabel("True")
        plt.ylabel("Predicted")
        plt.title(f"Meta-model True vs Predicted — Horizon {h+1}")
        plt.tight_layout()
        plt.show()

    # Tabla de métricas para meta-modelo por horizonte
    meta_metrics = []
    for h in range(OUTPUT_HORIZON):
        y_true_h = yte[:, h, :, :, 0].ravel()
        y_pred_h = yhat_meta[:, h, :, :, 0].ravel()
        rmse, mae, mape, r2 = evaluate_metrics(y_true_h, y_pred_h)
        meta_metrics.append(
            {"horizon": h + 1, "RMSE": rmse, "MAE": mae, "MAPE": mape, "R2": r2}
        )
    df_meta = pd.DataFrame(meta_metrics)
    display(df_meta)
    df_meta.to_csv(MODEL_DIR / f"meta_metrics_unet_convlstm_w{OUTPUT_HORIZON}_ref{ref}.csv", index=False)

else:
    logger.warning("⚠ No hay muestras para entrenar el meta-modelo.")

logger.info("🏁 Proceso completo.")


▶️ Base path: /Users/riperez/Conda/anaconda3/envs/precipitation_prediction/github.com/ml_precipitation_prediction


2025-05-19 23:11:28,831 INFO ⚙ CPU cores: 10, RAM libre: 2.7 GB
2025-05-19 23:11:28,833 INFO 📂 Cargando datasets…
2025-05-19 23:11:30,185 INFO REF_DATE fuera de rango; usando último mes: 2025-02
2025-05-19 23:11:30,187 INFO ▶ Procesando CEEMDAN_high


Epoch 1/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 93ms/step - loss: 0.9856 - val_loss: 0.8132 - learning_rate: 0.0010
Epoch 2/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 69ms/step - loss: 0.6379 - val_loss: 0.7613 - learning_rate: 0.0010
Epoch 3/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 95ms/step - loss: 0.5127 - val_loss: 0.6941 - learning_rate: 0.0010
Epoch 4/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 86ms/step - loss: 0.4712 - val_loss: 0.7173 - learning_rate: 0.0010
Epoch 5/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 96ms/step - loss: 0.4667 - val_loss: 0.7911 - learning_rate: 0.0010
Epoch 6/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 77ms/step - loss: 0.4334 - val_loss: 0.6176 - learning_rate: 0.0010
Epoch 7/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 102ms/step - loss: 0.4242 - val_loss: 0.7683 - learning_rate

2025-05-19 23:13:53,063 INFO ▶ Procesando CEEMDAN_medium


Epoch 1/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 106ms/step - loss: 1.0133 - val_loss: 1.4045 - learning_rate: 0.0010
Epoch 2/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 79ms/step - loss: 0.8065 - val_loss: 0.8336 - learning_rate: 0.0010
Epoch 3/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 77ms/step - loss: 0.7130 - val_loss: 1.4256 - learning_rate: 0.0010
Epoch 4/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 86ms/step - loss: 0.6434 - val_loss: 1.2344 - learning_rate: 0.0010
Epoch 5/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 74ms/step - loss: 0.5437 - val_loss: 1.2226 - learning_rate: 0.0010
Epoch 6/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 84ms/step - loss: 0.4841 - val_loss: 1.3690 - learning_rate: 0.0010
Epoch 7/300
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 80ms/step - loss: 0.4579 - val_loss: 2.1632 - learning_rate

2025-05-19 23:15:18,911 INFO ▶ Procesando CEEMDAN_low


: 