In [1]:
#!/usr/bin/env python3
"""
Pigeon Lake data preparation — GCS-friendly, binary-ready (storage-client preferred)

Updates in this copy:
- Default CSV / tif_dir / out_dir set to your bucket (kgc2-arch) but targeting Pigeon Lake naming.
- upload_file_to_gcs now checks that the destination bucket exists and raises a
  clear error if not (instead of surfacing a raw 404 stack trace).
- All other behavior unchanged: falls back to CSV dates if no tifs, produces
  processed_dataset_binary.npz and joblibs named with "binary".

Run example:
python3 pigeon_prepare_binary_simple_gcs.py \
  --csv "gs://final_data_kgc2/Final_data/Pigeon_pixel_80m_binary.csv" \
  --tif_dir "gs://final_data_kgc2/Final_data/Pigeon_80_binary/" \
  --out_dir "gs://final_data_kgc2/models/Pigeon_binary/" \
  --random_state 0

If you prefer to write outputs to a local directory, pass --out_dir /tmp/pigeon_out
"""

import os
import re
import argparse
import numpy as np
import pandas as pd
from datetime import timedelta
import joblib
import sys
import tempfile
import shutil
import io
from pathlib import Path

# IterativeImputer
from sklearn.experimental import enable_iterative_imputer  # noqa: F401
from sklearn.impute import IterativeImputer
from sklearn.linear_model import BayesianRidge
from sklearn.preprocessing import StandardScaler

# Optional: GCS support (gcsfs)
try:
    import gcsfs
except Exception:
    gcsfs = None

# ---------------------------
# Defaults (set to your kgc2-arch bucket)
# ---------------------------
DEFAULT_CSV = "gs://final_data_kgc2/Final_data/Pigeon_pixel_80m_binary.csv"
SAT_TIF_DIR = "gs://final_data_kgc2/Final_data/Pigeon_80m/"
DEFAULT_OUT_DIR = "gs://final_data_kgc2/models/Pigeon_binary/"

META_COLS = ["Date", "X", "Y", "Pixel_ID"]
TARGET_COL = "Chl_a"
BLOOM_THRESHOLD_DEFAULT = 20.0  # used only if binary label must be derived

# ---------------------------
# GCS / IO helpers (storage-client preferred)
# ---------------------------
def is_gcs_path(path: str) -> bool:
    return isinstance(path, str) and path.startswith("gs://")

def gcs_join(prefix: str, *parts: str) -> str:
    p = prefix.rstrip('/')
    for part in parts:
        p = p.rstrip('/') + '/' + str(part).lstrip('/')
    return p

def upload_file_to_gcs(local_path: str, gs_uri: str, requester_pays_project: str | None = None):
    """Upload local_path -> gs:// path using google-cloud-storage client (preferred).
    This function now checks that the destination bucket exists and raises a clear error
    with guidance if it does not.
    """
    if not is_gcs_path(gs_uri):
        raise ValueError("gs_uri must be a gs:// path")
    try:
        from google.cloud import storage
    except Exception as e:
        if gcsfs is not None:
            # best-effort fallback to gcsfs
            fs = gcsfs.GCSFileSystem(token="google")
            fs.put(local_path, gs_uri)
            return
        raise RuntimeError("Install google-cloud-storage or gcsfs to upload to GCS") from e

    client_kwargs = {}
    if requester_pays_project:
        client_kwargs["project"] = requester_pays_project
    client = storage.Client(**client_kwargs)

    _, rest = gs_uri.split("gs://", 1)
    bucket_name, _, blob_path = rest.partition("/")

    # explicit bucket existence check to give a helpful error instead of 404 stacktrace
    bucket = client.lookup_bucket(bucket_name)
    if bucket is None:
        raise FileNotFoundError(
            f"Destination bucket does not exist: gs://{bucket_name}\n\n"
            "Fixes:\n"
            f" - Use an existing bucket (e.g. --out_dir gs://kgc2-arch/models/Pigeon_binary/),\n"
            " - Or create the bucket in your project with gsutil mb gs://<bucket>/,\n"
            " - Ensure the service account running this job has storage permissions on that bucket."
        )

    blob = bucket.blob(blob_path)
    try:
        blob.upload_from_filename(local_path)
    except Exception as e:
        # re-raise with helpful context
        raise RuntimeError(f"Upload failed for {local_path} -> {gs_uri}: {e}") from e

def save_local_or_gcs(local_src_path: str, out_dir: str, dest_basename: str | None = None, requester_pays_project: str | None = None) -> str:
    """Move local_src_path into out_dir (local) or upload to gs://out_dir/dest_basename. Returns destination path."""
    if dest_basename is None:
        dest_basename = os.path.basename(local_src_path)
    if is_gcs_path(out_dir):
        gs_dest = gcs_join(out_dir, dest_basename)
        upload_file_to_gcs(local_src_path, gs_dest, requester_pays_project=requester_pays_project)
        try:
            os.remove(local_src_path)
        except Exception:
            pass
        return gs_dest
    else:
        os.makedirs(out_dir, exist_ok=True)
        dst = os.path.join(out_dir, dest_basename)
        shutil.move(local_src_path, dst)
        return dst

# ---------------------------
# Robust CSV loader + listing (storage-client preferred)
# ---------------------------
def load_csv(csv_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False) -> pd.DataFrame:
    """
    Load CSV from local path or GCS.
    use_gcsfs: if True attempt gcsfs first (may show gcsfs traces on bad credentials).
    Otherwise prefer google-cloud-storage client (recommended on Google Cloud).
    """
    print(f"Loading CSV: {csv_path}  (use_gcsfs={use_gcsfs})")
    if not is_gcs_path(csv_path):
        df = pd.read_csv(csv_path)
        df.columns = [c.strip() for c in df.columns]
        df["Date"] = pd.to_datetime(df["Date"]).dt.normalize()
        return df

    # Try google-cloud-storage client first (preferred)
    if not use_gcsfs:
        try:
            from google.cloud import storage
        except Exception as e:
            # fallback to gcsfs if storage client not available
            if gcsfs is not None:
                use_gcsfs = True
            else:
                raise RuntimeError("Install google-cloud-storage or gcsfs to read CSV from GCS") from e
        if not use_gcsfs:
            _, rest = csv_path.split("gs://", 1)
            bucket_name, _, blob_path = rest.partition("/")
            client_kwargs = {}
            if requester_pays_project:
                client_kwargs["project"] = requester_pays_project
            client = storage.Client(**client_kwargs)
            bucket = client.lookup_bucket(bucket_name)
            if bucket is None:
                raise FileNotFoundError(f"Bucket not found or access denied: gs://{bucket_name}")
            blob = bucket.blob(blob_path)
            if not blob.exists():
                raise FileNotFoundError(f"CSV blob not found: {csv_path}")
            data = blob.download_as_bytes()
            df = pd.read_csv(io.BytesIO(data))
            df.columns = [c.strip() for c in df.columns]
            df["Date"] = pd.to_datetime(df["Date"]).dt.normalize()
            print("Loaded CSV via google-cloud-storage client")
            return df

    # If we fall through here, attempt gcsfs
    if gcsfs is None:
        raise RuntimeError("gcsfs is not installed; install gcsfs or enable google-cloud-storage")
    try:
        fs_kwargs = {}
        if requester_pays_project:
            fs_kwargs["project"] = requester_pays_project
        fs = gcsfs.GCSFileSystem(token="google", **fs_kwargs)
        with fs.open(csv_path, "rb") as f:
            df = pd.read_csv(f)
        df.columns = [c.strip() for c in df.columns]
        df["Date"] = pd.to_datetime(df["Date"]).dt.normalize()
        print("Loaded CSV via gcsfs")
        return df
    except Exception as e:
        print("gcsfs read failed (and google-cloud-storage fallback wasn't used):", repr(e))
        raise

def list_tif_files(tif_dir: str, requester_pays_project: str | None = None, use_gcsfs: bool = False) -> list:
    """List .tif files either locally or from GCS. Prefers storage client unless use_gcsfs=True."""
    if not is_gcs_path(tif_dir):
        return sorted([os.path.join(tif_dir, f) for f in os.listdir(tif_dir) if f.lower().endswith('.tif')])

    # Prefer google-cloud-storage client
    if not use_gcsfs:
        try:
            from google.cloud import storage
        except Exception as e:
            if gcsfs is not None:
                use_gcsfs = True
            else:
                raise RuntimeError("Install google-cloud-storage or gcsfs to list GCS tifs") from e
        if not use_gcsfs:
            _, rest = tif_dir.split("gs://", 1)
            bucket_name, _, prefix = rest.partition("/")
            client_kwargs = {}
            if requester_pays_project:
                client_kwargs["project"] = requester_pays_project
            client = storage.Client(**client_kwargs)
            bucket = client.lookup_bucket(bucket_name)
            if bucket is None:
                raise FileNotFoundError(f"Bucket not found: gs://{bucket_name}")
            prefix = prefix.rstrip('/') + '/'
            blobs = client.list_blobs(bucket_name, prefix=prefix)
            files = []
            for b in blobs:
                if b.name.lower().endswith('.tif'):
                    files.append("gs://" + bucket_name + "/" + b.name)
            return sorted(files)

    # fallback to gcsfs
    if gcsfs is None:
        raise RuntimeError("gcsfs is not installed; install gcsfs or enable google-cloud-storage")
    fs_kwargs = {}
    if requester_pays_project:
        fs_kwargs["project"] = requester_pays_project
    fs = gcsfs.GCSFileSystem(token="google", **fs_kwargs)
    pattern = tif_dir.rstrip('/') + '/*.tif'
    files = fs.glob(pattern)
    files = [f if f.startswith("gs://") else "gs://" + f for f in files]
    return sorted(files)

def parse_sat_dates_from_tifs(tif_dir: str, requester_pays_project: str | None = None, use_gcsfs: bool = False) -> list:
    tif_files = list_tif_files(tif_dir, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs)
    print(f"Found {len(tif_files)} .tif files in {tif_dir}")
    dates = set()
    date_pat1 = re.compile(r'(\d{4}-\d{2}-\d{2})')
    date_pat2 = re.compile(r'(\d{8})')
    for fp in tif_files:
        bn = os.path.basename(fp)
        for pat, fmt in [(date_pat1, None), (date_pat2, "%Y%m%d")]:
            m = pat.search(bn)
            if m:
                try:
                    d = pd.to_datetime(m.group(1), format=fmt).normalize()
                    dates.add(pd.Timestamp(d))
                    break
                except Exception:
                    pass
    dates = sorted(list(dates))
    print(f"Parsed {len(dates)} unique satellite dates.")
    return dates

# ---------------------------
# Data engineering (binary-aware)
# ---------------------------
def filter_to_sat_days(df: pd.DataFrame, sat_dates: list) -> pd.DataFrame:
    before = len(df)
    df_sat = df[df["Date"].isin(set(sat_dates))].copy()
    after = len(df_sat)
    print(f"Filtered to satellite days: {before} → {after}")
    return df_sat

def build_samples(df_sat: pd.DataFrame, sat_dates: list, meta_cols: list, target_col: str, hist_days=14, fut_days=10, threshold=BLOOM_THRESHOLD_DEFAULT):
    excluded = set(meta_cols + [target_col, "y_bin"])
    num_cols = [c for c in df_sat.columns if c not in excluded and pd.api.types.is_numeric_dtype(df_sat[c])]
    sat_dates = sorted(pd.to_datetime(sat_dates))
    df_indexed = df_sat.set_index(["Pixel_ID", "Date"])
    samples = []
    dropped_no_history = dropped_no_future = 0
    for pid in df_indexed.index.get_level_values("Pixel_ID").unique():
        idx = df_indexed.loc[pid]
        pid_dates = sorted(idx.index)
        for t in pid_dates:
            t_ts = pd.Timestamp(t)
            hist_start = t_ts - timedelta(days=hist_days)
            hist_rows = idx.loc[(idx.index > hist_start) & (idx.index <= t_ts)]
            if hist_rows.empty:
                dropped_no_history += 1
                continue
            hist_means = hist_rows[num_cols + ([target_col] if target_col in idx.columns else [])].mean(axis=0, skipna=True)
            fut_end = t_ts + timedelta(days=fut_days)
            fut_candidates = [s for s in pid_dates if t_ts < s <= fut_end]
            if not fut_candidates:
                dropped_no_future += 1
                continue
            s_star = min(fut_candidates)
            y_raw = idx.loc[s_star, target_col] if target_col in idx.columns else np.nan
            if pd.isna(y_raw):
                dropped_no_future += 1
                continue
            y_bin = None
            if "y_bin" in idx.columns:
                try:
                    y_bin = int(idx.loc[s_star, "y_bin"])
                except Exception:
                    y_bin = None
            if y_bin is None:
                if np.isin(y_raw, [0, 1]) or (pd.api.types.is_integer_dtype(type(y_raw)) and (y_raw in (0, 1))):
                    try:
                        y_bin = int(y_raw)
                    except Exception:
                        y_bin = int(y_raw >= threshold)
                else:
                    y_bin = int(y_raw >= threshold)
            row_meta = idx.loc[t_ts]
            feature_dict = {
                "Pixel_ID": pid,
                "date_t": t_ts,
                "date_target": s_star,
                "horizon": (s_star - t_ts).days,
                "X": row_meta.get("X", np.nan),
                "Y": row_meta.get("Y", np.nan),
                "chl_a_t": row_meta.get(target_col, np.nan),
                "y_raw": y_raw,
                "y_bin": y_bin
            }
            for c in hist_means.index:
                feature_dict[f"histmean_{c}"] = hist_means[c]
            samples.append(feature_dict)
    print(f"Built {len(samples)} samples (dropped_no_history={dropped_no_history}, dropped_no_future={dropped_no_future})")
    samples_df = pd.DataFrame(samples)
    feature_cols = [c for c in samples_df.columns if c not in ["Pixel_ID", "date_t", "date_target", "y_raw", "y_bin"]]
    return samples_df, feature_cols

def strict_drop_and_split(samples_df: pd.DataFrame, feature_cols: list, out_dir: str, requester_pays_project: str | None = None):
    before = len(samples_df)
    samples_clean = samples_df.dropna(subset=feature_cols + ["y_raw"]).copy()
    print(f"Dropped {before - len(samples_clean)} samples due to NaNs in features or y_raw")
    samples_clean["year_t"] = samples_clean["date_t"].dt.year
    train_df = samples_clean[samples_clean["year_t"] <= 2022].copy()
    val_df = samples_clean[samples_clean["year_t"] == 2023].copy()
    test_df = samples_clean[samples_clean["year_t"] == 2024].copy()
    tmpdir = tempfile.mkdtemp(prefix="pigeon_samples_")
    try:
        samples_clean.to_csv(os.path.join(tmpdir, "samples_all_strict.csv"), index=False)
        train_df.to_csv(os.path.join(tmpdir, "samples_train.csv"), index=False)
        val_df.to_csv(os.path.join(tmpdir, "samples_val.csv"), index=False)
        test_df.to_csv(os.path.join(tmpdir, "samples_test.csv"), index=False)
        for fname in ["samples_all_strict.csv", "samples_train.csv", "samples_val.csv", "samples_test.csv"]:
            local_fp = os.path.join(tmpdir, fname)
            save_path = save_local_or_gcs(local_fp, out_dir, dest_basename=fname, requester_pays_project=requester_pays_project)
            print("Saved:", save_path)
    finally:
        shutil.rmtree(tmpdir, ignore_errors=True)
    print(f"Split → train={len(train_df)}, val={len(val_df)}, test={len(test_df)}")
    return train_df, val_df, test_df

def fit_imputer_and_scaler(train_df: pd.DataFrame, val_df: pd.DataFrame, test_df: pd.DataFrame, feature_cols: list, out_dir: str, random_state=0, requester_pays_project: str | None = None):
    if "chl_a_t" not in feature_cols:
        feature_cols.append("chl_a_t")
        print("Added chl_a_t to feature list.")
    X_train = train_df[feature_cols].values
    X_val = val_df[feature_cols].values
    X_test = test_df[feature_cols].values
    imputer = IterativeImputer(estimator=BayesianRidge(), max_iter=100, random_state=random_state)
    X_train_imp = imputer.fit_transform(X_train)
    X_val_imp = imputer.transform(X_val)
    X_test_imp = imputer.transform(X_test)
    tmp_im = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib")
    tmp_im.close()
    joblib.dump(imputer, tmp_im.name)
    imputer_dest = save_local_or_gcs(tmp_im.name, out_dir, dest_basename="iterative_imputer_binary.joblib", requester_pays_project=requester_pays_project)
    print("Saved imputer to:", imputer_dest)
    df_train_imp = pd.DataFrame(X_train_imp, columns=feature_cols)
    df_val_imp = pd.DataFrame(X_val_imp, columns=feature_cols)
    df_test_imp = pd.DataFrame(X_test_imp, columns=feature_cols)
    cols_to_standardize = [c for c in feature_cols if ("histmean_" in c) or (c == "chl_a_t")]
    print("Standardizing columns:", cols_to_standardize)
    scaler = StandardScaler()
    if cols_to_standardize:
        scaler.fit(df_train_imp[cols_to_standardize])
        df_train_imp[cols_to_standardize] = scaler.transform(df_train_imp[cols_to_standardize])
        df_val_imp[cols_to_standardize] = scaler.transform(df_val_imp[cols_to_standardize])
        df_test_imp[cols_to_standardize] = scaler.transform(df_test_imp[cols_to_standardize])
    tmp_sc = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib")
    tmp_sc.close()
    joblib.dump(scaler, tmp_sc.name)
    scaler_dest = save_local_or_gcs(tmp_sc.name, out_dir, dest_basename="standard_scaler_binary.joblib", requester_pays_project=requester_pays_project)
    print("Saved scaler to:", scaler_dest)
    y_train = np.log1p(train_df["y_raw"].values)
    y_val = np.log1p(val_df["y_raw"].values)
    y_test = np.log1p(test_df["y_raw"].values)
    y_train_bin = train_df["y_bin"].astype(int).values
    y_val_bin = val_df["y_bin"].astype(int).values
    y_test_bin = test_df["y_bin"].astype(int).values
    tmp_npz = tempfile.NamedTemporaryFile(delete=False, suffix=".npz")
    tmp_npz.close()
    np.savez_compressed(
        tmp_npz.name,
        X_train=df_train_imp.values, X_val=df_val_imp.values, X_test=df_test_imp.values,
        y_train=y_train, y_val=y_val, y_test=y_test,
        y_train_bin=y_train_bin, y_val_bin=y_val_bin, y_test_bin=y_test_bin,
        feature_cols=feature_cols
    )
    dest_npz = save_local_or_gcs(tmp_npz.name, out_dir, dest_basename="processed_dataset_binary.npz", requester_pays_project=requester_pays_project)
    print("Saved processed dataset to:", dest_npz)
    return df_train_imp, df_val_imp, df_test_imp, y_train, y_val, y_test, y_train_bin, y_val_bin, y_test_bin

# ---------------------------
# Main
# ---------------------------
def main(args):
    use_gcsfs = bool(args.use_gcsfs)
    # load CSV first
    df = load_csv(args.csv, requester_pays_project=args.requester_pays_project, use_gcsfs=use_gcsfs)

    # Try to parse sentinel dates from tif files. If none found, fall back to using CSV dates.
    sat_dates = parse_sat_dates_from_tifs(args.tif_dir, requester_pays_project=args.requester_pays_project, use_gcsfs=use_gcsfs)
    if len(sat_dates) == 0:
        # Fallback: use unique dates in CSV (useful when tifs are not present or data is already pre-aligned)
        csv_unique_dates = sorted(df["Date"].dropna().unique())
        if len(csv_unique_dates) == 0:
            print("ERROR: No satellite .tif dates found and CSV contains no valid dates. Cannot proceed.", file=sys.stderr)
            sys.exit(1)
        sat_dates = [pd.Timestamp(d).normalize() for d in csv_unique_dates]
        print(f"Warning: No .tif files found in {args.tif_dir}. Falling back to using unique dates from CSV ({len(sat_dates)} dates).")

    df_sat = filter_to_sat_days(df, sat_dates)
    samples_df, feature_cols = build_samples(df_sat, sat_dates, META_COLS, args.target_col, hist_days=args.hist_days, fut_days=args.fut_days, threshold=args.threshold)

    # If building samples yields none, exit with helpful message instead of raising KeyError later
    if samples_df.empty:
        print("ERROR: No samples were built. This can happen when:\n"
              " - the tif_dir path is incorrect or contains no .tif files, or\n"
              " - the CSV and tif dates do not overlap, or\n"
              " - the CSV has missing y_raw values for future horizons.\n"
              "Please verify the tif_dir and csv paths, or re-run data prep with --use_gcsfs to debug gcsfs listing.\n", file=sys.stderr)
        # Helpful listing suggestion for debugging
        print("Quick checks you can run:\n"
              f"  gsutil ls gs://<bucket>/{args.tif_dir.split('gs://')[-1].lstrip('/')}  # list objects under tif_dir\n"
              f"  gsutil ls gs://{args.csv.split('gs://')[-1].split('/')[0]}/  # list buckets prefix\n", file=sys.stderr)
        sys.exit(1)

    train_df, val_df, test_df = strict_drop_and_split(samples_df, feature_cols, args.out_dir, requester_pays_project=args.requester_pays_project)
    fit_imputer_and_scaler(train_df, val_df, test_df, feature_cols, args.out_dir, args.random_state, requester_pays_project=args.requester_pays_project)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Pigeon Lake Data Preparation (binary-aware, GCS-friendly; storage-client preferred)")
    parser.add_argument("--csv", type=str, default=DEFAULT_CSV, help="Path to CSV (local or gs://)")
    parser.add_argument("--tif_dir", type=str, default=SAT_TIF_DIR, help="Directory of .tif files (local or gs://)")
    parser.add_argument("--out_dir", type=str, default=DEFAULT_OUT_DIR, help="Output directory (local path or gs:// prefix)")
    parser.add_argument("--target_col", type=str, default=TARGET_COL, help="Name of continuous target column in CSV (if present)")
    parser.add_argument("--hist_days", type=int, default=14, help="History window (days)")
    parser.add_argument("--fut_days", type=int, default=10, help="Horizon search window (days)")
    parser.add_argument("--threshold", type=float, default=BLOOM_THRESHOLD_DEFAULT, help="Threshold to derive binary label if needed")
    parser.add_argument("--random_state", type=int, default=0)
    parser.add_argument("--requester_pays_project", type=str, default=None, help="GCP project id to use for requester-pays buckets (optional)")
    parser.add_argument("--use_gcsfs", action="store_true", help="Attempt gcsfs first (may show 401 traces if not authenticated). Default = use storage client.")
    # Jupyter-safe argparse
    if "__file__" not in globals():
        args, _ = parser.parse_known_args(sys.argv)
    else:
        args = parser.parse_args()
    main(args)

Loading CSV: gs://final_data_kgc2/Final_data/Pigeon_pixel_80m_binary.csv  (use_gcsfs=False)
Loaded CSV via google-cloud-storage client
Found 281 .tif files in gs://final_data_kgc2/Final_data/Pigeon_80m/
Parsed 281 unique satellite dates.
Filtered to satellite days: 2364855 → 2364855
Built 2214239 samples (dropped_no_history=0, dropped_no_future=150616)
Dropped 468717 samples due to NaNs in features or y_raw
Saved: gs://final_data_kgc2/models/Pigeon_binary/samples_all_strict.csv
Saved: gs://final_data_kgc2/models/Pigeon_binary/samples_train.csv
Saved: gs://final_data_kgc2/models/Pigeon_binary/samples_val.csv
Saved: gs://final_data_kgc2/models/Pigeon_binary/samples_test.csv
Split → train=948281, val=74421, test=495447
Saved imputer to: gs://final_data_kgc2/models/Pigeon_binary/iterative_imputer_binary.joblib
Standardizing columns: ['chl_a_t', 'histmean_Day', 'histmean_Month', 'histmean_Year', 'histmean_ODO mg/L', 'histmean_ODO % local', 'histmean_Turbidity NTU', 'histmean_Water Temperatu

In [2]:
#!/usr/bin/env python3
"""
Train RandomForest classifier for binary bloom/no-bloom — GCS-ready final variant
with optional tagging of output artifacts to avoid overwriting when running multiple lakes.

This patched version fixes a bug that caused a ValueError when constructing the
per-pixel grid prediction CSV (arrays of different lengths). It also ensures the
"enriched" predictions CSV contains date_target, pixel locations (xi, yi, pixel_lon,
pixel_lat) and observed/predicted values.

Usage is unchanged; outputs are tagged by --tag (default "nak").
"""
from __future__ import annotations
import os
import argparse
import json
import tempfile
import shutil
import time
import io
from pathlib import Path
from typing import Optional, Tuple

import numpy as np
import pandas as pd

from sklearn.model_selection import TimeSeriesSplit, PredefinedSplit, ParameterSampler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    precision_score, recall_score, f1_score, roc_auc_score, average_precision_score,
    brier_score_loss, confusion_matrix
)
from sklearn.base import clone
import joblib

# joblib parallel backend
from joblib import parallel_backend

# optional imports
try:
    import gcsfs
except Exception:
    gcsfs = None

try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = None

# Try import google-cloud-storage (preferred)
try:
    from google.cloud import storage
except Exception:
    storage = None

import warnings
warnings.filterwarnings("ignore", message="pkg_resources is deprecated as an API")
warnings.filterwarnings("ignore", category=FutureWarning)

# ---------------------------
# Defaults pointing to your bucket
# ---------------------------
DEFAULT_NPZ = "gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz"
DEFAULT_SAMPLES_DIR = "gs://final_data_kgc2/Final_data/"
DEFAULT_OUT_DIR = "gs://final_data_kgc2/rf_results/"
RANDOM_STATE = 0

# ---------------------------
# GCS / local helpers (storage client preferred)
# ---------------------------
def is_gcs_path(p: str) -> bool:
    return isinstance(p, str) and p.startswith("gs://")

def gcs_join(prefix: str, *parts: str) -> str:
    p = prefix.rstrip('/')
    for part in parts:
        p = p.rstrip('/') + '/' + str(part).lstrip('/')
    return p

def upload_file_to_gcs(local_path: str, gs_uri: str, requester_pays_project: str | None = None):
    """Upload a local file to gs:// using google-cloud-storage (preferred) or gcsfs fallback."""
    if not is_gcs_path(gs_uri):
        raise ValueError("gs_uri must be gs:// path")
    # prefer google-cloud-storage
    if storage is not None:
        client_kwargs = {}
        if requester_pays_project:
            client_kwargs["project"] = requester_pays_project
        client = storage.Client(**client_kwargs)
        _, rest = gs_uri.split("gs://", 1)
        bucket_name, _, blob_path = rest.partition("/")
        bucket = client.lookup_bucket(bucket_name)
        if bucket is None:
            raise FileNotFoundError(f"Destination bucket does not exist or is not accessible: gs://{bucket_name}")
        blob = bucket.blob(blob_path)
        blob.upload_from_filename(local_path)
        return
    # fallback to gcsfs
    if gcsfs is not None:
        fs = gcsfs.GCSFileSystem(token="google")
        fs.put(local_path, gs_uri)
        return
    raise RuntimeError("No method available to upload to GCS: install google-cloud-storage or gcsfs")

def fetch_gcs_to_local(gcs_path: str, local_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    """Download a gs:// object to a local path. Prefer storage client, fallback to gcsfs."""
    if not is_gcs_path(gcs_path):
        raise ValueError("gcs path required")
    if storage is not None and not use_gcsfs:
        client_kwargs = {}
        if requester_pays_project:
            client_kwargs["project"] = requester_pays_project
        client = storage.Client(**client_kwargs)
        _, rest = gcs_path.split("gs://", 1)
        bucket_name, _, blob_path = rest.partition("/")
        bucket = client.lookup_bucket(bucket_name)
        if bucket is None:
            raise FileNotFoundError(f"Bucket not found or access denied: gs://{bucket_name}")
        blob = bucket.blob(blob_path)
        if not blob.exists():
            raise FileNotFoundError(f"Blob not found: {gcs_path}")
        blob.download_to_filename(local_path)
        return local_path
    # fallback to gcsfs
    if gcsfs is None:
        raise RuntimeError("gcsfs not installed and storage client not available")
    fs = gcsfs.GCSFileSystem(token="google")
    with fs.open(gcs_path, "rb") as src, open(local_path, "wb") as dst:
        shutil.copyfileobj(src, dst)
    return local_path

def load_npz(npz_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False) -> dict:
    """Load npz from local path or GCS into a dict."""
    if is_gcs_path(npz_path):
        tmp = tempfile.NamedTemporaryFile(suffix=".npz", delete=False)
        tmp.close()
        fetch_gcs_to_local(npz_path, tmp.name, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs)
        data = np.load(tmp.name, allow_pickle=True)
        result = {k: data[k] for k in data.files}
        try:
            os.unlink(tmp.name)
        except Exception:
            pass
        return result
    else:
        data = np.load(npz_path, allow_pickle=True)
        return {k: data[k] for k in data.files}

def load_samples_csvs(samples_dir: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    """Load samples_train/val/test CSVs (used for metadata & saving preds)."""
    def load_one(path_or_dir, name):
        if is_gcs_path(str(path_or_dir)):
            # prefer storage client
            if storage is not None and not use_gcsfs:
                client_kwargs = {}
                if requester_pays_project:
                    client_kwargs["project"] = requester_pays_project
                client = storage.Client(**client_kwargs)
                _, rest = str(path_or_dir).split("gs://", 1)
                bucket_name, _, prefix = rest.partition("/")
                prefix = prefix.rstrip('/') + '/'
                blob_path = prefix + name
                bucket = client.lookup_bucket(bucket_name)
                if bucket is None:
                    raise FileNotFoundError(f"Bucket not found: gs://{bucket_name}")
                blob = bucket.blob(blob_path)
                if not blob.exists():
                    raise FileNotFoundError(f"Samples CSV not found: gs://{bucket_name}/{blob_path}")
                data = blob.download_as_bytes()
                return pd.read_csv(io.BytesIO(data), parse_dates=['date_t', 'date_target'], low_memory=False)
            # fallback gcsfs
            if gcsfs is None:
                raise RuntimeError("gcsfs not installed and storage client not available")
            fs = gcsfs.GCSFileSystem(token="google")
            target = str(path_or_dir).rstrip('/') + '/' + name
            with fs.open(target, "rb") as f:
                return pd.read_csv(f, parse_dates=['date_t', 'date_target'], low_memory=False)
        else:
            p = Path(path_or_dir) / name
            return pd.read_csv(str(p), parse_dates=['date_t', 'date_target'], low_memory=False)
    return load_one(samples_dir, "samples_train.csv"), load_one(samples_dir, "samples_val.csv"), load_one(samples_dir, "samples_test.csv")

# ---------------------------
# Serialization helpers
# ---------------------------
def _save_json_atomic_local(path: str, data):
    tmp = path + ".tmp"
    with open(tmp, "w") as fh:
        json.dump(data, fh, default=lambda o: repr(o), indent=2)
    os.replace(tmp, path)

def _to_serializable(obj):
    if isinstance(obj, dict):
        return {k: _to_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_serializable(v) for v in obj]
    if isinstance(obj, np.ndarray):
        return _to_serializable(obj.tolist())
    if isinstance(obj, (np.integer,)):
        return int(obj)
    if isinstance(obj, (np.floating,)):
        return float(obj)
    if isinstance(obj, (np.bool_, bool)):
        return bool(obj)
    return obj

# ----------------- helper utilities for grid/date/coords -----------------
def create_regular_grid_edges_and_centers(xs, ys, grid_res=400, buffer_frac=0.02):
    xs = np.asarray(xs); ys = np.asarray(ys)
    xmin, xmax = float(np.nanmin(xs)), float(np.nanmax(xs))
    ymin, ymax = float(np.nanmin(ys)), float(np.nanmax(ys))
    if xmax == xmin: xmax = xmin + 1e-6
    if ymax == ymin: ymax = ymin + 1e-6
    xpad = (xmax - xmin) * buffer_frac; ypad = (ymax - ymin) * buffer_frac
    xmin -= xpad; xmax += xpad; ymin -= ypad; ymax += ypad
    lon_span = xmax - xmin; lat_span = ymax - ymin
    longest = max(lon_span, lat_span)
    desired_cells = max(2, int(grid_res))
    cell_size = float(longest) / float(desired_cells)
    nx = max(2, int(np.ceil(lon_span / cell_size))); ny = max(2, int(np.ceil(lat_span / cell_size)))
    xmax = xmin + nx * cell_size; ymax = ymin + ny * cell_size
    x_edges = np.linspace(xmin, xmax, nx+1); y_edges = np.linspace(ymin, ymax, ny+1)
    x_centers = 0.5*(x_edges[:-1] + x_edges[1:]); y_centers = 0.5*(y_edges[:-1] + y_edges[1:])
    return x_edges, y_edges, x_centers, y_centers

def detect_date_column_simple(df: pd.DataFrame) -> Optional[str]:
    """Pick a likely date column name (basic): prefer date_target, date_t, Date, date."""
    for c in ("date_target","date_t","Date","date"):
        if c in df.columns:
            return c
    for c in df.columns:
        if "date" in c.lower() or "day" in c.lower():
            return c
    return None

def find_coord_cols(df: pd.DataFrame) -> Tuple[Optional[str], Optional[str]]:
    lon_candidates = ["pixel_lon","lon","longitude","x","X","sample_lon","long"]
    lat_candidates = ["pixel_lat","lat","latitude","y","Y","sample_lat"]
    cols_lower = {c.lower(): c for c in df.columns}
    lon = None; lat = None
    for cand in lon_candidates:
        if cand.lower() in cols_lower:
            lon = cols_lower[cand.lower()]; break
    for cand in lat_candidates:
        if cand.lower() in cols_lower:
            lat = cols_lower[cand.lower()]; break
    return lon, lat

# ----------------- robust scoring callable -----------------
def scoring_callable(estimator, X, y):
    try:
        if hasattr(estimator, "predict_proba"):
            yprob = estimator.predict_proba(X)[:, 1]
            return float(average_precision_score(y, yprob))
        elif hasattr(estimator, "decision_function"):
            yprob = estimator.decision_function(X)
            return float(average_precision_score(y, yprob))
        else:
            yhat = estimator.predict(X)
            return float(f1_score(y, yhat, zero_division=0))
    except Exception:
        return float("-inf")

# ----------------- checkpointed randomized search -----------------
def _save_json_atomic(path, data):
    try:
        _save_json_atomic_local(path, data)
    except Exception:
        with open(path, "w") as fh:
            json.dump(data, fh, default=lambda o: repr(o), indent=2)

def checkpointed_random_search(estimator, param_dist, X_train, y_train, cv, n_iter=8, random_state=RANDOM_STATE, checkpoint_path=None, scorer=None, verbose=False, checkpoint_verbose=False):
    if checkpoint_path is None:
        checkpoint_path = os.path.join(tempfile.gettempdir(), "rf_search_checkpoint.json")
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, "r") as fh:
                history = json.load(fh)
            tried = set(h['params_repr'] for h in history if 'params_repr' in h)
            if verbose:
                print(f"Loaded checkpoint with {len(history)} completed candidates from {checkpoint_path}")
        except Exception:
            history = []; tried = set()
    else:
        history = []; tried = set()
    param_list = list(ParameterSampler(param_dist, n_iter=n_iter, random_state=random_state))
    remaining = [p for p in param_list if repr(p) not in tried]
    best_score = -np.inf; best_params = None; best_est = None
    iterator = enumerate(remaining, start=1)
    if tqdm is not None and verbose:
        iterator = enumerate(tqdm(remaining, desc="RF candidates", ncols=100), start=1)
    for i, params in iterator:
        p_repr = repr(params)
        if verbose:
            print(f"\nCandidate {i}/{len(remaining)}: {params}")
        est = clone(estimator); est.set_params(**params)
        fold_scores = []; fold_times = []
        for fold_idx, (tr_idx, te_idx) in enumerate(cv.split(X_train), start=1):
            t0 = time.time()
            try:
                est.fit(X_train[tr_idx], y_train[tr_idx])
                try:
                    if scorer is not None:
                        score = float(scorer(est, X_train[te_idx], y_train[te_idx]))
                    else:
                        score = float(scoring_callable(est, X_train[te_idx], y_train[te_idx]))
                except TypeError as te:
                    if verbose or checkpoint_verbose:
                        print("Scorer TypeError, falling back. Error:", te)
                    score = float(scoring_callable(est, X_train[te_idx], y_train[te_idx]))
                except Exception as e_other:
                    if verbose or checkpoint_verbose:
                        print("Scorer exception, using fallback:", e_other)
                    score = float(scoring_callable(est, X_train[te_idx], y_train[te_idx]))
            except Exception as e:
                score = float("-inf")
                if verbose:
                    print("  fold error during fit/predict/score:", e)
            elapsed = time.time() - t0
            fold_scores.append(score); fold_times.append(elapsed)
            if verbose:
                print(f"  fold {fold_idx}: score={score if np.isfinite(score) else 'FAILED'}, time={elapsed:.1f}s")
        mean_score = float(np.nanmean([s for s in fold_scores if np.isfinite(s)])) if len(fold_scores) > 0 else float("-inf")
        rec = {"params_repr": p_repr, "params": params, "fold_scores": fold_scores, "fold_times": fold_times, "mean_score": mean_score, "timestamp": time.time()}
        history.append(rec)
        try:
            _save_json_atomic(checkpoint_path, history)
            if verbose:
                print("  Saved checkpoint to", checkpoint_path)
        except Exception:
            pass
        if mean_score > best_score:
            best_score = mean_score; best_params = params
            try:
                best_est = clone(estimator); best_est.set_params(**params)
                best_est.fit(X_train, y_train)
                if verbose:
                    print("  -> New best found and refit on full training set.")
            except Exception as e:
                if verbose:
                    print("  -> Refit of best failed:", e)
    return best_est, best_params, history

def run_random_search(estimator, param_dist, X_train, y_train, cv, n_iter=8, n_jobs=1, scorer=None, random_state=RANDOM_STATE, verbose_search=False, checkpoint_path=None):
    if checkpoint_path:
        if verbose_search:
            print("Running checkpointed randomized search (checkpoint_path=%s)" % checkpoint_path)
        best_est, best_params, history = checkpointed_random_search(estimator, param_dist, X_train, y_train, cv, n_iter=n_iter, random_state=random_state, checkpoint_path=checkpoint_path, scorer=scorer, verbose=verbose_search)
        return best_est, best_params, history
    from sklearn.model_selection import RandomizedSearchCV
    search = RandomizedSearchCV(estimator, param_distributions=param_dist, n_iter=n_iter,
                                scoring=scoring_callable, cv=cv, random_state=random_state,
                                n_jobs=n_jobs, verbose=1 if verbose_search else 0)
    t0 = time.time()
    if n_jobs and n_jobs > 1:
        with parallel_backend('threading', n_jobs=n_jobs):
            search.fit(X_train, y_train)
    else:
        search.fit(X_train, y_train)
    elapsed = time.time() - t0
    print(f"Search finished in {elapsed:.1f}s; best_score={search.best_score_:.4f}")
    return search.best_estimator_, getattr(search, "best_params_", None), search

def evaluate_classification(y_true, y_pred_label, y_pred_prob):
    y_true = np.asarray(y_true)
    y_pred_label = np.asarray(y_pred_label)
    y_pred_prob = np.asarray(y_pred_prob)
    try:
        prec = float(precision_score(y_true, y_pred_label, zero_division=0))
        rec = float(recall_score(y_true, y_pred_label, zero_division=0))
        f1 = float(f1_score(y_true, y_pred_label, zero_division=0))
    except Exception:
        prec = rec = f1 = float("nan")
    try:
        ap = float(average_precision_score(y_true, y_pred_prob))
    except Exception:
        ap = float("nan")
    try:
        roc = float(roc_auc_score(y_true, y_pred_prob))
    except Exception:
        roc = float("nan")
    try:
        brier = float(brier_score_loss(y_true, y_pred_prob))
    except Exception:
        brier = float("nan")
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred_label).ravel()
    except Exception:
        tn = fp = fn = tp = None
    return {
        "precision": prec, "recall": rec, "f1": f1,
        "average_precision": ap, "roc_auc": roc, "brier": brier,
        "tn": tn, "fp": fp, "fn": fn, "tp": tp
    }

# ----------------- main -----------------
def main(args):
    local_tmp_out = tempfile.mkdtemp(prefix="rf_out_")
    try:
        tag = getattr(args, "tag", "")
        def tag_name(base_filename: str) -> str:
            if not tag:
                return base_filename
            base_noext, ext = os.path.splitext(base_filename)
            return f"{base_noext}_{tag}{ext}"

        checkpoint_path = args.checkpoint_path if args.checkpoint_path else os.path.join(local_tmp_out, tag_name("rf_search_checkpoint.json"))

        print("Loading processed arrays from:", args.npz)
        data = load_npz(args.npz, requester_pays_project=args.requester_pays_project, use_gcsfs=args.use_gcsfs)
        if 'X_train' not in data:
            raise RuntimeError("processed_dataset must contain X_train/X_val/X_test arrays (and y_*_bin etc).")

        X_train = data['X_train']; X_val = data['X_val']; X_test = data['X_test']
        feature_cols = list(data['feature_cols']) if 'feature_cols' in data else None

        if 'y_train_bin' in data and 'y_val_bin' in data and 'y_test_bin' in data:
            y_train_bin = np.asarray(data['y_train_bin'])
            y_val_bin = np.asarray(data['y_val_bin'])
            y_test_bin = np.asarray(data['y_test_bin'])
        else:
            print("Binary targets not found in NPZ — attempting to load from samples CSVs.")
            train_df, val_df, test_df = load_samples_csvs(args.samples_dir, requester_pays_project=args.requester_pays_project, use_gcsfs=args.use_gcsfs)
            if 'y_bin' not in train_df.columns:
                raise RuntimeError("y_bin column not found in samples CSVs nor NPZ. Please run data prep to save y_bin.")
            y_train_bin = train_df['y_bin'].astype(int).values
            y_val_bin = val_df['y_bin'].astype(int).values
            y_test_bin = test_df['y_bin'].astype(int).values

        print(f"Shapes: X_train={X_train.shape}, X_val={X_val.shape}, X_test={X_test.shape}")

        # load sample csvs for metadata (non-fatal)
        try:
            train_df, val_df, test_df = load_samples_csvs(args.samples_dir, requester_pays_project=args.requester_pays_project, use_gcsfs=args.use_gcsfs)
        except Exception:
            train_df = val_df = test_df = None

        # subsample for tuning if requested
        if args.max_train_samples:
            X_train_sub = X_train[:args.max_train_samples]; y_train_sub = y_train_bin[:args.max_train_samples]
            print("Using subsample for tuning:", X_train_sub.shape)
        else:
            X_train_sub, y_train_sub = X_train, y_train_bin

        # build CV object
        if args.use_predefined_split:
            X_comb = np.vstack([X_train_sub, X_val]); y_comb = np.concatenate([y_train_sub, y_val_bin])
            test_fold = np.concatenate([np.full(X_train_sub.shape[0], -1), np.zeros(X_val.shape[0], dtype=int)])
            cv_obj = PredefinedSplit(test_fold)
            X_tune, y_tune = X_comb, y_comb
            print("Using PredefinedSplit (train+val) for tuning.")
        else:
            cv_obj = TimeSeriesSplit(n_splits=args.ts_splits)
            X_tune, y_tune = X_train_sub, y_train_sub
            print(f"Using TimeSeriesSplit(n_splits={args.ts_splits})")

        # parameter distributions
        if args.fast:
            param_dist = {
                "n_estimators": [100, 200],
                "max_depth": [5, 8],
                "min_samples_leaf": [5, 10],
                "min_samples_split": [5, 10],
                "max_features": ["sqrt", 0.5],
                "class_weight": ["balanced"],
                "bootstrap": [True],
                "max_samples": [0.6, 0.8]
            }
        else:
            param_dist = {
                "n_estimators": [200, 500],
                "max_depth": [5, 8, 10],
                "min_samples_leaf": [3, 5, 10],
                "min_samples_split": [2, 5, 8],
                "max_features": ["sqrt", 0.5, 0.7],
                "class_weight": ["balanced", None],
                "bootstrap": [True],
                "max_samples": [0.5, 0.6, 0.8]
            }

        rf = RandomForestClassifier(random_state=RANDOM_STATE, n_jobs=1)

        best_rf, rf_params, rf_search = run_random_search(
            rf, param_dist, X_tune, y_tune, cv_obj,
            n_iter=args.n_iter, n_jobs=min(args.n_jobs, 8),
            scorer=None,
            random_state=RANDOM_STATE, verbose_search=args.verbose, checkpoint_path=checkpoint_path
        )
        if best_rf is None:
            print("No RF candidate found (search returned None). Exiting.")
            return

        best_initial_local = os.path.join(local_tmp_out, tag_name("rf_best_initial_binary.joblib"))
        joblib.dump(best_rf, best_initial_local)
        if is_gcs_path(args.out_dir):
            dest_best_initial = gcs_join(args.out_dir, os.path.basename(best_initial_local))
            upload_file_to_gcs(best_initial_local, dest_best_initial, requester_pays_project=args.requester_pays_project)
            print("Saved initial best RF to:", dest_best_initial)
        else:
            out_best = os.path.join(args.out_dir, os.path.basename(best_initial_local))
            shutil.move(best_initial_local, out_best)
            print("Saved initial best RF to:", out_best)

        print("RF best params:", rf_params)

        # Validate on val
        try:
            y_val_prob = best_rf.predict_proba(X_val)[:, 1]
            y_val_label = (y_val_prob >= args.prob_threshold).astype(int)
            val_metrics = evaluate_classification(y_val_bin, y_val_label, y_val_prob)
            print("Validation metrics (RF classifier):", val_metrics)
        except Exception as e:
            print("Validation failed:", e)
            val_metrics = None

        # Retrain on train+val and save final
        X_train_val = np.vstack([X_train, X_val]); y_train_val = np.concatenate([y_train_bin, y_val_bin])
        final = RandomForestClassifier(**best_rf.get_params())
        final.set_params(n_jobs=args.n_jobs, random_state=RANDOM_STATE)
        t0 = time.time(); final.fit(X_train_val, y_train_val); print("Final RF trained in", time.time() - t0, "s")

        final_local = os.path.join(local_tmp_out, tag_name("final_rf_classifier_binary.joblib"))
        joblib.dump(final, final_local)
        if is_gcs_path(args.out_dir):
            dest_final = gcs_join(args.out_dir, os.path.basename(final_local))
            upload_file_to_gcs(final_local, dest_final, requester_pays_project=args.requester_pays_project)
            print("Saved final RF to:", dest_final)
        else:
            out_final = os.path.join(args.out_dir, os.path.basename(final_local))
            shutil.move(final_local, out_final)
            print("Saved final RF to:", out_final)

        # Test eval and save predictions
        y_test_prob = final.predict_proba(X_test)[:, 1]
        y_test_label = (y_test_prob >= args.prob_threshold).astype(int)
        test_metrics = evaluate_classification(y_test_bin, y_test_label, y_test_prob)
        print("Test metrics:", test_metrics)

        # Build enriched predictions table with dates and pixel locations
        # If samples_test.csv is available (test_df), prefer it and augment; otherwise build from X_test grid.
        try:
            # reconstruct canonical grid from X_test for pixel centers/indices
            if feature_cols:
                lon_idx = None; lat_idx = None
                for j,c in enumerate(feature_cols):
                    if isinstance(c, str):
                        cl = c.lower()
                        if lon_idx is None and any(tok in cl for tok in ("lon","long","longitude","x")):
                            lon_idx = j
                        if lat_idx is None and any(tok in cl for tok in ("lat","latitude","y")):
                            lat_idx = j
                if lon_idx is None or lat_idx is None:
                    lon_idx = 0 if lon_idx is None else lon_idx
                    lat_idx = 1 if lat_idx is None else lat_idx
            else:
                lon_idx, lat_idx = 0, 1
            xs = X_test[:, lon_idx].astype(float)
            ys = X_test[:, lat_idx].astype(float)
            x_edges, y_edges, x_centers, y_centers = create_regular_grid_edges_and_centers(xs, ys, grid_res=400, buffer_frac=0.02)
        except Exception:
            x_edges = y_edges = x_centers = y_centers = None

        if test_df is not None:
            pred_table = test_df.copy()
            # ensure date column parsed
            date_col = detect_date_column_simple(pred_table)
            if date_col is not None:
                pred_table[date_col] = pd.to_datetime(pred_table[date_col], errors='coerce')
                pred_table['date_target'] = pred_table[date_col]
            else:
                pred_table['date_target'] = pd.NaT
            # observed label
            if 'y_raw' in pred_table.columns:
                pred_table['y_true_bin'] = pred_table['y_raw'].apply(lambda v: int(v >= args.threshold) if not pd.isna(v) else np.nan)
            elif 'y_bin' in pred_table.columns:
                pred_table['y_true_bin'] = pred_table['y_bin'].astype(int)
            else:
                pred_table['y_true_bin'] = y_test_bin if len(y_test_bin)==len(pred_table) else np.nan
            # attach predictions
            # If lengths match, attach by order; otherwise broadcast if single value
            if len(pred_table) == len(y_test_prob):
                pred_table['y_pred_prob'] = y_test_prob
                pred_table['y_pred_label'] = y_test_label
            else:
                # lengths differ; attach elementwise where possible, else NaN
                pred_table['y_pred_prob'] = np.nan
                pred_table['y_pred_label'] = np.nan
                minlen = min(len(pred_table), len(y_test_prob))
                pred_table.loc[:minlen-1, 'y_pred_prob'] = y_test_prob[:minlen]
                pred_table.loc[:minlen-1, 'y_pred_label'] = y_test_label[:minlen]

            # detect coordinate columns and compute pixel indices if possible
            lon_col, lat_col = find_coord_cols(pred_table)
            if lon_col is not None and lat_col is not None and x_edges is not None:
                pred_table['_mx'] = pd.to_numeric(pred_table[lon_col], errors='coerce')
                pred_table['_my'] = pd.to_numeric(pred_table[lat_col], errors='coerce')
                pred_table['xi'] = np.clip(np.searchsorted(x_edges, pred_table['_mx'].values) - 1, 0, len(x_centers)-1).astype(int)
                pred_table['yi'] = np.clip(np.searchsorted(y_edges, pred_table['_my'].values) - 1, 0, len(y_centers)-1).astype(int)
                pred_table['pixel_lon'] = x_centers[pred_table['xi'].astype(int).values]
                pred_table['pixel_lat'] = y_centers[pred_table['yi'].astype(int).values]
                pred_out_cols = ['date_target','xi','yi','pixel_lon','pixel_lat', lon_col, lat_col, 'y_true_bin','y_pred_prob','y_pred_label']
            else:
                pred_out_cols = ['date_target','y_true_bin','y_pred_prob','y_pred_label']
            preds_local = os.path.join(local_tmp_out, tag_name("predictions_test_rf_classifier_binary_enriched.csv"))
            pred_table.to_csv(preds_local, index=False)
        else:
            # build per-pixel predictions from X_test grid (use X_test rows — lengths must match y_test_prob)
            if x_edges is None or x_centers is None:
                # fallback minimal
                df_min = pd.DataFrame({"y_true_bin": y_test_bin if len(y_test_bin)==len(y_test_prob) else np.nan, "y_pred_prob": y_test_prob, "y_pred_label": y_test_label})
                preds_local = os.path.join(local_tmp_out, tag_name("predictions_test_rf_classifier_binary_min.csv"))
                df_min.to_csv(preds_local, index=False)
            else:
                # Use X_test rows directly to compute xi/yi and pixel centers per row (ensures matching lengths)
                xi_per_row = np.clip(np.searchsorted(x_edges, xs) - 1, 0, len(x_centers)-1).astype(int)
                yi_per_row = np.clip(np.searchsorted(y_edges, ys) - 1, 0, len(y_centers)-1).astype(int)
                pixel_lon_per_row = x_centers[xi_per_row]
                pixel_lat_per_row = y_centers[yi_per_row]
                df_grid = pd.DataFrame({
                    "xi": xi_per_row,
                    "yi": yi_per_row,
                    "pixel_lon": pixel_lon_per_row,
                    "pixel_lat": pixel_lat_per_row,
                    "y_pred_prob": y_test_prob,
                    "y_pred_label": y_test_label
                })
                # If y_test_bin aligns by X_test order, attach
                if len(y_test_bin) == len(df_grid):
                    df_grid['y_true_bin'] = y_test_bin
                else:
                    df_grid['y_true_bin'] = np.nan
                preds_local = os.path.join(local_tmp_out, tag_name("predictions_test_rf_classifier_binary_grid.csv"))
                df_grid.to_csv(preds_local, index=False)

        # upload/move predictions file
        if is_gcs_path(args.out_dir):
            dest_preds = gcs_join(args.out_dir, os.path.basename(preds_local))
            upload_file_to_gcs(preds_local, dest_preds, requester_pays_project=args.requester_pays_project)
            print("Saved test predictions to:", dest_preds)
        else:
            out_preds = os.path.join(args.out_dir, os.path.basename(preds_local))
            shutil.move(preds_local, out_preds)
            print("Saved test predictions to:", out_preds)

        # Save summary JSON (tagged)
        summary = {
            "model": "rf_classifier_binary",
            "val_metrics": val_metrics,
            "test_metrics": test_metrics,
            "best_params": rf_params,
            "n_iter": int(args.n_iter),
            "checkpoint_path": checkpoint_path
        }
        summary_safe = _to_serializable(summary)
        summary_local = os.path.join(local_tmp_out, tag_name("training_summary_rf_classifier_binary.json"))
        with open(summary_local, "w") as fh:
            json.dump(summary_safe, fh, indent=2)
        if is_gcs_path(args.out_dir):
            dest_summary = gcs_join(args.out_dir, os.path.basename(summary_local))
            upload_file_to_gcs(summary_local, dest_summary, requester_pays_project=args.requester_pays_project)
            print("Saved summary to:", dest_summary)
        else:
            out_summary = os.path.join(args.out_dir, os.path.basename(summary_local))
            shutil.move(summary_local, out_summary)
            print("Saved summary to:", out_summary)

        print("All outputs saved to", args.out_dir)

    finally:
        try:
            shutil.rmtree(local_tmp_out)
        except Exception:
            pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--npz", type=str, default=DEFAULT_NPZ, help="Processed dataset .npz (must contain X_train/X_val/X_test and y_*_bin)")
    parser.add_argument("--samples_dir", type=str, default=DEFAULT_SAMPLES_DIR, help="Folder with samples_train/val/test CSVs (used for metadata & fallback)")
    parser.add_argument("--out_dir", type=str, default=DEFAULT_OUT_DIR, help="Output folder (local path or gs://bucket/path/)")
    parser.add_argument("--n_iter", type=int, default=8)
    parser.add_argument("--n_jobs", type=int, default=4)
    parser.add_argument("--ts_splits", type=int, default=5)
    parser.add_argument("--no_predefined_split", action="store_false", dest="use_predefined_split", help="Disable using PredefinedSplit and use TimeSeriesSplit instead.")
    parser.set_defaults(use_predefined_split=True)
    parser.add_argument("--max_train_samples", type=int, default=None)
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to JSON checkpoint file to save/resume search (local recommended)")
    parser.add_argument("--prob_threshold", type=float, default=0.5, help="Probability threshold to convert probs into labels for evaluation")
    parser.add_argument("--threshold", type=float, default=20.0, help="Chl a threshold used to create y_bin (informational)")
    parser.add_argument("--use_gcsfs", action="store_true", help="Attempt gcsfs for reads if storage client fails (default: prefer storage client)")
    parser.add_argument("--requester_pays_project", type=str, default=None, help="GCP project id to use when accessing requester-pays buckets")
    parser.add_argument("--tag", type=str, default="nak", help="Optional tag appended to artifact filenames to avoid overwriting (default: nak). Use empty string to disable tagging.")
    args, _unknown = parser.parse_known_args()
    main(args)

Loading processed arrays from: gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz
Shapes: X_train=(948281, 39), X_val=(74421, 39), X_test=(495447, 39)
Using PredefinedSplit (train+val) for tuning.
Saved initial best RF to: gs://final_data_kgc2/rf_results/rf_best_initial_binary_pig.joblib
RF best params: {'n_estimators': 500, 'min_samples_split': 2, 'min_samples_leaf': 5, 'max_samples': 0.6, 'max_features': 'sqrt', 'max_depth': 10, 'class_weight': None, 'bootstrap': True}
Validation metrics (RF classifier): {'precision': 0.8175742013443094, 'recall': 0.9084081564152051, 'f1': 0.8606010016694491, 'average_precision': 0.9379559607168465, 'roc_auc': 0.971338969380898, 'brier': 0.06557838334845235, 'tn': 45756, 'fp': 4831, 'fn': 2183, 'tp': 21651}
Final RF trained in 180.4166374206543 s
Saved final RF to: gs://final_data_kgc2/rf_results/final_rf_classifier_binary_pig.joblib
Test metrics: {'precision': 0.31722310337014453, 'recall': 0.44880693538495847, 'f1': 0.3717136969

In [3]:
#!/usr/bin/env python3
"""
Train XGBoost classifier for binary bloom/no-bloom — GCS-ready (Pigeon Lake / pig tag)

This is the original script with two small, safe changes for Pigeon Lake:
- Default NPZ points to the Pigeon dataset.
- Default --tag is now "pig" (was "nak") and the tag is appended to all output filenames
  so runs for different lakes won't overwrite each other's artifacts.

Usage example:
  python3 train_xgb_binary_gcs_tagged_pig.py --tag pig --out_dir gs://final_data_kgc2/rf_results/ ...
"""
from __future__ import annotations
import os
import argparse
import json
import tempfile
import shutil
import time
import io
from pathlib import Path

import numpy as np
import pandas as pd

from sklearn.model_selection import TimeSeriesSplit, RandomizedSearchCV, PredefinedSplit, ParameterSampler
from sklearn.metrics import (
    precision_score, recall_score, f1_score, roc_auc_score, average_precision_score,
    brier_score_loss, confusion_matrix
)
from sklearn.base import clone
import joblib

# joblib parallel backend
from joblib import parallel_backend

# optional imports (GCS + progress)
try:
    import gcsfs
except Exception:
    gcsfs = None

try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = None

# xgboost import
try:
    from xgboost import XGBClassifier
except Exception:
    XGBClassifier = None

# preferred storage client
try:
    from google.cloud import storage
except Exception:
    storage = None

import warnings
warnings.filterwarnings("ignore", message="pkg_resources is deprecated as an API")
warnings.filterwarnings("ignore", category=FutureWarning)

# ---------------------------
# Defaults pointing to your bucket (Pigeon Lake)
# ---------------------------
DEFAULT_NPZ = "gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz"
DEFAULT_SAMPLES_DIR = "gs://final_data_kgc2/Final_data/"
DEFAULT_OUT_DIR = "gs://final_data_kgc2/rf_results/"
RANDOM_STATE = 0

# ---------------------------
# Helpers: GCS/local I/O
# ---------------------------
def is_gcs_path(p: str) -> bool:
    return isinstance(p, str) and p.startswith("gs://")

def gcs_join(prefix: str, *parts: str) -> str:
    p = prefix.rstrip('/')
    for part in parts:
        p = p.rstrip('/') + '/' + str(part).lstrip('/')
    return p

def upload_file_to_gcs(local_path: str, gs_uri: str, requester_pays_project: str | None = None):
    """Upload a local file to gs:// using google-cloud-storage (preferred) or gcsfs fallback."""
    if not is_gcs_path(gs_uri):
        raise ValueError("gs_uri must be gs:// path")
    # prefer storage client
    if storage is not None:
        client_kwargs = {}
        if requester_pays_project:
            client_kwargs["project"] = requester_pays_project
        client = storage.Client(**client_kwargs)
        _, rest = gs_uri.split("gs://", 1)
        bucket_name, _, blob_path = rest.partition("/")
        bucket = client.lookup_bucket(bucket_name)
        if bucket is None:
            raise FileNotFoundError(f"Destination bucket does not exist or is not accessible: gs://{bucket_name}")
        blob = bucket.blob(blob_path)
        blob.upload_from_filename(local_path)
        return
    # fallback to gcsfs if available
    if gcsfs is not None:
        fs = gcsfs.GCSFileSystem(token="google")
        fs.put(local_path, gs_uri)
        return
    raise RuntimeError("No method available to upload to GCS: install google-cloud-storage or gcsfs")

def fetch_gcs_to_local(gcs_path: str, local_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    """Download gs:// object to local path. Prefer storage client, fallback to gcsfs."""
    if not is_gcs_path(gcs_path):
        raise ValueError("gcs path required")
    # prefer storage client
    if storage is not None and not use_gcsfs:
        client_kwargs = {}
        if requester_pays_project:
            client_kwargs["project"] = requester_pays_project
        client = storage.Client(**client_kwargs)
        _, rest = gcs_path.split("gs://", 1)
        bucket_name, _, blob_path = rest.partition("/")
        bucket = client.lookup_bucket(bucket_name)
        if bucket is None:
            raise FileNotFoundError(f"Bucket not found: gs://{bucket_name}")
        blob = bucket.blob(blob_path)
        if not blob.exists():
            raise FileNotFoundError(f"Blob not found: {gcs_path}")
        blob.download_to_filename(local_path)
        return local_path
    # fallback to gcsfs
    if gcsfs is None:
        raise RuntimeError("gcsfs not installed and storage client not available")
    fs = gcsfs.GCSFileSystem(token="google")
    with fs.open(gcs_path, "rb") as src, open(local_path, "wb") as dst:
        shutil.copyfileobj(src, dst)
    return local_path

def load_npz(npz_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False) -> dict:
    """Load npz from local path or GCS into a dict."""
    if is_gcs_path(npz_path):
        tmp = tempfile.NamedTemporaryFile(suffix=".npz", delete=False)
        tmp.close()
        fetch_gcs_to_local(npz_path, tmp.name, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs)
        data = np.load(tmp.name, allow_pickle=True)
        result = {k: data[k] for k in data.files}
        try:
            os.unlink(tmp.name)
        except Exception:
            pass
        return result
    else:
        data = np.load(npz_path, allow_pickle=True)
        return {k: data[k] for k in data.files}

def load_samples_csvs(samples_dir: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    """Load samples_train/val/test CSVs (used for saving predictions and metadata)."""
    def load_one(path_or_dir, name):
        if is_gcs_path(str(path_or_dir)):
            # prefer storage client
            if storage is not None and not use_gcsfs:
                client_kwargs = {}
                if requester_pays_project:
                    client_kwargs["project"] = requester_pays_project
                client = storage.Client(**client_kwargs)
                _, rest = str(path_or_dir).split("gs://", 1)
                bucket_name, _, prefix = rest.partition("/")
                prefix = prefix.rstrip('/') + '/'
                blob_path = prefix + name
                bucket = client.lookup_bucket(bucket_name)
                if bucket is None:
                    raise FileNotFoundError(f"Bucket not found: gs://{bucket_name}")
                blob = bucket.blob(blob_path)
                if not blob.exists():
                    raise FileNotFoundError(f"Samples CSV not found: gs://{bucket_name}/{blob_path}")
                data = blob.download_as_bytes()
                return pd.read_csv(io.BytesIO(data), parse_dates=['date_t', 'date_target'])
            # fallback gcsfs
            if gcsfs is None:
                raise RuntimeError("gcsfs not installed and storage client not available")
            fs = gcsfs.GCSFileSystem(token="google")
            target = str(path_or_dir).rstrip('/') + '/' + name
            with fs.open(target, "rb") as f:
                return pd.read_csv(f, parse_dates=['date_t', 'date_target'])
        else:
            p = Path(path_or_dir) / name
            return pd.read_csv(str(p), parse_dates=['date_t', 'date_target'])
    return load_one(samples_dir, "samples_train.csv"), load_one(samples_dir, "samples_val.csv"), load_one(samples_dir, "samples_test.csv")

# ---------------------------
# Serialization helpers
# ---------------------------
def _save_json_atomic_local(path: str, data):
    tmp = path + ".tmp"
    with open(tmp, "w") as fh:
        json.dump(data, fh, default=lambda o: repr(o), indent=2)
    os.replace(tmp, path)

def _to_serializable(obj):
    if isinstance(obj, dict):
        return {k: _to_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_serializable(v) for v in obj]
    if isinstance(obj, np.ndarray):
        return _to_serializable(obj.tolist())
    if isinstance(obj, (np.integer,)):
        return int(obj)
    if isinstance(obj, (np.floating,)):
        return float(obj)
    if isinstance(obj, (np.bool_, bool)):
        return bool(obj)
    return obj

# ----------------- robust scoring callable -----------------
def scoring_callable(estimator, X, y):
    try:
        if hasattr(estimator, "predict_proba"):
            yprob = estimator.predict_proba(X)[:, 1]
            return float(average_precision_score(y, yprob))
        elif hasattr(estimator, "decision_function"):
            yprob = estimator.decision_function(X)
            return float(average_precision_score(y, yprob))
        else:
            yhat = estimator.predict(X)
            return float(f1_score(y, yhat, zero_division=0))
    except Exception:
        return float("-inf")

# ----------------- checkpointed randomized search -----------------
def _save_json_atomic(path, data):
    try:
        _save_json_atomic_local(path, data)
    except Exception:
        with open(path, "w") as fh:
            json.dump(data, fh, default=lambda o: repr(o), indent=2)

def checkpointed_random_search(estimator, param_dist, X_train, y_train, cv, n_iter=8, random_state=RANDOM_STATE, checkpoint_path=None, scorer=None, verbose=False, checkpoint_verbose=False):
    if checkpoint_path is None:
        checkpoint_path = os.path.join(tempfile.gettempdir(), "xgb_search_checkpoint.json")
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, "r") as fh:
                history = json.load(fh)
            tried = set(h['params_repr'] for h in history if 'params_repr' in h)
            if verbose:
                print(f"Loaded checkpoint with {len(history)} completed candidates from {checkpoint_path}")
        except Exception:
            history = []; tried = set()
    else:
        history = []; tried = set()
    param_list = list(ParameterSampler(param_dist, n_iter=n_iter, random_state=random_state))
    remaining = [p for p in param_list if repr(p) not in tried]
    best_score = -np.inf; best_params = None; best_est = None
    iterator = enumerate(remaining, start=1)
    if tqdm is not None and verbose:
        iterator = enumerate(tqdm(remaining, desc="XGB candidates", ncols=100), start=1)
    for i, params in iterator:
        p_repr = repr(params)
        if verbose:
            print(f"\nCandidate {i}/{len(remaining)}: {params}")
        est = clone(estimator); est.set_params(**params)
        fold_scores = []; fold_times = []
        for fold_idx, (tr_idx, te_idx) in enumerate(cv.split(X_train), start=1):
            t0 = time.time()
            try:
                est.fit(X_train[tr_idx], y_train[tr_idx])
                try:
                    if scorer is not None:
                        score = float(scorer(est, X_train[te_idx], y_train[te_idx]))
                    else:
                        score = float(scoring_callable(est, X_train[te_idx], y_train[te_idx]))
                except TypeError as te:
                    if verbose or checkpoint_verbose:
                        print("Scorer TypeError, falling back. Error:", te)
                    score = float(scoring_callable(est, X_train[te_idx], y_train[te_idx]))
                except Exception as e_other:
                    if verbose or checkpoint_verbose:
                        print("Scorer exception, using fallback:", e_other)
                    score = float(scoring_callable(est, X_train[te_idx], y_train[te_idx]))
            except Exception as e:
                score = float("-inf")
                if verbose:
                    print("  fold error during fit/predict/score:", e)
            elapsed = time.time() - t0
            fold_scores.append(score); fold_times.append(elapsed)
            if verbose:
                print(f"  fold {fold_idx}: score={score if np.isfinite(score) else 'FAILED'}, time={elapsed:.1f}s")
        mean_score = float(np.nanmean([s for s in fold_scores if np.isfinite(s)])) if len(fold_scores) > 0 else float("-inf")
        rec = {"params_repr": p_repr, "params": params, "fold_scores": fold_scores, "fold_times": fold_times, "mean_score": mean_score, "timestamp": time.time()}
        history.append(rec)
        try:
            _save_json_atomic(checkpoint_path, history)
            if verbose:
                print("  Saved checkpoint to", checkpoint_path)
        except Exception:
            pass
        if mean_score > best_score:
            best_score = mean_score; best_params = params
            try:
                best_est = clone(estimator); best_est.set_params(**params)
                best_est.fit(X_train, y_train)
                if verbose:
                    print("  -> New best found and refit on full training set.")
            except Exception as e:
                if verbose:
                    print("  -> Refit of best failed:", e)
    return best_est, best_params, history

def run_random_search(estimator, param_dist, X_train, y_train, cv, n_iter=8, n_jobs=1, scorer=None, random_state=RANDOM_STATE, verbose_search=False, checkpoint_path=None):
    if checkpoint_path:
        if verbose_search:
            print("Running checkpointed randomized search (checkpoint_path=%s)" % checkpoint_path)
        best_est, best_params, history = checkpointed_random_search(estimator, param_dist, X_train, y_train, cv, n_iter=n_iter, random_state=random_state, checkpoint_path=checkpoint_path, scorer=scorer, verbose=verbose_search)
        return best_est, best_params, history
    search = RandomizedSearchCV(estimator, param_distributions=param_dist, n_iter=n_iter,
                                scoring=scoring_callable, cv=cv, random_state=random_state,
                                n_jobs=n_jobs, verbose=1 if verbose_search else 0)
    t0 = time.time()
    if n_jobs and n_jobs > 1:
        with parallel_backend('threading', n_jobs=n_jobs):
            search.fit(X_train, y_train)
    else:
        search.fit(X_train, y_train)
    elapsed = time.time() - t0
    print(f"Search finished in {elapsed:.1f}s; best_score={search.best_score_:.4f}")
    return search.best_estimator_, getattr(search, "best_params_", None), search

def evaluate_classification(y_true, y_pred_label, y_pred_prob):
    y_true = np.asarray(y_true)
    y_pred_label = np.asarray(y_pred_label)
    y_pred_prob = np.asarray(y_pred_prob)
    try:
        prec = float(precision_score(y_true, y_pred_label, zero_division=0))
        rec = float(recall_score(y_true, y_pred_label, zero_division=0))
        f1 = float(f1_score(y_true, y_pred_label, zero_division=0))
    except Exception:
        prec = rec = f1 = float("nan")
    try:
        ap = float(average_precision_score(y_true, y_pred_prob))
    except Exception:
        ap = float("nan")
    try:
        roc = float(roc_auc_score(y_true, y_pred_prob))
    except Exception:
        roc = float("nan")
    try:
        brier = float(brier_score_loss(y_true, y_pred_prob))
    except Exception:
        brier = float("nan")
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred_label).ravel()
    except Exception:
        tn = fp = fn = tp = None
    return {
        "precision": prec, "recall": rec, "f1": f1,
        "average_precision": ap, "roc_auc": roc, "brier": brier,
        "tn": tn, "fp": fp, "fn": fn, "tp": tp
    }

# ----------------- main -----------------
def main(args):
    if XGBClassifier is None:
        raise RuntimeError("xgboost.XGBClassifier not available. Please install xgboost (pip install xgboost).")

    # Tag helper to avoid overwriting outputs across lakes/runs
    tag = getattr(args, "tag", "")
    def tag_name(base_filename: str) -> str:
        if not tag:
            return base_filename
        base_noext, ext = os.path.splitext(base_filename)
        return f"{base_noext}_{tag}{ext}"

    # If checkpoint_path not provided, create a tagged one in tempdir so different tags don't clash
    checkpoint_path = args.checkpoint_path if args.checkpoint_path else os.path.join(tempfile.gettempdir(), tag_name("xgb_search_checkpoint.json"))

    # ensure local outdir exists (for local intermediate files)
    if not is_gcs_path(args.out_dir):
        os.makedirs(args.out_dir, exist_ok=True)

    print("Loading processed arrays from:", args.npz)
    data = load_npz(args.npz, requester_pays_project=args.requester_pays_project, use_gcsfs=args.use_gcsfs)

    if 'X_train' not in data:
        raise RuntimeError("processed_dataset must contain X_train/X_val/X_test arrays (and y_train_bin etc).")

    X_train = data['X_train']; X_val = data['X_val']; X_test = data['X_test']
    feature_cols = list(data['feature_cols']) if 'feature_cols' in data else None

    if 'y_train_bin' in data and 'y_val_bin' in data and 'y_test_bin' in data:
        y_train_bin = np.asarray(data['y_train_bin'])
        y_val_bin = np.asarray(data['y_val_bin'])
        y_test_bin = np.asarray(data['y_test_bin'])
    else:
        print("Binary targets not found in NPZ — attempting to load from samples CSVs.")
        train_df, val_df, test_df = load_samples_csvs(args.samples_dir, requester_pays_project=args.requester_pays_project, use_gcsfs=args.use_gcsfs)
        if 'y_bin' not in train_df.columns:
            raise RuntimeError("y_bin column not found in samples CSVs nor NPZ. Please run data prep to save y_bin.")
        y_train_bin = train_df['y_bin'].astype(int).values
        y_val_bin = val_df['y_bin'].astype(int).values
        y_test_bin = test_df['y_bin'].astype(int).values

    print(f"Shapes: X_train={X_train.shape}, X_val={X_val.shape}, X_test={X_test.shape}")

    # load sample CSVs for metadata/predictions (non-fatal)
    try:
        train_df, val_df, test_df = load_samples_csvs(args.samples_dir, requester_pays_project=args.requester_pays_project, use_gcsfs=args.use_gcsfs)
    except Exception:
        train_df = val_df = test_df = None

    # respect max_train_samples
    if args.max_train_samples:
        X_train_sub = X_train[:args.max_train_samples]; y_train_sub = y_train_bin[:args.max_train_samples]
        print("Using subsample for tuning:", X_train_sub.shape)
    else:
        X_train_sub, y_train_sub = X_train, y_train_bin

    # build CV
    if args.use_predefined_split:
        X_comb = np.vstack([X_train_sub, X_val]); y_comb = np.concatenate([y_train_sub, y_val_bin])
        test_fold = np.concatenate([np.full(X_train_sub.shape[0], -1), np.zeros(X_val.shape[0], dtype=int)])
        cv_obj = PredefinedSplit(test_fold)
        X_tune, y_tune = X_comb, y_comb
        print("Using PredefinedSplit (train+val) for tuning.")
    else:
        cv_obj = TimeSeriesSplit(n_splits=args.ts_splits)
        X_tune, y_tune = X_train_sub, y_train_sub
        print(f"Using TimeSeriesSplit(n_splits={args.ts_splits})")

    # compute imbalance ratio
    try:
        pos = int(np.sum(y_train_sub == 1))
        neg = int(np.sum(y_train_sub == 0))
        imbalance_ratio = float(neg / pos) if pos > 0 else 1.0
    except Exception:
        imbalance_ratio = 1.0

    # build param_dist
    if args.fast:
        param_dist = {
            "n_estimators": [100, 200],
            "max_depth": [3, 5],
            "learning_rate": [0.05, 0.1],
            "subsample": [0.6, 0.8],
            "colsample_bytree": [0.6, 0.8],
            "gamma": [0, 1],
            "reg_lambda": [1, 5],
            "reg_alpha": [0, 1],
            "scale_pos_weight": [1, imbalance_ratio],
            "use_label_encoder": [False],
            "verbosity": [0]
        }
    else:
        param_dist = {
            "n_estimators": [200, 300, 500],
            "max_depth": [3, 5, 8, 10],
            "learning_rate": [0.01, 0.03, 0.05, 0.1],
            "subsample": [0.5, 0.7, 0.9],
            "colsample_bytree": [0.5, 0.7, 0.9],
            "gamma": [0, 0.5, 1.0, 2.0],
            "reg_lambda": [0.5, 1.0, 5.0],
            "reg_alpha": [0.0, 0.5, 1.0],
            "scale_pos_weight": [1, imbalance_ratio],
            "use_label_encoder": [False],
            "verbosity": [0]
        }

    # instantiate XGB
    xgb = XGBClassifier(random_state=RANDOM_STATE, n_jobs=1, objective="binary:logistic", eval_metric="logloss")

    best_xgb, xgb_params, xgb_search = run_random_search(
        xgb, param_dist, X_tune, y_tune, cv_obj,
        n_iter=args.n_iter, n_jobs=min(args.n_jobs, 8),
        scorer=None, random_state=RANDOM_STATE, verbose_search=args.verbose, checkpoint_path=checkpoint_path
    )
    if best_xgb is None:
        print("No XGB candidate found (search returned None). Exiting.")
        return

    # Save locally then upload if needed (tagged filenames)
    local_tmp = tempfile.mkdtemp(prefix="xgb_out_")
    try:
        best_initial_local = os.path.join(local_tmp, tag_name("xgb_best_initial_binary.joblib"))
        joblib.dump(best_xgb, best_initial_local)
        if is_gcs_path(args.out_dir):
            dest_best = gcs_join(args.out_dir, os.path.basename(best_initial_local))
            upload_file_to_gcs(best_initial_local, dest_best, requester_pays_project=args.requester_pays_project)
            print("Saved initial best XGB to:", dest_best)
        else:
            out_best = os.path.join(args.out_dir, os.path.basename(best_initial_local))
            shutil.move(best_initial_local, out_best)
            print("Saved initial best XGB to:", out_best)

        print("XGB best params:", xgb_params)

        # Validate on val
        try:
            y_val_prob = best_xgb.predict_proba(X_val)[:, 1]
            y_val_label = (y_val_prob >= args.prob_threshold).astype(int)
            val_metrics = evaluate_classification(y_val_bin, y_val_label, y_val_prob)
            print("Validation metrics (XGB):", val_metrics)
        except Exception as e:
            print("Validation failed:", e)
            val_metrics = None

        # Retrain on train+val
        X_train_val = np.vstack([X_train, X_val]); y_train_val = np.concatenate([y_train_bin, y_val_bin])
        final = XGBClassifier(**best_xgb.get_params())
        final.set_params(n_jobs=args.n_jobs, random_state=RANDOM_STATE, use_label_encoder=False, verbosity=0)
        t0 = time.time(); final.fit(X_train_val, y_train_val); print("Final XGB trained in", time.time() - t0, "s")

        final_local = os.path.join(local_tmp, tag_name("final_xgb_classifier_binary.joblib"))
        joblib.dump(final, final_local)
        if is_gcs_path(args.out_dir):
            dest_final = gcs_join(args.out_dir, os.path.basename(final_local))
            upload_file_to_gcs(final_local, dest_final, requester_pays_project=args.requester_pays_project)
            print("Saved final XGB to:", dest_final)
        else:
            out_final = os.path.join(args.out_dir, os.path.basename(final_local))
            shutil.move(final_local, out_final)
            print("Saved final XGB to:", out_final)

        # Test evaluation & predictions
        y_test_prob = final.predict_proba(X_test)[:, 1]
        y_test_label = (y_test_prob >= args.prob_threshold).astype(int)
        test_metrics = evaluate_classification(y_test_bin, y_test_label, y_test_prob)
        print("Test metrics (XGB):", test_metrics)

        if test_df is not None:
            test_out = test_df.copy()
            if 'y_raw' in test_out.columns:
                test_out['y_true_bin'] = test_out['y_raw'].apply(lambda v: int(v >= args.threshold) if not pd.isna(v) else np.nan)
            else:
                test_out['y_true_bin'] = y_test_bin
            test_out['y_pred_prob'] = y_test_prob
            test_out['y_pred_label'] = y_test_label
            preds_local = os.path.join(local_tmp, tag_name("predictions_test_xgb_binary.csv"))
            test_out.to_csv(preds_local, index=False)
        else:
            df_min = pd.DataFrame({"y_true_bin": y_test_bin, "y_pred_prob": y_test_prob, "y_pred_label": y_test_label})
            preds_local = os.path.join(local_tmp, tag_name("predictions_test_xgb_binary_min.csv"))
            df_min.to_csv(preds_local, index=False)

        if is_gcs_path(args.out_dir):
            dest_preds = gcs_join(args.out_dir, os.path.basename(preds_local))
            upload_file_to_gcs(preds_local, dest_preds, requester_pays_project=args.requester_pays_project)
            print("Saved test predictions to:", dest_preds)
        else:
            out_preds = os.path.join(args.out_dir, os.path.basename(preds_local))
            shutil.move(preds_local, out_preds)
            print("Saved test predictions to:", out_preds)

        # Save summary (tagged)
        summary = {
            "model": "xgb_classifier_binary",
            "val_metrics": val_metrics,
            "test_metrics": test_metrics,
            "best_params": xgb_params,
            "n_iter": int(args.n_iter),
            "checkpoint_path": checkpoint_path
        }
        summary_safe = _to_serializable(summary)
        summary_local = os.path.join(local_tmp, tag_name("training_summary_xgb_classifier_binary.json"))
        with open(summary_local, "w") as fh:
            json.dump(summary_safe, fh, indent=2)
        if is_gcs_path(args.out_dir):
            dest_summary = gcs_join(args.out_dir, os.path.basename(summary_local))
            upload_file_to_gcs(summary_local, dest_summary, requester_pays_project=args.requester_pays_project)
            print("Saved summary to:", dest_summary)
        else:
            out_summary = os.path.join(args.out_dir, os.path.basename(summary_local))
            shutil.move(summary_local, out_summary)
            print("Saved summary to:", out_summary)

        print("All outputs saved to", args.out_dir)

    finally:
        try:
            shutil.rmtree(local_tmp)
        except Exception:
            pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--npz", type=str, default=DEFAULT_NPZ, help="Processed dataset .npz (must contain X_train/X_val/X_test and y_*_bin)")
    parser.add_argument("--samples_dir", type=str, default=DEFAULT_SAMPLES_DIR, help="Folder with samples_train/val/test CSVs (used for metadata & fallback)")
    parser.add_argument("--out_dir", type=str, default=DEFAULT_OUT_DIR, help="Output folder (local path or gs://bucket/path/)")
    parser.add_argument("--n_iter", type=int, default=16)
    parser.add_argument("--n_jobs", type=int, default=4)
    parser.add_argument("--ts_splits", type=int, default=5)
    parser.add_argument("--no_predefined_split", action="store_false", dest="use_predefined_split", help="Disable using PredefinedSplit and use TimeSeriesSplit instead.")
    parser.set_defaults(use_predefined_split=True)
    parser.add_argument("--max_train_samples", type=int, default=None)
    parser.add_argument("--fast", action="store_true")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to JSON checkpoint file to save/resume search (local recommended)")
    parser.add_argument("--prob_threshold", type=float, default=0.5, help="Probability threshold to convert probs into labels for evaluation")
    parser.add_argument("--threshold", type=float, default=20.0, help="Chl a threshold used to create y_bin (informational)")
    parser.add_argument("--use_gcsfs", action="store_true", help="Attempt gcsfs for reads if storage client fails (default: prefer storage client)")
    parser.add_argument("--requester_pays_project", type=str, default=None, help="GCP project id to use when accessing requester-pays buckets")
    parser.add_argument("--tag", type=str, default="pig", help="Optional tag appended to artifact filenames to avoid overwriting (default: pig). Use empty string to disable tagging.")
    args, _unknown = parser.parse_known_args()
    main(args)

Loading processed arrays from: gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz
Shapes: X_train=(948281, 39), X_val=(74421, 39), X_test=(495447, 39)
Using PredefinedSplit (train+val) for tuning.
Saved initial best XGB to: gs://final_data_kgc2/rf_results/xgb_best_initial_binary_pig.joblib
XGB best params: {'verbosity': 0, 'use_label_encoder': False, 'subsample': 0.5, 'scale_pos_weight': 1.3642127360395713, 'reg_lambda': 1.0, 'reg_alpha': 1.0, 'n_estimators': 500, 'max_depth': 3, 'learning_rate': 0.05, 'gamma': 2.0, 'colsample_bytree': 0.5}
Validation metrics (XGB): {'precision': 0.786209184049532, 'recall': 0.8950658722832928, 'f1': 0.8371134829697064, 'average_precision': 0.8890391218396061, 'roc_auc': 0.9517202854299869, 'brier': 0.08197013808898203, 'tn': 44786, 'fp': 5801, 'fn': 2501, 'tp': 21333}
Final XGB trained in 12.07669472694397 s
Saved final XGB to: gs://final_data_kgc2/rf_results/final_xgb_classifier_binary_pig.joblib
Test metrics (XGB): {'precision': 

In [None]:
#!/usr/bin/env python3
"""
Train a Bayesian logistic classifier for binary bloom/no-bloom — GCS-ready (tagged, non-overwriting)

This is an updated variant of your script that:
- Adds a --tag CLI argument (default "pig").
- Includes the tag in saved filenames (scaler, model, results, checkpoints).
- Writes the tag into checkpoint/history entries and the final results JSON.
- Avoids overwriting previous outputs by embedding a timestamp (and tag) in saved filenames.
- Adds a small atomic JSON saver used by the checkpoint procedure.

Usage (example):
python3 train_bayes_binary_gcs_tagged_pig.py \
  --npz "gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz" \
  --samples_dir "gs://final_data_kgc2/Final_data/" \
  --out_dir "gs://final_data_kgc2/bayes_results/" \
  --n_iter 12 --n_jobs 2 --max_iter 5000 --solver saga --scale_features \
  --tag pig --verbose

Notes:
- When running inside Jupyter / IPython the script now uses parse_known_args() to ignore
  Jupyter kernel arguments (e.g. -f /.../kernel-*.json) so it won't raise an argparse error.
- This preserves your original logic but ensures outputs carry the "tag" metadata
  and are written with timestamped filenames to avoid overwrites.
"""
from __future__ import annotations
import os
import argparse
import json
import tempfile
import shutil
import time
import io
from pathlib import Path

import numpy as np
import pandas as pd
import joblib

from sklearn.model_selection import TimeSeriesSplit, PredefinedSplit, ParameterSampler
from sklearn.metrics import (
    precision_score, recall_score, f1_score, roc_auc_score, average_precision_score,
    brier_score_loss, confusion_matrix
)
from sklearn.base import clone
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler

# Optional PyMC3 imports
try:
    import pymc3 as pm
    import arviz as az
    PM_AVAILABLE = True
except Exception:
    pm = None
    az = None
    PM_AVAILABLE = False

# joblib parallel backend
from joblib import parallel_backend

# optional imports (gcsfs + progress)
try:
    import gcsfs
except Exception:
    gcsfs = None

try:
    from tqdm.auto import tqdm
except Exception:
    tqdm = None

# preferred storage client
try:
    from google.cloud import storage
except Exception:
    storage = None

import warnings
warnings.filterwarnings("ignore", message="pkg_resources is deprecated as an API")
warnings.filterwarnings("ignore", category=FutureWarning)

# ---------------------------
# Defaults pointing to your bucket (Pigeon)
# ---------------------------
DEFAULT_NPZ = "gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz"
DEFAULT_SAMPLES_DIR = "gs://final_data_kgc2/Final_data/"
DEFAULT_OUT_DIR = "gs://final_data_kgc2/bayes_results/"
RANDOM_STATE = 0

# ---------------------------
# Helpers: GCS / local I/O
# ---------------------------
def is_gcs_path(p: str) -> bool:
    return isinstance(p, str) and p.startswith("gs://")

def gcs_join(prefix: str, *parts: str) -> str:
    p = prefix.rstrip('/')
    for part in parts:
        p = p.rstrip('/') + '/' + str(part).lstrip('/')
    return p

def upload_file_to_gcs(local_path: str, gs_uri: str, requester_pays_project: str | None = None):
    if not is_gcs_path(gs_uri):
        raise ValueError("gs_uri must be a gs:// path")
    if storage is not None:
        client_kwargs = {}
        if requester_pays_project:
            client_kwargs["project"] = requester_pays_project
        client = storage.Client(**client_kwargs)
        _, rest = gs_uri.split("gs://", 1)
        bucket_name, _, blob_path = rest.partition("/")
        bucket = client.lookup_bucket(bucket_name)
        if bucket is None:
            raise FileNotFoundError(f"Destination bucket does not exist or is not accessible: gs://{bucket_name}")
        blob = bucket.blob(blob_path)
        try:
            blob.upload_from_filename(local_path)
        except Exception as e:
            raise RuntimeError(f"Upload failed for {local_path} -> {gs_uri}: {e}") from e
        return
    if gcsfs is not None:
        fs = gcsfs.GCSFileSystem(token="google")
        fs.put(local_path, gs_uri)
        return
    raise RuntimeError("No method available to upload to GCS: install google-cloud-storage or gcsfs")

def fetch_gcs_to_local(gcs_path: str, local_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    if not is_gcs_path(gcs_path):
        raise ValueError("gcs_path must be gs://...")
    if storage is not None and not use_gcsfs:
        client_kwargs = {}
        if requester_pays_project:
            client_kwargs["project"] = requester_pays_project
        client = storage.Client(**client_kwargs)
        _, rest = gcs_path.split("gs://", 1)
        bucket_name, _, blob_path = rest.partition("/")
        bucket = client.lookup_bucket(bucket_name)
        if bucket is None:
            raise FileNotFoundError(f"Bucket not found: gs://{bucket_name}")
        blob = bucket.blob(blob_path)
        if not blob.exists():
            raise FileNotFoundError(f"Blob not found: {gcs_path}")
        blob.download_to_filename(local_path)
        return local_path
    if gcsfs is None:
        raise RuntimeError("gcsfs not installed and storage client not available")
    fs = gcsfs.GCSFileSystem(token="google")
    with fs.open(gcs_path, "rb") as src, open(local_path, "wb") as dst:
        shutil.copyfileobj(src, dst)
    return local_path

def save_local_or_gcs(local_src_path: str, out_dir: str, dest_basename: str | None = None, requester_pays_project: str | None = None) -> str:
    """
    Save a local file to out_dir. If out_dir is a GCS path it will upload;
    otherwise it will move to a local directory. This function does not attempt
    to detect collisions — callers should provide distinct dest_basename values
    (we use tag+timestamp in callers).
    """
    if dest_basename is None:
        dest_basename = os.path.basename(local_src_path)
    if is_gcs_path(out_dir):
        gs_dest = gcs_join(out_dir, dest_basename)
        upload_file_to_gcs(local_src_path, gs_dest, requester_pays_project=requester_pays_project)
        try:
            os.remove(local_src_path)
        except Exception:
            pass
        return gs_dest
    else:
        os.makedirs(out_dir, exist_ok=True)
        dst = os.path.join(out_dir, dest_basename)
        shutil.move(local_src_path, dst)
        return dst

def load_samples_csvs(samples_dir: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    def load_one(path_or_dir, name):
        if is_gcs_path(str(path_or_dir)):
            if storage is not None and not use_gcsfs:
                client_kwargs = {}
                if requester_pays_project:
                    client_kwargs["project"] = requester_pays_project
                client = storage.Client(**client_kwargs)
                _, rest = str(path_or_dir).split("gs://", 1)
                bucket_name, _, prefix = rest.partition("/")
                prefix = prefix.rstrip('/') + '/'
                blob_path = prefix + name
                bucket = client.lookup_bucket(bucket_name)
                if bucket is None:
                    raise FileNotFoundError(f"Bucket not found: gs://{bucket_name}")
                blob = bucket.blob(blob_path)
                if not blob.exists():
                    raise FileNotFoundError(f"Samples CSV not found: gs://{bucket_name}/{blob_path}")
                data = blob.download_as_bytes()
                return pd.read_csv(io.BytesIO(data), parse_dates=['date_t', 'date_target'])
            if gcsfs is None:
                raise RuntimeError("gcsfs not installed and storage client not available")
            fs = gcsfs.GCSFileSystem(token="google")
            target = str(path_or_dir).rstrip('/') + '/' + name
            with fs.open(target, "rb") as f:
                return pd.read_csv(f, parse_dates=['date_t', 'date_target'])
        else:
            p = Path(path_or_dir) / name
            return pd.read_csv(str(p), parse_dates=['date_t', 'date_target'])
    return load_one(samples_dir, "samples_train.csv"), load_one(samples_dir, "samples_val.csv"), load_one(samples_dir, "samples_test.csv")

def load_npz(npz_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False) -> dict:
    if is_gcs_path(npz_path):
        tmp = tempfile.NamedTemporaryFile(suffix=".npz", delete=False)
        tmp.close()
        fetch_gcs_to_local(npz_path, tmp.name, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs)
        data = np.load(tmp.name, allow_pickle=True)
        result = {k: data[k] for k in data.files}
        try:
            os.unlink(tmp.name)
        except Exception:
            pass
        return result
    else:
        data = np.load(npz_path, allow_pickle=True)
        return {k: data[k] for k in data.files}

# ---------------------------
# Utility: JSON-safe conversion
# ---------------------------
def _to_serializable(obj):
    if isinstance(obj, dict):
        return {k: _to_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_to_serializable(v) for v in obj]
    if isinstance(obj, np.ndarray):
        return _to_serializable(obj.tolist())
    if isinstance(obj, (np.integer,)):
        return int(obj)
    if isinstance(obj, (np.floating,)):
        return float(obj)
    if isinstance(obj, (np.bool_, bool)):
        return bool(obj)
    return obj

def timestamp_str():
    return time.strftime("%Y%m%dT%H%M%S", time.gmtime())

def _save_json_atomic(path: str, data):
    """Write JSON to path atomically (write temp then replace)."""
    tmp = None
    try:
        dirn = os.path.dirname(path) or "."
        fd, tmp = tempfile.mkstemp(prefix="tmp_json_", dir=dirn, text=True)
        with os.fdopen(fd, "w") as fh:
            json.dump(_to_serializable(data), fh, indent=2)
            fh.flush()
            os.fsync(fh.fileno())
        os.replace(tmp, path)
    finally:
        if tmp and os.path.exists(tmp):
            try:
                os.remove(tmp)
            except Exception:
                pass

# ----------------- scoring callable (same) -----------------
def scoring_callable(estimator, X, y):
    try:
        if hasattr(estimator, "predict_proba"):
            yprob = estimator.predict_proba(X)[:, 1]
            return float(average_precision_score(y, yprob))
        elif hasattr(estimator, "decision_function"):
            yprob = estimator.decision_function(X)
            return float(average_precision_score(y, yprob))
        else:
            yhat = estimator.predict(X)
            return float(f1_score(y, yhat, zero_division=0))
    except Exception:
        return float("-inf")

# ----------------- checkpointed randomized search (estimator-safe) -----------------
def checkpointed_random_search(estimator, param_dist, X_train, y_train, cv,
                               n_iter=8, random_state=RANDOM_STATE,
                               checkpoint_path: str | None = None,
                               scorer=None, verbose=False, checkpoint_verbose=False,
                               run_tag: str | None = None, out_dir: str | None = None, requester_pays_project: str | None = None):
    """
    Performs randomized search but checkpoints history to checkpoint_path (local).
    If out_dir is provided, a copy of the checkpoint will be uploaded to out_dir
    after each local save (dest named to include tag+timestamp).
    Each history record will include 'tag': run_tag (if provided).
    """
    if checkpoint_path is None:
        # default local checkpoint in tempdir, include tag+timestamp to avoid collisions
        fname = f"bayes_search_checkpoint_{run_tag or 'run'}_{timestamp_str()}.json"
        checkpoint_path = os.path.join(tempfile.gettempdir(), fname)
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, "r") as fh:
                history = json.load(fh)
            tried = set(h.get('params_repr') for h in history if 'params_repr' in h)
            if verbose:
                print(f"Loaded checkpoint with {len(history)} completed candidates from {checkpoint_path}")
        except Exception:
            history = []; tried = set()
    else:
        history = []; tried = set()
    param_list = list(ParameterSampler(param_dist, n_iter=n_iter, random_state=random_state))
    remaining = [p for p in param_list if repr(p) not in tried]
    best_score = -np.inf; best_params = None; best_est = None
    iterator = enumerate(remaining, start=1)
    if tqdm is not None and verbose:
        iterator = enumerate(tqdm(remaining, desc="Bayes candidates", ncols=100), start=1)
    for i, params in iterator:
        p_repr = repr(params)
        if verbose:
            print(f"\nCandidate {i}/{len(remaining)}: {params}")
        est = clone(estimator)
        try:
            valid_keys = set(est.get_params().keys())
        except Exception:
            valid_keys = set()
        set_params = {k: v for k, v in params.items() if k in valid_keys}
        fit_kwargs = {k: v for k, v in params.items() if k not in valid_keys}
        if set_params:
            try:
                est.set_params(**set_params)
            except Exception:
                if verbose:
                    print("Warning: set_params failed for some params; continuing.")
        fold_scores = []; fold_times = []
        for fold_idx, (tr_idx, te_idx) in enumerate(cv.split(X_train), start=1):
            t0 = time.time()
            try:
                if type(est).__name__ == "BayesLogisticClassifier":
                    try:
                        est.fit(X_train[tr_idx], y_train[tr_idx], **fit_kwargs)
                    except TypeError:
                        est.fit(X_train[tr_idx], y_train[tr_idx])
                else:
                    est.fit(X_train[tr_idx], y_train[tr_idx])
                try:
                    if scorer is not None:
                        score = float(scorer(est, X_train[te_idx], y_train[te_idx]))
                    else:
                        score = float(scoring_callable(est, X_train[te_idx], y_train[te_idx]))
                except Exception:
                    score = float(scoring_callable(est, X_train[te_idx], y_train[te_idx]))
            except Exception as e:
                score = float("-inf")
                if verbose:
                    print("  fold error during fit/predict/score:", e)
            elapsed = time.time() - t0
            fold_scores.append(score); fold_times.append(elapsed)
            if verbose:
                print(f"  fold {fold_idx}: score={score if np.isfinite(score) else 'FAILED'}, time={elapsed:.1f}s")
        mean_score = float(np.nanmean([s for s in fold_scores if np.isfinite(s)])) if len(fold_scores) > 0 else float("-inf")
        rec = {"params_repr": p_repr, "params": params, "fold_scores": fold_scores, "fold_times": fold_times,
               "mean_score": mean_score, "timestamp": time.time()}
        if run_tag is not None:
            rec["tag"] = run_tag
        history.append(rec)
        try:
            _save_json_atomic(path=checkpoint_path, data=history)
            if verbose:
                print("  Saved checkpoint to", checkpoint_path)
            # mirror checkpoint to out_dir if provided (non-overwriting via timestamped dest)
            if out_dir is not None:
                try:
                    dest_name = f"bayes_search_checkpoint_{run_tag or 'run'}_{timestamp_str()}.json"
                    # create a local copy path (it already exists as checkpoint_path), so just upload
                    save_local_or_gcs(checkpoint_path, out_dir, dest_basename=dest_name, requester_pays_project=requester_pays_project)
                except Exception as e:
                    if checkpoint_verbose:
                        print("  Warning: failed to mirror checkpoint to out_dir:", e)
        except Exception:
            pass
        if mean_score > best_score:
            best_score = mean_score; best_params = params
            try:
                best_clone = clone(estimator)
                if set_params:
                    try:
                        best_clone.set_params(**set_params)
                    except Exception:
                        pass
                if type(best_clone).__name__ == "BayesLogisticClassifier":
                    try:
                        best_clone.fit(X_train, y_train, **fit_kwargs)
                    except TypeError:
                        best_clone.fit(X_train, y_train)
                else:
                    best_clone.fit(X_train, y_train)
                best_est = best_clone
                if verbose:
                    print("  -> New best found and refit on full tuning set.")
            except Exception as e:
                if verbose:
                    print("  -> Refit of best failed:", e)
    return best_est, best_params, history

# ----------------- evaluation helpers (same) -----------------
def evaluate_classification(y_true, y_pred_label, y_pred_prob):
    y_true = np.asarray(y_true)
    y_pred_label = np.asarray(y_pred_label)
    y_pred_prob = np.asarray(y_pred_prob)
    try:
        prec = float(precision_score(y_true, y_pred_label, zero_division=0))
        rec = float(recall_score(y_true, y_pred_label, zero_division=0))
        f1 = float(f1_score(y_true, y_pred_label, zero_division=0))
    except Exception:
        prec = rec = f1 = float("nan")
    try:
        ap = float(average_precision_score(y_true, y_pred_prob))
    except Exception:
        ap = float("nan")
    try:
        roc = float(roc_auc_score(y_true, y_pred_prob))
    except Exception:
        roc = float("nan")
    try:
        brier = float(brier_score_loss(y_true, y_pred_prob))
    except Exception:
        brier = float("nan")
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred_label).ravel()
        tn, fp, fn, tp = int(tn), int(fp), int(fn), int(tp)
    except Exception:
        tn = fp = fn = tp = None
    return {
        "precision": prec, "recall": rec, "f1": f1,
        "average_precision": ap, "roc_auc": roc, "brier": brier,
        "tn": tn, "fp": fp, "fn": fn, "tp": tp
    }

# ----------------- Bayes wrapper (sketch, similar to earlier) -----------------
class BayesLogisticClassifier:
    def __init__(self, prior_scale: float = 1.0, advi_iters: int = 10000, n_draws: int = 400, seed: int = RANDOM_STATE):
        self.prior_scale = float(prior_scale)
        self.advi_iters = int(advi_iters)
        self.n_draws = int(n_draws)
        self.seed = int(seed)
        self.posterior_coef_ = None
        self.posterior_intercept_ = None
        self.feature_names_in_ = None
        self.is_fitted_ = False

    def get_params(self, deep=True):
        return {"prior_scale": self.prior_scale, "advi_iters": self.advi_iters, "n_draws": self.n_draws, "seed": self.seed}

    def set_params(self, **params):
        for k, v in params.items():
            if hasattr(self, k):
                setattr(self, k, v)
        return self

    def fit(self, X, y, feature_names=None, verbose=False):
        if not PM_AVAILABLE:
            raise RuntimeError("PyMC3 not available; cannot fit Bayesian model.")
        X = np.asarray(X, dtype=float)
        y = np.asarray(y, dtype=int).ravel()
        n_samples, n_features = X.shape
        if feature_names is None:
            feature_names = [f"f{i}" for i in range(n_features)]
        self.feature_names_in_ = list(feature_names)
        # standardize
        self._x_mean = X.mean(axis=0)
        self._x_std = X.std(axis=0)
        self._x_std[self._x_std == 0.0] = 1.0
        Xs = (X - self._x_mean) / self._x_std
        with pm.Model() as model:
            sigma = pm.HalfNormal("sigma", sigma=self.prior_scale)
            intercept = pm.Normal("intercept", mu=0.0, sigma=self.prior_scale)
            coeffs = pm.Normal("coeffs", mu=0.0, sigma=self.prior_scale, shape=n_features)
            logits = intercept + pm.math.dot(Xs, coeffs)
            y_obs = pm.Bernoulli("y_obs", logit_p=logits, observed=y)
            approx = pm.fit(n=self.advi_iters, method="advi", random_seed=self.seed, progressbar=verbose)
            post_samples = approx.sample(draws=self.n_draws)
        try:
            if hasattr(post_samples, "posterior"):
                ds = post_samples.posterior
                coef_arr = np.asarray(ds["coeffs"].stack(draws=("chain", "draw")).values).reshape(-1, n_features)
                intercept_arr = np.asarray(ds["intercept"].stack(draws=("chain", "draw")).values).reshape(-1)
            elif isinstance(post_samples, dict):
                coef_arr = np.asarray(post_samples["coeffs"])
                intercept_arr = np.asarray(post_samples["intercept"]).reshape(-1)
            else:
                idata = az.from_pymc3(trace=post_samples)
                ds = idata.posterior
                coef_arr = np.asarray(ds["coeffs"].stack(draws=("chain", "draw")).values).reshape(-1, n_features)
                intercept_arr = np.asarray(ds["intercept"].stack(draws=("chain", "draw")).values).reshape(-1)
        except Exception as e:
            raise RuntimeError("Unable to extract posterior samples: " + str(e))
        self.posterior_coef_ = coef_arr
        self.posterior_intercept_ = intercept_arr
        self.is_fitted_ = True
        return self

    def predict_proba(self, X):
        if not getattr(self, "is_fitted_", False):
            raise RuntimeError("Estimator not fitted.")
        X = np.asarray(X, dtype=float)
        Xs = (X - self._x_mean) / self._x_std
        logits = Xs.dot(self.posterior_coef_.T) + self.posterior_intercept_.reshape(1, -1)
        probs = 1.0 / (1.0 + np.exp(-logits))
        p_mean = probs.mean(axis=1)
        return np.vstack([1.0 - p_mean, p_mean]).T

    def predict(self, X, threshold=0.5):
        probs = self.predict_proba(X)[:, 1]
        return (probs >= threshold).astype(int)

# ----------------- main -----------------
def main(args):
    local_tmp = tempfile.mkdtemp(prefix="bayes_out_")
    try:
        if PM_AVAILABLE:
            print("PyMC3 available: Bayesian path enabled.")
        else:
            print("PyMC3 not available: falling back to sklearn LogisticRegression.")

        print("Loading processed arrays from:", args.npz)
        data = load_npz(args.npz, requester_pays_project=getattr(args, "requester_pays_project", None), use_gcsfs=getattr(args, "use_gcsfs", False))

        if 'X_train' not in data:
            raise RuntimeError("processed_dataset must contain X_train/X_val/X_test arrays.")

        X_train = data['X_train']; X_val = data['X_val']; X_test = data['X_test']
        feature_cols = list(data['feature_cols']) if 'feature_cols' in data else None

        # Load binary targets
        if 'y_train_bin' in data and 'y_val_bin' in data and 'y_test_bin' in data:
            y_train_bin = np.asarray(data['y_train_bin'])
            y_val_bin = np.asarray(data['y_val_bin'])
            y_test_bin = np.asarray(data['y_test_bin'])
        else:
            print("Binary targets not found in NPZ — attempting to load from samples CSVs.")
            train_df, val_df, test_df = load_samples_csvs(args.samples_dir, requester_pays_project=getattr(args, "requester_pays_project", None), use_gcsfs=getattr(args, "use_gcsfs", False))
            if 'y_bin' not in train_df.columns:
                raise RuntimeError("y_bin column not found in samples CSVs nor NPZ.")
            y_train_bin = train_df['y_bin'].astype(int).values
            y_val_bin = val_df['y_bin'].astype(int).values
            y_test_bin = test_df['y_bin'].astype(int).values

        print(f"Shapes: X_train={X_train.shape}, X_val={X_val.shape}, X_test={X_test.shape}")

        # try to load sample CSVs for metadata
        try:
            train_df, val_df, test_df = load_samples_csvs(args.samples_dir, requester_pays_project=getattr(args, "requester_pays_project", None), use_gcsfs=getattr(args, "use_gcsfs", False))
        except Exception:
            train_df = val_df = test_df = None

        # optionally scale features when using sklearn fallback (recommended)
        scaler = None
        if not PM_AVAILABLE and args.scale_features:
            scaler = StandardScaler()
            scaler.fit(X_train)
            X_train = scaler.transform(X_train)
            X_val = scaler.transform(X_val)
            X_test = scaler.transform(X_test)
            # Save scaler locally then upload to out_dir
            tmp_scaler = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib")
            tmp_scaler.close()
            joblib.dump(scaler, tmp_scaler.name)
            dest_name = f"standard_scaler_bayes_fallback_{args.tag}_{timestamp_str()}.joblib"
            save_local_or_gcs(tmp_scaler.name, args.out_dir, dest_basename=dest_name, requester_pays_project=getattr(args, "requester_pays_project", None))

        if args.max_train_samples:
            X_train_sub = X_train[:args.max_train_samples]; y_train_sub = y_train_bin[:args.max_train_samples]
            print("Using subsample for tuning:", X_train_sub.shape)
        else:
            X_train_sub, y_train_sub = X_train, y_train_bin

        # CV
        if args.use_predefined_split:
            X_comb = np.vstack([X_train_sub, X_val]); y_comb = np.concatenate([y_train_sub, y_val_bin[:len(X_val)]])
            test_fold = np.array([ -1 ] * len(X_comb))
            # training portion indices 0..len(X_train_sub)-1 are train, rest are validation
            test_fold[:len(X_train_sub)] = -1
            test_fold[len(X_train_sub):] = 0
            ps = PredefinedSplit(test_fold)
            cv = ps
        else:
            cv = TimeSeriesSplit(n_splits=args.n_splits)

        # Build estimator and parameter distribution
        if PM_AVAILABLE and args.use_bayesian:
            base_est = BayesLogisticClassifier(prior_scale=args.prior_scale, advi_iters=args.advi_iters, n_draws=args.n_draws, seed=RANDOM_STATE)
            param_dist = {
                "prior_scale": [0.5, 1.0, 2.0]
            }
        else:
            # sklearn fallback
            base_est = LogisticRegression(solver=args.solver, max_iter=args.max_iter, tol=args.tol, class_weight=(None if args.class_weight == "none" else args.class_weight), random_state=RANDOM_STATE)
            param_dist = {
                "C": [1e-3, 1e-2, 1e-1, 1.0, 10.0, 100.0],
                "penalty": ["l2"] if args.solver in ("saga", "lbfgs", "newton-cg") else ["l1", "l2"]
            }

        # Run randomized search with checkpointing; checkpoint saved locally and mirrored to out_dir
        checkpoint_name = f"bayes_search_checkpoint_{args.tag}_{timestamp_str()}.json"
        checkpoint_local_path = os.path.join(tempfile.gettempdir(), checkpoint_name)
        best_est, best_params, history = checkpointed_random_search(
            base_est, param_dist, X_train_sub, y_train_sub, cv,
            n_iter=args.n_iter, random_state=RANDOM_STATE, checkpoint_path=checkpoint_local_path,
            scorer=None, verbose=args.verbose, checkpoint_verbose=args.verbose,
            run_tag=args.tag, out_dir=args.out_dir, requester_pays_project=getattr(args, "requester_pays_project", None)
        )

        if best_est is None:
            raise RuntimeError("No estimator succeeded during tuning.")

        # Evaluate best on validation and test sets
        y_val_prob = best_est.predict_proba(X_val)[:, 1] if hasattr(best_est, "predict_proba") else best_est.decision_function(X_val)
        y_val_pred = (y_val_prob >= 0.5).astype(int)
        val_metrics = evaluate_classification(y_val_bin, y_val_pred, y_val_prob)

        y_test_prob = best_est.predict_proba(X_test)[:, 1] if hasattr(best_est, "predict_proba") else best_est.decision_function(X_test)
        y_test_pred = (y_test_prob >= 0.5).astype(int)
        test_metrics = evaluate_classification(y_test_bin, y_test_pred, y_test_prob)

        results = {
            "tag": args.tag,
            "timestamp": timestamp_str(),
            "best_params": best_params,
            "val_metrics": val_metrics,
            "test_metrics": test_metrics,
            "history_len": len(history)
        }

        # Save results JSON locally then upload/move to out_dir
        tmp_results = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
        tmp_results.close()
        _save_json_atomic(tmp_results.name, results)
        res_name = f"bayes_results_{args.tag}_{timestamp_str()}.json"
        save_local_or_gcs(tmp_results.name, args.out_dir, dest_basename=res_name, requester_pays_project=getattr(args, "requester_pays_project", None))

        # Save the final estimator (joblib for sklearn or joblib wrapper for Bayes class)
        tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".joblib")
        tmp_model.close()
        try:
            joblib.dump(best_est, tmp_model.name)
            model_name = f"bayes_model_{args.tag}_{timestamp_str()}.joblib"
            save_local_or_gcs(tmp_model.name, args.out_dir, dest_basename=model_name, requester_pays_project=getattr(args, "requester_pays_project", None))
        except Exception as e:
            print("Warning: failed to joblib.dump model:", e)
            try:
                os.remove(tmp_model.name)
            except Exception:
                pass

        # Mirror the local checkpoint if exists (we already attempted mirroring during search,
        # but ensure final checkpoint is saved to out_dir)
        if os.path.exists(checkpoint_local_path):
            try:
                dest_name = f"bayes_search_checkpoint_{args.tag}_{timestamp_str()}.json"
                save_local_or_gcs(checkpoint_local_path, args.out_dir, dest_basename=dest_name, requester_pays_project=getattr(args, "requester_pays_project", None))
            except Exception as e:
                print("Warning: failed to upload final checkpoint:", e)

        print("Finished. Results and model saved with tag:", args.tag)
    finally:
        try:
            shutil.rmtree(local_tmp)
        except Exception:
            pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train Bayesian/logistic binary classifier (GCS-ready, tagged)")
    parser.add_argument("--npz", type=str, default=DEFAULT_NPZ, help="Path to processed dataset .npz")
    parser.add_argument("--samples_dir", type=str, default=DEFAULT_SAMPLES_DIR, help="Directory/gs:// path with samples CSVs")
    parser.add_argument("--out_dir", type=str, default=DEFAULT_OUT_DIR, help="Output directory (local or gs://) where results/models are saved")
    parser.add_argument("--n_iter", type=int, default=8, help="Randomized search iterations")
    parser.add_argument("--n_jobs", type=int, default=1, help="Parallel jobs for fit (if used)")
    parser.add_argument("--max_iter", type=int, default=5000, help="Max iterations for sklearn LogisticRegression")
    parser.add_argument("--solver", type=str, default="saga", help="Solver for sklearn LogisticRegression")
    parser.add_argument("--tol", type=float, default=1e-4, help="Tolerance for sklearn LogisticRegression")
    parser.add_argument("--scale_features", action="store_true", help="Whether to StandardScale features for sklearn fallback")
    parser.add_argument("--class_weight", type=str, default="none", help="class_weight for sklearn ('balanced' or 'none')")
    parser.add_argument("--verbose", action="store_true", help="Verbose output")
    parser.add_argument("--use_predefined_split", action="store_true", help="Use predefined split (train+val arrays) for CV")
    parser.add_argument("--n_splits", type=int, default=5, help="n_splits for TimeSeriesSplit CV")
    parser.add_argument("--max_train_samples", type=int, default=0, help="If >0, subsample training rows for tuning")
    parser.add_argument("--use_gcsfs", action="store_true", help="Force use of gcsfs instead of google-cloud-storage client")
    parser.add_argument("--requester_pays_project", type=str, default=None, help="Project to use for requester pays buckets")
    # Bayesian specific options
    parser.add_argument("--use_bayesian", action="store_true", help="Attempt to use PyMC3 Bayesian classifier if available")
    parser.add_argument("--prior_scale", type=float, default=1.0, help="Prior scale for Bayesian classifier")
    parser.add_argument("--advi_iters", type=int, default=10000, help="ADVI iterations for Bayesian classifier")
    parser.add_argument("--n_draws", type=int, default=400, help="Posterior draws for Bayesian classifier")
    # TAG option
    parser.add_argument("--tag", type=str, default="pig", help="Tag to attach to outputs (default: pig)")

    # Use parse_known_args() when running inside interactive environments (Jupyter/IPython)
    # to ignore extra kernel arguments like '-f /.../kernel-*.json' which would otherwise
    # cause argparse to raise "unrecognized arguments".
    if "__file__" not in globals():
        args, _ = parser.parse_known_args()
    else:
        args = parser.parse_args()
    main(args)

PyMC3 not available: falling back to sklearn LogisticRegression.
Loading processed arrays from: gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz
Shapes: X_train=(948281, 39), X_val=(74421, 39), X_test=(495447, 39)




Finished. Results and model saved with tag: pig


In [None]:
#!/usr/bin/env python3
"""
ConvLSTM training script (GCS-ready) with additional JSON metrics written for validation & test,
and optional artifact tagging to avoid overwriting outputs when running multiple lakes.

New behavior:
 - Adds a CLI flag --tag (default "pig"). The tag is appended to artifact filenames:
   e.g. predictions_test_convlstm_pig.csv, test_metrics_pig.json, best_model_pig.keras, history_pig.csv, etc.
 - CustomValMetrics accepts a tag and writes per-epoch val CSV/JSON files suffixed with the tag.
 - Test predictions and metrics are written with the tag so multiple runs (different lakes) don't overwrite each other.

Everything else in the original script (GCS handling, Sequence loader, training callbacks) is preserved.
"""
from __future__ import annotations
import os
import sys
import math
import argparse
import tempfile
import shutil
import hashlib
import io
import json
from glob import glob
from types import SimpleNamespace
from collections import OrderedDict
from pathlib import Path
from functools import lru_cache

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks
from tensorflow.keras.utils import Sequence
from tensorflow.keras import mixed_precision

# optional rasterio for .tif reading
try:
    import rasterio
except Exception:
    rasterio = None

# optional PIL fallback for TIFF/other image formats
try:
    from PIL import Image
except Exception:
    Image = None

# optional gcsfs
try:
    import gcsfs
except Exception:
    gcsfs = None

# preferred storage client
try:
    from google.cloud import storage
except Exception:
    storage = None

# allow explicit service-account credentials for workers (if provided via env var)
DEFAULT_GOOGLE_CREDS = None
try:
    from google.oauth2 import service_account as _sa
    sa_path = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS")
    if sa_path and os.path.exists(sa_path):
        try:
            DEFAULT_GOOGLE_CREDS = _sa.Credentials.from_service_account_file(sa_path)
            print("Loaded service-account credentials from:", sa_path)
        except Exception as e:
            print("Failed to load GOOGLE_APPLICATION_CREDENTIALS:", e)
            DEFAULT_GOOGLE_CREDS = None
except Exception:
    DEFAULT_GOOGLE_CREDS = None

# sklearn metrics for aggregated validation evaluation and final metrics
try:
    from sklearn.metrics import (
        average_precision_score,
        roc_auc_score,
        precision_recall_fscore_support,
        precision_recall_curve,
    )
except Exception:
    average_precision_score = None
    roc_auc_score = None
    precision_recall_fscore_support = None
    precision_recall_curve = None

# -------------------- GPU / mixed-precision setup --------------------
try:
    gpus = tf.config.list_physical_devices("GPU")
    if gpus:
        for g in gpus:
            try:
                tf.config.experimental.set_memory_growth(g, True)
            except Exception:
                pass
        mixed_precision.set_global_policy("mixed_float16")
        print("Mixed precision enabled (policy = mixed_float16). GPUs detected:", gpus)
    else:
        print("No GPUs detected; running in CPU/fp32 mode.")
except Exception as e:
    print("Warning: GPU initialization failed or not available:", e)

# -------------------- Defaults --------------------
DEFAULTS = {
    "processed_npz": "gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz",
    "samples_dir": "gs://final_data_kgc2/models/Pigeon_binary/",
    "image_dir": "gs://final_data_kgc2/Final_data/Pigeon_80m/",
    "out_dir": "gs://final_data_kgc2/convlstm_results/",
    "seq_len": 14,
    "img_h": 80,
    "img_w": 80,
    "img_ch": 1,
    "batch_size": 16,
    "epochs": 150,
    "patience": 12,
    "use_attention": True,
    "use_env": True,
    "use_mask_channel": True,
    "threshold": 20.0,
    "lr": 1e-4,
    "clipnorm": 1.0,
    "min_valid_pixels": 0,
    "compute_pos_weight": True,
    "pos_weight": None,
    "pad_value": 0.0
}

# -------------------- Helpers --------------------
def is_gcs_path(p: str) -> bool:
    return isinstance(p, str) and p.startswith("gs://")

def gcs_join(prefix: str, *parts: str) -> str:
    p = prefix.rstrip('/')
    for part in parts:
        p = p.rstrip('/') + '/' + str(part).lstrip('/')
    return p

def _hash_path(s: str) -> str:
    return hashlib.sha1(s.encode("utf-8")).hexdigest()

def parse_date_from_name(name: str):
    import re
    date_pat1 = re.compile(r"(\d{4}-\d{2}-\d{2})")
    date_pat2 = re.compile(r"(\d{8})")
    m = date_pat1.search(name)
    if m:
        return pd.to_datetime(m.group(1)).normalize()
    m = date_pat2.search(name)
    if m:
        return pd.to_datetime(m.group(1), format="%Y%m%d").normalize()
    raise ValueError(f"Cannot parse date from filename: {name}")

# -------------------- Metrics helpers --------------------
def safe_roc_auc(y_true, y_score):
    try:
        if average_precision_score is None or roc_auc_score is None:
            return None
        if len(np.unique(y_true)) < 2:
            return None
        return float(roc_auc_score(y_true, y_score))
    except Exception:
        return None

def best_threshold_from_pr_curve(y_true, y_score):
    """Return (best_threshold, best_f1, precision_at_best, recall_at_best)."""
    try:
        if precision_recall_curve is None:
            return None, 0.0, None, None
        precision, recall, thresholds = precision_recall_curve(y_true, y_score)
        if thresholds.size == 0:
            return None, 0.0, None, None
        p = precision[:-1]
        r = recall[:-1]
        denom = (p + r)
        f1s = np.zeros_like(p)
        valid = denom > 0
        f1s[valid] = 2 * p[valid] * r[valid] / denom[valid]
        best_idx = int(np.nanargmax(f1s))
        return float(thresholds[best_idx]), float(f1s[best_idx]), float(p[best_idx]), float(r[best_idx])
    except Exception:
        return None, 0.0, None, None

def compute_sample_metrics_from_arrays(y_true, y_score):
    """Compute precision/recall/f1 at 0.5, best threshold from PR-curve, AP and ROC."""
    out = {}
    if len(y_true) == 0:
        return {
            "n": 0,
            "n_pos": 0,
            "precision_at_0.5": None,
            "recall_at_0.5": None,
            "f1_at_0.5": None,
            "best_threshold_by_f1": None,
            "precision_at_best": None,
            "recall_at_best": None,
            "f1_at_best": None,
            "average_precision": None,
            "roc_auc": None
        }
    # threshold 0.5
    try:
        y_pred05 = (y_score >= 0.5).astype(int)
        if precision_recall_fscore_support is not None:
            prec05, rec05, f105, _ = precision_recall_fscore_support(y_true, y_pred05, average="binary", zero_division=0)
        else:
            prec05 = rec05 = f105 = None
    except Exception:
        prec05 = rec05 = f105 = None

    # AP
    ap = None
    try:
        if average_precision_score is not None:
            ap = float(average_precision_score(y_true, y_score))
    except Exception:
        ap = None

    # ROC
    roc = safe_roc_auc(y_true, y_score)

    # best threshold
    best_th, best_f1, p_best, r_best = best_threshold_from_pr_curve(y_true, y_score)
    if best_th is None:
        # fallback to 0.5 metrics
        best_th = 0.5
        best_f1 = float(f105) if f105 is not None else None
        p_best = float(prec05) if prec05 is not None else None
        r_best = float(rec05) if rec05 is not None else None

    out = {
        "n": int(len(y_true)),
        "n_pos": int(int(np.sum(y_true == 1))),
        "precision_at_0.5": (float(prec05) if prec05 is not None else None),
        "recall_at_0.5": (float(rec05) if rec05 is not None else None),
        "f1_at_0.5": (float(f105) if f105 is not None else None),
        "best_threshold_by_f1": (float(best_th) if best_th is not None else None),
        "precision_at_best": (float(p_best) if p_best is not None else None),
        "recall_at_best": (float(r_best) if r_best is not None else None),
        "f1_at_best": (float(best_f1) if best_f1 is not None else None),
        "average_precision": (float(ap) if ap is not None else None),
        "roc_auc": (float(roc) if roc is not None else None)
    }
    return out

# -------------------- GCS helpers (atomic download + retry) --------------------
def _make_storage_client(project: str | None = None):
    """Create a google.cloud.storage.Client using DEFAULT_GOOGLE_CREDS if available."""
    if storage is None:
        raise RuntimeError("google-cloud-storage not installed")
    if DEFAULT_GOOGLE_CREDS is not None:
        # storage.Client accepts project and credentials kwargs
        return storage.Client(project=project, credentials=DEFAULT_GOOGLE_CREDS)
    return storage.Client(project=project)

def _make_gcsfs(token=None, project: str | None = None):
    """Create gcsfs filesystem using provided token or DEFAULT_GOOGLE_CREDS if available."""
    if gcsfs is None:
        raise RuntimeError("gcsfs not installed")
    if token is None and DEFAULT_GOOGLE_CREDS is not None:
        token = DEFAULT_GOOGLE_CREDS
    # gcsfs accepts google-auth Credentials objects as token
    return gcsfs.GCSFileSystem(token=token, project=project)

def upload_file_to_gcs(local_path: str, gs_uri: str, requester_pays_project: str | None = None):
    if not is_gcs_path(gs_uri):
        raise ValueError("gs_uri must be gs:// path")
    if storage is not None:
        client = _make_storage_client(project=requester_pays_project)
        _, rest = gs_uri.split("gs://", 1)
        bucket_name, _, blob_path = rest.partition("/")
        bucket = client.lookup_bucket(bucket_name)
        if bucket is None:
            raise FileNotFoundError(f"Bucket not found: gs://{bucket_name}")
        blob = bucket.blob(blob_path)
        blob.upload_from_filename(local_path)
        return
    if gcsfs is not None:
        fs = _make_gcsfs()
        fs.put(local_path, gs_uri)
        return
    raise RuntimeError("Install google-cloud-storage or gcsfs to upload to GCS")

def upload_dir_to_gcs(local_dir: str, gs_prefix: str, requester_pays_project: str | None = None):
    for root, _dirs, files in os.walk(local_dir):
        for fn in files:
            local_fp = os.path.join(root, fn)
            rel = os.path.relpath(local_fp, local_dir)
            gs_dst = gcs_join(gs_prefix, rel)
            upload_file_to_gcs(local_fp, gs_dst, requester_pays_project=requester_pays_project)

def fetch_gcs_to_local(gcs_path: str, local_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    """
    Download a gs:// path to local_path atomically.
    Prefer google-cloud-storage client; fallback to gcsfs. Download to a .tmp file and
    os.replace to final path to avoid race conditions where another worker reads
    a zero-length file.
    """
    if not is_gcs_path(gcs_path):
        raise ValueError("gcs_path must be gs://...")
    _, rest = gcs_path.split("gs://", 1)
    bucket_name, _, blob_path = rest.partition("/")

    # Preferred path: google-cloud-storage client
    if storage is not None and not use_gcsfs:
        client = _make_storage_client(project=requester_pays_project)
        bucket = client.lookup_bucket(bucket_name)
        if bucket is None:
            raise FileNotFoundError(f"Bucket not found: gs://{bucket_name} (check credentials/project)")
        blob = bucket.blob(blob_path)
        if not blob.exists():
            raise FileNotFoundError(f"Blob not found: {gcs_path}")
        # quick sanity checks
        if blob.size is not None and blob.size < 16:
            raise RuntimeError(f"Blob seems unexpectedly small ({blob.size} bytes). Possible permission/404 response.")
        if blob.content_type and blob.content_type.startswith("text/"):
            raise RuntimeError(f"Blob content_type is '{blob.content_type}' — download may be an HTML error page.")
        tmp_path = local_path + ".tmp"
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        try:
            blob.download_to_filename(tmp_path)
            # ensure file non-empty
            if not os.path.exists(tmp_path) or os.path.getsize(tmp_path) == 0:
                if os.path.exists(tmp_path):
                    try:
                        os.remove(tmp_path)
                    except Exception:
                        pass
                raise RuntimeError(f"Downloaded file for {gcs_path} is zero bytes or missing.")
            os.replace(tmp_path, local_path)
        except Exception:
            # cleanup partial tmp file, then re-raise
            try:
                if os.path.exists(tmp_path):
                    os.remove(tmp_path)
            except Exception:
                pass
            raise
        return local_path

    # fallback to gcsfs (also write atomically)
    if gcsfs is None:
        raise RuntimeError("gcsfs not installed and storage client not available")
    fs = _make_gcsfs(project=requester_pays_project)
    tmp_path = local_path + ".tmp"
    os.makedirs(os.path.dirname(local_path), exist_ok=True)
    try:
        with fs.open(gcs_path, "rb") as src, open(tmp_path, "wb") as dst:
            shutil.copyfileobj(src, dst)
    except Exception:
        try:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)
        except Exception:
            pass
        raise
    if not os.path.exists(tmp_path) or os.path.getsize(tmp_path) == 0:
        try:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)
        except Exception:
            pass
        raise RuntimeError(f"gcsfs download produced zero-length file for {gcs_path}")
    os.replace(tmp_path, local_path)
    return local_path

# -------------------- Listing image files --------------------
def list_image_files(image_dir: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    file_map = OrderedDict()
    if is_gcs_path(image_dir):
        if storage is not None and not use_gcsfs:
            client = _make_storage_client(project=requester_pays_project)
            _, rest = image_dir.split("gs://", 1)
            bucket_name, _, prefix = rest.partition("/")
            prefix = prefix.rstrip('/') + '/'
            blobs = client.list_blobs(bucket_name, prefix=prefix)
            for b in blobs:
                name = b.name
                if name.lower().endswith((".npy", ".tif", ".tiff")):
                    bn = os.path.basename(name)
                    try:
                        d = parse_date_from_name(bn)
                    except Exception:
                        continue
                    file_map[pd.Timestamp(d)] = f"gs://{bucket_name}/{name}"
            return file_map
        # use gcsfs fallback with explicit token if available
        fs = _make_gcsfs(project=requester_pays_project)
        pattern = image_dir.rstrip('/') + '/*'
        files = fs.glob(pattern)
        for f in sorted(files):
            bn = os.path.basename(f)
            if not bn.lower().endswith((".npy", ".tif", ".tiff")):
                continue
            try:
                d = parse_date_from_name(bn)
            except Exception:
                continue
            file_map[pd.Timestamp(d)] = f if f.startswith("gs://") else "gs://" + f
        return file_map
    # local path
    files = sorted(glob(os.path.join(image_dir, "*")))
    for f in files:
        if not f.lower().endswith((".npy", ".tif", ".tiff")):
            continue
        try:
            d = parse_date_from_name(os.path.basename(f))
        except Exception:
            continue
        file_map[pd.Timestamp(d)] = f
    return file_map

# -------------------- Local cache-aware loader --------------------
CACHE_DIR = None
def _ensure_cache_dir():
    global CACHE_DIR
    if CACHE_DIR is None:
        CACHE_DIR = tempfile.mkdtemp(prefix="image_cache_")
    return CACHE_DIR

@lru_cache(maxsize=2048)
def cached_load(path: str, target_h: int | None = None, target_w: int | None = None, requester_pays_project: str | None = None, use_gcsfs: bool = False, pad_value: float = 0.0):
    """
    Load an image path (local or gs://) into a float32 numpy array.
    For GCS paths, download atomically into a cache dir. Retry once with fallback
    method in case of transient failures.
    """
    cache_dir = _ensure_cache_dir()
    if is_gcs_path(path):
        h = _hash_path(path)
        ext = os.path.splitext(path)[1] or ".npy"
        local_fp = os.path.join(cache_dir, f"{h}{ext}")
        if not os.path.exists(local_fp):
            # first attempt: preferred client
            try:
                fetch_gcs_to_local(path, local_fp, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs)
            except Exception as e:
                # cleanup any partial files and retry once with toggled use_gcsfs
                try:
                    if os.path.exists(local_fp):
                        os.remove(local_fp)
                except Exception:
                    pass
                try:
                    fetch_gcs_to_local(path, local_fp, requester_pays_project=requester_pays_project, use_gcsfs=not use_gcsfs)
                except Exception as e2:
                    raise RuntimeError(f"Failed to download {path}: {e} ; retry also failed: {e2}")
        # ensure local file is non-empty
        if not os.path.exists(local_fp) or os.path.getsize(local_fp) == 0:
            try:
                if os.path.exists(local_fp):
                    os.remove(local_fp)
            except Exception:
                pass
            raise RuntimeError(f"Downloaded file for {path} is missing or zero bytes; check blob/permissions.")
    else:
        local_fp = path

    # Fast path: .npy files
    if local_fp.lower().endswith(".npy"):
        try:
            arr = np.load(local_fp)
            return arr.astype("float32")
        except Exception as e:
            raise RuntimeError(f"Failed to np.load file {local_fp}: {e}")

    # Try rasterio first, but catch and fall back
    rasterio_err = None
    if rasterio is not None:
        try:
            with rasterio.open(local_fp) as src:
                arr = src.read(1).astype("float32")
                if src.nodata is not None:
                    arr[arr == src.nodata] = np.nan
                if target_h is not None and target_w is not None and arr.shape != (target_h, target_w):
                    arr = np.array(tf.image.resize(arr[..., np.newaxis], (target_h, target_w)))[..., 0]
                return arr.astype("float32")
        except Exception as e:
            rasterio_err = e

    # Try numpy.load in case the file is a .npy saved without .npy extension
    try:
        arr_try = np.load(local_fp)
        if isinstance(arr_try, np.ndarray):
            if arr_try.ndim == 3:
                arr_try = arr_try[..., 0]
            arr = arr_try.astype("float32")
            if target_h is not None and target_w is not None and arr.shape != (target_h, target_w):
                arr = np.array(tf.image.resize(arr[..., np.newaxis], (target_h, target_w)))[..., 0]
            print(f"Warning: rasterio failed to open {local_fp}; loaded with numpy.load instead.")
            return arr.astype("float32")
    except Exception:
        pass

    # Try PIL fallback
    if Image is not None:
        try:
            with Image.open(local_fp) as im:
                im_arr = np.array(im)
                if im_arr.ndim == 3:
                    im_arr = im_arr[..., 0]
                arr = im_arr.astype("float32")
                if target_h is not None and target_w is not None and arr.shape != (target_h, target_w):
                    arr = np.array(tf.image.resize(arr[..., np.newaxis], (target_h, target_w)))[..., 0]
                print(f"Warning: rasterio failed to open {local_fp}; loaded with PIL.Image instead.")
                return arr.astype("float32")
        except Exception:
            pass

    # Diagnostics if everything failed
    header = b""
    try:
        with open(local_fp, "rb") as fh:
            header = fh.read(512)
    except Exception:
        pass
    size = os.path.getsize(local_fp) if os.path.exists(local_fp) else None
    hint = ""
    try:
        htxt = header.decode("utf-8", errors="ignore").lower()
        if "<html" in htxt or "access denied" in htxt or "error" in htxt:
            hint = "Downloaded file looks like an HTML error page (AccessDenied/404) — check GCS path and credentials or requester-pays."
    except Exception:
        pass
    err_msg = f"Failed to read image file {local_fp} with rasterio/numpy/PIL. "
    if rasterio_err is not None:
        err_msg += f"rasterio error: {rasterio_err}. "
    err_msg += f"File size: {size}. Header sample (utf8): {repr(header[:200])}. {hint}"
    raise RuntimeError(err_msg)

# -------------------- Sequence generator (threshold-before-resize) --------------------
class ConvLSTMSequence(Sequence):
    def __init__(self, samples_df, X_features, image_map, params, use_mask_channel=False, requester_pays_project: str | None = None, use_gcsfs: bool = False, pad_value: float = 0.0, image_binary: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.df = samples_df.reset_index(drop=True)
        self.X_feat = X_features
        self.image_map = image_map
        self.dates_available = sorted(list(image_map.keys()))
        self.time_steps = params.seq_len
        self.img_h = params.img_h
        self.img_w = params.img_w
        self.channels = params.img_ch
        self.use_env = params.use_env
        self.bs = params.batch_size
        self.threshold_bloom = params.threshold_bloom
        self.min_valid_pixels = getattr(params, "min_valid_pixels", 0)
        self.pos_weight = getattr(params, "pos_weight", 1.0)
        self.use_mask_channel = use_mask_channel
        self.req_project = requester_pays_project
        self.use_gcsfs = use_gcsfs
        self.pad_value = pad_value
        self.image_binary = bool(image_binary)

        self.entries = []
        skipped = 0
        for i, row in self.df.iterrows():
            date_t = pd.to_datetime(row["date_t"]).normalize()
            history_start = date_t - pd.Timedelta(days=self.time_steps-1)
            history_dates = [d for d in self.dates_available if history_start <= d <= date_t]
            if not history_dates:
                skipped += 1
                continue
            date_target = pd.to_datetime(row["date_target"]).normalize()
            if date_target not in image_map:
                skipped += 1
                continue
            tgt_path = image_map[date_target]
            try:
                tgt = cached_load(tgt_path, self.img_h, self.img_w, requester_pays_project=self.req_project, use_gcsfs=self.use_gcsfs)
                if tgt.ndim == 3:
                    tgt = tgt[..., 0]
                n_valid = int(np.sum(~np.isnan(tgt)))
            except Exception:
                n_valid = 0
            if n_valid < self.min_valid_pixels:
                skipped += 1
                continue
            sample_y_bin = None
            sample_y_raw = None
            if "y_bin" in row.index:
                try:
                    sample_y_bin = int(row["y_bin"])
                except Exception:
                    sample_y_bin = None
            if "y_raw" in row.index:
                sample_y_raw = row["y_raw"]
            self.entries.append({
                "row_idx": i,
                "history_dates": history_dates,
                "target_path": tgt_path,
                "date_t": date_t,
                "date_target": date_target,
                "sample_y_bin": sample_y_bin,
                "sample_y_raw": sample_y_raw
            })
        if len(self.entries) == 0:
            raise ValueError("No valid samples found; check image_dir, sample dates, and min_valid_pixels.")
        print(f"Initialized ConvLSTMSequence with {len(self.entries)} valid entries (skipped {skipped}).")

    def __len__(self):
        return math.ceil(len(self.entries)/self.bs)

    def __getitem__(self, idx):
        batch = self.entries[idx*self.bs:(idx+1)*self.bs]
        B = len(batch)
        env_dims = self.X_feat.shape[1] if (self.use_env and self.X_feat is not None and self.X_feat.ndim>1) else 0
        mask_channel = 1 if self.use_mask_channel else 0
        channels_total = self.channels + env_dims + mask_channel

        Xb = np.full((B, self.time_steps, self.img_h, self.img_w, channels_total), fill_value=self.pad_value, dtype="float32")
        yb = np.full((B, self.img_h, self.img_w, 1), np.nan, dtype="float32")
        sample_weight = np.zeros((B, self.img_h, self.img_w, 1), dtype="float32")

        for i, entry in enumerate(batch):
            # --- history images: load/resized images as before ---
            history_imgs, history_valid = [], []
            for history_date in entry["history_dates"]:
                if history_date in self.image_map:
                    img = cached_load(self.image_map[history_date], self.img_h, self.img_w, requester_pays_project=self.req_project, use_gcsfs=self.use_gcsfs)
                    if img.ndim == 3:
                        img = img[..., 0]
                    history_imgs.append(np.nan_to_num(img, nan=self.pad_value).astype("float32"))
                    history_valid.append((~np.isnan(img)).astype("float32"))
            pad_len = self.time_steps - len(history_imgs)
            if pad_len > 0:
                pad_img = np.full((self.img_h, self.img_w), fill_value=self.pad_value, dtype="float32")
                pad_mask = np.zeros((self.img_h, self.img_w), dtype="float32")
                history_imgs = [pad_img]*pad_len + history_imgs
                history_valid = [pad_mask]*pad_len + history_valid
            seq_stack = np.stack(history_imgs, axis=0)[..., np.newaxis]

            # env features
            if self.use_env and env_dims > 0:
                env_vec = self.X_feat[entry["row_idx"]]
                env_maps = np.repeat(env_vec[np.newaxis, np.newaxis, :], self.img_h, axis=0)
                env_maps = np.repeat(env_maps, self.img_w, axis=1)
                env_maps = np.repeat(env_maps[np.newaxis, ...], self.time_steps, axis=0)
                seq_stack = np.concatenate([seq_stack, env_maps], axis=-1)

            # mask channel
            if self.use_mask_channel:
                mask_seq = np.stack(history_valid, axis=0)[..., np.newaxis]
                seq_stack = np.concatenate([seq_stack, mask_seq], axis=-1)

            Xb[i] = seq_stack

            # --- TARGET: load raw (no resize), threshold at native resolution, then nearest-neighbor resize ---
            try:
                tgt_raw = cached_load(entry["target_path"], None, None, requester_pays_project=self.req_project, use_gcsfs=self.use_gcsfs)
            except Exception:
                tgt_raw = None

            if tgt_raw is None:
                y_img = np.zeros((self.img_h, self.img_w), dtype="float32")
                valid_mask_resized = np.zeros((self.img_h, self.img_w), dtype="float32")
            else:
                if tgt_raw.ndim == 3:
                    tgt_raw = tgt_raw[..., 0]
                valid_mask_raw = ~np.isnan(tgt_raw)
                if valid_mask_raw.sum() == 0:
                    y_img = np.zeros((self.img_h, self.img_w), dtype="float32")
                    valid_mask_resized = np.zeros((self.img_h, self.img_w), dtype="float32")
                else:
                    if self.image_binary:
                        tgtf = tgt_raw.astype("float32")
                        y_mask_raw = valid_mask_raw & ((np.isclose(tgtf, 1.0)) | (np.isclose(tgtf, 255.0)) | (tgtf != 0.0))
                    else:
                        thr = float(self.threshold_bloom)
                        y_mask_raw = valid_mask_raw & (tgt_raw.astype("float32") >= thr)

                    # resize boolean mask and valid mask to model resolution using nearest neighbor
                    import tensorflow as _tf
                    y_mask_raw_f = y_mask_raw.astype("float32")[..., np.newaxis]
                    valid_mask_raw_f = valid_mask_raw.astype("float32")[..., np.newaxis]
                    try:
                        y_mask_resized = _tf.image.resize(y_mask_raw_f, (self.img_h, self.img_w), method="nearest").numpy()[..., 0]
                        valid_mask_resized = _tf.image.resize(valid_mask_raw_f, (self.img_h, self.img_w), method="nearest").numpy()[..., 0] > 0.5
                    except Exception:
                        rh, rw = tgt_raw.shape
                        ys = (np.linspace(0, rh-1, self.img_h)).round().astype(int)
                        xs = (np.linspace(0, rw-1, self.img_w)).round().astype(int)
                        mask_nn = y_mask_raw[np.ix_(ys, xs)]
                        valid_nn = valid_mask_raw[np.ix_(ys, xs)]
                        y_mask_resized = mask_nn.astype("float32")
                        valid_mask_resized = valid_nn.astype("bool")

                    y_img = np.where(y_mask_resized > 0.5, 1.0, 0.0).astype("float32")

            yb[i, ..., 0] = y_img
            sw = valid_mask_resized.astype("float32")
            if self.pos_weight is not None and self.pos_weight != 1.0:
                pos_mask = (y_img == 1.0).astype("float32")
                sw = sw + pos_mask * (self.pos_weight - 1.0)
            sample_weight[i, ..., 0] = sw

        return Xb, yb, sample_weight

    def on_epoch_end(self): return

# -------------------- NPZ and CSV loaders (support gs://) --------------------
def load_npz(npz_path: str, requester_pays_project: str | None = None, use_gcsfs: bool = False) -> dict:
    if is_gcs_path(npz_path):
        tmp = tempfile.NamedTemporaryFile(suffix=".npz", delete=False)
        tmp.close()
        fetch_gcs_to_local(npz_path, tmp.name, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs)
        data = np.load(tmp.name, allow_pickle=True)
        result = {k: data[k] for k in data.files}
        try:
            os.unlink(tmp.name)
        except Exception:
            pass
        return result
    else:
        data = np.load(npz_path, allow_pickle=True)
        return {k: data[k] for k in data.files}

def _samples_files_exist(prefix: str, requester_pays_project: str | None = None, use_gcsfs: bool = False) -> bool:
    try:
        for name in ("samples_train.csv", "samples_val.csv", "samples_test.csv"):
            if is_gcs_path(prefix):
                if storage is not None and not use_gcsfs:
                    client = _make_storage_client(project=requester_pays_project)
                    _, rest = prefix.split("gs://", 1)
                    bucket_name, _, prefix_path = rest.partition("/")
                    prefix_path = prefix_path.rstrip('/') + '/'
                    blob = client.bucket(bucket_name).blob(prefix_path + name)
                    if not blob.exists():
                        return False
                else:
                    fs = _make_gcsfs(project=requester_pays_project)
                    target = prefix.rstrip('/') + '/' + name
                    if not fs.exists(target):
                        return False
            else:
                if not os.path.exists(os.path.join(prefix, name)):
                    return False
        return True
    except Exception:
        return False

def load_samples_csvs(samples_dir: str, requester_pays_project: str | None = None, use_gcsfs: bool = False):
    def load_one(path_or_dir, name):
        if is_gcs_path(str(path_or_dir)):
            if storage is not None and not use_gcsfs:
                client = _make_storage_client(project=requester_pays_project)
                _, rest = str(path_or_dir).split("gs://", 1)
                bucket_name, _, prefix = rest.partition("/")
                prefix = prefix.rstrip('/') + '/'
                blob_path = prefix + name
                bucket = client.lookup_bucket(bucket_name)
                if bucket is None:
                    raise FileNotFoundError(f"Bucket not found: gs://{bucket_name}")
                blob = bucket.blob(blob_path)
                if not blob.exists():
                    raise FileNotFoundError(f"Samples CSV not found: gs://{bucket_name}/{blob_path}")
                data = blob.download_as_bytes()
                return pd.read_csv(io.BytesIO(data), parse_dates=['date_t', 'date_target'])
            fs = _make_gcsfs(project=requester_pays_project)
            target = str(path_or_dir).rstrip('/') + '/' + name
            with fs.open(target, "rb") as f:
                return pd.read_csv(f, parse_dates=['date_t', 'date_target'])
        else:
            p = Path(path_or_dir) / name
            return pd.read_csv(str(p), parse_dates=['date_t', 'date_target'])
    return load_one(samples_dir, "samples_train.csv"), load_one(samples_dir, "samples_val.csv"), load_one(samples_dir, "samples_test.csv")

# -------------------- Focal loss --------------------
def focal_loss(gamma=2.0, alpha=0.25):
    def loss_fn(y_true, y_pred):
        y_true = tf.cast(y_true, y_pred.dtype)
        bce = tf.keras.backend.binary_crossentropy(y_true, y_pred)
        p_t = y_true * y_pred + (1.0 - y_true) * (1.0 - y_pred)
        mod = tf.pow(1.0 - p_t, gamma)
        alpha_factor = y_true * alpha + (1.0 - y_true) * (1.0 - alpha)
        loss = alpha_factor * mod * bce
        return tf.reduce_mean(loss)
    return loss_fn

# -------------------- Custom callback & ValStats (unchanged except added metrics write and tag) --------------------
class CustomValMetrics(callbacks.Callback):
    def __init__(self, val_seq: Sequence, local_best_path: str, local_work: str, agg_percentile: float = 90.0, requester_pays_project: str | None = None, upload_fn=None, tag: str = ""):
        super().__init__()
        self.val_seq = val_seq
        self.best_score = -np.inf
        self.local_best_path = local_best_path
        self.local_work = local_work
        self.agg_percentile = float(agg_percentile)
        self.requester_pays_project = requester_pays_project
        self.upload_fn = upload_fn
        # tag to append to epoch artifacts (e.g., "pig")
        self.tag = tag if (tag is not None and len(str(tag))>0) else ""

    def _tagged_name(self, base: str):
        if not self.tag:
            return base
        base_noext, ext = os.path.splitext(base)
        return f"{base_noext}_{self.tag}{ext}"

    def on_epoch_end(self, epoch, logs=None):
        pix_trues = []
        pix_probs = []
        sample_trues = []
        sample_probs_mean = []
        sample_probs_p = []
        sample_rows = []

        for b_idx in range(len(self.val_seq)):
            Xb, yb, sw = self.val_seq[b_idx]
            probs = self.model.predict(Xb, verbose=0)
            for j in range(Xb.shape[0]):
                entry_idx = b_idx * self.val_seq.bs + j
                if entry_idx >= len(self.val_seq.entries):
                    continue
                row_idx = self.val_seq.entries[entry_idx]["row_idx"]
                mask = (sw[j, ..., 0] > 0)
                if mask.sum() == 0:
                    continue
                pix_trues.append(yb[j, ..., 0][mask].reshape(-1))
                pix_probs.append(probs[j, ..., 0][mask].reshape(-1))
                vals = probs[j, ..., 0][mask]
                mean_prob = float(np.nanmean(vals))
                try:
                    p_val = float(np.nanpercentile(vals, self.agg_percentile))
                except Exception:
                    p_val = mean_prob
                sample_label = None
                try:
                    sample_label = int(self.val_seq.df.iloc[row_idx]['y_bin'])
                except Exception:
                    try:
                        sample_label = int(self.val_seq.entries[entry_idx].get("sample_y_bin"))
                    except Exception:
                        sample_label = None
                if sample_label is not None:
                    sample_trues.append(sample_label)
                    sample_probs_mean.append(mean_prob)
                    sample_probs_p.append(p_val)
                    sample_rows.append({
                        "row_idx": int(row_idx),
                        "date_t": str(self.val_seq.entries[entry_idx]["date_t"]),
                        "date_target": str(self.val_seq.entries[entry_idx]["date_target"]),
                        "mean_prob": mean_prob,
                        f"p{int(self.agg_percentile)}_prob": p_val,
                        "label": sample_label
                    })

        if len(pix_trues) > 0:
            y_pix = np.concatenate(pix_trues).ravel()
            p_pix = np.concatenate(pix_probs).ravel()
        else:
            y_pix = np.array([])
            p_pix = np.array([])

        pix_ap = pix_roc = None
        if y_pix.size > 0 and average_precision_score is not None:
            try:
                pix_ap = float(average_precision_score(y_pix, p_pix))
            except Exception:
                pix_ap = None
        if y_pix.size > 0 and roc_auc_score is not None:
            try:
                pix_roc = float(roc_auc_score(y_pix, p_pix))
            except Exception:
                pix_roc = None

        samp_ap_mean = samp_roc_mean = samp_ap_p = samp_roc_p = None
        if len(sample_trues) > 0 and average_precision_score is not None:
            try:
                samp_ap_mean = float(average_precision_score(sample_trues, sample_probs_mean))
            except Exception:
                samp_ap_mean = None
            try:
                samp_ap_p = float(average_precision_score(sample_trues, sample_probs_p))
            except Exception:
                samp_ap_p = None
        if len(sample_trues) > 0 and roc_auc_score is not None:
            try:
                if len(set(sample_trues)) > 1:
                    samp_roc_mean = float(roc_auc_score(sample_trues, sample_probs_mean))
                else:
                    samp_roc_mean = None
            except Exception:
                samp_roc_mean = None
            try:
                if len(set(sample_trues)) > 1:
                    samp_roc_p = float(roc_auc_score(sample_trues, sample_probs_p))
                else:
                    samp_roc_p = None
            except Exception:
                samp_roc_p = None

        n_pix = int(y_pix.size)
        n_pix_pos = int(np.sum(y_pix == 1)) if n_pix > 0 else 0
        n_samp = len(sample_trues)
        n_samp_pos = int(np.sum(np.array(sample_trues) == 1)) if n_samp > 0 else 0

        print(f"Epoch {epoch+1} metrics: pix_ap={pix_ap}, pix_roc={pix_roc} (n_pix={n_pix}, n_pix_pos={n_pix_pos}); "
              f"samp_ap_mean={samp_ap_mean}, samp_ap_p{int(self.agg_percentile)}={samp_ap_p}, samp_roc_mean={samp_roc_mean}, samp_roc_p{int(self.agg_percentile)}={samp_roc_p} "
              f"(n_samp={n_samp}, n_samp_pos={n_samp_pos})")

        # Save sample-level preds CSV for this epoch (existing behaviour) but tagged
        if len(sample_rows) > 0:
            preds_df = pd.DataFrame(sample_rows)
            epoch_preds_base = f"val_sample_preds_epoch{epoch+1}.csv"
            epoch_preds_local = os.path.join(self.local_work, self._tagged_name(epoch_preds_base))
            preds_df.to_csv(epoch_preds_local, index=False)
        else:
            epoch_preds_local = None

        # --- NEW: compute and save validation metrics JSON for this epoch (tagged) ---
        try:
            val_metrics = {}
            # pixel metrics (if present)
            if y_pix.size > 0:
                val_metrics["pixel"] = compute_sample_metrics_from_arrays(y_pix, p_pix)
            else:
                val_metrics["pixel"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
            # sample metrics - mean & percentile aggregation
            if len(sample_trues) > 0:
                val_metrics["sample_mean"] = compute_sample_metrics_from_arrays(np.array(sample_trues, dtype=int), np.array(sample_probs_mean, dtype=float))
                val_metrics["sample_p"] = compute_sample_metrics_from_arrays(np.array(sample_trues, dtype=int), np.array(sample_probs_p, dtype=float))
            else:
                val_metrics["sample_mean"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
                val_metrics["sample_p"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))

            val_metrics["n_pix"] = int(n_pix)
            val_metrics["n_pix_pos"] = int(n_pix_pos)
            val_metrics["n_samp"] = int(n_samp)
            val_metrics["n_samp_pos"] = int(n_samp_pos)

            val_metrics_base = f"val_metrics_epoch{epoch+1}.json"
            val_metrics_local = os.path.join(self.local_work, self._tagged_name(val_metrics_base))
            with open(val_metrics_local, "w") as fh:
                json.dump(val_metrics, fh, indent=2)
            # attempt upload via upload_fn
            if self.upload_fn is not None and callable(self.upload_fn):
                try:
                    self.upload_fn(val_metrics_local)
                except Exception as e:
                    print("Upload of val metrics JSON failed:", e)
        except Exception as e:
            print("Failed to compute/save validation metrics JSON for epoch", epoch+1, ":", e)

        sel_score = samp_ap_p if (samp_ap_p is not None) else (samp_ap_mean if (samp_ap_mean is not None) else (pix_ap if (pix_ap is not None) else float("-inf")))

        if sel_score > self.best_score:
            print(f"New best aggregated selection score: {sel_score} (previous {self.best_score}); saving model to {self.local_best_path}")
            self.best_score = sel_score
            try:
                self.model.save(self.local_best_path)
                if self.upload_fn is not None and callable(self.upload_fn):
                    try:
                        self.upload_fn(self.local_best_path)
                    except Exception as e:
                        print("Upload of best model failed:", e)
                    if epoch_preds_local is not None:
                        try:
                            self.upload_fn(epoch_preds_local)
                        except Exception as e:
                            print("Upload of best preds CSV failed:", e)
            except Exception as e:
                print("Failed to save best model locally:", e)

class ValStats(callbacks.Callback):
    def __init__(self, val_seq: Sequence):
        super().__init__()
        self.val_seq = val_seq

    def on_epoch_begin(self, epoch, logs=None):
        n_samp_valid = 0
        n_pos_pix = 0
        for i in range(len(self.val_seq)):
            _, yb, sw = self.val_seq[i]
            pos_pix = (yb[..., 0] == 1) & (sw[..., 0] > 0)
            n_pos_pix += int(np.sum(pos_pix))
            n_samp_valid += int(np.sum(np.sum(sw[..., 0], axis=(1,2)) > 0))
        print(f"Epoch {epoch+1} start: validation valid_samples={n_samp_valid}, total_val_pos_pixels={n_pos_pix}")
        if n_pos_pix == 0:
            print("WARNING: Validation set contains ZERO positive pixels. This will make ROC/AUC undefined.")
            print(" - Check sample CSVs and processed .npz (y_bin), or use --image_binary if target files are already binary.")
            print(" - Consider lowering --threshold or enabling --image_binary if appropriate.")

# -------------------- Model builder (same as before) --------------------
def build_convlstm_model(H, W, channels, seq_len=14, use_attention=True, lr=1e-4, clipnorm: float = 0.0, lstm_filters=32):
    inp = layers.Input(shape=(seq_len, H, W, channels), name="convlstm_input")
    x = layers.ConvLSTM2D(filters=lstm_filters, kernel_size=(3,3), padding="same",
                          return_sequences=False, activation="relu")(inp)
    x = layers.BatchNormalization()(x)
    if use_attention:
        filters = int(x.shape[-1])
        se = layers.GlobalAveragePooling2D()(x)
        se = layers.Dense(max(4, filters//8), activation="relu")(se)
        se = layers.Dense(filters, activation="sigmoid")(se)
        se = layers.Reshape((1,1,filters))(se)
        x = layers.multiply([x, se])
    x = layers.Conv2D(32, (3,3), padding="same", activation="relu")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    out = layers.Conv2D(1, (1,1), padding="same", activation="sigmoid", dtype="float32")(x)
    model = models.Model(inp, out)
    return model

# -------------------- Main (unchanged except metrics write and tagging at end) --------------------
def main(args):
    local_work = tempfile.mkdtemp(prefix="convlstm_work_")
    try:
        tag = getattr(args, "tag", "")
        # small helper to tag filenames consistently
        def tag_name(base_filename: str) -> str:
            if not tag:
                return base_filename
            base_noext, ext = os.path.splitext(base_filename)
            return f"{base_noext}_{tag}{ext}"

        class P: pass
        P.seq_len = args.seq_len
        P.img_h = args.img_h
        P.img_w = args.img_w
        P.img_ch = args.img_ch
        P.batch_size = args.batch_size
        P.use_env = args.use_env
        P.threshold_bloom = args.threshold
        P.min_valid_pixels = args.min_valid_pixels
        P.pos_weight = args.pos_weight if args.pos_weight is not None else 1.0

        if getattr(args, "pos_weight_override", None) is not None:
            P.pos_weight = float(args.pos_weight_override)

        requester_pays_project = getattr(args, "requester_pays_project", None)
        use_gcsfs = getattr(args, "use_gcsfs", False)

        print("Loading processed arrays from:", args.processed_npz)
        data = {}
        try:
            data = load_npz(args.processed_npz, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs) or {}
        except Exception as e:
            print("Warning: failed to load processed npz:", e)
            data = {}

        X_train = data.get("X_train")
        X_val = data.get("X_val")
        X_test = data.get("X_test")

        npz_has_ybin = all(k in data for k in ("y_train_bin", "y_val_bin", "y_test_bin"))
        if npz_has_ybin:
            print("processed .npz contains binary targets (y_*_bin) — will use them for pos_weight and diagnostics.")
            y_train_bin_npz = np.asarray(data["y_train_bin"])
            y_val_bin_npz = np.asarray(data["y_val_bin"])
            y_test_bin_npz = np.asarray(data["y_test_bin"])
        else:
            y_train_bin_npz = y_val_bin_npz = y_test_bin_npz = None

        samples_dir_to_use = args.samples_dir
        if not _samples_files_exist(samples_dir_to_use, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs):
            if is_gcs_path(args.processed_npz):
                parent_prefix = os.path.dirname(args.processed_npz)
            else:
                parent_prefix = os.path.dirname(os.path.abspath(args.processed_npz))
            if _samples_files_exist(parent_prefix, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs):
                print(f"Warning: samples not found under {samples_dir_to_use}; falling back to {parent_prefix}")
                samples_dir_to_use = parent_prefix

        print("Loading sample CSVs from:", samples_dir_to_use)
        train_df, val_df, test_df = load_samples_csvs(samples_dir_to_use, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs)

        print("Listing image files in:", args.image_dir)
        image_map = list_image_files(args.image_dir, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs)
        if len(image_map) == 0:
            raise RuntimeError("No image files found under image_dir.")

        # compute pos_weight: prefer npz y_train_bin if present, else use train_df
        if getattr(args, "compute_pos_weight", False):
            computed_pw = None
            if npz_has_ybin:
                pos = int((y_train_bin_npz == 1).sum())
                neg = int((y_train_bin_npz == 0).sum())
                if pos > 0:
                    computed_pw = float(neg / pos)
            elif "y_bin" in train_df.columns:
                pos = int((train_df["y_bin"] == 1).sum())
                neg = int((train_df["y_bin"] == 0).sum())
                if pos > 0:
                    computed_pw = float(neg / pos)
            if computed_pw is not None:
                P.pos_weight = computed_pw
        print("Using pos_weight =", P.pos_weight)

        train_seq = ConvLSTMSequence(train_df, X_train if X_train is not None else np.zeros((len(train_df),1)), image_map, P, use_mask_channel=args.use_mask_channel, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs, pad_value=args.pad_value, image_binary=args.image_binary)
        val_seq = ConvLSTMSequence(val_df, X_val if X_val is not None else np.zeros((len(val_df),1)), image_map, P, use_mask_channel=args.use_mask_channel, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs, pad_value=args.pad_value, image_binary=args.image_binary)
        test_seq = ConvLSTMSequence(test_df, X_test if X_test is not None else np.zeros((len(test_df),1)), image_map, P, use_mask_channel=args.use_mask_channel, requester_pays_project=requester_pays_project, use_gcsfs=use_gcsfs, pad_value=args.pad_value, image_binary=args.image_binary)

        channels_total = args.img_ch + (X_train.shape[1] if (X_train is not None and args.use_env and X_train.ndim > 1) else 0) + (1 if args.use_mask_channel else 0)
        print("Channels total:", channels_total)

        model = build_convlstm_model(P.img_h, P.img_w, channels_total, seq_len=P.seq_len, use_attention=args.use_attention, lr=args.lr, clipnorm=args.clipnorm)

        from tensorflow.keras.optimizers import legacy as optimizers_legacy
        base_opt = optimizers_legacy.Adam(learning_rate=args.lr)
        policy = mixed_precision.global_policy()
        if policy.name == "mixed_float16":
            opt = mixed_precision.LossScaleOptimizer(base_opt)
            print("Using LossScaleOptimizer for mixed precision (legacy Adam).")
        else:
            opt = base_opt

        if getattr(args, "use_focal_loss", False):
            loss_fn = focal_loss(gamma=2.0, alpha=0.25)
            print("Using focal loss (gamma=2.0, alpha=0.25).")
        else:
            loss_fn = tf.keras.losses.BinaryCrossentropy()

        model.compile(optimizer=opt, loss=loss_fn, metrics=[tf.keras.metrics.BinaryAccuracy(name="acc")])
        model.summary()

        # Tag filenames early so callbacks and checkpointing use consistent names
        local_ckpt = os.path.join(local_work, tag_name(f"best_model.keras"))
        local_best_agg = os.path.join(local_work, tag_name(f"best_model_agg.keras"))
        csvlog = os.path.join(local_work, tag_name(f"history.csv"))
        tb_logdir = os.path.join(local_work, tag_name(f"tensorboard"))
        os.makedirs(tb_logdir, exist_ok=True)

        def _upload_to_out(local_path):
            try:
                dst = gcs_join(args.out_dir, os.path.basename(local_path))
                upload_file_to_gcs(local_path, dst, requester_pays_project=requester_pays_project)
                print("Uploaded", local_path, "->", dst)
            except Exception as e:
                print("Upload failed for", local_path, ":", e)

        cbs = [
            callbacks.EarlyStopping(monitor="val_loss", mode="min", patience=args.patience, restore_best_weights=True),
            callbacks.ReduceLROnPlateau(monitor="val_loss", mode="min", factor=0.5, patience=max(3, args.patience // 3), min_lr=1e-6, verbose=1),
            callbacks.ModelCheckpoint(local_ckpt, monitor="val_loss", mode="min", save_best_only=True),
            callbacks.CSVLogger(csvlog),
            callbacks.TensorBoard(log_dir=tb_logdir),
            ValStats(val_seq=val_seq),
            CustomValMetrics(val_seq=val_seq, local_best_path=local_best_agg, local_work=local_work, agg_percentile=args.agg_percentile, requester_pays_project=requester_pays_project, upload_fn=_upload_to_out, tag=tag)
        ]

        fit_kwargs = {}
        if args.workers > 0:
            fit_kwargs["workers"] = args.workers
            fit_kwargs["use_multiprocessing"] = args.use_multiprocessing

        history = model.fit(train_seq, validation_data=val_seq, epochs=args.epochs, callbacks=cbs, verbose=1, **fit_kwargs)
        local_final = os.path.join(local_work, tag_name("final_model.keras"))
        model.save(local_final)

        # --- Test set predictions and metrics collection (no model rerun) ---
        preds = []
        # containers for pixel-level metrics accumulation
        pix_trues = []
        pix_probs = []
        # containers for sample-level aggregation (mean and percentile)
        sample_trues = []
        sample_probs_mean = []
        sample_probs_p = []
        sample_rows = []

        for i in range(len(test_seq)):
            Xb, yb, sw = test_seq[i]
            yprob = model.predict(Xb, verbose=0)
            for j in range(Xb.shape[0]):
                entry_idx = i * test_seq.bs + j
                if entry_idx >= len(test_seq.entries):
                    continue
                row_idx = test_seq.entries[entry_idx]["row_idx"]
                date_t = test_seq.entries[entry_idx]["date_t"]
                date_target = test_seq.entries[entry_idx]["date_target"]
                # valid mask: use sample_weight > 0 (consistent with validation callback)
                valid_mask = (sw[j, ..., 0] > 0)
                if valid_mask.sum() > 0:
                    vals = yprob[j, ..., 0][valid_mask]
                    # pixel-level accumulation
                    pix_trues.append(yb[j, ..., 0][valid_mask].reshape(-1))
                    pix_probs.append(vals.reshape(-1))
                    # sample-level aggregated metrics
                    mean_prob = float(np.nanmean(vals))
                    try:
                        p_val = float(np.nanpercentile(vals, args.agg_percentile))
                    except Exception:
                        p_val = mean_prob
                else:
                    mean_prob = float("nan")
                    p_val = float("nan")
                preds.append({"row_idx": int(row_idx), "date_t": str(date_t), "date_target": str(date_target), "mean_prob": mean_prob})
                # sample label
                sample_label = None
                try:
                    sample_label = int(test_df.iloc[row_idx]['y_bin'])
                except Exception:
                    try:
                        sample_label = int(test_seq.entries[entry_idx].get("sample_y_bin"))
                    except Exception:
                        sample_label = None
                if sample_label is not None:
                    sample_trues.append(sample_label)
                    sample_probs_mean.append(mean_prob)
                    sample_probs_p.append(p_val)
                    sample_rows.append({
                        "row_idx": int(row_idx),
                        "date_t": str(date_t),
                        "date_target": str(date_target),
                        "mean_prob": mean_prob,
                        f"p{int(args.agg_percentile)}_prob": p_val,
                        "label": sample_label
                    })

        preds_df = pd.DataFrame(preds)
        preds_local = os.path.join(local_work, tag_name("predictions_test_convlstm.csv"))
        preds_df.to_csv(preds_local, index=False)

        # --- NEW: compute and save final test metrics JSON from preds_df + test_df (tagged) ---
        try:
            # Pixel-level metrics
            test_metrics = {}
            if len(pix_trues) > 0:
                y_pix = np.concatenate(pix_trues).ravel()
                p_pix = np.concatenate(pix_probs).ravel()
                test_metrics["pixel"] = compute_sample_metrics_from_arrays(y_pix, p_pix)
            else:
                test_metrics["pixel"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))

            # Sample-level metrics from aggregated per-sample values (mean and percentile)
            if len(sample_trues) > 0:
                test_metrics["sample_mean"] = compute_sample_metrics_from_arrays(np.array(sample_trues, dtype=int), np.array(sample_probs_mean, dtype=float))
                test_metrics["sample_p"] = compute_sample_metrics_from_arrays(np.array(sample_trues, dtype=int), np.array(sample_probs_p, dtype=float))
            else:
                # fallback: try merging preds_df with test_df using row_idx (useful if you didn't aggregate)
                merged_test = None
                try:
                    if "row_idx" in preds_df.columns:
                        test_idx = test_df.reset_index(drop=True).reset_index().rename(columns={"index": "row_idx"})
                        merged_test = pd.merge(preds_df, test_idx, on="row_idx", how="left", suffixes=("", "_samp"))
                    else:
                        # fallback to date columns join
                        join_keys = []
                        if ("date_t" in preds_df.columns) and ("date_t" in test_df.columns):
                            join_keys.append("date_t")
                        if ("date_target" in preds_df.columns) and ("date_target" in test_df.columns):
                            join_keys.append("date_target")
                        if len(join_keys) > 0:
                            merged_test = pd.merge(preds_df, test_df, on=join_keys, how="left", suffixes=("", "_samp"))
                    if merged_test is not None:
                        if "mean_prob" in merged_test.columns and "y_bin" in merged_test.columns:
                            merged_sub = merged_test[["mean_prob", "y_bin"]].dropna().rename(columns={"mean_prob": "prob", "y_bin": "label"})
                            if not merged_sub.empty:
                                y_samp = merged_sub["label"].astype(int).values
                                y_score = merged_sub["prob"].astype(float).values
                                test_metrics["sample_mean"] = compute_sample_metrics_from_arrays(y_samp, y_score)
                                # no percentile column available in this path; duplicate mean results
                                test_metrics["sample_p"] = test_metrics["sample_mean"]
                            else:
                                test_metrics["sample_mean"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
                                test_metrics["sample_p"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
                        else:
                            test_metrics["sample_mean"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
                            test_metrics["sample_p"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
                    else:
                        test_metrics["sample_mean"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
                        test_metrics["sample_p"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
                except Exception:
                    test_metrics["sample_mean"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))
                    test_metrics["sample_p"] = compute_sample_metrics_from_arrays(np.array([]), np.array([]))

            # counts
            test_metrics["n_pix"] = int(np.sum([len(x) for x in pix_trues])) if len(pix_trues) > 0 else 0
            test_metrics["n_pix_pos"] = int(np.sum(np.concatenate(pix_trues) == 1)) if len(pix_trues) > 0 else 0
            test_metrics["n_samp"] = int(len(sample_trues))
            test_metrics["n_samp_pos"] = int(np.sum(np.array(sample_trues) == 1)) if len(sample_trues) > 0 else 0

            metrics_local = os.path.join(local_work, tag_name("test_metrics.json"))
            with open(metrics_local, "w") as fh:
                json.dump(test_metrics, fh, indent=2)
            print("Test metrics:")
            print(json.dumps(test_metrics, indent=2))
            # upload metrics JSON
            try:
                _upload_to_out(metrics_local)
            except Exception:
                pass
        except Exception as e:
            print("Failed to compute/save test metrics JSON:", e)

        print("Uploading artifacts to:", args.out_dir)
        if is_gcs_path(args.out_dir):
            try:
                upload_file_to_gcs(local_final, gcs_join(args.out_dir, os.path.basename(local_final)), requester_pays_project=requester_pays_project)
            except Exception as e:
                print("Failed to upload final model:", e)
            try:
                upload_file_to_gcs(local_ckpt, gcs_join(args.out_dir, os.path.basename(local_ckpt)), requester_pays_project=requester_pays_project)
            except Exception as e:
                print("Failed to upload ckpt:", e)
            try:
                upload_file_to_gcs(preds_local, gcs_join(args.out_dir, os.path.basename(preds_local)), requester_pays_project=requester_pays_project)
            except Exception as e:
                print("Failed to upload preds csv:", e)
            try:
                upload_file_to_gcs(metrics_local, gcs_join(args.out_dir, os.path.basename(metrics_local)), requester_pays_project=requester_pays_project)
            except Exception as e:
                print("Failed to upload metrics json:", e)
            try:
                upload_file_to_gcs(csvlog, gcs_join(args.out_dir, os.path.basename(csvlog)), requester_pays_project=requester_pays_project)
            except Exception as e:
                print("Failed to upload csvlog:", e)
            try:
                upload_dir_to_gcs(tb_logdir, gcs_join(args.out_dir, "tensorboard"), requester_pays_project=requester_pays_project)
            except Exception as e:
                print("Failed to upload tensorboard logs:", e)
        else:
            try:
                os.makedirs(args.out_dir, exist_ok=True)
                shutil.move(local_final, os.path.join(args.out_dir, os.path.basename(local_final)))
            except Exception as e:
                print("Failed to move final model locally:", e)
            try:
                shutil.move(local_ckpt, os.path.join(args.out_dir, os.path.basename(local_ckpt)))
            except Exception as e:
                print("Failed to move checkpoint locally:", e)
            try:
                shutil.move(preds_local, os.path.join(args.out_dir, os.path.basename(preds_local)))
            except Exception as e:
                print("Failed to move preds csv locally:", e)
            try:
                shutil.move(metrics_local, os.path.join(args.out_dir, os.path.basename(metrics_local)))
            except Exception as e:
                print("Failed to move metrics json locally:", e)
            try:
                shutil.move(csvlog, os.path.join(args.out_dir, os.path.basename(csvlog)))
            except Exception as e:
                print("Failed to move csv log locally:", e)
            dst_tb = os.path.join(args.out_dir, "tensorboard")
            try:
                if os.path.exists(dst_tb):
                    shutil.rmtree(dst_tb)
                shutil.move(tb_logdir, dst_tb)
            except Exception as e:
                print("Failed to move tensorboard logs locally:", e)

        print("All artifacts uploaded to", args.out_dir)
    finally:
        try:
            shutil.rmtree(local_work)
        except Exception:
            pass

# -------------------- CLI --------------------
if __name__ == "__main__":
    interactive = ("ipykernel" in sys.modules) or ("google.colab" in sys.modules)
    parser = argparse.ArgumentParser(add_help=True)
    parser.add_argument("--processed_npz", type=str, default=DEFAULTS["processed_npz"], help="Processed dataset .npz (local or gs://)")
    parser.add_argument("--samples_dir", type=str, default=DEFAULTS["samples_dir"], help="Folder with samples_train/val/test CSVs (local or gs://)")
    parser.add_argument("--image_dir", type=str, default=DEFAULTS["image_dir"], help="Folder/prefix with .npy or .tif files (local or gs://)")
    parser.add_argument("--out_dir", type=str, default=DEFAULTS["out_dir"], help="Output directory (local or gs://)")
    parser.add_argument("--seq_len", type=int, default=DEFAULTS["seq_len"])
    parser.add_argument("--img_h", type=int, default=DEFAULTS["img_h"])
    parser.add_argument("--img_w", type=int, default=DEFAULTS["img_w"])
    parser.add_argument("--img_ch", type=int, default=DEFAULTS["img_ch"])
    parser.add_argument("--batch_size", type=int, default=DEFAULTS["batch_size"])
    parser.add_argument("--epochs", type=int, default=DEFAULTS["epochs"])
    parser.add_argument("--patience", type=int, default=DEFAULTS["patience"])
    parser.add_argument("--use_attention", action="store_true", dest="use_attention")
    parser.add_argument("--use_env", action="store_true", dest="use_env")
    parser.add_argument("--use_mask_channel", action="store_true", dest="use_mask_channel")
    parser.add_argument("--threshold", type=float, default=DEFAULTS["threshold"])
    parser.add_argument("--lr", type=float, default=DEFAULTS["lr"])
    parser.add_argument("--clipnorm", type=float, default=DEFAULTS["clipnorm"])
    parser.add_argument("--min_valid_pixels", type=int, default=DEFAULTS["min_valid_pixels"])
    parser.add_argument("--compute_pos_weight", action="store_true")
    parser.add_argument("--pos_weight", type=float, default=DEFAULTS["pos_weight"])
    parser.add_argument("--pos_weight_override", type=float, default=None, help="Force positive pixel weight (overrides computed pos weight)")
    parser.add_argument("--use_focal_loss", action="store_true", help="Use focal loss instead of BCE")
    parser.add_argument("--agg_percentile", type=float, default=90.0, help="Sample-level aggregation percentile (e.g. 90)")
    parser.add_argument("--pad_value", type=float, default=DEFAULTS["pad_value"], help="Value to use when padding missing historical images (default 0.0)")
    parser.add_argument("--image_binary", action="store_true", help="Treat target images as already binary (0/1). If set, skip thresholding and treat non-zero pixels as positive.")
    parser.add_argument("--use_gcsfs", action="store_true", help="Attempt gcsfs for reads if storage client fails (default: prefer storage client)")
    parser.add_argument("--requester_pays_project", type=str, default=None, help="GCP project id for requester-pays buckets (optional)")
    parser.add_argument("--scale_env_features", action="store_true", help="If set, scale env features using StandardScaler (saved to out_dir)")
    parser.add_argument("--workers", type=int, default=4, help="Number of worker processes for Sequence data loading (0 disables multiprocessing).")
    parser.add_argument("--use_multiprocessing", action="store_true", help="If set, use multiprocessing for Sequence data loading (requires workers>0).")
    parser.add_argument("--tag", type=str, default="pig", help="Optional tag to append to artifact filenames to avoid overwriting (default: pig). Use empty string to disable tagging.")
    args, unknown = parser.parse_known_args()
    if interactive and unknown:
        pass
    if interactive:
        defaults_ns = SimpleNamespace(**DEFAULTS)
        for k, v in vars(args).items():
            if v is not None:
                setattr(defaults_ns, k, v)
        args = defaults_ns
    main(args)

2025-11-12 17:35:25.920125: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-12 17:35:26.473153: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-12 17:35:28.456419: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64:/usr/lib/

Mixed precision enabled (policy = mixed_float16). GPUs detected: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Loading processed arrays from: gs://final_data_kgc2/models/Pigeon_binary/processed_dataset_binary.npz
processed .npz contains binary targets (y_*_bin) — will use them for pos_weight and diagnostics.
Loading sample CSVs from: gs://final_data_kgc2/models/Pigeon_binary/
Listing image files in: gs://final_data_kgc2/Final_data/Pigeon_80m/
Using pos_weight = 1.0


2025-11-12 17:35:59.265515: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-11-12 17:35:59.431765: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 20750 MB memory:  -> device: 0, name: NVIDIA L4, pci bus id: 0000:00:03.0, compute capability: 8.9


Initialized ConvLSTMSequence with 948281 valid entries (skipped 0).
Initialized ConvLSTMSequence with 74421 valid entries (skipped 0).
Initialized ConvLSTMSequence with 495447 valid entries (skipped 0).
Channels total: 1
Using LossScaleOptimizer for mixed precision (legacy Adam).
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 convlstm_input (InputLayer)  [(None, 14, 80, 80, 1)]  0         
                                                                 
 conv_lstm2d (ConvLSTM2D)    (None, 80, 80, 32)        38144     
                                                                 
 batch_normalization (BatchN  (None, 80, 80, 32)       128       
 ormalization)                                                   
                                                                 
 conv2d (Conv2D)             (None, 80, 80, 32)        9248      
                                            

2025-11-12 17:43:15.027265: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8900
2025-11-12 17:43:15.733569: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


Uploaded /var/tmp/convlstm_work_mv9ep3sk/val_metrics_epoch1_pig.json -> gs://final_data_kgc2/convlstm_results/val_metrics_epoch1_pig.json
New best aggregated selection score: 0.26318096254396794 (previous -inf); saving model to /var/tmp/convlstm_work_mv9ep3sk/best_model_agg_pig.keras
Uploaded /var/tmp/convlstm_work_mv9ep3sk/best_model_agg_pig.keras -> gs://final_data_kgc2/convlstm_results/best_model_agg_pig.keras
Uploaded /var/tmp/convlstm_work_mv9ep3sk/val_sample_preds_epoch1_pig.csv -> gs://final_data_kgc2/convlstm_results/val_sample_preds_epoch1_pig.csv
Epoch 2 start: validation valid_samples=74421, total_val_pos_pixels=70150987
Epoch 2/150
Uploaded /var/tmp/convlstm_work_mv9ep3sk/val_metrics_epoch2_pig.json -> gs://final_data_kgc2/convlstm_results/val_metrics_epoch2_pig.json
Epoch 3 start: validation valid_samples=74421, total_val_pos_pixels=70150987
Epoch 3/150
Uploaded /var/tmp/convlstm_work_mv9ep3sk/val_metrics_epoch3_pig.json -> gs://final_data_kgc2/convlstm_results/val_metrics