In [None]:
"""
Flooding methodology
---------------------------------------------------------
This script implements a continent-by-continent flooding assessment consistent with our
stated method.

Upstream (not in this script):
- GRDC daily discharge screening (1920–2024) with ≥9 continuous years and <5% missing.
- GEV fitting on annual maxima to get historical 10‑year return levels (Q10) + 95% CIs
  per station. The resulting table is passed in via --historical-q10.

In this script:
1) Match each GRDC station (geo_x, geo_y) to the **nearest climate grid cell** using
   a Haversine match, for both historical and future monthly files.
2) Convert units (K→°C; kg m⁻² s⁻¹→mm/month), then compute **snowpack** and **PET**
   exactly as defined:
   - Snowpack proxy = monthly precip when **tasmin < 0 °C**, else 0 (binary by month).
   - PET = Thornthwaite monthly PET using **mean monthly T**, **gauge latitude**,
     **daylength by month**, and **days in month**:
       PET_m = 16 * (10*T/I)^a * (L/12) * (N/30), T>0, I>0.
3) Aggregate monthly variables to **period means per station**.
4) **Train** (mode=train) regressors per continent on **log(Q10)** with optional
   interactions; pick best by test RMSE (RandomForest / GradientBoosting / optional XGBoost).
   Save model + feature list per continent.
5) **Predict** (mode=predict) future Q10 (same features), and compute **CF = Q10_future/Q10_hist**.

Inputs (CSV schemas expected; column names configurable via flags):
- --historical-q10 : station metadata + historical Q10
    required cols: Station_ID, continent, basin_area_km2, Q10_hist, geo_x, geo_y
- --historical-climate / --future-climate : monthly gridded climate
    required cols: lat, lon, month(1..12), pr, tas, tasmax, tasmin

Outputs (under --output-dir):
- models/<Continent>.pkl                         (best model)
- features_<Continent>.json                      (feature list used)
- summary_best_by_continent.csv                  (RMSE/R²/AdjR²)
- Africa_future_q10_predictions.csv (etc.)       (Station_ID, Q10_hist, Q10_future_pred, CF)

"""

from __future__ import annotations
import argparse
import json
import math
import os
import pickle
from typing import Dict, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
from tqdm import tqdm

from sklearn.model_selection import train_test_split, RandomizedSearchCV, KFold
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score

try:
    from xgboost import XGBRegressor
    _HAVE_XGB = True
except Exception:
    _HAVE_XGB = False

# ---------------------------
# Constants & small helpers
# ---------------------------
_DAYS_IN_MONTH = {1:31, 2:28, 3:31, 4:30, 5:31, 6:30, 7:31, 8:31, 9:30, 10:31, 11:30, 12:31}


def _ensure_columns(df: pd.DataFrame, cols: Sequence[str], label: str) -> None:
    missing = [c for c in cols if c not in df.columns]
    if missing:
        raise ValueError(f"{label} missing columns: {missing}")


def haversine(lat1, lon1, lat2, lon2):
    R = 6371.0
    lat1_rad, lon1_rad = np.radians(lat1), np.radians(lon1)
    lat2_rad, lon2_rad = np.radians(lat2), np.radians(lon2)
    dlat = lat2_rad - lat1_rad
    dlon = lon2_rad - lon1_rad
    a = np.sin(dlat/2)**2 + np.cos(lat1_rad)*np.cos(lat2_rad)*np.sin(dlon/2)**2
    c = 2*np.arctan2(np.sqrt(a), np.sqrt(1-a))
    return R*c


def match_stations_to_grid(stations: pd.DataFrame, climate_grid: pd.DataFrame,
                           id_col: str, xcol: str, ycol: str,
                           lat_col: str = "lat", lon_col: str = "lon") -> pd.DataFrame:
    """Return mapping of each station to the nearest grid point (lat, lon)."""
    _ensure_columns(stations, [id_col, xcol, ycol], "stations")
    _ensure_columns(climate_grid, [lat_col, lon_col], "climate_grid")

    # unique grid points to speed up search
    grid_pts = climate_grid[[lat_col, lon_col]].drop_duplicates().reset_index(drop=True)

    rows = []
    for sid, gx, gy in tqdm(stations[[id_col, xcol, ycol]].itertuples(index=False), total=len(stations)):
        dists = haversine(gy, gx, grid_pts[lat_col].values, grid_pts[lon_col].values)
        idx = int(np.argmin(dists))
        rows.append((sid, grid_pts.at[idx, lat_col], grid_pts.at[idx, lon_col]))
    mapping = pd.DataFrame(rows, columns=[id_col, lat_col, lon_col])
    return mapping


# ---------------------------
# Unit conversions
# ---------------------------

def to_celsius_inplace(df: pd.DataFrame, cols: Sequence[str]) -> None:
    for c in cols:
        df[c] = df[c].astype(float) - 273.15


def pr_to_mm_month_inplace(df: pd.DataFrame, pr_col: str, month_col: str) -> None:
    df[month_col] = df[month_col].astype(int)
    if not df[month_col].between(1, 12).all():
        bad = df.loc[~df[month_col].between(1, 12), month_col].unique().tolist()
        raise ValueError(f"Month must be 1..12, found: {bad}")
    seconds_per_day = 86400.0
    days = df[month_col].map(_DAYS_IN_MONTH).astype(float)
    df["pr_mm_month"] = df[pr_col].astype(float) * seconds_per_day * days


# ---------------------------
# Thornthwaite PET
# ---------------------------

def _monthly_day_of_year_midpoint(month: int) -> int:
    return int(round(sum(_DAYS_IN_MONTH[m] for m in range(1, month)) + _DAYS_IN_MONTH[month]/2.0))


def _daylength_hours(latitude_deg: float, month: int) -> float:
    lat = float(latitude_deg)
    if not (-90.0 <= lat <= 90.0):
        raise ValueError(f"Latitude out of bounds [-90,90]: {lat}")
    n = _monthly_day_of_year_midpoint(month)
    decl_deg = 23.45 * np.sin(np.deg2rad((360.0/365.0)*(284 + n)))
    phi = np.deg2rad(lat)
    decl = np.deg2rad(decl_deg)
    arg = np.clip(-np.tan(phi) * np.tan(decl), -1.0, 1.0)
    sha = np.arccos(arg)
    return float((24.0/np.pi) * sha)


def _thornthwaite_I(monthly_tas_C: pd.Series) -> float:
    tpos = monthly_tas_C[monthly_tas_C > 0.0]
    if tpos.empty:
        return 0.0
    return float(np.sum((tpos/5.0)**1.514))


def _thornthwaite_a(I: float) -> float:
    return (6.75e-7 * I**3) - (7.71e-5 * I**2) + (1.79e-2 * I) + 0.49


def _pet_month(T_C: float, I: float, lat_deg: float, month: int) -> float:
    if (T_C <= 0.0) or (I <= 0.0):
        return 0.0
    a = _thornthwaite_a(I)
    L = _daylength_hours(lat_deg, int(month))
    N = _DAYS_IN_MONTH[int(month)]
    return float(16.0 * ((10.0*T_C/I)**a) * (L/12.0) * (N/30.0))


def add_pet_snowpack_monthly(df: pd.DataFrame,
                             id_col: str,
                             lat_col: str,
                             month_col: str,
                             tas_col: str,
                             tasmin_col: str,
                             pr_col_mm_month: str) -> pd.DataFrame:
    """Compute I per station, PET per row, and snowpack per row; return a copy with columns added."""
    df = df.copy()
    _ensure_columns(df, [id_col, lat_col, month_col, tas_col, tasmin_col, pr_col_mm_month], "monthly climate")

    # Heat index I per station using mean monthly tas
    t_monthly = df.groupby([id_col, month_col], as_index=False)[tas_col].mean()
    I_per_id = (
        t_monthly.groupby(id_col, as_index=False)
        .agg(I=(tas_col, lambda s: _thornthwaite_I(s)))
    )
    df = df.merge(I_per_id, on=id_col, how="left")
    df["I"] = df["I"].fillna(0.0)

    # PET
    df["PET_mm"] = df.apply(lambda r: _pet_month(float(r[tas_col]), float(r["I"]), float(r[lat_col]), int(r[month_col])), axis=1)

    # Snowpack
    df["snow_pack_mm"] = np.where(df[tasmin_col] < 0.0, df[pr_col_mm_month], 0.0)

    return df


def aggregate_period_means(df: pd.DataFrame, id_col: str) -> pd.DataFrame:
    agg = df.groupby(id_col).agg({
        "pr_mm_month": "mean",
        "tas_C": "mean",
        "tasmax_C": "mean",
        "tasmin_C": "mean",
        "PET_mm": "mean",
        "snow_pack_mm": "mean",
    }).rename(columns={
        "pr_mm_month": "pr_mm_month_mean",
        "tas_C": "tas_C_mean",
        "tasmax_C": "tasmax_C_mean",
        "tasmin_C": "tasmin_C_mean",
        "PET_mm": "PET_mm_mean",
        "snow_pack_mm": "snow_pack_mm_mean",
    }).reset_index()
    return agg


# ---------------------------
# Feature matrix
# ---------------------------

def build_features(agg_df: pd.DataFrame, stations_df: pd.DataFrame, id_col: str,
                   min_area_km2: float,
                   interactions: List[Tuple[str, str]]) -> Tuple[pd.DataFrame, List[str]]:
    base_feats = [
        "pr_mm_month_mean",
        "tas_C_mean",
        "tasmax_C_mean",
        "tasmin_C_mean",
        "PET_mm_mean",
        "snow_pack_mm_mean",
        "basin_area_km2",
    ]
    need_cols = [id_col, "continent", "basin_area_km2", "Q10_hist"]
    _ensure_columns(stations_df, need_cols, "historical Q10 table")

    df = agg_df.merge(stations_df[[id_col, "continent", "basin_area_km2", "Q10_hist"]], on=id_col, how="inner")
    df = df[df["basin_area_km2"] >= float(min_area_km2)].copy()

    # interactions
    feat_cols = base_feats.copy()
    for a, b in interactions:
        if a not in df.columns or b not in df.columns:
            raise ValueError(f"Missing columns for interaction {a} x {b}")
        name = f"{a}__x__{b}"
        df[name] = df[a].astype(float) * df[b].astype(float)
        feat_cols.append(name)

    return df, feat_cols


# ---------------------------
# Modeling (per continent)
# ---------------------------

def fit_models_per_continent(df_all: pd.DataFrame, feature_cols: List[str], random_state: int = 42):
    continents = sorted(df_all["continent"].dropna().unique().tolist())
    best_models: Dict[str, dict] = {}
    rows = []

    for cont in continents:
        d = df_all[df_all["continent"] == cont].dropna(subset=["Q10_hist"]).copy()
        if d.empty:
            continue
        y = np.log(d["Q10_hist"].astype(float))
        X = d[feature_cols].astype(float)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=random_state)

        models = []
        rf = RandomForestRegressor(random_state=random_state, n_jobs=-1)
        rf_grid = {"n_estimators": [300, 600, 900], "max_depth": [None, 12, 24],
                   "min_samples_split": [2, 5, 10], "min_samples_leaf": [1, 2, 4]}
        models.append(("RandomForest", rf, rf_grid))

        gb = GradientBoostingRegressor(random_state=random_state)
        gb_grid = {"n_estimators": [300, 600, 900], "learning_rate": [0.03, 0.05, 0.1], "max_depth": [2, 3, 4], "subsample": [0.7, 0.9, 1.0]}
        models.append(("GradientBoosting", gb, gb_grid))

        if _HAVE_XGB:
            xgb = XGBRegressor(random_state=random_state, tree_method="hist")
            xgb_grid = {"n_estimators": [400, 800, 1200], "max_depth": [4, 6, 8], "learning_rate": [0.03, 0.06, 0.1], "subsample": [0.7, 0.9, 1.0], "colsample_bytree": [0.6, 0.8, 1.0]}
            models.append(("XGBoost", xgb, xgb_grid))

        best = None
        for name, est, grid in models:
            cv = KFold(n_splits=5, shuffle=True, random_state=random_state)
            n_iter = min(25, sum(len(v) for v in grid.values()))
            search = RandomizedSearchCV(est, grid, n_iter=n_iter, cv=cv, scoring="neg_root_mean_squared_error", n_jobs=-1, random_state=random_state)
            search.fit(X_train, y_train)
            pred = search.predict(X_test)
            rmse = float(math.sqrt(mean_squared_error(y_test, pred)))
            r2 = float(r2_score(y_test, pred))
            n = len(y_test); p = X_test.shape[1]
            adjr2 = float(1 - (1 - r2) * (n - 1) / max(1, (n - p - 1)))
            rows.append({"continent": cont, "model": name, "test_RMSE_logQ10": rmse, "test_R2": r2, "test_AdjR2": adjr2, "features": ",".join(feature_cols)})
            if (best is None) or (rmse < best["rmse"]):
                best = {"name": name, "estimator": search.best_estimator_, "rmse": rmse, "r2": r2, "adjr2": adjr2}
        best_models[cont] = best

    summary = pd.DataFrame(rows)
    best_by = summary.sort_values(["continent", "test_RMSE_logQ10"]).groupby("continent").head(1).reset_index(drop=True)
    best_by["chosen_model"] = best_by["model"]
    return best_models, best_by, summary


# ---------------------------
# Orchestration
# ---------------------------

def load_station_table(path: str, id_col: str, xcol: str, ycol: str) -> pd.DataFrame:
    df = pd.read_csv(path)
    need = [id_col, "continent", "basin_area_km2", "Q10_hist", xcol, ycol]
    _ensure_columns(df, need, "historical-q10")
    return df


def prepare_monthly(climate_csv: str, mapping: pd.DataFrame, id_col: str,
                    lat_col: str, lon_col: str, month_col: str,
                    temps_already_celsius: bool,
                    pr_already_mm_month: bool) -> pd.DataFrame:
    """Join climate monthly grid to stations via (lat,lon), convert units, compute PET/snowpack."""
    clim = pd.read_csv(climate_csv)
    _ensure_columns(clim, [lat_col, lon_col, month_col, "pr", "tas", "tasmax", "tasmin"], "monthly climate")

    # Convert units
    if not temps_already_celsius:
        to_celsius_inplace(clim, ["tas", "tasmax", "tasmin"])
        clim.rename(columns={"tas":"tas_C", "tasmax":"tasmax_C", "tasmin":"tasmin_C"}, inplace=True)
    else:
        clim.rename(columns={"tas":"tas_C", "tasmax":"tasmax_C", "tasmin":"tasmin_C"}, inplace=True)

    if not pr_already_mm_month:
        pr_to_mm_month_inplace(clim, pr_col="pr", month_col=month_col)
    else:
        clim.rename(columns={"pr":"pr_mm_month"}, inplace=True)

    # Attach station ids via nearest (lat,lon)
    m = mapping.merge(clim, on=[lat_col, lon_col], how="left")
    _ensure_columns(m, [id_col, month_col, "tas_C", "tasmin_C", "pr_mm_month"], "station-climate join")

    # Compute PET & snowpack using station latitude and monthly tas
    m.rename(columns={"geo_y": "station_lat"}, inplace=True) if "geo_y" in m.columns else None
    lat_for_pet = "station_lat" if "station_lat" in m.columns else lat_col

    enriched = add_pet_snowpack_monthly(
        m,
        id_col=id_col,
        lat_col=lat_for_pet,
        month_col=month_col,
        tas_col="tas_C",
        tasmin_col="tasmin_C",
        pr_col_mm_month="pr_mm_month",
    )

    return enriched[[id_col, month_col, "tas_C", "tasmax_C", "tasmin_C", "pr_mm_month", "PET_mm", "snow_pack_mm"]]


def main(argv: Optional[List[str]] = None) -> int:
    ap = argparse.ArgumentParser(description="Continent-by-continent flooding pipeline (training & projections)")
    ap.add_argument("mode", choices=["train", "predict"], help="train: fit models on historical; predict: apply models to future")
    ap.add_argument("--historical-q10", required=True, help="CSV with Station_ID, continent, basin_area_km2, Q10_hist, geo_x, geo_y")
    ap.add_argument("--historical-climate", help="Monthly climate CSV for baseline period (required for train)")
    ap.add_argument("--future-climate", help="Monthly climate CSV for future period (required for predict)")
    ap.add_argument("--continent", default=None, help="Optional filter to a single continent (e.g., Africa)")

    # Column names to match your files
    ap.add_argument("--id-col", default="Station_ID")
    ap.add_argument("--xcol", default="geo_x")
    ap.add_argument("--ycol", default="geo_y")
    ap.add_argument("--lat-col", default="lat")
    ap.add_argument("--lon-col", default="lon")
    ap.add_argument("--month-col", default="month")

    ap.add_argument("--temps-already-celsius", action="store_true")
    ap.add_argument("--pr-already-mm-month", action="store_true")
    ap.add_argument("--min-area-km2", type=float, default=500.0)
    ap.add_argument("--output-dir", required=True)
    ap.add_argument("--random-state", type=int, default=42)
    ap.add_argument("--interactions-json", default=None, help="JSON string or path: {\"Africa\":[[\"pr_mm_month_mean\",\"tas_C_mean\"]], ...}")

    args = ap.parse_args(argv)
    os.makedirs(args.output_dir, exist_ok=True)
    models_dir = os.path.join(args.output_dir, "models"); os.makedirs(models_dir, exist_ok=True)

    # Load stations/Q10 and optional continent filter
    stations = load_station_table(args.historical_q10, args.id_col, args.xcol, args.ycol)
    if args.continent:
        stations = stations[stations["continent"] == args.continent].copy()
        if stations.empty:
            raise SystemExit(f"No stations for continent={args.continent}")

    # Build interactions per continent
    per_cont_interactions: Dict[str, List[Tuple[str, str]]] = {}
    if args.interactions_json:
        raw = args.interactions_json
        if os.path.exists(raw):
            with open(raw, "r", encoding="utf-8") as f: raw = f.read()
        js = json.loads(raw)
        for k, v in js.items():
            per_cont_interactions[k] = [tuple(p) for p in v]

    # TRAIN MODE
    if args.mode == "train":
        if not args.historical_climate:
            raise SystemExit("--historical-climate is required in train mode")
        # Collocation mapping using the historical grid
        hist_grid = pd.read_csv(args.historical_climate, usecols=[args.lat_col, args.lon_col]).drop_duplicates()
        mapping = match_stations_to_grid(stations, hist_grid, args.id_col, args.xcol, args.ycol, args.lat_col, args.lon_col)
        # Attach station latitude (for PET wording correctness)
        mapping = mapping.merge(stations[[args.id_col, args.ycol]], on=args.id_col, how="left").rename(columns={args.ycol: "geo_y"})

        hist_monthly = prepare_monthly(args.historical_climate, mapping, args.id_col, args.lat_col, args.lon_col, args.month_col,
                                       args.temps_already_celsius, args.pr_already_mm_month)
        hist_agg = aggregate_period_means(hist_monthly, args.id_col)

        # Build features per continent and fit models
        best_models = {}
        best_rows = []
        all_rows = []
        for cont in sorted(stations["continent"].dropna().unique()):
            inters = per_cont_interactions.get(cont, [])
            merged, feat_cols = build_features(hist_agg, stations[stations["continent"] == cont], args.id_col, args.min_area_km2, inters)
            if merged.empty:
                continue
            models, best_by, summary = fit_models_per_continent(merged, feat_cols, random_state=args.random_state)
            if cont in models and models[cont] is not None:
                best_models[cont] = models[cont]
                # save model + features
                with open(os.path.join(models_dir, f"{cont}.pkl"), "wb") as f:
                    pickle.dump(models[cont]["estimator"], f)
                with open(os.path.join(models_dir, f"features_{cont}.json"), "w", encoding="utf-8") as f:
                    json.dump(feat_cols, f, ensure_ascii=False, indent=2)
            best_rows.append(best_by)
            all_rows.append(summary)

        if best_rows:
            pd.concat(best_rows, ignore_index=True).to_csv(os.path.join(args.output_dir, "summary_best_by_continent.csv"), index=False)
        if all_rows:
            pd.concat(all_rows, ignore_index=True).to_csv(os.path.join(args.output_dir, "summary_all_models_by_continent.csv"), index=False)
        print("Training complete.")

    # PREDICT MODE
    else:
        if not args.future_climate:
            raise SystemExit("--future-climate is required in predict mode")
        # Build mapping from the FUTURE grid (so we pick the nearest available point in that file)
        fut_grid = pd.read_csv(args.future_climate, usecols=[args.lat_col, args.lon_col]).drop_duplicates()
        mapping = match_stations_to_grid(stations, fut_grid, args.id_col, args.xcol, args.ycol, args.lat_col, args.lon_col)
        mapping = mapping.merge(stations[[args.id_col, args.ycol]], on=args.id_col, how="left").rename(columns={args.ycol: "geo_y"})

        fut_monthly = prepare_monthly(args.future_climate, mapping, args.id_col, args.lat_col, args.lon_col, args.month_col,
                                      args.temps_already_celsius, args.pr_already_mm_month)
        fut_agg = aggregate_period_means(fut_monthly, args.id_col)

        continents = sorted(stations["continent"].dropna().unique().tolist())
        for cont in continents:
            model_path = os.path.join(models_dir, f"{cont}.pkl")
            feats_path = os.path.join(models_dir, f"features_{cont}.json")
            if not (os.path.exists(model_path) and os.path.exists(feats_path)):
                print(f"[warn] Missing model or features for {cont}; skipping predictions.")
                continue
            with open(model_path, "rb") as f:
                est = pickle.load(f)
            with open(feats_path, "r", encoding="utf-8") as f:
                feat_cols = json.load(f)

            inters = per_cont_interactions.get(cont, [])  # same interactions as training
            # For predictions we still need basin_area_km2 & Q10_hist to compute CF
            merged_fut, _ = build_features(fut_agg, stations[stations["continent"] == cont], args.id_col, args.min_area_km2, inters)
            if merged_fut.empty:
                continue

            Xf = merged_fut[feat_cols].astype(float)
            logQ10_future = est.predict(Xf)
            Q10_future = np.exp(logQ10_future)
            out = merged_fut[[args.id_col]].copy()
            out["Q10_hist"] = merged_fut["Q10_hist"].values
            out["Q10_future_pred"] = Q10_future
            out["CF"] = out["Q10_future_pred"] / out["Q10_hist"]
            out.to_csv(os.path.join(args.output_dir, f"{cont}_future_q10_predictions.csv"), index=False)
        print("Prediction complete.")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())