In [18]:
from __future__ import annotations

import sys
from pathlib import Path
from typing import List, Dict, Any
import numpy as np
import pandas as pd
import xarray as xr
import tensorflow as tf
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import geopandas as gpd
import imageio.v2 as imageio

from tensorflow.keras.layers import (
    Input, ConvLSTM2D, GRU, Flatten, RepeatVector, Reshape, TimeDistributed,
    Dense, MultiHeadAttention, Add, LayerNormalization
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

## ╭─────────────────────────── Rutas ──────────────────────────╮
# ▶️ Path configuration
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
    # Instalar dependencias necesarias
    !pip install -r requirements.txt
    !pip install xarray netCDF4 optuna matplotlib seaborn lightgbm xgboost scikit-learn ace_tools_open cartopy geopandas
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p / '.git').exists():
            BASE_PATH = p; break

import cartopy.crs as ccrs
print('BASE_PATH =', BASE_PATH)

# Dataset paths
DATA_DIR = BASE_PATH/'data'/'output'
MODEL_OUTPUT_DIR = BASE_PATH/'models'/'output'
MODEL_DIR = BASE_PATH/'models'/'output'/'HybridLSTMModels'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_INPUT_DIR = BASE_PATH/'data'/'input'/'shapes'
MODEL_INPUT_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_DIR = MODEL_DIR/'images'
IMAGE_DIR.mkdir(exist_ok=True)
FULL_NC = DATA_DIR/'complete_dataset_with_features_with_clusters_elevation_windows_imfs_with_onehot_elevation.nc'
FULL_NC_CLEAN = DATA_DIR/'complete_dataset_with_features_with_clusters_elevation_windows_imfs_with_onehot_elevation_clean.nc'
dept_gdf = gpd.read_file(MODEL_INPUT_DIR/'MGN_Departamento.shp')

BASE_MODEL_DIR = MODEL_DIR
GIF_DIR        = MODEL_DIR / "gifs"
GIF_DIR.mkdir(exist_ok=True)

# ╭──────────────────────── Dataset & Shapes ──────────────────╮
ds          = xr.open_dataset(FULL_NC)
# ╰────────────────────────────────────────────────────────────╯

LAG_VARS = ['total_precipitation_lag1',
            'total_precipitation_lag2',
            'total_precipitation_lag12']

# ============================================================
print("\n📊  Resumen global de NaNs")
print("─"*55)
for var in LAG_VARS:
    arr    = ds[var].values
    total  = arr.size
    n_nans = int(np.isnan(arr).sum())
    print(f"{var:<28}: {n_nans:>8,} / {total:,}  ({n_nans/total:6.2%})")

# ============================================================
print("\n🕒  Fechas con NaNs por variable")
print("─"*55)
for var in LAG_VARS:
    arr         = ds[var].values
    nan_per_ts  = np.isnan(arr).reshape(len(ds.time), -1).sum(axis=1)
    if nan_per_ts.sum() == 0:
        print(f"{var}: sin NaNs ✔️")
        continue

    df_nan = (pd
              .DataFrame({"time": pd.to_datetime(ds.time.values),
                          "na_cells": nan_per_ts})
              .query("na_cells > 0"))

    # primeras 3 y últimas 3 fechas con NaNs
    head = df_nan.head(3).to_string(index=False)
    tail = df_nan.tail(3).to_string(index=False)
    last = df_nan["time"].iloc[-1].strftime("%Y-%m")

    print(f"\n{var}")
    print(head)
    if len(df_nan) > 6:
        print("   …")
    print(tail)
    print(f"   ⇢  última fecha con NaNs: {last}")

# ============================================================
# Primera fecha en la que las TRES variables están 100 % limpias
# ------------------------------------------------------------
def last_nan_index(var: str) -> int:
    """Índice del último timestamp que contiene al menos un NaN en `var`."""
    nan_per_ts = np.isnan(ds[var].values).reshape(len(ds.time), -1).sum(axis=1)
    idxs       = np.where(nan_per_ts > 0)[0]
    return idxs[-1] if len(idxs) else -1

last_nan_any = max(last_nan_index(v) for v in LAG_VARS)
first_clean  = pd.to_datetime(ds.time.values[last_nan_any + 1])

print("\nPrimera fecha 100 % libre de NaNs en TODOS los lags:",
      first_clean.strftime("%Y-%m"))

ds_clean = ds.sel(time=~(ds['time.year'] == 1981))   # descarta TODO 1981

print("🔎  Timestamps antes :", len(ds.time))
print("🔎  Timestamps después:", len(ds_clean.time))

# 3) Guarda nuevo archivo NetCDF
ds_clean.to_netcdf(FULL_NC_CLEAN, mode='w')
print(f"💾  Dataset sin 1981 guardado en {FULL_NC_CLEAN}")

# 4) (-- opcional --)  verifica que ya no queden NaNs en los lags
LAG_VARS = ['total_precipitation_lag1',
            'total_precipitation_lag2',
            'total_precipitation_lag12']

print("\n📊  NaNs restantes tras quitar 1981")
print("─"*50)
for var in LAG_VARS:
    n_nan = int(np.isnan(ds_clean[var].values).sum())
    print(f"{var:<28}: {n_nan:,} NaNs")


BASE_PATH = /Users/riperez/Conda/anaconda3/envs/precipitation_prediction/github.com/ml_precipitation_prediction

📊  Resumen global de NaNs
───────────────────────────────────────────────────────
total_precipitation_lag1    :    3,965 / 2,101,450  ( 0.19%)
total_precipitation_lag2    :    7,930 / 2,101,450  ( 0.38%)
total_precipitation_lag12   :   47,580 / 2,101,450  ( 2.26%)

🕒  Fechas con NaNs por variable
───────────────────────────────────────────────────────

total_precipitation_lag1
      time  na_cells
1981-01-01      3965
      time  na_cells
1981-01-01      3965
   ⇢  última fecha con NaNs: 1981-01

total_precipitation_lag2
      time  na_cells
1981-01-01      3965
1981-02-01      3965
      time  na_cells
1981-01-01      3965
1981-02-01      3965
   ⇢  última fecha con NaNs: 1981-02

total_precipitation_lag12
      time  na_cells
1981-01-01      3965
1981-02-01      3965
1981-03-01      3965
   …
      time  na_cells
1981-10-01      3965
1981-11-01      3965
1981-12-01      39

In [None]:
from __future__ import annotations

import sys
from pathlib import Path
from typing import List, Dict, Any

import numpy as np
import pandas as pd
import xarray as xr
import tensorflow as tf
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
import geopandas as gpd
import imageio.v2 as imageio

from tensorflow.keras.layers import (
    Input, ConvLSTM2D, GRU, Flatten, RepeatVector, Reshape, TimeDistributed,
    Dense, MultiHeadAttention, Add, LayerNormalization
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

## ╭─────────────────────────── Rutas ──────────────────────────╮
# ▶️ Path configuration
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
    # Instalar dependencias necesarias
    !pip install -r requirements.txt
    !pip install xarray netCDF4 optuna matplotlib seaborn lightgbm xgboost scikit-learn ace_tools_open cartopy geopandas
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p / '.git').exists():
            BASE_PATH = p; break

import cartopy.crs as ccrs
print('BASE_PATH =', BASE_PATH)

# Dataset paths
DATA_DIR = BASE_PATH/'data'/'output'
MODEL_OUTPUT_DIR = BASE_PATH/'models'/'output'
MODEL_DIR = BASE_PATH/'models'/'output'/'HybridLSTMModels'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_INPUT_DIR = BASE_PATH/'data'/'input'/'shapes'
MODEL_INPUT_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_DIR = MODEL_DIR/'images'
IMAGE_DIR.mkdir(exist_ok=True)
FULL_NC = DATA_DIR/'complete_dataset_with_features_with_clusters_elevation_windows_imfs_with_onehot_elevation_clean.nc'
dept_gdf = gpd.read_file(MODEL_INPUT_DIR/'MGN_Departamento.shp')

BASE_MODEL_DIR = MODEL_DIR
GIF_DIR        = MODEL_DIR / "gifs"
GIF_DIR.mkdir(exist_ok=True)

# ╰────────────────────────────────────────────────────────────╯

# ╭──────────────────────── Dataset & Shapes ──────────────────╮
ds          = xr.open_dataset(FULL_NC)
lat, lon    = len(ds.latitude), len(ds.longitude)
cells       = lat * lon
# ╰────────────────────────────────────────────────────────────╯

# ╭──────────────────── Hyper‑parámetros globales ─────────────╮
INPUT_WINDOW   = 60
HORIZON        = 3
TARGET_VAR     = 'total_precipitation'
EPOCHS         = 12
BATCH_SIZE     = 4           # tamaño pequeño → menor RAM GPU
PATIENCE       = 10
LR             = 1e-3
# ╰────────────────────────────────────────────────────────────╯


# ╭────────────────────── Modelo base ConvLSTM ────────────────╮

def _build_convlstm_ed(*,input_window: int,output_horizon: int,spatial_height: int,spatial_width: int,n_features: int,n_filters: int = 64,n_heads: int = 4,use_attention: bool = True,lr: float = LR) -> Model:
    """Construye un Encoder‑Decoder ConvLSTM.

    Si `use_attention=False` se omite la capa Multi‑Head Attention.
    La salida es `(B, T_out, H, W, 1)`.
    """
    inputs = Input(shape=(input_window, spatial_height, spatial_width, n_features), name="enc_input")

    # ── Encoder ────────────────────────────────────────────
    x = ConvLSTM2D(n_filters,   (3, 3), padding='same', return_sequences=True,  name="enc_lstm_1")(inputs)
    x = ConvLSTM2D(n_filters//2,(3, 3), padding='same', return_sequences=False, name="enc_lstm_2")(x)

    # ── Flatten + contexto temporal ────────────────────────
    flat = Flatten(name="flatten_spatial")(x)
    ctx  = RepeatVector(output_horizon, name="context")(flat)  # (B, T_out, H·W·C)

    # ── Decoder GRU (temporal) ─────────────────────────────
    dec = GRU(2*n_filters, return_sequences=True, name="dec_gru")(ctx)

    if use_attention:
        attn = MultiHeadAttention(num_heads=n_heads, key_dim=n_filters, dropout=0.1, name="mha")(dec, dec)
        dec  = LayerNormalization(name="mha_norm")(Add(name="mha_add")([dec, attn]))

    # ── Proyección + reshape a grid ───────────────────────
    proj = TimeDistributed(Dense(spatial_height*spatial_width, activation='linear'), name="dense_proj")(dec)
    out  = Reshape((output_horizon, spatial_height, spatial_width, 1), name="reshape_out")(proj)

    model = Model(inputs, out, name="ConvLSTM_ED_Attn" if use_attention else "ConvLSTM_ED")
    model.compile(optimizer=Adam(lr), loss='mse')
    return model

# Factories ---------------------------------------------------

def factory_no_attn(**kw):
    return _build_convlstm_ed(use_attention=False, **kw)

def factory_attn(**kw):
    return _build_convlstm_ed(use_attention=True, **kw)
# ╰────────────────────────────────────────────────────────────╯

# ╭────────────────────────── Métricas ────────────────────────╮

def evaluate(y_true: np.ndarray, y_pred: np.ndarray):
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    mae  = mean_absolute_error(y_true, y_pred)
    mape = np.mean(np.abs((y_true - y_pred) / (y_true + 1e-5))) * 100
    r2   = r2_score(y_true, y_pred)
    return rmse, mae, mape, r2
# ╰────────────────────────────────────────────────────────────╯

# ╭──────────────────────── Quick‑plot ────────────────────────╮

def quick_plot(ax, data, cmap, title, date_label, vmin=None, vmax=None):
    mesh = ax.pcolormesh(ds.longitude, ds.latitude, data, cmap=cmap, shading='nearest', vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
    ax.coastlines(); ax.add_geometries(dept_gdf.geometry, ccrs.PlateCarree(), edgecolor='black', facecolor='none', linewidth=1)
    gl = ax.gridlines(draw_labels=True); gl.top_labels=False; gl.right_labels=False
    ax.set_title(f"{title}\n{date_label}", pad=12)
    return mesh
# ╰────────────────────────────────────────────────────────────╯

# ╭────────────────────── Experiments & Folds ─────────────────╮
# ▸ Solo mostramos los tres primeros niveles; añade los demás igual
BASE_FEATURES = [
    'year','month','month_sin','month_cos','doy_sin','doy_cos',
    'max_daily_precipitation','min_daily_precipitation','daily_precipitation_std',
    'elevation','slope','aspect'
]
ELEV_CLUSTER = ['elev_high','elev_med','elev_low']
KCE_FEATURES = BASE_FEATURES + ELEV_CLUSTER
PAFC_FEATURES= KCE_FEATURES + ['total_precipitation_lag1','total_precipitation_lag2','total_precipitation_lag12']

FOLDS = {'F1': {'year': 2018,'active': True}}

EXPERIMENTS: Dict[str, Dict[str, Any]] = {
    'ConvLSTM-ED': {
        'active': True,
        'feature_list': BASE_FEATURES,
        'builder': factory_attn, #factory_no_attn,
        'n_filters': 64,
        'n_heads'  : 4
    },
    'ConvLSTM-ED-KCE': {
        'active': True,
        'feature_list': KCE_FEATURES,
        'builder': factory_attn,
        'n_filters': 64,
        'n_heads'  : 4,
    },
    'ConvLSTM-ED-KCE-PAFC': {
        'active': True,
        'feature_list': PAFC_FEATURES,
        'builder': factory_attn,
        'n_filters': 96,
        'n_heads'  : 6,
    },
}
# ╰────────────────────────────────────────────────────────────╯

# ╭──────────────────── Ventanas deslizadas ───────────────────╮

def make_windows(mask:np.ndarray, allow_past_context:bool)->tuple[np.ndarray,np.ndarray]:
    """Genera ventanas **descartando** las que contienen NaNs.  # 🔸 NEW"""
    seq_X, seq_y = [], []
    lim = len(mask) - INPUT_WINDOW - HORIZON + 1
    for start in range(lim):
        end_w = start + INPUT_WINDOW; end_y = end_w + HORIZON
        if allow_past_context:
            if not mask[end_w:end_y].all():
                continue
        else:
            if not mask[start:end_y].all():
                continue
        Xw = Xarr[start:end_w]; yw = yarr[end_w:end_y]
        if np.isnan(Xw).any() or np.isnan(yw).any():
            continue  # 🔸 NEW — descarta ventana con NaNs
        seq_X.append(Xw); seq_y.append(yw)
    return np.array(seq_X), np.array(seq_y)


# ╭────────────────── Bucle principal de entrenamiento ────────╮
RESULTS: List[Dict[str, Any]] = []

# 🔸 NEW helper ------------------------------------------------

def _impute_nans(a:np.ndarray, per_feature_mean:np.ndarray|None=None, is_target:bool=False)->np.ndarray:
    """Imputa NaNs restantes (seguridad extra)."""
    if not np.isnan(a).any():
        return a
    if is_target:
        a[np.isnan(a)] = 0.0  # 🔸 NEW – 0 para y
        return a
    if per_feature_mean is None:
        raise ValueError('per_feature_mean required for imputing X')
    flat = a.reshape(-1, a.shape[-1])
    nan_idx = np.isnan(flat)
    for f in range(a.shape[-1]):
        flat[nan_idx[:,f], f] = per_feature_mean[f]  # 🔸 NEW
    return flat.reshape(a.shape)
# ╰────────────────────────────────────────────────────────────╯

def run_all_experiments():
    times = pd.to_datetime(ds.time.values)
    total = sum(e['active'] for e in EXPERIMENTS.values()) * sum(f['active'] for f in FOLDS.values())
    cnt   = 0

    for exp_name, exp_cfg in EXPERIMENTS.items():
        if not exp_cfg['active']:
            continue
        vars_     = exp_cfg['feature_list']
        builder   = exp_cfg['builder']      # fábrica específica
        n_filters = exp_cfg.get('n_filters',64)
        n_heads   = exp_cfg.get('n_heads',4)

        # ─ Pre‑load features por experimento ─────────────────────
        global Xarr, yarr
        Xarr = ds[vars_].to_array().transpose('time','latitude','longitude','variable').values.astype(np.float32)
        yarr = ds[TARGET_VAR].values.astype(np.float32)
        feats = Xarr.shape[-1]

        for fold_name, fold_cfg in FOLDS.items():
            if not fold_cfg['active']:
                continue
            cnt += 1
            year_val = fold_cfg['year']
            print(f"\n▶️  [{cnt}/{total}] {exp_name} – {fold_name} (val={year_val})")

            mask_val = times.year == year_val
            mask_tr  = ~mask_val
            if mask_val.sum() < HORIZON:
                print("⚠️ Año sin pasos suficientes → skip"); continue

            X_tr, y_tr = make_windows(mask_tr,  allow_past_context=False)
            X_va, y_va = make_windows(mask_val, allow_past_context=True)
            print(f"Ventanas train: {len(X_tr)} · val: {len(X_va)}")
            if len(X_tr)==0 or len(X_va)==0:
                print("⚠️ Sin ventanas válidas → skip"); continue

            # 🔸 NEW — Imputación de seguridad
            feat_mean = np.nanmean(X_tr.reshape(-1,feats),axis=0)
            X_tr = _impute_nans(X_tr,feat_mean); X_va=_impute_nans(X_va,feat_mean)
            y_tr = _impute_nans(y_tr,is_target=True); y_va=_impute_nans(y_va,is_target=True)
            
            # ─ Scaling (fit solo en train) ─────────────────────
            sx = StandardScaler().fit(X_tr.reshape(-1, feats))
            sy = StandardScaler().fit(y_tr.reshape(-1, 1))
            X_tr_sc = sx.transform(X_tr.reshape(-1, feats)).reshape(X_tr.shape)
            X_va_sc = sx.transform(X_va.reshape(-1, feats)).reshape(X_va.shape)
            y_tr_sc = sy.transform(y_tr.reshape(-1, 1)).reshape(y_tr.shape)[..., None]
            y_va_sc = sy.transform(y_va.reshape(-1, 1)).reshape(y_va.shape)[..., None]

            # ─ Build & train model (factory) ───────────────────
            tag        = f"{exp_name.replace('+','_')}_{fold_name}"
            model_path = BASE_MODEL_DIR / f"{tag}.keras"
            if model_path.exists():
                print(f"⏩ {tag} ya existe → skip"); continue

            model = builder(
                input_window=INPUT_WINDOW,
                output_horizon=HORIZON,
                spatial_height=lat,
                spatial_width=lon,
                n_features=feats,
                n_filters=n_filters,
                n_heads=n_heads,
                lr=LR
            )

            es   = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True)
            hist = model.fit(X_tr_sc, y_tr_sc, validation_data=(X_va_sc, y_va_sc), epochs=EPOCHS, batch_size=BATCH_SIZE, callbacks=[es], verbose=1)

            # ─ Evaluación ─────────────────────────────────────
            y_hat_sc = model.predict(X_va_sc, verbose=0)
            y_hat    = sy.inverse_transform(y_hat_sc.reshape(-1,1)).reshape(y_hat_sc.shape)
            y_true   = sy.inverse_transform(y_va_sc.reshape(-1,1)).reshape(y_va_sc.shape)

            rmse, mae, mape, r2 = evaluate(y_true.ravel(), y_hat.ravel())
            RESULTS.append(dict(experiment=exp_name, fold=fold_name, RMSE=rmse, MAE=mae, MAPE=mape, R2=r2, epochs=len(hist.history['loss'])))

            # ─ Guardado artefactos ────────────────────────────
            model.save(model_path)
            plt.figure(); plt.plot(hist.history['loss'], label='train'); plt.plot(hist.history['val_loss'], label='val'); plt.legend(); plt.title(tag); plt.savefig(IMAGE_DIR/f"{tag}.png"); plt.close()

            _generate_gif(y_true[0], y_hat[0], tag)
            print(f"✅ Guardado {model_path.name}")

    # ─ Métricas globales ────────────────────────────────────
    df = pd.DataFrame(RESULTS)
    out_csv = BASE_MODEL_DIR / "metrics_experiments_folds.csv"
    df.to_csv(out_csv, index=False)
    print(f"\n📑 Tabla de métricas en {out_csv}")
# ╰────────────────────────────────────────────────────────────╯

# ╭──────────────────── Generador de GIF ──────────────────────╮

def _generate_gif(y_true_sample, y_pred_sample, tag):
    pcm_min, pcm_max = 0, np.max(y_pred_sample)
    frames = []
    for h in range(HORIZON):
        pmap = y_pred_sample[h, ..., 0]
        fig, ax = plt.subplots(1,1, figsize=(6,5), subplot_kw={'projection':ccrs.PlateCarree()})
        mesh = ax.pcolormesh(ds.longitude, ds.latitude, pmap, cmap='Blues', shading='nearest', vmin=pcm_min, vmax=pcm_max, transform=ccrs.PlateCarree())
        ax.coastlines(); ax.gridlines(draw_labels=True)
        ax.set_title(f"{tag} – H{h+1}")
        fig.colorbar(mesh, ax=ax, fraction=0.046, pad=0.04)
        tmp = GIF_DIR/f"tmp_{tag}_h{h}.png"
        fig.savefig(tmp, bbox_inches='tight'); plt.close(fig)
        frames.append(imageio.imread(tmp)); tmp.unlink(missing_ok=True)
    gif_path = GIF_DIR/f"{tag}.gif"
    imageio.mimsave(gif_path, frames, fps=0.5)
    print(f"💾 GIF {gif_path.name} listo")
# ╰────────────────────────────────────────────────────────────╯

# ╭────────────────────── Bucle principal ─────────────────────╮
run_all_experiments()
# ╰────────────────────────────────────────────────────────────╯


In [13]:

# 📈 **Evaluador para salidas espaciales ConvLSTM**

# ───────── Imports ──────────
from pathlib import Path
import numpy as np, pandas as pd, xarray as xr, tensorflow as tf
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt, geopandas as gpd, imageio.v2 as imageio
import sys
from typing import List, Dict, Any

# ───────── Paths & Constantes ─────────
# ▶️ Path configuration
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
    # Instalar dependencias necesarias
    !pip install -r requirements.txt
    !pip install xarray netCDF4 optuna matplotlib seaborn lightgbm xgboost scikit-learn ace_tools_open cartopy geopandas
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p / '.git').exists():
            BASE_PATH = p; break

import cartopy.crs as ccrs
print('BASE_PATH =', BASE_PATH)

# Dataset paths
DATA_DIR = BASE_PATH/'data'/'output'
MODEL_OUTPUT_DIR = BASE_PATH/'models'/'output'
MODEL_DIR = BASE_PATH/'models'/'output'/'HybridLSTMModels'
MODEL_DIR.mkdir(parents=True, exist_ok=True)
MODEL_INPUT_DIR = BASE_PATH/'data'/'input'/'shapes'
MODEL_INPUT_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_DIR = MODEL_DIR/'images'
IMAGE_DIR.mkdir(exist_ok=True)
FULL_NC = DATA_DIR/'complete_dataset_with_features_with_clusters_elevation_windows_imfs_with_onehot_elevation_clean.nc'
departamentos = gpd.read_file(MODEL_INPUT_DIR/'MGN_Departamento.shp')

BASE_MODEL_DIR = MODEL_DIR
GIF_DIR        = MODEL_DIR / "gifs"
GIF_DIR.mkdir(exist_ok=True)

# ╭──────────────────── Hyper‑parámetros globales ─────────────╮
INPUT_WINDOW   = 60
HORIZON        = 3
TARGET_VAR     = 'total_precipitation'
EPOCHS         = 12
BATCH_SIZE     = 4           # tamaño pequeño → menor RAM GPU
PATIENCE       = 10
LR             = 1e-3
# ╰────────────────────────────────────────────────────────────╯


# ───────── Dataset & shapes ─────────
ds = xr.open_dataset(FULL_NC); lat,lon=len(ds.latitude),len(ds.longitude)

#╭────────────────────── Experiments & Folds ─────────────────╮

# Factories ---------------------------------------------------

def factory_no_attn(**kw):
    return _build_convlstm_ed(use_attention=False, **kw)

def factory_attn(**kw):
    return _build_convlstm_ed(use_attention=True, **kw)
# ╰────────────────────────────────────────────────────────────╯

# ▸ Solo mostramos los tres primeros niveles; añade los demás igual
BASE_FEATURES = [
    'year','month','month_sin','month_cos','doy_sin','doy_cos',
    'max_daily_precipitation','min_daily_precipitation','daily_precipitation_std',
    'elevation','slope','aspect'
]
ELEV_CLUSTER = ['elev_high','elev_med','elev_low']
KCE_FEATURES = BASE_FEATURES + ELEV_CLUSTER
PAFC_FEATURES= KCE_FEATURES + ['total_precipitation_lag1','total_precipitation_lag2','total_precipitation_lag12']

FOLDS = {'F1': {'year': 2018,'active': True}}

EXPERIMENTS: Dict[str, Dict[str, Any]] = {
    'ConvLSTM-ED': {
        'active': True,
        'feature_list': BASE_FEATURES,
        'builder': factory_attn, #factory_no_attn,
        'n_filters': 64,
        'n_heads'  : 4
    },
    'ConvLSTM-ED-KCE': {
        'active': True,
        'feature_list': KCE_FEATURES,
        'builder': factory_attn,
        'n_filters': 64,
        'n_heads'  : 4,
    },
    'ConvLSTM-ED-KCE-PAFC': {
        'active': True,
        'feature_list': PAFC_FEATURES,
        'builder': factory_attn,
        'n_filters': 96,
        'n_heads'  : 6,
    },
}
# ╰────────────────────────────────────────────────────────────╯


def quick_plot(ax,data,cmap,title,date_label,vmin=None,vmax=None):
    mesh=ax.pcolormesh(ds.longitude,ds.latitude,data,cmap=cmap,shading='nearest',vmin=vmin,vmax=vmax,transform=ccrs.PlateCarree())
    ax.coastlines(); ax.add_geometries(departamentos.geometry,ccrs.PlateCarree(),edgecolor='black',facecolor='none',linewidth=1)
    gl=ax.gridlines(draw_labels=True); gl.top_labels=False; gl.right_labels=False
    ax.set_title(f"{title}\n{date_label}",pad=10); return mesh

# ───────── Recuperamos diccionario EXPERIMENTS (del bloque de entrenamiento) ─────────
from typing import Dict
EXPERIMENTS:Dict[str,Dict[str,Any]] = {
    'ConvLSTM-ED':              {'feature_list': "+".join(BASE_FEATURES).split("+")},
    'ConvLSTM-ED-KCE':          {'feature_list': "+".join(KCE_FEATURES).split("+")},
    'ConvLSTM-ED-KCE-PAFC':     {'feature_list': "+".join(PAFC_FEATURES).split("+")},
    # otros experimentos
}

# ———————————————————— Evaluación ————————————————————
all_metrics=[]; times=pd.to_datetime(ds.time.values)
for mpath in sorted(BASE_MODEL_DIR.glob("*.keras")):
    tag   = mpath.stem                        # p.ej. ConvLSTM-ED_F1
    parts = tag.split("_")
    fold  = parts[-1]                         # F1
    exp_token = "_".join(parts[:-1])
    exp_name  = exp_token.replace("_","+")  # vuelve al nombre original con +
    if exp_name not in EXPERIMENTS:
        print("⚠️ Exp no encontrado para",tag); continue
    feats = EXPERIMENTS[exp_name]['feature_list']
    print(f"\n🔍 Evaluando {tag} …")

    # — Extracción de arrays —
    Xarr = ds[feats].to_array().transpose('time','latitude','longitude','variable').values.astype(np.float32)
    yarr = ds[TARGET_VAR].values.astype(np.float32)
    T,_,_,F = Xarr.shape
    Xfull = Xarr; yfull=yarr  # mantenemos (T,H,W,F)

    # ventana final (idéntica lógica del cuaderno original)
    start=T-INPUT_WINDOW-HORIZON; end_w=start+INPUT_WINDOW; end_y=end_w+HORIZON
    X_eval = Xfull[start:end_w]                 # (60,H,W,F)
    y_eval = yfull[end_w:end_y]                 # (3,H,W)

    # — Scalers (fit incremental) —
    sx,sy = StandardScaler(),StandardScaler()
    for t in range(T):
        sx.partial_fit(Xfull[t].reshape(-1,F))
        sy.partial_fit(yfull[t].reshape(-1,1))
    Xe_sc = sx.transform(X_eval.reshape(-1,F)).reshape(1,INPUT_WINDOW,lat,lon,F)
    ye_sc = sy.transform(y_eval.reshape(-1,1)).reshape(1,HORIZON,lat,lon,1)

    # — Carga modelo y predicción —
    model=tf.keras.models.load_model(mpath,compile=False)
    yhat_sc=model.predict(Xe_sc,verbose=0)      # (1,3,H,W,1)
    yhat   = sy.inverse_transform(yhat_sc.reshape(-1,1)).reshape(HORIZON,lat,lon)
    ytrue  = y_eval

    # — Métricas por horizonte —
    for h in range(HORIZON):
        yt = ytrue[h].ravel()
        yp = yhat[h].ravel()

        # ---------- filtro NaN / ±∞ ----------
        mask = np.isfinite(yt) & np.isfinite(yp)
        if mask.sum() == 0:          # ventana vacía → se ignora
            print(f"   · h={h+1}: todos los valores son NaN/Inf → skip")
            continue
        yt, yp = yt[mask], yp[mask]
        # -------------------------------------

        rmse = np.sqrt(mean_squared_error(yt, yp))
        mae  = mean_absolute_error(yt, yp)
        mape = np.mean(np.abs((yt - yp) / (yt + 1e-5))) * 100
        r2   = r2_score(yt, yp)

        all_metrics.append(dict(
            model      = tag,
            experiment = exp_name,
            fold       = fold,
            horizon    = h + 1,
            RMSE       = rmse,
            MAE        = mae,
            MAPE       = mape,
            R2         = r2
        ))

    # — Figura Real vs Pred vs MAPE —
    fig,axes=plt.subplots(HORIZON,3,figsize=(14,4*HORIZON),subplot_kw={'projection':ccrs.PlateCarree()})
    dates=pd.date_range(times[end_w],periods=HORIZON,freq='MS')
    vmin=0; vmax=max(yhat.max(),ytrue.max())
    for h in range(HORIZON):
        quick_plot(axes[h,0],ytrue[h],'Blues',f"Real h={h+1}",dates[h].strftime('%Y-%m'),vmin,vmax)
        quick_plot(axes[h,1],yhat [h],'Blues',f"Pred h={h+1}",dates[h].strftime('%Y-%m'),vmin,vmax)
        err=np.clip(np.abs((ytrue[h]-yhat[h])/(ytrue[h]+1e-5))*100,0,100)
        quick_plot(axes[h,2],err,'Reds',f"MAPE% h={h+1}",dates[h].strftime('%Y-%m'),0,100)
    fig.suptitle(f"{tag}  — Eval final ventana",fontsize=16); fig.tight_layout();
    fig.savefig(BASE_MODEL_DIR/f"fig_{tag}.png"); plt.close(fig)

    # — GIF —
    frames=[]; pcm_min,pcm_max=0,yhat.max()
    for h in range(HORIZON):
        figg,ax=plt.subplots(1,1,figsize=(6,5),subplot_kw={'projection':ccrs.PlateCarree()})
        m=ax.pcolormesh(ds.longitude,ds.latitude,yhat[h],cmap='Blues',shading='nearest',vmin=pcm_min,vmax=pcm_max,transform=ccrs.PlateCarree())
        ax.coastlines(); ax.set_title(f"{tag} – H{h+1}"); figg.colorbar(m,ax=ax,fraction=0.046,pad=0.04)
        tmp=GIF_DIR/f"tmp_{tag}_{h}.png"; figg.savefig(tmp,bbox_inches='tight'); plt.close(figg)
        frames.append(imageio.imread(tmp)); tmp.unlink(missing_ok=True)
    imageio.mimsave(GIF_DIR/f"{tag}.gif",frames,fps=0.5)
    print("💾 GIF",f"{tag}.gif","creado")

# ——— Guardar tabla ———
pd.DataFrame(all_metrics).to_csv(BASE_MODEL_DIR/'metrics_eval.csv',index=False)
print("📑 Métricas guardadas en",BASE_MODEL_DIR/'metrics_eval.csv')


BASE_PATH = /Users/riperez/Conda/anaconda3/envs/precipitation_prediction/github.com/ml_precipitation_prediction

🔍 Evaluando ConvLSTM-ED-KCE-PAFC_F1 …


  updated_mean = (last_sum + new_sum) / updated_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= correction**2 / new_sample_count
  T = new_sum / new_sample_count
  new_unnormalized_variance -= co

   · h=1: todos los valores son NaN/Inf → skip
   · h=2: todos los valores son NaN/Inf → skip
   · h=3: todos los valores son NaN/Inf → skip
💾 GIF ConvLSTM-ED-KCE-PAFC_F1.gif creado

🔍 Evaluando ConvLSTM-ED-KCE_F1 …
💾 GIF ConvLSTM-ED-KCE_F1.gif creado

🔍 Evaluando ConvLSTM-ED_F1 …
💾 GIF ConvLSTM-ED_F1.gif creado
📑 Métricas guardadas en /Users/riperez/Conda/anaconda3/envs/precipitation_prediction/github.com/ml_precipitation_prediction/models/output/HybridLSTMModels/metrics_eval.csv
