In [None]:
# -*- coding: utf-8 -*-
"""
Batch clean + dedup per-depth CSVs under:
  /data/wang/Result_Data/alldoxy/{depth}dbar/*.csv

Outputs (per folder):
  depthX.csv  ->  depthX_TRAIN.csv   (same folder)

IMPORTANT (carried requirements)
- Preserve existing field: sigma_interp
- Output column order forces sigma_interp placed right after Source:
    Date,Time,Pressure,Latitude,Longitude,Temperature,Salinity,Oxygen,Source,sigma_interp
- sigma_interp NOT used in dedup comparisons; carried from surviving record
- Drop rows where sigma_interp > 3 (sigma_interp NA kept)
- Oxygen QC strictly positive: Oxygen > 0 and < 600 (0 not kept)
- If output file exists, delete it before writing

NEW (this version)
- Final output rounding:
    Latitude/Longitude: round to 4 decimals (max 4)
    Temperature/Salinity/Oxygen: round to 2 decimals (max 2)

CHANGE (requested)
- Argo and OceanSITES do NOT participate in cross-source dedup (Rule D).
  i.e., they will NOT be linked into cross-source clusters with any other sources,
  and will never be dropped by cross-source dedup.
  Implementation: when building cross-source edges, skip any pair where either side
  has SourceLabel in {"Argo","OceanSITES"}.

ADDED (requested)
- Switch to enable/disable cross-source dedup (Rule D). Default OFF.

Logs written to:
  /data/wang/Result_Data/alldoxy/_logs/
    dedup_file_summary.csv
    dedup_source_summary.csv
    crosssrc_cluster_size_hist.csv

Python: 3.9
"""

import os, re, math
from pathlib import Path
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
from tqdm import tqdm

# =========================
# Config
# =========================

ROOT_DIR = Path("/data/wang/Result_Data/alldoxy")
LOG_DIR  = ROOT_DIR / "_logs"
LOG_DIR.mkdir(parents=True, exist_ok=True)

WRITE_ENCODING = "utf-8-sig"

LAT_MIN, LAT_MAX = -90.0, 90.0
O2_MIN, O2_MAX   = 0.0, 600.0
PRESSURE_MIN     = 0.0

SIGMA_MAX = 3.0  # drop sigma_interp > 3 (NA kept)

R_KM     = 1.0
DT_HOURS = 24.0

# Candidate bucketing (coarse prefilter; exact check uses haversine + dt)
TIME_BUCKET_HOURS = 6
SPACE_BUCKET_DEG  = 0.1

EXCLUDE_FROM_PRESSURE_COLLAPSE = {"Argo", "OSDCTD"}  # SourceLabel exclusions
EXCLUDE_FROM_CROSSSRC = {"Argo", "OceanSITES"}       # SourceLabel exclusions for cross-source dedup

# >>> Cross-source dedup (Rule D) switch (DEFAULT OFF) <<<
ENABLE_CROSS_SOURCE_DEDUP = False

SOURCE_PRIORITY = {
    "OSDCTD": 0,
    "GLODAPV2 2022": 1,
    "CCHDO_Bottle": 2,
    "CCHDO_CTD": 3,
    "Argo": 4,
    "Geotraces IDP": 5,
    "OceanSITES": 6,
    "Other": 99,
}

# Output column order (force sigma_interp after Source)
OUT_COLS_ORDER = [
    "Date", "Time", "Pressure", "Latitude", "Longitude",
    "Temperature", "Salinity", "Oxygen", "Source", "sigma_interp"
]

# Rounding for final output
OUT_LATLON_DECIMALS = 4
OUT_TSO2_DECIMALS   = 2


# =========================
# Helpers
# =========================

def normalize_source_label(src) -> str:
    if src is None or (isinstance(src, float) and np.isnan(src)):
        return "Other"
    s = str(src).strip()
    su = s.upper()

    if "OSDCTD" in su:
        return "OSDCTD"
    if su in {"OSD", "CTD"}:
        return "OSDCTD"

    if "GLODAP" in su:
        return "GLODAPV2 2022"

    if "CCHDO" in su and "BOTTLE" in su:
        return "CCHDO_Bottle"
    if "CCHDO" in su and "CTD" in su:
        return "CCHDO_CTD"

    if "OCEANSITES" in su:
        return "OceanSITES"

    if "GEOTRACES" in su or "IDP" in su:
        return "Geotraces IDP"

    if "ARGO" in su:
        return "Argo"

    return "Other"


def normalize_lon_to_180(lon: pd.Series) -> pd.Series:
    lonv = pd.to_numeric(lon, errors="coerce")
    return ((lonv + 180.0) % 360.0) - 180.0


def parse_date_time(df: pd.DataFrame) -> pd.DataFrame:
    """
    Standardize:
      Date -> YYYY-MM-DD
      Time -> HH:MM (NA if missing or 00:00)
    Adds:
      dt (datetime64[ns]) for time-difference calculations
    """
    if "Time" not in df.columns:
        df["Time"] = pd.NA

    date_raw = df["Date"].astype("string")
    time_raw = df["Time"].astype("string")

    has_time_in_date = date_raw.str.contains(":", regex=False, na=False)
    dt_from_date = pd.to_datetime(date_raw, errors="coerce")

    time_norm = time_raw.str.slice(0, 5)
    time_norm = time_norm.mask(time_norm.isna() | (time_norm.str.strip() == ""), pd.NA)
    time_norm = time_norm.mask(time_norm == "00:00", pd.NA)

    time_from_date = dt_from_date.dt.strftime("%H:%M").astype("string")
    time_from_date = time_from_date.mask(~has_time_in_date, pd.NA)
    time_from_date = time_from_date.mask(time_from_date == "00:00", pd.NA)

    time_final = time_norm.fillna(time_from_date)
    date_final = dt_from_date.dt.strftime("%Y-%m-%d").astype("string")

    time_for_dt = time_final.fillna("00:00")
    dt_calc = pd.to_datetime(date_final + " " + time_for_dt, errors="coerce")

    df["Date"] = date_final
    df["Time"] = time_final
    df["dt"]   = dt_calc
    return df


def haversine_km(lat1, lon1, lat2, lon2) -> float:
    R = 6371.0
    phi1 = math.radians(lat1); phi2 = math.radians(lat2)
    dphi = math.radians(lat2 - lat1)
    dlmb = math.radians(lon2 - lon1)
    a = math.sin(dphi/2)**2 + math.cos(phi1)*math.cos(phi2)*math.sin(dlmb/2)**2
    return 2 * R * math.asin(math.sqrt(a))


def build_buckets(df: pd.DataFrame) -> None:
    dt_ns = df["dt"].astype("datetime64[ns]").view("int64")
    t_hours = (dt_ns // 10**9) / 3600.0
    df["t_bucket"] = np.floor(t_hours / TIME_BUCKET_HOURS).astype(np.int64)
    df["lat_bucket"] = np.floor(df["Latitude"].astype(np.float64) / SPACE_BUCKET_DEG).astype(np.int64)
    df["lon_bucket"] = np.floor(df["Longitude"].astype(np.float64) / SPACE_BUCKET_DEG).astype(np.int64)


def uf_init(n: int):
    parent = np.arange(n, dtype=np.int64)
    rank = np.zeros(n, dtype=np.int8)

    def find(x: int) -> int:
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x

    def union(a: int, b: int) -> None:
        ra, rb = find(a), find(b)
        if ra == rb:
            return
        if rank[ra] < rank[rb]:
            parent[ra] = rb
        elif rank[ra] > rank[rb]:
            parent[rb] = ra
        else:
            parent[rb] = ra
            rank[ra] += 1

    return find, union, parent


def get_target_depth_from_folder(folder: Path) -> float:
    m = re.search(r"(\d+)dbar$", folder.name)
    if not m:
        return np.nan
    return float(m.group(1))


def pick_input_csv_in_folder(folder: Path) -> Path:
    # Expect one CSV; ignore already produced *_TRAIN.csv
    cands = sorted([p for p in folder.glob("*.csv") if not p.name.endswith("_TRAIN.csv")])
    if len(cands) == 0:
        raise FileNotFoundError(f"No csv found in {folder}")
    if len(cands) > 1:
        depth_like = [p for p in cands if p.name.lower().startswith("depth")]
        if len(depth_like) == 1:
            return depth_like[0]
        return cands[0]
    return cands[0]


def out_train_path(in_csv: Path) -> Path:
    return in_csv.with_name(in_csv.stem + "_TRAIN.csv")


def ensure_out_columns(df: pd.DataFrame) -> pd.DataFrame:
    for c in OUT_COLS_ORDER:
        if c not in df.columns:
            df[c] = pd.NA
    return df[OUT_COLS_ORDER].copy()


def finalize_rounding(df: pd.DataFrame) -> pd.DataFrame:
    """
    Apply final rounding constraints:
      - Latitude/Longitude: 4 decimals
      - Temperature/Salinity/Oxygen: 2 decimals
    Keep NA as NA.
    """
    if df.empty:
        return df

    for c in ["Latitude", "Longitude"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce").round(OUT_LATLON_DECIMALS)

    for c in ["Temperature", "Salinity", "Oxygen"]:
        if c in df.columns:
            df[c] = pd.to_numeric(df[c], errors="coerce").round(OUT_TSO2_DECIMALS)

    return df


# =========================
# Per-file processing
# =========================

def process_depth_file(csv_path: Path, target_depth: float):
    """
    Returns:
      df_out: cleaned+dedup df (includes sigma_interp preserved)
      stats: dict
      per_label_drop: dict for drop counts by SourceLabel for each category
      cluster_hist: Counter cluster size distribution (cross-source)
    """
    df = pd.read_csv(csv_path, low_memory=False)
    n_raw = len(df)

    need = ["Date", "Latitude", "Longitude", "Pressure", "Oxygen", "Source", "sigma_interp"]
    for c in need:
        if c not in df.columns:
            raise ValueError(f"Missing column '{c}' in {csv_path}")

    # Parse/standardize Date & Time
    df = parse_date_time(df)

    # Numeric conversions + lon normalization
    df["Latitude"]  = pd.to_numeric(df["Latitude"], errors="coerce")
    df["Longitude"] = normalize_lon_to_180(df["Longitude"])
    df["Pressure"]  = pd.to_numeric(df["Pressure"], errors="coerce")
    df["Oxygen"]    = pd.to_numeric(df["Oxygen"], errors="coerce")

    # Keep sigma_interp (numeric if possible; NA allowed)
    df["sigma_interp"] = pd.to_numeric(df["sigma_interp"], errors="coerce")

    # Hard QC (sigma_interp NA kept, but if non-NA must be <= SIGMA_MAX)
    mask = (
        df["Date"].notna() &
        df["dt"].notna() &
        df["Latitude"].notna() &
        df["Longitude"].notna() &
        df["Pressure"].notna() &
        df["Oxygen"].notna() &
        df["Source"].notna()
    )
    mask &= df["Latitude"].between(LAT_MIN, LAT_MAX, inclusive="both")
    mask &= (df["Longitude"] >= -180.0) & (df["Longitude"] < 180.0)
    mask &= df["Pressure"] >= PRESSURE_MIN
    mask &= (df["Oxygen"] > O2_MIN) & (df["Oxygen"] < O2_MAX)
    mask &= (df["sigma_interp"].isna() | (df["sigma_interp"] <= SIGMA_MAX))

    n_qc_drop = int((~mask).sum())
    df = df.loc[mask].copy()

    if df.empty:
        stats = dict(
            n_raw=int(n_raw),
            n_qc_drop=int(n_qc_drop),
            n_exactdup_drop=0,
            n_presscollapse_drop=0,
            n_crosssrc_drop=0,
            n_final=0,
            qc_drop_ratio=float(n_qc_drop / n_raw) if n_raw else 0.0,
            exactdup_drop_ratio=0.0,
            presscollapse_drop_ratio=0.0,
            crosssrc_drop_ratio=0.0,
            final_ratio=0.0,
            crosssrc_enabled=bool(ENABLE_CROSS_SOURCE_DEDUP),
        )
        df_out = ensure_out_columns(df)
        df_out = finalize_rounding(df_out)
        return df_out, stats, {}, Counter()

    # Source label & rank
    df["SourceLabel"] = df["Source"].map(normalize_source_label)
    df["src_rank"] = df["SourceLabel"].map(lambda x: SOURCE_PRIORITY.get(x, SOURCE_PRIORITY["Other"])).astype(np.int64)

    # Grouping keys (use 4-decimal lat/lon keys to align with output precision)
    df["Lat_key"] = df["Latitude"].round(OUT_LATLON_DECIMALS)
    df["Lon_key"] = df["Longitude"].round(OUT_LATLON_DECIMALS)
    df["Time_key"] = df["Time"].astype("string").fillna("")

    # -------------------------
    # B) same-source exact duplicates
    # -------------------------
    df = df.sort_values(
        ["Source", "Date", "Time_key", "Pressure", "Lat_key", "Lon_key", "Oxygen"],
        kind="mergesort"
    ).copy()

    dup_mask = df.duplicated(
        subset=["Source", "Date", "Time_key", "Pressure", "Lat_key", "Lon_key", "Oxygen"],
        keep="first"
    )
    n_exactdup_drop = int(dup_mask.sum())
    df = df.loc[~dup_mask].copy()

    # -------------------------
    # C) same-source multi-pressure collapse (exclude Argo & OSDCTD)
    # -------------------------
    per_label_drop = defaultdict(int)

    if np.isfinite(target_depth):
        can_collapse = ~df["SourceLabel"].isin(EXCLUDE_FROM_PRESSURE_COLLAPSE)

        df["grp_key"] = (
            df["Source"].astype("string") + "||" +
            df["Date"].astype("string") + "||" +
            df["Time_key"].astype("string") + "||" +
            df["Lat_key"].astype("string") + "||" +
            df["Lon_key"].astype("string")
        )

        eligible = df.loc[can_collapse].copy()
        if not eligible.empty:
            eligible["p_dist"] = (eligible["Pressure"] - target_depth).abs()
            eligible = eligible.sort_values(
                ["grp_key", "p_dist", "dt"],
                ascending=[True, True, True],
                kind="mergesort"
            )
            keep_idx = eligible.groupby("grp_key", sort=False).head(1).index
            drop_idx = eligible.index.difference(keep_idx)

            n_presscollapse_drop = int(len(drop_idx))
            if n_presscollapse_drop > 0:
                dropped = df.loc[drop_idx, "SourceLabel"].value_counts()
                for k, v in dropped.to_dict().items():
                    per_label_drop[f"presscollapse_drop::{k}"] += int(v)

            df = df.drop(index=drop_idx).copy()
        else:
            n_presscollapse_drop = 0

        df.drop(columns=["grp_key"], inplace=True, errors="ignore")
    else:
        n_presscollapse_drop = 0

    # -------------------------
    # D) cross-source near-duplicates (OPTIONAL; default OFF)
    # -------------------------
    if (not ENABLE_CROSS_SOURCE_DEDUP) or (len(df) <= 1):
        n_crosssrc_drop = 0
        cluster_hist = Counter({1: int(len(df))}) if len(df) else Counter()
        df_out = df.copy()
    else:
        df = df.reset_index(drop=True).copy()
        df["dt_ns"] = df["dt"].astype("datetime64[ns]").view("int64")

        # meta completeness for tie-break
        meta_score = np.zeros(len(df), dtype=np.int8)
        if "Temperature" in df.columns:
            meta_score += pd.to_numeric(df["Temperature"], errors="coerce").notna().to_numpy(dtype=np.int8)
        if "Salinity" in df.columns:
            meta_score += pd.to_numeric(df["Salinity"], errors="coerce").notna().to_numpy(dtype=np.int8)
        df["meta_score"] = meta_score

        build_buckets(df)

        bucket = defaultdict(list)
        for i, key in enumerate(zip(df["t_bucket"].to_numpy(),
                                    df["lat_bucket"].to_numpy(),
                                    df["lon_bucket"].to_numpy())):
            bucket[(int(key[0]), int(key[1]), int(key[2]))].append(i)

        find, union, parent = uf_init(len(df))
        dt_max_sec = DT_HOURS * 3600.0

        # Create cross-source edges
        for i in range(len(df)):
            tb = int(df.at[i, "t_bucket"])
            lb = int(df.at[i, "lat_bucket"])
            ob = int(df.at[i, "lon_bucket"])

            src_i = df.at[i, "Source"]
            lbl_i = df.at[i, "SourceLabel"]

            # Excluded labels do not participate
            if lbl_i in EXCLUDE_FROM_CROSSSRC:
                continue

            for dtb in (tb - 1, tb, tb + 1):
                for dlb in (lb - 1, lb, lb + 1):
                    for dob in (ob - 1, ob, ob + 1):
                        cand = bucket.get((dtb, dlb, dob), [])
                        for j in cand:
                            if j >= i:
                                continue
                            if src_i == df.at[j, "Source"]:
                                continue

                            lbl_j = df.at[j, "SourceLabel"]
                            if lbl_j in EXCLUDE_FROM_CROSSSRC:
                                continue

                            # time filter
                            dts = abs((df.at[i, "dt_ns"] - df.at[j, "dt_ns"]) / 1e9)
                            if dts > dt_max_sec:
                                continue

                            # distance filter
                            dist = haversine_km(
                                float(df.at[i, "Latitude"]), float(df.at[i, "Longitude"]),
                                float(df.at[j, "Latitude"]), float(df.at[j, "Longitude"])
                            )
                            if dist > R_KM:
                                continue

                            union(i, j)

        # Build clusters
        clusters = defaultdict(list)
        for i in range(len(df)):
            clusters[find(i)].append(i)

        cluster_hist = Counter()
        keep_mask = np.zeros(len(df), dtype=bool)
        n_crosssrc_drop = 0

        for root, idxs in clusters.items():
            if len(idxs) == 1:
                keep_mask[idxs[0]] = True
                cluster_hist[1] += 1
                continue

            # If this cluster contains only ONE unique Source, do NOT dedup by cross-source rule
            srcs = df.loc[idxs, "Source"].astype("string")
            if srcs.nunique(dropna=False) <= 1:
                for k in idxs:
                    keep_mask[k] = True
                cluster_hist[len(idxs)] += 1
                continue

            cluster_hist[len(idxs)] += 1

            sub = df.loc[idxs].copy()
            sub["p_dist"] = (sub["Pressure"] - target_depth).abs() if np.isfinite(target_depth) else 0.0

            sub = sub.sort_values(
                ["src_rank", "meta_score", "p_dist", "dt_ns"],
                ascending=[True, False, True, True],
                kind="mergesort"
            )
            winner = int(sub.index[0])
            keep_mask[winner] = True

            dropped_idxs = [k for k in idxs if k != winner]
            n_crosssrc_drop += len(dropped_idxs)

            dropped_lbl = df.loc[dropped_idxs, "SourceLabel"].value_counts()
            for k, v in dropped_lbl.to_dict().items():
                per_label_drop[f"crosssrc_drop::{k}"] += int(v)

        df_out = df.loc[keep_mask].copy()
        df_out.drop(
            columns=["dt_ns", "t_bucket", "lat_bucket", "lon_bucket", "meta_score"],
            inplace=True,
            errors="ignore"
        )

    # Final cleanup
    df_out.drop(columns=["Lat_key", "Lon_key", "Time_key"], inplace=True, errors="ignore")

    # Force output order (sigma_interp after Source)
    df_out = ensure_out_columns(df_out)

    # Apply rounding at the very end (output constraint)
    df_out = finalize_rounding(df_out)

    stats = dict(
        n_raw=int(n_raw),
        n_qc_drop=int(n_qc_drop),
        n_exactdup_drop=int(n_exactdup_drop),
        n_presscollapse_drop=int(n_presscollapse_drop),
        n_crosssrc_drop=int(n_crosssrc_drop),
        n_final=int(len(df_out)),
        qc_drop_ratio=float(n_qc_drop / n_raw) if n_raw else 0.0,
        exactdup_drop_ratio=float(n_exactdup_drop / max(1, (n_raw - n_qc_drop))),
        presscollapse_drop_ratio=float(n_presscollapse_drop / max(1, (n_raw - n_qc_drop - n_exactdup_drop))),
        crosssrc_drop_ratio=float(n_crosssrc_drop / max(1, (n_raw - n_qc_drop - n_exactdup_drop - n_presscollapse_drop))),
        final_ratio=float(len(df_out) / n_raw) if n_raw else 0.0,
        crosssrc_enabled=bool(ENABLE_CROSS_SOURCE_DEDUP),
    )
    return df_out, stats, dict(per_label_drop), cluster_hist


# =========================
# Batch run
# =========================

depth_folders = sorted([p for p in ROOT_DIR.iterdir()
                        if p.is_dir() and re.match(r"^\d+dbar$", p.name)])
print(f"Found {len(depth_folders)} depth folders under {ROOT_DIR}")
print(f"Cross-source dedup enabled: {ENABLE_CROSS_SOURCE_DEDUP}")

file_rows = []
source_rows = defaultdict(lambda: Counter())
cluster_hist_all = Counter()

for folder in tqdm(depth_folders, desc="Process folders", unit="folder"):
    target_depth = get_target_depth_from_folder(folder)

    try:
        in_csv = pick_input_csv_in_folder(folder)
    except Exception as e:
        file_rows.append({"folder": str(folder), "error": f"input_csv_not_found: {e}"})
        continue

    out_csv = out_train_path(in_csv)

    try:
        df_out, stats, per_label_drop, cl_hist = process_depth_file(in_csv, target_depth)
    except Exception as e:
        file_rows.append({
            "folder": str(folder),
            "input_csv": str(in_csv),
            "output_csv": str(out_csv),
            "depth_dbar": target_depth,
            "error": str(e)
        })
        continue

    # If output exists, delete it first
    try:
        if out_csv.exists():
            out_csv.unlink()
    except Exception:
        if os.path.exists(str(out_csv)):
            os.remove(str(out_csv))

    # Write output
    df_out.to_csv(out_csv, index=False, encoding=WRITE_ENCODING)

    row = {
        "folder": str(folder),
        "depth_dbar": int(target_depth) if np.isfinite(target_depth) else None,
        "input_csv": str(in_csv),
        "output_csv": str(out_csv),
        **stats
    }
    file_rows.append(row)

    for k, v in per_label_drop.items():
        depth_tag = f"depth_{int(target_depth)}dbar" if np.isfinite(target_depth) else "depth_nan"
        source_rows[k].update({depth_tag: int(v)})

    cluster_hist_all.update(cl_hist)

# =========================
# Write logs (robust to empty)
# =========================

df_file = pd.DataFrame(file_rows)
df_file.to_csv(LOG_DIR / "dedup_file_summary.csv", index=False, encoding=WRITE_ENCODING)

# --- dedup_source_summary.csv (FIX KeyError: drop_category) ---
src_sum = []
for key, cnt in source_rows.items():
    total = sum(cnt.values())
    src_sum.append({"drop_category": key, "total_dropped": int(total)})

if len(src_sum) == 0:
    df_src = pd.DataFrame(columns=["drop_category", "total_dropped"])
else:
    df_src = pd.DataFrame(src_sum).sort_values(["drop_category"])

df_src.to_csv(LOG_DIR / "dedup_source_summary.csv", index=False, encoding=WRITE_ENCODING)

# --- crosssrc_cluster_size_hist.csv (robust) ---
if len(cluster_hist_all) == 0:
    df_hist = pd.DataFrame(columns=["cluster_size", "n_clusters"])
else:
    df_hist = pd.DataFrame(
        [{"cluster_size": int(k), "n_clusters": int(v)} for k, v in sorted(cluster_hist_all.items())]
    )
df_hist.to_csv(LOG_DIR / "crosssrc_cluster_size_hist.csv", index=False, encoding=WRITE_ENCODING)

print("\n[DONE]")
print("Outputs written as *_TRAIN.csv in each depth folder (existing outputs overwritten).")
print(f"Logs written to: {LOG_DIR}")
print(f"  - {LOG_DIR / 'dedup_file_summary.csv'}")
print(f"  - {LOG_DIR / 'dedup_source_summary.csv'}")
print(f"  - {LOG_DIR / 'crosssrc_cluster_size_hist.csv'}")
print("Output columns enforced:", ",".join(OUT_COLS_ORDER))
print(f"Output rounding: Lat/Lon={OUT_LATLON_DECIMALS} decimals; T/S/O2={OUT_TSO2_DECIMALS} decimals")
print("Cross-source dedup EXCLUDED labels:", ",".join(sorted(EXCLUDE_FROM_CROSSSRC)))
print("Cross-source dedup ENABLED:", ENABLE_CROSS_SOURCE_DEDUP)


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Compute oxygen saturation percent Sat(%) and DROP Sat>=120 for depth>200 dbar
for *TRAIN.csv under:
  /data/wang/Result_Data/alldoxy/{depth}dbar/

Sat definition:
  Sat(%) = 100 * Oxygen / O2_sat_umolkg
O2_sat_umolkg computed by:
  - TEOS-10 (gsw) preferred: SP + pt0 -> gsw.O2sol_SP_pt (μmol/kg)
  - fallback: Weiss(1970) (ml/L) + EOS-80 density -> μmol/kg

BEHAVIOR (this version):
- Always compute/overwrite column "Sat" (percent) when possible.
- For folders with depth_dbar > 200:
    DROP rows where Sat >= 120 (only if Sat is finite).
  (depth <= 200: keep all rows; still compute Sat)
- Keep rows with Sat=NA (due to missing T/S or missing O2sat), never drop by Sat.
- Write per-file logs + aggregated summary logs into:
    /data/wang/Result_Data/alldoxy/_logs/
      sat_qc_drop_log.csv
      sat_qc_drop_log_summary.csv

NOTE:
- This script edits files in-place (atomic replace).
"""

import re
import numpy as np
import pandas as pd
from pathlib import Path

# =========================
# Config
# =========================
ROOT = Path("/data/wang/Result_Data/alldoxy")
LOG_DIR = ROOT / "_logs"
LOG_DIR.mkdir(parents=True, exist_ok=True)

# Process all these depth folders (you can trim if you only want a subset)
TARGET_DEPTHS = [
    1,10,20,30,40,50,60,70,80,90,100,
    110,120,130,140,150,160,170,180,190,200,210,220,230,240,250,260,
    270,280,290,300,310,320,330,340,350,360,370,380,390,400,410,420,
    430,440,450,460,470,480,490,500,510,520,530,540,550,560,570,580,
    590,600,610,620,630,640,650,660,670,680,690,700,710,720,730,740,
    750,760,770,780,790,800,820,840,860,880,900,920,940,960,980,1000,
    1020,1040,1060,1080,1100,1120,1140,1160,1180,1200,1220,1240,1260,
    1280,1300,1320,1340,1360,1380,1400,1420,1440,1460,1480,1500,1520,
    1540,1560,1580,1600,1620,1640,1660,1680,1700,1720,1740,1760,1780,
    1800,1820,1840,1860,1880,1900,1920,1940,1960,1980,2000,2100,2200,
    2300,2400,2500,2600,2700,2800,2900,3000,3100,3200,3300,3400,3500,
    3600,3700,3800,3900,4000,4100,4200,4300,4400,4500,4600,4700,4800,4900,
    5000,5100,5200,5300,5400,5500
]

TEMP_COL = "Temperature"  # °C
SAL_COL  = "Salinity"     # PSU/SP
OXY_COL  = "Oxygen"       # μmol/kg
SAT_COL  = "Sat"          # %

WRITE_ENCODING = "utf-8-sig"
SAT_DROP_THR = 120.0      # drop if Sat >= 120 (depth>200 only)

# =========================
# Fallback: Weiss(1970) + EOS-80 rho
# =========================
def _rho_eos80_kg_m3(S, T):
    """EOS-80 density (kg/m^3) near 0 dbar; S=PSU, T=°C (ITS-90)."""
    S = np.asarray(S, float); T = np.asarray(T, float)
    rho_w = (999.842594 + 6.793952e-2*T - 9.095290e-3*T**2
             + 1.001685e-4*T**3 - 1.120083e-6*T**4 + 6.536332e-9*T**5)
    A = (0.824493 - 4.0899e-3*T + 7.6438e-5*T**2 - 8.2467e-7*T**3 + 5.3875e-9*T**4)
    B = (-5.72466e-3 + 1.0227e-4*T - 1.6546e-6*T**2)
    C = 4.8314e-4
    return rho_w + A*S + B*(S**1.5) + C*(S**2)

def _o2sol_weiss_ml_per_L(T, S):
    """Weiss (1970) O2 solubility (ml/L), T=°C (ITS-90) with IPTS-68 conversion inside."""
    T = np.asarray(T, float); S = np.asarray(S, float)
    Tk = T*1.00024 + 273.15  # ITS-90 -> IPTS-68 -> Kelvin
    A1, A2, A3, A4 = -173.4292, 249.6339, 143.3483, -21.8492
    B1, B2, B3 = -0.033096, 0.014259, -0.0017000
    lnC = (A1 + A2*(100.0/Tk) + A3*np.log(Tk/100.0) + A4*(Tk/100.0)
           + S*(B1 + B2*(Tk/100.0) + B3*(Tk/100.0)**2))
    return np.exp(lnC)

def o2_sat_umolkg_weiss(T, S):
    """Weiss saturation O2 converted to μmol/kg."""
    mlL = _o2sol_weiss_ml_per_L(T, S)
    rho = _rho_eos80_kg_m3(S, T)            # kg/m^3
    return mlL * 44.659 * (1000.0 / rho)    # μmol/kg

# =========================
# TEOS-10 (gsw) path
# =========================
def o2_sat_umolkg_teos10(SP, t, p, lon, lat):
    """
    TEOS-10: O2 saturation solubility (μmol/kg, referenced to 0 dbar with pt0).
    Needs SP, in-situ t, p(dbar), lon, lat.
    """
    import gsw
    SP = np.asarray(SP, float); t = np.asarray(t, float)
    p = np.asarray(p, float); lon = np.asarray(lon, float); lat = np.asarray(lat, float)
    SA  = gsw.SA_from_SP(SP, p, lon, lat)
    pt0 = gsw.pt0_from_t(SA, t, p)
    return gsw.O2sol_SP_pt(SP, pt0)  # μmol/kg

# =========================
# Helpers
# =========================
def atomic_write_csv(df: pd.DataFrame, path: Path, encoding=WRITE_ENCODING):
    tmp = path.with_suffix(path.suffix + ".tmp")
    df.to_csv(tmp, index=False, encoding=encoding)
    tmp.replace(path)

def parse_depth_from_folder(folder: Path):
    m = re.match(r"^(\d+)dbar$", folder.name)
    return int(m.group(1)) if m else None

def find_train_csvs(depth_folder: Path):
    return sorted(depth_folder.glob("*TRAIN.csv"))

def compute_sat_percent(df: pd.DataFrame):
    """
    Compute Sat(%) = 100 * Oxygen / O2_sat_umolkg.
    If Temperature/Salinity missing -> Sat stays NaN.
    TEOS-10 preferred; fallback Weiss.

    Returns:
      sat_percent (float array)
      used_teos (bool array)
      used_weiss (bool array)
      o2sat (float array)
    """
    n = len(df)
    sat_percent = np.full(n, np.nan, dtype=float)
    o2sat = np.full(n, np.nan, dtype=float)

    T = pd.to_numeric(df.get(TEMP_COL, np.nan), errors="coerce").to_numpy()
    S = pd.to_numeric(df.get(SAL_COL,  np.nan), errors="coerce").to_numpy()
    O = pd.to_numeric(df.get(OXY_COL,  np.nan), errors="coerce").to_numpy()

    m_ts = np.isfinite(T) & np.isfinite(S)
    used_teos = np.zeros(n, dtype=bool)
    used_weiss = np.zeros(n, dtype=bool)

    if not m_ts.any():
        return sat_percent, used_teos, used_weiss, o2sat

    pres = pd.to_numeric(df.get("Pressure", np.nan), errors="coerce").to_numpy()
    lon  = pd.to_numeric(df.get("Longitude", np.nan), errors="coerce").to_numpy()
    lat  = pd.to_numeric(df.get("Latitude", np.nan), errors="coerce").to_numpy()

    # TEOS-10
    try:
        import gsw  # noqa: F401
        m_teos = m_ts & np.isfinite(pres) & np.isfinite(lon) & np.isfinite(lat)
        if m_teos.any():
            o2sat[m_teos] = o2_sat_umolkg_teos10(S[m_teos], T[m_teos], pres[m_teos], lon[m_teos], lat[m_teos])
            used_teos[m_teos] = True
    except Exception:
        pass

    # Weiss fallback
    m_weiss = m_ts & ~np.isfinite(o2sat)
    if m_weiss.any():
        o2sat[m_weiss] = o2_sat_umolkg_weiss(T[m_weiss], S[m_weiss])
        used_weiss[m_weiss] = True

    # Sat(%)
    m_sat = np.isfinite(O) & np.isfinite(o2sat) & (o2sat > 0)
    sat_percent[m_sat] = 100.0 * (O[m_sat] / o2sat[m_sat])

    return sat_percent, used_teos, used_weiss, o2sat

def process_one_csv(csv_path: Path, depth_dbar: int):
    df = pd.read_csv(csv_path, low_memory=False)
    n_in = int(len(df))

    # default Sat NA if missing columns
    if TEMP_COL not in df.columns or SAL_COL not in df.columns or OXY_COL not in df.columns:
        df[SAT_COL] = pd.NA
        n_out = int(len(df))
        atomic_write_csv(df, csv_path)
        return {
            "file": str(csv_path),
            "depth_dbar": depth_dbar,
            "rows_in": n_in,
            "rows_out": n_out,
            "sat_finite": 0,
            "sat_ge120": 0,
            "dropped_sat_ge120": 0,
            "teos_used": 0,
            "weiss_used": 0,
            "note": f"missing required cols ({TEMP_COL}/{SAL_COL}/{OXY_COL}); Sat=NA; no drop"
        }

    sat, used_teos, used_weiss, _ = compute_sat_percent(df)
    df[SAT_COL] = sat  # overwrite/add

    m_finite = np.isfinite(sat)
    n_finite = int(m_finite.sum())
    n_ge120 = int((m_finite & (sat >= SAT_DROP_THR)).sum())

    dropped = 0
    if depth_dbar is not None and depth_dbar > 200:
        keep_mask = ~(m_finite & (sat >= SAT_DROP_THR))
        dropped = int((~keep_mask).sum())
        df = df.loc[keep_mask].copy()

    n_out = int(len(df))
    atomic_write_csv(df, csv_path)

    return {
        "file": str(csv_path),
        "depth_dbar": depth_dbar,
        "rows_in": n_in,
        "rows_out": n_out,
        "sat_finite": n_finite,
        "sat_ge120": n_ge120,
        "dropped_sat_ge120": dropped,
        "teos_used": int(used_teos.sum()),
        "weiss_used": int(used_weiss.sum()),
        "note": "ok" if (depth_dbar is not None and depth_dbar > 200) else "ok (depth<=200: no drop)"
    }

def main():
    logs = []

    for d in TARGET_DEPTHS:
        folder = ROOT / f"{d}dbar"
        if not folder.exists():
            logs.append({
                "file": "",
                "depth_dbar": d,
                "rows_in": 0,
                "rows_out": 0,
                "sat_finite": 0,
                "sat_ge120": 0,
                "dropped_sat_ge120": 0,
                "teos_used": 0,
                "weiss_used": 0,
                "note": f"folder_not_found:{folder}"
            })
            continue

        depth_val = parse_depth_from_folder(folder)
        files = find_train_csvs(folder)
        if not files:
            logs.append({
                "file": "",
                "depth_dbar": depth_val,
                "rows_in": 0,
                "rows_out": 0,
                "sat_finite": 0,
                "sat_ge120": 0,
                "dropped_sat_ge120": 0,
                "teos_used": 0,
                "weiss_used": 0,
                "note": f"no_train_csv:{folder}"
            })
            continue

        for fp in files:
            logs.append(process_one_csv(fp, depth_val))

    df_log = pd.DataFrame(logs)
    log_path = LOG_DIR / "sat_qc_drop_log.csv"
    df_log.to_csv(log_path, index=False, encoding=WRITE_ENCODING)

    # summary
    df_files = df_log[df_log["file"].astype(str).str.len() > 0].copy()
    df_deep = df_files[df_files["depth_dbar"].astype(float) > 200].copy()

    def summarize(df_in: pd.DataFrame, tag: str):
        if df_in.empty:
            return {
                "group": tag,
                "n_files": 0,
                "rows_in_total": 0,
                "rows_out_total": 0,
                "dropped_total": 0,
                "sat_finite_total": 0,
                "sat_ge120_total": 0,
                "drop_ratio_over_rows_in": np.nan,
                "ge120_ratio_over_sat_finite": np.nan
            }
        rows_in_total = int(df_in["rows_in"].sum())
        rows_out_total = int(df_in["rows_out"].sum())
        dropped_total = int(df_in["dropped_sat_ge120"].sum())
        sat_finite_total = int(df_in["sat_finite"].sum())
        sat_ge120_total = int(df_in["sat_ge120"].sum())
        return {
            "group": tag,
            "n_files": int(len(df_in)),
            "rows_in_total": rows_in_total,
            "rows_out_total": rows_out_total,
            "dropped_total": dropped_total,
            "sat_finite_total": sat_finite_total,
            "sat_ge120_total": sat_ge120_total,
            "drop_ratio_over_rows_in": (dropped_total / rows_in_total) if rows_in_total else np.nan,
            "ge120_ratio_over_sat_finite": (sat_ge120_total / sat_finite_total) if sat_finite_total else np.nan
        }

    df_sum = pd.DataFrame([
        summarize(df_files, "ALL_TARGET_FILES"),
        summarize(df_deep,  "DEPTH_GT_200_ONLY"),
    ])
    sum_path = LOG_DIR / "sat_qc_drop_log_summary.csv"
    df_sum.to_csv(sum_path, index=False, encoding=WRITE_ENCODING)

    deep = df_sum[df_sum["group"] == "DEPTH_GT_200_ONLY"].iloc[0].to_dict()

    print("[DONE]")
    print(f"  Per-file log : {log_path}")
    print(f"  Summary log  : {sum_path}")
    print("  ---- DEPTH>200 (drop Sat>=120) ----")
    print(f"  files={deep['n_files']}, rows_in_total={deep['rows_in_total']}, rows_out_total={deep['rows_out_total']}")
    print(f"  dropped_total={deep['dropped_total']} (ratio={deep['drop_ratio_over_rows_in']})")
    print(f"  sat_ge120_total={deep['sat_ge120_total']}, sat_finite_total={deep['sat_finite_total']}, "
          f"ge120_ratio_over_sat_finite={deep['ge120_ratio_over_sat_finite']}")

if __name__ == "__main__":
    main()
