<a href="https://colab.research.google.com/github/ninja-marduk/ml_precipitation_prediction/blob/feature%2Fhybrid-models/models/hybrid_models_ST-HybridWaveStack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# -*- coding: utf-8 -*-
# Script: Model Training Multi-Ambiente (Local & Colab)
#         con Evaluación, Forecast, Plots & GIFs
# ==============================================================================
import os
import sys
import warnings
import tempfile
import logging
from pathlib import Path

# 0) Suprimir warnings innecesarios
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)
from cartopy.io import DownloadWarning
warnings.filterwarnings("ignore", category=DownloadWarning)

# 1) Detectar entorno
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=True)
    BASE_PATH = Path("/content/drive/MyDrive/ml_precipitation_prediction")
    !pip install -q xarray netCDF4 optuna seaborn cartopy ace_tools_open imageio psutil tqdm
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p/".git").exists():
            BASE_PATH = p
            break

print(f"▶️ Base path: {BASE_PATH}")

# 2) Rutas y logger
MODEL_DIR   = BASE_PATH/"models"/"output"/"trained_models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
FEATURES_NC = BASE_PATH/"models"/"output"/"features_fusion_branches.nc"
FULL_NC     = BASE_PATH/"data"/"output"/"complete_dataset_with_features_with_clusters_elevation_with_windows.nc"
BOYACA_SHP  = BASE_PATH/"data"/"input"/"shapes"/"MGN_Departamento.shp"
RESULTS_CSV = MODEL_DIR/"training_metrics.csv"
GIF_DIR     = MODEL_DIR/"gifs"
GIF_DIR.mkdir(exist_ok=True)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)

# 3) Imports
import xarray           as xr
import numpy            as np
import pandas           as pd
import geopandas        as gpd
import imageio.v2       as imageio
import matplotlib.pyplot as plt
import cartopy.crs      as ccrs
import cartopy.feature  as cfeature
import tensorflow       as tf

from sklearn.preprocessing import StandardScaler
from sklearn.metrics       import mean_squared_error, mean_absolute_error, r2_score
from sklearn.linear_model  import Ridge
from joblib                import cpu_count, dump, load, Parallel, delayed
import psutil
from tqdm.auto             import tqdm

# 4) Recursos hardware
CORES     = cpu_count()
AVAIL_RAM = psutil.virtual_memory().available / (1024**3)
gpus      = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    logger.info(f"🖥 GPU detectada: {gpus[0].name}")
else:
    try:
        tf.config.threading.set_inter_op_parallelism_threads(CORES)
        tf.config.threading.set_intra_op_parallelism_threads(CORES)
    except RuntimeError:
        pass
    logger.info(f"⚙ CPU cores: {CORES}, RAM libre: {AVAIL_RAM:.1f} GB")

# decidir n_jobs
if AVAIL_RAM < 2:
    N_JOBS = 1
elif AVAIL_RAM < 8:
    N_JOBS = max(1, CORES//2)
else:
    N_JOBS = max(1, CORES-1)
logger.info(f"🔧 Paralelismo: n_jobs={N_JOBS}")

# 5) Parámetros
INPUT_WINDOW = 60
HORIZON      = 3
METHODS      = ["CEEMDAN","TVFEMD","FUSION"]
BRANCHES     = ["high","medium","low"]
TARGET_VAR   = "total_precipitation"

# 6) Utilitarios
def evaluate_metrics(y_true, y_pred):
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae  = mean_absolute_error(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred)/(y_true + 1e-5))*100)
    r2   = r2_score(y_true, y_pred) if y_true.var()>0 else np.nan
    return rmse, mae, mape, r2

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, X, Y, batch_size=16):
        self.X, self.Y   = X.astype(np.float32), Y.astype(np.float32)
        self.batch_size = batch_size
    def __len__(self):
        return int(np.ceil(len(self.X)/self.batch_size))
    def __getitem__(self, idx):
        sl = slice(idx*self.batch_size, (idx+1)*self.batch_size)
        return self.X[sl], self.Y[sl]

def build_lstm(input_shape, horizon, n_cells):
    m = tf.keras.models.Sequential([
        tf.keras.layers.Input(input_shape),
        tf.keras.layers.LSTM(64),
        tf.keras.layers.Dense(horizon * n_cells),
        tf.keras.layers.Reshape((horizon, n_cells))
    ])
    m.compile("adam","mse")
    return m

def build_cnn_lstm(input_shape, horizon, n_cells):
    m = tf.keras.models.Sequential([
        tf.keras.layers.Input((*input_shape,1)),
        tf.keras.layers.Conv2D(32,(3,3),padding="same",activation="relu"),
        tf.keras.layers.Reshape((input_shape[0], input_shape[1]*32)),
        tf.keras.layers.LSTM(64),
        tf.keras.layers.Dense(horizon * n_cells),
        tf.keras.layers.Reshape((horizon, n_cells))
    ])
    m.compile("adam","mse")
    return m

# 7) Cargar datos y shapefile
logger.info("📂 Cargando datasets…")
ds_full = xr.open_dataset(FULL_NC)
ds_feat = xr.open_dataset(FEATURES_NC)

boyaca = gpd.read_file(BOYACA_SHP)
if boyaca.crs is None:
    boyaca.set_crs(epsg=4326, inplace=True)
else:
    boyaca = boyaca.to_crs(epsg=4326)

lat   = ds_full.latitude.values
lon   = ds_full.longitude.values
times = pd.to_datetime(ds_full.time.values)

all_metrics = []

# 8) Bucle principal
for method in METHODS:
    for branch in BRANCHES:
        var = f"{method}_{branch}"
        if var not in ds_feat.data_vars:
            logger.warning(f"⚠ {var} no existe — saltando.")
            continue
        logger.info(f"▶ Procesando {method} – {branch}")

        # extraer arrays
        Xarr = ds_feat[var].values      # (T,ny,nx)
        yarr = ds_full[TARGET_VAR].values
        T, ny, nx = Xarr.shape
        n_cells   = ny * nx

        # ventanas deslizantes
        Xfull = Xarr.reshape(T, n_cells, 1)
        yfull = yarr.reshape(T, n_cells)
        N = T - INPUT_WINDOW - HORIZON + 1
        if N <= 0:
            logger.warning("❌ No hay suficientes pasos de tiempo.")
            continue

        Xs = np.stack([Xfull[i:i+INPUT_WINDOW]
                       for i in range(N)], axis=0)
        ys = np.stack([yfull[i+INPUT_WINDOW:i+INPUT_WINDOW+HORIZON]
                       for i in range(N)], axis=0)

        # split 70/30
        split   = int(0.7 * N)
        X_tr, X_va = Xs[:split], Xs[split:]
        y_tr, y_va = ys[:split], ys[split:]

        # escalado
        sx     = StandardScaler().fit(X_tr.reshape(-1,1))
        sy     = StandardScaler().fit(y_tr.reshape(-1,1))
        X_tr_s = sx.transform(X_tr.reshape(-1,1)).reshape(X_tr.shape)
        X_va_s = sx.transform(X_va.reshape(-1,1)).reshape(X_va.shape)
        y_tr_s = sy.transform(y_tr.reshape(-1,1)).reshape(y_tr.shape)
        y_va_s = sy.transform(y_va.reshape(-1,1)).reshape(y_va.shape)

        # reshape para modelos
        X_tr_m = X_tr_s.reshape(-1, INPUT_WINDOW, n_cells)
        X_va_m = X_va_s.reshape(-1, INPUT_WINDOW, n_cells)

        name  = f"{method}_{branch}"
        ext   = ".keras" if branch!="low" else ".joblib"
        mpath = MODEL_DIR/(name+ext)

        # cargar o entrenar
        if mpath.exists():
            logger.info(f"⏩ Cargando modelo: {mpath.name}")
            if branch!="low":
                model = tf.keras.models.load_model(mpath)
            else:
                model = load(mpath)
        else:
            logger.info("🔨 Entrenando modelo…")
            if branch=="high":
                model = build_lstm((INPUT_WINDOW,n_cells), HORIZON, n_cells)
                gen_tr = DataGenerator(X_tr_m, y_tr_s, batch_size=16)
                gen_va = DataGenerator(X_va_m, y_va_s, batch_size=16)
                cb     = tf.keras.callbacks.EarlyStopping("val_loss", patience=5, restore_best_weights=True)
                model.fit(gen_tr, validation_data=gen_va, epochs=50, callbacks=[cb], verbose=1)
                model.save(mpath)

            elif branch=="medium":
                Xt2 = X_tr_m[...,None]
                Xv2 = X_va_m[...,None]
                model = build_cnn_lstm((INPUT_WINDOW,n_cells), HORIZON, n_cells)
                gen_tr = DataGenerator(Xt2, y_tr_s, batch_size=16)
                gen_va = DataGenerator(Xv2, y_va_s, batch_size=16)
                cb     = tf.keras.callbacks.EarlyStopping("val_loss", patience=5, restore_best_weights=True)
                model.fit(gen_tr, validation_data=gen_va, epochs=50, callbacks=[cb], verbose=1)
                model.save(mpath)

            else:
                # low-branch: Ridge multioutput en forma cerrada
                Xr   = X_tr_m.reshape(split, INPUT_WINDOW * n_cells)
                Yall = y_tr.reshape(split, HORIZON * n_cells)
                logger.info("   • Low-branch: entrenando Ridge multioutput…")
                ridge = Ridge(alpha=1.0, solver="auto")
                ridge.fit(Xr, Yall)
                dump(ridge, mpath)
                model = ridge

        # ——— Evaluación últimas 3 ventanas ———
        dates_eval = times[-HORIZON:]
        frames      = []
        X_eval = X_va_m[-3:]
        if branch=="medium":
            X_eval = X_eval[...,None]

        if branch!="low":
            preds_s = model.predict(X_eval, verbose=0).reshape(3,HORIZON,n_cells)
        else:
            # reshape correcto para low-branch
            Xre     = X_eval.reshape(3, INPUT_WINDOW * n_cells)
            # paralelo con joblib
            preds_block = Parallel(n_jobs=N_JOBS)(
                delayed(model.predict)(Xre[i:i+1]) for i in range(3)
            )
            preds_s = np.vstack(preds_block).reshape(3,HORIZON,n_cells)

        preds = sy.inverse_transform(preds_s.reshape(-1,1)).reshape(3,HORIZON,n_cells)
        true  = sy.inverse_transform(y_va_s[-3:].reshape(-1,1)).reshape(3,HORIZON,n_cells)
        vmax  = preds.max()

        for i in range(3):
            for h in range(HORIZON):
                pm, tm = preds[i,h].reshape(ny,nx), true[i,h].reshape(ny,nx)
                date   = dates_eval[i].strftime("%Y-%m")

                # Predicción
                fig, ax = plt.subplots(figsize=(6,5), subplot_kw={"projection":ccrs.PlateCarree()})
                mesh = ax.pcolormesh(lon, lat, pm, cmap="Blues", vmin=0, vmax=vmax,
                                     transform=ccrs.PlateCarree())
                ax.add_geometries(boyaca.geometry, ccrs.PlateCarree(),
                                  edgecolor="k", facecolor="none", linewidth=1)
                ax.coastlines(); ax.add_feature(cfeature.BORDERS, linestyle=":")
                ax.set_title(f"{name} Eval Pred h={h+1}\n{date}")
                cb = plt.colorbar(mesh, ax=ax, pad=0.02); cb.set_label("Precip (mm)")
                tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
                fig.savefig(tmp.name, dpi=150); plt.close(fig)
                frames.append(imageio.imread(tmp.name)); os.unlink(tmp.name)

                # MAPE
                mape_map = np.clip(np.abs((tm-pm)/(tm+1e-5))*100, 0, 100)
                fig, ax = plt.subplots(figsize=(6,5), subplot_kw={"projection":ccrs.PlateCarree()})
                mesh2 = ax.pcolormesh(lon, lat, mape_map, cmap="Reds", vmin=0, vmax=100,
                                      transform=ccrs.PlateCarree())
                ax.add_geometries(boyaca.geometry, ccrs.PlateCarree(),
                                  edgecolor="k", facecolor="none", linewidth=1)
                ax.coastlines(); ax.add_feature(cfeature.BORDERS, linestyle=":")
                ax.set_title(f"{name} Eval MAPE h={h+1}\n{date}")
                plt.colorbar(mesh2, ax=ax, pad=0.02); plt.close(fig)

                rmse, mae, mape_v, r2 = evaluate_metrics(tm.ravel(), pm.ravel())
                all_metrics.append({
                    "model":   name,
                    "branch":  branch,
                    "horizon": h+1,
                    "type":    "evaluation",
                    "date":    date,
                    "RMSE":    rmse,
                    "MAE":     mae,
                    "MAPE":    mape_v,
                    "R2":      r2
                })

        # guardar GIF de evaluación
        gif_path = GIF_DIR/f"{name}_eval.gif"
        imageio.mimsave(str(gif_path), frames, duration=2.0)
        logger.info(f"💾 GIF guardado: {gif_path.name}")

# 9) Guardar métricas
dfm = pd.DataFrame(all_metrics)
dfm.to_csv(RESULTS_CSV, index=False)
import ace_tools_open as tools
tools.display_dataframe_to_user(name="Training & Eval Metrics", dataframe=dfm)

logger.info("🏁 ¡Todo completado con trazabilidad y paralelización optimizada!")


Mounted at /content/drive
▶️ Base path: /content/drive/MyDrive/ml_precipitation_prediction
Epoch 1/50


  self._warn_if_super_not_called()


[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 66ms/step - loss: 0.8958 - val_loss: 0.6200
Epoch 2/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - loss: 0.5954 - val_loss: 0.5236
Epoch 3/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - loss: 0.4781 - val_loss: 0.4491
Epoch 4/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - loss: 0.3894 - val_loss: 0.4156
Epoch 5/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - loss: 0.3196 - val_loss: 0.3794
Epoch 6/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - loss: 0.2932 - val_loss: 0.3632
Epoch 7/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - loss: 0.2685 - val_loss: 0.3688
Epoch 8/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - loss: 0.2307 - val_loss: 0.3456
Epoch 9/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m

  self._warn_if_super_not_called()


[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 127ms/step - loss: 0.9350 - val_loss: 0.7156
Epoch 2/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 96ms/step - loss: 0.7502 - val_loss: 0.6900
Epoch 3/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - loss: 0.7256 - val_loss: 0.6904
Epoch 4/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - loss: 0.7000 - val_loss: 0.6931
Epoch 5/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - loss: 0.7323 - val_loss: 0.6974
Epoch 6/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 91ms/step - loss: 0.7148 - val_loss: 0.6942
Epoch 7/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 92ms/step - loss: 0.7093 - val_loss: 0.6920
Epoch 1/50


  self._warn_if_super_not_called()


[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 26ms/step - loss: 0.8979 - val_loss: 0.6066
Epoch 2/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - loss: 0.5947 - val_loss: 0.4798
Epoch 3/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - loss: 0.4103 - val_loss: 0.4215
Epoch 4/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - loss: 0.3674 - val_loss: 0.3950
Epoch 5/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - loss: 0.3026 - val_loss: 0.3714
Epoch 6/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 16ms/step - loss: 0.2662 - val_loss: 0.3490
Epoch 7/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - loss: 0.2683 - val_loss: 0.3446
Epoch 8/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - loss: 0.2337 - val_loss: 0.3267
Epoch 9/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m



Epoch 1/50


  self._warn_if_super_not_called()


[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 112ms/step - loss: 0.9211 - val_loss: 0.6928
Epoch 2/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 90ms/step - loss: 0.7528 - val_loss: 0.6933
Epoch 3/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 95ms/step - loss: 0.6962 - val_loss: 0.6889
Epoch 4/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - loss: 0.6980 - val_loss: 0.6840
Epoch 5/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 90ms/step - loss: 0.6836 - val_loss: 0.6885
Epoch 6/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - loss: 0.7327 - val_loss: 0.6833
Epoch 7/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - loss: 0.6591 - val_loss: 0.6667
Epoch 8/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 94ms/step - loss: 0.6836 - val_loss: 0.6616
Epoch 9/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37



Training & Eval Metrics


0
Loading ITables v2.4.0 from the internet...  (need help?)
