In [None]:
# ───────────────────────── IMPORTS ─────────────────────────
from __future__ import annotations
from pathlib import Path
import sys, os, gc, warnings
import numpy as np, pandas as pd, xarray as xr
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Conv2D, ConvLSTM2D, SimpleRNN, Flatten, Dense, Reshape,
    Lambda, Permute
)

# ── ConvGRU2D: Verificar disponibilidad ───────────────────────────
try:
    from tensorflow.keras.layers import ConvGRU2D  # TF ≥ 2.8
    HAS_CONVGRU = True
    print("✅ ConvGRU2D nativo disponible")
except ImportError:
    HAS_CONVGRU = False
    print("⚠️ ConvGRU2D no disponible. Se usará ConvLSTM2D como alternativa.")

from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt, seaborn as sns, geopandas as gpd, imageio.v2 as imageio
warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid'); sns.set_context('notebook')

# ───────────────────────── ENTORNO / GPU ─────────────────────────
## ╭─────────────────────────── Rutas ──────────────────────────╮
# ▶️ Path configuration
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
    # Instalar dependencias necesarias
    %pip install -r requirements.txt
    %pip install xarray netCDF4 optuna matplotlib seaborn lightgbm xgboost scikit-learn ace_tools_open cartopy geopandas
else:
    BASE_PATH = Path.cwd()
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p / '.git').exists():
            BASE_PATH = p; break

import cartopy.crs as ccrs

# limitar crecimiento de memoria‑GPU (evita OOM)
for g in tf.config.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(g, True)

# ───────────────────────── PATHS & CONST ─────────────────────────
DATA_FILE = BASE_PATH/'data'/'output'/(
    'complete_dataset_with_features_with_clusters_elevation_windows_imfs_with_onehot_elevation_clean.nc')
OUT_ROOT  = BASE_PATH/'models'/'output'/'Spatial_CONVRNN'
OUT_ROOT.mkdir(parents=True, exist_ok=True)
SHAPE_DIR = BASE_PATH/'data'/'input'/'shapes'
DEPT_GDF   = gpd.read_file(SHAPE_DIR/'MGN_Departamento.shp')

INPUT_WINDOW = 60
HORIZON = 3 
EPOCHS = 50
BATCH = 4
LR = 1e-3
PATIENCE = 6

# ───────────────────────── FEATURE SETS ─────────────────────────
BASE_FEATS = ['year','month','month_sin','month_cos','doy_sin','doy_cos',
              'max_daily_precipitation','min_daily_precipitation','daily_precipitation_std',
              'elevation','slope','aspect']
ELEV_CLUSTER = ['elev_high','elev_med','elev_low']
KCE_FEATS = BASE_FEATS + ELEV_CLUSTER
PAFC_FEATS= KCE_FEATS + ['total_precipitation_lag1','total_precipitation_lag2','total_precipitation_lag12']
EXPERIMENTS = {'BASIC':BASE_FEATS,'KCE':KCE_FEATS,'PAFC':PAFC_FEATS}

# ───────────────────────── DATASET ─────────────────────────
ds = xr.open_dataset(DATA_FILE)
lat, lon = len(ds.latitude), len(ds.longitude)
print(f"Dataset → time={len(ds.time)}, lat={lat}, lon={lon}")

# ───────────────────────── HELPERS ─────────────────────────

def windowed_arrays(X:np.ndarray, y:np.ndarray):
    seq_X, seq_y = [], []
    T = len(X)
    for start in range(T-INPUT_WINDOW-HORIZON+1):
        end_w = start+INPUT_WINDOW; end_y=end_w+HORIZON
        Xw, yw = X[start:end_w], y[end_w:end_y]
        if np.isnan(Xw).any() or np.isnan(yw).any():
            continue
        seq_X.append(Xw); seq_y.append(yw)
    return np.asarray(seq_X,dtype=np.float32), np.asarray(seq_y,dtype=np.float32)

def quick_plot(ax,data,cmap,title,vmin=None,vmax=None):
    mesh = ax.pcolormesh(ds.longitude,ds.latitude,data,cmap=cmap,shading='nearest',
                         vmin=vmin,vmax=vmax,transform=ccrs.PlateCarree())
    ax.coastlines(); ax.add_geometries(DEPT_GDF.geometry,ccrs.PlateCarree(),
                                       edgecolor='black',facecolor='none',linewidth=1)
    # etiquetas desactivadas para evitar bug en cartopy
    ax.gridlines(draw_labels=False, linewidth=.5, linestyle='--', alpha=.4)
    ax.set_title(title,fontsize=9)
    return mesh

# ───────────────────────── LIGHTWEIGHT HEAD ─────────────────────────

def _spatial_head(x):
    """Proyección 1×1 → (B, H,lat,lon,1) con *shape hints* para que
    Keras pueda reconstruir la capa `Lambda` al volver a cargar el modelo.
    """
    #   1) Conv 1×1 que genera H mapas (uno por horizonte)
    x = Conv2D(
        HORIZON,
        (1, 1),
        padding="same",
        activation="linear",
        name="head_conv1x1",
    )(x)  # ==> (B, lat, lon, H)

    #   2) Transponemos a (B, H, lat, lon)
    x = Lambda(
        lambda t: tf.transpose(t, [0, 3, 1, 2]),
        output_shape=(HORIZON, lat, lon),
        name="head_transpose",
    )(x)

    #   3) Añadimos eje canales: (B, H, lat, lon, 1)
    x = Lambda(
        lambda t: tf.expand_dims(t, -1),
        output_shape=(HORIZON, lat, lon, 1),
        name="head_expand_dim",
    )(x)
    return x

# ───────────────────────── MODEL FACTORIES ─────────────────────────

def build_conv_lstm(n_feats:int):
    inp = Input(shape=(INPUT_WINDOW,lat,lon,n_feats))
    x   = ConvLSTM2D(32,(3,3),padding='same',return_sequences=True)(inp)
    x   = ConvLSTM2D(16,(3,3),padding='same',return_sequences=False)(x)
    out = _spatial_head(x)
    return Model(inp, out, name='ConvLSTM')

def build_conv_gru(n_feats: int):
    """
    Construye modelo ConvGRU si está disponible, 
    de lo contrario usa ConvLSTM2D como alternativa
    """
    inp = Input(shape=(INPUT_WINDOW, lat, lon, n_feats))
    
    if HAS_CONVGRU:
        # Usar ConvGRU2D nativo
        x = ConvGRU2D(32, (3, 3), padding="same", return_sequences=True)(inp)
        x = ConvGRU2D(16, (3, 3), padding="same", return_sequences=False)(x)
        model_name = "ConvGRU"
    else:
        # Usar ConvLSTM2D como alternativa robusta
        print("  → Usando ConvLSTM2D como alternativa a ConvGRU2D")
        x = ConvLSTM2D(32, (3, 3), padding="same", return_sequences=True)(inp)
        x = ConvLSTM2D(16, (3, 3), padding="same", return_sequences=False)(x)
        model_name = "ConvGRU_alt"
    
    out = _spatial_head(x)
    return Model(inp, out, name=model_name)

def build_conv_rnn(n_feats:int):
    inp = Input(shape=(INPUT_WINDOW,lat,lon,n_feats))
    x = Flatten()(inp)
    x = SimpleRNN(64,activation='tanh')(x)
    x = Dense(lat*lon*HORIZON)(x)
    out = Reshape((HORIZON,lat,lon,1))(x)
    return Model(inp, out, name='ConvRNN')

MODELS = {'ConvLSTM': build_conv_lstm, 'ConvRNN': build_conv_rnn}

# ───────────────────────── TRAIN + EVAL LOOP ─────────────────────────
results=[]
for exp, feat_list in EXPERIMENTS.items():
    print(f"\n=== Experimento {exp} ({len(feat_list)} feats) ===")
    Xarr = ds[feat_list].to_array().transpose('time','latitude','longitude','variable').values.astype(np.float32)
    yarr = ds['total_precipitation'].values.astype(np.float32)[...,None]
    X, y = windowed_arrays(Xarr, yarr)
    split=int(0.8*len(X))

    sx = StandardScaler().fit(X[:split].reshape(-1,len(feat_list)))
    sy = StandardScaler().fit(y[:split].reshape(-1,1))
    X_sc = sx.transform(X.reshape(-1,len(feat_list))).reshape(X.shape)
    y_sc = sy.transform(y.reshape(-1,1)).reshape(y.shape)
    X_tr, X_va = X_sc[:split], X_sc[split:]
    y_tr, y_va = y_sc[:split], y_sc[split:]

    OUT_EXP = OUT_ROOT/exp; OUT_EXP.mkdir(exist_ok=True)

    for mdl_name,builder in MODELS.items():
        print(f"→ {mdl_name}")
        model_path = OUT_EXP/f"{mdl_name.lower()}_best.keras"
        if model_path.exists():
            model_path.unlink()  # Eliminar modelo antiguo
        
        try:
            # Siempre construir y entrenar desde cero
            model = builder(n_feats=len(feat_list))
            model.compile(tf.keras.optimizers.Adam(LR), 'mse')
            cb=[EarlyStopping('val_loss',patience=PATIENCE,restore_best_weights=True),
                ModelCheckpoint(model_path,save_best_only=True)]
            model.fit(X_tr, y_tr, validation_data=(X_va,y_va),
                      epochs=EPOCHS, batch_size=BATCH, callbacks=cb, verbose=0)

            # ─ Pred última ventana validación ─
            y_hat_sc = model.predict(X_va[-1:],verbose=0)
            y_hat = sy.inverse_transform(y_hat_sc.reshape(-1,1)).reshape(HORIZON,lat,lon)
            y_true= sy.inverse_transform(y_va[-1:].reshape(-1,1)).reshape(HORIZON,lat,lon)

            # ─ Mapas & GIF ─
            vmin,vmax=0,max(y_true.max(),y_hat.max()); frames=[]
            dates=pd.date_range(ds.time.values[-HORIZON],periods=HORIZON,freq='MS')
            for h in range(HORIZON):
                err=np.clip(np.abs((y_true[h]-y_hat[h])/(y_true[h]+1e-5))*100,0,100)
                fig,axs=plt.subplots(1,3,figsize=(12,4),subplot_kw={'projection':ccrs.PlateCarree()})
                quick_plot(axs[0],y_true[h],'Blues',f"Real h={h+1}",vmin,vmax)
                quick_plot(axs[1],y_hat[h],'Blues',f"{mdl_name} h={h+1}",vmin,vmax)
                quick_plot(axs[2],err,'Reds',f"MAPE% h={h+1}",0,100)
                fig.suptitle(f"{mdl_name} – {exp} – {dates[h].strftime('%Y-%m')}")
                png=OUT_EXP/f"{mdl_name}_{h+1}.png"; fig.savefig(png,bbox_inches='tight'); plt.close(fig)
                frames.append(imageio.imread(png))
            imageio.mimsave(OUT_EXP/f"{mdl_name}.gif",frames,fps=0.5)

            # ─ Métricas ─
            for h in range(HORIZON):
                results.append({
                    'Experiment':exp,'Model':mdl_name,'H':h+1,
                    'RMSE':np.sqrt(mean_squared_error(y_true[h].ravel(),y_hat[h].ravel())),
                    'MAE': mean_absolute_error(y_true[h].ravel(),y_hat[h].ravel()),
                    'R2' : r2_score(y_true[h].ravel(),y_hat[h].ravel())
                })
            tf.keras.backend.clear_session(); gc.collect()
            
        except Exception as e:
            print(f"  ⚠️ Error en {mdl_name}: {str(e)}")
            print(f"  → Saltando {mdl_name} para {exp}")
            continue

# ───────────────────────── CSV FINAL ─────────────────────────
res_df=pd.DataFrame(results)
res_df.to_csv(OUT_ROOT/'metrics_spatial.csv',index=False)
print("\n📑 Metrics saved →", OUT_ROOT/'metrics_spatial.csv')
