<a href="https://colab.research.google.com/github/ninja-marduk/ml_precipitation_prediction/blob/feature%2Fhybrid-models/models/hybrid_models_ST-HybridWaveStack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Notebook: Model Training Multi-Ambiente (Local & Colab)
# ======================================================

# 1) Entorno: Colab vs Local
import os, sys, logging
from pathlib import Path

IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
    # Instalar dependencias necesarias
    !pip install -r requirements.txt
    !pip install xarray netCDF4 optuna seaborn cartopy ace_tools_open
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p/'.git').exists(): BASE_PATH = p; break

print(f"▶️ Usando ruta base: {BASE_PATH}")

# 2) Rutas y Logger
model_dir    = BASE_PATH/'models'/'output'/'trained_models'
model_dir.mkdir(parents=True, exist_ok=True)
features_nc  = BASE_PATH/'models'/'output'/'features_fusion_branches.nc'
full_nc      = BASE_PATH/'data'/'output'/'complete_dataset_with_features_with_clusters_elevation_with_windows.nc'
RESULTS_CSV  = model_dir/'training_metrics.csv'

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger()

# 3) Imports principales
import xarray as xr
import numpy   as np
import pandas  as pd
import tensorflow as tf
from sklearn.preprocessing    import StandardScaler, LabelEncoder
from sklearn.metrics          import mean_squared_error, mean_absolute_error, r2_score
from sklearn.linear_model     import HuberRegressor
import matplotlib.pyplot       as plt
import cartopy.crs             as ccrs
import cartopy.feature         as cfeature
import traceback

# 4) Parámetros globales
INPUT_WINDOW = 60
HORIZON      = 3
TIME_VARS    = ['year','month']  # se añadirán automáticamente

# 5) Funciones auxiliares
def evaluate_metrics(y_true, y_pred):
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae  = mean_absolute_error(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred)/(y_true + 1e-5))*100)
    r2   = r2_score(y_true, y_pred) if np.var(y_true)>0 else np.nan
    return rmse, mae, mape, r2

class DataGenerator(tf.keras.utils.Sequence):
    """Genera batches para Keras sin cargar todo en memoria."""
    def __init__(self, X, Y, batch_size=16):
        self.X, self.Y = X.astype(np.float32), Y.astype(np.float32)
        self.batch_size = batch_size
    def __len__(self): return int(np.ceil(len(self.X)/self.batch_size))
    def __getitem__(self, i):
        sl = slice(i*self.batch_size, (i+1)*self.batch_size)
        return self.X[sl], self.Y[sl]

def build_lstm(input_shape, horizon, n_cells):
    m = tf.keras.models.Sequential([
        tf.keras.layers.Input(input_shape),
        tf.keras.layers.LSTM(64),
        tf.keras.layers.Dense(horizon * n_cells),
        tf.keras.layers.Reshape((horizon, n_cells))
    ])
    m.compile('adam','mse')
    return m

def build_cnn_lstm(input_shape, horizon, n_cells):
    m = tf.keras.models.Sequential([
        tf.keras.layers.Input((*input_shape,1)),
        tf.keras.layers.Conv2D(32,(3,3),padding='same',activation='relu'),
        tf.keras.layers.Reshape((input_shape[0], input_shape[1]*32)),
        tf.keras.layers.LSTM(64),
        tf.keras.layers.Dense(horizon * n_cells),
        tf.keras.layers.Reshape((horizon, n_cells))
    ])
    m.compile('adam','mse')
    return m

def build_elm():
    # Usamos HuberRegressor como proxy de regresión robusta
    return HuberRegressor()

def plot_map(data, title, date_label, cmap='Blues', vmin=None, vmax=None):
    fig = plt.figure(figsize=(6,5))
    ax  = plt.axes(projection=ccrs.PlateCarree())
    mesh = ax.pcolormesh(lon, lat, data, cmap=cmap, vmin=vmin, vmax=vmax,
                         transform=ccrs.PlateCarree(), shading='nearest')
    ax.coastlines(); ax.add_feature(cfeature.BORDERS, linestyle=':')
    ax.add_feature(cfeature.LAND, facecolor='lightgray')
    plt.colorbar(mesh, ax=ax, orientation='vertical', label=title)
    ax.set_title(f"{title}\n{date_label}")
    ax.gridlines(draw_labels=True)
    plt.show()

# 6) Carga de datos
logger.info("📂 Cargando datos...")
ds_full     = xr.open_dataset(full_nc)
ds_features = xr.open_dataset(features_nc)
lat = ds_full.latitude.values
lon = ds_full.longitude.values
n_cells = len(lat)*len(lon)

# 7) Preparar series de tiempo y variables de tiempo
time = ds_full.time.values
years = ds_full['year'].values
months= ds_full['month'].values

# 8) Modelos y branches
methods = ['CEEMDAN','TVFEMD','FUSION']
branches = ['high','medium','low']
TARGET_VAR = 'total_precipitation'

metrics = []

for method in methods:
    for branch in branches:
        varname = f"{method}_{branch}"
        if varname not in ds_features:
            logger.warning(f"⚠️ No existe variable {varname} en features.")
            continue
        logger.info(f"\n▶️ Entrenando: {method} - {branch}")

        # --- Extraer X,y ---
        Xarr = ds_features[varname].values  # (T,lat,lon)
        yarr = ds_full[TARGET_VAR].values   # (T,lat,lon)

        T,_,_= Xarr.shape
        X = Xarr.reshape(T, n_cells)
        y = yarr.reshape(T, n_cells)

        # añadir time vars como features
        tv = np.stack([years, months], axis=1)  # (T,2)
        # repetimos por celda
        tv_rep = np.repeat(tv[:,None,:], n_cells, axis=1)  # (T, cells,2)
        X_full = np.concatenate([X[...,None], tv_rep], axis=2)  # (T, cells, 1+2)

        # filtrar nan
        mask = ~np.isnan(y)
        Xf = X_full[mask]
        yf = y[mask]

        # --- Secuencias de entrenamiento ---
        seq_X, seq_y = [], []
        for i in range(len(Xf)-INPUT_WINDOW-HORIZON):
            seq_X.append(Xf[i:i+INPUT_WINDOW])
            seq_y.append(yf[i+INPUT_WINDOW:i+INPUT_WINDOW+HORIZON])
        seq_X = np.array(seq_X)  # (N,window,cells,feat)
        seq_y = np.array(seq_y)  # (N,horizon,cells)

        if len(seq_X)==0:
            logger.warning("❌ Secuencias vacías, omito.")
            continue

        # split 70/30
        split = int(0.7*len(seq_X))
        X_tr, X_va = seq_X[:split], seq_X[split:]
        y_tr, y_va = seq_y[:split], seq_y[split:]

        # escalado
        sx = StandardScaler()
        sy = StandardScaler()
        X_tr_flat = X_tr.reshape(-1, X_tr.shape[-1])
        X_va_flat = X_va.reshape(-1, X_va.shape[-1])
        X_tr_s = sx.fit_transform(X_tr_flat).reshape(X_tr.shape)
        X_va_s = sx.transform(X_va_flat).reshape(X_va.shape)

        y_tr_flat = y_tr.reshape(-1,1)
        y_va_flat = y_va.reshape(-1,1)
        y_tr_s = sy.fit_transform(y_tr_flat).reshape(y_tr.shape)
        y_va_s = sy.transform(y_va_flat).reshape(y_va.shape)

        # aplanamiento para Keras
        flat_dim = n_cells * X_tr.shape[-1]
        X_tr_m = X_tr_s.reshape(-1, INPUT_WINDOW, flat_dim)
        X_va_m = X_va_s.reshape(-1, INPUT_WINDOW, flat_dim)
        y_tr_m = y_tr_s.reshape(-1, HORIZON, n_cells)
        y_va_m = y_va_s.reshape(-1, HORIZON, n_cells)

        # 9) Construir modelo
        model_name = f"{method}_{branch}"
        try:
            if branch=='high':
                model = build_lstm((INPUT_WINDOW,flat_dim), HORIZON, n_cells)
            elif branch=='medium':
                model = build_cnn_lstm((INPUT_WINDOW,flat_dim), HORIZON, n_cells)
                # re-dimensionar entrada
                X_tr_m = X_tr_m[...,None]
                X_va_m = X_va_m[...,None]
            else:  # low
                model = build_elm()
        except Exception:
            logger.error("❌ Error al construir modelo", exc_info=True)
            continue

        # 10) Entrenamiento
        try:
            if branch in ['high','medium']:
                gen_tr = DataGenerator(X_tr_m, y_tr_m, batch_size=16)
                gen_va = DataGenerator(X_va_m, y_va_m, batch_size=16)
                cb = tf.keras.callbacks.EarlyStopping('val_loss',patience=5,restore_best_weights=True)
                hist = model.fit(gen_tr, validation_data=gen_va, epochs=50,
                                 verbose=1, callbacks=[cb])
            else:
                # ELM / regresión robusta: aplanar horizon*cell
                Xr = X_tr_m.reshape(-1, INPUT_WINDOW*flat_dim)
                yr = y_tr_m.reshape(-1, n_cells)
                model.fit(Xr, yr)
            # guardar
            save_path = model_dir/f"{model_name}.h5"
            if branch in ['high','medium']:
                model.save(save_path, include_optimizer=False)
            else:
                import joblib
                joblib.dump(model, save_path.with_suffix('.joblib'))
            logger.info(f"💾 Modelo guardado: {save_path.name}")
        except Exception:
            logger.error("❌ Error en entrenamiento", exc_info=True)
            continue

        # 11) Evaluación histórica (últimas 3)
        try:
            # tomar últimas 3 de validation
            Xe = X_va_m[-3:]
            ye = y_va_m[-3:]
            if branch!='low': ye_hat_s = model.predict(Xe,verbose=0)
            else:
                Xe_flat = Xe.reshape(-1, INPUT_WINDOW*flat_dim)
                ye_hat_flat = model.predict(Xe_flat)
                ye_hat_s = ye_hat_flat.reshape(-1, HORIZON, n_cells)
            ye_hat   = sy.inverse_transform(ye_hat_s.reshape(-1,1)).reshape(ye.shape)
            ye_true  = sy.inverse_transform(ye.reshape(-1,1)).reshape(ye.shape)

            # mapas y métricas
            for h in range(HORIZON):
                true_map = ye_true[:,h].reshape(-1,n_cells)[-1].reshape(len(lat),len(lon))
                pred_map = ye_hat[:,h].reshape(-1,n_cells)[-1].reshape(len(lat),len(lon))
                mape_map = np.clip(np.abs((true_map-pred_map)/(true_map+1e-5))*100,0,100)
                date_label = pd.to_datetime(time[-(HORIZON-h)]).strftime("%Y-%m")
                plot_map(pred_map, f"{model_name} Eval Pred h={h+1}", date_label)
                plot_map(mape_map, f"{model_name} Eval MAPE h={h+1}", date_label,'Reds',0,100)
                rmse, mae, mape, r2 = evaluate_metrics(true_map.ravel(), pred_map.ravel())
                metrics.append({
                    'model':model_name,'branch':branch,'type':'evaluation',
                    'horizon':h+1,'date':date_label,
                    'RMSE':rmse,'MAE':mae,'MAPE':mape,'R2':r2
                })
        except Exception:
            logger.error("❌ Error en evaluación histórica", exc_info=True)

        # 12) Forecast 3 meses siguientes
        try:
            # usar última ventana completa
            Xf = seq_X[-1:]  # shape (1,window,cells,feat)
            Xf_s = sx.transform(Xf.reshape(-1,Xf.shape[-1])).reshape(Xf.shape)
            Xf_m = Xf_s.reshape(1,INPUT_WINDOW,flat_dim) if branch!='medium' else Xf_s.reshape(1,INPUT_WINDOW,flat_dim,1)
            if branch!='low': yf_hat_s = model.predict(Xf_m,verbose=0)
            else:
                xf_flat = Xf_m.reshape(-1,INPUT_WINDOW*flat_dim)
                yh_flat = model.predict(xf_flat)
                yf_hat_s = yh_flat.reshape(1,HORIZON,n_cells)
            yf_hat = sy.inverse_transform(yf_hat_s.reshape(-1,1)).reshape(HORIZON,n_cells)
            # fechas forecast
            last_time = pd.to_datetime(time[-1])
            fc_dates = pd.date_range(last_time+pd.DateOffset(months=1), periods=HORIZON, freq='MS')
            for h in range(HORIZON):
                pm = yf_hat[h].reshape(len(lat),len(lon))
                date_label = fc_dates[h].strftime("%Y-%m")
                plot_map(pm, f"{model_name} Forecast h={h+1}", date_label)
                metrics.append({
                    'model':model_name,'branch':branch,'type':'forecast',
                    'horizon':h+1,'date':date_label,
                    'RMSE':np.nan,'MAE':np.nan,'MAPE':np.nan,'R2':np.nan
                })
        except Exception:
            logger.error("❌ Error en forecast", exc_info=True)

# 13) Guardar métricas
dfm = pd.DataFrame(metrics)
dfm.to_csv(RESULTS_CSV, index=False)
import ace_tools_open as tools
tools.display_dataframe_to_user(name="Training & Eval Metrics", dataframe=dfm)

logger.info("🏁 Entrenamiento, evaluación y forecast completos.")


Mounted at /content/drive
Collecting netCDF4
  Downloading netCDF4-1.7.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting PyEMD
  Downloading pyemd-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting emd
  Downloading emd-0.8.1-py3-none-any.whl.metadata (5.0 kB)
Collecting cartopy
  Downloading Cartopy-0.24.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Collecting EMD-signal
  Downloading EMD_signal-1.6.4-py3-none-any.whl.metadata (8.9 kB)
Collecting cftime (from netCDF4)
  Downloading cftime-1.6.4.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Collecting dcor (from emd)
  Downloading dcor-0.6-py3-none-any.whl.metadata (6.2 kB)
Collecting sparse (from emd)
  Downloading sparse-0.16.0-py2.py3-none-any.whl.metadata (5.3 kB)
Collecting pathos>=0.2.1 (from EMD-signal)
  Downloading pathos-0.3.4-py3-none-any.whl.metadata (11 kB)
Collecting ppft

  frames.append(imageio.imread(tmp.name))
  frames.append(imageio.imread(tmp.name))
  frames.append(imageio.imread(tmp.name))


Epoch 1/50


  self._warn_if_super_not_called()


[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m326s[0m 15s/step - loss: 52328.7305 - val_loss: 46610.0312
Epoch 2/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m324s[0m 16s/step - loss: 51222.5312 - val_loss: 46316.1836
Epoch 3/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m323s[0m 15s/step - loss: 50097.7812 - val_loss: 46025.3906
Epoch 4/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m326s[0m 16s/step - loss: 51148.0586 - val_loss: 45737.1406
Epoch 5/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m319s[0m 15s/step - loss: 50336.8281 - val_loss: 45450.9258
Epoch 6/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m325s[0m 16s/step - loss: 51018.7539 - val_loss: 45155.8906
Epoch 7/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m399s[0m 16s/step - loss: 47123.3281 - val_loss: 44860.5195
Epoch 8/50
[1m21/21[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m323s[0m 15s/step - loss: 48327.4688 - val_loss

  frames.append(imageio.imread(tmp.name))
  frames.append(imageio.imread(tmp.name))
  frames.append(imageio.imread(tmp.name))
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res, self.max_iter)
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res, self.max_iter)
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res, self.max_iter)
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale 

KeyboardInterrupt: 