In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Monthly grid aggregation for Oxygen + representativeness error proxy (SIGMA_rep)
Python 3.9 + Jupyter compatible.

Implemented requirements:
- Process TARGET_DEPTHS under: /data/wang/Result_Data/alldoxy/{dep}dbar/*TRAIN.csv
  (also tolerates "{dep}dabr")
- Output ONLY ONE file per depth:
    {dep}dbar/depth{dep}_TRAIN.csv
- Original *TRAIN.csv are NOT deleted; they are renamed to:
    <stem>__ORIG.csv  (collision-safe)
- If output file already exists, rename it aside first:
    depth{dep}_TRAIN__ORIG_<timestamp>.csv

Binning:
- Longitude: wrap to [-180,180), then floor-bin to 0.5° anchored at -180
  output with ONE decimal (string): -180.0, -179.5, ..., 179.5
- Latitude: nearest center mapping to lat_centers.txt
  output EXACT strings from lat_centers.txt

Aggregation per (Date=YYYY-MM-15, Year, Month, Latitude, Longitude, Pressure):
- Oxygen: median(Oxygen)
- Oxygen_MAD: median(|Oxygen - median(Oxygen)|)
- SIGMA_rep: 1.4826 * Oxygen_MAD (string trimmed zeros)
- sigma_interp: median(sigma_interp)
- n_obs: count
- Source: Source of observation closest to median Oxygen (tie -> first)
- Source_fraction:
    - if n_obs == 1 => empty string ""
    - else JSON string of per-source COUNTS sorted by count desc then name asc

Notes:
- Source assumed always present.
- Oxygen assumed already valid (0,600).
"""

import gc
import re
import json
from pathlib import Path
from typing import Dict, List
from datetime import datetime

import numpy as np
import pandas as pd


# =========================
# User configuration
# =========================
ROOT_DIR = Path("/data/wang/Result_Data/alldoxy")
LAT_CENTERS_PATH = ROOT_DIR / "lat_centers.txt"

TARGET_DEPTHS = [
    1
]

# Lon bins definition
LON_START = -180.0
LON_END_EXCL = 180.0
LON_RES = 0.5
LON_BINS = np.arange(LON_START, LON_END_EXCL, LON_RES, dtype=np.float64)  # -180 ... 179.5

# Minimal columns required from TRAIN.csv
USECOLS = ["Date", "Latitude", "Longitude", "Oxygen", "Source", "sigma_interp"]
ENCODINGS_TO_TRY = ("utf-8", "utf-8-sig", "latin1")
LOW_MEMORY = False

# If RAM is tight, set chunksize (e.g., 1_000_000). None = read whole file.
CHUNKSIZE = None  # e.g., 1_000_000


# =========================
# Helpers
# =========================
def read_lat_centers(path: Path) -> Dict[str, np.ndarray]:
    if not path.exists():
        raise FileNotFoundError(f"lat_centers.txt not found: {path}")

    centers_str: List[str] = []
    centers_float: List[float] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            centers_str.append(s)
            centers_float.append(float(s))

    if not centers_str:
        raise ValueError(f"lat_centers.txt is empty: {path}")

    cf = np.array(centers_float, dtype=np.float64)
    cs = np.array(centers_str, dtype=object)

    if not np.all(cf[1:] >= cf[:-1]):
        idx = np.argsort(cf)
        cf = cf[idx]
        cs = cs[idx]

    return {"centers_float": cf, "centers_str": cs}


def safe_read_csv(path: Path, usecols: List[str]):
    last_err = None
    for enc in ENCODINGS_TO_TRY:
        try:
            return pd.read_csv(
                path,
                usecols=usecols,
                encoding=enc,
                low_memory=LOW_MEMORY,
                chunksize=CHUNKSIZE,
            )
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(f"Failed to read {path}. Last error: {repr(last_err)}")


def wrap_lon_to_180(lon: np.ndarray) -> np.ndarray:
    x = lon.astype(np.float64, copy=False)
    out = ((x + 180.0) % 360.0) - 180.0
    out[out == 180.0] = -180.0
    return out


def lon_to_bin_floor(lon_wrapped: np.ndarray) -> np.ndarray:
    x = lon_wrapped.astype(np.float64, copy=False)
    x = np.clip(x, -180.0, np.nextafter(180.0, -np.inf))
    idx = np.floor((x - LON_START) / LON_RES).astype(np.int64)
    idx = np.clip(idx, 0, len(LON_BINS) - 1)
    return (LON_START + idx * LON_RES).astype(np.float64)


def lat_to_nearest_center_index(lat: np.ndarray, centers_float: np.ndarray) -> np.ndarray:
    x = lat.astype(np.float64, copy=False)
    x = np.clip(x, centers_float[0], centers_float[-1])

    idx = np.searchsorted(centers_float, x, side="left")
    idx = np.clip(idx, 0, len(centers_float) - 1)

    idx0 = np.maximum(idx - 1, 0)
    idx1 = idx

    c0 = centers_float[idx0]
    c1 = centers_float[idx1]

    choose_left = np.abs(x - c0) <= np.abs(x - c1)
    return np.where(choose_left, idx0, idx1).astype(np.int64)


def find_depth_folders(root: Path, target_depths: List[int]) -> Dict[int, Path]:
    out: Dict[int, Path] = {}
    pat = re.compile(r"^(\d+)dba[rr]$")  # dbar or dabr
    for p in root.iterdir():
        if not p.is_dir():
            continue
        m = pat.match(p.name)
        if not m:
            continue
        dep = int(m.group(1))
        if dep not in target_depths:
            continue
        if dep in out:
            if p.name.endswith("dbar"):
                out[dep] = p
        else:
            out[dep] = p
    return out


def list_train_files(depth_dir: Path) -> List[Path]:
    return sorted([p for p in depth_dir.glob("*TRAIN.csv") if p.is_file()])


def make_date_ym15(year: pd.Series, month: pd.Series) -> pd.Series:
    y = year.astype(np.int32)
    m = month.astype(np.int32)
    return y.astype(str).str.zfill(4) + "-" + m.astype(str).str.zfill(2) + "-15"


def fmt_sigma_rep(x: float) -> str:
    if pd.isna(x):
        return ""
    s = f"{float(x):.12f}".rstrip("0").rstrip(".")
    return s if s else "0"


def fmt_lon_one_decimal(x: float) -> str:
    if pd.isna(x):
        return ""
    return f"{float(x):.1f}"


def source_fraction_json_count(src_series: pd.Series) -> str:
    vc = src_series.astype("string").str.strip().value_counts(dropna=True)
    if vc.sum() == 0:
        return "{}"
    items = [(str(k), int(v)) for k, v in vc.items()]
    items.sort(key=lambda kv: (-kv[1], kv[0]))
    obj = {k: v for k, v in items}
    return json.dumps(obj, ensure_ascii=False, separators=(",", ":"))


def normalize_chunk(df: pd.DataFrame, lat_centers: Dict[str, np.ndarray], depth_dbar: int) -> pd.DataFrame:
    dt = pd.to_datetime(df["Date"], errors="coerce")
    year = dt.dt.year
    month = dt.dt.month

    lat = pd.to_numeric(df["Latitude"], errors="coerce")
    lon = pd.to_numeric(df["Longitude"], errors="coerce")
    oxy = pd.to_numeric(df["Oxygen"], errors="coerce")
    sigi = pd.to_numeric(df["sigma_interp"], errors="coerce")
    src = df["Source"]

    m = year.notna() & month.notna() & lat.notna() & lon.notna() & oxy.notna() & src.notna()
    if int(m.sum()) == 0:
        return pd.DataFrame(columns=["Date","Year","Month","Latitude","Longitude","Pressure","Oxygen","sigma_interp","Source"])

    year = year[m].astype(np.int16)
    month = month[m].astype(np.int8)

    latv = lat[m].astype(np.float64)
    lonv = lon[m].astype(np.float64)
    oxyv = oxy[m].astype(np.float64)
    sigv = sigi[m].astype(np.float64)
    srcv = src[m].astype("string")

    lon_wrapped = wrap_lon_to_180(lonv.to_numpy(dtype=np.float64, copy=False))
    lon_bin = lon_to_bin_floor(lon_wrapped)

    centers_float = lat_centers["centers_float"]
    centers_str = lat_centers["centers_str"]
    lat_idx = lat_to_nearest_center_index(latv.to_numpy(dtype=np.float64, copy=False), centers_float)
    lat_bin_str = centers_str[lat_idx]

    out = pd.DataFrame(
        {
            "Year": year,
            "Month": month,
            "Latitude": pd.Series(lat_bin_str, dtype="string"),
            "Longitude": pd.Series([fmt_lon_one_decimal(v) for v in lon_bin], dtype="string"),
            "Pressure": int(depth_dbar),
            "Oxygen": oxyv,
            "sigma_interp": sigv,
            "Source": srcv,
        }
    )
    out["Date"] = make_date_ym15(out["Year"], out["Month"])
    return out[["Date","Year","Month","Latitude","Longitude","Pressure","Oxygen","sigma_interp","Source"]]


def load_and_normalize_all(train_files: List[Path], lat_centers: Dict[str, np.ndarray], depth_dbar: int) -> pd.DataFrame:
    parts = []
    for fp in train_files:
        reader_or_df = safe_read_csv(fp, USECOLS)

        if CHUNKSIZE is None:
            df0 = reader_or_df  # type: ignore
            parts.append(normalize_chunk(df0, lat_centers, depth_dbar))
            del df0
            gc.collect()
        else:
            for chunk in reader_or_df:  # type: ignore
                parts.append(normalize_chunk(chunk, lat_centers, depth_dbar))
                del chunk
            gc.collect()

        print(f"    [READ] {fp.name} -> parts={len(parts)}")

    if not parts:
        return pd.DataFrame(columns=["Date","Year","Month","Latitude","Longitude","Pressure","Oxygen","sigma_interp","Source"])

    out = pd.concat(parts, ignore_index=True)
    parts.clear()
    gc.collect()
    return out


def backup_original_train_files(train_files: List[Path]) -> None:
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    for fp in train_files:
        bak = fp.with_name(f"{fp.stem}__ORIG.csv")
        if bak.exists():
            bak_old = fp.with_name(f"{fp.stem}__ORIG__OLD_{ts}.csv")
            bak.rename(bak_old)
            print(f"    [MOVE] {bak.name} -> {bak_old.name}")
        fp.rename(bak)
        print(f"    [BACKUP] {fp.name} -> {bak.name}")


def safe_backup_if_exists(path: Path) -> None:
    if not path.exists():
        return
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    bak = path.with_name(f"{path.stem}__ORIG_{ts}{path.suffix}")
    path.rename(bak)
    print(f"  [MOVE] Existing output {path.name} -> {bak.name}")


def aggregate(df: pd.DataFrame) -> pd.DataFrame:
    out_cols = [
        "Date","Year","Month","Latitude","Longitude","Pressure",
        "Oxygen","Oxygen_MAD","sigma_rep","sigma_interp","n_obs","Source","Source_fraction"
    ]
    if df.empty:
        return pd.DataFrame(columns=out_cols)

    keys = ["Date","Year","Month","Latitude","Longitude","Pressure"]
    g = df.groupby(keys, sort=False, observed=True)

    # median oxygen per row
    med_per_row = g["Oxygen"].transform("median")
    abs_dev = (df["Oxygen"] - med_per_row).abs()

    # aggregated median oxygen
    oxy_med = g["Oxygen"].median().rename("Oxygen").reset_index()

    # MAD
    mad = abs_dev.groupby([df[k] for k in keys], sort=False).median().rename("Oxygen_MAD").reset_index()
    mad.columns = keys + ["Oxygen_MAD"]

    # sigma_interp median
    sig_med = g["sigma_interp"].median().rename("sigma_interp").reset_index()

    # n_obs
    n_obs = g.size().rename("n_obs").reset_index()

    # Source closest to median
    idx = abs_dev.groupby([df[k] for k in keys], sort=False).idxmin()
    src_pick = df.loc[idx, keys + ["Source"]].drop_duplicates(subset=keys, keep="first")

    # Source_fraction JSON (counts)
    src_frac = g["Source"].apply(source_fraction_json_count).rename("Source_fraction").reset_index()

    # Merge
    agg = (
        oxy_med.merge(mad, on=keys, how="left")
               .merge(sig_med, on=keys, how="left")
               .merge(n_obs, on=keys, how="left")
               .merge(src_pick, on=keys, how="left")
               .merge(src_frac, on=keys, how="left")
    )

    # If n_obs == 1 => Source_fraction = ""
    agg["Source_fraction"] = agg["Source_fraction"].where(agg["n_obs"].astype("int64") > 1, "")

    # SIGMA_rep formatted string
    sigma_rep_num = (1.4826 * agg["Oxygen_MAD"].astype(np.float64)).to_numpy()
    agg["sigma_rep"] = [fmt_sigma_rep(v) for v in sigma_rep_num]

    # Sort (lat/lon are strings)
    agg["_Lat_sort"] = pd.to_numeric(agg["Latitude"], errors="coerce")
    agg["_Lon_sort"] = pd.to_numeric(agg["Longitude"], errors="coerce")
    agg = (
        agg[out_cols + ["_Lat_sort","_Lon_sort"]]
        .sort_values(["Year","Month","_Lat_sort","_Lon_sort"], kind="mergesort")
        .drop(columns=["_Lat_sort","_Lon_sort"])
    )

    gc.collect()
    return agg


# =========================
# Main
# =========================
def run():
    print("[INFO] Loading lat centers...")
    lat_centers = read_lat_centers(LAT_CENTERS_PATH)
    print(f"[INFO] lat_centers: n={lat_centers['centers_float'].size} | "
          f"range=({lat_centers['centers_float'].min():.4f}, {lat_centers['centers_float'].max():.4f})")
    print(f"[INFO] lon_bins: n={LON_BINS.size} | range=({LON_BINS.min():.1f}, {LON_BINS.max():.1f}) step={LON_RES}")

    depth_dirs = find_depth_folders(ROOT_DIR, TARGET_DEPTHS)
    if not depth_dirs:
        raise RuntimeError(f"No target depth folders found under {ROOT_DIR} for given TARGET_DEPTHS (n={len(TARGET_DEPTHS)})")

    for dep in TARGET_DEPTHS:
        ddir = depth_dirs.get(dep)
        if ddir is None:
            print(f"[SKIP] Missing folder for depth {dep} (searched dbar/dabr).")
            continue

        train_files = list_train_files(ddir)
        if not train_files:
            print(f"[SKIP] No *TRAIN.csv in {ddir}.")
            continue

        print(f"\n[DEPTH {dep} dbar] folder={ddir} | TRAIN files={len(train_files)}")

        df_all = load_and_normalize_all(train_files, lat_centers, dep)
        print(f"  [INFO] Rows after minimal validity filter: {len(df_all):,}")

        agg_df = aggregate(df_all)
        print(f"  [INFO] Aggregated cells (with obs): {len(agg_df):,}")

        # backup originals
        backup_original_train_files(train_files)

        # write output 
        out_path = ddir / f"depth{dep}_TRAIN.csv"
        safe_backup_if_exists(out_path)

        agg_df.to_csv(out_path, index=False)
        print(f"  [OK] Wrote aggregated file: {out_path}")

        del df_all, agg_df
        gc.collect()

    print("\n[DONE]")


# ---- Execute in Jupyter ----
run()


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Add per-record observation uncertainty SIGMA_obs based on Source
Python 3.9 + Jupyter compatible.

Targets:
  /data/wang/Result_Data/alldoxy/{1dbar,10dbar}/*TRAIN.csv
  (also tolerates folder typo like "{depth}dabr")

Rule (Source -> SIGMA_obs):
  Argo            -> 1.5
  OSDCTD          -> 1.5
  CCHDO_Bottle    -> 1.0
  CCHDO_CTD       -> 1.5
  GLODAPV2 2022   -> 1.0
  OceanSITES      -> 2.0
  Geotraces IDP   -> 1.0

Behavior (NO backup):
- For each matched *TRAIN.csv:
    1) Read original (chunked)
    2) Add/overwrite column "SIGMA_obs"
    3) Write to a temp file in the SAME directory
    4) Atomically replace original file via os.replace(temp, original)

Notes:
- Preserves ALL original columns; only adds "SIGMA_obs".
- "SIGMA_obs" is inserted immediately after "Source" column if present; otherwise appended.
- Unknown Source -> SIGMA_obs = NaN; prints top unknowns.
- Output encoding is UTF-8 (pandas default).
"""

import os
import gc
import re
from pathlib import Path
from typing import Dict, List, Iterable
from datetime import datetime

import numpy as np
import pandas as pd


# =========================
# User configuration
# =========================
ROOT_DIR = Path("/data/wang/Result_Data/alldoxy")
TARGET_DEPTHS = [		1
]   # change if needed

CHUNKSIZE = 1_000_000     # set None to disable chunking
ENCODINGS_TO_TRY = ("utf-8", "utf-8-sig", "latin1")

# Source -> SIGMA_obs mapping
SIGMA_OBS_MAP_RAW: Dict[str, float] = {
    "Argo": 1.5,
    "OSDCTD": 1.5,
    "CCHDO_Bottle": 1.0,
    "CCHDO_CTD": 1.5,
    "GLODAPV2 2022": 1.0,
    "OceanSITES": 2.0,
    "Geotraces IDP": 1.0,
}


# =========================
# Helpers
# =========================
def _norm_source(s: str) -> str:
    """Normalize Source for mapping (strip, collapse spaces, case-insensitive)."""
    if s is None:
        return ""
    x = str(s).strip()
    x = re.sub(r"\s+", " ", x)
    return x.upper()


SIGMA_OBS_MAP: Dict[str, float] = {_norm_source(k): float(v) for k, v in SIGMA_OBS_MAP_RAW.items()}


def find_depth_folders(root: Path, target_depths: List[int]) -> Dict[int, Path]:
    """Accept both '{dep}dbar' and typo '{dep}dabr'. Prefer 'dbar' if both exist."""
    out: Dict[int, Path] = {}
    pat = re.compile(r"^(\d+)dba[rr]$")  # dbar or dabr
    for p in root.iterdir():
        if not p.is_dir():
            continue
        m = pat.match(p.name)
        if not m:
            continue
        dep = int(m.group(1))
        if dep not in target_depths:
            continue
        if dep in out:
            if p.name.endswith("dbar"):
                out[dep] = p
        else:
            out[dep] = p
    return out


def list_train_files(depth_dir: Path) -> List[Path]:
    return sorted([p for p in depth_dir.glob("*TRAIN.csv") if p.is_file()])


def safe_read_csv_iter(path: Path) -> Iterable[pd.DataFrame]:
    """Yield df chunks; reads ALL columns (preserve original attrs)."""
    last_err = None
    for enc in ENCODINGS_TO_TRY:
        try:
            if CHUNKSIZE is None:
                yield pd.read_csv(path, encoding=enc, low_memory=False)
            else:
                for chunk in pd.read_csv(path, encoding=enc, low_memory=False, chunksize=CHUNKSIZE):
                    yield chunk
            return
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(f"Failed to read {path}. Last error: {repr(last_err)}")


def insert_after_source(cols: List[str], new_col: str) -> List[str]:
    """Insert new_col immediately after 'Source' if present, else append."""
    if new_col in cols:
        return cols
    if "Source" in cols:
        i = cols.index("Source")
        return cols[: i + 1] + [new_col] + cols[i + 1 :]
    return cols + [new_col]


def add_sigma_obs_inplace(df: pd.DataFrame) -> None:
    """Add/overwrite SIGMA_obs based on Source mapping; unknown -> NaN."""
    if "Source" not in df.columns:
        raise ValueError("Missing required column 'Source'")

    src_norm = df["Source"].astype("string").map(lambda x: _norm_source(x))
    df["sigma_obs"] = src_norm.map(SIGMA_OBS_MAP).astype("float64")


def atomic_replace_csv(original: Path, write_chunks_fn) -> None:
    """
    Write to temp file in same directory, then os.replace(temp, original) for atomic swap.
    write_chunks_fn(temp_path) should write full CSV to temp_path.
    """
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    tmp = original.with_name(f".{original.name}.tmp_{os.getpid()}_{ts}")
    try:
        write_chunks_fn(tmp)

        # best-effort fsync temp file for durability
        try:
            with open(tmp, "rb") as f:
                os.fsync(f.fileno())
        except Exception:
            pass

        os.replace(tmp, original)  # atomic on same filesystem
    finally:
        if tmp.exists():
            # if anything failed before replace
            try:
                tmp.unlink()
            except Exception:
                pass


# =========================
# Main processing
# =========================
def process_one_file(fp: Path) -> None:
    print(f"  -> Processing: {fp}")

    unknown_sources: Dict[str, int] = {}
    total_rows = 0

    def _write_to_temp(tmp_path: Path) -> None:
        nonlocal total_rows, unknown_sources
        first = True

        for chunk in safe_read_csv_iter(fp):
            total_rows += int(len(chunk))

            add_sigma_obs_inplace(chunk)

            # track unknown sources
            src = chunk["Source"].astype("string").map(lambda x: _norm_source(x))
            unk = src[chunk["sigma_obs"].isna()].value_counts(dropna=True)
            if len(unk) > 0:
                for k, v in unk.items():
                    key = str(k)
                    unknown_sources[key] = unknown_sources.get(key, 0) + int(v)

            # ensure column order
            cols = list(chunk.columns)
            # remove then re-insert to ensure correct position
            cols = [c for c in cols if c != "sigma_obs"]
            cols = insert_after_source(cols, "sigma_obs")
            chunk = chunk[cols]

            chunk.to_csv(tmp_path, index=False, mode="w" if first else "a", header=first)
            first = False

            del chunk
            gc.collect()

    atomic_replace_csv(fp, _write_to_temp)

    print(f"     [OK] Replaced: {fp.name} | rows={total_rows:,}")
    if unknown_sources:
        top = sorted(unknown_sources.items(), key=lambda x: x[1], reverse=True)[:15]
        msg = ", ".join([f"{k}:{v}" for k, v in top])
        print(f"     [WARN] Unknown Source -> sigma_obs=NaN (top): {msg}")


def run():
    depth_dirs = find_depth_folders(ROOT_DIR, TARGET_DEPTHS)
    if not depth_dirs:
        raise RuntimeError(f"No target depth folders found under {ROOT_DIR} for {TARGET_DEPTHS}")

    for dep in TARGET_DEPTHS:
        ddir = depth_dirs.get(dep)
        if ddir is None:
            print(f"[WARN] Missing folder for depth {dep} (searched dbar/dabr). Skip.")
            continue

        train_files = list_train_files(ddir)
        if not train_files:
            print(f"[WARN] No *TRAIN.csv in {ddir}. Skip.")
            continue

        print(f"\n[DEPTH {dep}] folder={ddir} | TRAIN files={len(train_files)}")
        for fp in train_files:
            process_one_file(fp)

    print("\n[DONE]")


# ---- Execute in Jupyter ----
run()


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Map SIGMA_rep > 3 for Month==1 from:
  /data/wang/Result_Data/alldoxy/1dbar/depth1_TRAIN.csv

Default: grid-binned mean for values > 3 only.
Optional: scatter (subsample) for quick look.
"""

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy.feature as cfeature


# =========================
# User config
# =========================
CSV_PATH = Path("/data/wang/Result_Data/alldoxy/1dbar/depth1_TRAIN.csv")
OUT_PNG  = CSV_PATH.with_name("depth1_TRAIN_Month01_SIGMA_rep_gt3_map.png")

COL_LAT  = "Latitude"
COL_LON  = "Longitude"
COL_MON  = "Month"
COL_VAL  = "sigma_rep"

THRESH = 3.0  # only plot values > THRESH

# ---- Gridding (recommended) ----
DX_DEG = 1.0
DY_DEG = 1.0

# ---- Chunked read (for big files) ----
CHUNKSIZE = 2_000_000

# ---- Plot control ----
USE_SCATTER = False
SCATTER_MAX_N = 300_000


def wrap_lon(lon):
    lon = np.asarray(lon, dtype="float64")
    return ((lon + 180.0) % 360.0) - 180.0


def grid_mean_map_gt(csv_path: Path, thresh: float):
    lon_edges = np.arange(-180.0, 180.0 + DX_DEG, DX_DEG)
    lat_edges = np.arange(-90.0,   90.0 + DY_DEG, DY_DEG)

    sum_grid = np.zeros((lat_edges.size - 1, lon_edges.size - 1), dtype="float64")
    cnt_grid = np.zeros_like(sum_grid, dtype="int64")

    usecols = [COL_MON, COL_LAT, COL_LON, COL_VAL]

    for chunk in pd.read_csv(csv_path, usecols=usecols, chunksize=CHUNKSIZE, low_memory=False):
        # Month==1
        mon = pd.to_numeric(chunk[COL_MON], errors="coerce")
        msk = (mon == 1)
        if not msk.any():
            continue

        sub = chunk.loc[msk, [COL_LAT, COL_LON, COL_VAL]].copy()
        sub[COL_LAT] = pd.to_numeric(sub[COL_LAT], errors="coerce")
        sub[COL_LON] = pd.to_numeric(sub[COL_LON], errors="coerce")
        sub[COL_VAL] = pd.to_numeric(sub[COL_VAL], errors="coerce")
        sub = sub.dropna(subset=[COL_LAT, COL_LON, COL_VAL])
        if sub.empty:
            continue

        lat = sub[COL_LAT].to_numpy(dtype="float64", copy=False)
        lon = wrap_lon(sub[COL_LON].to_numpy(dtype="float64", copy=False))
        val = sub[COL_VAL].to_numpy(dtype="float64", copy=False)

        # keep only val > thresh
        ok = (
            (lat >= -90) & (lat <= 90) &
            (lon >= -180) & (lon < 180) &
            np.isfinite(val) & (val > thresh)
        )
        if not np.any(ok):
            continue

        lat = lat[ok]; lon = lon[ok]; val = val[ok]

        sum2d, _, _ = np.histogram2d(lat, lon, bins=[lat_edges, lon_edges], weights=val)
        cnt2d, _, _ = np.histogram2d(lat, lon, bins=[lat_edges, lon_edges])

        sum_grid += sum2d
        cnt_grid += cnt2d.astype("int64")

    mean_grid = np.full_like(sum_grid, np.nan, dtype="float64")
    valid = cnt_grid > 0
    mean_grid[valid] = sum_grid[valid] / cnt_grid[valid]

    return lon_edges, lat_edges, mean_grid, cnt_grid


def scatter_map_gt(csv_path: Path, thresh: float):
    usecols = [COL_MON, COL_LAT, COL_LON, COL_VAL]
    df = pd.read_csv(csv_path, usecols=usecols, low_memory=False)

    df[COL_MON] = pd.to_numeric(df[COL_MON], errors="coerce")
    df = df.loc[df[COL_MON] == 1, [COL_LAT, COL_LON, COL_VAL]].copy()

    df[COL_LAT] = pd.to_numeric(df[COL_LAT], errors="coerce")
    df[COL_LON] = pd.to_numeric(df[COL_LON], errors="coerce")
    df[COL_VAL] = pd.to_numeric(df[COL_VAL], errors="coerce")
    df = df.dropna(subset=[COL_LAT, COL_LON, COL_VAL])
    if df.empty:
        raise RuntimeError("No valid rows for Month==1 after cleaning.")

    lon = wrap_lon(df[COL_LON].to_numpy(dtype="float64", copy=False))
    lat = df[COL_LAT].to_numpy(dtype="float64", copy=False)
    val = df[COL_VAL].to_numpy(dtype="float64", copy=False)

    ok = (
        (lat >= -90) & (lat <= 90) &
        (lon >= -180) & (lon < 180) &
        np.isfinite(val) & (val > thresh)
    )
    lon, lat, val = lon[ok], lat[ok], val[ok]

    if lon.size > SCATTER_MAX_N:
        idx = np.random.default_rng(42).choice(lon.size, size=SCATTER_MAX_N, replace=False)
        lon, lat, val = lon[idx], lat[idx], val[idx]

    return lon, lat, val


def main():
    fig = plt.figure(figsize=(12, 5.5))
    ax = plt.axes(projection=ccrs.PlateCarree())
    ax.set_global()

    ax.add_feature(cfeature.LAND, zorder=1, facecolor="0.9", edgecolor="0.5", linewidth=0.3)
    ax.add_feature(cfeature.COASTLINE, zorder=2, linewidth=0.4)

    if USE_SCATTER:
        lon, lat, val = scatter_map_gt(CSV_PATH, THRESH)
        sc = ax.scatter(
            lon, lat, c=val, s=3, transform=ccrs.PlateCarree(),
            linewidths=0, alpha=0.6, zorder=3
        )
        cb = plt.colorbar(sc, ax=ax, orientation="vertical", pad=0.02, shrink=0.92)
        cb.set_label(f"{COL_VAL} (> {THRESH:g})")
        ax.set_title(f"{COL_VAL} (Month=1, >{THRESH:g}) | scatter (n={lon.size})")

    else:
        lon_edges, lat_edges, mean_grid, cnt_grid = grid_mean_map_gt(CSV_PATH, THRESH)

        pm = ax.pcolormesh(
            lon_edges, lat_edges, mean_grid,
            transform=ccrs.PlateCarree(),
            shading="auto", zorder=3
        )
        cb = plt.colorbar(pm, ax=ax, orientation="vertical", pad=0.02, shrink=0.92)
        cb.set_label(f"{COL_VAL} (> {THRESH:g}) | binned mean")
        ax.set_title(f"{COL_VAL} (Month=1, >{THRESH:g}) | gridded mean ({DX_DEG:.2f}°×{DY_DEG:.2f}°)")

    gl = ax.gridlines(draw_labels=True, linewidth=0.3, alpha=0.5)
    gl.top_labels = False
    gl.right_labels = False

    plt.tight_layout()
    fig.savefig(OUT_PNG, dpi=300)
    print(f"[OK] Saved figure: {OUT_PNG}")


if __name__ == "__main__":
    main()


In [None]:
import os
import re
from pathlib import Path
from typing import Optional

import numpy as np
import pandas as pd
import xarray as xr


# -------------------------
# User configuration
# -------------------------
ROOT_DIR = Path("/data/wang/Result_Data/alldoxy")
NC_PATH  = Path("/data/wang/Merage_Biomes_0p5deg.nc")

# Depths to process
depths = [
        1,10,20,30,40,50,60,70,80,90,100,
        110,120,130,140,150,160,170,180,190,200,210,220,230,240,250,260,
        270,280,290,300,310,320,330,340,350,360,370,380,390,400,410,420,
        430,440,450,460,470,480,490,500,510,520,530,540,550,560,570,580,
        590,600,610,620,630,640,650,660,670,680,690,700,710,720,730,740,
        750,760,770,780,790,800,820,840,860,880,900,920,940,960,980,1000,
        1020,1040,1060,1080,1100,1120,1140,1160,1180,1200,1220,1240,1260,
        1280,1300,1320,1340,1360,1380,1400,1420,1440,1460,1480,1500,1520,
        1540,1560,1580,1600,1620,1640,1660,1680,1700,1720,1740,1760,1780,
        1800,1820,1840,1860,1880,1900,1920,1940,1960,1980,2000,2100,2200,
        2300,2400,2500,2600,2700,2800,2900,3000,3100,3200,3300,3400,3500,
        3600,3700,3800,3900,4000,4100,4200,4300,4400,4500,4600,4700,4800,4900,
        5000,5100,5200,5300,5400,5500
]

# Input CSV filename pattern inside each "{dep}dbar" directory
PATTERN = "*TRAIN.csv"   # e.g., "*TRAIN.csv" or "*NoAgg.csv", etc.

LAT_COL = "Latitude"
LON_COL = "Longitude"
MONTH_COL = "Month"  # required for SOM_Zone (1..12)

# SOM labels
SOM_LABEL_DIR = Path("/data/wang/Result_Data/Province_SOM/labels")
SOM_FILE_FMT  = "SOM_province_labels_depth{dep}m.nc"
ALLOW_NEAREST_DEPTH_FILE = True
SOM_INVALID_VALUE = -1

# Output controls
ATOMIC_OVERWRITE = True         # True: atomic overwrite original file; False: write to a new file
OUT_SUFFIX = "_withSOM"         # output filename suffix


# -------------------------
# Helper functions
# -------------------------
def normalize_lon(lon: np.ndarray, lon_nc: np.ndarray) -> np.ndarray:
    """Normalize input longitudes to match the NC convention (0..360 or -180..180)."""
    lon = lon.astype(np.float64)
    lon_nc = np.asarray(lon_nc, dtype=np.float64)
    lon_nc_min = float(np.nanmin(lon_nc))
    lon_nc_max = float(np.nanmax(lon_nc))

    nc_is_0_360 = (lon_nc_min >= -1e-6) and (lon_nc_max > 180.0)
    nc_is_m180_180 = (lon_nc_min < 0.0) and (lon_nc_max <= 180.0 + 1e-6)

    if nc_is_0_360:
        lon2 = np.mod(lon, 360.0)
        lon2[lon2 < 0] += 360.0
        return lon2
    if nc_is_m180_180:
        lon2 = np.mod(lon + 180.0, 360.0) - 180.0
        return lon2
    return lon


def build_grid_mapper(lon_nc: np.ndarray, lat_nc: np.ndarray):
    """Build a mapper from (lon, lat) to the nearest regular-grid indices."""
    lon_nc = np.asarray(lon_nc, dtype=np.float64)
    lat_nc = np.asarray(lat_nc, dtype=np.float64)

    if not (np.all(np.diff(lon_nc) > 0) and np.all(np.diff(lat_nc) > 0)):
        raise ValueError("NC lon/lat must be strictly increasing. Please check the file.")

    lon_step = float(np.nanmedian(np.diff(lon_nc)))
    lat_step = float(np.nanmedian(np.diff(lat_nc)))
    lon0 = float(lon_nc[0])
    lat0 = float(lat_nc[0])
    nlon = lon_nc.size
    nlat = lat_nc.size

    lon_tol = 0.5 * lon_step + 1e-6
    lat_tol = 0.5 * lat_step + 1e-6

    def to_index(lon_in: np.ndarray, lat_in: np.ndarray):
        lon_in = np.asarray(lon_in, dtype=np.float64)
        lat_in = np.asarray(lat_in, dtype=np.float64)

        lon_pos = (lon_in - lon0) / lon_step
        lat_pos = (lat_in - lat0) / lat_step

        lon_idx = np.rint(lon_pos).astype(np.int64)
        lat_idx = np.rint(lat_pos).astype(np.int64)

        in_bounds = (lon_idx >= 0) & (lon_idx < nlon) & (lat_idx >= 0) & (lat_idx < nlat)

        lon_center = lon0 + lon_idx.astype(np.float64) * lon_step
        lat_center = lat0 + lat_idx.astype(np.float64) * lat_step

        close_enough = (np.abs(lon_in - lon_center) <= lon_tol) & (np.abs(lat_in - lat_center) <= lat_tol)
        valid = in_bounds & close_enough
        return lon_idx, lat_idx, valid

    meta = dict(
        lon0=lon0, lat0=lat0,
        lon_step=lon_step, lat_step=lat_step,
        nlon=nlon, nlat=nlat,
        lon_tol=lon_tol, lat_tol=lat_tol
    )
    return to_index, meta


def extract_by_indices_2d(var_lon_lat: np.ndarray,
                          lon_idx: np.ndarray,
                          lat_idx: np.ndarray,
                          valid: np.ndarray) -> np.ndarray:
    """2D variable with shape (nlon, nlat). Return extracted values; otherwise NaN."""
    nlon, nlat = var_lon_lat.shape
    out = np.full(lon_idx.shape, np.nan, dtype=np.float64)
    if valid.any():
        li = lon_idx[valid]
        la = lat_idx[valid]
        flat = li * nlat + la
        out[valid] = var_lon_lat.reshape(-1)[flat]
    return out


def extract_by_indices_3d_month_lat_lon(var_mll: np.ndarray,
                                       mon_idx: np.ndarray,
                                       lat_idx: np.ndarray,
                                       lon_idx: np.ndarray,
                                       valid_xy: np.ndarray,
                                       valid_mon: np.ndarray) -> np.ndarray:
    """3D variable with shape (12, nlat, nlon). Return extracted values; otherwise NaN."""
    out = np.full(mon_idx.shape, np.nan, dtype=np.float64)
    good = valid_xy & valid_mon
    if not good.any():
        return out

    mi = mon_idx[good].astype(np.int64, copy=False)
    la = lat_idx[good].astype(np.int64, copy=False)
    lo = lon_idx[good].astype(np.int64, copy=False)

    nlat = var_mll.shape[1]
    nlon = var_mll.shape[2]
    flat = (mi * (nlat * nlon)) + (la * nlon) + lo

    out[good] = var_mll.reshape(-1)[flat]
    return out


def parse_depth_from_filename(p: Path) -> Optional[int]:
    """Parse depth (meters) from filename like '..._depth{dep}m.nc'."""
    m = re.search(r"depth(\d+)m\.nc$", p.name)
    return int(m.group(1)) if m else None


def find_som_file_for_depth(dep: int) -> Optional[Path]:
    """Find the SOM label file for a given depth (exact match or nearest, if enabled)."""
    exact = SOM_LABEL_DIR / SOM_FILE_FMT.format(dep=dep)
    if exact.exists():
        return exact
    if not ALLOW_NEAREST_DEPTH_FILE:
        return None

    cand = []
    for f in SOM_LABEL_DIR.glob("SOM_province_labels_depth*m.nc"):
        d = parse_depth_from_filename(f)
        if d is not None:
            cand.append((abs(d - dep), d, f))
    if not cand:
        return None
    cand.sort(key=lambda x: (x[0], x[1]))
    return cand[0][2]


def load_meanbiomes_nc(nc_path: Path):
    """Load MeanBiomes/ExcludeMask and lon/lat from the provided NetCDF."""
    if not nc_path.exists():
        raise FileNotFoundError(f"NC file not found: {nc_path}")

    ds = xr.open_dataset(nc_path)
    for v in ["MeanBiomes", "ExcludeMask", "lon", "lat"]:
        if v not in ds.variables:
            ds.close()
            raise KeyError(f"Variable '{v}' not found in {nc_path}")

    mean_biomes = ds["MeanBiomes"].values  # (nlon, nlat)
    exclude_mask = ds["ExcludeMask"].values
    lon_nc = ds["lon"].values
    lat_nc = ds["lat"].values
    ds.close()

    if mean_biomes.ndim != 2 or exclude_mask.ndim != 2:
        raise ValueError("MeanBiomes / ExcludeMask must be 2D arrays.")
    if mean_biomes.shape != exclude_mask.shape:
        raise ValueError("MeanBiomes and ExcludeMask shapes do not match.")
    if mean_biomes.shape != (lon_nc.size, lat_nc.size):
        raise ValueError("2D variable shape must match (len(lon), len(lat)).")

    return mean_biomes, exclude_mask, lon_nc, lat_nc


def load_som_nc_for_depth(dep: int):
    """Load SOM province labels and coordinate variables for the given depth."""
    if not SOM_LABEL_DIR.exists():
        raise FileNotFoundError(f"SOM label directory not found: {SOM_LABEL_DIR}")

    som_nc = find_som_file_for_depth(dep)
    if som_nc is None:
        return None, None, None, None, None

    ds = xr.open_dataset(som_nc)
    for v in ["province_id", "lon", "lat", "month"]:
        if v not in ds.variables:
            ds.close()
            raise KeyError(f"Variable '{v}' not found in {som_nc}")

    province_id = ds["province_id"].values  # expected shape (12, nlat, nlon)
    lon_som = ds["lon"].values
    lat_som = ds["lat"].values
    ts_depth = float(ds.attrs.get("ts_depth_m", np.nan))
    lon_conv = ds.attrs.get("lon_convention", "unknown")
    ds.close()

    if province_id.ndim != 3 or province_id.shape[0] != 12:
        raise ValueError(f"province_id must be 3D with the first dimension = 12 months; got {province_id.shape}")
    if province_id.shape[1] != lat_som.size or province_id.shape[2] != lon_som.size:
        raise ValueError("province_id shape does not match (lat, lon).")

    return som_nc, province_id, lon_som, lat_som, ts_depth, lon_conv


def process_one_csv_three_fields(csv_path: Path,
                                out_path: Path,
                                mean_biomes: np.ndarray,
                                exclude_mask: np.ndarray,
                                lon_mb: np.ndarray,
                                lat_mb: np.ndarray,
                                province_id: Optional[np.ndarray],
                                lon_som: Optional[np.ndarray],
                                lat_som: Optional[np.ndarray]) -> None:
    """
    Process a whole CSV file:
    - Remove existing Zone0 / ExcludeMask / SOM_Zone columns (if present)
    - Recompute and write new values based on MeanBiomes/ExcludeMask and SOM labels
    """

    # Build mappers
    to_idx_mb, meta_mb = build_grid_mapper(lon_mb, lat_mb)

    if province_id is not None:
        to_idx_som, meta_som = build_grid_mapper(lon_som, lat_som)
    else:
        to_idx_som, meta_som = None, None

    print(f"\n[FILE] {csv_path}")
    print(f"  MeanBiomes grid: lon0={meta_mb['lon0']}, lat0={meta_mb['lat0']}, "
          f"dlon={meta_mb['lon_step']}, dlat={meta_mb['lat_step']}")

    if province_id is None:
        print("  SOM grid: [skipped] no SOM file for this depth (SOM_Zone will be NA)")
    else:
        print(f"  SOM grid: lon0={meta_som['lon0']}, lat0={meta_som['lat0']}, "
              f"dlon={meta_som['lon_step']}, dlat={meta_som['lat_step']}")

    # Read the entire file
    chunk = pd.read_csv(csv_path)

    # Drop existing columns if already present
    for col in ["Zone0", "ExcludeMask", "SOM_Zone"]:
        if col in chunk.columns:
            del chunk[col]

    # Parse lat/lon
    lat = pd.to_numeric(chunk[LAT_COL], errors="coerce").to_numpy(np.float64)
    lon_raw = pd.to_numeric(chunk[LON_COL], errors="coerce").to_numpy(np.float64)

    # ---------- MeanBiomes mapping ----------
    lon_mb_in = normalize_lon(lon_raw, lon_mb)
    lon_idx_mb, lat_idx_mb, valid_mb = to_idx_mb(lon_mb_in, lat)

    zone_vals = extract_by_indices_2d(mean_biomes, lon_idx_mb, lat_idx_mb, valid_mb)
    excl_vals = extract_by_indices_2d(exclude_mask, lon_idx_mb, lat_idx_mb, valid_mb)

    # ---------- Zone0 and ExcludeMask assignment ----------
    zone_int = np.rint(zone_vals).astype(np.int64, casting="unsafe", copy=False)
    excl_int = np.rint(excl_vals).astype(np.int64, casting="unsafe", copy=False)

    zone_valid = np.isfinite(zone_vals)
    excl_valid = np.isfinite(excl_vals)

    chunk["Zone0"] = pd.Series(pd.array(np.where(zone_valid, zone_int, pd.NA), dtype="Int64"))
    chunk["ExcludeMask"] = pd.Series(pd.array(np.where(excl_valid, excl_int, pd.NA), dtype="Int64"))

    # ---------- SOM mapping (optional) ----------
    if province_id is None:
        chunk["SOM_Zone"] = pd.Series(pd.array(np.full(len(chunk), pd.NA), dtype="Int64"))
    else:
        if MONTH_COL not in chunk.columns:
            raise KeyError(f"Missing '{MONTH_COL}' in {csv_path.name} (SOM_Zone depends on month).")

        lon_som_in = normalize_lon(lon_raw, lon_som)
        lon_idx_som, lat_idx_som, valid_som_xy = to_idx_som(lon_som_in, lat)

        mon_raw = pd.to_numeric(chunk[MONTH_COL], errors="coerce").to_numpy(np.float64)
        mon_int = np.rint(mon_raw).astype(np.int64, casting="unsafe", copy=False)
        valid_mon = np.isfinite(mon_raw) & (mon_int >= 1) & (mon_int <= 12)
        mon_idx = (mon_int - 1).astype(np.int64, copy=False)

        som_vals = extract_by_indices_3d_month_lat_lon(
            province_id, mon_idx, lat_idx_som, lon_idx_som, valid_som_xy, valid_mon
        )

        som_int = np.rint(som_vals).astype(np.int64, casting="unsafe", copy=False)
        # Rule: must be >= 0 and must not be the fill value (-1)
        som_valid = np.isfinite(som_vals) & (som_int >= 0) & (som_int != SOM_INVALID_VALUE)

        chunk["SOM_Zone"] = pd.Series(pd.array(np.where(som_valid, som_int, pd.NA), dtype="Int64"))

    # Write processed data
    chunk.to_csv(out_path, index=False)

    print(f"[INFO] Completed: {csv_path.name} -> {out_path.name}")


# -------------------------
# Main
# -------------------------
if not ROOT_DIR.exists():
    raise FileNotFoundError(f"ROOT_DIR not found: {ROOT_DIR}")
if not SOM_LABEL_DIR.exists():
    raise FileNotFoundError(f"SOM label directory not found: {SOM_LABEL_DIR}")

print("[OK] Loading MeanBiomes/ExcludeMask NC...")
mean_biomes, exclude_mask, lon_mb, lat_mb = load_meanbiomes_nc(NC_PATH)
print("[OK] MeanBiomes NC loaded.")
print("  MeanBiomes:", mean_biomes.shape, "ExcludeMask:", exclude_mask.shape)

for dep in depths:
    dep_dir = ROOT_DIR / f"{dep}dbar"
    if not dep_dir.exists():
        print(f"\n[SKIP] Depth directory not found: {dep_dir}")
        continue

    files = sorted(dep_dir.glob(PATTERN))
    if not files:
        print(f"\n[SKIP] No '{PATTERN}' files found in {dep_dir}")
        continue

    som_nc, province_id, lon_som, lat_som, ts_depth, lon_conv = load_som_nc_for_depth(dep)
    if som_nc is None:
        print(f"\n[DEPTH] {dep} dbar | files={len(files)} | SOM: [missing] -> SOM_Zone will be NA")
    else:
        print(f"\n[DEPTH] {dep} dbar | files={len(files)} | SOM={som_nc.name} | ts_depth_m={ts_depth} | lon_convention={lon_conv}")

    for csv_path in files:
        out_path = csv_path.with_name(csv_path.stem + OUT_SUFFIX + csv_path.suffix)

        process_one_csv_three_fields(
            csv_path=csv_path,
            out_path=out_path,
            mean_biomes=mean_biomes,
            exclude_mask=exclude_mask,
            lon_mb=lon_mb,
            lat_mb=lat_mb,
            province_id=province_id,
            lon_som=lon_som,
            lat_som=lat_som
        )

print("\nAll done.")

In [None]:
# sigma_rep: if value equals 0, set it to empty (blank)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Batch edit for *TRAIN.csv under:
  /data/wang/Result_Data/alldoxy/{depth}dbar or {depth}dabr/

Requirement:
- For target depths in TARGET_DEPTHS, in each *TRAIN.csv:
    if sigma_rep == 0  -> set to empty (blank)
- All other fields remain unchanged (read/write as strings to minimize reformatting).
- Edit "in place" via safe-write (temp file) then atomic replace.

Python 3.9 + Jupyter compatible.
"""

import os
import re
from pathlib import Path
import pandas as pd


# =========================
# User configuration
# =========================
ROOT_DIR = Path("/data/wang/Result_Data/alldoxy")
TARGET_DEPTHS = [
    10,20,30,40,50,60,70,80,90,100,
    110,120,130,140,150,160,170,180,190,200,210,220,230,240,250,260,
    270,280,290,300,310,320,330,340,350,360,370,380,390,400,410,420,
    430,440,450,460,470,480,490,500,510,520,530,540,550,560,570,580,
    590,600,610,620,630,640,650,660,670,680,690,700,710,720,730,740,
    750,760,770,780,790,800,820,840,860,880,900,920,940,960,980,1000,
    1020,1040,1060,1080,1100,1120,1140,1160,1180,1200,1220,1240,1260,
    1280,1300,1320,1340,1360,1380,1400,1420,1440,1460,1480,1500,1520,
    1540,1560,1580,1600,1620,1640,1660,1680,1700,1720,1740,1760,1780,
    1800,1820,1840,1860,1880,1900,1920,1940,1960,1980,2000,2100,2200,
    2300,2400,2500,2600,2700,2800,2900,3000,3100,3200,3300,3400,3500,
    3600,3700,3800,3900,4000,4100,4200,4300,4400,4500,4600,4700,4800,4900,
    5000,5100,5200,5300,5400,5500
]
FILE_GLOB = "*TRAIN.csv"

# Large files: adjust based on memory / I/O
CHUNK_SIZE = 600_000

# Target column to edit
COL = "sigma_rep"


# =========================
# Helpers
# =========================
ZERO_RE = re.compile(r"^[\+\-]?0+(\.0+)?([eE][\+\-]?0+)?$")

def is_zero_string(x: str) -> bool:
    """
    Decide whether a string represents numeric zero (common CSV encodings).
    Examples treated as zero:
      '0', '0.0', '0.00', '+0', '-0', '0e0', '0E+00', '000.000'
    """
    if x is None:
        return False
    s = str(x).strip()
    if s == "":
        return False
    return bool(ZERO_RE.match(s))


def process_one_csv_inplace(csv_path: Path) -> None:
    """
    Stream-read as strings to preserve formatting as much as possible,
    replace sigma_rep zeros with blank, then overwrite the file safely.
    """
    if not csv_path.exists():
        print(f"[SKIP] Missing: {csv_path}")
        return

    tmp_path = csv_path.with_suffix(csv_path.suffix + ".tmp")

    # Read as strings and do NOT auto-convert empty strings to NaN
    reader = pd.read_csv(
        csv_path,
        chunksize=CHUNK_SIZE,
        dtype=str,
        keep_default_na=False,
        na_values=[],
        low_memory=False,
    )

    wrote_header = False
    total_rows = 0
    changed = 0
    cols_ref = None

    try:
        for chunk in reader:
            if cols_ref is None:
                cols_ref = list(chunk.columns)

            if COL in chunk.columns:
                s = chunk[COL].astype(str)
                m = s.map(is_zero_string)
                if m.any():
                    changed += int(m.sum())
                    chunk.loc[m, COL] = ""  # blank

            total_rows += len(chunk)

            chunk.to_csv(
                tmp_path,
                mode="w" if not wrote_header else "a",
                index=False,
                header=(not wrote_header),
                columns=cols_ref,   # preserve original column order
                encoding="utf-8",
            )
            wrote_header = True

        # Atomic replace
        os.replace(tmp_path, csv_path)

        print(f"[OK] {csv_path} | rows={total_rows:,} | {COL} zero->blank: {changed:,}")

    except Exception as e:
        # Cleanup temp file on failure
        try:
            if tmp_path.exists():
                tmp_path.unlink()
        except Exception:
            pass
        print(f"[ERROR] {csv_path} | {type(e).__name__}: {e}")


def find_depth_dirs(root: Path, depth: int):
    """
    Accept both 'dbar' and occasional 'dabr' naming.
    Return a list of existing directories.
    """
    candidates = [root / f"{depth}dbar", root / f"{depth}dabr"]
    return [p for p in candidates if p.is_dir()]


# =========================
# Main
# =========================
def main():
    if not ROOT_DIR.is_dir():
        raise FileNotFoundError(f"ROOT_DIR not found: {ROOT_DIR}")

    all_files = []
    for dep in TARGET_DEPTHS:
        dep_dirs = find_depth_dirs(ROOT_DIR, dep)
        if not dep_dirs:
            print(f"[WARN] Depth folder not found for {dep}: tried {dep}dbar / {dep}dabr")
            continue

        for d in dep_dirs:
            files = sorted(d.glob(FILE_GLOB))
            if not files:
                print(f"[WARN] No files matched '{FILE_GLOB}' in {d}")
            all_files.extend(files)

    if not all_files:
        print("[DONE] No target files found.")
        return

    print(f"[INFO] Found {len(all_files)} file(s) to process.")
    for fp in all_files:
        process_one_csv_inplace(fp)

    print("[DONE] All processing finished.")


# In Jupyter, just run this cell; as a script, it also works.
main()