<a href="https://colab.research.google.com/github/ninja-marduk/ml_precipitation_prediction/blob/feature%2Fhybrid-models/models/hybrid_models_ST-HybridWaveStack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# -*- coding: utf-8 -*-
# Script: Precip Training w/ Spatial AE + GRU Seq2Seq (robust evaluation)
# =============================================================================

import os, sys, tempfile, logging, warnings
from pathlib import Path

# 0) Suppress unhelpful warnings
from sklearn.exceptions import ConvergenceWarning
warnings.filterwarnings("ignore", category=ConvergenceWarning)
from cartopy.io import DownloadWarning
warnings.filterwarnings("ignore", category=DownloadWarning)

# 1) Detect Colab vs Local
IN_COLAB = "google.colab" in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=True)
    BASE_PATH = Path("/content/drive/MyDrive/ml_precipitation_prediction")
    !pip install -q xarray netCDF4 optuna seaborn cartopy ace_tools_open
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p/".git").exists():
            BASE_PATH = p; break

print(f"▶️ Base path: {BASE_PATH}")

# 2) Paths and logger
MODEL_DIR   = BASE_PATH/"models"/"output"/"trained_models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
FEATURES_NC = BASE_PATH/"models"/"output"/"features_fusion_branches.nc"
FULL_NC     = BASE_PATH/"data"/"output"/"complete_dataset_with_features_with_clusters_elevation_with_windows.nc"
BOYACA_SHP  = BASE_PATH/"data"/"input"/"shapes"/"MGN_Departamento.shp"
RESULTS_CSV = MODEL_DIR/"training_metrics.csv"
GIF_DIR     = MODEL_DIR/"gifs"
GIF_DIR.mkdir(exist_ok=True)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("precip")

# 3) Imports
import xarray           as xr
import numpy            as np
import pandas           as pd
import geopandas        as gpd
import imageio.v2       as imageio
import matplotlib.pyplot as plt
import cartopy.crs      as ccrs
import cartopy.feature  as cfeature
import tensorflow       as tf

from sklearn.preprocessing import StandardScaler
from sklearn.metrics       import mean_squared_error, mean_absolute_error, r2_score

# 4) Hardware configuration
import psutil
CORES     = os.cpu_count()
AVAIL_RAM = psutil.virtual_memory().available / (1024**3)
gpus      = tf.config.list_physical_devices("GPU")
if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)
    DEVICE = "/GPU:0"
    logger.info(f"🖥 GPU detected: {gpus[0].name}")
else:
    DEVICE = "/CPU:0"
    try:
        tf.config.threading.set_inter_op_parallelism_threads(CORES)
        tf.config.threading.set_intra_op_parallelism_threads(CORES)
    except RuntimeError:
        pass
    logger.info(f"⚙ CPU cores: {CORES}, RAM free: {AVAIL_RAM:.1f} GB")

# 5) Constants
INPUT_WINDOW = 60
HORIZON      = 3
METHODS      = ["CEEMDAN","TVFEMD","FUSION"]
BRANCHES     = ["high","medium","low"]
TARGET_VAR   = "total_precipitation"

# 6) Utility functions
def evaluate_metrics(y_true, y_pred):
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae  = mean_absolute_error(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred)/(y_true + 1e-5))*100)
    r2   = r2_score(y_true, y_pred) if np.var(y_true)>0 else np.nan
    return rmse, mae, mape, r2

def plot_map(arr, title, date_label):
    fig, ax = plt.subplots(figsize=(6,5),
                           subplot_kw={"projection":ccrs.PlateCarree()})
    mesh = ax.pcolormesh(lon, lat, arr, cmap="viridis", shading="nearest",
                         transform=ccrs.PlateCarree())
    ax.add_geometries(boyaca.geometry, ccrs.PlateCarree(),
                      edgecolor="k", facecolor="none")
    ax.coastlines(); ax.add_feature(cfeature.BORDERS, linestyle=":")
    ax.set_title(f"{title}\n{date_label}")
    cb = plt.colorbar(mesh, ax=ax, pad=0.02)
    cb.set_label(title)
    return fig

# 7) Load datasets
logger.info("📂 Loading CHIRPS & features")
ds_full = xr.open_dataset(FULL_NC)
ds_feat = xr.open_dataset(FEATURES_NC)
boyaca  = gpd.read_file(BOYACA_SHP)
boyaca = boyaca.to_crs(epsg=4326) if boyaca.crs else boyaca.set_crs(epsg=4326)
lat   = ds_full.latitude.values
lon   = ds_full.longitude.values
times = pd.to_datetime(ds_full.time.values)
T, ny, nx = len(times), len(lat), len(lon)
n_cells   = ny * nx

# 8) Train spatial autoencoder once
#    (same as before—omitted for brevity)
# ... your AE training here ...
# assume you end with `encoder` and `scaler_feat`

# 9) GRU seq2seq builder
def build_gru_precip(latent_dim):
    inp = tf.keras.layers.Input((INPUT_WINDOW, latent_dim))
    e   = tf.keras.layers.GRU(128)(inp)
    r   = tf.keras.layers.RepeatVector(HORIZON)(e)
    d   = tf.keras.layers.GRU(128, return_sequences=True)(r)
    out = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(n_cells))(d)
    m   = tf.keras.Model(inp,out)
    m.compile("adam","mse")
    return m

# 10) Branch processor
def process_branch(method, branch):
    var = f"{method}_{branch}"
    if var not in ds_feat.data_vars:
        logger.warning(f"⚠ Missing {var}, skipping.")
        return []

    logger.info(f"▶ Processing {var}")
    # 10.1 Encode to (T,latent_dim)
    feats = ds_feat[var].values.reshape(T, -1)
    feats_scaled = scaler_feat.transform(feats)
    latd = encoder.predict(feats_scaled, verbose=0)  # (T,latent_dim)

    # 10.2 Precip target
    prec = ds_full[TARGET_VAR].values.reshape(T, -1)
    scaler_y = StandardScaler().fit(prec)
    ys       = scaler_y.transform(prec)  # (T,n_cells)

    # 10.3 Build windows
    N = T - INPUT_WINDOW - HORIZON + 1
    if N <= 1:
        logger.warning(f"Not enough windows ({N}) for {var}, skipping.")
        return []
    Xw = np.stack([latd[i:i+INPUT_WINDOW] for i in range(N)], axis=0)
    yw = np.stack([ys[i+INPUT_WINDOW:i+INPUT_WINDOW+HORIZON] for i in range(N)], axis=0)
    split = int(0.7 * N)
    X_tr, X_va = Xw[:split], Xw[split:]
    y_tr, y_va = yw[:split], yw[split:]

    # 10.4 Train or load GRU model
    mpath = MODEL_DIR/f"{var}.keras"
    if mpath.exists():
        model = tf.keras.models.load_model(mpath)
        logger.info(f"⏩ Loaded model {mpath.name}")
    else:
        with tf.device(DEVICE):
            model = build_gru_precip(latd.shape[1])
            model.fit(X_tr, y_tr,
                      epochs=30, batch_size=16,
                      validation_split=0.1,
                      callbacks=[tf.keras.callbacks.EarlyStopping("val_loss",patience=5,restore_best_weights=True)],
                      verbose=1)
            model.save(mpath)
            logger.info(f"   ● Saved model {mpath.name}")

    # 10.5 Robust evaluation of up to last 3 windows
    k = min(3, len(X_va))
    X_eval = X_va[-k:]
    try:
        preds_s = model.predict(X_eval, verbose=0)
        if preds_s.ndim != 3 or preds_s.shape[1] != HORIZON or preds_s.shape[2] != n_cells:
            raise ValueError(f"Unexpected pred shape {preds_s.shape}")
    except Exception as e:
        logger.error(f"❌ Predict failed for {var}: {e}")
        return []

    preds = scaler_y.inverse_transform(preds_s.reshape(-1,n_cells)).reshape(k,HORIZON,n_cells)
    true  = scaler_y.inverse_transform(y_va[-k:].reshape(-1,n_cells)).reshape(k,HORIZON,n_cells)

    results, frames = [], []
    for i in range(k):
        for h in range(HORIZON):
            pm = preds[i,h].reshape(ny,nx)
            tm = true[i,h].reshape(ny,nx)
            date = (times[split + INPUT_WINDOW + i]).strftime("%Y-%m")

            fig = plot_map(pm, f"{var} Pred h={h+1}", date)
            tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
            fig.savefig(tmp.name, dpi=120); plt.close(fig)
            frames.append(imageio.imread(tmp.name)); os.unlink(tmp.name)

            rmse, mae, mape, r2 = evaluate_metrics(tm.ravel(), pm.ravel())
            results.append({
                "model": var,
                "branch": branch,
                "horizon": h+1,
                "type": "evaluation",
                "date": date,
                "RMSE": rmse,
                "MAE": mae,
                "MAPE": mape,
                "R2": r2
            })

    gif_path = GIF_DIR/f"{var}_eval.gif"
    imageio.mimsave(str(gif_path), frames, duration=1.5)
    logger.info(f"   ✓ {var} done, GIF→{gif_path.name}")
    return results

# 11) Run over all branches
all_metrics = []
for m in METHODS:
    for b in BRANCHES:
        all_metrics += process_branch(m,b)

# 12) Save & display
dfm = pd.DataFrame(all_metrics)
dfm.to_csv(RESULTS_CSV, index=False)
import ace_tools_open as tools
tools.display_dataframe_to_user(name="Precip Metrics", dataframe=dfm)

logger.info("🏁 All branches processed successfully.")


2025-05-19 12:39:59,352 INFO ⚙ CPU cores: 10, RAM free: 3.0 GB
2025-05-19 12:39:59,353 INFO 📂 Loading CHIRPS & features
2025-05-19 12:39:59,449 INFO ▶ Processing CEEMDAN_high


▶️ Base path: /Users/riperez/Conda/anaconda3/envs/precipitation_prediction/github.com/ml_precipitation_prediction
Epoch 1/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 39ms/step - loss: 0.8929 - val_loss: 0.9586
Epoch 2/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 27ms/step - loss: 0.5824 - val_loss: 0.8297
Epoch 3/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 27ms/step - loss: 0.4503 - val_loss: 0.7250
Epoch 4/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 27ms/step - loss: 0.3626 - val_loss: 0.6569
Epoch 5/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - loss: 0.2913 - val_loss: 0.6455
Epoch 6/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 28ms/step - loss: 0.2703 - val_loss: 0.5996
Epoch 7/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - loss: 0.2602 - val_loss: 0.6210
Epoch 8/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0

2025-05-19 12:40:12,996 INFO    ● Saved model CEEMDAN_high.keras
2025-05-19 12:40:14,524 INFO    ✓ CEEMDAN_high done, GIF→CEEMDAN_high_eval.gif
2025-05-19 12:40:14,524 INFO ▶ Processing CEEMDAN_medium


Epoch 1/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 38ms/step - loss: 0.9530 - val_loss: 1.2026
Epoch 2/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 26ms/step - loss: 0.7789 - val_loss: 1.0949
Epoch 3/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 39ms/step - loss: 0.6017 - val_loss: 1.1116
Epoch 4/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - loss: 0.4892 - val_loss: 1.0841
Epoch 5/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 30ms/step - loss: 0.3537 - val_loss: 0.8827
Epoch 6/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 29ms/step - loss: 0.3169 - val_loss: 0.9562
Epoch 7/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 28ms/step - loss: 0.2951 - val_loss: 1.0076
Epoch 8/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 28ms/step - loss: 0.2679 - val_loss: 0.9314
Epoch 9/30
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━

2025-05-19 12:40:21,875 INFO    ● Saved model CEEMDAN_medium.keras
2025-05-19 12:40:22,954 INFO    ✓ CEEMDAN_medium done, GIF→CEEMDAN_medium_eval.gif
2025-05-19 12:40:22,954 INFO ▶ Processing CEEMDAN_low
2025-05-19 12:40:23,103 INFO ⏩ Loaded model CEEMDAN_low.keras
2025-05-19 12:40:24,138 INFO    ✓ CEEMDAN_low done, GIF→CEEMDAN_low_eval.gif
2025-05-19 12:40:24,138 INFO ▶ Processing TVFEMD_high
2025-05-19 12:40:24,267 INFO ⏩ Loaded model TVFEMD_high.keras
2025-05-19 12:40:25,296 INFO    ✓ TVFEMD_high done, GIF→TVFEMD_high_eval.gif
2025-05-19 12:40:25,297 INFO ▶ Processing TVFEMD_medium
2025-05-19 12:40:25,424 INFO ⏩ Loaded model TVFEMD_medium.keras
2025-05-19 12:40:26,461 INFO    ✓ TVFEMD_medium done, GIF→TVFEMD_medium_eval.gif
2025-05-19 12:40:26,462 INFO ▶ Processing TVFEMD_low
2025-05-19 12:40:26,589 INFO ⏩ Loaded model TVFEMD_low.keras
2025-05-19 12:40:27,887 INFO    ✓ TVFEMD_low done, GIF→TVFEMD_low_eval.gif
2025-05-19 12:40:27,888 INFO ▶ Processing FUSION_high
2025-05-19 12:40:28,

Precip Metrics


0
Loading ITables v2.4.0 from the internet...  (need help?)


2025-05-19 12:40:31,433 INFO 🏁 All branches processed successfully.
