In [3]:
# ============================================================
# SUITE UNIFICADA (CHIRPS + MODIS PET) — SPI + SPEI + Análisis
#   • SPI (k=1,3,6,12) sobre PR mensual CHIRPS
#   • SPEI (k=1,3,6,12) usando balance hídrico (CHIRPS − MODIS PET)
#   • Validación SPEI (CSIC/GEE vs Local)
#   • SPEI(Local) vs ENSO mensual (ONI)
#   • FIRMS VIIRS 375 m: export mensual (grilla CHIRPS) + visualización CSV
#   • Métricas globales (punto-biserial + AUC/ROC)
#   • Correlación geográficamente ponderada (GWSS) — unificado (ENSO opcional)
#   • Grilla 1:1 alineada a CHIRPS (~0.05° ≈ 5 km)
#   • Interfaz Gradio (2 pestañas)
# ============================================================

# =================== 0) INSTALACIÓN =========================
!pip -q install "gradio==5.35.0" pandas numpy scipy matplotlib \
                 earthengine-api geemap==0.30.2 rpy2 rasterio \
                 geopandas shapely pyproj fiona scikit-learn

# =================== 1) IMPORTS =============================
import os, re, io, time, shutil, pathlib, traceback, itertools, warnings, datetime as dt
import numpy as np
import pandas as pd
import gradio as gr
import matplotlib as mpl
import matplotlib.pyplot as plt
import math
from matplotlib import colors
from matplotlib.colors import TwoSlopeNorm
from dataclasses import dataclass
from pathlib import Path
from shapely.ops import unary_union

# GIS / EE / Raster
import ee, geemap, geopandas as gpd, rasterio
from rasterio.features import rasterize
from rasterio.transform import from_origin

# R (rpy2)
import rpy2.robjects as ro

# ML/Stats
from scipy.stats import pointbiserialr
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.neighbors import BallTree

warnings.filterwarnings("ignore")

# =================== 1b) FUENTE (opcional Palatino) =========
import matplotlib.font_manager as fm
PALATINO_PATH = "/content/drive/My Drive/fonts/palatino-linotype.ttf"
try:
    from google.colab import auth as gcolab_auth, drive
    gcolab_auth.authenticate_user()
    drive.mount("/content/drive")
except Exception:
    pass

try:
    if os.path.exists(PALATINO_PATH):
        fm.fontManager.addfont(PALATINO_PATH)
        PAL = fm.FontProperties(fname=PALATINO_PATH).get_name()
    else:
        PAL = "Palatino Linotype"
except Exception:
    PAL = "serif"

mpl.rcParams.update({
    "font.family":  "serif",
    "font.serif":   [PAL],
    "axes.titleweight": "semibold",
    "axes.labelsize": 12,
    "axes.titlesize": 13,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "figure.dpi": 120,
    "mathtext.fontset": "stix",
})

# =================== 1c) EE INIT ============================
try:
    PROJECT_ID = "ee-example"  # añadir Cloud Project
    ee.Initialize(project=PROJECT_ID)
except Exception:
    ee.Initialize()

# =================== 2) CONSTANTES Y DIRS ===================
DRIVE_DIR_SPI  = "/content/drive/My Drive/SPI_NIFT_outputs"
DRIVE_DIR_SPEI = "/content/drive/My Drive/SPEI_outputs_CHIRPS_MODIS"
os.makedirs(DRIVE_DIR_SPI, exist_ok=True)
os.makedirs(DRIVE_DIR_SPEI, exist_ok=True)

CSV_SPI_PR     = "Grid5k_Mean_Prec.csv"            # PR mensual (CHIRPS)
CSV_SPEI_BAL   = "Grid5k_WaterBalance_PRmPET.csv"  # (PR − PET) mensual (CHIRPS−MODIS)
FIRMS_COLLECTION = 'FIRMS'                         # VIIRS 375 m (en tu cuenta)
EXPORT_FOLDER    = 'SPEI_outputs_CHIRPS_MODIS'     # carpeta en Drive

# =================== 3) REGIÓN + GRILLA (CHIRPS) =============
def _get_regions():
    fc = (ee.FeatureCollection("FAO/GAUL/2015/level1")
            .filter(ee.Filter.inList("ADM1_NAME", ["Cundinamarca", "Boyaca"])))
    return fc, fc.union().geometry()

def _chirps_proj():
    return ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY").first().projection()

def _make_grid(region):
    """Malla 1:1 con píxeles nativos CHIRPS (0.05°). cell_id = lon*1e4_lat*1e4"""
    chirps_proj = _chirps_proj()
    grid = region.coveringGrid(chirps_proj)
    def _set_id(f):
        cent = f.geometry().centroid(1).coordinates()
        lon  = ee.Number(cent.get(0)).multiply(1e4).round()
        lat  = ee.Number(cent.get(1)).multiply(1e4).round()
        cid  = lon.format('%08d').cat('_').cat(lat.format('%08d'))
        return f.set('cell_id', cid)
    return grid.map(_set_id)

def _chirps_deg_res():
    tr = _chirps_proj().getInfo()['transform']
    a, _, _, _, e, _ = tr
    return abs(a), abs(e)

# =================== 4) COLECCIONES MENSUALES =================
def _chirps_monthly_ic(start_date, end_date):
    start = ee.Date(start_date)
    end   = ee.Date(end_date).advance(1, 'month')
    n_months = end.difference(start, 'month').subtract(1).int()
    months = ee.List.sequence(0, n_months)
    def _month_img(m):
        mstart = start.advance(m, 'month')
        mend   = mstart.advance(1, 'month')
        daily  = ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY").filterDate(mstart, mend)
        img = daily.select('precipitation').sum()
        return img.rename('pr').set('system:time_start', mstart.millis())
    return ee.ImageCollection.fromImages(months.map(_month_img))

_MODIS_PET_ID = None
def _pick_modis_pet_id():
    global _MODIS_PET_ID
    if _MODIS_PET_ID: return _MODIS_PET_ID
    for cid in ["MODIS/061/MOD16A2GF","MODIS/061/MOD16A2","MODIS/006/MOD16A2GF","MODIS/006/MOD16A2"]:
        try:
            _ = ee.ImageCollection(cid).first().getInfo()
            _MODIS_PET_ID = cid; break
        except Exception:
            continue
    if not _MODIS_PET_ID:
        raise ValueError("No se pudo abrir MOD16A2* en EE.")
    return _MODIS_PET_ID

def _modis_pet_monthly_ic(start_date, end_date, scale_factor=0.1):
    cid   = _pick_modis_pet_id()
    start = ee.Date(start_date)
    end   = ee.Date(end_date).advance(1, 'month')
    n_months = end.difference(start, 'month').subtract(1).int()
    months = ee.List.sequence(0, n_months)
    ic_pet = ee.ImageCollection(cid).select('PET')  # 0.1 mm/8-d
    def _month_img(m):
        mstart = start.advance(m, 'month')
        mend   = mstart.advance(1, 'month')
        img = ic_pet.filterDate(mstart, mend).sum().multiply(scale_factor)
        return img.rename('pet').set('system:time_start', mstart.millis())
    return ee.ImageCollection.fromImages(months.map(_month_img))

def _stack_ic(ic, band):
    return (ic.select(band)
              .map(lambda i: i.select([0], [i.date().format("yyyyMM")]))
              .toBands())

# =================== 5) EXPORTS A DRIVE =======================
def export_chirps_pr(start_date, end_date):
    """Paso SPI-1: Exporta precipitación mensual CHIRPS (mm/mes) a CSV (grilla CHIRPS)."""
    try:
        sd, ed = map(lambda s: pd.to_datetime(s).to_pydatetime(), (start_date, end_date))
        if sd > ed: return "❌ La fecha inicial es posterior a la final."
        region_fc, ROI = _get_regions()
        grid_fc = _make_grid(ROI)
        ic_pr   = _chirps_monthly_ic(start_date, end_date)
        pr_img  = _stack_ic(ic_pr, 'pr')
        chirps_proj = _chirps_proj()
        cell_res    = chirps_proj.nominalScale()
        ee.batch.Export.table.toDrive(
            collection     = pr_img.reduceRegions(grid_fc, ee.Reducer.mean(), scale=cell_res, crs=chirps_proj),
            description    = "Export_PR_CHIRPS",
            fileNamePrefix = CSV_SPI_PR[:-4],
            fileFormat     = "CSV"
        ).start()
        return ("⏳ Exportación iniciada (CHIRPS mensual).\n"
                f"Al finalizar tendrás **{CSV_SPI_PR}** en tu Drive.")
    except Exception:
        return f"❌ Error exportando PR:\n{traceback.format_exc()}"

def export_pr_minus_pet(start_date, end_date):
    """Paso SPEI-1: Exporta (PR_mean − PET_mean) mensual por celda (grilla CHIRPS)."""
    try:
        sd, ed = map(lambda s: pd.to_datetime(s).to_pydatetime(), (start_date, end_date))
        if sd > ed: return "❌ La fecha inicial es posterior a la final."
        _, ROI  = _get_regions(); grid_fc = _make_grid(ROI)
        ic_pr  = _chirps_monthly_ic(start_date, end_date)          # 'pr'
        ic_pet = _modis_pet_monthly_ic(start_date, end_date, 0.1)  # 'pet'
        pr_img  = _stack_ic(ic_pr,  'pr');  pet_img = _stack_ic(ic_pet, 'pet')

        pr_bands  = pr_img.bandNames()
        pet_bands = pet_img.bandNames()
        common_bands = ee.List(pr_bands.map(lambda b: ee.Algorithms.If(pet_bands.contains(b), b, None))).removeAll([None])

        chirps_proj  = _chirps_proj();    chirps_scale = chirps_proj.nominalScale()
        modis_proj   = ee.ImageCollection(_pick_modis_pet_id()).first().projection()
        modis_scale  = modis_proj.nominalScale()

        fc_pr = pr_img.reduceRegions(grid_fc, ee.Reducer.mean(), scale=chirps_scale, crs=chirps_proj, tileScale=2)
        fc_pet = pet_img.reduceRegions(grid_fc, ee.Reducer.mean(), scale=modis_scale,  crs=modis_proj,  tileScale=2)

        joined = ee.Join.inner().apply(primary=fc_pr, secondary=fc_pet,
                                       condition=ee.Filter.equals(leftField='cell_id', rightField='cell_id'))

        def _to_balance(jfeat):
            left  = ee.Feature(jfeat.get('primary'))
            right = ee.Feature(jfeat.get('secondary'))
            ldict, rdict = ee.Dictionary(left.toDictionary()), ee.Dictionary(right.toDictionary())
            diffs = common_bands.map(lambda b: ee.Algorithms.If(
                                      ldict.contains(b),
                                      ee.Algorithms.If(rdict.contains(b),
                                                       ee.Number(ldict.get(b)).subtract(ee.Number(rdict.get(b))), None),
                                      None))
            diff_dict = ee.Dictionary.fromLists(common_bands, diffs)
            return ee.Feature(None, {'cell_id': left.get('cell_id')}).set(diff_dict)

        fc_bal = ee.FeatureCollection(joined.map(_to_balance))

        ee.batch.Export.table.toDrive(
            collection     = fc_bal,
            description    = "Export_WBAL_CHIRPS_MODIS_mean_minus_mean",
            fileNamePrefix = CSV_SPEI_BAL[:-4],
            fileFormat     = "CSV"
        ).start()

        return ("⏳ Exportación iniciada (PR_mean − PET_mean mensual).\n"
                f"Al finalizar tendrás **{CSV_SPEI_BAL}** en tu Drive.")
    except Exception:
        return f"❌ Error exportando (PR−PET):\n{traceback.format_exc()}"

# =================== 6) SPI y SPEI con R =====================
def compute_spi_from_csv(k_list_str="1,3,6,12"):
    """Paso SPI-2: Lee PR mensual (CSV_SPI_PR) y calcula SPI_k → CSV largos."""
    try:
        csv_path = f"/content/drive/My Drive/{CSV_SPI_PR}"
        if not Path(csv_path).exists():
            return "❌ CSV no encontrado; completa antes el Paso 1 (PR)."
        df_raw = pd.read_csv(csv_path)
        df_raw.columns = [str(c).strip() for c in df_raw.columns]
        rename_map, month_keys = {}, []
        for c in df_raw.columns:
            m = re.search(r'(?<!\d)(\d{4})(\d{2})(?!\d)', str(c))
            if m:
                y, mo = int(m.group(1)), int(m.group(2))
                if 1900 <= y <= 2100 and 1 <= mo <= 12:
                    key = f"{y}{mo:02d}"
                    rename_map[c] = key; month_keys.append(key)
        if not month_keys: return "❌ El CSV no trae columnas con patrón YYYYMM."
        df_raw = df_raw.rename(columns=rename_map)
        band_cols = sorted(set(month_keys))
        cell_ids  = df_raw["cell_id"].astype(str).tolist()
        n_cells   = len(cell_ids)
        idx_map   = pd.DataFrame({"idx": range(n_cells), "cell_id": cell_ids})
        idx_map.to_csv("idx2cellid.csv", index=False)
        matrix = df_raw[band_cols].to_numpy().T
        pd.DataFrame(matrix, columns=range(n_cells)).to_csv("SPI_Input.csv", index=False)
        first_band  = band_cols[0]; start_year, start_month = int(first_band[:4]), int(first_band[4:6])
        k_vals = [int(k) for k in k_list_str.split(',') if k.strip().isdigit()]
        r_code = f'''
        if (!requireNamespace("SPEI", quietly=TRUE))
          install.packages("SPEI", repos="https://cran.rstudio.com");
        if (!requireNamespace("zoo", quietly=TRUE))
          install.packages("zoo", repos="https://cran.rstudio.com");
        if (!requireNamespace("data.table", quietly=TRUE))
          install.packages("data.table", repos="https://cran.rstudio.com");
        library(SPEI); library(zoo); library(data.table);
        mat <- as.matrix(read.csv("SPI_Input.csv", header=TRUE))
        tsm <- ts(mat, start=c({start_year},{start_month}), frequency=12)
        for (k in c({','.join(map(str,k_vals))})) {{
          spi_k <- spi(tsm, k, distribution="Gamma", na.rm=TRUE)
          m_out <- t(as.matrix(spi_k$fitted)); m_out[!is.finite(m_out)] <- NA
          fwrite(as.data.table(m_out), file=sprintf("SPI_%d_month_raw.csv", k),
                 sep=",", quote=FALSE, col.names=FALSE)
        }}
        '''
        ro.r(r_code)
        date_map = [f"{b[:4]}-{b[4:]}" for b in band_cols]
        out_files = []
        for k in k_vals:
            mat = pd.read_csv(f"SPI_{k}_month_raw.csv", header=None).values
            df  = pd.DataFrame(mat, columns=date_map[:mat.shape[1]])
            df.insert(0, "idx", range(n_cells))
            df = (df.merge(idx_map, on="idx").drop(columns="idx")
                    .melt(id_vars="cell_id", var_name="date", value_name="spi"))
            out_csv = f"SPI_{k}_month.csv"
            df.to_csv(out_csv, index=False); shutil.copy(out_csv, DRIVE_DIR_SPI); out_files.append(out_csv)
        return ("✔️ SPI calculado y exportado:\n‣ " + "\n‣ ".join(out_files) +
                f"\nArchivos en {DRIVE_DIR_SPI}")
    except Exception as e:
        return f"❌ Error SPI: {e}\n{traceback.format_exc()}"

def compute_spei_from_csv(k_list_str="1,3,6,12"):
    """
    Paso SPEI-2: Lee (PR − PET) mensual (CSV_SPEI_BAL) y calcula SPEI_k → CSV largos.
    Además, exporta parámetros de la log-logística (xi, alpha, kappa) por mes (m01..m12)
    y llama a export_spei_params_geotiffs para guardar TIFFs multibanda (36 bandas) en Drive.
    """
    try:
        # 1) Localizar CSV de balance hídrico en Drive
        csv_path = f"/content/drive/My Drive/{CSV_SPEI_BAL}"
        if not Path(csv_path).exists():
            return "❌ CSV no encontrado; completa antes el Paso 1 (PR−PET)."

        # 2) Leer y estandarizar nombres de columnas YYYYMM
        df_raw = pd.read_csv(csv_path)
        df_raw.columns = [str(c).strip() for c in df_raw.columns]

        band_cols_raw = [c for c in df_raw.columns if re.fullmatch(r"\d{6}_\d{6}", str(c))]
        rename_map = {c: c.split("_")[0] for c in band_cols_raw} if band_cols_raw else {}
        if not rename_map:
            for c in df_raw.columns:
                m = re.search(r'(?<!\d)(\d{4})(\d{2})(?!\d)', str(c))
                if m:
                    y, mo = int(m.group(1)), int(m.group(2))
                    if 1900 <= y <= 2100 and 1 <= mo <= 12:
                        rename_map[c] = f"{y}{mo:02d}"
        df_raw = df_raw.rename(columns=rename_map)
        if "cell_id" not in df_raw.columns:
            return "❌ Falta la columna 'cell_id' en el CSV de balance."
        band_cols = sorted({c for c in df_raw.columns if re.fullmatch(r"\d{6}", str(c))})
        if not band_cols:
            return "❌ No se detectaron columnas YYYYMM en el CSV de balance."

        # 3) Índices de celdas y matriz sitio×tiempo para R
        cell_ids = df_raw["cell_id"].astype(str).tolist()
        n_cells  = len(cell_ids)
        idx_map  = pd.DataFrame({"idx": range(n_cells), "cell_id": cell_ids})
        idx_map.to_csv("SPEI_idx2cellid.csv", index=False)

        matrix = df_raw[band_cols].to_numpy().T  # filas→tiempo, columnas→sitios
        pd.DataFrame(matrix, columns=range(n_cells)).to_csv("SPEI_Input.csv", index=False)

        first_band  = band_cols[0]
        start_year, start_month = int(first_band[:4]), int(first_band[4:6])

        # 4) Parsear k
        k_vals = [int(k) for k in str(k_list_str).split(',') if str(k).strip().isdigit()]
        if not k_vals:
            return "❌ k_list_str vacío o inválido."

        # 5) Ejecutar R: SPEI + export robusto de coeficientes log-logísticos por mes
        r_code = f'''
        if (!requireNamespace("SPEI", quietly=TRUE))
          install.packages("SPEI", repos="https://cran.rstudio.com");
        if (!requireNamespace("zoo", quietly=TRUE))
          install.packages("zoo", repos="https://cran.rstudio.com");
        if (!requireNamespace("data.table", quietly=TRUE))
          install.packages("data.table", repos="https://cran.rstudio.com");
        library(SPEI); library(zoo); library(data.table);

        mat <- as.matrix(read.csv("SPEI_Input.csv", header=TRUE))
        tsm <- ts(mat, start=c({start_year},{start_month}), frequency=12)
        ts_freq <- frequency(tsm)

        dir.create("SPEI_params", showWarnings=FALSE)

        get_param_mat <- function(coe, p, ts_freq) {{
          # Devuelve matriz sitios × ts_freq (m01..m12) para 'p', robusto a dims.
          if (is.null(coe)) return(NULL)

          # Array 3D
          if (is.array(coe) && length(dim(coe)) == 3) {{
            dims <- dim(coe)
            par_dim <- 1L
            month_dim <- if (dims[2] == ts_freq) 2L else if (dims[3] == ts_freq) 3L else {{
              w <- which(dims == 12L); if (length(w)) w[[1]] else 3L
            }}
            site_dim <- setdiff(1:3, c(par_dim, month_dim))[1]

            p_idx <- if (!is.null(dimnames(coe)) && !is.null(dimnames(coe)[[par_dim]])) {{
              which(dimnames(coe)[[par_dim]] == p)
            }} else if (p == "xi") 1L else if (p == "alpha") 2L else 3L
            if (length(p_idx) == 0) p_idx <- 1L

            a <- switch(par_dim,
              `1` = coe[p_idx,,, drop=FALSE],
              `2` = coe[,p_idx,, drop=FALSE],
              `3` = coe[,,p_idx, drop=FALSE]
            )

            adims <- dim(a)
            amonth <- which(adims == ts_freq)[1]; if (is.na(amonth)) amonth <- 3L
            asite  <- setdiff(1:3, c(1L, amonth))[1]
            a <- aperm(a, c(asite, amonth, setdiff(1:3, c(asite, amonth))))
            m <- drop(a)
            if (is.null(dim(m))) m <- matrix(m, nrow=1, byrow=TRUE)
            colnames(m) <- sprintf("m%02d", seq_len(ncol(m)))
            return(m)
          }}

          # Matriz 2D
          if (is.matrix(coe)) {{
            rnames <- rownames(coe); cnames <- colnames(coe)
            if (nrow(coe) == 3L) {{
              pid <- if (!is.null(rnames)) which(rnames == p) else {{
                if (p == "xi") 1L else if (p == "alpha") 2L else 3L }}
              v <- coe[pid, , drop=TRUE]
              m <- matrix(as.numeric(v), nrow=1)
            }} else if (ncol(coe) == 3L) {{
              pid <- if (!is.null(cnames)) which(cnames == p) else {{
                if (p == "xi") 1L else if (p == "alpha") 2L else 3L }}
              v <- coe[, pid, drop=TRUE]
              m <- matrix(as.numeric(v), nrow=1, byrow=TRUE)
            }} else {{
              v <- as.numeric(coe); if (length(v) %% ts_freq != 0L) v <- v[seq_len(ts_freq)]
              m <- matrix(v, nrow=1, byrow=TRUE)
            }}
            colnames(m) <- sprintf("m%02d", seq_len(ncol(m)))
            return(m)
          }}

          # Vector
          v <- as.numeric(coe); if (length(v) %% ts_freq != 0L) v <- v[seq_len(ts_freq)]
          m <- matrix(v, nrow=1, byrow=TRUE)
          colnames(m) <- sprintf("m%02d", seq_len(ncol(m)))
          return(m)
        }}

        for (k in c({','.join(map(str, k_vals))})) {{
          spei_k <- spei(tsm, k, na.rm=TRUE)
          m_out <- t(as.matrix(spei_k$fitted)); m_out[!is.finite(m_out)] <- NA
          fwrite(as.data.table(m_out), file=sprintf("SPEI_%d_month_raw.csv", k),
                 sep=",", quote=FALSE, col.names=FALSE)

          coe <- spei_k$coefficients
          if (!is.null(coe)) {{
             xi_mat    <- get_param_mat(coe, "xi",    ts_freq)
             alpha_mat <- get_param_mat(coe, "alpha", ts_freq)
             kappa_mat <- get_param_mat(coe, "kappa", ts_freq)

             if (!is.null(xi_mat))    fwrite(as.data.table(xi_mat),    file=sprintf("SPEI_params/SPEI_params_k%d_xi.csv",    k), sep=",", quote=FALSE, col.names=TRUE)
             if (!is.null(alpha_mat)) fwrite(as.data.table(alpha_mat), file=sprintf("SPEI_params/SPEI_params_k%d_alpha.csv", k), sep=",", quote=FALSE, col.names=TRUE)
             if (!is.null(kappa_mat)) fwrite(as.data.table(kappa_mat), file=sprintf("SPEI_params/SPEI_params_k%d_kappa.csv", k), sep=",", quote=FALSE, col.names=TRUE)
          }}
        }}
        '''
        ro.r(r_code)

        # 6) Convertir matrices SPEI sitio×tiempo a CSV largo y copiar a Drive
        date_map = [f"{b[:4]}-{b[4:]}" for b in band_cols]
        out_files = []
        for k in k_vals:
            raw_path = f"SPEI_{k}_month_raw.csv"
            if not Path(raw_path).exists():
                continue
            mat = pd.read_csv(raw_path, header=None).values  # sitios × tiempo
            df  = pd.DataFrame(mat, columns=date_map[:mat.shape[1]])
            df.insert(0, "idx", range(n_cells))
            df = (df.merge(idx_map, on="idx").drop(columns="idx")
                    .melt(id_vars="cell_id", var_name="date", value_name="spei"))
            out_csv = f"SPEI_{k}_month.csv"
            df.to_csv(out_csv, index=False)
            shutil.copy(out_csv, os.path.join(DRIVE_DIR_SPEI, os.path.basename(out_csv)))
            out_files.append(out_csv)

        # 7) Exportar TIFF multibanda (36 bandas) directo a Drive reutilizando la función existente
        tiff_msg = ""
        try:
            tiff_msg = export_spei_params_geotiffs(k_list_str)
        except NameError:
            tiff_msg = "⚠️ Define export_spei_params_geotiffs() antes de llamar a compute_spei_from_csv (ejecuta la celda donde está definida)."

        # 8) Mensaje final
        msg = ("✔️ SPEI calculado y exportado:\n‣ " + "\n‣ ".join(out_files)
               + f"\nCarpeta CSV: {DRIVE_DIR_SPEI}")
        if Path("SPEI_params").exists():
            msg += "\n✔️ Parámetros (xi, alpha, kappa) por mes guardados en ./SPEI_params/"
        if tiff_msg:
            msg += f"\n{tiff_msg}"
        return msg

    except Exception as e:
        return f"❌ Error SPEI: {e}\n{traceback.format_exc()}"

# =================== 7) RASTERIZAR / PLOT ====================
def _aligned_transform(bounds):
    xmin, ymin, xmax, ymax = bounds
    res_x, res_y = _chirps_deg_res()
    width  = int(np.round((xmax - xmin) / res_x))
    height = int(np.round((ymax - ymin) / res_y))
    transform = from_origin(xmin, ymax, res_x, res_y)
    return transform, width, height

def rasterize_grid_csv(csv_long, target_date, out_tif, value_col, nodata_val=-9999.0):
    df = pd.read_csv(csv_long).query("date == @target_date")
    if df.empty:
        raise ValueError(f"Fecha {target_date} no encontrada en {Path(csv_long).name}.")
    _, ROI  = _get_regions()
    grid_fc = _make_grid(ROI)
    gdf     = geemap.ee_to_gdf(grid_fc).merge(df, on='cell_id', how='left')
    gdf[value_col] = gdf[value_col].astype('float32').fillna(nodata_val)
    bounds = gdf.total_bounds
    transform, width, height = _aligned_transform(bounds)
    arr = rasterize(((geom, val) for geom, val in zip(gdf.geometry, gdf[value_col])),
                    out_shape=(height, width), transform=transform,
                    fill=nodata_val, dtype="float32")
    with rasterio.open(out_tif, "w", driver="GTiff", height=height, width=width,
                       count=1, dtype="float32", crs="EPSG:4326",
                       transform=transform, nodata=nodata_val) as dst:
        dst.write(arr, 1)

def plot_index_map(csv_dir, stem, k, date_str, value_col, title, drive_out_dir):
    try:
        region_fc, _ = _get_regions()
        gdf_regions  = geemap.ee_to_gdf(region_fc)
        csv_long = f"{csv_dir}/{stem}_{k}_month.csv"
        if not Path(csv_long).exists():
            return None, f"Archivo no encontrado: {Path(csv_long).name}"
        tif_path = f"/content/{stem}{k}_{date_str.replace('-', '')}.tif"
        rasterize_grid_csv(csv_long, date_str, tif_path, value_col)
        shutil.copy(tif_path, os.path.join(drive_out_dir, os.path.basename(tif_path)))
        from matplotlib.colors import TwoSlopeNorm
        with rasterio.open(tif_path) as src:
            data = np.ma.masked_equal(src.read(1), src.nodata)
            norm = TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3) if value_col.lower() in ("spi","spei") else None
            fig, ax = plt.subplots(figsize=(8, 6))
            im = ax.imshow(
                data, cmap="RdBu" if norm else "jet", norm=norm,
                extent=[src.bounds.left, src.bounds.right, src.bounds.bottom, src.bounds.top]
            )
            gpd.GeoSeries(unary_union(gdf_regions.boundary.geometry)).plot(
                ax=ax, edgecolor="black", linewidth=0.7, facecolor="none")
            ax.set_xlabel("Longitud"); ax.set_ylabel("Latitud"); ax.set_title(title)
            plt.colorbar(im, ax=ax, label=value_col.upper()); fig.tight_layout()
        return fig, "✔️ GeoTIFF generado"
    except Exception as e:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5,0.5,str(e), ha='center', va='center', color='red')
        return fig, f"❌ {e}"

def export_spei_params_geotiffs(k_list_str="1,3,6,12",
                                params_dir="SPEI_params",
                                out_dir=DRIVE_DIR_SPEI,
                                fname_fmt="SPEI_Params_TS{k}.tif",
                                nodata_val=-9999.0):
    """
    Genera, por cada k, un GeoTIFF multibanda (36 bandas = 12 meses × 3 parámetros),
    con bandas en el orden: xi_01..xi_12, alpha_01..alpha_12, kappa_01..kappa_12.
    Resolución/CRS: grilla CHIRPS (EPSG:4326), sobre ROI Cundinamarca+Boyacá.
    """
    # Grilla base y geometría alineada a CHIRPS
    _, ROI = _get_regions()
    grid_fc  = _make_grid(ROI)
    gdf_base = geemap.ee_to_gdf(grid_fc)

    bounds = gdf_base.total_bounds  # xmin, ymin, xmax, ymax
    res_x, res_y = _chirps_deg_res()
    width  = int(np.round((bounds[2]-bounds[0]) / res_x))
    height = int(np.round((bounds[3]-bounds[1]) / res_y))
    transform = from_origin(bounds[0], bounds[3], res_x, res_y)

    Path(out_dir).mkdir(parents=True, exist_ok=True)

    # Mapeo idx -> cell_id
    idx_map = pd.read_csv("SPEI_idx2cellid.csv", dtype={"cell_id": str})

    k_vals = [int(k) for k in str(k_list_str).split(",") if str(k).strip().isdigit()]
    if not k_vals:
        return "❌ k_list_str vacío o inválido."

    results = []
    for k in k_vals:
        # Archivos de parámetros generados en R
        f_xi    = os.path.join(params_dir, f"SPEI_params_k{k}_xi.csv")
        f_alpha = os.path.join(params_dir, f"SPEI_params_k{k}_alpha.csv")
        f_kappa = os.path.join(params_dir, f"SPEI_params_k{k}_kappa.csv")
        if not (Path(f_xi).exists() and Path(f_alpha).exists() and Path(f_kappa).exists()):
            results.append(f"⚠️ Faltan CSV de parámetros (k={k})"); continue

        xi    = pd.read_csv(f_xi, dtype=float)
        alpha = pd.read_csv(f_alpha, dtype=float)
        kappa = pd.read_csv(f_kappa, dtype=float)

        # Asegurar nombres m01..m12 si fuese necesario
        if xi.shape[1] == 12 and not all(c.startswith("m") for c in xi.columns):
            xi.columns    = [f"m{m:02d}" for m in range(1,13)]
            alpha.columns = [f"m{m:02d}" for m in range(1,13)]
            kappa.columns = [f"m{m:02d}" for m in range(1,13)]

        # Añadir idx para merge con cell_id
        xi["idx"]    = range(len(xi))
        alpha["idx"] = range(len(alpha))
        kappa["idx"] = range(len(kappa))

        # Ensamblar tabla por celda con 36 columnas (xi_01.., alpha_01.., kappa_01..)
        df = idx_map.copy()
        for p_name, df_p in [("xi", xi), ("alpha", alpha), ("kappa", kappa)]:
            df_p = df_p.merge(idx_map[["idx"]], on="idx", how="right").sort_values("idx")
            for m in range(1, 13):
                df[f"{p_name}_{m:02d}"] = pd.to_numeric(df_p.get(f"m{m:02d}", np.nan), errors="coerce")

        # Merge con geometría de grilla
        gdf = gdf_base.merge(df, on="cell_id", how="left")

        # Bandas y escritura GeoTIFF
        band_names = [f"xi_{m:02d}" for m in range(1,13)] + \
                     [f"alpha_{m:02d}" for m in range(1,13)] + \
                     [f"kappa_{m:02d}" for m in range(1,13)]

        tif_local = f"/content/{fname_fmt.format(k=k)}"
        profile = {
            "driver": "GTiff",
            "height": height,
            "width": width,
            "count": len(band_names),
            "dtype": "float32",
            "crs": "EPSG:4326",
            "transform": transform,
            "nodata": nodata_val,
            "tiled": True,
            "compress": "LZW",
        }
        with rasterio.open(tif_local, "w", **profile) as dst:
            for bi, bname in enumerate(band_names, start=1):
                vals = pd.to_numeric(gdf[bname], errors="coerce").astype("float32").fillna(nodata_val)
                arr = rasterize(
                    ((geom, val) for geom, val in zip(gdf.geometry, vals)),
                    out_shape=(height, width), transform=transform,
                    fill=nodata_val, dtype="float32",
                )
                dst.write(arr, bi)
            # Metadatos
            dst.update_tags(
                creator="Hidro-Suite (SPEI params)",
                description=f"Coeficientes log-logística SPEI TS-{k} (xi, alpha, kappa) por mes",
                timeseries_k=k,
                bands=",".join(band_names)
            )

        out_path = os.path.join(out_dir, os.path.basename(tif_local))
        shutil.copy(tif_local, out_path)
        results.append(os.path.basename(out_path))

    ok = [r for r in results if r.lower().endswith(".tif")]
    extra = [r for r in results if not r.lower().endswith(".tif")]
    msg = ("✔️ GeoTIFFs de parámetros guardados:\n‣ " + "\n‣ ".join(ok)) if ok else "⚠️ No se generaron TIFF."
    if extra:
        msg += ("\n\nObservaciones:\n- " + "\n- ".join(extra))
    msg += f"\nCarpeta: {out_dir}"
    return msg

# =================== 8) NIFT (sobre SPI) =====================
_DRY_CLASSES = {"mild": (0.0, -1.0), "moderate": (-1.0, -1.5), "severe": (-1.5, -2.0), "extreme": (-2.0, -np.inf)}
DEFAULT_WEIGHTS_BRASIL_NETO = {
    "P1": 0.125, "P2": 0.125, "P3": 0.009, "P4": 0.034, "P5": 0.071,
    "P6": 0.136, "P7": 0.100, "P8": 0.075, "P9": 0.075, "P10": 0.250
}
DEFAULT_WEIGHTS_UNIFORME = {f"P{i}": 0.10 for i in range(1, 11)}
_NEGATIVE_PARAMS = ["P7", "P10"]
def _sanitize_weights(w: dict, renormalize: bool = True) -> dict:
    """Valida claves P1..P10, fuerza no-negatividad y (opcional) renormaliza a suma=1."""
    w = {f"P{i}": float(w.get(f"P{i}", 0.0)) for i in range(1, 11)}
    for k in w:
        if not np.isfinite(w[k]) or w[k] < 0: w[k] = 0.0
    s = sum(w.values())
    if renormalize and s > 0:
        w = {k: v / s for k, v in w.items()}
    return w

def _detect_events(series, run_min=3):
    events, run, start_idx = [], [], None
    for i, val in enumerate(series):
        if np.isfinite(val) and val <= 0:
            if not run: start_idx = i
            run.append(val)
        else:
            if len(run) >= run_min:
                events.append((start_idx, len(run), np.sum(run)))
            run, start_idx = [], None
    if len(run) >= run_min:
        events.append((start_idx, len(run), np.sum(run)))
    return events

def _sen_slope(y, x):
    if len(y) < 2: return np.nan
    slopes = [(y[j]-y[i])/(x[j]-x[i]) for i, j in itertools.combinations(range(len(y)), 2)]
    return np.median(slopes)

def _compute_nift_single(cell_df, run_min=3):
    s_all = cell_df["spi"].astype(float).values
    mask  = np.isfinite(s_all)
    if mask.sum() == 0:
        return dict(P1=0, P2=0, P3=0, P4=0, P5=0, P6=0, P7=np.nan, P8=np.nan, P9=np.nan)
    s = s_all[mask]
    total_valid = mask.sum()
    pct = {cls: 100.0 * np.logical_and(s <= hi, s > lo).sum() / total_valid if total_valid else 0
           for cls, (hi, lo) in _DRY_CLASSES.items()}
    events = _detect_events(s_all, run_min)
    if events:
        durations  = np.array([e[1] for e in events], float)
        severities = np.abs(np.array([e[2] for e in events], float))
        P1 = float(len(events)); P2 = float(np.mean(severities/durations))
    else:
        P1 = P2 = 0.0; durations = severities = np.array([], float)
    months_idx = np.flatnonzero(mask).astype(float); P7 = _sen_slope(s, months_idx)
    if durations.size:
        starts = np.array([e[0] for e in events], float)
        P8 = _sen_slope(durations,  starts); P9 = _sen_slope(severities, starts)
    else:
        P8 = P9 = np.nan
    return dict(P1=P1, P2=P2, P3=pct['mild'], P4=pct['moderate'], P5=pct['severe'], P6=pct['extreme'], P7=P7, P8=P8, P9=P9)

def _prepare_mean_precip(start="1985-01-01", end="2024-12-31"):
    cm = _chirps_monthly_ic(start, end).select('pr')
    n_years = (pd.to_datetime(end) - pd.to_datetime(start)).days / 365.25
    ann_pr  = cm.sum().divide(n_years)
    chirps_proj = _chirps_proj(); cell_res = chirps_proj.nominalScale()
    _, ROI   = _get_regions(); grid_fc  = _make_grid(ROI)
    pr_fc    = ann_pr.reduceRegions(grid_fc, ee.Reducer.mean(), scale=cell_res, crs=chirps_proj)
    gdf_pr   = geemap.ee_to_gdf(pr_fc).rename(columns={'mean': 'prec'})
    return gdf_pr[['cell_id', 'prec']]

def compute_nift(k_list_str="1,3,6,12", run_min=3, weights: dict | None = None, renormalize=True):
    """
    Calcula NIFT permitiendo pesos editables desde UI.
    - weights: dict con claves 'P1'..'P10'. Por defecto usa Brasil Neto.
    - renormalize=True: reescala para que sumen 1 si fuese necesario.
    - Exporta cada parámetro (P1-P10, Pn, NIFT) como un GeoTIFF individual a la carpeta de Drive (DRIVE_DIR_SPI) al momento del cálculo.
    """
    try:
        # --- CONFIGURACIÓN DE GEOMETRÍA PARA RASTERIZAR (alineado a CHIRPS) ---
        _, ROI = _get_regions()
        grid_fc  = _make_grid(ROI)
        gdf_base = geemap.ee_to_gdf(grid_fc)
        bounds = gdf_base.total_bounds
        res_x, res_y = _chirps_deg_res()
        width  = int(np.round((bounds[2]-bounds[0]) / res_x))
        height = int(np.round((bounds[3]-bounds[1]) / res_y))
        transform = from_origin(bounds[0], bounds[3], res_x, res_y)
        nodata_val = -9999.0
        # --- FIN CONFIGURACIÓN GEOMETRÍA ---

        # 1) elegir pesos
        if weights is None:
            weights = DEFAULT_WEIGHTS_BRASIL_NETO.copy()
        weights = _sanitize_weights(weights, renormalize=renormalize)

        # 2) parseo de k
        k_vals = [int(k) for k in str(k_list_str).split(',') if str(k).strip().isdigit()]
        if not k_vals:
            return "❌ No se proporcionó ningún k válido."

        resultados = []
        for k in k_vals:
            csv_long = f"{DRIVE_DIR_SPI}/SPI_{k}_month.csv"
            if not Path(csv_long).exists():
                resultados.append(f"⚠️ SPI-{k} no encontrado"); continue

            # 3) computar parámetros P1..P9 por celda
            df = pd.read_csv(csv_long, dtype={'cell_id': str})
            out = []
            for cid, grp in df.groupby('cell_id'):
                param = _compute_nift_single(grp.sort_values('date'), run_min)
                param['cell_id'] = cid
                out.append(param)
            nift_df = pd.DataFrame(out)

            # 4) P10 = precipitación media anual (ya lo haces con EE)
            pr_df   = _prepare_mean_precip()
            nift_df = nift_df.merge(pr_df, on='cell_id', how='left').rename(columns={'prec':'P10'})

            # 5) normalización 0..1 con signo correcto (P7,P10 negativas)
            for col in [f"P{i}" for i in range(1, 11)]:
                vals = pd.to_numeric(nift_df[col], errors="coerce")
                if vals.notna().sum() == 0:
                    nift_df[f"{col}n"] = 0.0
                    continue
                mx, mn = np.nanmax(vals), np.nanmin(vals)
                if np.isclose(mx, mn, atol=1e-12):
                    nift_df[f"{col}n"] = 0.0
                elif col in _NEGATIVE_PARAMS:
                    nift_df[f"{col}n"] = (mx - vals) / (mx - mn)
                else:
                    nift_df[f"{col}n"] = (vals - mn) / (mx - mn)

            # 6) NIFT = sum_j w_j * Pjn * 100
            cols_n = [f"P{i}n" for i in range(1, 11)]
            w_vec  = np.array([weights[f"P{i}"] for i in range(1, 11)], float)
            nift_df['NIFT'] = nift_df[cols_n].to_numpy().dot(w_vec) * 100.0

            # 7) meta
            nift_df['date'] = df['date'].max()

            # 8) guardar CSV y preset usado
            out_csv = f"NIFT_k{k}.csv"
            nift_df.to_csv(out_csv, index=False)
            shutil.copy(out_csv, os.path.join(DRIVE_DIR_SPI, os.path.basename(out_csv)))
            with open(os.path.join(DRIVE_DIR_SPI, f"NIFT_weights_k{k}.json"), "w") as f:
                import json; json.dump({"weights": weights, "k": k, "run_min": run_min}, f, indent=2)

            # --- NUEVO: Exportar cada parámetro como un GeoTIFF individual ---
            tiff_msg = ""
            try:
                params_to_export = [f"P{i}" for i in range(1, 11)] + \
                                   [f"P{i}n" for i in range(1, 11)] + \
                                   ['NIFT']
                n_tiffs_saved = 0
                for param in params_to_export:
                    if param not in nift_df.columns:
                        continue

                    # Merge con la grilla base
                    gdf = gdf_base.merge(nift_df[["cell_id", param]], on="cell_id", how="left")
                    vals = pd.to_numeric(gdf[param], errors="coerce").astype("float32").fillna(nodata_val)

                    # Rasterizar
                    arr = rasterize(
                        ((geom, val) for geom, val in zip(gdf.geometry, vals)),
                        out_shape=(height, width), transform=transform,
                        fill=nodata_val, dtype="float32",
                    )

                    # Escribir GeoTIFF local y copiar a Drive
                    tif_local = f"/content/NIFT_Param_{param}_TS{k}.tif"
                    profile = {
                        "driver": "GTiff", "height": height, "width": width, "count": 1,
                        "dtype": "float32", "crs": "EPSG:4326", "transform": transform,
                        "nodata": nodata_val, "tiled": True, "compress": "LZW",
                    }
                    with rasterio.open(tif_local, "w", **profile) as dst:
                        dst.write(arr, 1)
                        dst.update_tags(description=f"NIFT Parameter {param} for Timescale {k}")

                    out_path = os.path.join(DRIVE_DIR_SPI, os.path.basename(tif_local))
                    shutil.copy(tif_local, out_path)
                    n_tiffs_saved += 1

                if n_tiffs_saved > 0:
                    tiff_msg = f"\n✔️ {n_tiffs_saved} GeoTIFFs de parámetros guardados en {DRIVE_DIR_SPI}."

            except Exception as e_tiff:
                tiff_msg = f"\n⚠️ Error al guardar GeoTIFFs de parámetros para k={k}: {e_tiff}"
            # --- FIN NUEVO ---

            msg = f"✔️ NIFT-{k} calculado (preset w sum={sum(weights.values()):.3f}) → {out_csv}"
            msg += tiff_msg
            resultados.append(msg)

        return "\n".join(resultados)
    except Exception as e:
        return f"❌ Error NIFT: {e}"

def export_nift_geotiffs(k_list_str="1,3,6,12", out_dir=DRIVE_DIR_SPI,
                         fname_fmt="NIFT_TS{k}.tif", nodata_val=-9999.0):
    """
    Exporta un GeoTIFF por cada k con el valor final de NIFT por celda (0–100).
    • Resolución/CRS: alineado a CHIRPS (~0.05°, EPSG:4326).
    • Salida: copia los .tif a out_dir (por defecto, DRIVE_DIR_SPI).
    """
    try:
        # --- grilla base y geometría alineada a CHIRPS ---
        _, ROI = _get_regions()
        grid_fc  = _make_grid(ROI)
        gdf_base = geemap.ee_to_gdf(grid_fc)

        bounds = gdf_base.total_bounds  # xmin, ymin, xmax, ymax
        res_x, res_y = _chirps_deg_res()
        width  = int(np.round((bounds[2]-bounds[0]) / res_x))
        height = int(np.round((bounds[3]-bounds[1]) / res_y))
        transform = from_origin(bounds[0], bounds[3], res_x, res_y)

        Path(out_dir).mkdir(parents=True, exist_ok=True)

        # --- lista de k ---
        k_vals = [int(k) for k in str(k_list_str).split(",") if str(k).strip().isdigit()]
        if not k_vals:
            return "❌ k_list_str vacío o inválido."

        saved = []
        for k in k_vals:
            csv_path = f"{DRIVE_DIR_SPI}/NIFT_k{k}.csv"
            if not Path(csv_path).exists():
                saved.append(f"⚠️ Falta {Path(csv_path).name}");
                continue

            df_k = pd.read_csv(csv_path, dtype={"cell_id": str})
            if "NIFT" not in df_k.columns:
                saved.append(f"❌ {Path(csv_path).name} sin columna 'NIFT'")
                continue

            # Merge con grilla y rasterización
            gdf = gdf_base.merge(df_k[["cell_id", "NIFT"]], on="cell_id", how="left")
            gdf["NIFT"] = pd.to_numeric(gdf["NIFT"], errors="coerce").astype("float32").fillna(nodata_val)

            arr = rasterize(
                ((geom, val) for geom, val in zip(gdf.geometry, gdf["NIFT"])),
                out_shape=(height, width),
                transform=transform,
                fill=nodata_val,
                dtype="float32",
            )

            # Escribir GeoTIFF (temporal en /content → copiar a Drive)
            tif_local = f"/content/{fname_fmt.format(k=k)}"
            profile = {
                "driver": "GTiff",
                "height": height,
                "width": width,
                "count": 1,
                "dtype": "float32",
                "crs": "EPSG:4326",
                "transform": transform,
                "nodata": nodata_val,
                "tiled": True,
                "compress": "LZW",
            }
            with rasterio.open(tif_local, "w", **profile) as dst:
                dst.write(arr, 1)
                # opcional: metadatos
                dst.update_tags(creator="Hidro-Suite (NIFT)",
                                description=f"NIFT TS-{k} (0–100), alineado a CHIRPS",
                                timeseries_k=k)

            out_path = os.path.join(out_dir, os.path.basename(tif_local))
            shutil.copy(tif_local, out_path)
            saved.append(os.path.basename(out_path))

        ok = [s for s in saved if s.lower().endswith(".tif")]
        extra = [s for s in saved if not s.lower().endswith(".tif")]
        msg = ("✔️ GeoTIFFs guardados:\n‣ " + "\n‣ ".join(ok)) if ok else "⚠️ No se generaron TIFF."
        if extra:
            msg += ("\n\nObservaciones:\n- " + "\n- ".join(extra))
        msg += f"\nCarpeta: {out_dir}"
        return msg

    except Exception as e:
        return f"❌ Error exportando NIFT a GeoTIFF: {e}\n{traceback.format_exc()}"

from PIL import Image
def save_png_trim(fig, out_path, dpi=600, bg='white', thr=252, pad_px=0):
    """
    Guarda 'fig' como PNG de alta resolución.
    """
    import io, numpy as _np
    buf = io.BytesIO()
    # guardado inicial (aprovecha bbox tight para minimizar)
    fig.savefig(buf, format="png", dpi=int(dpi),
                facecolor=bg, edgecolor=bg,
                bbox_inches="tight", pad_inches=0.01)
    buf.seek(0)

    im = Image.open(buf).convert("RGBA")
    arr = _np.array(im)            # HxWx4
    rgb, alpha = arr[..., :3], arr[..., 3]
    # Trata transparencia como blanco puro
    rgb = _np.where(alpha[..., None] == 0, 255, rgb)

    # máscara de "no-blanco" (si cualquier canal < thr)
    nonwhite = (rgb < thr).any(axis=2)
    if nonwhite.any():
        ys, xs = _np.where(nonwhite)
        y0, y1 = ys.min(), ys.max() + 1
        x0, x1 = xs.min(), xs.max() + 1
        # acolchado
        y0 = max(0, y0 - pad_px); x0 = max(0, x0 - pad_px)
        y1 = min(im.height, y1 + pad_px); x1 = min(im.width, x1 + pad_px)
        im = im.crop((x0, y0, x1, y1))
    im.save(out_path, format="PNG", optimize=True)


# --- Conjuntos de parámetros para figuras compactas ---
_SPI_GROUPS = {
    "G1": ["P1","P2"],              # dinámica de eventos
    "G2": ["P3","P4","P5","P6"],    # % mild/moderate/severe/extreme
    "G3": ["P7","P8","P9"],         # tendencias (Sen)
    "G4": ["P10"],                  # precipitación media anual
}

def _spi_plot_maps(k_list_str="1,3,6,12", use_normalized=False):
    try:
        k_vals = [int(k) for k in k_list_str.split(",") if k.strip().isdigit()]
        if not k_vals: raise ValueError("No se especificó ningún k válido.")

        # Base CHIRPS / grilla
        _, ROI       = _get_regions()
        grid_fc      = _make_grid(ROI)
        gdf_base     = geemap.ee_to_gdf(grid_fc)

        region_fc, _ = _get_regions()
        gdf_regions  = geemap.ee_to_gdf(region_fc)
        single_border = unary_union(gdf_regions.boundary.geometry)

        bounds = gdf_base.total_bounds
        res_x, res_y = _chirps_deg_res()
        width  = int(np.round((bounds[2]-bounds[0])/res_x))
        height = int(np.round((bounds[3]-bounds[1])/res_y))
        transform = from_origin(bounds[0], bounds[3], res_x, res_y)

        # --- figura de parámetros P1–P10 ---
        var_params = [f"P{i}" for i in range(1,11)]
        n_rows, n_cols = len(var_params), len(k_vals)
        fig_params, ax_p = plt.subplots(n_rows, n_cols,
                                        figsize=(3.1*n_cols, 2.4*n_rows),
                                        sharex=True, sharey=True, dpi=300)
        if n_cols == 1:
            ax_p = np.expand_dims(ax_p, 1)

        # 1) NORMALIZACIÓN POR FILA (igual que en _spi_plot_maps_grouped)
        row_norms = []
        for r, var in enumerate(var_params):
            colname = var + ("n" if use_normalized else "")
            vals_concat = []
            for k in k_vals:
                df_k = pd.read_csv(f"{DRIVE_DIR_SPI}/NIFT_k{k}.csv", dtype={"cell_id": str})
                if colname in df_k.columns:
                    vals_concat.append(pd.to_numeric(df_k[colname], errors="coerce").values)
            all_vals = np.concatenate(vals_concat) if vals_concat else np.array([0.0])
            if use_normalized:
                row_norms.append(colors.Normalize(vmin=0.0, vmax=1.0))
            else:
                vmin = float(np.nanmin(all_vals)) if np.isfinite(all_vals).any() else 0.0
                vmax = float(np.nanmax(all_vals)) if np.isfinite(all_vals).any() else 1.0
                if np.isclose(vmin, vmax, atol=1e-12):
                    vmin, vmax = 0.0, 1.0
                row_norms.append(colors.Normalize(vmin=vmin, vmax=vmax))

        # 2) DIBUJO USANDO ESA norm POR FILA
        for c, k in enumerate(k_vals):
            csv_path = f"{DRIVE_DIR_SPI}/NIFT_k{k}.csv"
            if not Path(csv_path).exists():
                raise FileNotFoundError(f"{csv_path} no encontrado.")
            df_k = pd.read_csv(csv_path, dtype={"cell_id": str})

            for r, var in enumerate(var_params):
                col = var + ("n" if use_normalized else "")
                gdf = gdf_base.merge(df_k[["cell_id", col]], on="cell_id", how="left")
                gdf[col] = pd.to_numeric(gdf[col], errors="coerce").astype("float32")
                arr = rasterize(((geom, val) for geom, val in zip(gdf.geometry, gdf[col].fillna(-9999.0))),
                                out_shape=(height, width), transform=transform,
                                fill=-9999.0, dtype="float32")
                data = np.ma.masked_equal(arr, -9999.0)

                ax = ax_p[r, c]
                im = ax.imshow(
                    data, cmap="jet", norm=row_norms[r],
                    extent=[bounds[0], bounds[2], bounds[1], bounds[3]]
                )
                gpd.GeoSeries(single_border).plot(ax=ax, edgecolor="black", linewidth=0.7, facecolor="none")
                if r == 0:
                    ax.set_title(f"TS-{k}", fontsize=9, pad=2)
                if c == 0:
                    ax.text(-0.05, 0.5, var, transform=ax.transAxes, ha="right", va="center", fontsize=9)
                ax.axis("off")

        # 3) BARRAS DE COLOR (una por fila, consistente con TODOS los k)
        for r in range(n_rows):
            sm = mpl.cm.ScalarMappable(norm=row_norms[r], cmap="jet")
            sm.set_array([])
            fig_params.colorbar(sm, ax=ax_p[r, -1], fraction=0.04, pad=0.02)

        fig_params.tight_layout()

        # --- figura NIFT por k ---
        n_cols_nift = 2 if len(k_vals) <= 4 else 3
        n_rows_nift = int(np.ceil(len(k_vals)/n_cols_nift))
        fig_nift, ax_n = plt.subplots(n_rows_nift, n_cols_nift,
                                      figsize=(4*n_cols_nift+0.5, 3*n_rows_nift),
                                      sharex=True, sharey=True, dpi=300)
        ax_n = ax_n.ravel()
        im_nift = None

        tag = "_".join(map(str, k_vals)) + ("_norm" if use_normalized else "")
        out_maps = os.path.join(DRIVE_DIR_SPI, f"NIFT_all_maps_TS_{tag}.png")
        save_png_trim(fig_nift, out_maps, dpi=600)

        for i, k in enumerate(k_vals):
            csv_path = f"{DRIVE_DIR_SPI}/NIFT_k{k}.csv"
            df_k     = pd.read_csv(csv_path, dtype={"cell_id": str})
            gdf      = gdf_base.merge(df_k[["cell_id","NIFT"]], on="cell_id", how="left")
            gdf["NIFT"] = pd.to_numeric(gdf["NIFT"], errors="coerce").astype("float32").fillna(-9999.0)
            arr = rasterize(((geom, val) for geom, val in zip(gdf.geometry, gdf["NIFT"])),
                            out_shape=(height, width), transform=transform,
                            fill=-9999.0, dtype="float32")
            data = np.ma.masked_equal(arr, -9999.0)
            ax   = ax_n[i]
            im_nift = ax.imshow(data, cmap="jet", norm=colors.Normalize(vmin=0, vmax=100),
                                extent=[bounds[0], bounds[2], bounds[1], bounds[3]])
            gdf_regions.boundary.plot(ax=ax, edgecolor="black", linewidth=0.7, facecolor="none")
            ax.set_title(f"TS-{k}", fontsize=9, pad=2); ax.axis("off")

        for ax in ax_n[len(k_vals):]:
            ax.axis("off")
        fig_nift.tight_layout(rect=[0, 0.02, 0.86, 1])
        cax = fig_nift.add_axes([0.88, 0.20, 0.025, 0.65])
        fig_nift.colorbar(im_nift, cax=cax, label="NIFT (0–100)")

        tag = "_".join(map(str, k_vals)) + ("_norm" if use_normalized else "")
        out_params = os.path.join(DRIVE_DIR_SPI, f"NIFT_P1P10_full_TS_{tag}.png")
        out_maps   = os.path.join(DRIVE_DIR_SPI, f"NIFT_maps_TS_{tag}.png")
        save_png_trim(fig_params, out_params, dpi=600)
        save_png_trim(fig_nift,   out_maps,   dpi=600)

        return fig_params, fig_nift, "✔️ Mapas NIFT y P1–P10 generados (escalas por fila unificadas)."
    except Exception as e:
        fig_err, ax_err = plt.subplots(); ax_err.axis("off")
        ax_err.text(0.5,0.5,str(e), ha="center", va="center", color="red")
        return fig_err, fig_err, f"❌ {e}"


def _spi_plot_maps_grouped(k_list_str="1,3,6,12", use_normalized=False):
    """
    Devuelve 4 figuras compactas (G1..G4) con barras de color por fila y títulos TS-k.
    """
    try:
        k_vals = [int(k) for k in k_list_str.split(",") if k.strip().isdigit()]
        if not k_vals: raise ValueError("No se especificó ningún k válido.")

        _, ROI       = _get_regions()
        grid_fc      = _make_grid(ROI)
        gdf_base     = geemap.ee_to_gdf(grid_fc)

        region_fc, _ = _get_regions()
        gdf_regions  = geemap.ee_to_gdf(region_fc)
        single_border = unary_union(gdf_regions.boundary.geometry)

        bounds = gdf_base.total_bounds
        res_x, res_y = _chirps_deg_res()
        width  = int(np.round((bounds[2]-bounds[0])/res_x))
        height = int(np.round((bounds[3]-bounds[1])/res_y))
        transform = from_origin(bounds[0], bounds[3], res_x, res_y)

        TITLE_FZ, ROW_LABEL_FZ, CB_TICK_FZ = 16, 16, 12
        MAP_W_IN, MAP_H_IN = 2.6, 2.2
        LABEL_W_IN, CB_W_IN = 0.85, 0.35
        TITLE_PAD, TITLE_Y = 3.0, 0.985
        WSPACE, HSPACE = 0.004, 0.004
        LEFT_M, RIGHT_M, TOP_M, BOT_M = 0.015, 0.90, 0.94, 0.06
        TOP_M_SINGLE = 0.88
        CB_HEIGHT_FRAC, CB_INSET_W_FRAC = 0.82, 0.75
        from matplotlib import gridspec

        def _build_group(var_list):
            n_rows, n_cols = len(var_list), len(k_vals)
            fig_w = LABEL_W_IN + n_cols*MAP_W_IN + CB_W_IN
            fig_h = n_rows*MAP_H_IN
            fig = plt.figure(figsize=(fig_w, fig_h), dpi=300)
            fig.subplots_adjust(left=LEFT_M, right=RIGHT_M,
                                top=(TOP_M if n_rows>1 else TOP_M_SINGLE), bottom=BOT_M)
            width_ratios = [LABEL_W_IN/MAP_W_IN] + [1.0]*n_cols + [CB_W_IN/MAP_W_IN]
            gs = gridspec.GridSpec(nrows=n_rows, ncols=n_cols+2, figure=fig,
                                   width_ratios=width_ratios, height_ratios=[1.0]*n_rows,
                                   wspace=WSPACE, hspace=HSPACE)

            # normalización por fila
            row_norms = []
            for var in var_list:
                colname = var+"n" if use_normalized else var
                vals_concat = []
                for k in k_vals:
                    df_k = pd.read_csv(f"{DRIVE_DIR_SPI}/NIFT_k{k}.csv", dtype={"cell_id": str})
                    vals_concat.append(df_k[colname].astype(float).values)
                all_vals = np.concatenate(vals_concat) if vals_concat else np.array([0.0])
                if use_normalized:
                    row_norms.append(colors.Normalize(vmin=0, vmax=1))
                else:
                    vmin = float(np.nanmin(all_vals)) if np.isfinite(all_vals).any() else 0.0
                    vmax = float(np.nanmax(all_vals)) if np.isfinite(all_vals).any() else 1.0
                    if np.isclose(vmin, vmax, atol=1e-12): vmin, vmax = 0.0, 1.0
                    row_norms.append(colors.Normalize(vmin=vmin, vmax=vmax))

            for r, var in enumerate(var_list):
                colname = var+"n" if use_normalized else var
                ax_label = fig.add_subplot(gs[r, 0]); ax_label.axis("off")
                ax_label.text(0.98, 0.5, var, transform=ax_label.transAxes,
                              ha="right", va="center", fontsize=ROW_LABEL_FZ)

                last_im = None
                for c, k in enumerate(k_vals):
                    df_k = pd.read_csv(f"{DRIVE_DIR_SPI}/NIFT_k{k}.csv", dtype={"cell_id": str})
                    gdf  = gdf_base.merge(df_k[["cell_id", colname]], on="cell_id", how="left")
                    gdf[colname] = gdf[colname].astype("float32").fillna(-9999.0)
                    arr = rasterize(((geom, val) for geom, val in zip(gdf.geometry, gdf[colname])),
                                    out_shape=(height, width), transform=transform,
                                    fill=-9999.0, dtype="float32")
                    data = np.ma.masked_equal(arr, -9999.0)
                    ax   = fig.add_subplot(gs[r, c+1])
                    last_im = ax.imshow(data, cmap="jet", norm=row_norms[r],
                                        extent=[bounds[0], bounds[2], bounds[1], bounds[3]])
                    gpd.GeoSeries(single_border).plot(ax=ax, edgecolor="black", linewidth=0.7, facecolor="none")
                    if r == 0: ax.set_title(f"TS-{k}", fontsize=TITLE_FZ, pad=TITLE_PAD, y=TITLE_Y)
                    ax.axis("off")

                cax_cell = fig.add_subplot(gs[r, n_cols+1]); cax_cell.axis("off")
                cax = cax_cell.inset_axes([(1-CB_INSET_W_FRAC)/2, (1-CB_HEIGHT_FRAC)/2,
                                           CB_INSET_W_FRAC, CB_HEIGHT_FRAC])
                sm = mpl.cm.ScalarMappable(norm=row_norms[r], cmap="jet"); sm.set_array([])
                cb = fig.colorbar(sm, cax=cax, orientation="vertical")
                cb.ax.tick_params(labelsize=CB_TICK_FZ, length=2, pad=2)
            return fig

        figs = [ _build_group(_SPI_GROUPS[key]) for key in ["G1","G2","G3","G4"] ]

        tag = "_".join([x.strip() for x in k_list_str.split(",") if x.strip().isdigit()]) + ("_norm" if use_normalized else "")
        names = ["G1_dynEventos","G2_porcentajes","G3_tendencias","G4_precip"]
        for fig, name in zip(figs, names):
            outp = os.path.join(DRIVE_DIR_SPI, f"NIFT_{name}_TS_{tag}.png")
            save_png_trim(fig, outp, dpi=600)

        return figs[0], figs[1], figs[2], figs[3], "✔️ Figuras G1–G4 generadas."
    except Exception as e:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5,0.5,str(e), ha="center", va="center", color="red")
        return fig, fig, fig, fig, f"❌ {e}"

# =================== 9) VALIDACIÓN SPEI (CSIC vs Local) =====
def _gee_mean_series(k, start='1958-01-01', end='2023-01-01'):
    band = f"SPEI_{int(k):02d}_month"
    ic = (ee.ImageCollection('CSIC/SPEI/2_10').select(band).filterDate(start, end))
    ROI = _get_regions()[1]
    def _to_feat(img):
        mean = img.reduceRegion(ee.Reducer.mean(), ROI, 10000, maxPixels=1e13).get(band)
        return ee.Feature(None, {'date': img.date().format('YYYY-MM'), 'mean': mean})
    fc = ee.FeatureCollection(ic.map(_to_feat)).filter(ee.Filter.notNull(['mean']))
    gdf = geemap.ee_to_gdf(fc)
    df  = gdf.drop(columns='geometry', errors='ignore')
    df['date'] = pd.to_datetime(df['date'])
    ser = (df.sort_values('date').set_index('date')['mean'].astype(float))
    return ser

def _csv_mean_series(k):
    path = f'{DRIVE_DIR_SPEI}/SPEI_{k}_month.csv'
    if not pathlib.Path(path).exists():
        raise FileNotFoundError(f'CSV {path} no encontrado.')
    df = pd.read_csv(path, parse_dates=['date'])
    ser = df.groupby('date')['spei'].mean().sort_index().astype(float)
    return ser

def _plot_validation_series(k_list_str="1,3,6,12", start_date="", end_date=""):
    try:
        k_vals = [int(x) for x in k_list_str.split(',') if x.strip().isdigit()]
        sd = pd.to_datetime(start_date) if start_date.strip() else None
        ed = pd.to_datetime(end_date)   if end_date.strip() else None
        n = len(k_vals)
        fig, axs = plt.subplots(n, 1, figsize=(10, 3*n), sharex=True)
        if n == 1: axs = [axs]
        for ax, k in zip(axs, k_vals):
            ser_gee = _gee_mean_series(k)
            ser_loc = _csv_mean_series(k)
            df = pd.concat({'GEE': ser_gee, 'Local': ser_loc}, axis=1).dropna(how='all')
            if sd is not None: df = df.loc[sd:]
            if ed is not None: df = df.loc[:ed]
            if df.empty:
                ax.text(0.5, 0.5, f'SPEI-{k}: sin datos', ha='center', va='center')
                ax.axis('off'); continue
            df['GEE'].plot(ax=ax, lw=1, label='GEE')
            df['Local'].plot(ax=ax, lw=1, ls='--', label='Local')
            ax.set_ylabel(f'SPEI-{k}'); ax.grid(ls=':')
            ax.legend(); ax.set_title(f'k={k}')
        fig.tight_layout()
        return fig, "✔️ Curvas generadas."
    except Exception as e:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5,0.5,f'Error: {e}',ha='center',va='center',color='red')
        return fig, f"❌ {e}"

# =================== 10) ENSO (ONI) ==========================
def _read_enso_monthly(csv_path: str):
    if not os.path.isfile(csv_path):
        raise FileNotFoundError(f"No se encontró ENSO.csv en: {csv_path}")
    raw = pd.read_csv(csv_path, dtype=str)
    if 'Year' not in raw.columns:
        raw = pd.read_csv(csv_path, dtype=str, sep=';')
    raw = raw[raw['Year'].str.fullmatch(r'\d{4}', na=False)].copy()
    raw['Year'] = raw['Year'].astype(int)
    seasons = ['DJF','JFM','FMA','MAM','AMJ','MJJ','JJA','JAS','ASO','SON','OND','NDJ']
    missing = [c for c in seasons if c not in raw.columns]
    if missing: raise ValueError(f"ENSO.csv sin columnas: {missing}")
    for c in seasons: raw[c] = pd.to_numeric(raw[c].str.strip(), errors='coerce')
    season_to_month = {'DJF':1,'JFM':2,'FMA':3,'MAM':4,'AMJ':5,'MJJ':6,'JJA':7,'JAS':8,'ASO':9,'SON':10,'OND':11,'NDJ':12}
    records = []
    for _, r in raw.iterrows():
        y = int(r['Year'])
        for s, m in season_to_month.items():
            val = r[s]
            if pd.notna(val):
                records.append((pd.Timestamp(year=y, month=m, day=1), float(val)))
    ser = pd.Series(dict(records)).sort_index(); ser.name = 'ENSO'; return ser


AXIS_LABEL_FONTSIZE = 15
TICK_LABEL_FONTSIZE = 13
LEGEND_FONTSIZE     = 13
METRICS_FONTSIZE    = 13
TITLE_FONTSIZE      = 16
CB_LABEL_FONTSIZE   = 13
CB_TICK_FONTSIZE    = 12


def _plot_spei_vs_enso(k_list_str="1,3,6,12", start_date="", end_date="", enso_csv_path=None,
                       enso_lag=-1, invert_axis=True, show_colorbar=False):
    import matplotlib.dates as mdates
    from matplotlib.collections import LineCollection
    try:
        if enso_csv_path is None:
            enso_csv_path = f"{DRIVE_DIR_SPEI}/ENSO.csv"
        enso = _read_enso_monthly(enso_csv_path)
        if int(enso_lag) != 0:
            enso = enso.shift(-int(enso_lag), freq="MS")
        if start_date.strip(): enso = enso[enso.index >= pd.to_datetime(start_date)]
        if end_date.strip():   enso = enso[enso.index <= pd.to_datetime(end_date)]

        k_vals = [int(x) for x in k_list_str.split(',') if str(x).strip().isdigit()]
        if not k_vals: raise ValueError("k_list_str vacío o inválido.")
        n = len(k_vals)

        fig, axs = plt.subplots(n, 1, figsize=(11, 3.4*n), sharex=True)
        if n == 1: axs = [axs]

        for ax, k in zip(axs, k_vals):
            spei = _csv_mean_series(k).astype(float)
            if start_date.strip(): spei = spei[spei.index >= pd.to_datetime(start_date)]
            if end_date.strip():   spei = spei[spei.index <= pd.to_datetime(end_date)]

            df = pd.concat({"SPEI": spei, "ENSO": enso.astype(float)}, axis=1).dropna()
            if df.empty:
                ax.axis('off'); ax.text(0.5,0.5,f'k={k}: sin solapamiento', ha='center', va='center', fontsize=METRICS_FONTSIZE); continue

            r_pearson  = df["SPEI"].corr(df["ENSO"], method="pearson")
            r_spearman = df["SPEI"].corr(df["ENSO"], method="spearman")

            # Serie SPEI (línea gris)
            ax.plot(df.index, df["SPEI"], lw=0.9, color="0.25", label="SPEI (Local)")
            ax.set_ylabel(f"SPEI-{k}", fontsize=AXIS_LABEL_FONTSIZE)
            ax.tick_params(axis="both", labelsize=TICK_LABEL_FONTSIZE)
            ax.grid(ls=":", alpha=0.7)

            # ENSO como línea coloreada (eje secundario)
            x = mdates.date2num(df.index.to_pydatetime())
            y = df["ENSO"].values.astype(float)
            points = np.array([x, y]).T.reshape(-1, 1, 2)
            segments = np.concatenate([points[:-1], points[1:]], axis=1)

            max_abs = float(np.nanmax(np.abs(y))); max_abs = max(max_abs, 2.0)
            norm = colors.TwoSlopeNorm(vmin=-max_abs, vcenter=0.0, vmax=max_abs)
            cmap = plt.get_cmap("RdBu_r")

            lc = LineCollection(segments, cmap=cmap, norm=norm, linewidth=1.5, zorder=3)
            lc.set_array((y[:-1] + y[1:]) / 2.0)

            ax2 = ax.twinx()
            ax2.add_collection(lc)
            ax2.set_xlim(x.min(), x.max())
            if invert_axis:
                ax2.set_ylim(max_abs, -max_abs)
            else:
                ax2.set_ylim(-max_abs, max_abs)
            ax2.set_ylabel("ENSO (ONI)", fontsize=AXIS_LABEL_FONTSIZE)
            ax2.tick_params(axis="both", labelsize=TICK_LABEL_FONTSIZE)

            if show_colorbar:
                cbar = fig.colorbar(lc, ax=ax2, fraction=0.035, pad=0.02)
                cbar.set_label("ENSO (ONI)", fontsize=CB_LABEL_FONTSIZE)
                cbar.ax.tick_params(labelsize=CB_TICK_FONTSIZE)
                cbar.set_ticks([-2, -1, 0, 1, 2])

            # Métricas
            ax.text(0.02, 0.98,
                    f"r={r_pearson:.3f} · ρ={r_spearman:.3f} · n={len(df)} · lag_ENSO={int(enso_lag)}",
                    transform=ax.transAxes, ha='left', va='top',
                    bbox=dict(facecolor='white', edgecolor='0.6', alpha=0.85, pad=3),
                    fontsize=METRICS_FONTSIZE, zorder=10)

            ax.set_title(f"SPEI vs ENSO (k={k})", fontsize=TITLE_FONTSIZE, pad=4)

        axs[-1].set_xlabel("Fecha", fontsize=AXIS_LABEL_FONTSIZE)
        fig.tight_layout()
        return fig, "✔️ Comparación SPEI–ENSO generada."
    except Exception as e:
        fig, ax = plt.subplots(); ax.axis('off')
        ax.text(0.5, 0.5, f'Error: {e}', ha='center', va='center', color='red', fontsize=METRICS_FONTSIZE)
        return fig, f"❌ {e}"

def _plot_spei_vs_spei_clc(k_list_str="1,3,6,12", start_date="", end_date=""):
    """
    Validación SPEI (Local) vs SPEI CSIC con el estilo clásico:
      • CSIC línea continua
      • Local línea punteada
      • Texto con r de Pearson, rho de Spearman y n
    """
    try:
        k_vals = [int(x) for x in k_list_str.split(',') if str(x).strip().isdigit()]
        if not k_vals:
            raise ValueError("k_list_str vacío o inválido.")
        n = len(k_vals)

        fig, axs = plt.subplots(n, 1, figsize=(11, 3.2*n), sharex=True)
        if n == 1:
            axs = [axs]

        for ax, k in zip(axs, k_vals):
            ser_gee = _gee_mean_series(k)   # CSIC/GEE
            ser_loc = _csv_mean_series(k)   # Local

            df = pd.concat({"CSIC": ser_gee.astype(float),
                            "Local": ser_loc.astype(float)}, axis=1).dropna()

            if start_date.strip():
                df = df[df.index >= pd.to_datetime(start_date)]
            if end_date.strip():
                df = df[df.index <= pd.to_datetime(end_date)]

            if df.empty:
                ax.axis('off')
                ax.text(0.5, 0.5, f'k={k}: sin solapamiento', ha='center', va='center', fontsize=METRICS_FONTSIZE)
                continue

            # Métricas
            r_pearson  = df["Local"].corr(df["CSIC"], method="pearson")
            r_spearman = df["Local"].corr(df["CSIC"], method="spearman")

            ax.plot(df.index, df["CSIC"],   lw=1.2, color="C0", label="SPEI (CSIC)")
            ax.plot(df.index, df["Local"], lw=1.2, color="C1", ls="--", label="SPEI (Local)")

            ax.set_ylabel(f"SPEI-{k}", fontsize=AXIS_LABEL_FONTSIZE)
            ax.tick_params(axis="both", labelsize=TICK_LABEL_FONTSIZE)
            ax.grid(ls=":", alpha=0.7)

            # Leyenda
            ax.legend(loc="upper left", fontsize=LEGEND_FONTSIZE)

            ax.set_title(f"Validación SPEI Local vs CSIC (k={k})", fontsize=TITLE_FONTSIZE, pad=4)

            # Métricas
            ax.text(0.98, 0.98,
                    f"r={r_pearson:.3f} · ρ={r_spearman:.3f} · n={len(df)}",
                    transform=ax.transAxes, ha="right", va="top",
                    bbox=dict(facecolor='white', edgecolor='0.6', alpha=0.85, pad=3),
                    fontsize=METRICS_FONTSIZE, zorder=10)

        axs[-1].set_xlabel("Fecha", fontsize=AXIS_LABEL_FONTSIZE)
        fig.tight_layout()
        return fig, "✔️ Validación (línea punteada) generada."
    except Exception as e:
        fig, ax = plt.subplots(); ax.axis('off')
        ax.text(0.5, 0.5, f'Error: {e}', ha='center', va='center', color='red', fontsize=METRICS_FONTSIZE)
        return fig, f"❌ {e}"

# =================== 11) FIRMS (VIIRS 375 m) =================
def _firms_monthly_fc(start_date, end_date, conf_thr=80):
    sd, ed = ee.Date(start_date), ee.Date(end_date)
    if sd.millis().getInfo() >= ed.millis().getInfo():
        raise ValueError("start_date debe ser < end_date")
    grid_fc  = _make_grid(_get_regions()[1])
    n_months = ed.difference(sd, 'month').int()
    def _per_month(n):
        d0 = sd.advance(n, 'month'); d1 = d0.advance(1, 'month')
        presence = (ee.ImageCollection(FIRMS_COLLECTION)
                      .filterDate(d0, d1).filterBounds(_get_regions()[1])
                      .map(lambda im: im.updateMask(im.select('confidence').gte(conf_thr)).select('T21').gt(0))
                      .max().unmask(0).rename('fires').toByte())
        month_fc = (presence.reduceRegions(collection=grid_fc, reducer=ee.Reducer.max(), scale=375, crs='EPSG:4326')
                      .map(lambda f: (f.set({'fires': ee.Number(f.get('max')).int(), 'date' : d0.format('YYYY-MM')})
                                         .select(['cell_id','date','fires'])))
                      .filter(ee.Filter.eq('fires', 1)))
        return month_fc
    return ee.FeatureCollection(ee.List.sequence(0, n_months.subtract(1)).map(_per_month)).flatten()

def _export_firms_monthly(start_date, end_date, conf_thr=80):
    try:
        fc  = _firms_monthly_fc(start_date, end_date, conf_thr)
        tag = f"{start_date.replace('-','')}_{end_date.replace('-','')}"
        ee.batch.Export.table.toDrive(
            collection     = fc.filter(ee.Filter.gt('fires', 0)),
            description    = f'FIRMS_MONTH_{tag}',
            folder         = EXPORT_FOLDER,
            fileNamePrefix = f'FIRMS_MONTH_{tag}',
            fileFormat     = 'CSV',
            selectors      = ['cell_id', 'date', 'fires']
        ).start()
        return (f"⏳ Exportando FIRMS ({tag}) — confianza ≥{conf_thr}.")
    except Exception as e:
        return f"❌ {e}"

def _csv_firms_monthly_df(csv_path: str) -> pd.DataFrame:
    if not os.path.isfile(csv_path):
        raise FileNotFoundError(f"CSV no encontrado: {csv_path}")
    df = pd.read_csv(csv_path)
    req = {'cell_id', 'date', 'fires'}
    if not req.issubset(df.columns):
        raise ValueError(f"El CSV debe tener columnas {req}")
    df['date']  = pd.to_datetime(df['date'])
    df['fires'] = df['fires'].fillna(0).astype(int)
    return df[['cell_id', 'date', 'fires']]

def rasterize_fires(csv_path, start_month, end_month, out_tif, nodata_val=-9999.0):
    """GeoTIFF binario 1=presencia (alineado CHIRPS) desde CSV mensual."""
    df = pd.read_csv(csv_path, parse_dates=['date'])
    df['month'] = df['date'].dt.to_period('M')
    sm, em = pd.Period(start_month, 'M'), pd.Period(end_month, 'M')
    df = df[(df['month'] >= sm) & (df['month'] <= em)]
    if df.empty: raise ValueError("El rango seleccionado no contiene incendios.")
    pres_df = (df.groupby('cell_id', as_index=False).agg(presence=('fires', lambda _: 1)))
    _, ROI  = _get_regions(); grid_fc = _make_grid(ROI)
    gdf     = geemap.ee_to_gdf(grid_fc).merge(pres_df, on='cell_id', how='left')
    gdf['presence'] = gdf['presence'].fillna(0).astype('float32')
    res_x, res_y = _chirps_deg_res()
    xmin, ymin, xmax, ymax = gdf.total_bounds
    width  = int(np.round((xmax - xmin) / res_x)); height = int(np.round((ymax - ymin) / res_y))
    transform = from_origin(xmin, ymax, res_x, res_y)
    arr = rasterize(((geom, val) for geom, val in zip(gdf.geometry, gdf.presence)),
                    out_shape=(height, width), transform=transform, fill=nodata_val, dtype='float32')
    with rasterio.open(out_tif, 'w', driver='GTiff', height=height, width=width, count=1, dtype='float32',
                       crs='EPSG:4326', transform=transform, nodata=nodata_val) as dst:
        dst.write(arr, 1)

def _download_fires_gee_tif(start_month, end_month, out_tif, conf_thr=80):
    sd = ee.Date(f"{start_month}-01"); ed = ee.Date(f"{end_month}-01").advance(1, 'month')
    viirs_proj = (ee.ImageCollection(FIRMS_COLLECTION).first().select('T21').projection())
    ch_proj    = _chirps_proj()
    presence_375 = (ee.ImageCollection(FIRMS_COLLECTION)
                      .filterDate(sd, ed).filterBounds(_get_regions()[1])
                      .map(lambda im: im.updateMask(im.select('confidence').gte(conf_thr)).select('T21').gt(0))
                      .max().unmask(0).setDefaultProjection(viirs_proj))
    presence_ch = (presence_375.reduceResolution(ee.Reducer.max(), bestEffort=True, maxPixels=256)
                              .reproject(ch_proj).toByte())
    geemap.ee_export_image(presence_ch, filename=out_tif,
                           scale=ch_proj.nominalScale().getInfo(),
                           region=_get_regions()[1], file_per_band=False)

def _plot_fires(csv_path, start_month, end_month):
    """Lado a lado: (1) raster desde CSV y (2) GeoTIFF descargado de EE (ambos a grilla CHIRPS)."""
    from matplotlib import colors as mcolors
    try:
        tag = f"{start_month.replace('-','')}_{end_month.replace('-','')}"
        local_tif = f"/content/FIRES_LOCAL_{tag}.tif"
        gee_tif   = f"/content/FIRES_GEE_{tag}.tif"
        rasterize_fires(csv_path, start_month, end_month, local_tif)
        _download_fires_gee_tif(start_month, end_month, gee_tif)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)
        for ax, tif, title in zip((ax1, ax2), (local_tif, gee_tif), ("CSV → raster", "GEE directo")):
            with rasterio.open(tif) as src:
                data = src.read(1)
                mask = data == src.nodata if src.nodata is not None else np.isnan(data)
                data = np.ma.masked_where(mask, data)
                cmap   = mcolors.ListedColormap(['white', 'red'])
                bounds = [-0.5, 0.5, 1.5]; norm = mcolors.BoundaryNorm(bounds, cmap.N)
                ax.imshow(data, cmap=cmap, norm=norm,
                          extent=[src.bounds.left, src.bounds.right, src.bounds.bottom, src.bounds.top])
                ax.set_title(title); ax.set_xlabel("Longitud"); ax.set_ylabel("Latitud"); ax.grid(ls=":")
        fig.suptitle(f"Incendios {start_month} → {end_month}"); fig.tight_layout()
        return fig, "✔️ Comparación generada"
    except Exception as e:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5, 0.5, str(e), ha='center', va='center', color='red')
        return fig, f"❌ {e}"

# =================== 12) MÉTRICAS GLOBALES ===================
def correlate_spei_firms_csv(k: int, lag: int, firms_csv_path: str, fire_min: int = 1):
    try:
        fig, ax = plt.subplots(); ax.axis('off')
        ax.text(0.5,0.5,"OBSOLETO: Usar análisis GWSS (basado en distancia)", ha='center',va='center',color='red')
        return fig, fig, fig, "Función obsoleta. Usar GWSS."
    except Exception as e:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5,0.5,str(e),ha='center',va='center',color='red')
        return fig, fig, fig, f"❌ {e}"

# =================== 13) GWSS — UNIFICADO (SIMPLIFICADO) ====================

def _parse_years_str(years_str: str):
    years = set()
    if not years_str or not str(years_str).strip(): return []
    parts = re.split(r'[,\s]+', years_str.strip())
    for p in parts:
        if not p: continue
        if '-' in p:
            a, b = p.split('-', 1); a, b = int(a), int(b); years.update(range(min(a, b), max(a, b) + 1))
        else:
            years.add(int(p))
    return sorted(years)

def _cellid_to_lonlat(cell_id: str):
    lon_str, lat_str = str(cell_id).split('_'); return int(lon_str)/1e4, int(lat_str)/1e4

#
# LA FUNCIÓN _apply_riskset_and_sampling HA SIDO ELIMINADA
#

def _months_by_enso_phase(enso_csv_path, phase="nino", thr=0.5, enso_lag=0):
    ser = _read_enso_monthly(enso_csv_path).astype(float)
    if int(enso_lag) != 0: ser = ser.shift(-int(enso_lag), freq="MS")
    if phase == "nino":   mask = ser >= thr
    elif phase == "nina": mask = ser <= -thr
    elif phase == "neutral": mask = (ser.abs() < thr)
    else: raise ValueError("phase: 'nino'|'nina'|'neutral'")
    months = pd.Series(mask[mask.index.notna() & mask.notna()])
    return set(months[months].index.to_period("M"))

def _prepare_gw_input_unified(k: int, lag: int, firms_csv_path: str, years_str: str,
                              phase: str, enso_csv_path: str, enso_thr: float, enso_lag: int):
    """
    PREPARA ENTRADA GW (celda-mes) usando la lógica de DISTANCIA AL INCENDIO MÁS CERCANO.
    - X (spei): Valor SPEI mensual de la celda.
    - Y (fire): Métrica de "cercanía" (closeness) al incendio más cercano ESE MES.
                closeness = 1.0 / (distancia_km + 1.0)
                (1.0 si la celda tiene fuego, 0.0 si el mes no tuvo incendios)
    """
    from sklearn.neighbors import BallTree
    import numpy as np

    # --- 1. SPEI mensual (aplicar lag y periodizar a 'M') ---
    spei_path = f"{DRIVE_DIR_SPEI}/SPEI_{k}_month.csv"
    if not os.path.isfile(spei_path):
        raise FileNotFoundError(f"No existe {spei_path}")
    spei = pd.read_csv(spei_path, parse_dates=['date'], dtype={'cell_id': str})
    spei['date'] = (spei['date'] + pd.DateOffset(months=int(lag))).dt.to_period('M')

    # --- 2. FIRMS mensual normalizado ---
    fires = _csv_firms_monthly_df(firms_csv_path) # valida schema
    fires['date'] = fires['date'].dt.to_period('M')

    # --- 3. Filtros por años y ENSO (construir conjunto de meses a conservar) ---
    months_keep: set[pd.Period] | None = None
    years = _parse_years_str(years_str)

    all_spei_months = set(spei['date'].unique())

    if years:
        months_years = {m for m in all_spei_months if m.year in years}
        months_keep = months_years if months_years else set()

    if phase in ("nino", "nina", "neutral"):
        enso_csv_eff = enso_csv_path or f"{DRIVE_DIR_SPEI}/ENSO.csv"
        months_phase = _months_by_enso_phase(enso_csv_eff, phase=phase, thr=float(enso_thr), enso_lag=int(enso_lag))
        months_keep = months_phase if months_keep is None else (months_keep & months_phase)

    if months_keep is not None:
        if len(months_keep) == 0:
            raise ValueError("Tras filtros de años/ENSO no quedan meses.")
        spei  = spei[spei['date'].isin(months_keep)]
        fires = fires[fires['date'].isin(months_keep)]

    if spei.empty:
        raise ValueError("SPEI vacío tras filtros.")

    # --- 4. LÓGICA DE DISTANCIA (CERCANÍA) ---

    # 4a. Celdas maestras (todas las que tienen datos SPEI)
    all_cells = pd.DataFrame(spei['cell_id'].unique(), columns=['cell_id'])
    coords_lonlat = pd.DataFrame(
        all_cells['cell_id'].apply(_cellid_to_lonlat).tolist(),
        index=all_cells.index,
        columns=['lon', 'lat']
    )
    all_cells = all_cells.join(coords_lonlat)
    coords_all_rad = np.deg2rad(all_cells[['lat', 'lon']].values.astype(float))
    all_cells_idx_map = pd.Series(index=all_cells['cell_id'], data=np.arange(len(all_cells)))

    # 4b. Celdas con incendios POR MES
    fire_events_by_month = fires[fires['fires'] > 0][['date', 'cell_id']].drop_duplicates()
    fire_events_by_month['idx'] = fire_events_by_month['cell_id'].map(all_cells_idx_map)
    fire_idxs_by_month = fire_events_by_month.dropna(subset=['idx']).groupby('date')['idx'].apply(list)

    # 4c. Calcular distancias SOLO para meses CON incendios
    monthly_results = []
    R_km = 6371.0088

    valid_months_with_fires = set(spei['date'].unique()) & set(fire_idxs_by_month.index)

    for month in valid_months_with_fires:
        fire_indices = fire_idxs_by_month[month]
        fire_indices_int = [int(i) for i in fire_indices if pd.notna(i)]
        if not fire_indices_int:
            continue

        fire_coords_rad = coords_all_rad[fire_indices_int]

        tree = BallTree(fire_coords_rad, metric='haversine')
        distances_rad, _ = tree.query(coords_all_rad, k=1)
        distances_km = distances_rad.flatten() * R_km

        closeness = 1.0 / (distances_km + 1.0)

        month_df = all_cells[['cell_id']].copy()
        month_df['date'] = month
        month_df['fire'] = closeness # Y = Closeness (1.0 = en fuego, ~0.0 = lejos)

        monthly_results.append(month_df)

    if monthly_results:
        df_dist = pd.concat(monthly_results)
    else:
        df_dist = pd.DataFrame(columns=['cell_id', 'date', 'fire'])

    # 5. Unir SPEI, Coordenadas y Distancias

    # 5a. Unir SPEI con la lista maestra de celdas (para obtener lon/lat)
    df = spei.merge(all_cells, on='cell_id', how='left')

    # 5b. Unir (how='left') con los resultados de distancias.
    #      Meses SIN incendios tendrán NaN en 'fire'.
    df = df.merge(df_dist, on=['cell_id', 'date'], how='left')

    # 5c. Rellenar los NaN (meses sin incendios):
    #     - 'fire' (closeness) = 0.0 (distancia infinita)
    df['fire'] = df['fire'].fillna(0.0)

    # 5d. Limpiar datos inválidos de SPEI o coordenadas
    df = df[np.isfinite(df['spei']) & np.isfinite(df['fire']) & np.isfinite(df['lon'])]

    if df.empty:
         raise ValueError("El dataframe final está vacío tras unir SPEI y distancias.")

    # --- 6. Balanceo y Riskset ELIMINADOS ---

    if df['fire'].nunique() < 2:
        raise ValueError("La variable 'fire' (closeness) no tiene varianza. (Probablemente 0 incendios en el período)")

    # --- 7. Guardar entrada (todas filas) + sitios (una por celda) ---
    df[['cell_id', 'lon', 'lat', 'spei', 'fire']].to_csv("gw_input.csv", index=False)
    # Guardar TODOS los sitios (celdas)
    all_cells[['cell_id', 'lon', 'lat']].to_csv("gw_sites.csv", index=False)

    return len(df)


def _run_gwss_gwmodel(bw_neighbors: int = 120, kernel: str = "bisquare", use_spearman: bool = False):
    # ===================== CAMBIO 1: Lógica para priorizar columnas =====================
    if use_spearman:
        search_order = (
            "'SCorr_spei.fire','SCorr_fire.spei','SCorr_spei_fire','SCorr_fire_spei',"
            "'Corr_spei.fire','Corr_fire.spei','Corr_spei_fire','Corr_fire_spei'"
        )
    else:
        search_order = (
            "'Corr_spei.fire','Corr_fire.spei','Corr_spei_fire','Corr_fire_spei',"
            "'SCorr_spei.fire','SCorr_fire.spei','SCorr_spei_fire','SCorr_fire_spei'"
        )
    # =============================== FIN DEL CAMBIO 1 ===================================

    r_code = f"""
    if (!requireNamespace('GWmodel', quietly=TRUE))
        install.packages('GWmodel', repos='https://cran.rstudio.com');
    if (!requireNamespace('sp', quietly=TRUE))
        install.packages('sp', repos='https://cran.rstudio.com');
    suppressPackageStartupMessages({{ library(GWmodel); library(sp) }})

    df <- read.csv('gw_input.csv', stringsAsFactors=FALSE)
    sites <- if (file.exists('gw_sites.csv')) read.csv('gw_sites.csv', stringsAsFactors=FALSE) else df[,c('cell_id','lon','lat')]

    req <- c('spei','fire','lon','lat')
    if (!all(req %in% names(df))) stop('gw_input.csv sin columnas requeridas.')
    df$spei <- as.numeric(df$spei); df$fire <- as.numeric(df$fire)
    df$lon  <- as.numeric(df$lon);  df$lat  <- as.numeric(df$lat)
    df <- df[is.finite(df$spei) & is.finite(df$fire) & is.finite(df$lon) & is.finite(df$lat), , drop=FALSE]
    if (nrow(df) < 5) stop('Muy pocas filas en gw_input.csv (n < 5).')

    sites$lon <- as.numeric(sites$lon); sites$lat <- as.numeric(sites$lat)
    sites <- sites[is.finite(sites$lon) & is.finite(sites$lat), , drop=FALSE]
    if (nrow(sites) < 1) stop('gw_sites.csv sin sitios válidos.')

    coordinates(df)    <- ~ lon + lat
    proj4string(df)    <- CRS('+proj=longlat +datum=WGS84 +no_defs')
    coordinates(sites) <- ~ lon + lat
    proj4string(sites) <- CRS('+proj=longlat +datum=WGS84 +no_defs')

    bw_req <- {int(bw_neighbors)}
    bw_adapt <- min(bw_req, nrow(df) - 1L)
    if (bw_adapt < 2L) stop('BW adaptativo demasiado pequeño (bw < 2).')

    qflag <- FALSE

    key_df <- paste0(round(coordinates(df)[,1], 10L), '|', round(coordinates(df)[,2], 10L))
    max_dups <- max(table(key_df))
    use_adaptive <- TRUE
    bw_fixed <- NA_real_

    if (is.finite(max_dups) && max_dups >= bw_adapt) {{
      use_adaptive <- FALSE
      uid <- !duplicated(coordinates(df))
      df_unique <- df[uid,]
      if (nrow(df_unique) < 3) stop('Muy pocas ubicaciones únicas para BW fijo.')

      k_unique <- max(5L, min(round(bw_adapt * 0.6), nrow(df_unique) - 1L))
      dmat_sites_unique <- spDists(sites, df_unique, longlat=TRUE)
      kth <- apply(dmat_sites_unique, 1, function(x) {{
        xs <- sort(x, partial=k_unique)
        xs[k_unique]
      }})
      bw_fixed <- suppressWarnings(median(kth[is.finite(kth) & kth > 0], na.rm=TRUE))
      if (!is.finite(bw_fixed) || bw_fixed <= 0) {{
        bw_fixed <- suppressWarnings(max(kth[is.finite(kth)], na.rm=TRUE))
      }}
      if (!is.finite(bw_fixed) || bw_fixed <= 0) stop('No se pudo estimar BW fijo > 0.')
    }}

    res <- if (use_adaptive) {{
      gwss(
        data          = df,
        summary.locat = sites,
        vars          = c('spei','fire'),
        kernel        = '{kernel}',
        adaptive      = TRUE,
        bw            = bw_adapt,
        longlat       = TRUE,
        quantile      = qflag
      )
    }} else {{
      gwss(
        data          = df,
        summary.locat = sites,
        vars          = c('spei','fire'),
        kernel        = '{kernel}',
        adaptive      = FALSE,
        bw            = bw_fixed,
        longlat       = TRUE,
        quantile      = qflag
      )
    }}

    S  <- res$SDF@data
    cn <- names(S)

    pick_corr <- function(cn) {{
      # ================== CAMBIO 2: Se inyecta el orden de búsqueda ==================
      cand <- c({search_order})
      # ============================ FIN DEL CAMBIO 2 ===============================
      hit <- cand[cand %in% cn]
      if (length(hit) > 0) return(hit[1])
      g <- grep('S?Corr.*(spei).*(fire)', cn, ignore.case=TRUE, value=TRUE)
      if (length(g) > 0) return(g[1])
      stop('No se encontró columna de correlación local (Corr/SCorr).')
    }}
    corr_col <- pick_corr(cn)

    get_sd <- function(varname) {{
      sd_col <- cn[grepl('(LSD|\\\\bSD\\\\b|Std)', cn, ignore.case=TRUE) & grepl(varname, cn, ignore.case=TRUE)]
      if (length(sd_col) > 0) return(S[[sd_col[1]]])
      var_col <- cn[grepl('Var', cn, ignore.case=TRUE) & grepl(varname, cn, ignore.case=TRUE)]
      if (length(var_col) > 0) return(sqrt(pmax(0, S[[var_col[1]]])))
      rep(NA_real_, nrow(S))
    }}
    pick_mean <- function(varname) {{
      m_col <- cn[grepl('Mean|Ave|Avg|LM', cn, ignore.case=TRUE) & grepl(varname, cn, ignore.case=TRUE)]
      if (length(m_col) > 0) S[[m_col[1]]] else rep(NA_real_, nrow(S))
    }}

    dmat_sites_all <- spDists(sites, df, longlat=TRUE)
    if (use_adaptive) {{
      neff <- apply(dmat_sites_all, 1, function(d) {{
        thr <- sort(d, partial=bw_adapt)[bw_adapt]
        sum(is.finite(d) & d <= thr)
      }})
    }} else {{
      neff <- apply(dmat_sites_all, 1, function(d) sum(is.finite(d) & d <= bw_fixed))
    }}

    out <- data.frame(
      cell_id   = sites$cell_id,
      corr      = S[[corr_col]],
      mean_spei = pick_mean('spei'),
      mean_fire = pick_mean('fire'),
      sd_spei   = get_sd('spei'),
      sd_fire   = get_sd('fire'),
      bw_mode   = rep(if (use_adaptive) 'adaptive' else 'fixed', nrow(S)),
      bw_used   = rep(if (use_adaptive) bw_adapt else bw_fixed, nrow(S)),
      neff      = as.numeric(neff),
      corr_name = rep(corr_col, nrow(S)),
      corr_type = rep(if (grepl('^SCorr', corr_col)) 'Spearman' else 'Pearson', nrow(S))
    )
    write.csv(out, 'gw_corr_out.csv', row.names=FALSE)
    """
    ro.r(r_code)

    out = pd.read_csv("gw_corr_out.csv")
    if os.path.exists("gw_sites.csv"):
        sites = pd.read_csv("gw_sites.csv")
    else:
        sites = pd.read_csv("gw_input.csv")[['cell_id','lon','lat']].drop_duplicates()
    out = out.merge(sites[['cell_id','lon','lat']], on='cell_id', how='left')
    return out

def _bh_fdr(pvals: np.ndarray) -> np.ndarray:
    p = np.asarray(pvals, float); m = np.sum(~np.isnan(p))
    order = np.argsort(np.where(np.isnan(p), np.inf, p))
    ranks = np.empty_like(order); ranks[order] = np.arange(1, len(p)+1)
    q = p * m / ranks; q_sorted = np.minimum.accumulate(q[order][::-1])[::-1]
    q_final = np.empty_like(q); q_final[order] = q_sorted
    return np.clip(q_final, 0, 1)

def _compute_gw_post_metrics(df_corr: pd.DataFrame) -> pd.DataFrame:
    df = df_corr.copy()
    neff = pd.to_numeric(df.get('neff', np.nan), errors='coerce')
    bw   = pd.to_numeric(df.get('bw_used', np.nan), errors='coerce')
    df['neff'] = np.where(np.isfinite(neff), neff, bw).astype(float)
    df['neff'] = df['neff'].fillna(30).clip(lower=5)
    r = pd.to_numeric(df['corr'], errors='coerce').astype(float)
    df_ = df['neff'] - 2.0
    denom = np.clip(1.0 - r**2, 1e-12, None)
    numer = np.clip(df_, 0.0, None)
    tstat = r * np.sqrt(numer / denom)
    from scipy.stats import t as tdist
    p_arr = 2.0 * tdist.sf(np.abs(np.asarray(tstat, float)),
                           np.asarray(df_.clip(lower=1), float))
    df['pval'] = p_arr
    df['qval'] = _bh_fdr(p_arr)
    df['sig_fdr05'] = (df['qval'] <= 0.05)
    if {'sd_spei','sd_fire'}.issubset(df.columns):
        sd_x = pd.to_numeric(df['sd_spei'], errors='coerce')
        sd_y = pd.to_numeric(df['sd_fire'], errors='coerce')
        with np.errstate(divide='ignore', invalid='ignore'):
            slope = r * (sd_y / sd_x)
        slope[~np.isfinite(slope)] = np.nan
        df['slope'] = np.clip(slope, -1.0, 1.0)
    else:
        df['slope'] = np.nan
    return df

def gw_correlation_yearly_panels(
    k: int,
    lag: int,
    firms_csv_path: str,
    start_year: int = 2014,
    end_year: int   = 2024,
    bw_neighbors: int = 480,
    kernel: str = "bisquare",
    use_spearman: bool = False,
    nan_strategy: str = "mask",         # "mask" (omite NaN) o "impute0" (rellena 0)
    enso_csv_path: str | None = None,
    enso_phase: str = "all",           # "all","nino","nina","neutral"
    enso_thr: float = 0.5,
    enso_lag: int = 0,
    ncols: int = 4,                    # más filas que columnas
    dot_size: float = 9.0,             # tamaño del punto en scatter
    out_dir: str = None                # por defecto: DRIVE_DIR_SPEI
):
    """
    Ejecuta GWSS (lógica de distancia) año a año (start_year..end_year).
    Ya no usa 'spei_agg', 'fire_metric', 'riskset' o 'balanceo'.
    """
    try:
        import numpy as _np
        import matplotlib.pyplot as _plt
        from matplotlib.colors import TwoSlopeNorm
        import matplotlib as _mpl

        if out_dir is None:
            out_dir = DRIVE_DIR_SPEI
        Path(out_dir).mkdir(parents=True, exist_ok=True)

        years = list(range(int(start_year), int(end_year) + 1))
        results_per_year = []
        kept_years, skipped_years = [], []

        enso_csv_eff = enso_csv_path or f"{DRIVE_DIR_SPEI}/ENSO.csv"

        for y in years:
            if enso_phase != "all":
                try:
                    months_phase = _months_by_enso_phase(
                        enso_csv_eff, phase=enso_phase, thr=float(enso_thr), enso_lag=int(enso_lag)
                    )
                    has_month_in_year = any(getattr(p, "year", None) == int(y) for p in months_phase)
                    if not has_month_in_year:
                        skipped_years.append(y)
                        continue
                except Exception:
                    pass

            try:
                _ = _prepare_gw_input_unified(
                    k=int(k),
                    lag=int(lag),
                    firms_csv_path=firms_csv_path,
                    years_str=str(y),
                    phase=enso_phase,
                    enso_csv_path=(enso_csv_eff if enso_phase != "all" else enso_csv_eff),
                    enso_thr=float(enso_thr),
                    enso_lag=int(enso_lag),
                )
            except ValueError as ve:
                msg = str(ve).lower()
                if ("no quedan meses" in msg) or ("spei vacío" in msg) or ("spei vacio" in msg) or ("tras filtrar" in msg) or ("no se encontraron meses" in msg) or ("dataframe final está vacío" in msg) or ("no tiene varianza" in msg):
                    skipped_years.append(y)
                    continue
                else:
                    raise

            corr_df = _run_gwss_gwmodel(bw_neighbors=int(bw_neighbors), kernel=kernel, use_spearman=bool(use_spearman))
            dfm = _compute_gw_post_metrics(corr_df).copy()
            if dfm is None or len(dfm) == 0:
                skipped_years.append(y)
                continue
            dfm["year"] = y
            results_per_year.append(dfm)
            kept_years.append(y)

        if not results_per_year:
            raise ValueError("No se obtuvieron resultados: todos los años fueron omitidos por filtros ENSO/datos.")

        region_fc, _ = _get_regions()
        gdf_regions  = geemap.ee_to_gdf(region_fc)
        bounds = gdf_regions.total_bounds
        xmin, ymin, xmax, ymax = bounds

        n_years = len(kept_years)
        ncols   = max(1, int(ncols))
        nrows   = int(_np.ceil(n_years / ncols))
        corr_norm  = TwoSlopeNorm(vmin=-1, vcenter=0.0, vmax=1)
        slope_norm = TwoSlopeNorm(vmin=-0.5, vcenter=0.0, vmax=0.5)

        # ---------- Panel 1: Correlación local ----------
        fig_corr, axs = _plt.subplots(nrows, ncols, figsize=(3.2*ncols, 2.6*nrows), sharex=True, sharey=True, dpi=300)
        axs = _np.atleast_2d(axs)
        last_im = None
        for i, y in enumerate(kept_years):
            r = i // ncols; c = i % ncols
            ax = axs[r, c]
            dfy = next(d for d in results_per_year if int(d["year"].iloc[0]) == y)
            corr_vals = _np.asarray(dfy["corr"], float)
            if nan_strategy == "impute0":
                corr_vals = _np.where(_np.isfinite(corr_vals), corr_vals, 0.0)
                mask = _np.ones_like(corr_vals, bool)
            else:
                mask = _np.isfinite(corr_vals)
            sc = ax.scatter(dfy.loc[mask, "lon"].astype(float),
                            dfy.loc[mask, "lat"].astype(float),
                            c=corr_vals[mask], s=dot_size, cmap="RdGy", norm=corr_norm, linewidths=0, zorder=2)
            last_im = sc
            gpd.GeoSeries(unary_union(gdf_regions.boundary.geometry)).plot(ax=ax, edgecolor="black", linewidth=0.7, facecolor="none", zorder=5)
            ax.set_title(f"Año {y}", fontsize=10, pad=2); ax.set_xlim(xmin, xmax); ax.set_ylim(ymin, ymax); ax.axis("off")
        for j in range(n_years, nrows*ncols):
            axs[j//ncols, j % ncols].axis('off')
        _plt.tight_layout(rect=[0, 0, 0.88, 1])
        cax = fig_corr.add_axes([0.90, 0.15, 0.025, 0.7])
        fig_corr.colorbar(last_im, cax=cax, label="Correlación local (GW)")

        # ---------- Panel 2: Significancia local ----------
        fig_sig, axs2 = _plt.subplots(nrows, ncols, figsize=(3.2*ncols, 2.6*nrows), sharex=True, sharey=True, dpi=300)
        axs2 = _np.atleast_2d(axs2)
        all_logq_vals = []
        for dfy in results_per_year:
            q = np.asarray(dfy["qval"], float)
            logq = -np.log10(np.clip(q, 1e-12, 1.0))
            all_logq_vals.append(logq[np.isfinite(logq)])
        if all_logq_vals:
            concatenated_vals = np.concatenate(all_logq_vals)
            vmin = np.min(concatenated_vals) if concatenated_vals.size > 0 else 0
            vmax = np.max(concatenated_vals) if concatenated_vals.size > 0 else 1
            norm_sig = plt.Normalize(vmin=vmin, vmax=vmax)
        else:
            norm_sig = plt.Normalize(vmin=0, vmax=1)
        last_im2 = None
        for i, y in enumerate(kept_years):
            r = i // ncols; c = i % ncols
            ax = axs2[r, c]
            dfy = next(d for d in results_per_year if int(d["year"].iloc[0]) == y)
            q = _np.asarray(dfy["qval"], float)
            logq = -_np.log10(_np.clip(q, 1e-12, 1.0))
            if nan_strategy == "impute0":
                logq = _np.where(_np.isfinite(logq), logq, 0.0)
                mask = _np.ones_like(logq, bool)
            else:
                mask = _np.isfinite(logq)
            sc = ax.scatter(dfy.loc[mask, "lon"].astype(float),
                            dfy.loc[mask, "lat"].astype(float),
                            c=logq[mask], s=dot_size, cmap="viridis", linewidths=0, zorder=2, norm=norm_sig)
            last_im2 = sc
            sig_mask = (dfy["sig_fdr05"].fillna(False).values) & mask
            ax.scatter(dfy.loc[sig_mask, "lon"].astype(float),
                       dfy.loc[sig_mask, "lat"].astype(float), facecolors="none",
                       edgecolors="k", s=dot_size*3.0, linewidths=0.6, zorder=3)
            gpd.GeoSeries(unary_union(gdf_regions.boundary.geometry)).plot(ax=ax, edgecolor="black", linewidth=0.7, facecolor="none", zorder=5)
            ax.set_title(f"Año {y}", fontsize=10, pad=2); ax.set_xlim(xmin, xmax); ax.set_ylim(ymin, ymax); ax.axis("off")
        for j in range(n_years, nrows*ncols):
            axs2[j//ncols, j % ncols].axis('off')
        _plt.tight_layout(rect=[0, 0, 0.88, 1])
        cax2 = fig_sig.add_axes([0.90, 0.15, 0.025, 0.7])
        fig_sig.colorbar(last_im2, cax=cax2, label="-log10(q)")

        # ---------- Panel 3: Pendiente local ----------
        fig_slope, axs3 = _plt.subplots(nrows, ncols, figsize=(3.2*ncols, 2.6*nrows), sharex=True, sharey=True, dpi=300)
        axs3 = _np.atleast_2d(axs3)
        last_im3 = None
        for i, y in enumerate(kept_years):
            r = i // ncols; c = i % ncols
            ax = axs3[r, c]
            dfy = next(d for d in results_per_year if int(d["year"].iloc[0]) == y)
            slope = _np.asarray(dfy["slope"], float)
            has_any = _np.isfinite(slope).any()
            if has_any:
                if nan_strategy == "impute0":
                    slope = _np.where(_np.isfinite(slope), slope, 0.0)
                    mask = _np.ones_like(slope, bool)
                else:
                    mask = _np.isfinite(slope)
                sc = ax.scatter(dfy.loc[mask, "lon"].astype(float),
                                dfy.loc[mask, "lat"].astype(float),
                                c=slope[mask], s=dot_size, cmap="PiYG", norm=slope_norm, linewidths=0, zorder=2)
                last_im3 = sc
            else:
                ax.text(0.5, 0.5, "Sin SD locales", ha="center", va="center", fontsize=9)
            gpd.GeoSeries(unary_union(gdf_regions.boundary.geometry)).plot(ax=ax, edgecolor="black", linewidth=0.7, facecolor="none", zorder=5)
            ax.set_title(f"Año {y}", fontsize=10, pad=2); ax.set_xlim(xmin, xmax); ax.set_ylim(ymin, ymax); ax.axis("off")
        for j in range(n_years, nrows*ncols):
            axs3[j//ncols, j % ncols].axis('off')
        _plt.tight_layout(rect=[0, 0, 0.88, 1])
        if last_im3 is not None:
            cax3 = fig_slope.add_axes([0.90, 0.15, 0.025, 0.7])
            fig_slope.colorbar(last_im3, cax=cax3, label="Pendiente local y|x")

        # ---------- Panel 4: Resumen (ECDF + Histo) ----------
        fig_sum, (ax_ecdf, ax_hist) = _plt.subplots(1, 2, figsize=(10, 4), dpi=300, sharex=False)
        if len(kept_years) <= 10: cmap = _mpl.cm.get_cmap("tab10", len(kept_years))
        else: cmap = _mpl.cm.get_cmap("tab20", len(kept_years))
        bins = _np.linspace(-1.0, 1.0, 41)
        for i, y in enumerate(kept_years):
            dfy  = next(d for d in results_per_year if int(d["year"].iloc[0]) == y)
            vals = _np.asarray(dfy["corr"], float)
            vals = vals[_np.isfinite(vals)]
            if vals.size == 0: continue
            color = cmap(i)
            v = _np.sort(vals)
            ecdf = _np.arange(1, len(v)+1)/len(v)
            ax_ecdf.plot(v, ecdf, lw=1.8, alpha=0.95, color=color, label=str(y))
            h, edges = _np.histogram(vals, bins=bins, density=True)
            centers  = 0.5*(edges[:-1]+edges[1:])
            ax_hist.plot(centers, h, lw=1.4, alpha=0.9, color=color)
        ax_ecdf.axvline(0.0, color="k", lw=0.8, ls="--")
        ax_ecdf.set_title("ECDF correlación (años superpuestos)")
        ax_ecdf.set_xlabel("Correlación local"); ax_ecdf.set_ylabel("Proporción")
        ax_ecdf.grid(ls=":", alpha=0.5)
        ax_hist.set_title("Histograma densidad (superpuestos)")
        ax_hist.set_xlabel("Correlación local"); ax_hist.set_ylabel("Densidad")
        ax_hist.grid(ls=":", alpha=0.5)
        handles, labels = ax_ecdf.get_legend_handles_labels()
        import math
        ncols_leg = max(1, math.ceil(len(handles) / 2))
        fig_sum.subplots_adjust(bottom=0.25)
        leg = fig_sum.legend(
            handles=handles, labels=labels, title="Año",
            loc="upper center", ncol=ncols_leg, frameon=False,
            bbox_to_anchor=(0.5, 0.18), fontsize=9, title_fontsize=10,
            columnspacing=1.2, handlelength=2.6,
        )

        # 5) Guardado + mensaje con años omitidos
        tag = f"TS{k}_lag{lag}_{start_year}_{end_year}"
        f_corr = os.path.join(out_dir, f"GWSS_YearPanel_corr_{tag}.png")
        f_sig  = os.path.join(out_dir, f"GWSS_YearPanel_sig_{tag}.png")
        f_slo  = os.path.join(out_dir, f"GWSS_YearPanel_slope_{tag}.png")
        f_sum  = os.path.join(out_dir, f"GWSS_YearPanel_summary_{tag}.png")
        save_png_trim(fig_corr, f_corr, dpi=600)
        save_png_trim(fig_sig,  f_sig,  dpi=600)
        save_png_trim(fig_slope,f_slo,  dpi=600)
        save_png_trim(fig_sum,  f_sum,  dpi=600)
        omit_txt = f"\nOmitidos por ENSO/datos: {', '.join(map(str, skipped_years))}" if skipped_years else ""
        msg = (f"✔️ Paneles anuales guardados:\n"
               f"‣ {os.path.basename(f_corr)}\n"
               f"‣ {os.path.basename(f_sig)}\n"
               f"‣ {os.path.basename(f_slo)}\n"
               f"‣ {os.path.basename(f_sum)}\n"
               f"Años usados: {', '.join(map(str, kept_years))}{omit_txt}\n"
               f"Carpeta: {out_dir}")
        return fig_corr, fig_sig, fig_slope, fig_sum, msg

    except Exception as e:
        import traceback as _tb
        fig_err, ax_err = plt.subplots(figsize=(6, 4), dpi=150)
        ax_err.axis("off"); ax_err.text(0.5, 0.5, f"{e}", ha="center", va="center", color="red")
        return fig_err, fig_err, fig_err, fig_err, f"❌ {e}\n{_tb.format_exc()}"

from matplotlib.colors import TwoSlopeNorm
def _plot_gw_corr(df_corr, title="Correlación GW (SPEI vs Incendios)", nan_strategy: str = "mask", draw_borders: bool = True):
    dfp = df_corr.copy()
    if nan_strategy == "impute0": dfp['corr'] = dfp['corr'].astype(float).fillna(0.0)
    else: dfp = dfp.dropna(subset=['corr'])
    if dfp.empty:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5,0.5,"Sin celdas válidas",ha='center',va='center'); return fig
    fig, ax = plt.subplots(figsize=(8, 7)); norm = TwoSlopeNorm(vmin=-0.5, vcenter=0.0, vmax=0.5)
    sc = ax.scatter(dfp['lon'].astype(float), dfp['lat'].astype(float), c=dfp['corr'].astype(float),
                    s=9, cmap='RdGy', norm=norm, linewidths=0, zorder=2)
    if draw_borders:
        region_fc, _roi = _get_regions(); gdf_regions = geemap.ee_to_gdf(region_fc)
        gpd.GeoSeries(unary_union(gdf_regions.boundary.geometry)).plot(ax=ax, edgecolor="black", linewidth=0.7, facecolor="none", zorder=5)
        xmin, ymin, xmax, ymax = gdf_regions.total_bounds; ax.set_xlim(xmin, xmax); ax.set_ylim(ymin, ymax)
    cb = plt.colorbar(sc, ax=ax); cb.set_label("ρ local (GW)"); ax.grid(ls=":")
    ax.set_xlabel("Longitud"); ax.set_ylabel("Latitud"); ax.set_title(title); fig.tight_layout(); return fig

def _plot_gw_significance_map(df, title="Significancia local (FDR)"):
    dfp = df.copy()
    if dfp.empty:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5,0.5,'Sin datos',ha='center'); return fig
    fig, ax = plt.subplots(figsize=(8, 7))
    logq = -np.log10(np.clip(dfp['qval'].astype(float), 1e-12, 1))
    sc = ax.scatter(dfp['lon'], dfp['lat'], c=logq, s=10, cmap='viridis', zorder=2)
    sig = dfp['sig_fdr05'].fillna(False).values
    ax.scatter(dfp.loc[sig,'lon'], dfp.loc[sig,'lat'], facecolors='none', edgecolors='k', s=30, linewidths=0.6, zorder=3)
    region_fc, _ = _get_regions(); gdf_regions = geemap.ee_to_gdf(region_fc)
    gpd.GeoSeries(unary_union(gdf_regions.boundary.geometry)).plot(ax=ax, edgecolor='black', linewidth=0.7, facecolor='none', zorder=5)
    xmin, ymin, xmax, ymax = gdf_regions.total_bounds; ax.set_xlim(xmin, xmax); ax.set_ylim(ymin, ymax)
    cb = plt.colorbar(sc, ax=ax); cb.set_label("−log10(q)"); ax.set_xlabel("Longitud"); ax.set_ylabel("Latitud")
    ax.set_title(title); ax.grid(ls=':'); fig.tight_layout(); return fig

def _plot_gw_slope_map(df, title="Sensibilidad local (pendiente y~x)"):
    dfp = df.copy()
    if dfp['slope'].notna().sum() == 0:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5,0.5,'Sin SD locales → no hay pendiente',ha='center'); return fig
    fig, ax = plt.subplots(figsize=(8, 7)); norm = TwoSlopeNorm(vmin=-0.5, vcenter=0.0, vmax=0.5)
    sc = ax.scatter(dfp['lon'], dfp['lat'], c=dfp['slope'], s=10, cmap='PiYG', norm=norm, zorder=2)
    region_fc, _ = _get_regions(); gdf_regions = geemap.ee_to_gdf(region_fc)
    gpd.GeoSeries(unary_union(gdf_regions.boundary.geometry)).plot(ax=ax, edgecolor='black', linewidth=0.7, facecolor='none', zorder=5)
    xmin, ymin, xmax, ymax = gdf_regions.total_bounds; ax.set_xlim(xmin, xmax); ax.set_ylim(ymin, ymax)
    cb = plt.colorbar(sc, ax=ax); cb.set_label("Pendiente local (≈ Δprob/ΔSPEI)")
    ax.set_xlabel("Longitud"); ax.set_ylabel("Latitud"); ax.set_title(title); ax.grid(ls=':'); fig.tight_layout(); return fig

def _plot_corr_hist_ecdf(df, title="Distribución de correlaciones locales"):
    vals = df['corr'].astype(float).values; sig  = df['sig_fdr05'].fillna(False).values
    vals_sig = vals[sig]; fig = plt.figure(figsize=(10,5))
    ax1 = fig.add_subplot(1,2,1); bins = np.linspace(-1, 1, 41)
    ax1.hist(vals, bins=bins, color='0.7', edgecolor='0.3', label='Todas')
    if len(vals_sig) > 0: ax1.hist(vals_sig, bins=bins, alpha=0.8, color='tab:blue', edgecolor='black', label='FDR ≤ 0.05')
    ax1.axvline(0, color='k', lw=1); ax1.set_xlabel("ρ local"); ax1.set_ylabel("Celdas"); ax1.grid(ls=':')
    med = np.nanmedian(vals); ax1.axvline(med, color='tab:red', lw=1.5, ls='--', label=f"Mediana={med:.3f}")
    ax1.legend(); ax1.set_title(title)
    ax2 = fig.add_subplot(1,2,2); v_sorted = np.sort(vals[np.isfinite(vals)]); y = np.arange(1, len(v_sorted)+1) / len(v_sorted)
    ax2.plot(v_sorted, y, lw=2); ax2.set_xlabel("ρ local"); ax2.set_ylabel("ECDF"); ax2.grid(ls=':'); fig.tight_layout(); return fig

def gw_correlation_unified_plus(k: int, lag: int, firms_csv_path: str, years_str: str,
                                bw_neighbors: int, kernel: str,
                                use_spearman: bool, nan_strategy: str,
                                enso_csv_path: str, enso_phase: str, enso_thr: float, enso_lag: int):
    try:
        # Prepara los datos de entrada (versión simplificada)
        _prepare_gw_input_unified(int(k), int(lag), firms_csv_path, years_str,
                                  enso_phase, enso_csv_path, float(enso_thr), int(enso_lag))

        # Ejecuta el análisis GWSS en R
        corr_df = _run_gwss_gwmodel(bw_neighbors=int(bw_neighbors), kernel=kernel, use_spearman=bool(use_spearman))

        # Calcula métricas de significancia
        dfm = _compute_gw_post_metrics(corr_df)

        # Genera las figuras y el resumen
        phase_txt = f" · ENSO={enso_phase.upper()}" if enso_phase!='all' else ""
        fig_corr = _plot_gw_corr(dfm, title=f"Correlación GW — SPEI-{k} (lag {lag}), años={years_str if years_str.strip() else 'todos'}{phase_txt}",
                                 nan_strategy=nan_strategy)
        fig_sig  = _plot_gw_significance_map(dfm, title="Significancia local (FDR 5%)")
        fig_sens = _plot_gw_slope_map(dfm, title="Sensibilidad local (pendiente ≈ Δprob/ΔSPEI)")
        fig_hist = _plot_corr_hist_ecdf(dfm)

        # Prepara el texto de resumen
        n_all = len(dfm); n_sig = int(dfm['sig_fdr05'].sum()); frac_sig = n_sig / n_all if n_all else 0
        med_r = float(np.nanmedian(dfm['corr'])); iqr_r = np.nanpercentile(dfm['corr'], [25, 75]) if np.isfinite(dfm['corr']).any() else [np.nan, np.nan]
        neg_sig = int(((dfm['sig_fdr05']) & (dfm['corr'] < 0)).sum()); pos_sig = int(((dfm['sig_fdr05']) & (dfm['corr'] > 0)).sum())

        if dfm['slope'].notna().any():
            med_slope = float(np.nanmedian(dfm['slope'])); p2, p98 = np.nanpercentile(dfm['slope'], [2, 98])
            slope_txt = f"pendiente_mediana={med_slope:.3f} · p2–p98=[{p2:.3f}, {p98:.3f}]"
        else:
            slope_txt = "pendiente: N/D"

        text_sum = (f"celdas={n_all} · bw(adapt)≈{int(np.nanmedian(dfm.get('bw_used', pd.Series([bw_neighbors]))))} · "
                    f"ρ_mediana={med_r:.3f} · IQR=[{iqr_r[0]:.3f},{iqr_r[1]:.3f}] · "
                    f"FDR≤0.05: {n_sig} ({frac_sig:.1%}) → neg={neg_sig}, pos={pos_sig} · {slope_txt}")

        status_msg = (f"✔️ GWSS OK (Distancia) · periodo={years_str if years_str.strip() else 'todos'} · "
                      f"k={k} · lag={lag} · kernel={kernel} · "
                      f"{'Spearman' if use_spearman else 'Pearson'} · "
                      f"ENSO_phase={enso_phase} thr={enso_thr} lag_ENSO={enso_lag}")

        return fig_corr, fig_sig, fig_sens, fig_hist, text_sum, status_msg
    except Exception as e:
        fig, ax = plt.subplots(); ax.axis('off'); ax.text(0.5,0.5,str(e),ha='center',va='center',color='red')
        return fig, fig, fig, fig, f"❌ {e}", f"❌ {e}"

# ========= GWSS: explicación paso a paso para UNA celda (VERSIÓN SIMPLIFICADA) =========
def explain_gwss_single_cell(
    k: int, lag: int, firms_csv_path: str, years_str: str,
    bw_neighbors: int, kernel: str,
    use_spearman: bool,
    enso_phase: str = "all", enso_csv_path: str = None,
    enso_thr: float = 0.5, enso_lag: int = 0,
    cell_id: str = "random", seed: int = 13,
):
    """
    Versión eficiente y SIMPLIFICADA.
    Ya no usa 'spei_agg', 'fire_metric', 'riskset' o 'balanceo'.
    """
    try:
        from matplotlib.gridspec import GridSpec
        from scipy.stats import t as tdist

        # 1. Preparar datos de entrada globales (crea /content/gw_input.csv)
        _prepare_gw_input_unified(
            int(k), int(lag), firms_csv_path, years_str, enso_phase, enso_csv_path,
            float(enso_thr), int(enso_lag)
        )
        df_input = pd.read_csv("/content/gw_input.csv", dtype={"cell_id": str})
        if df_input.empty:
            return {"error": "El DataFrame de entrada está vacío después de aplicar los filtros."}

        # Leer la lista completa de sitios (para seleccionar la celda focal ANTES de R)
        df_sites = pd.read_csv("/content/gw_sites.csv", dtype={"cell_id": str})
        if df_sites.empty:
            return {"error": "No se encontró gw_sites.csv"}

        # 2. Ejecutar el análisis GWSS global con R
        df_results_raw = _run_gwss_gwmodel(
            bw_neighbors=int(bw_neighbors), kernel=kernel, use_spearman=bool(use_spearman)
        )

        # 3. Calcular métricas de significancia (p-value, q-value)
        df_results_full = _compute_gw_post_metrics(df_results_raw)

        # 4. Seleccionar celda focal (de la lista de sitios completa)
        unique_cells = df_sites["cell_id"].unique()
        if cell_id == "random":
            rng = np.random.default_rng()
            cell_id = rng.choice(unique_cells)
        elif cell_id not in unique_cells:
            return {"error": f"El cell_id '{cell_id}' no se encontró en la grilla base (gw_sites.csv)."}

        # 5. Extraer resultados FINALES para la celda focal
        if cell_id not in df_results_full["cell_id"].values:
            return {"error": f"El cell_id '{cell_id}' no tiene resultados en el análisis GWSS (probablemente no tiene datos SPEI válidos)."}
        focal_results = df_results_full[df_results_full["cell_id"] == cell_id].iloc[0]
        rho_local = focal_results['corr']
        p_local = focal_results['pval']
        q_local = focal_results['qval']
        t_local = focal_results.get('tstat', np.nan)
        df_local = int(focal_results['neff'] - 2)
        signif_fdr05 = focal_results['sig_fdr05']

        # 6. Recopilar datos de ENTRADA para la visualización
        focal_row_site = df_sites[df_sites["cell_id"] == cell_id].iloc[0]
        lon_focal, lat_focal = focal_row_site["lon"], focal_row_site["lat"]

        df_focal_data = df_input[df_input["cell_id"] == cell_id]
        if df_focal_data.empty:
            focal_spei_val = np.nan
            focal_fire_val = np.nan
        else:
            focal_spei_val = df_focal_data["spei"].mean()
            focal_fire_val = df_focal_data["fire"].mean()


        # 6.b) Vecinos y pesos
        coords = df_input[["lon", "lat"]].to_numpy(float)
        from sklearn.neighbors import BallTree
        tree = BallTree(coords, metric="euclidean")

        # Encontrar el índice del PUNTO DE DATOS MÁS CERCANO a la celda focal
        focal_coord_array = np.array([[focal_row_site["lon"], focal_row_site["lat"]]])
        dist_to_focal, idx_of_nearest_point = tree.query(focal_coord_array, k=1)

        idx_query_point = idx_of_nearest_point[0][0]

        bw = min(int(bw_neighbors), len(coords))
        dists_deg, idxs = tree.query([coords[idx_query_point]], k=bw)
        idxs = idxs[0]
        dists_km = dists_deg[0] * 111.0
        d_max_km = float(dists_km[-1])

        def compute_kernel(d, h, kernel_type):
            u = d / max(h, 1e-12)
            k = kernel_type.lower()
            if k == "bisquare":
                w = np.where(u < 1.0, (1.0 - u**2) ** 2, 0.0)
            elif k == "gaussian":
                w = np.exp(-0.5 * u**2)
            elif k == "exponential":
                w = np.exp(-u)
            elif k == "tricube":
                w = np.where(u < 1.0, (1.0 - u**3) ** 3, 0.0)
            elif k == "boxcar":
                w = np.where(u < 1.0, 1.0, 0.0)
            else:
                raise ValueError(f"Kernel desconocido: {kernel_type}")
            s = w.sum()
            return w / (s + 1e-12)

        weights = compute_kernel(dists_km, d_max_km, kernel)

        df_neighbors = df_input.iloc[idxs].copy()
        df_neighbors["dist_km"] = dists_km
        df_neighbors["weight"]  = weights

        X_neighbors = df_neighbors["spei"].astype(float).values
        Y_neighbors = df_neighbors["fire"].astype(float).values
        mask = np.isfinite(X_neighbors) & np.isfinite(Y_neighbors)
        Xc, Yc, Wc = X_neighbors[mask], Y_neighbors[mask], weights[mask]

        if not np.isfinite(t_local) and np.isfinite(rho_local) and np.isfinite(focal_results["neff"]) and focal_results["neff"] >= 3:
            n_eff_pairs = int(focal_results["neff"])
            num = rho_local * np.sqrt(max(n_eff_pairs - 2, 1))
            den = np.sqrt(max(1.0 - rho_local**2, 1e-12))
            t_local = float(num / den)
        df_local = int(max(int(focal_results["neff"]) - 2, 1))

        # 7) FIGURA MULTI-PANEL (A–F)
        import matplotlib.pyplot as plt
        from matplotlib.gridspec import GridSpec
        from scipy.stats import t as tdist

        fig = plt.figure(figsize=(16, 10), dpi=120)
        gs  = GridSpec(3, 3, figure=fig, hspace=0.35, wspace=0.3)

        # Panel A: Mapa vecinos
        ax1 = fig.add_subplot(gs[0:2, 0:2])
        ax1.set_aspect('equal')
        sc = ax1.scatter(
            df_neighbors["lon"], df_neighbors["lat"],
            c=df_neighbors["weight"], cmap="viridis",
            s=30 + 300 * df_neighbors["weight"], alpha=0.75,
            edgecolor="k", linewidth=0.4
        )
        ax1.scatter([lon_focal], [lat_focal], c="crimson", s=300, marker="*",
                    edgecolor="white", linewidth=2, zorder=10, label="Celda focal")
        plt.colorbar(sc, ax=ax1, label="Peso (kernel)", shrink=0.8)
        ax1.set_title(f"A. Vecinos adaptativos (bw={bw}, {kernel})", fontsize=12, fontweight="bold")
        ax1.set_xlabel("Longitud"); ax1.set_ylabel("Latitud")
        ax1.legend(loc="upper right", fontsize=9); ax1.grid(ls=":", alpha=0.3)

        # Panel B: Scatter ponderado
        ax2 = fig.add_subplot(gs[0, 2])
        ax2.scatter(Xc, Yc, c=Wc, cmap="plasma", s=30, alpha=0.7, edgecolor="k", linewidth=0.3)
        ax2.scatter([focal_spei_val], [focal_fire_val], c="red",
                    s=150, marker="*", edgecolor="white", linewidth=1.5, zorder=10, label="Focal (prom.)")
        method = "spearman" if use_spearman else "pearson"
        ax2.set_title(f"B. Correlación ({method})", fontsize=11, fontweight="bold")
        ax2.set_xlabel(f"SPEI-{k} (mensual, lag={lag})"); ax2.set_ylabel(f"Incendios (Cercanía)")
        ax2.text(0.05, 0.95, f"ρ = {rho_local:.3f}", transform=ax2.transAxes, fontsize=13, fontweight="bold",
                 va="top", bbox=dict(boxstyle="round", facecolor="yellow", alpha=0.7))
        ax2.legend(fontsize=8); ax2.grid(ls=":", alpha=0.3)
        y_min, y_max = Yc.min(), Yc.max()
        y_min = max(0.0, y_min)
        y_max = min(1.0, y_max)
        padding = (y_max - y_min) * 0.1
        ax2.set_ylim(y_min - padding - 0.01, y_max + padding + 0.01)

        # Panel C: Histograma de pesos
        ax3 = fig.add_subplot(gs[1, 2])
        ax3.hist(df_neighbors["weight"].values, bins=30, color="steelblue", alpha=0.7, edgecolor="k")
        if np.isfinite(focal_fire_val): # Solo mostrar si la celda focal está en los datos
            ax3.axvline(df_neighbors["weight"].values[0], color="crimson", ls="--", lw=2, label="Vecino más cercano")
        ax3.set_title("C. Distribución de pesos", fontsize=11, fontweight="bold")
        ax3.set_xlabel("Peso"); ax3.set_ylabel("Frecuencia")
        ax3.legend(fontsize=9); ax3.grid(axis="y", ls=":", alpha=0.3)

        # Panel D: Distancia vs Peso
        ax4 = fig.add_subplot(gs[2, 0])
        ax4.scatter(dists_km, weights, s=20, alpha=0.6, c="teal", edgecolor="k", linewidth=0.3)
        ax4.axhline(weights[0], color="crimson", ls="--", lw=1.5, alpha=0.7)
        ax4.axvline(d_max_km, color="orange", ls="--", lw=1.5, alpha=0.7, label=f"Bandwidth={d_max_km:.1f} km")
        ax4.set_title(f"D. Kernel: {kernel}", fontsize=11, fontweight="bold")
        ax4.set_xlabel("Distancia (km)"); ax4.set_ylabel("Peso")
        ax4.legend(fontsize=8); ax4.grid(ls=":", alpha=0.3)

        # Panel E: t-Student
        ax5 = fig.add_subplot(gs[2, 1])
        if df_local > 0 and np.isfinite(df_local):
            x_t = np.linspace(-6, 6, 200)
            y_t = tdist.pdf(x_t, df_local)
            ax5.plot(x_t, y_t, 'k-', lw=2, label=f"t({df_local} df)")
            if np.isfinite(t_local):
                ax5.axvline(t_local, color="red", lw=2.5, label=f"t obs = {t_local:.2f}")
            t_crit = float(tdist.ppf(0.975, df_local))
            ax5.axvline( t_crit,  color="orange", ls="--", lw=1.5, alpha=0.7)
            ax5.axvline(-t_crit,  color="orange", ls="--", lw=1.5, alpha=0.7, label=f"t crit ±{t_crit:.2f}")
            ax5.fill_between(x_t[x_t >  t_crit], 0, y_t[x_t >  t_crit], color="red", alpha=0.2)
            ax5.fill_between(x_t[x_t < -t_crit], 0, y_t[x_t < -t_crit], color="red", alpha=0.2)
        ax5.set_title("E. Test t-Student", fontsize=11, fontweight="bold")
        ax5.set_xlabel("Estadístico t"); ax5.set_ylabel("Densidad")
        ax5.legend(fontsize=7, loc="upper right"); ax5.grid(ls=":", alpha=0.3)

        # Panel F: Resumen
        ax6 = fig.add_subplot(gs[2, 2]); ax6.axis("off")
        summary_text = f"""
SIGNIFICANCIA

{method.upper()}: ρ={rho_local:.3f}
n_eff={int(focal_results['neff']):,} celdas

t={t_local:.3f}, df={df_local}
p={p_local:.2e}
q(FDR)={q_local:.2e}

FDR≤0.05: {'✓ SÍ' if signif_fdr05 else '✗ NO'}
"""
        ax6.text(0.5, 0.5, summary_text.strip(), transform=ax6.transAxes,
                 ha="center", va="center", fontsize=9, family="monospace",
                 bbox=dict(boxstyle="round,pad=1",
                           facecolor=("lightgreen" if signif_fdr05 else "lightyellow"),
                           edgecolor="black", linewidth=2, alpha=0.9))
        ax6.set_title("F. Resumen", fontsize=11, fontweight="bold")

        fig.suptitle(f"GWSS Celda {cell_id} · TS-{k} lag={lag} · ({lon_focal:.4f}°, {lat_focal:.4f}°)",
                     fontsize=14, fontweight="bold", y=0.98)

        csv_out_path = f"/content/GWSS_neighbors_{cell_id}_TS{k}_lag{lag}_bw{bw}_{kernel}.csv"
        df_neighbors[["cell_id","lon","lat","dist_km","weight","spei","fire"]].to_csv(csv_out_path, index=False)

        return {
            "mensaje": "Reporte GWSS completo generado",
            "cell_id": cell_id,
            "paso1_datos_entrada": {
                "lon": lon_focal, "lat": lat_focal,
                "spei_value": float(focal_spei_val),
                "fire_value": float(focal_fire_val),
            },
            "paso2_vecinos": {"bw_usado": bw, "distancia_max_km": d_max_km},
            "paso4_correlacion": {
                "metodo": method, "rho_local": float(rho_local),
                "n_eff_pairs": int(focal_results["neff"]),
                "t_stat": float(t_local), "df": int(df_local), "p_value": float(p_local),
            },
            "paso5_significancia": {
                "q_value": float(q_local), "significativo_fdr05": bool(signif_fdr05),
                "interpretacion": "Correlación significativa" if signif_fdr05 else "No significativa",
            },
            "figura": fig,
            "csv_vecinos": csv_out_path,
        }

    except Exception as e:
        import traceback
        return {"error": f"Error: {str(e)}\n\n{traceback.format_exc()}"}

# ========= Wrapper para UI Gradio =========
def explain_gwss_ui(k, lag, firms_csv, years, bw, kernel,
                    spearman, enso_phase, enso_csv,
                    enso_thr, enso_lag, cell_id_input):
    try:
        res = explain_gwss_single_cell(
            int(k), int(lag), firms_csv, years,
            int(bw), kernel, bool(spearman),
            enso_phase, enso_csv if enso_phase != "all" else None,
            float(enso_thr), int(enso_lag),
            cell_id=cell_id_input.strip() if cell_id_input.strip() else "random"
        )

        if "error" in res:
            fig_err, ax = plt.subplots(figsize=(10, 8)); ax.axis("off")
            ax.text(0.05, 0.95, f"❌ Error:\n\n{res['error']}", ha="left", va="top", color="red",
                    fontsize=8, family="monospace", bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8))
            return fig_err, f"❌ {res['error']}"

        texto = f"""
GWSS — Análisis de celda única (lógica de distancia)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
CONFIGURACIÓN
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Celda: {res['cell_id']} ({res['paso1_datos_entrada']['lon']:.4f}°, {res['paso1_datos_entrada']['lat']:.4f}°)
Periodo: {years if years.strip() else "Todos los años"}
SPEI: k={k}, lag={lag} (mensual)
Incendios: Métrica de Cercanía (1 / (dist_km + 1))
Kernel: {kernel}, BW={res['paso2_vecinos']['bw_usado']} vecinos ({res['paso2_vecinos']['distancia_max_km']:.1f} km)
ENSO: {enso_phase if enso_phase != 'all' else 'sin filtro'}
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
DATOS DE ENTRADA (promedio de la celda focal en el período)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
SPEI (promedio): {res['paso1_datos_entrada']['spei_value']:.3f}
Incendios (Cercanía, promedio): {res['paso1_datos_entrada']['fire_value']:.3f}
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
RESULTADOS
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Método: {res['paso4_correlacion']['metodo'].upper()}
ρ (correlación local): {res['paso4_correlacion']['rho_local']:.4f}
n_eff (celdas vecinas): {res['paso4_correlacion']['n_eff_pairs']:,}
t: {res['paso4_correlacion']['t_stat']:.3f}, df: {res['paso4_correlacion']['df']}
p-value: {res['paso4_correlacion']['p_value']:.2e}
q-value (FDR): {res['paso5_significancia']['q_value']:.3e}
Significativo FDR≤0.05: {'✓ SÍ' if res['paso5_significancia']['significativo_fdr05'] else '✗ NO'}
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
La figura muestra 6 paneles explicativos. CSV de vecinos en: {res['csv_vecinos']}
"""
        return res["figura"], texto.strip()

    except Exception as e:
        import traceback
        fig_err, ax = plt.subplots(figsize=(10, 8)); ax.axis("off")
        ax.text(0.05, 0.95, f"❌ Error en UI:\n\n{traceback.format_exc()}",
                ha="left", va="top", color="red", fontsize=7, family="monospace",
                bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8))
        return fig_err, f"❌ {str(e)}"

# =================== 14) UI GRADIO ===========================
with gr.Blocks(title="Suite CHIRPS + MODIS — SPI + SPEI + Análisis") as demo:
    gr.Markdown("""
### **Hidro-Suite (CHIRPS + MODIS)** — SPI + SPEI + ENSO + FIRMS + GWSS


**Requisitos**: EE, Google Drive, R (paquete **SPEI**) y **GWmodel** (R) para GWSS.
**Nota**: Todo queda alineado a la grilla **CHIRPS** (0.05°). Archivos largos `SPI_k_month.csv` y `SPEI_k_month.csv` se guardan en sus carpetas de Drive.
""")

    with gr.Tabs():
        # -------------------- TAB 1: SPI + NIFT --------------------
        with gr.TabItem("SPI + NIFT (CHIRPS)"):
            with gr.Row():
                sd_spi = gr.Textbox("1985-01-01", label="Inicio (AAAA-MM-DD)")
                ed_spi = gr.Textbox("2024-12-31", label="Fin (AAAA-MM-DD)")
            btn_e1 = gr.Button("Paso 1 → Exportar PR (CHIRPS)")
            out_e1 = gr.Textbox(label="Estado exportación PR")

            k_spi  = gr.Textbox("1,3,6,12", label="k (meses)")
            btn_e2 = gr.Button("Paso 2 → Calcular SPI (R)")
            out_e2 = gr.Textbox(label="Estado SPI")

            with gr.Row():
                k_plot  = gr.Dropdown(["1","3","6","12"], value="3", label="k")
                date_sp = gr.Textbox("2021-01", label="Mes (AAAA-MM)")
            btn_e3 = gr.Button("Paso 3 → Rasterizar + Plot SPI")
            fig_spi = gr.Plot(label="Mapa SPI")
            msg_sp  = gr.Textbox(label="Estado SPI map")

            gr.Markdown("### Índice **NIFT** y parámetros P1–P10")

            k_nift = gr.Textbox("1,3,6,12", label="k para NIFT")

            with gr.Accordion("Ponderaciones NIFT (por defecto: Brasil Neto)", open=False):
                preset_w = gr.Dropdown(
                    ["Brasil Neto (2024)", "Uniforme (10% c/u)", "Personalizado"],
                    value="Brasil Neto (2024)", label="Preset"
                )
                with gr.Row():
                    wP1 = gr.Slider(0.0, 1.0, value=0.125, step=0.001, label="P1 (N)")
                    wP2 = gr.Slider(0.0, 1.0, value=0.125, step=0.001, label="P2 (MDI)")
                    wP3 = gr.Slider(0.0, 1.0, value=0.009, step=0.001, label="P3 (%Mild)")
                    wP4 = gr.Slider(0.0, 1.0, value=0.034, step=0.001, label="P4 (%Mod)")
                    wP5 = gr.Slider(0.0, 1.0, value=0.071, step=0.001, label="P5 (%Sev)")
                with gr.Row():
                    wP6 = gr.Slider(0.0, 1.0, value=0.136, step=0.001, label="P6 (%Ext)")
                    wP7 = gr.Slider(0.0, 1.0, value=0.100, step=0.001, label="P7 (BT) NEG")
                    wP8 = gr.Slider(0.0, 1.0, value=0.075, step=0.001, label="P8 (DT)")
                    wP9 = gr.Slider(0.0, 1.0, value=0.075, step=0.001, label="P9 (ST)")
                    wP10= gr.Slider(0.0, 1.0, value=0.250, step=0.001, label="P10 (Prec) NEG")

                renorm = gr.Checkbox(True, label="Renormalizar a suma=1.0")
                sum_box = gr.Markdown("**Suma actual:** 1.000")

            def _preset_vals(name: str):
                if name.startswith("Brasil"): w = DEFAULT_WEIGHTS_BRASIL_NETO
                elif name.startswith("Uniforme"): w = DEFAULT_WEIGHTS_UNIFORME
                else: return None
                return [w[f"P{i}"] for i in range(1, 11)]
            def _weights_from_ui(preset, p1,p2,p3,p4,p5,p6,p7,p8,p9,p10):
                if preset.startswith("Brasil"): return DEFAULT_WEIGHTS_BRASIL_NETO
                if preset.startswith("Uniforme"): return DEFAULT_WEIGHTS_UNIFORME
                return {f"P{i}": v for i, v in enumerate([p1,p2,p3,p4,p5,p6,p7,p8,p9,p10], start=1)}
            def _sum_text(p1,p2,p3,p4,p5,p6,p7,p8,p9,p10):
                s = float(p1+p2+p3+p4+p5+p6+p7+p8+p9+p10); return f"**Suma actual:** {s:.3f}"
            def _apply_preset(preset):
                vals = _preset_vals(preset);
                if vals is None: return (gr.update(),)*10
                return tuple(gr.update(value=v) for v in vals)
            preset_w.change(_apply_preset, [preset_w], [wP1,wP2,wP3,wP4,wP5,wP6,wP7,wP8,wP9,wP10])
            for s in (wP1,wP2,wP3,wP4,wP5,wP6,wP7,wP8,wP9,wP10):
                s.change(_sum_text, [wP1,wP2,wP3,wP4,wP5,wP6,wP7,wP8,wP9,wP10], [sum_box])
            def _do_nift_with_weights(ks, preset, p1,p2,p3,p4,p5,p6,p7,p8,p9,p10, renormalize):
                w = _weights_from_ui(preset, p1,p2,p3,p4,p5,p6,p7,p8,p9,p10)
                return compute_nift(ks, run_min=3, weights=w, renormalize=bool(renormalize))
            btn_e4 = gr.Button("Paso 4 → Calcular NIFT")
            out_e4 = gr.Textbox(label="Estado NIFT")
            def _exp_pr(a,b): return export_chirps_pr(a,b)
            def _do_spi(ks):   return compute_spi_from_csv(ks)
            def _plot_spi(k, d): return plot_index_map(DRIVE_DIR_SPI, "SPI", k, d, "spi", f"SPI-{k} · {d}", DRIVE_DIR_SPI)
            btn_e1.click(_exp_pr,  [sd_spi, ed_spi], out_e1)
            btn_e2.click(_do_spi,  [k_spi], out_e2)
            btn_e3.click(_plot_spi, [k_plot, date_sp], [fig_spi, msg_sp])
            btn_e4.click(_do_nift_with_weights, inputs=[k_nift, preset_w, wP1,wP2,wP3,wP4,wP5,wP6,wP7,wP8,wP9,wP10, renorm], outputs=[out_e4])
            gr.Markdown("### Mapas NIFT y parámetros P1–P10")
            with gr.Row():
                k_plot_all_spi = gr.Textbox("1,3,6,12", label="k para mapas (coma)")
                norm_chk_spi   = gr.Checkbox(False, label="Normalizar P1–P10")
            plot_btn_spi  = gr.Button("Paso 5A → Mapas NIFT + P1–P10")
            param_plot_spi = gr.Plot(label="Parámetros P1–P10 (completo)")
            nift_plot_spi  = gr.Plot(label="Mapas NIFT")
            plot_msg_spi   = gr.Textbox(label="Estado mapas (SPI/NIFT)")
            gr.Markdown("### Figuras compactas (4 imágenes: conjuntos G1–G4)")
            plot_grp_btn_spi = gr.Button("Paso 5B → Figuras compactas (G1–G4)")
            with gr.Row():
                g1_plot_spi = gr.Plot(label="Grupo 1"); g2_plot_spi = gr.Plot(label="Grupo 2")
            with gr.Row():
                g3_plot_spi = gr.Plot(label="Grupo 3"); g4_plot_spi = gr.Plot(label="Grupo 4")
            plot_grp_msg_spi = gr.Textbox(label="Estado grupos (SPI/NIFT)")
            plot_btn_spi.click(_spi_plot_maps, inputs=[k_plot_all_spi, norm_chk_spi], outputs=[param_plot_spi, nift_plot_spi, plot_msg_spi])
            plot_grp_btn_spi.click(_spi_plot_maps_grouped, inputs=[k_plot_all_spi, norm_chk_spi], outputs=[g1_plot_spi, g2_plot_spi, g3_plot_spi, g4_plot_spi, plot_grp_msg_spi])
            gr.Markdown("### Exportar NIFT a GeoTIFF (uno por k)")
            btn_nift_tif  = gr.Button("Paso 5C → Exportar NIFT a GeoTIFFs")
            out_nift_tif  = gr.Textbox(label="Estado export GeoTIFF (NIFT)")
            btn_nift_tif.click(lambda ks: export_nift_geotiffs(ks, DRIVE_DIR_SPI), inputs=[k_plot_all_spi], outputs=[out_nift_tif])

        # -------------------- TAB 2: SPEI + Análisis --------------
        with gr.TabItem("SPEI (CHIRPS − MODIS) + Análisis"):
            with gr.Row():
                sd_sp = gr.Textbox("2001-01-01", label="Inicio (AAAA-MM-DD)")
                ed_sp = gr.Textbox("2024-12-31", label="Fin (AAAA-MM-DD)")
            btn_s1 = gr.Button("Paso 1 → Exportar PR − PET (mensual)")
            out_s1 = gr.Textbox(label="Estado exportación balance")

            k_spei = gr.Textbox("1,3,6,12", label="k (meses)")
            btn_s2 = gr.Button("Paso 2 → Calcular SPEI (R)")
            out_s2 = gr.Textbox(label="Estado SPEI")

            with gr.Row():
                k_plot2  = gr.Dropdown(["1","3","6","12"], value="3", label="k")
                date_sp2 = gr.Textbox("2021-01", label="Mes (AAAA-MM)")
            btn_s3 = gr.Button("Paso 3 → Rasterizar + Plot SPEI")
            fig_spei = gr.Plot(label="Mapa SPEI")
            msg_spei = gr.Textbox(label="Estado SPEI map")

            gr.Markdown("#### Validación SPEI (GEE CSIC vs Local)")
            val_k_list = gr.Textbox("1,3,6,12", label="k (coma)")
            with gr.Row():
                val_sd = gr.Textbox("1985-01-01", label="Inicio gráfico (opcional)")
                val_ed = gr.Textbox("2024-12-31", label="Fin gráfico (opcional)")
            val_btn  = gr.Button("Comparar series")
            val_plot = gr.Plot(label="GEE vs Local")
            val_msg  = gr.Textbox(label="Estado validación")

            gr.Markdown("#### Comparar SPEI (Local) vs ENSO mensual (ONI)")
            with gr.Row():
                enso_csv_in = gr.Textbox(f"{DRIVE_DIR_SPEI}/ENSO.csv", label="Ruta ENSO.csv (ONI 3-meses)")
                enso_lag_dd = gr.Dropdown([str(i) for i in range(-3,4)], value="-1", label="Lag ENSO (meses, −3…+3)")
            with gr.Row():
                enso_sd = gr.Textbox("1985-01-01", label="Inicio gráfico (opcional)")
                enso_ed = gr.Textbox("2024-12-31", label="Fin gráfico (opcional)")
            enso_btn    = gr.Button("Comparar con ENSO mensual")
            enso_plot   = gr.Plot(label="SPEI (Local) vs ENSO (ONI)")
            enso_msg    = gr.Textbox(label="Estado ENSO")

            gr.Markdown("#### FIRMS mensual (VIIRS 375 m)")
            with gr.Row():
                f_start = gr.Textbox("2014-01-01", label="Inicio (AAAA-MM-DD)")
                f_end   = gr.Textbox("2024-12-31", label="Fin (AAAA-MM-DD)")
                firms_btn = gr.Button("Exportar FIRMS → CSV (grilla CHIRPS)")
            firms_out = gr.Textbox(label="Estado FIRMS (export)")

            gr.Markdown("**Visualizar incendios (presencia) desde CSV**")
            with gr.Row():
                fires_csv_map   = gr.Textbox(f"{DRIVE_DIR_SPEI}/FIRMS_MONTH_20140101_20241231.csv", label="Ruta CSV FIRMS_MONTH…")
            with gr.Row():
                fires_start_map = gr.Textbox("2014-01", label="Mes inicio (AAAA-MM)")
                fires_end_map   = gr.Textbox("2024-12", label="Mes fin (AAAA-MM)")
            fires_map_btn  = gr.Button("Mostrar mapa (CSV vs GEE)")
            fires_map_plot = gr.Plot(label="Mapa incendios (comparación)")
            fires_map_msg  = gr.Textbox(label="Estado mapa")

            # ---- GWSS PERÍODO COMPLETO ----
            gr.Markdown("#### Correlación GWSS para Período Completo — Lógica de Distancia")
            gr.Markdown("Genera un mapa de correlación local para todo el período seleccionado (ej. 2014-2024).")
            with gr.Row():
                gw_firms_csv = gr.Textbox(value=f"{DRIVE_DIR_SPEI}/FIRMS_MONTH_20140101_20241231.csv", label="Ruta CSV FIRMS_MONTH…")
            with gr.Row():
                gw_k      = gr.Dropdown(["1","3","6","12"], value="3", label="k SPEI")
                gw_lag    = gr.Dropdown(["0","1"], value="0", label="lag SPEI (meses)")
                gw_years  = gr.Textbox(value="2014-2024", label="Período: AAAA-AAAA | AAAA | AAAA,AAAA (vacío=todo)")
            with gr.Row():
                gw_bw         = gr.Slider(minimum=10, maximum=1500, value=1080, step=10, label="BW adaptativo (nº obs. celda-mes)")
                gw_kernel     = gr.Dropdown(["bisquare","gaussian","exponential","tricube","boxcar"], value="bisquare", label="Kernel")
                gw_spearman   = gr.Checkbox(value=False, label="Usar Spearman (rango)")
                gw_nan_strategy = gr.Dropdown(["mask","impute0"], value="mask", label="NaN locales: 'mask' o 'impute0'")

            gr.Markdown("**Filtro ENSO opcional dentro del periodo:**")
            with gr.Row():
                gw_enso_csv = gr.Textbox(f"{DRIVE_DIR_SPEI}/ENSO.csv", label="Ruta ENSO.csv")
                gw_enso_phase = gr.Dropdown(["all","nino","neutral","nina"], value="all", label="Fase ENSO (all = sin filtro)")
                gw_enso_thr   = gr.Slider(0.3, 1.0, value=0.5, step=0.1, label="Umbral |ONI| (fase ENSO)")
                gw_enso_lag   = gr.Dropdown([str(i) for i in range(-3,4)], value="0", label="Lag ENSO (meses, −3…+3)")

            gw_btn   = gr.Button("Calcular GWSS para el Período Seleccionado")
            gw_plot  = gr.Plot(label="Mapa de correlación GW (ρ local) - Período Completo")
            gw_msg   = gr.Textbox(label="Estado / Resumen breve")
            gw_sig_plot  = gr.Plot(label="Significancia local (−log10 q) - Período Completo")
            gw_sens_plot = gr.Plot(label="Sensibilidad local (pendiente) - Período Completo")
            gw_hist_plot = gr.Plot(label="Distribución de ρ locales (hist + ECDF) - Período Completo")
            gw_stats_box = gr.Textbox(label="Resumen estadístico (mapa GW)", lines=4)

            # --- Wiring SPEI tab ---
            def _exp_bal(a,b): return export_pr_minus_pet(a,b)
            def _do_spei(ks):  return compute_spei_from_csv(ks)
            def _plot_spei(k, d): return plot_index_map(DRIVE_DIR_SPEI, "SPEI", k, d, "spei", f"SPEI-{k} · {d}", DRIVE_DIR_SPEI)
            btn_s1.click(_exp_bal,  [sd_sp, ed_sp], out_s1)
            btn_s2.click(_do_spei,  [k_spei], out_s2)
            btn_s3.click(_plot_spei, [k_plot2, date_sp2], [fig_spei, msg_spei])
            val_btn.click(lambda ks, sd, ed: _plot_spei_vs_spei_clc(ks, sd, ed), inputs=[val_k_list, val_sd, val_ed], outputs=[val_plot, val_msg])
            enso_btn.click(lambda ks, sd, ed, p, lag: _plot_spei_vs_enso(ks, sd, ed, p, int(lag)), inputs=[val_k_list, enso_sd, enso_ed, enso_csv_in, enso_lag_dd], outputs=[enso_plot, enso_msg])
            firms_btn.click(_export_firms_monthly, [f_start, f_end], firms_out)
            fires_map_btn.click(_plot_fires, inputs=[fires_csv_map, fires_start_map, fires_end_map], outputs=[fires_map_plot, fires_map_msg])

            gw_btn.click(
                lambda fcsv, k, lag, yrs, bw, ker, sp, ns, epath, phase, thr, elag:
                    gw_correlation_unified_plus( # Llama a la misma función que antes
                        int(k), int(lag), fcsv, yrs, int(bw), ker, bool(sp),
                        ns, epath, phase, float(thr), int(elag)
                    ),
                inputs=[gw_firms_csv, gw_k, gw_lag, gw_years,
                        gw_bw, gw_kernel, gw_spearman, gw_nan_strategy,
                        gw_enso_csv, gw_enso_phase, gw_enso_thr, gw_enso_lag],
                outputs=[gw_plot, gw_sig_plot, gw_sens_plot, gw_hist_plot, gw_stats_box, gw_msg]
            )

            # ---- Panel anual ----
            gr.Markdown("#### Panel anual GWSS (Año por Año) — 4 figuras compactas")
            with gr.Row():
                yrs_start = gr.Number(value=2014, label="Año inicio")
                yrs_end   = gr.Number(value=2024, label="Año fin")
                ncols_v   = gr.Slider(minimum=1, maximum=4, value=4, step=1, label="Columnas del panel")
            bw_panel_anual = gr.Slider(minimum=10, maximum=500, value=120, step=5, label="BW Anual (nº obs. celda-mes)")
            btn_yearpan = gr.Button("Generar Panel Anual (Año por Año)")
            yr_corr = gr.Plot(label="Panel: Correlación local por año")
            yr_sig  = gr.Plot(label="Panel: Significancia local por año")
            yr_slo  = gr.Plot(label="Panel: Pendiente local por año")
            yr_sum  = gr.Plot(label="Panel: Resumen (ECDF + Histo)")
            yr_msg  = gr.Textbox(label="Estado GWSS anual")
            def _run_yearpanels_ui(
                firms_csv: str, k: int, lag: int,
                years_start: int, years_end: int, ncols: int,
                bw_neighbors: int, kernel: str, use_spearman: bool,
                nan_strategy: str,
                enso_csv: str, enso_phase: str, enso_thr: float, enso_lag: int,
            ):
                return gw_correlation_yearly_panels(
                    k=int(k), lag=int(lag), firms_csv_path=firms_csv,
                    start_year=int(years_start), end_year=int(years_end),
                    bw_neighbors=int(bw_neighbors), kernel=kernel, use_spearman=bool(use_spearman),
                    nan_strategy=nan_strategy,
                    enso_csv_path=(enso_csv or f"{DRIVE_DIR_SPEI}/ENSO.csv"),
                    enso_phase=enso_phase, enso_thr=float(enso_thr), enso_lag=int(enso_lag),
                    ncols=int(ncols)
                )
            btn_yearpan.click(
                _run_yearpanels_ui,
                inputs=[
                    gw_firms_csv, gw_k, gw_lag,
                    yrs_start, yrs_end, ncols_v,
                    bw_panel_anual,
                    gw_kernel, gw_spearman,
                    gw_nan_strategy,
                    gw_enso_csv, gw_enso_phase, gw_enso_thr, gw_enso_lag,
                ],
                outputs=[yr_corr, yr_sig, yr_slo, yr_sum, yr_msg]
            )

        # -------------------- TAB 3: GWSS Ejemplo Celda Única ----
        with gr.TabItem("GWSS — Celda ejemplo"):
            gr.Markdown("""
#### Ejemplo paso a paso GWSS para UNA celda (Lógica de Distancia)
Genera un reporte detallado (ρ, vecinos, pesos, t-value, p-value, q-value) para una celda aleatoria o indicada.
""")
            with gr.Row():
                k_ex = gr.Dropdown(["1","3","6","12"], value="3", label="k (meses)")
                lag_ex = gr.Slider(-6, 6, value=1, step=1, label="lag SPEI")
                years_ex = gr.Textbox("2015-2023", label="Años (AAAA-AAAA)")
            with gr.Row():
                bw_ex = gr.Slider(10, 1500, value=400, step=10, label="BW (vecinos celda-mes)")
                kernel_ex = gr.Dropdown(["bisquare","gaussian","exponential","tricube","boxcar"], value="exponential", label="Kernel")
                spearman_ex = gr.Checkbox(True, label="Spearman")
            gr.Markdown("**Filtros ENSO:**")
            with gr.Row():
                enso_phase_ex = gr.Dropdown(["all","nino","neutral","nina"], value="all", label="Fase ENSO")
                enso_csv_ex = gr.Textbox(f"{DRIVE_DIR_SPEI}/ENSO.csv", label="ENSO.csv")
                enso_thr_ex = gr.Slider(0.3, 1.0, value=0.5, step=0.1, label="Umbral |ONI|")
                enso_lag_ex = gr.Dropdown([str(i) for i in range(-3,4)], value="0", label="Lag ENSO")
            firms_csv_ex = gr.Textbox(f"{DRIVE_DIR_SPEI}/FIRMS_MONTH_20140101_20241231.csv", label="FIRMS mensual (CSV)")
            cell_id_ex = gr.Textbox("", label="cell_id (vacío=aleatorio)")
            btn_explain = gr.Button("Ejecutar análisis (celda única)")
            fig_explain = gr.Plot(label="Análisis GWSS completo (6 paneles)")
            txt_explain = gr.Textbox(label="Reporte detallado", lines=15)
            gr.Markdown("""
**Interpretación:**
- **ρ (rho)**: coeficiente de correlación local (Pearson o Spearman)
- **t**: estadístico t de Student para probar H₀: ρ=0
- **df**: grados de libertad efectivos (n_eff - 2)
- **p**: p-value bilateral
- **q**: q-value ajustado por FDR (Benjamini-Hochberg)
- **FDR≤0.05**: si es significativa tras corrección por comparaciones múltiples
""")
            btn_explain.click(
                explain_gwss_ui,
                inputs=[k_ex, lag_ex, firms_csv_ex, years_ex,
                        bw_ex, kernel_ex, spearman_ex,
                        enso_phase_ex, enso_csv_ex, enso_thr_ex, enso_lag_ex,
                        cell_id_ex],
                outputs=[fig_explain, txt_explain]
            )
    gr.Markdown(
        "> **Tips**\n"
        "> • MODIS **PET** proviene de *MOD16A2* (8-días); se agrega a **mm/mes** aplicando escala 0.1.\n"
        "> • Alineación 1:1 a CHIRPS: sin reproyección en rasterización local.\n"
        "> • `SPI_k_month.csv` y `SPEI_k_month.csv` quedan en tus carpetas de Drive.\n"
        "> • La pestaña **GWSS — Celda ejemplo** genera reportes detallados para una única celda."
    )

demo.launch(share=True, debug=True)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://cb2afe488ded0dda60.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://cb2afe488ded0dda60.gradio.live


