In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Final training script (Python 3.9 + Jupyter compatible)

Fixes:
1) Density weighting bin-key now uses int64 encoding (no numpy string concat).
2) RFECV.fit(sample_weight=...) is NOT supported in some sklearn versions.
   -> fallback to RFECV without sample_weight, while keeping weights in:
      - permutation importance (weighted)
      - Optuna CV objective (weighted)
      - final training (weighted)

NEW (this revision):
3) For each depth, after training all zone models, build "overall test set" predictions by
   loading each zone model + metadata and predicting for that zone's test rows, then output:
      - test_with_pred.csv   (ALL test rows, with Pred filled where zone model exists)
      - overall_test_metrics.txt  (overall test RMSE/MAE/R2 + coverage)
      - (optional) test_panel.png if ENABLE_PLOTTING

User request (this turn):
- Explicitly "stitch into an overall/global test set and output overall skill"
- Keep everything else unchanged.
"""

import os, gc, warnings
warnings.filterwarnings("ignore")

# ============================ Config ============================

ENABLE_PLOTTING = True
COMPUTE_SHAP = False
ENABLE_FEATURE_SELECTION = True

# ---- Inverse-density weighting (ON) ----
ENABLE_DENSITY_WEIGHTING = True
WEIGHT_GRID_RESOLUTION_DEG = 5.0     # 5°
WEIGHT_TIME_RESOLUTION_YR  = 10.0    # 10-year bins

# Optional: extra penalty for transition period (kept but default OFF)
ENABLE_TRANSITION_PERIOD_PENALTY = False
PENALTY_YEAR_START = 1985
PENALTY_YEAR_END   = 2000
PENALTY_FACTOR     = 0.5

# ---- Permutation importance (CV) ----
PERM_N_SPLITS   = 5
PERM_N_REPEATS  = 3
PERM_TOPK       = None
PERM_MAX_VAL_SAMPLES = 12000

# ---- RFECV ----
RFECV_STEP      = 1
RFECV_SCORING   = "neg_root_mean_squared_error"
USE_ADAPTIVE_RFECV_MIN = True
RFECV_MIN_RATIO = 0.40
RFECV_MIN_ABS   = 6
RFECV_MIN_FEATS = 8
RFECV_MIN_FEATS = 8
ALWAYS_KEEP = []
# ---- Optuna ----
import numpy as np
import pandas as pd

N_TRIALS_OPTUNA = int(os.environ.get("N_TRIALS_OPTUNA", "32"))
RANDOM_SEED = 42
N_ZONE_WORKERS = int(os.environ.get("ZONE_WORKERS", "4"))

MAX_CORES = 24
total_cores = min(os.cpu_count() or 16, MAX_CORES)
_default_cb_threads = max(4, total_cores // max(1, N_ZONE_WORKERS))
N_THREADS = int(os.environ.get("CB_N_THREADS", str(_default_cb_threads)))

# Avoid BLAS/OMP oversubscription
os.environ["OMP_NUM_THREADS"]        = "1"
os.environ["MKL_NUM_THREADS"]        = "1"
os.environ["OPENBLAS_NUM_THREADS"]   = "1"
os.environ["NUMEXPR_NUM_THREADS"]    = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.linear_model import LinearRegression
from sklearn.feature_selection import RFECV

from joblib import Parallel, delayed
import joblib

from catboost import CatBoostRegressor, Pool
import optuna
from optuna.samplers import TPESampler
from math import sqrt

import seaborn as sns

# ============================ Run depth list ============================

depthlist = ["1900","1800"]

SEASONS = ["Spring", "Summer", "Autumn", "Winter", "NewYear"]

# ============================ Time setup ============================

yearstart = int(os.environ.get("YEARSTART", "1960"))

_TEST_YEARS_MAP = {
    1960: [1961,1970,1984,1993,2003,2012,2020],
}
TEST_YEARS = _TEST_YEARS_MAP.get(yearstart, _TEST_YEARS_MAP[yearstart])

# ============================ Feature config ============================

TempName  = "Temp"
SalName   = "Sal"
Satname   = "O2_sat"
timename  = "Year"
Zonename  = "Zone0"

CORE_FEATURES = [
    timename,
    "month_sin", "month_cos",
    "Latitude", #"lon_cos20", "lon_sin110",
    TempName, SalName, "SOM_Zone",
]


ENFORCE_VALID_FEATURES = True

# ============================ Paths ============================

ROOT_DIR = "/data/wang/Result_Data/alldoxy"
yearstart_tag = str(yearstart)
OUTPUT_DIR = os.path.join("/data/wang/Result_Data", f"models_ML")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================ Sentinel / Range checks ============================

SENTINEL_VALUES = [
    -9999, -9999.0, -999, -32767, 32767,
    1e20, -1e20, 9.96921e36, -9.96921e36
]

FEATURE_RANGES = {
    "Latitude": (-90, 90),
    "Longitude": (-180, 180),
    TempName: (-2, 45),
    SalName: (0, 45),
    "Chla": (0, None),
    "pCO2": (0, None),
}

DESIRED_SHAP_FEATURES = [TempName, SalName, "U", "V", "SSH", "pCO2", "DIC", "Chla"]

TARGET = "Oxygen"
ZONE   = Zonename
TIME   = "Year"
MIN_SAMPLES = 300

# ============================ Plot style ============================

if ENABLE_PLOTTING:
    mpl.rcParams.update({
        "font.family": "DejaVu Sans",
        "font.size": 10.5,
        "axes.linewidth": 1.1,
        "xtick.direction": "in",
        "ytick.direction": "in",
        "xtick.major.size": 4,
        "ytick.major.size": 4,
        "savefig.dpi": 340,
        "figure.figsize": (6.4, 4.8)
    })

# ============================ Helpers ============================

AUX_FEATURES = []
MERGE_ZONES_1_TO_12 = False
MERGE_MAP = {}

def adjust_features_and_merge(depth):
    global AUX_FEATURES, MERGE_ZONES_1_TO_12, MERGE_MAP
    depth_int = int(depth)
    AUX_FEATURES = [Satname, "MLD","U","V","SSH","EKE","PAR","CO2_flux","pH","pCO2","DIC","Alkalinity","Chla"]
    if 0 < depth_int <= 3000:
        MERGE_ZONES_1_TO_12 = False
    else:
        MERGE_ZONES_1_TO_12 = True

    MERGE_MAP = {}
    if MERGE_ZONES_1_TO_12:
        for o in [1, 2, 3, 4, 5]:
            MERGE_MAP[o] = 1

    print(f"[Depth {depth_int}] AUX_FEATURES={AUX_FEATURES}")
    print(f"[Depth {depth_int}] MERGE_ZONES_1_TO_12={MERGE_ZONES_1_TO_12}")

def compute_chla_offset(x):
    x = np.asarray(x, float)
    pos = x[np.isfinite(x) & (x > 0)]
    if pos.size == 0:
        return 1e-6
    p5 = np.nanpercentile(pos, 5)
    return max(1e-6, 0.01 * p5)

def apply_chla_log_inplace(df, offset):
    if "Chla" in df.columns:
        z = np.asarray(df["Chla"], float)
        z = np.where(np.isfinite(z), z, np.nan)
        df["Chla"] = np.log10(np.maximum(z, 0) + float(offset))

def transform_features(df_part, feature_cols, chla_offset):
    X = df_part[feature_cols].copy()
    if "Chla" in X.columns:
        apply_chla_log_inplace(X, chla_offset if chla_offset is not None else 1e-6)
    return X

def calibration_stats(y_true, y_pred):
    y_true = np.asarray(y_true, float)
    y_pred = np.asarray(y_pred, float)
    mask = np.isfinite(y_true) & np.isfinite(y_pred)
    y_true, y_pred = y_true[mask], y_pred[mask]
    try:
        lr = LinearRegression().fit(y_true.reshape(-1, 1), y_pred)
        return dict(
            slope=float(lr.coef_[0]),
            intercept=float(lr.intercept_),
            bias=float(np.nanmean(y_pred - y_true)),
            r2=float(r2_score(y_true, y_pred))
        )
    except Exception:
        return {"slope": np.nan, "intercept": np.nan, "bias": np.nan, "r2": np.nan}

def enforce_valid_feature_rows(df, features, sentinel_values=None, feature_ranges=None):
    if not features:
        return df
    cols = [c for c in features if c in df.columns]
    if not cols:
        return df

    sentinel_values = list(sentinel_values or [])
    feature_ranges  = dict(feature_ranges or {})

    df2 = df.copy()
    for c in cols:
        df2[c] = pd.to_numeric(df2[c], errors="coerce")

    arr = df2[cols].to_numpy(dtype=float, copy=False)
    keep_mask = np.isfinite(arr).all(axis=1)

    if sentinel_values:
        sent_bad = np.zeros(len(df2), dtype=bool)
        for v in sentinel_values:
            sent_bad |= (arr == float(v)).any(axis=1)
        keep_mask &= ~sent_bad

    if feature_ranges:
        range_mask = np.ones(len(df2), dtype=bool)
        for c in cols:
            rng = feature_ranges.get(c, None)
            if rng is None:
                continue
            lo, hi = rng if isinstance(rng, (tuple, list)) else (None, None)
            col = df2[c].to_numpy(dtype=float, copy=False)
            m = np.isfinite(col)
            if lo is not None:
                m &= (col >= float(lo))
            if hi is not None:
                m &= (col <= float(hi))
            range_mask &= m
        keep_mask &= range_mask

    kept = int(keep_mask.sum())
    dropped = int(len(df2) - kept)
    if dropped > 0:
        print(f"[ValidFilter] kept {kept}/{len(df2)}; dropped {dropped} rows due to invalid feature(s).")
    return df2.loc[keep_mask].copy()

def plot_test_panel(y_true, y_pred, out_path,
                    xlab=r"Predicted DO ($\mu mol\ kg^{-1}$)",
                    ylab=r"Observed DO ($\mu mol\ kg^{-1}$)"):
    """
    Compact parity hexbin + residual histogram panel.
    (Matches your previous idea; safe if scipy absent.)
    """
    if not ENABLE_PLOTTING:
        return

    y_true = np.asarray(y_true, float)
    y_pred = np.asarray(y_pred, float)
    mask = np.isfinite(y_true) & np.isfinite(y_pred)
    y_true = y_true[mask]
    y_pred = y_pred[mask]
    if y_true.size == 0:
        return

    resid = y_true - y_pred

    try:
        import scipy.stats as st
        skew = float(st.skew(resid, nan_policy="omit"))
        kurt = float(st.kurtosis(resid, nan_policy="omit", fisher=False))
    except Exception:
        skew, kurt = np.nan, np.nan

    rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
    mae  = float(mean_absolute_error(y_true, y_pred))
    r2   = float(r2_score(y_true, y_pred))

    fig = plt.figure(figsize=(12.8, 5.4))
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)

    hb = ax1.hexbin(y_pred, y_true, gridsize=150, mincnt=1, linewidths=0)
    arr = hb.get_array()
    if arr.size:
        vmax = float(np.nanpercentile(arr, 99.0))
        hb.set_clim(0, vmax)

    dmax = float(min(np.nanmax(y_pred), np.nanmax(y_true)))
    dmin = 0.0
    ax1.plot([dmin, dmax], [dmin, dmax], ls="--", lw=1.2, color="0.55")
    ax1.set_xlabel(xlab)
    ax1.set_ylabel(ylab)
    ax1.grid(alpha=0.22, lw=0.5)
    ax1.text(0.03, 0.95,
             f"R2={r2:.3f}\nMAE={mae:.3f}\nRMSE={rmse:.3f}",
             transform=ax1.transAxes, va="top", ha="left")

    cb = fig.colorbar(hb, ax=ax1, pad=0.01)
    cb.set_label("Counts")

    # residual hist
    try:
        lo = float(np.nanpercentile(resid, 0.1))
        hi = float(np.nanpercentile(resid, 99.9))
        if (not np.isfinite(lo)) or (not np.isfinite(hi)) or lo >= hi:
            raise ValueError
    except Exception:
        lo = float(np.nanmin(resid)) if np.isfinite(resid).any() else -1.0
        hi = float(np.nanmax(resid)) if np.isfinite(resid).any() else 1.0
        if (not np.isfinite(lo)) or (not np.isfinite(hi)) or lo >= hi:
            lo, hi = -1.0, 1.0

    ax2.hist(resid, bins=80, alpha=0.9, edgecolor="none")
    ax2.set_xlabel(r"Residual (Obs - Pred) ($\mu mol\ kg^{-1}$)")
    ax2.set_ylabel("Counts")
    ax2.grid(alpha=0.20, lw=0.5)
    ax2.text(0.03, 0.95,
             f"skew={skew:+.3f}\nkurt={kurt:.3f}",
             transform=ax2.transAxes, va="top", ha="left")

    fig.tight_layout()
    fig.savefig(out_path, bbox_inches="tight")
    plt.close(fig)

# ============================ Inverse-density weights (FIXED) ============================

def build_density_sample_weight(df: pd.DataFrame) -> np.ndarray:
    """
    5°×5°×10-year bins:
      w = 1/sqrt(count), normalize to mean=1
    """
    n = len(df)
    if n == 0:
        return np.array([], dtype=float)

    w = np.ones(n, dtype=float)
    if not ENABLE_DENSITY_WEIGHTING:
        return w

    if not {"Latitude", "Longitude", "Year"}.issubset(df.columns):
        print("[DensityWeight] Missing Latitude/Longitude/Year. Fallback to uniform weights.")
        return w

    res_deg = float(WEIGHT_GRID_RESOLUTION_DEG)
    res_yr  = float(WEIGHT_TIME_RESOLUTION_YR)

    lat = pd.to_numeric(df["Latitude"], errors="coerce").to_numpy(dtype=float, copy=False)
    lon = pd.to_numeric(df["Longitude"], errors="coerce").to_numpy(dtype=float, copy=False)
    yr  = pd.to_numeric(df["Year"], errors="coerce").to_numpy(dtype=float, copy=False)

    key_ok = np.isfinite(lat) & np.isfinite(lon) & np.isfinite(yr)
    if not np.any(key_ok):
        print("[DensityWeight] All Latitude/Longitude/Year invalid. Fallback to uniform weights.")
        return w

    idx_ok = np.where(key_ok)[0]

    # ensure lon in [-180,180)
    lon_ok = ((lon[idx_ok] + 180.0) % 360.0) - 180.0

    lat_bin  = np.floor(lat[idx_ok] / res_deg).astype(np.int64)
    lon_bin  = np.floor(lon_ok / res_deg).astype(np.int64)
    time_bin = np.floor(yr[idx_ok]  / res_yr).astype(np.int64)

    # int64 key
    t0  = int(time_bin.min())
    la0 = int(lat_bin.min())
    lo0 = int(lon_bin.min())

    tb  = (time_bin - t0).astype(np.int64)
    lab = (lat_bin  - la0).astype(np.int64)
    lob = (lon_bin  - lo0).astype(np.int64)

    n_lat = int(lab.max()) + 1
    n_lon = int(lob.max()) + 1

    key = (tb * np.int64(n_lat) + lab) * np.int64(n_lon) + lob

    _, inv = np.unique(key, return_inverse=True)
    counts = np.bincount(inv).astype(float)
    cnt_each = counts[inv]

    ww = 1.0 / np.sqrt(np.maximum(cnt_each, 1.0))

    if ENABLE_TRANSITION_PERIOD_PENALTY:
        yy = yr[idx_ok]
        m = (yy >= float(PENALTY_YEAR_START)) & (yy <= float(PENALTY_YEAR_END))
        if np.any(m):
            ww[m] *= float(PENALTY_FACTOR)

    w[idx_ok] = ww

    mean_w = float(np.nanmean(w)) if np.isfinite(w).any() else 1.0
    if (not np.isfinite(mean_w)) or mean_w <= 0:
        mean_w = 1.0
    w = w / mean_w
    return w

# ============================ Decade-block KFold ============================

class DecadeBlockKFold:
    def __init__(self, n_splits=5, random_state=42):
        if n_splits < 2:
            raise ValueError("n_splits must be >= 2")
        self.n_splits = int(n_splits)
        self.random_state = int(random_state)

    def split(self, X, y=None, groups=None):
        n = len(X)
        if groups is None or len(groups) != n:
            rng = np.random.RandomState(self.random_state)
            idx = np.arange(n)
            rng.shuffle(idx)
            folds = np.array_split(idx, min(self.n_splits, max(2, n)))
            for f in folds:
                va = np.sort(f)
                tr = np.sort(np.setdiff1d(idx, va, assume_unique=False))
                yield tr, va
            return

        years = np.asarray(groups, dtype=float)
        decade = np.where(np.isfinite(years), (np.floor(years / 10.0) * 10).astype(int), -999999)

        uniq_dec = np.unique(decade)
        if len(uniq_dec) < 2:
            rng = np.random.RandomState(self.random_state)
            idx = np.arange(n)
            rng.shuffle(idx)
            folds = np.array_split(idx, min(self.n_splits, max(2, n)))
            for f in folds:
                va = np.sort(f)
                tr = np.sort(np.setdiff1d(idx, va, assume_unique=False))
                yield tr, va
            return

        rng = np.random.RandomState(self.random_state)
        rng.shuffle(uniq_dec)
        k_eff = min(self.n_splits, len(uniq_dec))
        dec_folds = np.array_split(uniq_dec, k_eff)

        for dec_va in dec_folds:
            dec_va_set = set(dec_va.tolist())
            va_mask = np.isin(decade, list(dec_va_set))
            va_idx = np.where(va_mask)[0]
            tr_idx = np.where(~va_mask)[0]
            yield tr_idx, va_idx

    def get_n_splits(self, X=None, y=None, groups=None):
        if groups is None:
            return self.n_splits
        years = np.asarray(groups, dtype=float)
        decade = np.where(np.isfinite(years), (np.floor(years / 10.0) * 10).astype(int), -999999)
        uniq_dec = np.unique(decade)
        return min(self.n_splits, max(2, len(uniq_dec)))

# ============================ Model factories ============================

def make_cb_regressor_for_fs():
    return CatBoostRegressor(
        iterations=700,
        depth=8,
        learning_rate=0.06,
        loss_function="RMSE",
        eval_metric="RMSE",
        random_seed=RANDOM_SEED,
        thread_count=N_THREADS,
        verbose=False
    )

# ============================ Weighted permutation importance (CV) ============================

def _weighted_mse(y_true, y_pred, w):
    y_true = np.asarray(y_true, float)
    y_pred = np.asarray(y_pred, float)
    w = np.asarray(w, float)
    m = np.isfinite(y_true) & np.isfinite(y_pred) & np.isfinite(w)
    if not np.any(m):
        return np.nan
    y_true = y_true[m]; y_pred = y_pred[m]; w = w[m]
    sw = np.sum(w)
    if sw <= 0 or not np.isfinite(sw):
        return np.nan
    err = y_pred - y_true
    return float(np.sum(w * err * err) / sw)

def permutation_importance_cv_weighted(X_df, y, years, sample_weight,
                                      n_splits=5, n_repeats=3, max_val_samples=12000):
    cols = list(X_df.columns)
    p = len(cols)
    scores = np.zeros(p, dtype=float)

    cv = DecadeBlockKFold(n_splits=n_splits, random_state=RANDOM_SEED)

    y = np.asarray(y)
    sw_all = np.asarray(sample_weight, float)
    rng_global = np.random.RandomState(RANDOM_SEED)

    for tr_idx, va_idx in cv.split(X_df, y, groups=years):
        X_tr = X_df.iloc[tr_idx].values
        y_tr = y[tr_idx]
        sw_tr = sw_all[tr_idx]

        if len(va_idx) > int(max_val_samples):
            va_idx = rng_global.choice(va_idx, size=int(max_val_samples), replace=False)

        X_va0 = X_df.iloc[va_idx].values
        y_va = y[va_idx]
        sw_va = sw_all[va_idx]

        est = make_cb_regressor_for_fs()
        est.fit(X_tr, y_tr, sample_weight=sw_tr)

        base_pred = est.predict(X_va0)
        base_mse = _weighted_mse(y_va, base_pred, sw_va)
        if not np.isfinite(base_mse):
            continue

        for j in range(p):
            inc = 0.0
            for _ in range(int(n_repeats)):
                X_perm = X_va0.copy()
                perm_idx = rng_global.permutation(X_perm.shape[0])
                X_perm[:, j] = X_perm[perm_idx, j]
                pred = est.predict(X_perm)
                mse_perm = _weighted_mse(y_va, pred, sw_va)
                if np.isfinite(mse_perm):
                    inc += max(0.0, mse_perm - base_mse)
            scores[j] += inc / max(1, int(n_repeats))

    denom = cv.get_n_splits(X_df, y, groups=years)
    if denom <= 0:
        denom = 1
    scores /= float(denom)
    return pd.Series(scores, index=cols).sort_values(ascending=False)

# ============================ RFECV with decade-block CV (FIXED) ============================

def rfecv_with_decade_blocks(X_df, y, years, sample_weight,
                            always_keep=None,
                            min_feats=8, step=1, scoring="neg_root_mean_squared_error"):
    """
    RFECV with decade-block CV.

    IMPORTANT compatibility note:
    - Some sklearn versions do NOT allow passing sample_weight into RFECV.fit().
    - We therefore try weighted RFECV first; if TypeError, fall back to unweighted RFECV.
      (Other stages still keep weights: permutation importance / Optuna CV / final training.)

    Returns:
      final_features, rfecv_object, info_dict
    """
    always_keep = list(always_keep) if always_keep else []

    cols_all = list(X_df.columns)
    p_total = len(cols_all)

    keep_in = [c for c in cols_all if c not in always_keep]
    X_core = X_df[keep_in].copy()
    p_core = len(keep_in)

    if USE_ADAPTIVE_RFECV_MIN:
        min_total = max(RFECV_MIN_ABS, int(np.ceil(p_total * RFECV_MIN_RATIO)))
        min_total = min(min_total, max(1, p_total - 1))
        min_total = min(min_total, max(1, min_feats, RFECV_MIN_FEATS))
    else:
        min_total = min(max(1, min_feats), max(1, p_total - 1))

    min_core = max(1, min_total - len(always_keep))
    min_core = min(min_core, max(1, p_core - 1))

    info = {
        "p_total": int(p_total),
        "p_core": int(p_core),
        "min_total_used": int(min_total),
        "min_core_used": int(min_core),
        "policy": ("adaptive" if USE_ADAPTIVE_RFECV_MIN else "fixed"),
        "ratio": float(RFECV_MIN_RATIO),
        "abs_min": int(RFECV_MIN_ABS),
        "fixed_min": int(RFECV_MIN_FEATS),
        "rfecv_weighted_supported": None
    }

    if p_core <= 1 or min_core >= p_core:
        final = list(dict.fromkeys(always_keep + keep_in))
        return final, None, info

    est = make_cb_regressor_for_fs()
    cv = DecadeBlockKFold(n_splits=PERM_N_SPLITS, random_state=RANDOM_SEED)

    rfecv = RFECV(
        estimator=est,
        step=step,
        cv=cv,
        scoring=scoring,
        min_features_to_select=min_core,
        n_jobs=1
    )

    # --- try weighted RFECV; fall back if sklearn doesn't support it ---
    try:
        rfecv.fit(
            X_core.values, y,
            groups=np.asarray(years),
            sample_weight=np.asarray(sample_weight, float)
        )
        info["rfecv_weighted_supported"] = True
    except TypeError:
        print("[RFECV] The current scikit-learn version does not support RFECV.fit(sample_weight=...). "
              "Falling back to unweighted RFECV (other stages still use weights: PI/OptunaCV/final training).")
        rfecv.fit(
            X_core.values, y,
            groups=np.asarray(years)
        )
        info["rfecv_weighted_supported"] = False

    mask = rfecv.support_
    selected_core = list(np.array(keep_in)[mask])
    final = list(dict.fromkeys(always_keep + selected_core))
    return final, rfecv, info

# ============================ Optuna objective (decade-block CV, weighted) ============================

def objective(trial, X, y, years, sample_weights):
    X = np.asarray(X)
    y = np.asarray(y)
    years = np.asarray(years)
    sw_all = np.asarray(sample_weights, float)

    params = {
        "iterations": trial.suggest_int("iterations", 200, 2500),
        "learning_rate": trial.suggest_float("learning_rate", 0.02, 0.08, log=True),
        "depth": trial.suggest_int("depth",6, 10),
        "l2_leaf_reg": trial.suggest_int("l2_leaf_reg", 5, 20),
        "bagging_temperature": trial.suggest_float("bagging_temperature", 0.1, 0.6),
        "bootstrap_type": "Bayesian",
        "loss_function": "RMSE",
        "eval_metric": "RMSE",
        "random_seed": RANDOM_SEED,
        "od_type": "Iter",
        "od_wait": 50,
        "use_best_model": True,
        "verbose": False
    }

    cv = DecadeBlockKFold(n_splits=5, random_state=RANDOM_SEED)
    rmses, best_iters = [], []

    for tr_idx, va_idx in cv.split(X, y, groups=years):
        X_tr, X_va = X[tr_idx], X[va_idx]
        y_tr, y_va = y[tr_idx], y[va_idx]
        sw_tr = sw_all[tr_idx]
        sw_va = sw_all[va_idx]

        eval_pool = Pool(X_va, y_va, weight=sw_va)
        model = CatBoostRegressor(**params, thread_count=N_THREADS)
        model.fit(X_tr, y_tr, sample_weight=sw_tr, eval_set=eval_pool)

        preds = model.predict(eval_pool)

        err = preds - y_va
        denom = np.sum(sw_va)
        if denom <= 0 or not np.isfinite(denom):
            wrmse = float(np.sqrt(np.mean(err * err)))
        else:
            wrmse = float(np.sqrt(np.sum(sw_va * err * err) / denom))

        rmses.append(wrmse)
        try:
            best_iters.append(int(model.get_best_iteration()))
        except Exception:
            best_iters.append(int(params["iterations"]))

    trial.set_user_attr("best_iters", best_iters)
    trial.set_user_attr("median_best_iter", int(np.median(best_iters)) if best_iters else int(params["iterations"]))
    return float(np.mean(rmses)) if rmses else np.nan

# ============================ Zone heatmap (unchanged) ============================

def plot_zone_feature_importance_heatmap(zone_importances,
                                         out_png,
                                         out_csv_matrix,
                                         out_csv_sum,
                                         feature_order=None,
                                         annotate=True):
    if len(zone_importances) == 0:
        return

    zone_labels = [str(z.get("zone_id", "?")) for z in zone_importances]
    feat_sets = [set(z["feat_names"]) for z in zone_importances]
    all_features = set().union(*feat_sets)

    if feature_order is not None:
        features = [f for f in feature_order if f in all_features]
    else:
        seen = set()
        features = []
        for z in zone_importances:
            for f in z["feat_names"]:
                if f not in seen:
                    seen.add(f)
                    features.append(f)

    if len(features) == 0:
        return

    F = len(features)
    Z = len(zone_importances)
    M = np.full((F, Z), np.nan, dtype=float)

    for j, z in enumerate(zone_importances):
        name_to_val = dict(zip(z["feat_names"], np.asarray(z["fi"], float)))
        for i, feat in enumerate(features):
            val = name_to_val.get(feat, np.nan)
            M[i, j] = float(val) if np.isfinite(val) else np.nan

    col_sums = np.nansum(M, axis=0)
    M_norm = np.full_like(M, np.nan, dtype=float)
    for j in range(Z):
        denom = col_sums[j]
        if np.isfinite(denom) and denom > 0:
            M_norm[:, j] = M[:, j] / denom

    df_matrix = pd.DataFrame(M_norm, index=features, columns=[f"Zone_{lbl}" for lbl in zone_labels])
    df_matrix.to_csv(out_csv_matrix)

    row_sum = np.nansum(M_norm, axis=1)
    presence = np.sum(np.isfinite(M), axis=1).astype(int)
    df_sum = pd.DataFrame({"Feature": features, "SumImportance": row_sum, "PresentCount": presence}).set_index("Feature")
    df_sum.to_csv(out_csv_sum)

    if not ENABLE_PLOTTING:
        return

    cmap = sns.color_palette("inferno", as_cmap=True)
    try:
        cmap.set_bad("#f0f0f0")
    except Exception:
        pass

    fig = plt.figure(figsize=(14, 6))
    ax0 = fig.add_subplot(111)

    data = df_matrix.values.astype(float)
    mask = ~np.isfinite(data)
    finite_vals = data[np.isfinite(data)]
    vmax = float(np.nanmax(finite_vals)) if finite_vals.size else 1.0
    if not np.isfinite(vmax) or vmax <= 0:
        vmax = 1.0
    norm = Normalize(vmin=0.0, vmax=vmax)

    sns.heatmap(data, ax=ax0, mask=mask, cmap=cmap, norm=norm, cbar=True,
                cbar_kws={"label": "Normalized importance per zone"}, linewidths=0)

    fig.savefig(out_png, bbox_inches="tight")
    plt.close(fig)

# ============================ Single-zone training ============================

def run_single_zone(depth_out, depth, zone, train_df, test_df):
    odir = os.path.join(depth_out, f"zone_{zone}")
    os.makedirs(odir, exist_ok=True)

    tr = train_df[train_df["Zone_Merged"] == zone].reset_index(drop=True)
    te = test_df[test_df["Zone_Merged"] == zone].reset_index(drop=True)

    if len(tr) < MIN_SAMPLES:
        return [], None, None

    sw = tr["sample_weight"].values if "sample_weight" in tr.columns else np.ones(len(tr), dtype=float)

    candidates = [c for c in (CORE_FEATURES + AUX_FEATURES) if c in tr.columns]
    if len(candidates) == 0:
        return [], None, None

    chla_offset = None
    if "Chla" in tr.columns:
        chla_offset = compute_chla_offset(tr["Chla"].values)

    X_tr_df = transform_features(tr, candidates, chla_offset)
    y_tr = tr[TARGET].values
    years = tr[TIME].values

    always_keep_present = [f for f in ALWAYS_KEEP if f in X_tr_df.columns]

    pi_score = None
    rfecv_info = None

    if ENABLE_FEATURE_SELECTION:
        pi_score = permutation_importance_cv_weighted(
            X_tr_df, y_tr, years, sw,
            n_splits=PERM_N_SPLITS,
            n_repeats=PERM_N_REPEATS,
            max_val_samples=PERM_MAX_VAL_SAMPLES
        )

        if PERM_TOPK is None:
            k_auto = max(10, int(2 * sqrt(len(pi_score))))
            top_feats = list(pi_score.index[:k_auto])
        else:
            top_feats = list(pi_score.index[:min(int(PERM_TOPK), len(pi_score))])

        rfecv_candidates = list(dict.fromkeys(top_feats + always_keep_present))
        X_rfecv = X_tr_df[rfecv_candidates].copy()

        selected_feats, rfecv_obj, rfecv_info = rfecv_with_decade_blocks(
            X_rfecv, y_tr, years, sw,
            always_keep=always_keep_present,
            min_feats=max(RFECV_MIN_FEATS, len(always_keep_present)),
            step=RFECV_STEP,
            scoring=RFECV_SCORING
        )

        selected_feats_sorted = [f for f in pi_score.index if f in selected_feats]
        for f in always_keep_present:
            if f not in selected_feats_sorted:
                selected_feats_sorted = [f] + selected_feats_sorted
    else:
        selected_feats_sorted = [f for f in candidates if f in X_tr_df.columns]
        for f in always_keep_present:
            if f not in selected_feats_sorted:
                selected_feats_sorted.append(f)

    X_tr = X_tr_df[selected_feats_sorted].values

    if not te.empty:
        X_te_df = transform_features(te, candidates, chla_offset)
        X_te_df = X_te_df.reindex(columns=selected_feats_sorted, fill_value=np.nan)
        X_te = X_te_df.values
    else:
        X_te = None

    study = optuna.create_study(direction="minimize", sampler=TPESampler(seed=RANDOM_SEED))
    study.optimize(lambda trial: objective(trial, X_tr, y_tr, years, sw),
                   n_trials=N_TRIALS_OPTUNA, show_progress_bar=False)

    best_params = study.best_trial.params.copy()
    best_iters = study.best_trial.user_attrs.get("best_iters", [])
    median_best_iter = int(np.median(best_iters)) if len(best_iters) > 0 else int(best_params.get("iterations", 800))

    best_params.update({
        "bootstrap_type": "Bayesian",
        "loss_function": "RMSE",
        "eval_metric": "RMSE",
        "random_seed": RANDOM_SEED
    })

    final_iters = int(np.clip(median_best_iter, 200, 2000))
    final_params = {**best_params, "iterations": final_iters, "verbose": False}
    for k in ["od_type", "od_wait", "use_best_model"]:
        final_params.pop(k, None)

    model = CatBoostRegressor(**final_params, thread_count=N_THREADS)
    model.fit(X_tr, y_tr, sample_weight=sw)
    model.save_model(os.path.join(odir, "model.cbm"))

    metadata = {
        "features": selected_feats_sorted,
        "all_candidates": candidates,
        "always_keep": always_keep_present,
        "feature_selection_enabled": bool(ENABLE_FEATURE_SELECTION),
        "chla_offset": float(chla_offset) if chla_offset is not None else None,
        "params": model.get_params(),
        "n_train": int(len(tr)),
        "n_test": int(len(te)),
        "test_years": TEST_YEARS,
        "cv_best_iterations_per_fold": best_iters,
        "final_iterations": int(final_iters),
        "perm_importance": (pi_score.to_dict() if pi_score is not None else {}),
        "rfecv_info": rfecv_info,
        "density_weighting": {
            "enabled": bool(ENABLE_DENSITY_WEIGHTING),
            "grid_deg": float(WEIGHT_GRID_RESOLUTION_DEG),
            "time_yr": float(WEIGHT_TIME_RESOLUTION_YR),
            "formula": "1/sqrt(count), normalized mean=1"
        }
    }
    joblib.dump(metadata, os.path.join(odir, "metadata.pkl"))

    y_tr_pred = model.predict(X_tr)
    metrics_rows = []

    mse = mean_squared_error(y_tr, y_tr_pred)
    rmse = float(np.sqrt(mse))
    mae = float(mean_absolute_error(y_tr, y_tr_pred))
    r2 = float(r2_score(y_tr, y_tr_pred))
    metrics_rows.append([depth, zone, "train", mse, rmse, mae, r2])

    calib_train = calibration_stats(y_tr, y_tr_pred)
    calib_test = {}

    if X_te is not None and len(te) > 0:
        y_te = te[TARGET].values
        y_te_pred = model.predict(X_te)

        mse2 = mean_squared_error(y_te, y_te_pred)
        rmse2 = float(np.sqrt(mse2))
        mae2 = float(mean_absolute_error(y_te, y_te_pred))
        r22 = float(r2_score(y_te, y_te_pred))
        metrics_rows.append([depth, zone, "test", mse2, rmse2, mae2, r22])

        calib_test = calibration_stats(y_te, y_te_pred)

    # summary
    valid_trials = [t for t in study.trials if (t.state == optuna.trial.TrialState.COMPLETE and t.value is not None)]
    cv_rmses = [t.value for t in valid_trials]
    cv_mean, cv_std = (float(np.mean(cv_rmses)), float(np.std(cv_rmses))) if len(cv_rmses) > 0 else (np.nan, np.nan)

    with open(os.path.join(odir, "zone_summary.txt"), "w") as f:
        f.write(f"Depth: {depth}   Zone: {zone}\n")
        f.write(f"Train samples: {len(tr)}   Test samples: {len(te)}\n")
        f.write(f"Yearstart: {yearstart}\n")
        f.write(f"Test years: {TEST_YEARS}\n\n")
        f.write("[Density weighting]\n")
        f.write(f"  enabled: {ENABLE_DENSITY_WEIGHTING}\n")
        f.write(f"  grid_deg: {WEIGHT_GRID_RESOLUTION_DEG}\n")
        f.write(f"  time_bin_years: {WEIGHT_TIME_RESOLUTION_YR}\n")
        f.write("  w = 1/sqrt(count), normalized to mean=1\n\n")

        f.write("[Feature Selection]\n")
        f.write(f"  enabled: {ENABLE_FEATURE_SELECTION}\n")
        f.write(f"  candidate_features: {candidates}\n")
        f.write(f"  always_keep_present: {always_keep_present}\n")
        if ENABLE_FEATURE_SELECTION and (pi_score is not None):
            f.write(f"  permutation_importance_top: {list(pi_score.index[:min(30, len(pi_score))])}\n")
            f.write(f"  rfecv_selected: {selected_feats_sorted}\n")
            if rfecv_info is not None:
                f.write(f"  rfecv_weighted_supported: {rfecv_info.get('rfecv_weighted_supported')}\n")
                f.write(f"  rfecv_min_features_policy: {rfecv_info.get('policy')} "
                        f"(ratio={rfecv_info.get('ratio')}, abs_min={rfecv_info.get('abs_min')}, fixed_min={rfecv_info.get('fixed_min')})\n")
        else:
            f.write("  selection_skipped: used all available candidate features\n")
            f.write(f"  used_features: {selected_feats_sorted}\n")
        if chla_offset is not None:
            f.write(f"  chla_log10_offset: {chla_offset:.6g}\n")

        f.write("\n[HPO / Optuna]\n")
        f.write(f"  best_trial_value (weighted decade-block CV RMSE): {study.best_value:.6f}\n")
        f.write(f"  best_trial_params: {study.best_trial.params}\n")
        f.write(f"  completed_trials: {len(valid_trials)} / {len(study.trials)}\n")
        f.write(f"  best_iterations_per_fold: {best_iters}\n")
        f.write(f"  median_best_iter: {median_best_iter}\n")
        f.write(f"  final_training_iterations: {final_iters}\n")
        f.write(f"  CV RMSE stats over trials: mean={cv_mean:.6f}, std={cv_std:.6f}\n\n")

        f.write("[Metrics]\n")
        f.write("  Set      MSE        RMSE        MAE         R2\n")
        for row in metrics_rows:
            _, _, nm, mse_, rmse_, mae_, r2_ = row
            f.write(f"  {nm:<6} {mse_:10.6f} {rmse_:10.6f} {mae_:10.6f} {r2_:10.6f}\n")

        def _fmt_c(cal):
            return (f"slope={cal.get('slope', np.nan):.6f}, "
                    f"intercept={cal.get('intercept', np.nan):.6f}, "
                    f"bias={cal.get('bias', np.nan):.6f} (Pred-Obs), "
                    f"R2={cal.get('r2', np.nan):.6f}")
        f.write("\n[Calibration]\n")
        f.write(f"  train: {_fmt_c(calib_train)}\n")
        if calib_test:
            f.write(f"  test : {_fmt_c(calib_test)}\n")

    fi_zone = np.asarray(model.get_feature_importance(), float)
    zh = {"feat_names": selected_feats_sorted, "fi": fi_zone, "weight": len(tr), "zone_id": zone}

    del model, X_tr, X_te
    plt.close("all")
    gc.collect()

    return metrics_rows, None, zh

# ============================ Per-depth pipeline ============================

def _predict_overall_test_for_depth(depth_out: str, test_df: pd.DataFrame) -> None:
    """
    Build overall test predictions across all zone models for this depth by loading
    each zone model + metadata, predicting its zone rows, then writing:
      - test_with_pred.csv
      - overall_test_metrics.txt
      - optional test_panel.png

    Notes:
    - This function stitches zone-wise predictions into one global test table.
    - Metrics are computed on rows where Pred is finite.
    """
    if test_df is None or len(test_df) == 0:
        return

    test_df_total = test_df.copy()
    test_df_total["Pred"] = np.nan

    # robust zone list (Zone_Merged may be int or nullable int)
    zones_in_test = (
        pd.to_numeric(test_df_total["Zone_Merged"], errors="coerce")
        .dropna()
        .astype(int)
        .unique()
        .tolist()
    )
    zones_in_test = sorted(zones_in_test)
    if len(zones_in_test) == 0:
        return

    n_skipped_zones = 0

    for zone in zones_in_test:
        odir = os.path.join(depth_out, f"zone_{zone}")
        meta_path = os.path.join(odir, "metadata.pkl")
        model_path = os.path.join(odir, "model.cbm")
        if not (os.path.exists(meta_path) and os.path.exists(model_path)):
            n_skipped_zones += 1
            continue

        try:
            meta = joblib.load(meta_path)
        except Exception:
            n_skipped_zones += 1
            continue

        feat_order = meta.get("features", [])
        candidates = meta.get("all_candidates", [])
        chla_offset = meta.get("chla_offset", None)

        if not feat_order or not candidates:
            n_skipped_zones += 1
            continue

        mask_te = (pd.to_numeric(test_df_total["Zone_Merged"], errors="coerce").astype("Int64") == int(zone))
        if not bool(mask_te.any()):
            continue

        df_part = test_df_total.loc[mask_te].copy()

        # Build X in the same way as training, using stored candidates and feature order
        try:
            Xf = transform_features(df_part, candidates, chla_offset)
            Xf = Xf.reindex(columns=feat_order, fill_value=np.nan)
            X = Xf.values
        except Exception:
            n_skipped_zones += 1
            continue

        model = CatBoostRegressor()
        model.load_model(model_path)
        try:
            pred = model.predict(X)
            test_df_total.loc[mask_te, "Pred"] = pred
        except Exception:
            n_skipped_zones += 1
        finally:
            del model
            gc.collect()

    # ---- Output stitched global test set ----
    test_csv = os.path.join(depth_out, "test_with_pred.csv")
    test_df_total.to_csv(test_csv, index=False)

    # ---- Overall metrics (rows with finite Pred + Oxygen) ----
    y_true = pd.to_numeric(test_df_total[TARGET], errors="coerce").to_numpy(dtype=float, copy=False)
    y_pred = pd.to_numeric(test_df_total["Pred"], errors="coerce").to_numpy(dtype=float, copy=False)
    m = np.isfinite(y_true) & np.isfinite(y_pred)

    n_total = int(len(test_df_total))
    n_eval = int(np.sum(m))

    metrics_path = os.path.join(depth_out, "overall_test_metrics.txt")
    with open(metrics_path, "w") as f:
        f.write("Overall test performance across all zone models (this depth)\n")
        f.write(f"  test_rows_total: {n_total}\n")
        f.write(f"  test_rows_with_pred: {n_eval}\n")
        f.write(f"  test_rows_missing_pred: {n_total - n_eval}\n")
        f.write(f"  zones_in_test: {len(zones_in_test)}\n")
        f.write(f"  zones_missing_model_or_meta_or_failed: {n_skipped_zones}\n\n")

        if n_eval <= 0:
            f.write("No valid predictions available to compute metrics.\n")
            return

        yt = y_true[m]
        yp = y_pred[m]

        mse = float(mean_squared_error(yt, yp))
        rmse = float(np.sqrt(mse))
        mae = float(mean_absolute_error(yt, yp))
        r2 = float(r2_score(yt, yp))

        f.write("[Metrics]\n")
        f.write(f"  RMSE: {rmse:.6f}\n")
        f.write(f"  MAE : {mae:.6f}\n")
        f.write(f"  R2  : {r2:.6f}\n\n")

        cal = calibration_stats(yt, yp)
        f.write("[Calibration]\n")
        f.write(f"  slope={cal.get('slope', np.nan):.6f}, "
                f"intercept={cal.get('intercept', np.nan):.6f}, "
                f"bias={cal.get('bias', np.nan):.6f} (Pred-Obs), "
                f"R2={cal.get('r2', np.nan):.6f}\n")

    if ENABLE_PLOTTING and n_eval > 0:
        plot_test_panel(
            y_true[m], y_pred[m],
            out_path=os.path.join(depth_out, "test_panel.png")
        )

def process_depth(depth):
    global MERGE_MAP

    csv_path = os.path.join(ROOT_DIR, f"{depth}dbar", f"depth{depth}_TRAIN.csv")
    if not os.path.exists(csv_path):
        print(f"[Depth {depth}] Missing: {csv_path}")
        return

    df = pd.read_csv(csv_path, parse_dates=["Date"], low_memory=False)

    df = df[
        (df[TIME] >= yearstart)
    ].copy()

    df[ZONE] = pd.to_numeric(df[ZONE], errors="coerce").astype("Int64")
    df = df[df[ZONE].between(1, 20, inclusive="both")].copy()

    df["Zone_Merged"] = df[ZONE].apply(lambda z: MERGE_MAP.get(int(z), int(z)) if pd.notna(z) else z)

    if ENFORCE_VALID_FEATURES:
        validFEATURES = [Satname, timename, "Latitude", TempName, SalName]
        features_to_check = [c for c in validFEATURES if c in df.columns]
        before = len(df)
        df = enforce_valid_feature_rows(
            df,
            features=features_to_check,
            sentinel_values=SENTINEL_VALUES,
            feature_ranges=FEATURE_RANGES
        )
        after = len(df)
        print(f"[Depth {depth}] Valid feature filter: kept {after}/{before}.")

    train_df = df[~df[TIME].isin(TEST_YEARS)].reset_index(drop=True).copy()
    test_df  = df[df[TIME].isin(TEST_YEARS)].reset_index(drop=True).copy()

    train_df["sample_weight"] = build_density_sample_weight(train_df)

    if len(train_df) > 0:
        sw = train_df["sample_weight"].values
        print(f"[Depth {depth}] sample_weight built. "
              f"stats: min={np.min(sw):.4f} max={np.max(sw):.4f} mean={np.mean(sw):.4f} "
              f"(grid={WEIGHT_GRID_RESOLUTION_DEG}°, time={WEIGHT_TIME_RESOLUTION_YR}yr)")

    depth_out = os.path.join(OUTPUT_DIR, f"depth_{depth}")
    os.makedirs(depth_out, exist_ok=True)

    zone_ids_sorted = sorted(train_df["Zone_Merged"].dropna().unique().astype(int))

    results = Parallel(
        n_jobs=N_ZONE_WORKERS,
        backend="loky",
        prefer="processes",
        pre_dispatch="n_jobs"
    )(
        delayed(run_single_zone)(depth_out, depth, int(zone), train_df, test_df)
        for zone in zone_ids_sorted
    )

    overall_metrics = []
    zone_importances = []

    for (metrics_rows, _, zh) in results:
        overall_metrics.extend(metrics_rows or [])
        if zh:
            zone_importances.append(zh)

    if overall_metrics:
        pd.DataFrame(overall_metrics, columns=["Depth", Zonename, "Set", "MSE", "RMSE", "MAE", "R2"]) \
            .to_csv(os.path.join(depth_out, "all_zones_metrics.csv"), index=False)

    zone_counts = train_df.groupby("Zone_Merged").size().reset_index(name="Train_Sample_Count")
    zone_counts.to_csv(os.path.join(depth_out, "zone_train_sample_counts.csv"), index=False)

    desired_feature_order = CORE_FEATURES + AUX_FEATURES
    plot_zone_feature_importance_heatmap(
        zone_importances,
        out_png=os.path.join(depth_out, "zone_feature_importance_heatmap.png"),
        out_csv_matrix=os.path.join(depth_out, "zone_feature_importance_matrix.csv"),
        out_csv_sum=os.path.join(depth_out, "zone_feature_importance_sum.csv"),
        feature_order=desired_feature_order,
        annotate=True
    )

    # ===== NEW: overall/global test-set prediction + overall metrics across all zones =====
    _predict_overall_test_for_depth(depth_out, test_df)

    print(f"[Depth {depth}] Done. Outputs in: {depth_out}")

# ============================ Hard memory cleanup ============================

def _hard_memory_cleanup():
    try:
        plt.close("all")
    except Exception:
        pass
    try:
        from joblib.externals.loky import get_reusable_executor
        get_reusable_executor().shutdown(wait=True, kill_workers=True)
    except Exception:
        pass
    for _ in range(2):
        gc.collect()
    try:
        import ctypes
        ctypes.CDLL("libc.so.6").malloc_trim(0)
    except Exception:
        pass

# ============================ Main ============================

if __name__ == "__main__":
    print(f"[Config] yearstart={yearstart}, TEST_YEARS={TEST_YEARS}")
    print(f"[Config] DensityWeighting={ENABLE_DENSITY_WEIGHTING} "
          f"(grid={WEIGHT_GRID_RESOLUTION_DEG}°, time={WEIGHT_TIME_RESOLUTION_YR}yr)")
    print(f"[Config] CV = DecadeBlockKFold (10-year blocks)")

    for d in depthlist:
        adjust_features_and_merge(d)
        process_depth(d)

    print("All depths done.")
    _hard_memory_cleanup()

In [None]:
# -*- coding: utf-8 -*-
import os, re, glob, warnings, argparse
warnings.filterwarnings("ignore")

# ========= Threads & parallelism (set BEFORE importing numeric libs) =========
DEFAULT_THREADS = 16
THREADS = int(os.getenv("PRED_THREADS", str(DEFAULT_THREADS)))
os.environ.setdefault("OMP_NUM_THREADS", str(THREADS))
os.environ.setdefault("MKL_NUM_THREADS", str(THREADS))
os.environ.setdefault("NUMEXPR_NUM_THREADS", str(THREADS))

# ========= Numeric/data & ML =========
import numpy as np
import pandas as pd
from functools import lru_cache
from catboost import CatBoostRegressor

# Optional dependencies
try:
    import lightgbm as lgb
except Exception:
    lgb = None
try:
    import xgboost as xgb
except Exception:
    xgb = None

import joblib
from joblib import Parallel, delayed

# ============== Optional dependencies for remap ==============
try:
    from netCDF4 import Dataset
except Exception:
    Dataset = None
try:
    from sklearn.neighbors import BallTree
    HAVE_SKLEARN = True
except Exception:
    BallTree = None
    HAVE_SKLEARN = False
try:
    from threadpoolctl import threadpool_limits
except Exception:
    threadpool_limits = None


# =============================================================================
# Configuration you must adapt to your environment (must match training outputs)
# =============================================================================

# Prediction input root directory (one subdir per depth, e.g. "1dbar", "10dbar"...)
INPUT_ROOT   = "/data/wang/Result_Data/allnodoxy"

MODELS_ROOT  = OUTPUT_DIR
DEPTH_LIST   = depthlist

CSV_NAME_TEMPLATE = "depth{depth}_{season}_TRAIN.csv"

Zonename = Zonename
TempName = TempName
SalName  = SalName
O2SAT_NAME = Satname
TIMENAME = timename

SEASONS = SEASONS

# Columns to keep when writing *_with_pred.csv (adjust if needed)
KEEP_COLS = ["Year", "Month", "Latitude", "Longitude", Zonename, "Oxygen"]

# Whether to delete the original CSV after writing the new CSV
DELETE_ORIGINAL = False


# Remap switch (inference-time boundary fusion)
REMAP_ENABLED_GLOBAL = True
REMAP_SMOOTH_KM = float(os.getenv("REMAP_SMOOTH_KM", "300.0"))
REMAP_WORKERS   = min(THREADS, 16)

# MeanBiomes and LoS mask paths
BIOMES_NC       = os.getenv("BIOMES_NC", "/data/wang/Merage_Biomes_0p5deg.nc")
REMAP_MASK_NC   = os.getenv("REMAP_MASK_NC", "/data/wang/mask_lineofsight.nc")

EARTH_R_KM = 6371.0


# Whether to perform seam leveling (fit per-zone bias then smooth within region)
SEAM_LEVELING   = False
SEAM_K          = 10
SEAM_MAX_DIST_KM= 200.0
SEAM_RIDGE      = 1e-3

# Remap debugging
REMAP_DEBUG      = False
REMAP_TRACE_N    = 8
REMAP_TRACE_SEED = 42
REMAP_DIAG_CSV   = False


# Candidate aliases for temperature/salinity columns in CSV
TEMP_ALIASES = ["Temp", "Temperature"]
SAL_ALIASES  = ["Sal", "Salinity"]

# =============================================================================
# Global state for feature/merge policy (mirrors training-side structure)
# =============================================================================
AUX_FEATURES = []
MERGE_MAP = {}

# =============================================================================
# Zone merging policy (inference-side; keep consistent with training policy)
# =============================================================================
def adjust_features_and_mergepred(depth):
    """
    Goal: remain structurally consistent with training-side adjust_features_and_merge(depth).

    - Set AUX_FEATURES by depth
    - Set MERGE_ZONES_1_TO_12 by depth
    - Build MERGE_MAP
        * If MERGE_ZONES_1_TO_12=True: map zones to a single zone (deep ocean), and disable remap
        * Else: keep remap enabled
    """
    global AUX_FEATURES, MERGE_ZONES_1_TO_12, MERGE_MAP

    depth_int = int(depth)
    AUX_FEATURES = [Satname, "U", "V", "SSH", "pCO2", "DIC", "Chla", "MLD", "pH", "CO2_flux", "Alkalinity"]

    # Keep consistent with training: depth-based merge policy
    if 0 < depth_int <= 3000:
        MERGE_ZONES_1_TO_12 = False
    else:
        MERGE_ZONES_1_TO_12 = True

    # Keep consistent with training: build MERGE_MAP
    MERGE_MAP = {}
    remap_enabled = True

    if MERGE_ZONES_1_TO_12:
        # Deep ocean: map multiple basins into one
        for o in [1, 2, 3, 4, 5]:
            MERGE_MAP[o] = 1
        remap_enabled = False

    print(
        f"[POLICY] depth={depth_int} | AUX_FEATURES={AUX_FEATURES} | "
        f"MERGE_ZONES_1_TO_12={MERGE_ZONES_1_TO_12} | MERGE_MAP={MERGE_MAP} | "
        f"remap_enabled={remap_enabled}"
    )
    return MERGE_MAP, remap_enabled


# =============================================================================
# Utilities: file/depth discovery
# =============================================================================
def parse_depth_from_dir(dirname):
    m = re.search(r"(\d+)dbar$", os.path.basename(dirname))
    return m.group(1) if m else None

def find_depth_dirs(root):
    dirs = [p for p in glob.glob(os.path.join(root, "*")) if os.path.isdir(p)]
    return [d for d in dirs if re.search(r"\d+dbar$", os.path.basename(d))]

def list_depths_to_run(input_root, depth_list_cfg=None, depth_list_cli=None):
    if depth_list_cli and len(depth_list_cli) > 0:
        candidates = [str(d) for d in depth_list_cli]
    elif depth_list_cfg and len(depth_list_cfg) > 0:
        candidates = [str(d) for d in depth_list_cfg]
    else:
        ddirs = find_depth_dirs(input_root)
        candidates = []
        for d in ddirs:
            dep = parse_depth_from_dir(d)
            if dep:
                candidates.append(dep)

    uniq = sorted(set(candidates), key=lambda x: int(re.sub(r"\D", "", x)))
    result = []
    for dep in uniq:
        ddir = os.path.join(input_root, f"{dep}dbar")
        if os.path.isdir(ddir):
            result.append(dep)
        else:
            print(f"[WARN] Depth directory does not exist (skipped): {ddir}")
    return result


# =============================================================================
# Utilities: column normalization + base feature compatibility
# =============================================================================
def normalize_temp_sal_columns(df: pd.DataFrame,
                               temp_target: str = TempName,
                               sal_target: str = SalName,
                               temp_aliases=None,
                               sal_aliases=None) -> pd.DataFrame:
    rename_map = {}

    if ("Temperature" in df.columns) and (temp_target not in df.columns):
        rename_map["Temperature"] = temp_target
    if ("Salinity" in df.columns) and (sal_target not in df.columns):
        rename_map["Salinity"] = sal_target

    if ("Temperature" not in df.columns) and (temp_target not in df.columns):
        if temp_aliases is None:
            temp_aliases = TEMP_ALIASES
        for c in temp_aliases:
            if c in df.columns:
                rename_map.setdefault(c, temp_target)
                break

    if ("Salinity" not in df.columns) and (sal_target not in df.columns):
        if sal_aliases is None:
            sal_aliases = SAL_ALIASES
        for c in sal_aliases:
            if c in df.columns:
                rename_map.setdefault(c, sal_target)
                break

    if rename_map:
        df = df.rename(columns=rename_map)
        print(f"  -> Mapped column names: {rename_map}")

    return df

def wrap_lon180(lon_deg):
    lon = (np.asarray(lon_deg) + 180.0) % 360.0 - 180.0
    lon = np.where(lon == -180.0, 180.0, lon)
    return lon.item() if np.ndim(lon) == 0 else lon


# =============================================================================
# Model discovery & loading (consistent with training output structure)
# =============================================================================
def detect_model_file(zdir):
    candidates = []
    cbm = os.path.join(zdir, "model.cbm")
    if os.path.exists(cbm):
        candidates.append(("catboost", cbm))
    lgb_pkl = os.path.join(zdir, "model_lgbm.pkl")
    if os.path.exists(lgb_pkl):
        candidates.append(("lgbm_sklearn", lgb_pkl))
    for name in ("model.lgb", "model.txt"):
        p = os.path.join(zdir, name)
        if os.path.exists(p):
            candidates.append(("lgbm_booster", p))
    xgb_json = os.path.join(zdir, "model.json")
    if os.path.exists(xgb_json):
        candidates.append(("xgb_booster", xgb_json))
    xgb_pkl = os.path.join(zdir, "model_xgb.pkl")
    if os.path.exists(xgb_pkl):
        candidates.append(("xgb_sklearn", xgb_pkl))

    all_found = [f"{t}:{p}" for t, p in candidates]
    if not candidates:
        return None, None, all_found
    return candidates[0][0], candidates[0][1], all_found

@lru_cache(maxsize=256)
def load_zone_cached(depth_str, zone_id):
    zdir = os.path.join(MODELS_ROOT, f"depth_{depth_str}", f"zone_{zone_id}")
    meta_path = os.path.join(zdir, "metadata.pkl")
    if not (os.path.exists(zdir) and os.path.exists(meta_path)):
        return None

    mtype, mpath, all_found = detect_model_file(zdir)
    if mtype is None:
        print(f"    [WARN] No usable model file found in {zdir}.")
        return None
    if len(all_found) > 1:
        print(f"    [INFO] Multiple model files found; using priority: {mtype} -> {mpath}; all: {all_found}")

    meta = joblib.load(meta_path)

    model = None
    try:
        if mtype == "catboost":
            model = CatBoostRegressor()
            model.load_model(mpath)

        elif mtype == "lgbm_sklearn":
            if lgb is None:
                raise RuntimeError("lightgbm is not installed (pip install lightgbm).")
            model = joblib.load(mpath)
            try:
                model.set_params(n_jobs=THREADS)
            except Exception:
                pass

        elif mtype == "lgbm_booster":
            if lgb is None:
                raise RuntimeError("lightgbm is not installed (pip install lightgbm).")
            model = lgb.Booster(model_file=mpath)
            try:
                model.reset_parameter({"num_threads": THREADS})
            except Exception:
                pass

        elif mtype == "xgb_sklearn":
            if xgb is None:
                raise RuntimeError("xgboost is not installed (pip install xgboost).")
            model = joblib.load(mpath)
            try:
                model.set_params(n_jobs=THREADS)
            except Exception:
                pass

        elif mtype == "xgb_booster":
            if xgb is None:
                raise RuntimeError("xgboost is not installed (pip install xgboost).")
            booster = xgb.Booster()
            booster.load_model(mpath)
            try:
                booster.set_param({"nthread": THREADS})
            except Exception:
                pass
            model = booster

        else:
            raise RuntimeError(f"Unknown model type: {mtype}")

    except Exception as e:
        raise RuntimeError(f"Failed to load model ({mtype}): {e}")

    return {
        "meta": meta,
        "model_type": mtype,
        "model": model,
        "model_path": mpath,
        "features": meta.get("features", []),
        "core_cols": meta.get("core_cols", []),
        "aux_pool": meta.get("aux_pool", []),
        "selected_aux": meta.get("selected_aux", []),
        "chla_offset": meta.get("chla_offset", None),
    }


# =============================================================================
# Feature matrix construction (compat: fill missing weight features if needed)
# =============================================================================
def apply_chla_log(arr_like, offset):
    z = np.asarray(arr_like, float)
    z = np.where(np.isfinite(z), z, np.nan)
    return np.log10(np.maximum(z, 0) + (offset if offset is not None else 1e-6))

def _maybe_fill_weight_feature(df_part: pd.DataFrame, col: str):
    """
    Inference compatibility:
    If training included a density-weight feature but inference CSV lacks it,
    fill it with constant 1.0 so the model can still run.
    """
    if col in df_part.columns:
        return
    if col.lower() in {"w_density", "density_weight", "w_den", "w_dens", "w_density_5deg10yr"}:
        df_part[col] = 1.0

def make_feature_matrix(df_part, zone_meta, return_names=False):
    feat_order = zone_meta.get("features", [])
    if not feat_order:
        core_cols    = zone_meta.get("core_cols", []) or []
        selected_aux = zone_meta.get("selected_aux", []) or []
        feat_order   = [c for c in (list(core_cols) + list(selected_aux)) if c]

    DISCRETE_HARD_BOUNDARY_COLS = {Zonename, "Zone_Merged"}
    feat_order = [c for c in feat_order if c not in DISCRETE_HARD_BOUNDARY_COLS]

    # Compatibility: fill possible "weight feature columns"
    for c in feat_order:
        _maybe_fill_weight_feature(df_part, c)

    n_rows = len(df_part)
    n_cols = len(feat_order)
    Xf = np.full((n_rows, n_cols), np.nan, dtype=float)

    cols = df_part.columns
    chla_offset = zone_meta.get("chla_offset", None)

    for j, col in enumerate(feat_order):
        if col == "Chla":
            if "Chla" in cols:
                Xf[:, j] = apply_chla_log(df_part["Chla"], chla_offset)
            continue
        if col in cols:
            Xf[:, j] = pd.to_numeric(df_part[col], errors="coerce").to_numpy(dtype=float, copy=False)

    if return_names:
        return Xf, feat_order
    return Xf

def predict_zone_block(df_part, zone_meta):
    mtype = zone_meta["model_type"]
    model = zone_meta["model"]
    X, feat_names = make_feature_matrix(df_part, zone_meta, return_names=True)
    try:
        if mtype == "catboost":
            yhat = model.predict(X, thread_count=THREADS)
            return np.asarray(yhat, float)
        elif mtype == "lgbm_sklearn":
            try:
                yhat = model.predict(X, num_threads=THREADS)
            except TypeError:
                yhat = model.predict(X)
            return np.asarray(yhat, float)
        elif mtype == "lgbm_booster":
            try:
                yhat = model.predict(X, num_threads=THREADS)
            except TypeError:
                yhat = model.predict(X)
            return np.asarray(yhat, float)
        elif mtype == "xgb_sklearn":
            yhat = model.predict(X)
            return np.asarray(yhat, float)
        elif mtype == "xgb_booster":
            dmat = xgb.DMatrix(
                X, feature_names=feat_names,
                nthread=THREADS, enable_categorical=False, missing=np.nan
            )
            yhat = model.predict(dmat, validate_features=False)
            return np.asarray(yhat, float)
        else:
            raise RuntimeError(f"Unknown model type: {mtype}")
    except Exception as e:
        raise RuntimeError(f"{mtype} prediction failed: {e}")


def apply_zone_merge(df, zone_col=Zonename, merge_map=None, out_col="Zone_Merged"):
    if merge_map is None:
        merge_map = {}
    mm = {}
    for k, v in merge_map.items():
        try:
            mm[int(k)] = int(v)
        except Exception:
            pass
    s_num  = pd.to_numeric(df[zone_col], errors="coerce").astype("Int64")
    merged = s_num.replace(mm)
    df[out_col] = merged
    return df


# =============================================================================
# Remap: MeanBiomes reading, adjacency, and shared-boundary distances
# =============================================================================
@lru_cache(maxsize=4)
def load_biomes(nc_path):
    if Dataset is None:
        raise RuntimeError("netCDF4 is not installed; cannot load MeanBiomes (pip install netCDF4).")
    ds = Dataset(nc_path, "r")
    lon = ds["lon"][:].astype(np.float64)          # (360,)
    lat = ds["lat"][:].astype(np.float64)          # (180,)
    Z   = ds["MeanBiomes"][:].astype(np.float64)   # (lon, lat)
    ds.close()
    Z = Z.T  # -> (lat, lon)
    return lat, lon, Z

def apply_merge_map_to_grid(Z, merge_map):
    if not merge_map:
        return Z
    Zm = Z.copy()
    for old, new in merge_map.items():
        try:
            old_i = int(old); new_i = int(new)
        except Exception:
            continue
        mask = (Zm == old_i)
        if np.any(mask):
            Zm[mask] = new_i
    return Zm

def compute_pair_adjacency_and_shared_boundaries(lat, lon, Z, region_ids):
    nlat, nlon = Z.shape
    region_ids = set(region_ids)
    valid = np.isin(Z, list(region_ids))

    adj = {rid: set() for rid in region_ids}
    pair_coords = {}

    def _add_pair(j, i, ilat, ilon):
        key = (int(j), int(i))
        pair_coords.setdefault(key, []).append((int(ilat), int(ilon)))

    # Horizontal neighbors
    for jj in range(nlon - 1):
        a = Z[:, jj]
        b = Z[:, jj + 1]
        m = (a != b) & valid[:, jj] & valid[:, jj + 1]
        if not np.any(m):
            continue
        rows = np.where(m)[0]
        for ii in rows:
            ra = int(a[ii]); rb = int(b[ii])
            if (ra not in region_ids) or (rb not in region_ids):
                continue
            adj[ra].add(rb); adj[rb].add(ra)
            _add_pair(ra, rb, ii, jj + 1)  # i-side boundary points
            _add_pair(rb, ra, ii, jj)

    # Vertical neighbors
    for ii in range(nlat - 1):
        a = Z[ii, :]
        b = Z[ii + 1, :]
        m = (a != b) & valid[ii, :] & valid[ii + 1, :]
        if not np.any(m):
            continue
        cols = np.where(m)[0]
        for jj in cols:
            ra = int(a[jj]); rb = int(b[jj])
            if (ra not in region_ids) or (rb not in region_ids):
                continue
            adj[ra].add(rb); adj[rb].add(ra)
            _add_pair(ra, rb, ii + 1, jj)
            _add_pair(rb, ra, ii, jj)

    # Wrap-around neighbors (lon=0 with lon=-1)
    a = Z[:, 0]
    b = Z[:, -1]
    m = (a != b) & valid[:, 0] & valid[:, -1]
    if np.any(m):
        rows = np.where(m)[0]
        for ii in rows:
            ra = int(a[ii]); rb = int(b[ii])
            if (ra not in region_ids) or (rb not in region_ids):
                continue
            adj[ra].add(rb); adj[rb].add(ra)
            _add_pair(ra, rb, ii, nlon - 1)
            _add_pair(rb, ra, ii, 0)

    lat_rad = np.deg2rad(lat.astype(np.float64))
    lon180  = wrap_lon180(lon.astype(np.float64))
    lon_rad = np.deg2rad(lon180.astype(np.float64))

    pair_bcoords = {}
    pair_btrees  = {}

    for key, ij_list in pair_coords.items():
        if not ij_list:
            pair_bcoords[key] = None
            pair_btrees[key]  = None
            continue
        ii = np.fromiter((p[0] for p in ij_list), dtype=np.int64, count=len(ij_list))
        jj = np.fromiter((p[1] for p in ij_list), dtype=np.int64, count=len(ij_list))
        coords = np.c_[lat_rad[ii], lon_rad[jj]]
        pair_bcoords[key] = coords
        if HAVE_SKLEARN and coords.shape[0] >= 2:
            pair_btrees[key] = BallTree(coords, metric="haversine")
        else:
            pair_btrees[key] = None

    return adj, pair_btrees, pair_bcoords

def nearest_index_nonuniform(axis, values):
    values = np.asarray(values)
    idx = np.searchsorted(axis, values, side="left")
    idx0 = np.clip(idx - 1, 0, axis.size - 1)
    idx1 = np.clip(idx,     0, axis.size - 1)
    choose_right = np.abs(axis[idx1] - values) < np.abs(axis[idx0] - values)
    out = np.where(choose_right, idx1, idx0).astype(np.int64)
    return out if values.ndim > 0 else int(out)

def map_lon_to_grid(lon_vals_deg, grid_lons_deg):
    lon_vals_deg = np.asarray(lon_vals_deg)
    if (np.nanmin(grid_lons_deg) >= 0.0) and (np.nanmax(grid_lons_deg) <= 360.0):
        out = np.mod(lon_vals_deg, 360.0)
    else:
        out = wrap_lon180(lon_vals_deg)
    return out if lon_vals_deg.ndim > 0 else float(out)

def _load_mask_indices(mask_nc_path, lat_vals, lon_vals):
    if (mask_nc_path is None) or (not os.path.isfile(mask_nc_path)):
        raise RuntimeError(f"Mask file not found: {mask_nc_path}")
    with Dataset(mask_nc_path, "r") as ds:
        mlat = ds["lat"][:].astype(np.float64)
        mlon = ds["lon"][:].astype(np.float64)
        mmask = ds["remap_mask_los"][:].astype(bool)

    ii = nearest_index_nonuniform(mlat, lat_vals.astype(np.float64))
    lon_vals_grid = map_lon_to_grid(lon_vals.astype(np.float64), mlon)
    jj = nearest_index_nonuniform(mlon, lon_vals_grid)
    allow = mmask[ii, jj]
    return allow

def haversine_km(lat1, lon1, lat2, lon2):
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = np.sin(dlat/2.0)**2 + np.cos(lat1)*np.cos(lat2)*np.sin(dlon/2.0)**2
    return 2.0 * EARTH_R_KM * np.arcsin(np.sqrt(a))

def remap_smooth_zoneaware_inplace(
    df, nc_path,
    depth_str,
    smooth_km=250.0,
    workers=8,
    mask_nc_path=None,
    merge_map=None,
    debug=False,
    trace_n=0,
    trace_seed=42,
    diag_csv=False
):
    if smooth_km is None or smooth_km <= 0:
        df["RegionID"]     = pd.to_numeric(df.get(Zonename, None), errors="coerce").astype("Int64")
        df["Oxygen_remap"] = pd.to_numeric(df["Oxygen"], errors="coerce")
        return

    needed = ["Latitude", "Longitude", "Oxygen", Zonename]
    for c in needed:
        if c not in df.columns:
            raise RuntimeError(f"Remap requires column: {c}")

    df["RegionID"]     = pd.to_numeric(df[Zonename], errors="coerce").astype("Int64")
    df["Oxygen_remap"] = pd.to_numeric(df["Oxygen"], errors="coerce")

    lat_g, lon_g, Z = load_biomes(nc_path)
    Z = apply_merge_map_to_grid(Z, merge_map)

    regions_from_grid = set(
        int(v) for v in np.unique(Z[np.isfinite(Z)])
        if (v > 0) and (float(int(v)) == float(v))
    )
    regions_from_data = set(
        int(v) for v in df["RegionID"].dropna().unique()
        if v > 0
    )
    region_ids = sorted((regions_from_grid & regions_from_data) or regions_from_data or regions_from_grid)
    if not region_ids:
        return

    df_valid = df[
        df["RegionID"].isin(region_ids)
        & df["Oxygen_remap"].notna()
        & df["Latitude"].notna()
        & df["Longitude"].notna()
    ].copy()
    if df_valid.empty:
        return

    df_valid["Latitude"]  = df_valid["Latitude"].astype("float32")
    df_valid["Longitude"] = wrap_lon180(df_valid["Longitude"].astype("float32").values)
    df_valid["Oxygen"]    = df_valid["Oxygen_remap"].astype("float32")
    df_valid["RegionID"]  = df_valid["RegionID"].astype(int)

    if mask_nc_path is None:
        raise RuntimeError("mask_nc_path must be provided (LoS mask NetCDF).")

    allow_valid = _load_mask_indices(
        mask_nc_path,
        df_valid["Latitude"].values,
        df_valid["Longitude"].values
    )

    n_valid = len(df_valid)
    n_allow = int(allow_valid.sum())
    if debug:
        print(f"[REMAP] Samples inside mask: {n_allow}/{n_valid} ({(n_allow/max(1,n_valid)):.1%})")
    if not np.any(allow_valid):
        return

    adj, pair_btrees, pair_bcoords = compute_pair_adjacency_and_shared_boundaries(lat_g, lon_g, Z, region_ids)

    valid_lat   = df_valid["Latitude"].astype(np.float64).to_numpy()
    valid_lon   = df_valid["Longitude"].astype(np.float64).to_numpy()
    valid_oxy   = df_valid["Oxygen"].astype(np.float64).to_numpy()
    valid_rid   = df_valid["RegionID"].to_numpy()
    valid_index = df_valid.index.to_numpy()

    zone_meta_cache = {}
    def _get_zone_meta(zid: int):
        zmeta = zone_meta_cache.get(zid, None)
        if zmeta is None:
            zmeta = load_zone_cached(depth_str, zid)
            zone_meta_cache[zid] = zmeta
        return zmeta

    contrib_counts = []
    wsum_list = []
    dist_all = []
    w_all = []
    n_smoothed = 0

    def process_one_region(rid_j: int):
        nonlocal n_smoothed

        rel = np.where((valid_rid == rid_j) & allow_valid)[0]
        if rel.size == 0:
            return (np.empty(0, dtype=int), np.empty(0, dtype=float))

        idx_global = valid_index[rel]
        lat_rad = np.deg2rad(valid_lat[rel])
        lon_rad = np.deg2rad(valid_lon[rel])

        y_sum = valid_oxy[rel].astype(np.float64).copy()
        w_sum = np.ones(rel.size, dtype=np.float64)

        Xi = np.c_[lat_rad, lon_rad]
        neighs = sorted(adj.get(rid_j, set()))
        per_point_contrib_n = np.zeros(rel.size, dtype=int)

        for rid_i in neighs:
            key = (int(rid_j), int(rid_i))
            coords_ji = pair_bcoords.get(key, None)
            if coords_ji is None or (isinstance(coords_ji, np.ndarray) and coords_ji.size == 0):
                continue

            tree_ji = pair_btrees.get(key, None)
            if tree_ji is not None:
                dd, _ = tree_ji.query(Xi, k=1)
                d_km = dd[:, 0] * EARTH_R_KM
            else:
                latb = coords_ji[:, 0]
                lonb = coords_ji[:, 1]
                d_km = np.empty(Xi.shape[0], dtype=np.float64)
                for s in range(0, Xi.shape[0], 4096):
                    e = min(s + 4096, Xi.shape[0])
                    d = haversine_km(
                        lat_rad[s:e, None], lon_rad[s:e, None],
                        latb[None, :],      lonb[None, :]
                    )
                    d_km[s:e] = d.min(axis=1)

            in_band = d_km <= smooth_km
            if not in_band.any():
                continue

            wi = ((smooth_km - d_km[in_band]) / max(smooth_km, 1e-6)) ** 2
            rows_sel = idx_global[in_band]

            zmeta_i = _get_zone_meta(int(rid_i))
            if zmeta_i is None:
                continue

            try:
                y_hat_i = predict_zone_block(df.loc[rows_sel], zmeta_i)
            except Exception as e:
                print(f"    [WARN] Neighbor zone {rid_i} prediction failed (skipped): {e}")
                continue

            y_sum[in_band] += wi * y_hat_i
            w_sum[in_band] += wi
            per_point_contrib_n[in_band] += 1

            dist_all.extend(d_km[in_band].tolist())
            w_all.extend(wi.tolist())

        contrib_counts.extend(per_point_contrib_n.tolist())
        n_smoothed += int(np.sum(per_point_contrib_n > 0))
        wsum_list.extend(w_sum.tolist())

        out = y_sum / np.where(w_sum == 0.0, 1.0, w_sum)
        return (idx_global, out)

    def _run_parallel():
        if workers > 1:
            return Parallel(n_jobs=workers, prefer="threads")(
                delayed(process_one_region)(rid_j) for rid_j in region_ids
            )
        else:
            return [process_one_region(rid_j) for rid_j in region_ids]

    if threadpool_limits is not None:
        with threadpool_limits(limits=1):
            results = _run_parallel()
    else:
        results = _run_parallel()

    oxy_out = df["Oxygen_remap"].astype(np.float64).values.copy()
    for idx_global, vals in results:
        if idx_global.size:
            oxy_out[idx_global] = vals
    df["Oxygen_remap"] = oxy_out

    if debug:
        if len(contrib_counts) == 0:
            print("[REMAP] No masked points entered the smoothing band; output equals in-zone predictions.")
        else:
            cc = np.array(contrib_counts, int)
            ws = np.array(wsum_list, float)
            def _q(a, q):
                return float(np.percentile(a, q)) if a.size else float('nan')
            print(f"[REMAP] Samples with neighbor contributions: {n_smoothed}/{n_allow} ({(n_smoothed/max(1,n_allow)):.1%})")
            print(f"[REMAP] Neighbor-count per sample mean/median/p95 = {cc.mean():.2f} / {_q(cc,50):.2f} / {_q(cc,95):.2f}")
            print(f"[REMAP] w_sum mean/median/p95/min/max = {ws.mean():.3f} / {_q(ws,50):.3f} / {_q(ws,95):.3f} / {ws.min():.3f} / {ws.max():.3f}")
            if dist_all:
                da = np.array(dist_all, float); wa = np.array(w_all, float)
                print(f"[REMAP] Neighbor distance d(km) mean/median/p95 = {da.mean():.1f} / {_q(da,50):.1f} / {_q(da,95):.1f}")
                print(f"[REMAP] Weight w mean/median/p95 = {wa.mean():.4f} / {_q(wa,50):.4f} / {_q(wa,95):.4f}")


# =============================================================================
# Writer: chunked output to *_with_pred.csv
# =============================================================================
def save_new_csv(orig_path,
                 df,
                 delete_original=False,
                 keep_cols=None,
                 chunk_size=1_000_000):
    out_path = os.path.splitext(orig_path)[0] + "_with_pred.csv"

    if keep_cols is None or len(keep_cols) == 0:
        cols_final = list(df.columns)
    else:
        missing = [c for c in keep_cols if c not in df.columns]
        if missing:
            print(f"  -> [WARN] The following requested output columns are missing and will be ignored: {missing}")
        cols_final = [c for c in keep_cols if c in df.columns]

    na_critical_cols = [c for c in cols_final if c not in ("Oxygen", "Oxygen_remap")]

    if os.path.exists(out_path):
        try:
            os.remove(out_path)
        except Exception as e:
            print(f"  -> Cannot remove existing target file {out_path}: {e}")

    n_total = len(df)
    first_chunk = True

    for start in range(0, n_total, chunk_size):
        end = min(start + chunk_size, n_total)
        chunk = df.iloc[start:end][cols_final].copy()

        if na_critical_cols:
            before = len(chunk)
            chunk = chunk.dropna(subset=na_critical_cols)
            after = len(chunk)
            if after < before:
                print(f"  -> [INFO] chunk {start}:{end} dropped {before - after} rows containing NaN in required columns")

        for oxy_col in ["Oxygen", "Oxygen_remap"]:
            if oxy_col in chunk.columns:
                vals = pd.to_numeric(chunk[oxy_col], errors="coerce").to_numpy(dtype=float)
                vals = np.round(vals, 3)
                zero_mask = np.isfinite(vals) & (vals == 0.0)
                if np.any(zero_mask):
                    vals[zero_mask] = 0.001

                def _fmt3_arr(v):
                    if np.isnan(v):
                        return ""
                    return f"{float(v):.3f}"

                chunk[oxy_col] = [_fmt3_arr(v) for v in vals]

        chunk.to_csv(
            out_path,
            mode='w' if first_chunk else 'a',
            header=first_chunk,
            index=False
        )
        first_chunk = False
        del chunk

    print(f"  -> Wrote new file: {out_path}")

    if delete_original:
        try:
            os.remove(orig_path)
            print(f"  -> Deleted original file: {orig_path}")
        except Exception as e:
            print(f"  -> Failed to delete original file {orig_path}: {e}")

    return out_path


# =============================================================================
# Fast CSV reader (optional pyarrow)
# =============================================================================
def read_csv_fast(csv_path):
    try:
        return pd.read_csv(csv_path, engine="pyarrow")  # pandas>=2.0
    except Exception:
        return pd.read_csv(csv_path)


# =============================================================================
# Main pipeline
# =============================================================================
def main():
    parser = argparse.ArgumentParser(
        description="Zone-only oxygen prediction (cached models) + optional remap smoothing (LoS-mask limited)."
    )
    parser.add_argument("--delete", action="store_true",
                        help="Delete original CSV after writing the new file.")
    parser.add_argument("--depths", nargs="*",
                        help="Override DEPTH_LIST in config, e.g. --depths 1 10 50 100")
    parser.add_argument("--mask", type=str, default=REMAP_MASK_NC,
                        help="LoS mask NetCDF path (must contain variable remap_mask_los).")
    args, _ = parser.parse_known_args()

    delete_flag = args.delete or DELETE_ORIGINAL
    mask_path   = args.mask

    depths = list_depths_to_run(INPUT_ROOT, DEPTH_LIST, args.depths)
    if not depths:
        print(f"[ERROR] No depth directories found under root: {INPUT_ROOT}")
        return

    print(f"CWD = {os.getcwd()}")
    print(f"THREADS={THREADS}")
    print(f"INPUT_ROOT={INPUT_ROOT}")
    print(f"MODELS_ROOT={MODELS_ROOT}")
    print(f"Zonename={Zonename}, TempName={TempName}, SalName={SalName}, O2SAT_NAME={O2SAT_NAME}, TIMENAME={TIMENAME}")
    print("SEASONS:", SEASONS)
    print("Depths:", depths)

    for depth in depths:
        merge_map, remap_enabled_depth = adjust_features_and_mergepred(depth)
        remap_enabled = bool(REMAP_ENABLED_GLOBAL and remap_enabled_depth)

        if remap_enabled:
            print(f"[REMAP] enabled: S={REMAP_SMOOTH_KM} km, workers={REMAP_WORKERS}, mask={mask_path}")
        else:
            print(f"[REMAP] disabled for depth={depth} (policy-driven).")

        ddir = os.path.join(INPUT_ROOT, f"{depth}dbar")
        print(f"\n[Depth {depth}] dir: {ddir}")

        csv_files = []
        for season in SEASONS:
            candidate = os.path.join(ddir, CSV_NAME_TEMPLATE.format(depth=depth, season=season))
            if os.path.isfile(candidate):
                csv_files.append(candidate)
            else:
                print(f"  [INFO] Missing {os.path.basename(candidate)} ({depth}dbar); skipping this season.")

        if not csv_files:
            print(f"  [INFO] No input CSV found for {depth}dbar; skipping this depth.")
            continue

        model_depth_dir = os.path.join(MODELS_ROOT, f"depth_{depth}")
        if not os.path.isdir(model_depth_dir):
            print(f"  [WARN] Model directory not found: {model_depth_dir}; skipping all files for this depth.")
            continue

        for csv_path in csv_files:
            try:
                df = read_csv_fast(csv_path)
                df["source_csv"] = csv_path
                df = normalize_temp_sal_columns(df, TempName, SalName)
            except Exception as e:
                print(f"  [WARN] Failed to read/preprocess {csv_path}: {e}")
                continue

            # Zone merge: write Zone_Merged and make zone_col point to it
            if isinstance(merge_map, dict) and len(merge_map) > 0:
                if Zonename in df.columns:
                    df = apply_zone_merge(df, zone_col=Zonename, merge_map=merge_map, out_col="Zone_Merged")
                    zone_col = "Zone_Merged"
                else:
                    zone_col = Zonename
            else:
                zone_col = Zonename if Zonename in df.columns else Zonename

            if zone_col not in df.columns:
                print(f"  [SKIP] {os.path.basename(csv_path)} missing zone column {zone_col}/{Zonename}; cannot predict.")
                continue

            # Expose zone column consistently as Zonename (for KEEP_COLS and remap)
            if zone_col != Zonename:
                df[Zonename] = df[zone_col]

            # Minimal required columns
            required_min = ["Latitude", "Longitude", Zonename, O2SAT_NAME]
            missing_min = [c for c in required_min if c not in df.columns]
            if missing_min:
                print(f"  [SKIP] {os.path.basename(csv_path)} missing minimal required columns: {missing_min}")
                continue

            print(f"  File: {os.path.basename(csv_path)}  n={len(df)}  zone_col={zone_col}")

            # --- Predict per zone ---
            y_pred = np.full(len(df), np.nan, dtype=float)
            groups = df.groupby(df[zone_col]).groups

            for zval, idx in groups.items():
                try:
                    zid = int(zval)
                except Exception:
                    print(f"    [WARN] Zone value {zval} is not an integer; skipping (kept as NaN).")
                    continue

                zmeta = load_zone_cached(depth, zid)
                if zmeta is None:
                    print(f"    [WARN] Model not found depth_{depth}/zone_{zid}; zone samples kept as NaN.")
                    continue

                sub = df.loc[idx]
                try:
                    y_pred[idx] = predict_zone_block(sub, zmeta)
                except Exception as e:
                    print(f"    [WARN] zone {zid} prediction failed (NaN): {e}")

            df["Oxygen"] = np.asarray(y_pred, float)

            # Round and avoid 0.000
            y_out = np.asarray(df["Oxygen"].to_numpy(), float)
            y_out = np.round(y_out, 3)
            zero_mask = np.isfinite(y_out) & (y_out == 0.0)
            if zero_mask.any():
                y_out[zero_mask] = 0.001
                print(f"    [INFO] {int(zero_mask.sum())} values rounded to 0.000; replaced by 0.001.")
            df["Oxygen"] = y_out.astype("float32")

            # --- Remap fusion (mask-limited) ---
            if remap_enabled:
                try:
                    remap_smooth_zoneaware_inplace(
                        df,
                        nc_path=BIOMES_NC,
                        depth_str=depth,
                        smooth_km=REMAP_SMOOTH_KM,
                        workers=REMAP_WORKERS,
                        mask_nc_path=mask_path,
                        merge_map=merge_map,
                        debug=REMAP_DEBUG,
                        trace_n=REMAP_TRACE_N,
                        trace_seed=REMAP_TRACE_SEED,
                        diag_csv=REMAP_DIAG_CSV
                    )

                    if "Oxygen_remap" in df.columns:
                        yr = np.asarray(df["Oxygen_remap"].to_numpy(), float)
                        yr = np.round(yr, 3)
                        zrm = np.isfinite(yr) & (yr == 0.0)
                        if zrm.any():
                            yr[zrm] = 0.001
                            print(f"    [INFO] After remap, {int(zrm.sum())} values are 0.000; replaced by 0.001.")
                        df["Oxygen_remap"] = yr.astype("float32")

                except Exception as e:
                    print(f"  [WARN] Remap fusion failed: {e} (skipping remap; keeping Oxygen).")

            # Output columns
            keep_cols_out = KEEP_COLS.copy()
            if remap_enabled:
                for c in ["RegionID", "Oxygen_remap"]:
                    if c not in keep_cols_out:
                        keep_cols_out.append(c)

            save_new_csv(
                csv_path,
                df,
                delete_original=delete_flag,
                keep_cols=keep_cols_out,
                chunk_size=1_000_000
            )

    print("\nAll done.")

if __name__ == "__main__":
    main()

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
make_monthly_do_netcdf_cmemsgrid_mask_smooth_fill_floor_NOOBS.py

Fixed pipeline (NO observation residual blending):
- Output grid axes are strictly identical to the CMEMS TEMP grid (latitude/longitude read from TEMP_NC).
- For each (year, month, depth):
    1) Bin scattered predictions onto the target grid;
    2) Build the Mallow mask (TEMP-valid & latitude cut & optional shallow coastal exclusion);
    3) Keep the background prediction field only where (Mallow & prediction coverage);
    4) Apply Gaussian smoothing on the masked field (preserve footprint; do not expand into NaNs);
    5) Apply a minimum-value clamp (floor) to all valid values;
    6) Write into the corresponding monthly NetCDF file (time/depth unlimited; depth inserted/overwritten in ascending order).

- Variable OXY is stored as int16 with scale_factor/add_offset packing (handled by netCDF4).
"""

import os
import argparse
import numpy as np
import pandas as pd
from netCDF4 import Dataset
from datetime import datetime, timezone
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp

# ================= Optional: BallTree (accelerate haversine kNN) =================
try:
    from sklearn.neighbors import BallTree
    HAVE_SKLEARN = True
except Exception:
    BallTree = None
    HAVE_SKLEARN = False

# ===================== Default paths (override via CLI) =====================
PRED_ROOT = "/data/wang/Result_Data/allnodoxy"
OUT_DIR   = "/data/wang/Result_Data/models_ML/Datasets"

# NOTE: depthlist must exist in your runtime environment (as in your original setup).
DEPTH_LIST_DEFAULT = depthlist

SEASONS_DEFAULT = ["Spring", "Summer", "Autumn", "Winter", "NewYear"]
SEASON_MONTHS = {
    "Spring": {3, 4, 5},
    "Summer": {6, 7, 8},
    "Autumn": {9, 10, 11},
    "Winter": {12, 1, 2},
    "NewYear": {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
}

# ===================== Mask / smoothing / interpolation parameters =====================
TEMP_NC_DEFAULT                  = "/data/wang/CMEMS/TEMP/OA_CORA5.2_19600115_fld_TEMP.nc"
MASK_NC_DEFAULT                  = "/data/wang/Mask_File/land_mask_0m.nc"    # 1=land/shallow; 0=other (lon:0..360)
SHALLOW_MASK_APPLY_MAX_M_DEFAULT = 0.0
ARCTIC_CUT_LAT_DEFAULT           = -77.0105
LL_TOL_DEG_DEFAULT               = 1e-3
DEPTH_TOL_DEFAULT                = 1e-3

# Smoothing: fixed Gaussian on
GAUSS_RADIUS_KM_DEFAULT    = 50.0
MIN_SUPPORT_FRAC_DEFAULT   = 0.5

FILL_KERNEL_DEFAULT        = "gauss"          # 'idw' | 'gauss'
FILL_KMAX_DEFAULT          = 8
BASE_RADIUS_KM_DEFAULT     = 60.0
MAX_RADIUS_KM_DEFAULT      = 120.0
FILL_POWER_DEFAULT         = 2

# Minimum-value clamp
MIN_VALUE_FLOOR_DEFAULT    = 0.001

# ================= Metadata (CF-1.8 + ACDD-1.3) =================
TITLE        = "GEOXYGEN v1.1"
KEYWORDS     = "Ocean dissolved oxygen, Machine learning, 0.5deg, Long time series"
INSTITUTION  = "Fudan University"
CREATOR_NAME = "Wang et al."
PROJECT      = "24ZR1404500"
ANALYSIS     = "DO_"
SOURCE = (
    "Multi-source in situ dissolved oxygen observations (CCHDO, GLODAP, GEOTRACES IDP2021, "
    "OceanSITES, and OSD/CTD + Argo internally consistent calibrated product) and machine-learning reconstruction"
)
LICENSE      = "Creative Commons Attribution 4.0 International (CC BY 4.0)"
NAMING_AUTH  = "cn.edu.fudan"
STANDARD_NAME_VOC = "CF Standard Name Table"
REFS         = "Documentation forthcoming"
REF_DATE     = "1950-01-01T00:00:00Z"
CONVENTIONS  = "CF-1.8, ACDD-1.3"

# ================= Constants =================
DEPTH_MATCH_TOL = 1e-4
EARTH_R_KM = 6371.0

# ================= OXY packing parameters (CORA-like) =================
OXY_SCALE_FACTOR = 0.02           # 1 LSB = 0.02 µmol/kg
OXY_ADD_OFFSET   = 0.0
OXY_VALID_MIN    = 0
OXY_VALID_MAX    = 20000          # -> 400 µmol/kg
OXY_FILL_VALUE   = np.int16(32767)

# ================= Utility functions =================
def normalize_lon_to_m180_180(lon_arr: np.ndarray) -> np.ndarray:
    return (lon_arr + 180.0) % 360.0 - 180.0

def lon_to_0_360(lon_deg):
    x = np.asarray(lon_deg, np.float64)
    out = np.mod(x, 360.0)
    out[out < 0] += 360.0
    return out

def days_since_ref(dt_utc: datetime, ref=REF_DATE) -> float:
    ref_dt = datetime.fromisoformat(ref.replace("Z", "+00:00"))
    return (dt_utc - ref_dt).total_seconds() / 86400.0

def mid_month_dt(y: int, m: int) -> datetime:
    return datetime(int(y), int(m), 15, tzinfo=timezone.utc)

def nearest_index_nonuniform(axis: np.ndarray, values: np.ndarray) -> np.ndarray:
    idx = np.searchsorted(axis, values, side="left")
    idx0 = np.clip(idx - 1, 0, axis.size - 1)
    idx1 = np.clip(idx,     0, axis.size - 1)
    choose_right = np.abs(axis[idx1] - values) < np.abs(axis[idx0] - values)
    return np.where(choose_right, idx1, idx0).astype(np.int64)

def bincount_mean(i: np.ndarray, j: np.ndarray, val: np.ndarray, ny: int, nx: int):
    flat = i * nx + j
    m = np.isfinite(val)
    CNT = np.bincount(flat[m], minlength=ny*nx)
    if not np.any(m):
        return np.full((ny, nx), np.nan, dtype=np.float32), CNT.reshape(ny, nx).astype(np.int32)
    SUM = np.bincount(flat[m], weights=val[m].astype(np.float64), minlength=ny*nx)
    SUM = SUM.reshape(ny, nx)
    CNT = CNT.reshape(ny, nx)
    with np.errstate(invalid="ignore", divide="ignore"):
        MEAN = SUM / np.where(CNT == 0, 1, CNT)
    MEAN[CNT == 0] = np.nan
    return MEAN.astype(np.float32), CNT.astype(np.int32)

# ================= Gaussian smoothing (fix 0*NaN contamination) =================
def _gaussian_kernel1d(sigma_gp, truncate=3.0):
    sigma = float(max(1e-6, sigma_gp))
    half = int(max(1, np.ceil(truncate * sigma)))
    x = np.arange(-half, half+1, dtype=np.float64)
    k = np.exp(-0.5 * (x / sigma) ** 2)
    k /= k.sum() + 1e-12
    return k

def _km_to_grid_sigma(radius_km: float, ddeg: float):
    if ddeg == 0:
        return 1.0
    return max(0.5, (radius_km / 111.0) / abs(ddeg))

def _reflect_indices(idx, N):
    idx = idx.copy()
    idx[idx < 0] = -idx[idx < 0] - 1
    idx[idx >= N] = 2 * N - idx[idx >= N] - 1
    return np.clip(idx, 0, N-1)

def _norm_conv_line_wrap(x, mask01, kernel):
    acc = np.zeros_like(x, float)
    wsum = np.zeros_like(x, float)
    m = kernel.size // 2
    for r, w in enumerate(kernel):
        shift = r - m
        xs = np.roll(x, shift)
        ms = np.roll(mask01, shift)
        xs = np.where(ms > 0, xs, 0.0)
        acc += w * xs
        wsum += w * ms
    out = np.where(wsum > 0, acc / np.maximum(wsum, 1e-12), np.nan)
    return out, wsum

def _norm_conv_line_reflect(x, mask01, kernel):
    acc = np.zeros_like(x, float)
    wsum = np.zeros_like(x, float)
    N = x.shape[0]
    m = kernel.size // 2
    base = np.arange(N, dtype=np.int64)
    for r, w in enumerate(kernel):
        shift = r - m
        idx = _reflect_indices(base + shift, N)
        xs = x[idx]
        ms = mask01[idx]
        xs = np.where(ms > 0, xs, 0.0)
        acc += w * xs
        wsum += w * ms
    out = np.where(wsum > 0, acc / np.maximum(wsum, 1e-12), np.nan)
    return out, wsum

def gaussian_smooth_preserve(V: np.ndarray, lats: np.ndarray, lons: np.ndarray,
                             radius_km: float, min_support_frac: float = 0.5) -> np.ndarray:
    V = V.astype(np.float64)
    Ny, Nx = V.shape
    mask0 = np.isfinite(V).astype(np.float64)
    out = V.copy()

    dlon = float(np.nanmedian(np.diff(lons))) if Nx > 1 else 1.0
    for i in range(Ny):
        cosphi = max(0.1, abs(np.cos(np.deg2rad(float(lats[i])))))
        sigma_lon = _km_to_grid_sigma(radius_km / max(cosphi, 1e-6), dlon)
        k_lon = _gaussian_kernel1d(sigma_lon, 3.0)
        sm, ws = _norm_conv_line_wrap(out[i, :], mask0[i, :], k_lon)
        good = ws >= (min_support_frac * k_lon.sum())
        sm = np.where((mask0[i, :] > 0) & (~good), out[i, :], sm)
        sm = np.where((mask0[i, :] > 0) & np.isnan(sm), out[i, :], sm)
        sm = np.where(mask0[i, :] > 0, sm, np.nan)
        out[i, :] = sm

    mask1 = np.isfinite(out).astype(np.float64)
    dlat = float(np.nanmedian(np.diff(lats))) if Ny > 1 else 1.0
    sigma_lat = _km_to_grid_sigma(radius_km, dlat)
    k_lat = _gaussian_kernel1d(sigma_lat, 3.0)
    final = out.copy()
    for j in range(Nx):
        sm, ws = _norm_conv_line_reflect(out[:, j], mask1[:, j], k_lat)
        good = ws >= (min_support_frac * k_lat.sum())
        sm = np.where((mask0[:, j] > 0) & (~good), out[:, j], sm)
        sm = np.where((mask0[:, j] > 0) & np.isnan(sm), out[:, j], sm)
        sm = np.where(mask0[:, j] > 0, sm, np.nan)
        final[:, j] = sm

    final = np.where(mask0 > 0, final, np.nan)
    return final.astype(np.float32)

# ================= TEMP / offshore mask cache and mapping =================
_TEMP_CACHE = {"path": None, "lats": None, "lons": None, "deps": None, "mask3d": None}
_MASK_CACHE = {"path": None, "lat": None, "lon": None, "mask2d": None,
               "lat_sorted": None, "lat_inv": None, "lon_sorted": None, "lon_inv": None}

def _as_sorted_axis_and_index(axis):
    axis = np.asarray(axis, np.float64)
    order = np.argsort(axis)
    inv = np.empty_like(order)
    inv[order] = np.arange(order.size)
    return axis[order], order, inv

def nearest_index_on_sorted_axis(sorted_axis, values):
    v = np.asarray(values, np.float64)
    idx = np.searchsorted(sorted_axis, v, side="left")
    idx0 = np.clip(idx-1, 0, sorted_axis.size-1)
    idx1 = np.clip(idx,   0, sorted_axis.size-1)
    choose_right = np.abs(sorted_axis[idx1] - v) < np.abs(sorted_axis[idx0] - v)
    return np.where(choose_right, idx1, idx0).astype(np.int64)

def match_axis_indices_subset(axis_full, axis_sub, tol_deg):
    full_sorted, order_full, inv_full = _as_sorted_axis_and_index(axis_full)
    sub_sorted,  order_sub,  inv_sub  = _as_sorted_axis_and_index(axis_sub)
    idx_sorted = nearest_index_on_sorted_axis(full_sorted, sub_sorted)
    diff = np.abs(full_sorted[idx_sorted] - sub_sorted)
    if diff.size > 0 and np.nanmax(diff) > float(tol_deg):
        raise RuntimeError(f"Axis matching failed: max |Δ|={np.nanmax(diff):.6f} > tol={tol_deg}")
    idx_full_orig = inv_full[idx_sorted]
    idx_final = idx_full_orig[order_sub.argsort()]
    return idx_final.astype(np.int64)

def load_temp_cache(temp_nc_path):
    global _TEMP_CACHE
    if _TEMP_CACHE["path"] == temp_nc_path:
        return
    with Dataset(temp_nc_path, "r") as nc:
        try:
            nc.set_auto_maskandscale(True)
        except Exception:
            pass
        lats = nc["latitude"][:].astype(np.float64)
        lons = nc["longitude"][:].astype(np.float64)
        deps = nc["depth"][:].astype(np.float64)
        T0 = nc["TEMP"][0, :, :, :]   # (depth, lat, lon) masked
        mask3d = np.isfinite(np.array(T0.filled(np.nan), dtype=np.float32))
    _TEMP_CACHE = {"path": temp_nc_path, "lats": lats, "lons": lons,
                   "deps": deps, "mask3d": mask3d}

def depth_index_match(temp_deps, depth_m, tol=DEPTH_TOL_DEFAULT):
    idx = int(np.argmin(np.abs(temp_deps - float(depth_m))))
    if abs(float(temp_deps[idx]) - float(depth_m)) > float(tol):
        raise RuntimeError(f"No matching depth: {depth_m} m (nearest={temp_deps[idx]} m)")
    return idx

def load_offshore_mask(mask_nc_path):
    global _MASK_CACHE
    if _MASK_CACHE["path"] == mask_nc_path:
        return
    with Dataset(mask_nc_path, "r") as nc:
        lat = nc["lat"][:].astype(np.float64)
        lon = nc["lon"][:].astype(np.float64)
        m   = nc["offshore_mask"][:].astype(np.uint8)  # 1=land/shallow; 0=other
    lat_sorted, _, lat_inv = _as_sorted_axis_and_index(lat)
    lon_sorted, _, lon_inv = _as_sorted_axis_and_index(lon)
    _MASK_CACHE = {"path": mask_nc_path, "lat": lat, "lon": lon, "mask2d": m,
                   "lat_sorted": lat_sorted, "lat_inv": lat_inv,
                   "lon_sorted": lon_sorted, "lon_inv": lon_inv}

def map_offshore_mask_to_grid(mask_nc_path, do_lats, do_lons):
    load_offshore_mask(mask_nc_path)
    lat_sorted = _MASK_CACHE["lat_sorted"]
    lon_sorted = _MASK_CACHE["lon_sorted"]
    lat_inv    = _MASK_CACHE["lat_inv"]
    lon_inv    = _MASK_CACHE["lon_inv"]
    M2         = _MASK_CACHE["mask2d"]
    idx_lat_sorted = nearest_index_on_sorted_axis(lat_sorted, do_lats)
    idx_lon_sorted = nearest_index_on_sorted_axis(lon_sorted, lon_to_0_360(do_lons))
    lat_idx = lat_inv[idx_lat_sorted]
    lon_idx = lon_inv[idx_lon_sorted]
    mapped = M2[np.ix_(lat_idx, lon_idx)].astype(np.uint8)
    return mapped

def build_pred_mask_for_grid(depth_m: float, target_lats: np.ndarray, target_lons: np.ndarray,
                             temp_nc_path: str,
                             mask_nc_path: str,
                             shallow_mask_max_m: float,
                             arctic_cut_lat: float,
                             ll_tol_deg: float,
                             depth_tol: float) -> np.ndarray:
    load_temp_cache(temp_nc_path)
    tlats = _TEMP_CACHE["lats"]; tlons = _TEMP_CACHE["lons"]
    tdeps = _TEMP_CACHE["deps"]; tmask3d = _TEMP_CACHE["mask3d"]
    lat_idx_temp = match_axis_indices_subset(tlats, target_lats, tol_deg=ll_tol_deg)
    lon_idx_temp = match_axis_indices_subset(tlons, target_lons, tol_deg=ll_tol_deg)
    tz = depth_index_match(tdeps, float(depth_m), tol=depth_tol)
    Mtemp_do = tmask3d[tz, :, :][np.ix_(lat_idx_temp, lon_idx_temp)]
    Mallow = Mtemp_do.copy()

    lat_ok = (target_lats < float(arctic_cut_lat))
    if np.any(lat_ok):
        Mallow[lat_ok, :] = False

    if float(depth_m) <= float(shallow_mask_max_m):
        offshore_mask = map_offshore_mask_to_grid(mask_nc_path, target_lats, target_lons)
        Mallow = Mallow & (offshore_mask == 0)

    return Mallow.astype(bool)

# ================= NetCDF: create / write (time/depth unlimited) =================
def ensure_month_file(path_nc: str, lats: np.ndarray, lons: np.ndarray, year: int, month: int):
    if not os.path.exists(path_nc):
        os.makedirs(os.path.dirname(path_nc), exist_ok=True)
        ds = Dataset(path_nc, "w", format="NETCDF4")
        Ny, Nx = len(lats), len(lons)
        ds.createDimension("time", None)
        ds.createDimension("depth", None)
        ds.createDimension("latitude",  Ny)
        ds.createDimension("longitude", Nx)

        vtime  = ds.createVariable("time", "f4", ("time",))
        vdepth = ds.createVariable("depth", "f4", ("depth",))
        vlat   = ds.createVariable("latitude",  "f4", ("latitude",))
        vlon   = ds.createVariable("longitude", "f4", ("longitude",))
        vlat[:] = lats.astype(np.float32)
        vlon[:] = lons.astype(np.float32)

        vtime.standard_name = "time"
        vtime.units = f"days since {REF_DATE}"
        vtime.calendar = "gregorian"
        vtime.axis = "T"

        vdepth.standard_name = "depth"
        vdepth.units = "m"
        vdepth.positive = "down"
        vdepth.axis = "Z"

        vlat.standard_name = "latitude"
        vlat.units = "degree_north"
        vlat.axis = "Y"

        vlon.standard_name = "longitude"
        vlon.units = "degree_east"
        vlon.axis = "X"

        chunks_lat = min(200, Ny if Ny > 0 else 1)
        chunks_lon = min(200, Nx if Nx > 0 else 1)

        voxy = ds.createVariable(
            "OXY", "i2", ("time", "depth", "latitude", "longitude"),
            zlib=True, complevel=4, shuffle=True, fill_value=OXY_FILL_VALUE,
            chunksizes=(1, 8, max(1, chunks_lat), max(1, chunks_lon))
        )
        voxy.long_name     = "Dissolved Oxygen"
        voxy.standard_name = "mole_concentration_of_dissolved_molecular_oxygen_in_sea_water"
        voxy.units         = "umol kg-1"
        voxy.scale_factor  = np.float32(OXY_SCALE_FACTOR)
        voxy.add_offset    = np.float32(OXY_ADD_OFFSET)
        voxy.valid_min     = np.int16(OXY_VALID_MIN)
        voxy.valid_max     = np.int16(OXY_VALID_MAX)
        voxy.missing_value = OXY_FILL_VALUE

        vtime[0:1] = np.float32(days_since_ref(mid_month_dt(year, month)))

        ds.Conventions              = CONVENTIONS
        ds.title                    = TITLE
        ds.keywords                 = KEYWORDS
        ds.institution              = INSTITUTION
        ds.creator_name             = CREATOR_NAME
        ds.project                  = PROJECT
        ds.project_name             = PROJECT
        ds.analysis_name            = ANALYSIS
        ds.source                   = SOURCE
        ds.license                  = LICENSE
        ds.naming_authority         = NAMING_AUTH
        ds.standard_name_vocabulary = STANDARD_NAME_VOC
        ds.references               = REFS
        ds.history                  = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") + " : Creation"
        ds.id                       = f"DO_{year:04d}{month:02d}"
        ds.dataset_id               = f"DO_{year:04d}{month:02d}"
        ds.cdm_data_type            = "Grid"
        ds.featureType              = "grid"
        ds.time_coverage_start      = f"{year:04d}-{month:02d}-15T00:00:00Z"
        ds.time_coverage_end        = f"{year:04d}-{month:02d}-15T00:00:00Z"
        ds.geospatial_lat_min       = float(np.nanmin(lats))
        ds.geospatial_lat_max       = float(np.nanmax(lats))
        ds.geospatial_lon_min       = float(np.nanmin(lons))
        ds.geospatial_lon_max       = float(np.nanmax(lons))
        if len(lats) > 1:
            ds.geospatial_lat_resolution = float(np.nanmedian(np.abs(np.diff(lats))))
        if len(lons) > 1:
            ds.geospatial_lon_resolution = float(np.nanmedian(np.abs(np.diff(lons))))
        ds.comment = (
            "Monthly mean DO; CMEMS grid; "
            "mask(Mallow)→smooth(finite)→fill(Mallow NaNs)→floor(min)→write."
        )

        ds.close()
        return lats.astype(np.float32), lons.astype(np.float32)

    with Dataset(path_nc, "r") as ds:
        return ds["latitude"][:].astype(np.float32), ds["longitude"][:].astype(np.float32)

def insert_or_overwrite_depth_sorted(ds: Dataset, depth_m: float, data2d: np.ndarray):
    vdep = ds["depth"]
    voxy = ds["OXY"]
    Ny = ds.dimensions["latitude"].size
    Nx = ds.dimensions["longitude"].size

    if data2d.shape != (Ny, Nx):
        raise RuntimeError("data2d does not match the file grid shape.")

    data2d_ma = np.ma.masked_invalid(np.array(data2d, dtype=np.float32, copy=False))

    if vdep.size == 0:
        vdep[0:1] = float(depth_m)
        voxy[0:1, 0:1, :, :] = data2d_ma[np.newaxis, np.newaxis, :, :]
        return

    depths = vdep[:].astype(float)
    hit = np.where(np.abs(depths - float(depth_m)) < DEPTH_MATCH_TOL)[0]
    if hit.size > 0:
        idx = int(hit[0])
        voxy[0:1, idx:idx+1, :, :] = data2d_ma[np.newaxis, np.newaxis, :, :]
        return

    idx = int(np.searchsorted(depths, float(depth_m), side="left"))
    N = vdep.size
    if idx >= N:
        vdep[N:N+1] = float(depth_m)
        voxy[0:1, N:N+1, :, :] = data2d_ma[np.newaxis, np.newaxis, :, :]
        return

    vdep[N:N+1] = depths[-1]
    for d in range(N-1, idx-1, -1):
        vdep[d+1] = vdep[d]
        slab = voxy[0, d, :, :].copy()
        voxy[0, d+1, :, :] = slab
    vdep[idx] = float(depth_m)
    voxy[0:1, idx:idx+1, :, :] = data2d_ma[np.newaxis, np.newaxis, :, :]

def write_or_append_depth_sorted(path_nc, year, month, depth_m, lats_target, lons_target, field2d_target):
    ds = Dataset(path_nc, "r+")
    try:
        vtime = ds["time"]
        intended = days_since_ref(mid_month_dt(year, month))
        if vtime.size == 0:
            vtime[0:1] = np.float32(intended)
        else:
            if abs(float(vtime[0]) - float(intended)) > 1e-3:
                raise RuntimeError(f"{path_nc}: time[0] does not match {year}-{month:02d}.")

        lat_nc = ds["latitude"][:]
        lon_nc = ds["longitude"][:]
        if lat_nc.shape[0] != lats_target.shape[0] or lon_nc.shape[0] != lons_target.shape[0]:
            raise RuntimeError("Field axes do not match the file axes (CMEMS grid must not change).")

        insert_or_overwrite_depth_sorted(ds, float(depth_m), np.array(field2d_target, dtype=np.float32, copy=True))

        valid = np.isfinite(field2d_target)
        if valid.any():
            mn = float(np.nanmin(field2d_target))
            mx = float(np.nanmax(field2d_target))
            ds.setncattr("last_write_stats", f"{year}-{month:02d} depth={depth_m:g}m min={mn:.3f} max={mx:.3f}")
        else:
            ds.setncattr("last_write_stats", f"{year}-{month:02d} depth={depth_m:g}m ALL_NaN")
    finally:
        ds.close()

# ================= CSV reader (predictions) =================
def read_pred_csv(csv_path: str):
    if not os.path.isfile(csv_path):
        return [], None, None
    header = pd.read_csv(csv_path, nrows=0)
    cols = set(header.columns)
    oxy_col = "Oxygen_remap" if "Oxygen_remap" in cols else ("Oxygen" if "Oxygen" in cols else None)
    if oxy_col is None or "Latitude" not in cols or "Longitude" not in cols or "Year" not in cols or "Month" not in cols:
        print(f"[SKIP-PRED] {os.path.basename(csv_path)} is missing required columns.")
        return [], None, None
    usecols = ["Year", "Month", "Latitude", "Longitude", oxy_col]
    df = pd.read_csv(csv_path, usecols=usecols)
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=usecols)
    if oxy_col != "Oxygen":
        df = df.rename(columns={oxy_col: "Oxygen"})
    df["Longitude"] = normalize_lon_to_m180_180(df["Longitude"].astype(float).values)
    df = df[(df["Latitude"].between(-90, 90, inclusive="both"))]
    df["Year"]  = df["Year"].astype(int)
    df["Month"] = df["Month"].astype(int)
    months = sorted(set(zip(df["Year"].tolist(), df["Month"].tolist())))
    return months, df, oxy_col

# ================= Neighbor interpolation =================
def _deg2rad(x):
    return np.deg2rad(np.asarray(x, np.float64))

def points_idw_or_gauss(src_lat, src_lon, src_val, tgt_lat, tgt_lon,
                        kmax=12, radius_km=None, kernel="idw", power=2):
    src_lat = np.asarray(src_lat, np.float64)
    src_lon = np.asarray(src_lon, np.float64)
    src_val = np.asarray(src_val, np.float64)
    tgt_lat = np.asarray(tgt_lat, np.float64)
    tgt_lon = np.asarray(tgt_lon, np.float64)

    m = np.isfinite(src_val)
    if not np.any(m) or tgt_lat.size == 0:
        return np.full(tgt_lat.shape, np.nan, dtype=np.float32)
    src_lat = src_lat[m]; src_lon = src_lon[m]; src_val = src_val[m]

    P = np.c_[_deg2rad(src_lat), _deg2rad(src_lon)]
    X = np.c_[_deg2rad(tgt_lat), _deg2rad(tgt_lon)]
    k = max(1, min(int(kmax), P.shape[0]))

    if HAVE_SKLEARN and (P.shape[0] >= 2):
        tree = BallTree(P, metric="haversine")
        dd, ii = tree.query(X, k=k)
        d_km = dd * EARTH_R_KM
        neigh_vals = src_val[ii]
    else:
        d_all = 2.0 * EARTH_R_KM * np.arcsin(np.sqrt(
            np.sin((X[:, 0, None] - P[None, :, 0]) / 2.0) ** 2 +
            np.cos(X[:, 0, None]) * np.cos(P[None, :, 0]) *
            np.sin((X[:, 1, None] - P[None, :, 1]) / 2.0) ** 2
        ))
        k = min(k, d_all.shape[1])
        ordk = np.argpartition(d_all, k-1, axis=1)[:, :k]
        sort_idx = np.take_along_axis(d_all, ordk, axis=1).argsort(axis=1)
        idxk = np.take_along_axis(ordk, sort_idx, axis=1)
        d_km = np.take_along_axis(d_all, idxk, axis=1)
        neigh_vals = src_val[idxk]

    if str(kernel).lower() == "gauss":
        s = max((radius_km if radius_km else float(np.nanmedian(d_km))) / 2.0, 1e-3)
        W = np.exp(-(d_km / s) ** 2)
    else:
        W = 1.0 / (d_km + 1e-6) ** max(1, int(power))

    if radius_km is not None:
        W = np.where(d_km <= float(radius_km), W, 0.0)

    Wsum = W.sum(axis=1)
    Wsafe = np.where(Wsum == 0.0, 1.0, Wsum)
    pred = (W * neigh_vals).sum(axis=1) / Wsafe
    pred[Wsum == 0.0] = np.nan
    return pred.astype(np.float32)

def staged_fill(src_lat, src_lon, src_val, tgt_lat, tgt_lon,
                kmax, base_radius_km, max_radius_km, kernel, power):
    R0 = float(base_radius_km)
    R1 = float(min(2.0 * R0, max_radius_km))
    pred = points_idw_or_gauss(src_lat, src_lon, src_val, tgt_lat, tgt_lon,
                               kmax=kmax, radius_km=R0, kernel=kernel, power=power)
    miss = ~np.isfinite(pred)
    if np.any(miss) and (R1 > R0 + 1e-6):
        pred2 = points_idw_or_gauss(src_lat, src_lon, src_val,
                                    tgt_lat[miss], tgt_lon[miss],
                                    kmax=kmax, radius_km=R1,
                                    kernel=kernel, power=power)
        pred[miss] = pred2
    return pred

# ================= Floor clamp =================
def apply_floor(A: np.ndarray, floor_val: float) -> np.ndarray:
    X = A.copy()
    m = np.isfinite(X)
    if np.any(m):
        X[m] = np.maximum(X[m], float(floor_val))
    return X

# ================= Child process: handle one (year, month) =================
def _process_one_month_task(task):
    yy = task['yy']; mm = task['mm']
    depth_m = float(task['depth_m'])
    out_dir = task['out_dir']
    floor   = float(task['min_value_floor'])

    plat = task['pred_lat'].astype(np.float64)
    plon = normalize_lon_to_m180_180(task['pred_lon'].astype(np.float64))
    poxy = task['pred_oxy'].astype(np.float64)

    load_temp_cache(task["temp_nc"])
    tlats = _TEMP_CACHE["lats"].astype(np.float32)
    tlons = _TEMP_CACHE["lons"].astype(np.float32)

    out_name = f"GLOBAL_DO_{yy:04d}{mm:02d}15_0p5deg_v1.nc"
    out_path = os.path.join(out_dir, out_name)

    if os.path.exists(out_path):
        lats_nc, lons_nc = ensure_month_file(out_path, tlats, tlons, yy, mm)
        if (np.max(np.abs(lats_nc - tlats)) > 1e-6) or (np.max(np.abs(lons_nc - tlons)) > 1e-6):
            raise RuntimeError(f"{out_path}: file axes do not match CMEMS axes.")
        target_lats, target_lons = lats_nc.astype(np.float32), lons_nc.astype(np.float32)
    else:
        target_lats, target_lons = tlats, tlons
        ensure_month_file(out_path, target_lats, target_lons, yy, mm)

    Ny, Nx = len(target_lats), len(target_lons)

    # 1) Bin to grid (mean)
    ii = nearest_index_nonuniform(target_lats, plat)
    jj = nearest_index_nonuniform(target_lons, plon)
    V_pred_raw_on_cmems, CNT_pred = bincount_mean(ii, jj, poxy, Ny, Nx)
    have_pred = (CNT_pred > 0)

    # 2) Mallow mask
    Mallow = build_pred_mask_for_grid(
        depth_m=depth_m,
        target_lats=target_lats.astype(np.float64),
        target_lons=target_lons.astype(np.float64),
        temp_nc_path=task["temp_nc"],
        mask_nc_path=task["mask_nc"],
        shallow_mask_max_m=task["shallow_mask_max_m"],
        arctic_cut_lat=task["arctic_cut_lat"],
        ll_tol_deg=task["ll_tol_deg"],
        depth_tol=task["depth_tol"]
    )

    # 3) Masked background prediction field
    V = np.where(Mallow & have_pred, V_pred_raw_on_cmems, np.nan).astype(np.float32)
    V = apply_floor(V, floor)

    # 4) Gaussian smoothing (finite-only; footprint-preserving)
    V = gaussian_smooth_preserve(
        V, target_lats, target_lons,
        radius_km=task["gauss_radius_km"],
        min_support_frac=task["gauss_min_support"]
    )
    V = apply_floor(V, floor)

    # 5) Fill remaining NaNs inside Mallow (two-stage radius)
    need_fill = Mallow & ~np.isfinite(V)
    src_mask  = np.isfinite(V)
    if np.any(need_fill) and np.any(src_mask):
        JJ, II = np.meshgrid(np.arange(Nx), np.arange(Ny))
        lat_grid = target_lats[II]
        lon_grid = target_lons[JJ]

        si, sj = np.where(src_mask)
        ti, tj = np.where(need_fill)

        src_lat = lat_grid[si, sj]
        src_lon = lon_grid[si, sj]
        src_val = V[si, sj]
        tgt_lat = lat_grid[ti, tj]
        tgt_lon = lon_grid[ti, tj]

        pred = staged_fill(
            src_lat, src_lon, src_val,
            tgt_lat, tgt_lon,
            kmax=task["fill_kmax"],
            base_radius_km=task["fill_base_radius_km"],
            max_radius_km=task["fill_max_radius_km"],
            kernel=task["fill_kernel"],
            power=task["fill_power"]
        )
        ok = np.isfinite(pred)
        if np.any(ok):
            V[ti[ok], tj[ok]] = pred[ok]

    # 6) Final: valid only inside Mallow + floor
    V_out = np.where(Mallow, V, np.nan).astype(np.float32)
    V_out = apply_floor(V_out, floor)

    # 7) Write NetCDF
    write_or_append_depth_sorted(out_path, yy, mm, float(depth_m),
                                 target_lats, target_lons, V_out)

    vmin = float(np.nanmin(V_out)) if np.isfinite(V_out).any() else np.nan
    vmax = float(np.nanmax(V_out)) if np.isfinite(V_out).any() else np.nan
    return {
        "yy": yy, "mm": mm, "depth": depth_m,
        "out": out_path, "vmin": vmin, "vmax": vmax,
        "count_valid": int(np.isfinite(V_out).sum()),
        "grid": (Ny, Nx)
    }

# ================= Main =================
def main():
    parser = argparse.ArgumentParser(
        description="Monthly DO NetCDF maker on CMEMS grid: mask(Mallow)→smooth→fill→floor→write (NO OBS).",
        allow_abbrev=False
    )
    parser.add_argument("--pred_root", default=PRED_ROOT)
    parser.add_argument("--out",       default=OUT_DIR)
    parser.add_argument("--depths", nargs="*", default=DEPTH_LIST_DEFAULT, required=True,
                        help="Depth list must be provided explicitly, e.g.: --depths 10 20 50 100")
    parser.add_argument("--seasons", nargs="*", default=SEASONS_DEFAULT)

    parser.add_argument("--workers", type=int, default=min(24, os.cpu_count() or 1))

    parser.add_argument("--temp_nc", default=TEMP_NC_DEFAULT)
    parser.add_argument("--mask_nc", default=MASK_NC_DEFAULT)
    parser.add_argument("--shallow_mask_max_m", type=float, default=SHALLOW_MASK_APPLY_MAX_M_DEFAULT)
    parser.add_argument("--arctic_cut_lat", type=float, default=ARCTIC_CUT_LAT_DEFAULT)
    parser.add_argument("--ll_tol_deg", type=float, default=LL_TOL_DEG_DEFAULT)
    parser.add_argument("--depth_tol", type=float, default=DEPTH_TOL_DEFAULT)

    parser.add_argument("--gauss_radius_km", type=float, default=GAUSS_RADIUS_KM_DEFAULT)
    parser.add_argument("--gauss_min_support", type=float, default=MIN_SUPPORT_FRAC_DEFAULT)

    parser.add_argument("--fill_kernel", type=str, default=FILL_KERNEL_DEFAULT, choices=["idw", "gauss"])
    parser.add_argument("--fill_kmax", type=int, default=FILL_KMAX_DEFAULT)
    parser.add_argument("--fill_base_radius_km", type=float, default=BASE_RADIUS_KM_DEFAULT)
    parser.add_argument("--fill_max_radius_km", type=float, default=MAX_RADIUS_KM_DEFAULT)
    parser.add_argument("--fill_power", type=int, default=FILL_POWER_DEFAULT)

    parser.add_argument("--min_value_floor", type=float, default=MIN_VALUE_FLOOR_DEFAULT)

    args, _ = parser.parse_known_args()

    pred_root = str(args.pred_root)
    out_dir   = str(args.out)
    depths    = [str(d) for d in (args.depths or [])]
    seasons   = list(args.seasons or [])
    workers   = max(1, int(args.workers))

    temp_nc_path       = str(args.temp_nc)
    mask_nc_path       = str(args.mask_nc)
    shallow_mask_max_m = float(args.shallow_mask_max_m)
    arctic_cut_lat     = float(args.arctic_cut_lat)
    ll_tol_deg         = float(args.ll_tol_deg)
    depth_tol          = float(args.depth_tol)

    gauss_radius_km    = float(args.gauss_radius_km)
    gauss_min_support  = float(args.gauss_min_support)

    fill_kernel        = str(args.fill_kernel).lower()
    fill_kmax          = int(args.fill_kmax)
    fill_base_radius_km= float(args.fill_base_radius_km)
    fill_max_radius_km = float(args.fill_max_radius_km)
    fill_power         = int(args.fill_power)

    min_value_floor    = float(args.min_value_floor)

    os.makedirs(out_dir, exist_ok=True)

    mp_context = mp.get_context("fork") if hasattr(mp, "get_context") else None

    # Warm up TEMP cache
    load_temp_cache(temp_nc_path)

    total_written = 0

    for depth in depths:
        pred_dir = os.path.join(pred_root, f"{depth}dbar")
        if not os.path.isdir(pred_dir):
            print(f"[WARN] Prediction directory not found: {pred_dir} (skip this depth)")
            continue

        for season in seasons:
            if season not in SEASON_MONTHS:
                print(f"[WARN] Unknown season name: {season} (skip)")
                continue

            base_pred = os.path.join(pred_dir, f"depth{depth}_{season}_TRAIN")
            pred_candidates = [
                base_pred + "_with_pred_remap.csv",
                base_pred + "_with_pred.csv",
                base_pred + ".csv",
            ]
            pred_csv = next((p for p in pred_candidates if os.path.isfile(p)), None)
            if not pred_csv:
                print(f"[INFO] depth={depth} season={season}: no prediction file found; skip.")
                continue

            print(f"\n[PROC] depth={depth}m season={season}")
            print(f"  - pred: {os.path.basename(pred_csv)}")

            months_pred, df_pred, _ = read_pred_csv(pred_csv)

            # Keep your original month window logic (adjust as needed for your paper version)
            def _ym(y, m): return int(y) * 12 + int(m)
            START_YM = 1960 * 12 + 1
            END_YM   = 2030 * 12 + 12
            months = [(y, m) for (y, m) in sorted(set(months_pred))
                      if START_YM <= _ym(y, m) <= END_YM]
            if not months:
                print("  - No months within the selection window for this season; skip.")
                continue

            pred_groups = {(y, m): g for (y, m), g in df_pred.groupby(["Year", "Month"], sort=False)}

            tasks = []
            for (yy, mm) in months:
                g_pred = pred_groups.get((yy, mm), None)
                if g_pred is None or g_pred.empty:
                    continue
                tasks.append({
                    "yy": int(yy), "mm": int(mm),
                    "depth_m": float(depth),
                    "out_dir": out_dir,
                    "pred_lat": g_pred["Latitude"].to_numpy(np.float64, copy=False),
                    "pred_lon": g_pred["Longitude"].to_numpy(np.float64, copy=False),
                    "pred_oxy": g_pred["Oxygen"].to_numpy(np.float64, copy=False),

                    "temp_nc": temp_nc_path,
                    "mask_nc": mask_nc_path,
                    "shallow_mask_max_m": shallow_mask_max_m,
                    "arctic_cut_lat": arctic_cut_lat,
                    "ll_tol_deg": ll_tol_deg,
                    "depth_tol": depth_tol,

                    "gauss_radius_km": gauss_radius_km,
                    "gauss_min_support": gauss_min_support,

                    "fill_kernel": fill_kernel,
                    "fill_kmax": fill_kmax,
                    "fill_base_radius_km": fill_base_radius_km,
                    "fill_max_radius_km": fill_max_radius_km,
                    "fill_power": fill_power,

                    "min_value_floor": min_value_floor,
                })

            if not tasks:
                print("  - No writable months for this season.")
                continue

            executor_kwargs = {"max_workers": workers}
            if mp_context is not None:
                executor_kwargs["mp_context"] = mp_context

            with ProcessPoolExecutor(**executor_kwargs) as ex:
                futs = {ex.submit(_process_one_month_task, t): (t["yy"], t["mm"]) for t in tasks}
                for fut in as_completed(futs):
                    yy_, mm_ = futs[fut]
                    try:
                        info = fut.result()
                        total_written += 1
                        print(
                            f"    [OK] {info['yy']}-{info['mm']:02d} "
                            f"depth={info['depth']:g}m -> {os.path.basename(info['out'])} "
                            f"grid={info['grid']} out[min,max]=[{info['vmin']:.3f},{info['vmax']:.3f}] "
                            f"valid={info['count_valid']}"
                        )
                    except Exception as e:
                        print(f"    [ERR] {yy_}-{mm_:02d} depth={depth}m : {e}")

    if total_written == 0:
        print("\n[DONE] No writable data found (check input paths/naming/columns).")
    else:
        print(f"\n[DONE] Successfully wrote/updated {total_written} (month, depth) slices.")

if __name__ == "__main__":
    # Limit BLAS thread counts (avoid oversubscription under multiprocessing)
    os.environ.setdefault("OMP_NUM_THREADS", "1")
    os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
    os.environ.setdefault("MKL_NUM_THREADS", "1")
    os.environ.setdefault("VECLIB_MAXIMUM_THREADS", "1")
    os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
    main()