# TopoRain-Net

## Nivel 0: Preprocesamiento Avanzado (simplificado para empezar)
Usaremos precipitación, elevación y clusters de altitud.
La fusión CEEMDAN/TVF-EMD se omite por ahora para simplificar el Nivel 1. Podría ser un feature adicional al meta-modelo más tarde.

## Nivel 1: Topo-UNet Encoder con Elevation Attention (simplificada)
## Nivel 2: Bidirectional ConvLSTM
## Nivel 3: Predicción Base Multi-Modelos (reducido para simplificar)
Nos enfocaremos en 1 o 2 modelos base de Deep Learning (e.g., un GRU simple o un MLP aplicado a la salida del BiConvLSTM) para demostrar el pipeline. RF/XGBoost se pueden añadir después como benchmarks.
## Nivel 4: Meta-modelo Simple (MLP)
Tomará las salidas de los modelos base y quizás algunos features globales (como el cluster de elevación promedio de la ventana o features temporales).
Sin "Meta-Attention dinámica" por ahora.
Consideraciones Importantes Antes de Empezar:
Datos: Asegúrate de que FULL_NC contenga total_precipitation, elevation, y clusters_elevation. El FEATURES_NC actual no se usará directamente en esta primera fase simplificada (ya que CEEMDAN/TVF-EMD se omite del input principal por ahora).
Complejidad del UNet y ConvLSTM: Comenzaremos con arquitecturas de tamaño moderado.
Elevation Attention: Implementaremos una versión simple (multiplicación o concatenación seguida de una convolución).
Aquí está el esqueleto del código adaptado. Este es un punto de partida y necesitará refinamiento, depuración y experimentación.

## Puntos Clave de esta Versión Priorizada

Feature Extractor Principal: El TopoRainFeatureExtractor (que contiene TopoUNetEncoder y BiDirectionalConvLSTM) es el corazón. Se entrena primero, intentando predecir la precipitación directamente. Esto fuerza al extractor a aprender features relevantes.

Congelación: Una vez entrenado, el feature_extractor se congela.

Modelo Base Simple: Un SimpleBaseModel (MLP) se entrena usando las salidas (aplanadas) del feature_extractor congelado. Este MLP predice la precipitación en la resolución original. (Este paso es donde la simplificación es mayor; un decoder o modelos base más sofisticados serían mejores).

Generación de Predicciones para Meta: Las predicciones del SimpleBaseModel en el conjunto de validación se guardan.

Meta-Modelo: Se reutiliza UNetConvLSTMMeta (o podría ser un MLP más simple si se prefiere) para tomar las predicciones del SimpleBaseModel y refinarlas.

Simplificaciones:

Solo un modelo base (MLP) después del feature extractor.

El SimpleBaseModel tiene una forma muy simple de ir de features (C, H_red, W_red) a (T_out, H_orig, W_orig).

La "Meta-Attention dinámica" se omite.

La fusión CEEMDAN/TVF-EMD se omite del input principal al TopoUNet.

El manejo de datos para el meta-modelo es directo (usa las predicciones del MLP base en el conjunto de validación del extractor).

Próximos Pasos Después de Ejecutar Esto:

Evaluar el Feature Extractor: ¿Qué tan bien predice la precipitación por sí solo? ¿Son significativas sus pérdidas?

Evaluar el Modelo Base Simple: ¿Mejora la predicción sobre la salida directa del Feature Extractor?

Evaluar el Meta-Modelo: ¿Aporta una mejora adicional?

Iterar:

Mejorar el cabezal de predicción del FeatureExtractorPredictor (e.g., añadir un pequeño decoder UNet).

Probar modelos base más sofisticados en el Nivel 3.

Experimentar con diferentes arquitecturas para el meta-modelo (Nivel 4), incluyendo un MLP más simple si el UNetConvLSTMMeta es demasiado.

Reintroducir features CEEMDAN/TVF-EMD, ya sea como canales adicionales al TopoUNet o como inputs directos al meta-modelo.

Refinar la ElevationAttention.

In [9]:
# %% [markdown]
# # TopoRain-Net (Versión Priorizada) - Notebook

# %% [markdown]
# ## 0) Configuración entorno, rutas y dependencias
# (Misma celda de configuración que la última versión corregida, asegurando que PyTorch se importe correctamente
# y las rutas BASE_PATH, DATA_OUTPUT, MODELS_OUTPUT, etc., estén bien definidas)
# %%
import sys
import os
import logging
from pathlib import Path
import joblib
import datetime

print("Python version:", sys.version)
print("Current working directory at script start:", Path.cwd())

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
IN_COLAB = "google.colab" in sys.modules
logger.info(f"Ejecutando en Colab: {IN_COLAB}")
if IN_COLAB: logger.warning("RECUERDA: Reiniciar el entorno de ejecución si hay errores de importación de PyTorch.")

if IN_COLAB:
    from google.colab import drive
    try:
        drive.mount('/content/drive', force_remount=True); logger.info("Drive mounted.")
    except Exception as e:
        logger.error(f"Drive mount failed: {e}"); sys.exit("Drive mount failed.")
    EXPECTED_PROJECT_PATH_ON_DRIVE = Path('/content/drive/MyDrive/ml_precipitation_prediction') # AJUSTA ESTA RUTA
    if EXPECTED_PROJECT_PATH_ON_DRIVE.is_dir() and (EXPECTED_PROJECT_PATH_ON_DRIVE / '.git').is_dir():
        BASE_PATH = EXPECTED_PROJECT_PATH_ON_DRIVE
    else:
        logger.warning(f"Proyecto no encontrado en Drive o no es repo git: {EXPECTED_PROJECT_PATH_ON_DRIVE}. Clonando a temporal.")
        CLONE_DIR = Path('/content/toporain_net_project_temp');
        if CLONE_DIR.exists(): 
          !rm -rf {CLONE_DIR}
        !git clone https://github.com/ninja-marduk/ml_precipitation_prediction.git {CLONE_DIR} # Reemplaza con tu repo si es diferente
        if not CLONE_DIR.is_dir(): sys.exit("Git clone failed.")
        BASE_PATH = CLONE_DIR
    logger.info(f"BASE_PATH: {BASE_PATH}");
    if Path('torch.py').exists(): sys.exit(f"FATAL: 'torch.py' encontrado en {Path.cwd()}. Renómbralo y reinicia.")
    !pip install -q xarray netCDF4 matplotlib seaborn scikit-learn cartopy geopandas joblib # optuna lightgbm xgboost omitidos por ahora
    !pip uninstall -y torch torchvision torchaudio fastai
    !pip install -q torch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 --index-url https://download.pytorch.org/whl/cu118
else: # Local
    BASE_PATH = '..' / Path.cwd() # Simplificado, ajusta si es necesario
    if (BASE_PATH / 'torch.py').exists(): sys.exit("FATAL: 'torch.py' en el directorio del proyecto.")
logger.info(f"Directorio de trabajo actual: {Path.cwd()}")

DATA_OUTPUT = BASE_PATH.parent / 'data' / 'output'
MODELS_OUTPUT = BASE_PATH / 'models' / 'output_toporain' # Nuevo directorio para modelos
PREDS_DIR = MODELS_OUTPUT / 'base_model_predictions_toporain'
for p in [DATA_OUTPUT, MODELS_OUTPUT, PREDS_DIR]: p.mkdir(parents=True, exist_ok=True)

# Archivos de datos (Asegúrate que estos existan y contengan las variables necesarias)
# FULL_NC debe tener: total_precipitation, elevation, clusters_elevation (y opcionalmente slope, aspect, etc. si los añades)
FULL_NC = DATA_OUTPUT / 'complete_dataset_with_features_with_clusters_elevation_with_windows.nc' # Este es el crucial
SHP_PATH = BASE_PATH / 'data' / 'input' / 'shapes' / 'MGN_Departamento.shp'
if not FULL_NC.exists(): logger.error(f"CRÍTICO: {FULL_NC} no encontrado!"); sys.exit("Falta archivo de datos principal.")
if not SHP_PATH.exists(): logger.warning(f"{SHP_PATH} no encontrado.")

# Parámetros generales
INPUT_WINDOW   = 12 
OUTPUT_HORIZON = 3  
BATCH_SIZE     = 8   
MAX_EPOCHS     = 100 # Reducido para pruebas iniciales, aumentar para entrenamiento real
PATIENCE       = 15
LR_FEATURE_EXTRACTOR = 1e-4
LR_BASE_MODELS = 1e-3
LR_META_MODEL = 1e-3

logger.info("Intentando importar PyTorch DESPUÉS de la configuración del entorno...")
import torch
try:
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda")
        logger.info(f"PyTorch importado exitosamente: versión {torch.__version__}")
        logger.info(f"Usando dispositivo: {DEVICE}.")
        logger.info(f"Nombre del dispositivo CUDA: {torch.cuda.get_device_name(0)}")
        cuda_cap_val_check = torch.cuda.get_device_capability(0)
        logger.info(f"Capacidad CUDA: {cuda_cap_val_check[0]}.{cuda_cap_val_check[1]}")
    else:
        DEVICE = torch.device("cpu")
        logger.info(f"PyTorch importado exitosamente: versión {torch.__version__}")
        logger.warning(f"CUDA no disponible. Usando dispositivo: {DEVICE}. El entrenamiento será muy lento.")
except AttributeError as e_attr:
    logger.error(f"AttributeError durante la importación de PyTorch o configuración del dispositivo: {e_attr}")
    logger.error("Esto a menudo significa que la instalación de PyTorch sigue corrupta o hay un conflicto.")
    logger.error("Asegúrate de haber 'Reiniciado el Entorno de Ejecución' y que no exista un archivo 'torch.py' en tu proyecto.")
    sys.exit("Falló la inicialización de PyTorch.")
except Exception as e_gen:
    logger.error(f"Ocurrió un error inesperado durante la importación de PyTorch o configuración del dispositivo: {e_gen}")
    sys.exit("Falló la inicialización de PyTorch.")

# %% [markdown]
# ## 1) Imports adicionales y Funciones Auxiliares
# %%
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import xarray as xr
import torch.nn.functional as F # Import F para F.interpolate

def evaluate_metrics_np(y_true, y_pred):
    if not isinstance(y_true, np.ndarray): y_true = np.array(y_true)
    if not isinstance(y_pred, np.ndarray): y_pred = np.array(y_pred)
    if y_true.shape != y_pred.shape:
        min_len = min(y_true.size, y_pred.size); y_true=y_true.flatten()[:min_len]; y_pred=y_pred.flatten()[:min_len]
    valid_indices = ~np.isnan(y_true) & ~np.isnan(y_pred)
    y_true_clean, y_pred_clean = y_true[valid_indices], y_pred[valid_indices]
    if y_true_clean.size == 0: return np.nan, np.nan, np.nan, np.nan 
    rmse = np.sqrt(np.mean((y_true_clean - y_pred_clean)**2))
    mae = np.mean(np.abs(y_true_clean - y_pred_clean))
    mape = np.mean(np.abs((y_true_clean - y_pred_clean)/(y_true_clean + 1e-6))) * 100 
    y_true_var = np.sum((y_true_clean - np.mean(y_true_clean))**2)
    r2 = np.nan if y_true_var < 1e-7 else 1 - np.sum((y_true_clean - y_pred_clean)**2) / y_true_var 
    return rmse, mae, mape, r2

def plot_training_history(history, title="Training and Validation Loss"):
    plt.figure(figsize=(8,5))
    plot_train = "train_loss" in history and history["train_loss"] and any(not np.isnan(x) for x in history["train_loss"] if isinstance(x, (int,float)))
    plot_val = "val_loss" in history and history["val_loss"] and any(not np.isnan(x) for x in history["val_loss"] if isinstance(x, (int,float)))
    if plot_train: plt.plot(history["train_loss"], label="Train Loss")
    if plot_val: plt.plot(history["val_loss"], label="Validation Loss")
    plt.title(title); plt.xlabel("Epoch"); plt.ylabel("Loss")
    if plot_train or plot_val: plt.legend()
    plt.grid(True); plt.show()

# %% [markdown]
# ## Nivel 1 & 2: Topo-UNet Encoder con Elevation Attention & BiConvLSTM

# %%
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), # bias=False si usas BN
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x): return self.double_conv(x)

class ElevationAttention(nn.Module):
    def __init__(self, feature_channels, elevation_channels=1):
        super().__init__()
        self.attention_conv = nn.Conv2d(elevation_channels, feature_channels, kernel_size=1) 
    def forward(self, features, elevation_map_norm):
        elevation_map_norm = elevation_map_norm.to(features.device, features.dtype)
        if elevation_map_norm.shape[2:] != features.shape[2:]:
            elevation_map_norm = F.interpolate(elevation_map_norm, size=features.shape[2:], mode='bilinear', align_corners=False)
        attention_map = torch.sigmoid(self.attention_conv(elevation_map_norm))
        return features * attention_map

class TopoUNetEncoder(nn.Module):
    def __init__(self, in_channels_feat, out_channels_convlstm, base_unet_channels=16): # in_channels_static_topo no se usa si solo se pasa elev_map_norm
        super().__init__()
        self.inc = DoubleConv(in_channels_feat, base_unet_channels)
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(base_unet_channels, base_unet_channels*2))
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(base_unet_channels*2, base_unet_channels*4))
        self.attention_stage2 = ElevationAttention(base_unet_channels*4, 1) # Asume 1 canal para elev_map_norm
        self.final_conv = nn.Conv2d(base_unet_channels*4, out_channels_convlstm, kernel_size=1)

    def forward(self, x_temporal_feat, elevation_map_norm_static):
        x1 = self.inc(x_temporal_feat); x2 = self.down1(x1); x3 = self.down2(x2)
        x3_att = self.attention_stage2(x3, elevation_map_norm_static)
        return self.final_conv(x3_att)

class ConvLSTMCell(nn.Module): # Re-definida aquí por si acaso
    def __init__(self, in_chan, hid_chan, kernel_size=3):
        super().__init__()
        padding = kernel_size // 2
        self.conv = nn.Conv2d(in_chan + hid_chan, 4 * hid_chan, kernel_size, padding=padding)
        self.hid_chan = hid_chan
    def forward(self, x, h, c):
        combined = torch.cat([x, h], dim=1); conv_out = self.conv(combined)
        i, f, o, g = torch.split(conv_out, self.hid_chan, dim=1)
        i=torch.sigmoid(i); f=torch.sigmoid(f); o=torch.sigmoid(o); g=torch.tanh(g)
        c_next = f*c + i*g; h_next = o*torch.tanh(c_next)
        return h_next, c_next

class ConvLSTM(nn.Module): # Re-definida
    def __init__(self, in_chan, hid_chan, kernel_size=3, num_layers=1):
        super().__init__()
        self.cells = nn.ModuleList([ConvLSTMCell(in_chan if i==0 else hid_chan, hid_chan, kernel_size) for i in range(num_layers)])
    def forward(self, x): # x: (B, T, C_in, H, W)
        B, T, _, H_lstm, W_lstm = x.shape
        h_s = [torch.zeros(B,cell.hid_chan,H_lstm,W_lstm,device=x.device) for cell in self.cells]
        c_s = [torch.zeros_like(h) for h in h_s]
        out_seq = []
        for t_step in range(T):
            x_t = x[:,t_step]; curr_in = x_t
            for l_idx, cell_l in enumerate(self.cells):
                h_s[l_idx],c_s[l_idx] = cell_l(curr_in,h_s[l_idx],c_s[l_idx]); curr_in=h_s[l_idx]
            out_seq.append(h_s[-1])
        return torch.stack(out_seq, dim=1)

class BiDirectionalConvLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size=3, num_layers=1):
        super().__init__()
        self.forward_convlstm = ConvLSTM(input_dim, hidden_dim, kernel_size, num_layers)
        self.backward_convlstm = ConvLSTM(input_dim, hidden_dim, kernel_size, num_layers)
        self.output_channels = hidden_dim * 2 
    def forward(self, x_sequence):
        f_out = self.forward_convlstm(x_sequence)
        b_out_rev = self.backward_convlstm(torch.flip(x_sequence, dims=[1]))
        return torch.cat((f_out, torch.flip(b_out_rev, dims=[1])), dim=2)

class TopoRainFeatureExtractor(nn.Module):
    def __init__(self, precip_channels=1, unet_base_channels=16, convlstm_in_channels=32, convlstm_hid_channels=64):
        super().__init__()
        # static_topo_channels no es un parámetro directo del encoder ahora, solo se pasa elev_map_norm
        self.topo_unet_encoder = TopoUNetEncoder(precip_channels, convlstm_in_channels, unet_base_channels)
        self.bi_conv_lstm = BiDirectionalConvLSTM(convlstm_in_channels, convlstm_hid_channels)
        self.final_feature_channels = self.bi_conv_lstm.output_channels

    def forward(self, x_precip_seq, x_elevation_static_norm): # x_cluster_static omitido por ahora
        B, T_in, H, W = x_precip_seq.shape
        x_precip_seq_ch = x_precip_seq.unsqueeze(2) # (B, T_in, 1, H, W)
        unet_outputs_seq = [self.topo_unet_encoder(x_precip_seq_ch[:,t], x_elevation_static_norm) for t in range(T_in)]
        unet_outputs_tensor = torch.stack(unet_outputs_seq, dim=1)
        biconvlstm_output_seq = self.bi_conv_lstm(unet_outputs_tensor)
        return biconvlstm_output_seq[:, -1] # (B, C_biconvlstm_out, H_red, W_red)

# %% [markdown]
# ## Preparación de Datos para TopoRain-Net
# %%
class TopoRainDataset(Dataset):
    def __init__(self, ds_full_xr, input_window, output_horizon, target_variable='total_precipitation', 
                 static_features=['elevation']): # Solo elevación por ahora
        self.ds = ds_full_xr; self.input_window = input_window; self.output_horizon = output_horizon
        self.target_var = target_variable; self.static_vars = static_features
        self.times_ds = self.ds.time.values
        self.num_samples = len(self.times_ds) - input_window - output_horizon + 1
        if self.num_samples <= 0: logger.warning(f"TopoRainDataset: No samples ({self.num_samples})")

        self.elevation_data = self.ds['elevation'].values
        self.ele_scaler = MinMaxScaler(); self.elevation_norm = self.ele_scaler.fit_transform(self.elevation_data.reshape(-1,1)).reshape(self.elevation_data.shape)
        
        precip_vals = self.ds[self.target_var].values.reshape(-1,1)
        self.precip_scaler = StandardScaler().fit(precip_vals)
        logger.info("TopoRainDataset: Precip scaler ajustado (global a este subset de ds).")

    def __len__(self): return self.num_samples

    def __getitem__(self, idx):
        precip_in_np = self.ds[self.target_var].isel(time=slice(idx,idx+self.input_window)).values
        precip_in_s = self.precip_scaler.transform(precip_in_np.reshape(-1,1)).reshape(precip_in_np.shape)
        precip_out_np = self.ds[self.target_var].isel(time=slice(idx+self.input_window,idx+self.input_window+self.output_horizon)).values
        precip_out_s = self.precip_scaler.transform(precip_out_np.reshape(-1,1)).reshape(precip_out_np.shape)
        elev_norm_item = self.elevation_norm[np.newaxis,:,:] # (1,H,W)
        # Devuelve: precip_in_seq, elev_norm_static, target_precip_out_seq
        return (torch.from_numpy(precip_in_s).float(), 
                torch.from_numpy(elev_norm_item).float(), 
                torch.from_numpy(precip_out_s).float())

# Carga de datos y split
ds_full, ds_feat_legacy, gdf, times_global, REF_val_end_date, _ = load_and_preprocess_data_main() # ds_feat_legacy no se usa aquí
NY_GLOBAL, NX_GLOBAL = ds_full.latitude.size, ds_full.longitude.size

last_train_month_date = REF_val_end_date - np.timedelta64(OUTPUT_HORIZON, 'M')
idx_last_train_month_times = np.where(times_global == last_train_month_date)[0]

if not idx_last_train_month_times.size:
    logger.warning(f"REF logic error: idx_last_train_month not found. Fallback to 80/20 split on `times_global` indices.")
    num_time_steps = len(times_global)
    # Ensure enough data for at least one sample in train and val after windowing
    min_data_needed = (INPUT_WINDOW + OUTPUT_HORIZON -1) + 1 + (INPUT_WINDOW + OUTPUT_HORIZON -1)
    if num_time_steps < min_data_needed:
        logger.error(f"Dataset too short ({num_time_steps} steps) for 80/20 split and windowing. Needs at least {min_data_needed}. Halting.")
        sys.exit("Dataset too short for robust splitting.")
    
    # Number of samples this dataset can produce
    num_possible_samples = num_time_steps - (INPUT_WINDOW + OUTPUT_HORIZON) + 1
    num_train_samples = int(0.8 * num_possible_samples)
    if num_train_samples == 0 and num_possible_samples > 0: num_train_samples = 1 # At least one train sample
    
    # idx_split_point is the index in `times_global` which is the FIRST time step of the VALIDATION set's FIRST input window
    # The training set will use time steps up to idx_split_point - 1
    # A sample `k` starts at times_global[k].
    # Training samples go from k=0 to k=num_train_samples-1
    # First validation sample starts at k=num_train_samples
    # So, ds_train should include all time steps needed for the last training sample.
    # Last training sample (index num_train_samples-1) uses data up to
    # times_global[ (num_train_samples-1) + INPUT_WINDOW + OUTPUT_HORIZON - 1 ]
    # So, idx_split_point for isel should be exclusive: (num_train_samples-1) + INPUT_WINDOW + OUTPUT_HORIZON
    idx_split_point = num_train_samples # This is the index of the first time step that belongs to the first validation sample's input window
                                        # So ds_train.isel(time=slice(None, idx_split_point))
                                        # However, ds_full.isel uses indices relative to its own time coord.
                                        # It's simpler: ds_train includes data for `num_train_samples` possible windows.
                                        # Each window needs INPUT_WINDOW + OUTPUT_HORIZON - 1 additional steps beyond its start.
                                        # Last time step index for ds_train: (num_train_samples - 1) + INPUT_WINDOW + OUTPUT_HORIZON - 1
    # Correct split point: number of time steps in the training set's xarray slice
    # ds_train will produce `num_train_samples` via TopoRainDataset.
    # It needs `num_train_samples - 1 + INPUT_WINDOW + OUTPUT_HORIZON` time points.
    idx_split_point_for_isel = (num_train_samples -1) + INPUT_WINDOW + OUTPUT_HORIZON # Exclusive index for slice
    
    if idx_split_point_for_isel <= 0 or idx_split_point_for_isel >= len(times_global):
        logger.error(f"Invalid 80/20 split index {idx_split_point_for_isel}. Halting.")
        sys.exit("Split error.")
    logger.info(f"Using 80/20 split. Training uses time steps up to index {idx_split_point_for_isel-1}.")
else:
    idx_split_point_for_isel = idx_last_train_month_times[0] + 1 
    logger.info(f"Using REF. Training uses time steps up to index {idx_split_point_for_isel-1}.")

ds_train_xr = ds_full.isel(time=slice(None, idx_split_point_for_isel))
ds_val_xr = ds_full.isel(time=slice(idx_split_point_for_isel, None))

logger.info(f"Train Xarray time range: {pd.Period(ds_train_xr.time.values[0]).strftime('%Y-%m')} to {pd.Period(ds_train_xr.time.values[-1]).strftime('%Y-%m')}")
if len(ds_val_xr.time) > 0:
    logger.info(f"Val Xarray time range: {pd.Period(ds_val_xr.time.values[0]).strftime('%Y-%m')} to {pd.Period(ds_val_xr.time.values[-1]).strftime('%Y-%m')}")
else: logger.warning("Validation Xarray (ds_val_xr) is empty.")

train_dataset_topo = TopoRainDataset(ds_train_xr, INPUT_WINDOW, OUTPUT_HORIZON)
val_dataset_topo = TopoRainDataset(ds_val_xr, INPUT_WINDOW, OUTPUT_HORIZON)

if len(train_dataset_topo)==0: sys.exit("TopoRainDataset (train) vacío. Revisar lógica de split o datos.")
if len(val_dataset_topo)==0: logger.warning("TopoRainDataset (val) vacío. Entrenamiento sin validación.")

dl_settings = {"batch_size":BATCH_SIZE, "num_workers":min(os.cpu_count()//2,0), "pin_memory":str(DEVICE)=="cuda"} # num_workers=0 for easier debug
train_loader_topo = DataLoader(train_dataset_topo, shuffle=True, **dl_settings)
val_loader_topo = DataLoader(val_dataset_topo, shuffle=False, **dl_settings) if len(val_dataset_topo)>0 else []


# %% [markdown]
# ## Entrenamiento del Feature Extractor (TopoUNet + BiConvLSTM)

# %%
class FeatureExtractorPredictor(nn.Module):
    def __init__(self, fe_model, bilstm_out_ch, output_horizon, H_orig, W_orig): # Pass H_orig, W_orig
        super().__init__(); self.feature_extractor = fe_model
        self.final_pred_conv = nn.Conv2d(bilstm_out_ch, output_horizon, kernel_size=1)
        self.H_orig, self.W_orig = H_orig, W_orig # Store original dimensions for upsampling
        logger.info("FeatureExtractorPredictor will upsample to original H,W.")

    def forward(self, precip_seq, elev_static, cluster_static_unused): # cluster_static_unused for API consistency
        extracted_feats = self.feature_extractor(precip_seq, elev_static) # No cluster passed to FE
        preds_reduced = self.final_pred_conv(extracted_feats) # (B, T_out, H_red, W_red)
        # Upsample to original H, W
        return F.interpolate(preds_reduced, size=(self.H_orig, self.W_orig), mode='bilinear', align_corners=False)

precip_ch_fe=1; unet_base_ch_fe=32; convlstm_in_ch_fe=64; convlstm_hid_ch_fe=128
feature_extractor_inst = TopoRainFeatureExtractor(precip_ch_fe, unet_base_ch_fe, convlstm_in_ch_fe, convlstm_hid_ch_fe)
full_predictor_inst = FeatureExtractorPredictor(
    feature_extractor_inst, feature_extractor_inst.final_feature_channels, 
    OUTPUT_HORIZON, NY_GLOBAL, NX_GLOBAL # Pass original H, W
).to(DEVICE)

optimizer_fe_adam = optim.Adam(full_predictor_inst.parameters(), lr=LR_FEATURE_EXTRACTOR)
criterion_fe_mse = nn.MSELoss()
logger.info("Entrenando TopoRainFeatureExtractor con cabezal predictor...")
fe_hist_data = {"train_loss":[], "val_loss":[]}
best_fe_val_loss = float('inf'); patience_fe_count = 0

for ep_fe in range(MAX_EPOCHS // 2): # Reduced epochs for FE training
    full_predictor_inst.train(); tr_losses_fe_ep = []
    for b_idx, (p_in, e_norm, c_stat, p_out_tgt) in enumerate(train_loader_topo): # c_stat no se usa en FE
        p_in,e_norm,p_out_tgt = p_in.to(DEVICE),e_norm.to(DEVICE),p_out_tgt.to(DEVICE)
        optimizer_fe_adam.zero_grad()
        preds_fe = full_predictor_inst(p_in, e_norm, c_stat) # c_stat es dummy aquí
        loss_fe = criterion_fe_mse(preds_fe, p_out_tgt)
        loss_fe.backward(); optimizer_fe_adam.step(); tr_losses_fe_ep.append(loss_fe.item())
        if b_idx % 20 == 0: logger.debug(f"FE Tr Ep {ep_fe}, B {b_idx}, L: {loss_fe.item():.4f}")
    avg_tr_l_fe = np.mean(tr_losses_fe_ep) if tr_losses_fe_ep else np.nan; fe_hist_data["train_loss"].append(avg_tr_l_fe)

    full_predictor_inst.eval(); val_losses_fe_ep = []
    if val_loader_topo:
        with torch.no_grad():
            for p_in_v,e_norm_v,c_stat_v,p_out_tgt_v in val_loader_topo:
                p_in_v,e_norm_v,p_out_tgt_v = p_in_v.to(DEVICE),e_norm_v.to(DEVICE),p_out_tgt_v.to(DEVICE)
                preds_fe_v = full_predictor_inst(p_in_v,e_norm_v,c_stat_v)
                val_losses_fe_ep.append(criterion_fe_mse(preds_fe_v, p_out_tgt_v).item())
    avg_val_l_fe = np.mean(val_losses_fe_ep) if val_losses_fe_ep else np.nan; fe_hist_data["val_loss"].append(avg_val_l_fe)
    logger.info(f"FE Ep {ep_fe+1} - TrL: {avg_tr_l_fe:.4f} - VaL: {avg_val_l_fe:.4f}")

    current_stop_loss_fe = avg_val_l_fe if not np.isnan(avg_val_l_fe) else avg_tr_l_fe
    if not np.isnan(current_stop_loss_fe) and current_stop_loss_fe < best_fe_val_loss:
        best_fe_val_loss=current_stop_loss_fe; patience_fe_count=0
        torch.save(feature_extractor_inst.state_dict(), MODELS_OUTPUT/"best_feature_extractor.pt")
        logger.info(f"FE guardado ep {ep_fe+1}, ValL: {best_fe_val_loss:.4f}")
    elif not np.isnan(current_stop_loss_fe):
        patience_fe_count+=1
        if patience_fe_count>=PATIENCE: logger.info("ES para FE."); break
    elif np.isnan(current_stop_loss_fe) and ep_fe > PATIENCE: # No val loss, stop after patience epochs
        logger.info("Parando FE (sin val_loss, paciencia en épocas)."); torch.save(feature_extractor_inst.state_dict(), MODELS_OUTPUT/"best_feature_extractor.pt"); break
plot_training_history(fe_hist_data, "Pérdida Entrenamiento Feature Extractor")
if (MODELS_OUTPUT/"best_feature_extractor.pt").exists():
    feature_extractor_inst.load_state_dict(torch.load(MODELS_OUTPUT/"best_feature_extractor.pt",map_location=DEVICE)); logger.info("Mejor FE cargado.")
else: logger.warning("No se encontró best_feature_extractor.pt.")


# %% [markdown]
# ## Nivel 3: Modelos Base (Simplificado)
# %%
class SimpleBaseModelMLP(nn.Module): # Renamed for clarity
    def __init__(self, C_in, H_in, W_in, hidden_dim, output_horizon, H_out, W_out):
        super().__init__(); self.output_horizon,self.H_out,self.W_out = output_horizon,H_out,W_out
        input_flat_dim = C_in * H_in * W_in
        self.fc1 = nn.Linear(input_flat_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_horizon * H_out * W_out)
    def forward(self, x_feat): # x_feat: (B, C, H_red, W_red)
        B = x_feat.shape[0]; x_flat = x_feat.view(B, -1)
        x = self.relu(self.fc1(x_flat)); preds_flat = self.fc2(x)
        return preds_flat.view(B, self.output_horizon, self.H_out, self.W_out)

for param in feature_extractor_inst.parameters(): param.requires_grad = False
feature_extractor_inst.eval() 

# Dims de salida del FE: feature_extractor_inst.final_feature_channels, ny_reduced_calc, nx_reduced_calc
# (asumiendo que topo_unet_encoder reduce H,W por 4)
C_fe_out = feature_extractor_inst.final_feature_channels
H_fe_red = NY_GLOBAL // 4 
W_fe_red = NX_GLOBAL // 4

base_model_mlp_inst = SimpleBaseModelMLP(C_fe_out, H_fe_red, W_fe_red, hidden_dim=256, 
                                   output_horizon=OUTPUT_HORIZON, H_out=NY_GLOBAL, W_out=NX_GLOBAL).to(DEVICE)
optimizer_bm_adam = optim.Adam(base_model_mlp_inst.parameters(), lr=LR_BASE_MODELS)
criterion_bm_mse = nn.MSELoss()

logger.info("Entrenando SimpleBaseModelMLP...")
bm_hist_data = {"train_loss":[], "val_loss":[]}
best_bm_val_loss = float('inf'); patience_bm_count = 0
base_model_preds_for_meta_dict = {} 
preds_bm_dir = PREDS_DIR / "MLP_base_preds"; preds_bm_dir.mkdir(parents=True,exist_ok=True)

# Extraer features una vez para los DataLoaders del Base Model
X_tr_ext_bm, y_tr_tgt_bm = (None,None); X_val_ext_bm, y_val_tgt_bm = (None,None)
dl_tr_ext_bm, dl_val_ext_bm = [],[]

if len(train_dataset_topo)>0:
    logger.info("Extrayendo features (train) para Base Model...");
    X_tr_ext_bm, y_tr_tgt_bm = get_extracted_features(train_loader_topo, feature_extractor_inst, DEVICE)
    ds_tr_ext_bm = torch.utils.data.TensorDataset(X_tr_ext_bm, y_tr_tgt_bm)
    dl_tr_ext_bm = DataLoader(ds_tr_ext_bm, batch_size=BATCH_SIZE, shuffle=True, **{"num_workers":dl_settings["num_workers"], "pin_memory":dl_settings["pin_memory"]})

if len(val_dataset_topo)>0:
    logger.info("Extrayendo features (val) para Base Model y Meta Model...");
    X_val_ext_bm, y_val_tgt_bm = get_extracted_features(val_loader_topo, feature_extractor_inst, DEVICE)
    ds_val_ext_bm = torch.utils.data.TensorDataset(X_val_ext_bm, y_val_tgt_bm)
    dl_val_ext_bm = DataLoader(ds_val_ext_bm, batch_size=BATCH_SIZE, **{"num_workers":dl_settings["num_workers"], "pin_memory":dl_settings["pin_memory"]})

for ep_bm in range(MAX_EPOCHS // 2):
    base_model_mlp_inst.train(); tr_losses_bm_ep = []
    if not dl_tr_ext_bm: logger.warning("dl_tr_ext_bm vacío, saltando época BM."); break
    for x_e, y_t_bm in dl_tr_ext_bm:
        x_e,y_t_bm = x_e.to(DEVICE),y_t_bm.to(DEVICE); optimizer_bm_adam.zero_grad()
        p_bm = base_model_mlp_inst(x_e); l_bm = criterion_bm_mse(p_bm, y_t_bm)
        l_bm.backward(); optimizer_bm_adam.step(); tr_losses_bm_ep.append(l_bm.item())
    avg_tr_l_bm = np.mean(tr_losses_bm_ep) if tr_losses_bm_ep else np.nan; bm_hist_data["train_loss"].append(avg_tr_l_bm)
    
    base_model_mlp_inst.eval(); val_losses_bm_ep = []
    if dl_val_ext_bm:
        with torch.no_grad():
            for b_idx_v_bm, (x_e_v,y_t_bm_v) in enumerate(dl_val_ext_bm):
                x_e_v,y_t_bm_v = x_e_v.to(DEVICE),y_t_bm_v.to(DEVICE)
                p_bm_v = base_model_mlp_inst(x_e_v)
                val_losses_bm_ep.append(criterion_bm_mse(p_bm_v,y_t_bm_v).item())
                p_bm_v_cpu = p_bm_v.cpu().numpy()
                for i_batch in range(p_bm_v_cpu.shape[0]):
                    orig_val_idx = b_idx_v_bm*BATCH_SIZE + i_batch
                    for h_idx_m in range(OUTPUT_HORIZON):
                        pred_map_m = p_bm_v_cpu[i_batch,h_idx_m]; fname_m=f"pred_val{orig_val_idx}_h{h_idx_m+1}.npy"
                        fpath_m = preds_bm_dir/fname_m; np.save(fpath_m,pred_map_m)
                        base_model_preds_for_meta_dict[(orig_val_idx,h_idx_m)] = str(fpath_m)
    avg_val_l_bm = np.mean(val_losses_bm_ep) if val_losses_bm_ep else np.nan; bm_hist_data["val_loss"].append(avg_val_l_bm)
    logger.info(f"BM Ep {ep_bm+1} - TrL: {avg_tr_l_bm:.4f} - VaL: {avg_val_l_bm:.4f}")

    current_stop_loss_bm = avg_val_l_bm if not np.isnan(avg_val_l_bm) else avg_tr_l_bm
    if not np.isnan(current_stop_loss_bm) and current_stop_loss_bm < best_bm_val_loss:
        best_bm_val_loss=current_stop_loss_bm; patience_bm_count=0
        torch.save(base_model_mlp_inst.state_dict(),MODELS_OUTPUT/"best_mlp_base_model.pt")
    elif not np.isnan(current_stop_loss_bm):
        patience_bm_count+=1
        if patience_bm_count>=PATIENCE: logger.info("ES para MLP BM."); break
    elif np.isnan(current_stop_loss_bm) and ep_bm > PATIENCE:
        logger.info("Parando entreno BM (sin val_loss)."); torch.save(base_model_mlp_inst.state_dict(),MODELS_OUTPUT/"best_mlp_base_model.pt"); break
plot_training_history(bm_hist_data,"Pérdida Entrenamiento MLP Base Model")
if (MODELS_OUTPUT/"best_mlp_base_model.pt").exists():
    base_model_mlp_inst.load_state_dict(torch.load(MODELS_OUTPUT/"best_mlp_base_model.pt",map_location=DEVICE)); logger.info("Mejor MLP BM cargado.")

# %% [markdown]
# ## Nivel 4: Meta-Modelo Simple (Usando UNetConvLSTMMeta)
# %%
X_meta_paths_final, y_meta_paths_final = [], []
true_tgt_meta_dir = PREDS_DIR/"true_targets_for_meta_toporain"; true_tgt_meta_dir.mkdir(parents=True,exist_ok=True)

meta_model_was_trainable = False # Flag

if y_val_tgt_bm is not None and base_model_preds_for_meta_dict: # y_val_tgt_bm son los targets verdaderos del set de validación
    num_val_s_meta = y_val_tgt_bm.shape[0]
    for val_s_idx_m in range(num_val_s_meta):
        paths_h_X, paths_h_y = [],[]
        valid_sample_for_meta = True
        for h_m_idx in range(OUTPUT_HORIZON):
            pred_path_key_m = (val_s_idx_m, h_m_idx)
            if pred_path_key_m in base_model_preds_for_meta_dict:
                paths_h_X.append([base_model_preds_for_meta_dict[pred_path_key_m]]) # Lista de un solo modelo base
            else: valid_sample_for_meta=False; logger.warning(f"Falta pred base para meta: val_s {val_s_idx_m}, h {h_m_idx}"); break
        
        if valid_sample_for_meta:
            for h_m_idx in range(OUTPUT_HORIZON):
                true_map_m_tgt = y_val_tgt_bm[val_s_idx_m,h_m_idx].numpy() # .cpu() si estaba en GPU
                fname_true_m = f"true_val{val_s_idx_m}_h{h_m_idx+1}.npy"; fpath_true_m=true_tgt_meta_dir/fname_true_m
                np.save(fpath_true_m, true_map_m_tgt); paths_h_y.append(str(fpath_true_m))
            X_meta_paths_final.append(paths_h_X); y_meta_paths_final.append(paths_h_y)

if not X_meta_paths_final: logger.error("No se pudieron preparar datos para el meta-modelo. Saltando.")
else:
    meta_ds_tr = MetaDataset(X_meta_paths_final, y_meta_paths_final, ny_expected=NY_GLOBAL, nx_expected=NX_GLOBAL)
    # Usar todo para entrenar y evaluar en el mismo set por simplicidad aquí. Idealmente, dividir.
    meta_dl_tr = DataLoader(meta_ds_tr, shuffle=True, **dl_settings)
    meta_dl_ev = DataLoader(meta_ds_tr, shuffle=False, **dl_settings)
    
    meta_model_unet_inst = UNetConvLSTMMeta(num_models=1, ny=NY_GLOBAL,nx=NX_GLOBAL,output_horizon=OUTPUT_HORIZON,channels=1).to(DEVICE)
    opt_meta_adam_unet = optim.Adam(meta_model_unet_inst.parameters(), lr=LR_META_MODEL)
    crit_meta_mse_unet = nn.MSELoss()
    logger.info("Entrenando Meta-Modelo (UNetConvLSTMMeta)...")
    
    meta_model_unet_inst, meta_hist_data = train_meta_model(meta_model_unet_inst, opt_meta_adam_unet, crit_meta_mse_unet, 
                                                            meta_dl_tr, meta_dl_ev, max_epochs=MAX_EPOCHS//2, patience=PATIENCE,
                                                            save_path=MODELS_OUTPUT/"best_toporain_meta.pt")
    plot_training_history(meta_hist_data, "Pérdida Entrenamiento Meta-Modelo TopoRain")
    torch.save(meta_model_unet_inst.state_dict(), MODELS_OUTPUT/"final_toporain_meta.pt")
    meta_model_was_trainable = True # Set flag

# %% [markdown]
# ## Nivel 5: Output y Evaluación (Simplificado)
# %%
if meta_model_was_trainable and 'meta_model_unet_inst' in locals() and meta_dl_ev:
    logger.info("Evaluando Meta-Modelo TopoRain...")
    meta_model_unet_inst.eval(); all_true_m_eval, all_pred_m_eval = [],[]
    # Asumimos que train_dataset_topo.precip_scaler está disponible y es el correcto
    scaler_final_eval = train_dataset_topo.precip_scaler 
    with torch.no_grad():
        for X_m_b, y_m_b in meta_dl_ev:
            X_m_b, y_m_b_dev = X_m_b.to(DEVICE), y_m_b.to(DEVICE) # y_m_b (target) también a device para posible desescalado en GPU
            y_p_m_b = meta_model_unet_inst(X_m_b) # (B,T,H,W,C)
            
            # Desescalar predicciones y targets
            # Forma original (B,T,H,W,C=1), squeeze C -> (B,T,H,W) -> reshape (-1,1) para scaler
            y_p_m_b_rescaled = scaler_final_eval.inverse_transform(y_p_m_b.squeeze(-1).cpu().numpy().reshape(-1,1)).reshape(y_p_m_b.shape[:-1]) # B,T,H,W
            y_m_b_rescaled = scaler_final_eval.inverse_transform(y_m_b.squeeze(-1).cpu().numpy().reshape(-1,1)).reshape(y_m_b.shape[:-1])

            all_true_m_eval.append(y_m_b_rescaled.reshape(-1))
            all_pred_m_eval.append(y_p_m_b_rescaled.reshape(-1))

    if all_true_m_eval:
        y_true_flat_m_final = np.concatenate(all_true_m_eval)
        y_pred_flat_m_final = np.concatenate(all_pred_m_eval)
        rmse_f,mae_f,mape_f,r2_f = evaluate_metrics_np(y_true_flat_m_final, y_pred_flat_m_final)
        logger.info(f"Métricas Meta-Modelo TopoRain (DESESCALADAS): RMSE:{rmse_f:.4f}, MAE:{mae_f:.4f}, MAPE:{mape_f:.2f}%, R2:{r2_f:.4f}")
    else: logger.warning("No se generaron predicciones para evaluar meta-modelo.")
else: logger.warning("Meta-modelo no entrenado o datos de eval no disponibles. Saltando evaluación final.")
logger.info("FIN DEL NOTEBOOK PRIORIZADO TOPOAIN-NET.")


2025-05-22 19:11:04,445 - INFO - Ejecutando en Colab: False
2025-05-22 19:11:04,447 - INFO - Directorio de trabajo actual: /Users/riperez/Conda/anaconda3/envs/precipitation_prediction/github.com/ml_precipitation_prediction/models
2025-05-22 19:11:04,448 - INFO - Intentando importar PyTorch DESPUÉS de la configuración del entorno...
2025-05-22 19:11:04,449 - INFO - PyTorch importado exitosamente: versión 2.5.1


Python version: 3.12.10 | packaged by conda-forge | (main, Apr 10 2025, 22:19:24) [Clang 18.1.8 ]
Current working directory at script start: /Users/riperez/Conda/anaconda3/envs/precipitation_prediction/github.com/ml_precipitation_prediction/models


NameError: name 'load_and_preprocess_data_main' is not defined