In [1]:
# -*- coding: utf-8 -*-
"""
TopoRain-Net: entrenamiento y evaluación de modelos base (LSTM, GRU, MLP, XGB)
y meta-modelo MLP multisalida sobre features_fusion_branches + lags + topografía.
Genera métricas, scatter, mapas y tablas (global, por elevación, por percentiles).
"""

import warnings, logging
from pathlib import Path

import numpy            as np
import pandas           as pd
import xarray           as xr
import geopandas        as gpd
import matplotlib.pyplot as plt
import cartopy.crs      as ccrs

from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.metrics        import mean_squared_error, mean_absolute_error, r2_score
from xgboost                import XGBRegressor

import tensorflow as tf
from tensorflow.keras.models    import Sequential
from tensorflow.keras.layers    import Input, Dense, LSTM, GRU, Flatten, Reshape, Dropout
from tensorflow.keras.callbacks import EarlyStopping

import ace_tools_open as tools

# -----------------------------------------------------------------------------
# Configuración y rutas
# -----------------------------------------------------------------------------
logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

# Configuración del entorno (compatible con Colab y local)
import os
import sys
from pathlib import Path
import shutil
import time
import psutil

# Detectar si estamos en Google Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    # Si estamos en Colab, clonar el repositorio
    !git clone https://github.com/ninja-marduk/ml_precipitation_prediction.git
    %cd ml_precipitation_prediction
    # Instalar dependencias necesarias
    !pip install -r requirements.txt
    !pip install xarray netCDF4 optuna matplotlib seaborn lightgbm xgboost scikit-learn
    BASE_PATH = '/content/drive/MyDrive/ml_precipitation_prediction'
else:
    # Si estamos en local, usar la ruta actual
    if '/models' in os.getcwd():
        BASE_PATH = Path('..')
    else:
        BASE_PATH = Path('.')

BASE = Path(BASE_PATH)
print(f"Entorno configurado. Usando ruta base: {BASE}")

FULL_NC      = BASE/"data"/"output"/"complete_dataset_with_features_with_clusters_elevation_with_windows.nc"
FUSION_NC    = BASE/"models"/"output"/"features_fusion_branches.nc"
TRAINED_DIR  = BASE/"models"/"output"/"trained_models"
TRAINED_DIR.mkdir(parents=True, exist_ok=True)

INPUT_WINDOW   = 60
OUTPUT_HORIZON = 3

# -----------------------------------------------------------------------------
# 1) Carga de datos
# -----------------------------------------------------------------------------
logger.info("Cargando datasets…")
ds_full = xr.open_dataset(FULL_NC)
ds_fuse = xr.open_dataset(FUSION_NC)

# precipitacion y variables
prec    = ds_full["total_precipitation"].values            # (T, ny, nx)
lags    = sorted([v for v in ds_full.data_vars if "_lag" in v])
da_lags = np.stack([ds_full[lag].values for lag in lags], axis=-1)  # (T, ny, nx, n_lags)

# ramas fusionadas
branches = ["FUSION_high", "FUSION_medium", "FUSION_low"]
# Asegurémonos de que da_br sea un ndarray correcto
da_br = np.stack([ds_fuse[branch].values for branch in branches], axis=-1)  # (T, ny, nx, 3)

# topografía y cluster
elev    = ds_full["elevation"].values.ravel()               # (cells,)
slope   = ds_full["slope"].values.ravel()

# Manejar correctamente los valores de cluster (pueden ser texto)
cluster_values = ds_full["cluster_elevation"].values.ravel()
# Verificar si los valores son strings o numéricos
if isinstance(cluster_values[0], (str, np.str_)):
    # Usar un LabelEncoder para convertir strings a enteros
    le = LabelEncoder()
    cluster = le.fit_transform(cluster_values)
    logger.info(f"Clusters codificados de texto a números: {dict(zip(le.classes_, range(len(le.classes_))))}")
else:
    # Si ya son numéricos, convertir a enteros
    cluster = cluster_values.astype(int)

# dimensiones
lat     = ds_full.latitude.values
lon     = ds_full.longitude.values
ny, nx  = len(lat), len(lon)
cells   = ny*nx
T       = prec.shape[0]

logger.info(f"Dimensiones: T={T}, ny={ny}, nx={nx}, cells={cells}")
logger.info(f"Shapes: prec={prec.shape}, da_br={da_br.shape}, da_lags={da_lags.shape}")

# -----------------------------------------------------------------------------
# 2) Ventanas deslizantes (implementación escalable con chunks)
# -----------------------------------------------------------------------------
logger.info("Armando ventanas deslizantes con procesamiento por chunks…")

# Definir el tamaño de chunks para procesamiento por lotes
CHUNK_SIZE = 50  # Ajustar según capacidad de memoria disponible

# Número total de ventanas posibles
n_windows = T - INPUT_WINDOW - OUTPUT_HORIZON + 1
Xw, Yw = [], []

# Procesar por chunks para evitar problemas de memoria
for chunk_start in range(0, n_windows, CHUNK_SIZE):
    chunk_end = min(chunk_start + CHUNK_SIZE, n_windows)
    logger.info(f"Procesando chunk de ventanas {chunk_start} a {chunk_end-1} de {n_windows}")
    
    # Crear ventanas para este chunk
    chunk_Xw, chunk_Yw = [], []
    
    for i in range(chunk_start, chunk_end):
        # Stack de features en ventana
        # Branches
        bwin = da_br[i:i+INPUT_WINDOW].reshape(INPUT_WINDOW, cells, 3)
        # Lags
        lwin = da_lags[i:i+INPUT_WINDOW].reshape(INPUT_WINDOW, cells, len(lags))
        # Concatenar
        feat = np.concatenate([bwin, lwin], axis=-1)            # (W, cells, F)
        chunk_Xw.append(feat.reshape(INPUT_WINDOW, cells*feat.shape[-1]))
        # Targets
        tw = [prec[i+INPUT_WINDOW+h].reshape(cells) for h in range(OUTPUT_HORIZON)]
        chunk_Yw.append(np.stack(tw,axis=0))                    # (H, cells)
    
    # Convertir a arrays y aplicar filtro de NaNs dentro del chunk
    chunk_X = np.stack(chunk_Xw)                              # (chunk_size, W, cells*F)
    chunk_Y = np.stack(chunk_Yw)                              # (chunk_size, H, cells)
    
    # Filtrar NaNs en este chunk
    chunk_mask = (~np.isnan(chunk_X).any(axis=(1,2))) & (~np.isnan(chunk_Y).any(axis=(1,2)))
    valid_X = chunk_X[chunk_mask]
    valid_Y = chunk_Y[chunk_mask]
    
    # Añadir los datos válidos de este chunk a las listas principales
    if len(valid_X) > 0:
        Xw.append(valid_X)
        Yw.append(valid_Y)
    
    # Limpiar memoria explícitamente
    del chunk_Xw, chunk_Yw, chunk_X, chunk_Y, valid_X, valid_Y
    if 'gc' in sys.modules:
        import gc
        gc.collect()

# Concatenar todos los chunks
X = np.vstack(Xw) if Xw else np.array([])  # (N, W, cells*F)
Y = np.vstack(Yw) if Yw else np.array([])  # (N, H, cells)
N = len(X)

logger.info(f"Ventanas válidas totales: {N}")

# Opcional: Guardar en disco para futuros usos
# np.save(BASE/"models"/"output"/"ventanas_X.npy", X)
# np.save(BASE/"models"/"output"/"ventanas_Y.npy", Y)

# -----------------------------------------------------------------------------
# 3) Escalado + one-hot de cluster
# -----------------------------------------------------------------------------
logger.info("Escalado de features y codificación de cluster…")
# escalar X
scX = StandardScaler()
Xf  = scX.fit_transform(X.reshape(-1, X.shape[-1])).reshape(X.shape)

# preparar topografía+cluster (repite por celda)
try:
    # Para versiones más recientes de scikit-learn
    ohe = OneHotEncoder(sparse_output=False)
except TypeError:
    # Para versiones anteriores de scikit-learn
    ohe = OneHotEncoder(sparse=False)
    
c_ohe  = ohe.fit_transform(cluster.reshape(-1,1))         # (cells, n_clusters)
topo   = np.hstack([elev.reshape(-1,1), slope.reshape(-1,1), c_ohe])  # (cells, 2+n_clusters)
logger.info(f"Forma de matriz de topografía+cluster: {topo.shape}")

# -----------------------------------------------------------------------------
# 4) Train/val split
# -----------------------------------------------------------------------------
split   = int(0.7*N)
X_tr    = Xf[:split];    X_va = Xf[split:]
Y_tr    = Y[:split];     Y_va = Y[split:]
logger.info(f"Split train={len(X_tr)}, val={len(X_va)}")

# -----------------------------------------------------------------------------
# 5) Entrenamiento de modelos base
# -----------------------------------------------------------------------------
def build_ts_model(kind):
    m = Sequential([ Input(shape=(INPUT_WINDOW,X_tr.shape[-1])) ])
    if kind=="LSTM": m.add(LSTM(64))
    elif kind=="GRU": m.add(GRU(64))
    elif kind=="MLP":
        m.add(Flatten())
        m.add(Dense(128,activation="relu"))
    m.add(Dense(OUTPUT_HORIZON*cells))
    m.add(Reshape((OUTPUT_HORIZON,cells)))
    m.compile("adam","mse")
    return m

BASES   = ["LSTM","GRU","MLP","XGB"]
pred_va = {}
pred_fc = {}

# TensorFlow models
for b in ["LSTM","GRU","MLP"]:
    logger.info(f"Entrenando {b}…")
    m = build_ts_model(b)
    m.fit(X_tr, Y_tr,
          validation_data=(X_va,Y_va),
          epochs=100, batch_size=16,
          callbacks=[EarlyStopping("val_loss",patience=10,restore_best_weights=True)],
          verbose=0)
    pred_va[b] = m.predict(X_va)             # (Nv,H,cells)
    pred_fc[b] = m.predict(Xf[-1:])[0]       # (H,cells)

# XGBoost por horizonte
for h in range(OUTPUT_HORIZON):
    logger.info(f"Entrenando XGB H={h+1}…")
    xgb = XGBRegressor(n_estimators=100, max_depth=5, verbosity=0)
    # entrenar
    xgb.fit(X_tr.reshape(-1,X_tr.shape[-1]), Y_tr[:,h].ravel())
    # preds
    pv = xgb.predict(X_va.reshape(-1,X_va.shape[-1])).reshape(-1,cells)
    fc = xgb.predict(Xf[-1:].reshape(-1,Xf.shape[-1])).ravel()
    pred_va.setdefault("XGB",[]).append(pv)
    pred_fc.setdefault("XGB",[]).append(fc)

# apilar XGB preds → (Nv,H,cells)
pred_va["XGB"] = np.stack(pred_va["XGB"],axis=1)
pred_fc["XGB"] = np.stack(pred_fc["XGB"],axis=0)

# -----------------------------------------------------------------------------
# 6) Evaluación Base-models
# -----------------------------------------------------------------------------
rows=[]
for b in BASES:
    pv = pred_va[b]
    for h in range(OUTPUT_HORIZON):
        yt = Y_va[:,h,:].ravel()
        yp = pv[:,h,:].ravel()
        rows.append({
            "model":b, "horizon":h+1,
            "RMSE":np.sqrt(mean_squared_error(yt,yp)),
            "MAE": mean_absolute_error(yt,yp),
            "MAPE":np.mean(np.abs((yt-yp)/(yt+1e-5)))*100,
            "R2":  r2_score(yt,yp)
        })
df_base = pd.DataFrame(rows)
tools.display_dataframe_to_user("Base_models_metrics", df_base)

# scatter y mapas Base-models
grid_lon,grid_lat = np.meshgrid(lon,lat)
for b in BASES:
    pv = pred_va[b]
    for h in range(OUTPUT_HORIZON):
        # scatter
        yt,yp = Y_va[:,h,:].ravel(), pv[:,h,:].ravel()
        plt.figure(figsize=(4,4))
        plt.scatter(yt,yp,s=2,alpha=0.3)
        mn,mx = yt.min(),yp.max()
        plt.plot([mn,mx],[mn,mx],'k--')
        plt.title(f"{b} True vs Pred H={h+1}")
        plt.show()
        # mapas H
        pm = yp.reshape(-1,cells).reshape(-1,ny,nx)[0]
        tm = Y_va[:,h,:].reshape(-1,cells).reshape(-1,ny,nx)[0]
        err= np.abs((tm-pm)/(tm+1e-5))*100
        fig,ax = plt.subplots(1,2,figsize=(10,4),
                              subplot_kw={"projection":ccrs.PlateCarree()})
        ax[0].pcolormesh(grid_lon,grid_lat,pm,transform=ccrs.PlateCarree(),cmap="Blues")
        ax[0].set_title(f"{b} Pred H={h+1}")
        ax[1].pcolormesh(grid_lon,grid_lat,err,transform=ccrs.PlateCarree(),
                         cmap="Reds",vmin=0,vmax=100)
        ax[1].set_title(f"{b} MAPE% H={h+1}")
        plt.show()

# -----------------------------------------------------------------------------
# 7) Preparar datos para Meta-modelo
# -----------------------------------------------------------------------------
logger.info("Armando características para meta-modelo…")
X_meta_va = []
for i in range(len(X_va)):
    feats = []
    for b in BASES:
        feats.append(pred_va[b][i])                 # (H,cells)
    stack = np.vstack(feats)                       # (B*H, cells)
    vec   = stack.T.flatten()                      # (cells*B*H,)
    topo_flat = topo.flatten()                     # (cells*topo_dim,)
    X_meta_va.append(np.hstack([vec, topo_flat]))
X_meta_va = np.stack(X_meta_va)                    # (Nv, ...)

# forecast meta
stack_fc = []
for b in BASES:
    stack_fc.append(pred_fc[b])                     # (H,cells)
stack_fc = np.vstack(stack_fc)                      # (B*H, cells)
Xm_fc   = np.hstack([stack_fc.T.flatten(), topo.flatten()])[None,:]  # (1,...)

# target meta
Y_meta_va = Y_va.reshape(len(X_va), -1)            # (Nv, H*cells)

# -----------------------------------------------------------------------------
# 8) Entrenamiento Meta-modelo MLP
# -----------------------------------------------------------------------------
logger.info("Entrenando meta-modelo…")
meta = Sequential([
    Input(shape=(X_meta_va.shape[-1],)),
    Dense(256,activation="relu"), Dropout(0.4),
    Dense(128,activation="relu"),
    Dense(Y_meta_va.shape[-1])
])
meta.compile("adam","mse")
meta.fit(X_meta_va, Y_meta_va,
         validation_split=0.2,
         epochs=200, batch_size=32,
         callbacks=[EarlyStopping("val_loss",patience=20,restore_best_weights=True)],
         verbose=0)

# preds meta
P_meta_va = meta.predict(X_meta_va)               # (Nv, H*cells)
P_meta_fc = meta.predict(Xm_fc)                   # (1, H*cells)

Y_meta_va = P_meta_va.reshape(-1,OUTPUT_HORIZON,cells)
Y_meta_fc = P_meta_fc.reshape(OUTPUT_HORIZON,cells)

# -----------------------------------------------------------------------------
# 9) Evaluación Meta-modelo
# -----------------------------------------------------------------------------
rows=[]
for h in range(OUTPUT_HORIZON):
    yt = Y_va[:,h,:].ravel(); yp = Y_meta_va[:,h,:].ravel()
    rows.append({
        "horizon":h+1,
        "RMSE":np.sqrt(mean_squared_error(yt,yp)),
        "MAE": mean_absolute_error(yt,yp),
        "MAPE":np.mean(np.abs((yt-yp)/(yt+1e-5)))*100,
        "R2":  r2_score(yt,yp)
    })
df_meta = pd.DataFrame(rows)
tools.display_dataframe_to_user("Meta_model_metrics", df_meta)

# scatter y mapas Meta
for h in range(OUTPUT_HORIZON):
    yt,yp = Y_va[:,h,:].ravel(), Y_meta_va[:,h,:].ravel()
    plt.figure(figsize=(4,4))
    plt.scatter(yt,yp,s=2,alpha=0.3)
    mn,mx = yt.min(),yp.max()
    plt.plot([mn,mx],[mn,mx],'k--')
    plt.title(f"Meta True vs Pred H={h+1}"); plt.show()

    pm = yp.reshape(-1,cells).reshape(-1,ny,nx)[0]
    tm = Y_va[:,h,:].reshape(-1,cells).reshape(-1,ny,nx)[0]
    err= np.abs((tm-pm)/(tm+1e-5))*100
    fig,ax = plt.subplots(1,2,figsize=(10,4),
                          subplot_kw={"projection":ccrs.PlateCarree()})
    ax[0].pcolormesh(grid_lon,grid_lat,pm,transform=ccrs.PlateCarree(),cmap="Blues")
    ax[0].set_title(f"Meta Pred H={h+1}")
    ax[1].pcolormesh(grid_lon,grid_lat,err,transform=ccrs.PlateCarree(),
                     cmap="Reds",vmin=0,vmax=100)
    ax[1].set_title(f"Meta MAPE% H={h+1}")
    plt.show()

# -----------------------------------------------------------------------------
# 10) Métricas desagregadas por elevación y percentiles
# -----------------------------------------------------------------------------
logger.info("Calculando métricas por elevación y percentiles…")
mask_low  = elev < 200
mask_mid  = (elev>=200)&(elev<=1000)
mask_high = elev>1000

elev_rows = []
pct_rows  = []
for h in range(OUTPUT_HORIZON):
    yt = Y_va[:,h,:].ravel(); yp = Y_meta_va[:,h,:].ravel()
    # elevación
    for name, m in zip(["<200m","200-1000m",">1000m"], [mask_low,mask_mid,mask_high]):
        yt_m,yp_m = yt[m], yp[m]
        elev_rows.append({
            "horizon":h+1, "region":name,
            "RMSE":np.sqrt(mean_squared_error(yt_m,yp_m)),
            "MAE": mean_absolute_error(yt_m,yp_m),
            "MAPE":np.mean(np.abs((yt_m-yp_m)/(yt_m+1e-5)))*100,
            "R2":  r2_score(yt_m,yp_m)
        })
    # percentiles
    edges = [0,25,50,75,100]
    pcts  = np.percentile(yt, edges)
    for i in range(4):
        lo,hi = pcts[i],pcts[i+1]
        idx   = (yt>=lo)&(yt<hi)
        yt_p, yp_p = yt[idx], yp[idx]
        pct_rows.append({
            "horizon":h+1,
            "pct_range":f"{edges[i]}-{edges[i+1]}%",
            "RMSE":np.sqrt(mean_squared_error(yt_p,yp_p)),
            "MAE": mean_absolute_error(yt_p,yp_p),
            "MAPE":np.mean(np.abs((yt_p-yp_p)/(yt_p+1e-5)))*100,
            "R2":  r2_score(yt_p,yp_p)
        })

df_elev = pd.DataFrame(elev_rows)
df_pct  = pd.DataFrame(pct_rows)
tools.display_dataframe_to_user("Meta_by_elevation", df_elev)
tools.display_dataframe_to_user("Meta_by_percentile", df_pct)

logger.info("🎉 Proceso completado con éxito.")

2025-05-25 22:10:06,847 INFO Cargando datasets…


Entorno configurado. Usando ruta base: ..


2025-05-25 22:10:08,310 INFO Clusters codificados de texto a números: {'high': 0, 'low': 1, 'medium': 2}
2025-05-25 22:10:08,310 INFO Dimensiones: T=530, ny=61, nx=65, cells=3965
2025-05-25 22:10:08,311 INFO Shapes: prec=(530, 61, 65), da_br=(530, 61, 65, 3), da_lags=(530, 61, 65, 7)
2025-05-25 22:10:08,311 INFO Armando ventanas deslizantes con procesamiento por chunks…
2025-05-25 22:10:08,311 INFO Procesando chunk de ventanas 0 a 49 de 468
2025-05-25 22:10:08,310 INFO Dimensiones: T=530, ny=61, nx=65, cells=3965
2025-05-25 22:10:08,311 INFO Shapes: prec=(530, 61, 65), da_br=(530, 61, 65, 3), da_lags=(530, 61, 65, 7)
2025-05-25 22:10:08,311 INFO Armando ventanas deslizantes con procesamiento por chunks…
2025-05-25 22:10:08,311 INFO Procesando chunk de ventanas 0 a 49 de 468
2025-05-25 22:10:09,424 INFO Procesando chunk de ventanas 50 a 99 de 468
2025-05-25 22:10:09,424 INFO Procesando chunk de ventanas 50 a 99 de 468
2025-05-25 22:10:10,485 INFO Procesando chunk de ventanas 100 a 149 d

[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 167ms/step
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 167ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 51ms/step


2025-05-25 22:29:08,792 INFO Entrenando GRU…


[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 139ms/step
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 139ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 56ms/step


2025-05-25 22:41:34,828 INFO Entrenando MLP…


: 