
# Spatiotemporal Precipitation Prediction
**5‚ÄØ√ó‚ÄØ5 Experiments Notebook**  
Train & validate five architectures across five temporal folds (48‚ÄØm train ‚Üí 12‚ÄØm val).  Designed to run **locally or on Google¬†Colab** ‚Äî auto‚Äëdetects GPU/CPU and adapts parallelism.

In [1]:
# ‚ñ∂Ô∏è Environment setup (PyTorch + TF + XGBoost)
import sys, subprocess, importlib, os, multiprocessing, logging, warnings, json
basic_pkgs = ["torch","torchvision","torchaudio","pytorch-lightning",
              "xarray","netcdf4","scikit-learn","tqdm","xgboost",
              "tensorflow","geopandas","cartopy","torchmetrics","pytorch_lightning"]
def _install(pkg):
    if importlib.util.find_spec(pkg) is None:
        print(f"Installing {pkg} ...")
        subprocess.check_call([sys.executable,"-m","pip","install","-q",pkg])
for p in basic_pkgs: _install(p)

import torch, xarray as xr, numpy as np, pandas as pd, pytorch_lightning as pl
import tensorflow as tf, geopandas as gpd, cartopy.crs as ccrs
from sklearn.preprocessing import RobustScaler, StandardScaler
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from torch import nn
from tqdm.auto import tqdm
# ‚ñ∂Ô∏è Funciones para curvas de aprendizaje y visualizaci√≥n de predicciones
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.gridspec import GridSpec
import numpy as np
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from sklearn.metrics import mean_absolute_percentage_error
from pathlib import Path

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_GPU  = torch.cuda.device_count()
CPU_CORES = multiprocessing.cpu_count()
NUM_WORKERS = max(1, int(CPU_CORES // 2))  # Use half of CPU cores for data loading
print(f"‚úÖ Torch device: {DEVICE} | GPUs: {N_GPU} | CPU cores: {CPU_CORES} | Workers: {NUM_WORKERS}")


# ‚ñ∂Ô∏è Path configuration (Colab vs Local)
from pathlib import Path
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    BASE_PATH = Path('/content/drive/MyDrive/ml_precipitation_prediction')
else:
    BASE_PATH = Path.cwd()
    # climb to project root if inside subfolder
    for p in [BASE_PATH, *BASE_PATH.parents]:
        if (p / '.git').exists():
            BASE_PATH = p; break
print('BASE_PATH =', BASE_PATH)

# centralised dataset / model paths
DATA_DIR      = BASE_PATH/'data'/'output'
MODEL_DIR     = BASE_PATH/'models'/'output'/'trained_models'; MODEL_DIR.mkdir(parents=True, exist_ok=True)
IMAGE_DIR     = MODEL_DIR/'images'; IMAGE_DIR.mkdir(exist_ok=True)
FEATURES_NC   = BASE_PATH/'models'/'output'/'features_fusion_branches.nc'
FULL_NC       = DATA_DIR/'complete_dataset_with_features_with_clusters_elevation_with_windows.nc'
PRECIP_NC     = DATA_DIR/'precip_topo.nc'  # for PyTorch toy experiments
print('Using FULL_NC  :', FULL_NC)
print('Using FEATURES :', FEATURES_NC)



FOLDS = {'F1':2024,'F2':2023,'F3':2022,'F4':2000,'F5':1990}

# Actualizar diccionario de experimentos seg√∫n la nueva nomenclatura
EXPERIMENTS = {
    'GRU-ED': {'model':'gru_ed', 'use_lags':False},
    'GRU-ED-PAFC': {'model':'gru_ed', 'use_lags':True},
    'AE-FUSION-GRU-ED-PAFC': {'model':'ae_fusion_gru', 'use_lags':True},
    'AE-FUSION-GRU-ED-PAFC-T': {'model':'ae_fusion_gru_t', 'use_lags':True},
    'AE-FUSION-GRU-ED-PAFC-T-TopoMask': {'model':'ae_fusion_gru_t_mask', 'use_lags':True},
}

# ‚ñ∂Ô∏è Add variable definitions consistent with documentation
FULL_FEATURES = [
    'precip_hist','lag_1','lag_2','lag_12',
    'month_sin','month_cos','doy_sin','doy_cos',
    'elevation','slope','roughness','curvature','aspect',
    'alt_cluster','ceemdan_imf1','ceemdan_imf2','ceemdan_imf3',
    'tvfemd_imf1','tvfemd_imf2','tvfemd_imf3'
]

BASE_FEATURES = [
    'precip_hist','lag_1','lag_2','lag_12',
    'month_sin','month_cos','doy_sin','doy_cos',
    'elevation','slope','roughness','curvature',
    'alt_cluster'
]


# ‚ñ∂Ô∏è Helper functions
import pandas as pd, numpy as np
def add_time_encodings(ds: xr.Dataset):
    '''Add month/day-of-year sinusoidal encodings'''
    dates = pd.to_datetime(ds['time'].values)
    month = dates.month
    doy = dates.dayofyear
    ds['month_sin'] = ('time', np.sin(2*np.pi*month/12))
    ds['month_cos'] = ('time', np.cos(2*np.pi*month/12))
    ds['doy_sin']   = ('time', np.sin(2*np.pi*doy/365.25))
    ds['doy_cos']   = ('time', np.cos(2*np.pi*doy/365.25))
    return ds

# ‚ñ∂Ô∏è Logger & helper prints
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s [%(levelname)s] %(message)s',
                    datefmt='%H:%M:%S')
logger = logging.getLogger('precip')

def print_progress(msg, level=0, is_start=False, is_end=False):
    prefix={0:'üîµ ' if is_start else '‚úÖ ' if is_end else '‚û°Ô∏è ',
            1:'  ‚ö™ ',2:'    ‚Ä¢ '}.get(level,'')
    print(f'{prefix}{msg}')

# (Reuse code from earlier minimal pipeline, but path variable PRECIP_NC)
DATASET_PATH = str(PRECIP_NC)
INPUT_WINDOW=48; HORIZON=12; BATCH_SIZE=32
FOLDS={'F1':2024,'F2':2023,'F3':2022,'F4':2000,'F5':1990}
# ... (insert PyTorch dataset, model, training utils from earlier) ...
print_progress('‚ö†Ô∏è   PyTorch quick baseline section trimmed for brevity ‚Äî insert from earlier if desired', level=1)

# ‚ñ∂Ô∏è Verify precipitation lags utility
def verify_precipitation_lags(ds, required_lags=None, min_valid_ratio=0.9):
    all_possible = [f"total_precipitation_lag{i}" for i in [1,2,3,4,12,24,36]]
    lags = required_lags or [l for l in all_possible if l in ds.data_vars]
    if not lags: raise ValueError('No lag variables found.')
    for lag in lags:
        arr = ds[lag].values
        valid = np.count_nonzero(~np.isnan(arr))
        ratio = valid/arr.size
        logger.info(f'{lag}: {ratio:.1%} valid')
        if ratio<min_valid_ratio:
            raise ValueError(f'{lag} has only {ratio:.1%} valid data (<{min_valid_ratio})')
    logger.info('Lag verification ‚úÖ')

# ‚ñ∂Ô∏è NaN‚Äërobust scaling utils
def check_nans(arr, name='array'):
    nan_cnt=np.isnan(arr).sum(); tot=arr.size
    return {'name':name,'nan':nan_cnt,'total':tot,'pct':nan_cnt/tot*100,'has':nan_cnt>0}

def replace_nans(arr, strategy='mean'):
    if not np.isnan(arr).any(): return arr
    arr=arr.copy()
    if strategy=='mean':
        fill=np.nanmean(arr); arr[np.isnan(arr)]=fill
    elif strategy=='median':
        fill=np.nanmedian(arr); arr[np.isnan(arr)]=fill
    else:
        arr=np.nan_to_num(arr)
    return arr

class ScalerNaN:
    def fit(self,X):
        self.mean_=np.nanmean(X,0); var=np.nanvar(X,0); var[var<1e-9]=1
        self.scale_=np.sqrt(var); return self
    def transform(self,X):
        return (X-self.mean_)/self.scale_
    def fit_transform(self,X): self.fit(X); return self.transform(X)
    def inverse_transform(self,X):
        return X*self.scale_+self.mean_

# ‚ñ∂Ô∏è Dataset & DataLoader builder
class PrecipDataset(Dataset):
    def __init__(self, ds, idx_list, input_window, horizon,
                 sc_p, sc_x, features):
        self.ds = ds
        self.idx = idx_list
        self.w = input_window
        self.h = horizon
        self.scp = sc_p
        self.scx = sc_x
        self.features = features

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, i):
        t,y,x = self.idx[i]
        win = self.ds.isel(time=slice(t-self.w, t), y=y, x=x)
        tgt = self.ds.isel(time=slice(t, t+self.h), y=y, x=x)['precip'].values.astype(np.float32)

        feats=[]
        ph = self.scp.transform(win['precip'].values.reshape(-1,1)).flatten()
        feats.append(ph)

        for var in self.features:
            if var.startswith('lag') or var=='precip_hist':  # already included
                continue
            if var in win:
                arr = win[var].values
                arr = self.scx.transform(arr.reshape(-1,1)).flatten()
                feats.append(arr)

        X = np.concatenate(feats).astype(np.float32)
        return torch.tensor(X), torch.tensor(tgt)

def build_dataloaders(val_year, use_lags, batch_size=BATCH_SIZE):
    ds = xr.open_dataset(DATASET_PATH)
    ds = add_time_encodings(ds)

    train_start = np.datetime64(f'{val_year-4}-01-01')
    train_end   = np.datetime64(f'{val_year-1}-12-31')
    val_start   = np.datetime64(f'{val_year}-01-01')
    val_end     = np.datetime64(f'{val_year}-12-31')

    train_mask = (ds['time']>=train_start)&(ds['time']<=train_end)
    val_mask   = (ds['time']>=val_start)&(ds['time']<=val_end)

    sc_p = RobustScaler().fit(ds['precip'].where(train_mask).values.reshape(-1,1))
    preds=[]
    for var in ['month_sin','month_cos','doy_sin','doy_cos','elevation',
               'slope','roughness','curvature','alt_cluster']:
        if var in ds.data_vars:  # Make sure variable exists in dataset
            preds.append(ds[var].where(train_mask).values.flatten())
    sc_x = StandardScaler().fit(np.concatenate(preds).reshape(-1,1))

    def make_idx(mask):
        idx=[]
        for t in range(INPUT_WINDOW, len(ds['time'])-HORIZON):
            if mask[t+HORIZON-1]:
                for y in range(ds.dims['y']):
                    for x in range(ds.dims['x']):
                        idx.append((t,y,x))
        return idx

    train_idx = make_idx(train_mask)
    val_idx   = make_idx(val_mask)

    feats = BASE_FEATURES.copy()
    if not use_lags:
        feats = [f for f in feats if not f.startswith('lag')]

    train_ds = PrecipDataset(ds, train_idx, INPUT_WINDOW, HORIZON, sc_p, sc_x, feats)
    val_ds   = PrecipDataset(ds, val_idx, INPUT_WINDOW, HORIZON, sc_p, sc_x, feats)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=NUM_WORKERS, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=True)
    return train_loader, val_loader, len(feats)


# ‚ñ∂Ô∏è Model definitions
class GRUEncoderDecoder(nn.Module):
    def __init__(self, input_dim, hidden_size=128, num_layers=2, dropout=0.2, horizon=HORIZON):
        super().__init__()
        self.enc = nn.GRU(input_dim, hidden_size, num_layers,
                          batch_first=True, dropout=dropout)
        self.dec = nn.GRU(1, hidden_size, num_layers,
                          batch_first=True, dropout=dropout)
        self.fc  = nn.Linear(hidden_size,1)
        self.hor = horizon

    def forward(self, x, teacher_forcing_ratio=0.5, y=None):
        _, h = self.enc(x)
        dec_in = x[:, -1:, 0:1]
        outs=[]
        for t in range(self.hor):
            o, h = self.dec(dec_in, h)
            pred = self.fc(o.squeeze(1))
            outs.append(pred)
            if self.training and y is not None and torch.rand(1)<teacher_forcing_ratio:
                dec_in = y[:, t:t+1].unsqueeze(-1)
            else:
                dec_in = pred.unsqueeze(1)
        return torch.stack(outs, dim=1)

# Implementation aligned with documentation
class Conv3DAutoEncoder(nn.Module):
    def __init__(self, in_channels=3, bottleneck_dim=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv3d(in_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.Conv3d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2),
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64*6*6*6, bottleneck_dim)  # Adjust dimensions based on your input
        )
        
    def forward(self, x):
        return self.encoder(x)

class AEFusionGRU(nn.Module):
    def __init__(self, input_dim, hidden_size=128, num_layers=2, dropout=0.2, horizon=HORIZON):
        super().__init__()
        self.ae = Conv3DAutoEncoder(in_channels=3, bottleneck_dim=64)
        
        # Combined dim: original features + bottleneck
        combined_dim = input_dim + 64
        
        self.backbone = GRUEncoderDecoder(combined_dim, hidden_size, num_layers, dropout, horizon)
    
    def forward(self, x, teacher_forcing_ratio=0.5, y=None):
        # Assuming x_imfs is processed elsewhere and passed with x
        # This is a placeholder for the actual implementation
        ae_features = torch.zeros((x.size(0), 64), device=x.device)
        
        # Concatenate features
        combined = torch.cat([x, ae_features.unsqueeze(1).expand(-1, x.size(1), -1)], dim=2)
        
        return self.backbone(combined, teacher_forcing_ratio, y)

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hidden_dim, n_heads=4, dropout=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, n_heads, dropout=dropout)
        self.norm = nn.LayerNormalization(hidden_dim)
        
    def forward(self, x, mask=None):
        attn_out, _ = self.attention(x, x, x, attn_mask=mask)
        return self.norm(x + attn_out)

class AEFusionGRUT(AEFusionGRU):
    def __init__(self, input_dim, hidden_size=128, num_layers=2, dropout=0.2, horizon=HORIZON):
        super().__init__(input_dim, hidden_size, num_layers, dropout, horizon)
        self.attention = MultiHeadAttentionLayer(hidden_size, n_heads=4, dropout=dropout)
        
    def forward(self, x, teacher_forcing_ratio=0.5, y=None):
        # Similar implementation as AEFusionGRU but with attention
        # This is placeholder for the actual implementation with attention
        return super().forward(x, teacher_forcing_ratio, y)

class AEFusionGRUTMask(AEFusionGRUT):
    def __init__(self, input_dim, hidden_size=128, num_layers=2, dropout=0.2, horizon=HORIZON):
        super().__init__(input_dim, hidden_size, num_layers, dropout, horizon)
        
    def forward(self, x, teacher_forcing_ratio=0.5, y=None):
        # Similar implementation but with causal masking for attention
        # This is placeholder for the actual implementation with causal masking
        return super().forward(x, teacher_forcing_ratio, y)

# Update MODEL_FACTORY with proper implementations
MODEL_FACTORY = {
    'gru_ed': GRUEncoderDecoder,
    'ae_fusion_gru': AEFusionGRU,
    'ae_fusion_gru_t': AEFusionGRUT,
    'ae_fusion_gru_t_mask': AEFusionGRUTMask,
}


# ‚ñ∂Ô∏è Training utilities
from torchmetrics.functional import mean_squared_error
def huber_weighted(preds, target):
    h = torch.arange(1, target.size(1)+1, device=preds.device).float()
    weights = 1 + h/12.0
    loss = torch.nn.functional.huber_loss(preds, target, reduction='none')
    return (loss*weights).mean()

def train_one_epoch(model, loader, opt, tf_ratio, scheduler=None):
    model.train()
    losses=[]
    for X,y in loader:
        X,y = X.to(DEVICE), y.to(DEVICE)
        preds = model(X, teacher_forcing_ratio=tf_ratio, y=y)
        loss = huber_weighted(preds, y)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if scheduler:
            scheduler.step()
        losses.append(loss.item())
    return np.mean(losses)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    rmses=[]
    for X,y in loader:
        X,y = X.to(DEVICE), y.to(DEVICE)
        preds = model(X, teacher_forcing_ratio=0.0)
        rmse = mean_squared_error(preds, y, squared=False)
        rmses.append(rmse.item())
    return np.mean(rmses)

# Actualizar la funci√≥n de entrenamiento con visualizaci√≥n de curvas de aprendizaje
def train_with_history(model, train_loader, val_loader, epochs=60, patience=20, 
                      lr=1e-3, weight_decay=1e-4, fold='', exp_name=''):
    """
    Entrena el modelo con captura de historial para curvas de aprendizaje
    
    Args:
        model: Modelo PyTorch a entrenar
        train_loader: DataLoader de entrenamiento
        val_loader: DataLoader de validaci√≥n
        epochs: N√∫mero m√°ximo de √©pocas
        patience: √âpocas para early stopping
        lr: Tasa de aprendizaje
        weight_decay: Par√°metro de regularizaci√≥n
        fold: Identificador del fold para logs
        exp_name: Nombre del experimento para logs
        
    Returns:
        tuple: (mejor_modelo, historial_entrenamiento, mejor_rmse)
    """
    print_progress(f"Iniciando entrenamiento de {exp_name} en fold {fold}", is_start=True)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = OneCycleLR(optimizer, max_lr=lr, total_steps=epochs*len(train_loader),
                         pct_start=0.3, anneal_strategy='cos')
    
    # Inicializar historial para curvas de aprendizaje
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_rmse': [],
        'learning_rate': [],
        'teacher_forcing': []
    }
    
    best_rmse = float('inf')
    best_model_state = None
    counter = 0
    
    for epoch in range(1, epochs+1):
        # Calcular teacher forcing ratio con decaimiento coseno (0.7‚Üí0.3)
        tf_ratio = 0.7 - (epoch-1)*(0.4)/(epochs-1)
        history['teacher_forcing'].append(tf_ratio)
        
        # Entrenamiento
        model.train()
        train_losses = []
        for X, y in train_loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            preds = model(X, teacher_forcing_ratio=tf_ratio, y=y)
            loss = huber_weighted(preds, y)
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_losses.append(loss.item())
            
        # Obtener learning rate actual
        current_lr = scheduler.get_last_lr()[0]
        history['learning_rate'].append(current_lr)
        
        # Evaluaci√≥n
        model.eval()
        val_losses = []
        val_rmses = []
        with torch.no_grad():
            for X, y in val_loader:
                X, y = X.to(DEVICE), y.to(DEVICE)
                preds = model(X, teacher_forcing_ratio=0)
                val_loss = huber_weighted(preds, y).item()
                val_rmse = mean_squared_error(preds, y, squared=False).item()
                val_losses.append(val_loss)
                val_rmses.append(val_rmse)
        
        # Actualizar historia
        epoch_train_loss = np.mean(train_losses)
        epoch_val_loss = np.mean(val_losses)
        epoch_val_rmse = np.mean(val_rmses)
        
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['val_rmse'].append(epoch_val_rmse)
        
        # Imprimir progreso
        print(f"√âpoca {epoch}/{epochs} - Train loss: {epoch_train_loss:.4f} - Val RMSE: {epoch_val_rmse:.4f} - LR: {current_lr:.6f}")
        
        # Comprobar early stopping (‚àÜRMSE < 1%)
        if epoch_val_rmse < best_rmse * 0.99:  # Mejora de al menos 1%
            best_rmse = epoch_val_rmse
            best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            print_progress(f"√âpoca {epoch}: Nuevo mejor modelo con RMSE {best_rmse:.4f}", level=1)
            counter = 0
        else:
            counter += 1
        
        if counter >= patience:
            print_progress(f"Early stopping en √©poca {epoch}", level=1)
            break
    
    # Restaurar mejor modelo
    model.load_state_dict(best_model_state)
    
    # Visualizar curvas de aprendizaje
    plot_learning_curves(history, exp_name, fold)
    
    print_progress(f"Entrenamiento de {exp_name} en fold {fold} completado. Mejor RMSE: {best_rmse:.4f}", is_end=True)
    
    # Guardar modelo
    torch.save(model.state_dict(), MODEL_DIR / f"{exp_name}_{fold}_model.pt")
    
    return model, history, best_rmse

def plot_learning_curves(history, exp_name, fold):
    """
    Genera visualizaciones de curvas de aprendizaje durante el entrenamiento
    
    Args:
        history: Diccionario con historiales de entrenamiento
        exp_name: Nombre del experimento
        fold: ID del fold
    """
    curves_dir = IMAGE_DIR / "learning_curves"
    curves_dir.mkdir(exist_ok=True, parents=True)
    
    fig = plt.figure(figsize=(16, 12))
    gs = GridSpec(2, 2, figure=fig)
    
    # 1. P√©rdida de entrenamiento y validaci√≥n
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(history['train_loss'], label='Entrenamiento', color='#3498db', linewidth=2)
    if 'val_loss' in history and len(history['val_loss']) > 0:
        ax1.plot(history['val_loss'], label='Validaci√≥n', color='#e74c3c', linewidth=2)
    ax1.set_title('P√©rdida durante entrenamiento', fontsize=14)
    ax1.set_xlabel('√âpoca', fontsize=12)
    ax1.set_ylabel('P√©rdida', fontsize=12)
    ax1.grid(alpha=0.3)
    ax1.legend(fontsize=12)
    
    # 2. RMSE de validaci√≥n
    ax2 = fig.add_subplot(gs[0, 1])
    if 'val_rmse' in history and len(history['val_rmse']) > 0:
        ax2.plot(history['val_rmse'], color='#9b59b6', linewidth=2)
        min_rmse = min(history['val_rmse'])
        min_epoch = history['val_rmse'].index(min_rmse)
        ax2.scatter(min_epoch, min_rmse, c='red', s=100, zorder=10, label=f'Mejor: {min_rmse:.4f}')
    ax2.set_title('RMSE de validaci√≥n', fontsize=14)
    ax2.set_xlabel('√âpoca', fontsize=12)
    ax2.set_ylabel('RMSE', fontsize=12)
    ax2.grid(alpha=0.3)
    ax2.legend(fontsize=12)
    
    # 3. Tasa de aprendizaje y Teacher Forcing
    ax3 = fig.add_subplot(gs[1, 0])
    if 'learning_rate' in history and len(history['learning_rate']) > 0:
        ax3.plot(history['learning_rate'], color='#2ecc71', linewidth=2)
        ax3.set_title('Tasa de aprendizaje (OneCycleLR)', fontsize=14)
        ax3.set_xlabel('√âpoca', fontsize=12)
        ax3.set_ylabel('Learning Rate', fontsize=12)
        ax3.set_yscale('log')
        ax3.grid(alpha=0.3)
    
    ax4 = fig.add_subplot(gs[1, 1])
    if 'teacher_forcing' in history and len(history['teacher_forcing']) > 0:
        ax4.plot(history['teacher_forcing'], color='#f39c12', linewidth=2)
        ax4.set_title('Teacher Forcing Ratio (0.7 ‚Üí 0.3)', fontsize=14)
        ax4.set_xlabel('√âpoca', fontsize=12)
        ax4.set_ylabel('Teacher Forcing', fontsize=12)
        ax4.set_ylim(0, 1)
        ax4.grid(alpha=0.3)
    
    plt.suptitle(f'{exp_name} - Fold {fold}', fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(curves_dir / f'{exp_name}_{fold}_learning_curves.png', dpi=100, bbox_inches='tight')
    plt.close(fig)
    
    print_progress(f"Curvas de aprendizaje guardadas en: {curves_dir / f'{exp_name}_{fold}_learning_curves.png'}", level=1)

# ‚ñ∂Ô∏è Main experiment loop con curvas de aprendizaje y visualizaci√≥n
RESULTS = []
ALL_HISTORIES = {}
ALL_MODELS = {}

# Crear carpeta para m√©tricas agregadas
metrics_dir = MODEL_DIR / "metrics"
metrics_dir.mkdir(exist_ok=True, parents=True)

for exp_name, cfg in EXPERIMENTS.items():
    print_progress(f"Ejecutando experimento: {exp_name}", is_start=True)
    exp_histories = {}
    exp_models = {}
    exp_metrics = []
    
    for fold, val_year in FOLDS.items():
        print_progress(f"Procesando fold {fold} (validaci√≥n: {val_year})", level=1)
        
        # Construir dataloaders
        train_loader, val_loader, in_dim = build_dataloaders(val_year, cfg['use_lags'])
        
        # Ajustar dropout seg√∫n documentaci√≥n (0.25 para F4-F5, 0.20 para los dem√°s)
        dropout = 0.25 if fold in ['F4', 'F5'] else 0.20
        print_progress(f"Usando dropout={dropout} para fold {fold}", level=2)
        
        # Crear y entrenar modelo con seguimiento de historia
        model = MODEL_FACTORY[cfg['model']](in_dim, dropout=dropout).to(DEVICE)
        model, history, best_rmse = train_with_history(
            model, train_loader, val_loader, 
            epochs=60, patience=20, 
            lr=1e-3, weight_decay=1e-4,
            fold=fold, exp_name=exp_name
        )
        
        # Guardar resultados
        RESULTS.append({
            'exp': exp_name,
            'fold': fold,
            'rmse': best_rmse
        })
        
        # Almacenar modelo e historial
        exp_histories[fold] = history
        exp_models[fold] = model
        
        # Generar visualizaci√≥n de predicciones si est√° implementada prepare_grid_data
        try:
            # Descomentar las siguientes l√≠neas cuando prepare_grid_data est√© implementada
            # visualize_predictions(model, xr.open_dataset(DATASET_PATH), val_year, exp_name, fold)
            pass
        except Exception as e:
            print_progress(f"Error en visualizaci√≥n: {str(e)}", level=1)
    
    # Almacenar historias y modelos
    ALL_HISTORIES[exp_name] = exp_histories
    ALL_MODELS[exp_name] = exp_models
    
    print_progress(f"Experimento {exp_name} completado", is_end=True)

# ‚ñ∂Ô∏è Visualizar tabla de resultados
df = pd.DataFrame(RESULTS)
pivot_table = df.pivot(index='exp', columns='fold', values='rmse')
print_progress("Resumen de resultados RMSE:", is_start=True)
display(pivot_table)

plt.title('Comparaci√≥n de RMSE por experimento y fold', fontsize=14)
plt.xlabel('Experimento')
plt.ylabel('RMSE')
plt.xticks(rotation=45)
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(IMAGE_DIR / "experiment_comparison.png", dpi=100)
plt.show()

def visualize_predictions(model, dataset, val_year, exp_name, fold, scalers=None):
    """
    Genera mapas de predicciones y errores MAPE para los 12 meses de validaci√≥n
    
    Args:
        model: Modelo entrenado
        dataset: Dataset xarray completo
        val_year: A√±o de validaci√≥n
        exp_name: Nombre del experimento
        fold: ID del fold
        scalers: Tuple (sc_p, sc_x) de escaladores para transformar datos
    """
    print_progress(f"Generando visualizaciones para {exp_name}, fold {fold}", is_start=True)
    
    # Preparar directorio para guardar visualizaciones
    vis_dir = IMAGE_DIR / f"{exp_name}_{fold}_maps"
    vis_dir.mkdir(exist_ok=True, parents=True)
    
    # Obtener meses del per√≠odo de validaci√≥n
    months = pd.date_range(f"{val_year}-01-01", f"{val_year}-12-31", freq='MS')
    month_names = ['Ene', 'Feb', 'Mar', 'Abr', 'May', 'Jun', 'Jul', 'Ago', 'Sep', 'Oct', 'Nov', 'Dic']
    
    # Extraer coordenadas
    lats = dataset.latitude.values
    lons = dataset.longitude.values
    
    # Crear matrices para almacenar resultados
    predictions = np.zeros((len(months), len(lats), len(lons)))
    true_values = np.zeros((len(months), len(lats), len(lons)))
    mape_values = np.zeros((len(months), len(lats), len(lons)))
    
    # Obtener √≠ndices de tiempos para validaci√≥n
    val_times = dataset['time'].sel(time=slice(f"{val_year}-01-01", f"{val_year}-12-31")).values
    
    # Configure plots size
    plt.rcParams['figure.figsize'] = (20, 10)
    
    # Generar predicciones para cada punto de grilla
    print_progress(f"Generando predicciones", level=1)
    
    # Esta secci√≥n depende de c√≥mo est√©n organizados tus datos
    # Ejemplo simplificado usando una funci√≥n helper
    input_tensor, target_tensor = prepare_grid_data(dataset, val_year, INPUT_WINDOW, HORIZON)
    
    # Hacer predicciones
    with torch.no_grad():
        model.eval()
        preds = model(input_tensor.to(DEVICE)).cpu().numpy()
    
    # Desescalar predicciones si tenemos los escaladores
    if scalers:
        sc_p, _ = scalers
        preds = sc_p.inverse_transform(preds.reshape(-1, HORIZON)).reshape(-1, len(lats), len(lons), HORIZON)
        # Y reacomodar ejes para formato (month, lat, lon)
        preds = np.moveaxis(preds, 3, 0)
    
    # Tambi√©n necesitamos extraer los valores reales y reacomodar
    true_vals = target_tensor.numpy().reshape(-1, len(lats), len(lons), HORIZON)
    true_vals = np.moveaxis(true_vals, 3, 0)
    
    # Calcular MAPE
    for m in range(HORIZON):
        valid_mask = true_vals[m] > 0.1  # Evitar divisiones por ~0
        mape_values[m, valid_mask] = np.abs((preds[m, valid_mask] - true_vals[m, valid_mask]) / true_vals[m, valid_mask]) * 100
    
    # Visualizar mapas para cada mes
    print_progress(f"Generando mapas mensuales", level=1)
    
    for m in range(HORIZON):
        fig = plt.figure(figsize=(18, 10))
        plt.suptitle(f"{exp_name} - {fold} - {month_names[m]} {val_year}", fontsize=16)
        
        # Preparar l√≠mites para colorbar
        vmin_pred = np.nanpercentile(true_vals, 1)
        vmax_pred = np.nanpercentile(true_vals, 99)
        vmin_mape = 0
        vmax_mape = min(100, np.nanpercentile(mape_values, 95))
        
        # Crear grid para lat/lon
        lon2d, lat2d = np.meshgrid(lons, lats)
        
        # Plot de predicci√≥n
        ax1 = plt.subplot(1, 2, 1, projection=ccrs.PlateCarree())
        ax1.set_title(f"Precipitaci√≥n Predicha (mm)")
        pcm = ax1.pcolormesh(lon2d, lat2d, preds[m], cmap='Blues', 
                           vmin=vmin_pred, vmax=vmax_pred, 
                           transform=ccrs.PlateCarree())
        ax1.coastlines(resolution='10m')
        ax1.add_feature(cfeature.BORDERS, linestyle=':')
        gl = ax1.gridlines(draw_labels=True, linewidth=0.5)
        gl.top_labels = False
        gl.right_labels = False
        plt.colorbar(pcm, ax=ax1, shrink=0.7, label='mm')
        
        # Plot de MAPE
        ax2 = plt.subplot(1, 2, 2, projection=ccrs.PlateCarree())
        ax2.set_title(f"Error MAPE (%)")
        pcm2 = ax2.pcolormesh(lon2d, lat2d, mape_values[m], cmap='Reds', 
                             vmin=vmin_mape, vmax=vmax_mape, 
                             transform=ccrs.PlateCarree())
        ax2.coastlines(resolution='10m')
        ax2.add_feature(cfeature.BORDERS, linestyle=':')
        gl = ax2.gridlines(draw_labels=True, linewidth=0.5)
        gl.top_labels = False
        gl.right_labels = False
        plt.colorbar(pcm2, ax=ax2, shrink=0.7, label='%')
        
        # Guardar figura
        plt.tight_layout(rect=[0, 0, 1, 0.95])
        fig.savefig(vis_dir / f"map_{month_names[m]}.png", dpi=120, bbox_inches='tight')
        plt.close(fig)
    
    # Generar visualizaci√≥n resumida (promedio)
    print_progress(f"Generando mapa resumen", level=1)
    
    # Calcular promedios
    avg_pred = np.nanmean(preds, axis=0)
    avg_true = np.nanmean(true_vals, axis=0)
    avg_mape = np.nanmean(mape_values, axis=0)
    
    # Plot resumen
    fig = plt.figure(figsize=(18, 10))
    plt.suptitle(f"{exp_name} - {fold} - Promedio Anual {val_year}", fontsize=16)
    
    # Plot predicci√≥n promedio
    ax1 = plt.subplot(1, 2, 1, projection=ccrs.PlateCarree())
    ax1.set_title(f"Precipitaci√≥n Media Anual (mm)")
    pcm = ax1.pcolormesh(lon2d, lat2d, avg_pred, cmap='Blues', transform=ccrs.PlateCarree())
    ax1.coastlines(resolution='10m')
    ax1.add_feature(cfeature.BORDERS, linestyle=':')
    gl = ax1.gridlines(draw_labels=True, linewidth=0.5)
    gl.top_labels = False
    gl.right_labels = False
    plt.colorbar(pcm, ax=ax1, shrink=0.7, label='mm')
    
    # Plot MAPE promedio
    ax2 = plt.subplot(1, 2, 2, projection=ccrs.PlateCarree())
    ax2.set_title(f"MAPE Promedio (%)")
    pcm2 = ax2.pcolormesh(lon2d, lat2d, avg_mape, cmap='Reds', 
                         vmin=0, vmax=min(100, np.nanpercentile(avg_mape, 95)), 
                         transform=ccrs.PlateCarree())
    ax2.coastlines(resolution='10m')
    ax2.add_feature(cfeature.BORDERS, linestyle=':')
    gl = ax2.gridlines(draw_labels=True, linewidth=0.5)
    gl.top_labels = False
    gl.right_labels = False
    plt.colorbar(pcm2, ax=ax2, shrink=0.7, label='%')
    
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(vis_dir / f"map_annual_summary.png", dpi=120, bbox_inches='tight')
    plt.close(fig)
    
    # Gr√°fico de RMSE por horizonte (1-12)
    rmse_by_horizon = [np.sqrt(np.nanmean((preds[h] - true_vals[h])**2)) for h in range(HORIZON)]
    
    fig = plt.figure(figsize=(10, 6))
    plt.plot(range(1, HORIZON+1), rmse_by_horizon, marker='o', linewidth=2)
    plt.title(f"{exp_name} - {fold} - RMSE por Horizonte", fontsize=14)
    plt.xlabel('Horizonte de Predicci√≥n (meses)', fontsize=12)
    plt.ylabel('RMSE', fontsize=12)
    plt.grid(alpha=0.3)
    plt.xticks(range(1, HORIZON+1))
    plt.tight_layout()
    fig.savefig(vis_dir / f"rmse_by_horizon.png", dpi=120)
    plt.close(fig)
    
    print_progress(f"Visualizaciones guardadas en {vis_dir}", is_end=True)
    return preds, true_vals, mape_values

# Funci√≥n auxiliar para preparar datos en formato de grilla
def prepare_grid_data(dataset, val_year, input_window, horizon):
    """
    Prepara datos de entrada y objetivo para predicciones en grilla
    
    Esta funci√≥n es un placeholder - necesitar√°s implementarla seg√∫n
    tu estructura espec√≠fica de datos
    """
    print_progress("Esta funci√≥n necesita implementaci√≥n espec√≠fica para el dataset!", level=2)
    # Placeholder - devuelve tensores vac√≠os
    return torch.zeros((1, input_window, 10)), torch.zeros((1, horizon))

Installing pytorch-lightning ...
Installing netcdf4 ...
Installing netcdf4 ...
Installing scikit-learn ...
Installing scikit-learn ...


: 