In [None]:

TARGET_DEPTHS = [
   10,20,30,40,50,60,70,80,90,100,
	    110,120,130,140,150,160,170,180,190,200,210,220,230,240,250,260,
	    270,280,290,300,310,320,330,340,350,360,370,380,390,400,410,420,
	    430,440,450,460,470,480,490,500,510,520,530,540,550,560,570,580,
	    590,600,610,620,630,640,650,660,670,680,690,700,710,720,730,740,
	    750,760,770,780,790,800,820,840,860,880,900,920,940,960,980,1000,
	    1020,1040,1060,1080,1100,1120,1140,1160,1180,1200,1220,1240,1260,
	    1280,1300,1320,1340,1360,1380,1400,1420,1440,1460,1480,1500,1520,
	    1540,1560,1580,1600,1620,1640,1660,1680,1700,1720,1740,1760,1780,
	    1800,1820,1840,1860,1880,1900,1920,1940,1960,1980,2000,2100,2200,
	    2300,2400,2500,2600,2700,2800,2900,3000,3100,3200,3300,3400,3500,
	    3600,3700,3800,3900,4000,4100,4200,4300,4400,4500,4600,4700,4800,4900,
	    5000,5100,5200,5300,5400,5500
]

In [None]:
import os
import glob
import math
import warnings
from typing import Tuple, Dict, List
import numpy as np
import pandas as pd
import xarray as xr
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing

# Optional progress bar (falls back to plain-text percentage if tqdm is not installed)
try:
    from tqdm import tqdm
    HAS_TQDM = True
except Exception:
    HAS_TQDM = False

# ===================== Configurable section =====================

# Input base directory (subfolders by depth)
ALLOXY_BASE = "/data/wang/Result_Data/alldoxy"

# Only process these depth levels (CORA5.2 depth unit is meters; 1 ≈ 1 dbar)
DEPTHS = TARGET_DEPTHS

# CMEMS data directories
TEMP_DIR = "/data/wang/CMEMS/TEMP"
PSAL_DIR = "/data/wang/CMEMS/PSAL"

# Tolerances
LATLON_TOL_DEG = 0.1     # nearest-neighbor lat/lon tolerance (degrees)
DEPTH_TOL = 1.1          # nearest-neighbor depth tolerance (meters)

# Write strategy
INPLACE_UPDATE = True    # True: read/write the same depthX_TRAIN.csv (in-place update with atomic replace)
OVERWRITE_TRAIN = True   # True: overwrite Temp/Sal; False: only fill Temp/Sal where they are NaN

# Parallelism
MAX_WORKERS = 48  # tune based on I/O / CPU
# Prevent extra BLAS/OMP parallelism from oversubscribing cores
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# After all group matches, optionally fill remaining NaN Temp/Sal from original columns (Temperature/Salinity -> Temp/Sal)
FALLBACK_COPY_FROM_ORIGINAL = False

# =====================================================

warnings.filterwarnings("ignore", category=FutureWarning)

# -------------------- Utility functions -------------------- #

def ensure_month_str(m):
    m = str(m).zfill(2)
    if len(m) != 2 or not m.isdigit():
        raise ValueError(f"Invalid month: {m}")
    return m

def get_year_month_strs(year, month) -> Tuple[str, str]:
    y = str(int(year))
    m = ensure_month_str(month)
    return y, m

def nc_paths_for_year_month(year: str, month: str) -> Tuple[str, str]:
    date_token = f"{year}{month}15"
    temp_path = os.path.join(TEMP_DIR, f"OA_CORA5.2_{date_token}_fld_TEMP.nc")
    psal_path = os.path.join(PSAL_DIR, f"OA_CORA5.2_{date_token}_fld_PSAL.nc")
    return temp_path, psal_path

def normalize_lon_to_grid(lon_vals: np.ndarray, grid_min: float, grid_max: float) -> np.ndarray:
    lons = lon_vals.astype(float).copy()
    # CORA5.2 longitudes are typically in [-180, 180)
    if grid_min < 0 <= grid_max <= 180:
        lons = ((lons + 180.0) % 360.0) - 180.0
    elif 0 <= grid_min < 360:
        lons = lons % 360.0
    else:
        span = grid_max - grid_min
        lons = ((lons - grid_min) % span) + grid_min
    return lons

def nearest_indices_with_tol(grid: np.ndarray, values: np.ndarray, tol: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    grid = np.asarray(grid)
    values = np.asarray(values)
    reversed_grid = False
    if grid[0] > grid[-1]:
        grid = grid[::-1]
        reversed_grid = True

    pos = np.searchsorted(grid, values, side="left")
    pos0 = np.clip(pos - 1, 0, len(grid) - 1)
    pos1 = np.clip(pos, 0, len(grid) - 1)

    v0 = grid[pos0]
    v1 = grid[pos1]
    choose_left = np.abs(values - v0) <= np.abs(values - v1)
    idx = np.where(choose_left, pos0, pos1)
    nearest_val = np.where(choose_left, v0, v1)
    abs_diff = np.abs(values - nearest_val)

    if reversed_grid:
        idx = (len(grid) - 1) - idx
    return idx.astype(int), nearest_val, abs_diff

def open_nc_pair(temp_path: str, psal_path: str) -> Tuple[xr.Dataset, xr.Dataset]:
    if not os.path.exists(temp_path):
        raise FileNotFoundError(f"Temperature file not found: {temp_path}")
    if not os.path.exists(psal_path):
        raise FileNotFoundError(f"Salinity file not found: {psal_path}")
    # decode_cf=True applies scale_factor / add_offset automatically
    ds_t = xr.open_dataset(temp_path, decode_cf=True)
    ds_s = xr.open_dataset(psal_path, decode_cf=True)
    return ds_t, ds_s

def pick_depth_da(da: xr.DataArray, target_depth: float, tol: float) -> Tuple[xr.DataArray, float]:
    if "time" in da.dims:
        da = da.isel(time=0)
    if "depth" not in da.dims:
        raise ValueError("No 'depth' dimension found in the data array")
    picked = da.sel(depth=target_depth, method="nearest", tolerance=tol)
    actual_depth = float(picked.coords["depth"].values)
    if abs(actual_depth - target_depth) > tol:
        raise KeyError(f"Nearest depth {actual_depth} exceeds tolerance {tol} (target {target_depth})")
    return picked, actual_depth

def extract_values_at_points(field2d: xr.DataArray, lat_vals: np.ndarray, lon_vals: np.ndarray,
                             latlon_tol: float) -> np.ndarray:
    lat_name = "latitude" if "latitude" in field2d.dims else ("lat" if "lat" in field2d.dims else None)
    lon_name = "longitude" if "longitude" in field2d.dims else ("lon" if "lon" in field2d.dims else None)
    if lat_name is None or lon_name is None:
        raise ValueError(f"Cannot identify lat/lon dimension names: {field2d.dims}")

    grid_lats = field2d.coords[lat_name].values
    grid_lons = field2d.coords[lon_name].values

    lon_vals_norm = normalize_lon_to_grid(lon_vals, float(grid_lons.min()), float(grid_lons.max()))
    lat_idx, lat_nn, lat_diff = nearest_indices_with_tol(grid_lats, lat_vals, latlon_tol)
    lon_idx, lon_nn, lon_diff = nearest_indices_with_tol(grid_lons, lon_vals_norm, latlon_tol)

    ok = (lat_diff <= latlon_tol) & (lon_diff <= latlon_tol)
    out = np.full(len(lat_vals), np.nan, dtype=float)
    if np.any(ok):
        f = field2d.values  # (nlat, nlon)
        sel = np.where(ok)[0]
        out[ok] = f[lat_idx[sel], lon_idx[sel]]
    return out

# -------------------- Group-parallel worker -------------------- #

def worker_group(
    year: str,
    month: str,
    depth_target: int,
    lat_list: List[float],
    lon_list: List[float],
    row_index: List[int],
    latlon_tol_deg: float,
    depth_tol: float,
) -> Dict:
    """
    Process a single (year, month) group:
      - open TEMP/PSAL netCDF for that month
      - select nearest depth (tolerance=depth_tol)
      - extract nearest gridpoint values (tolerance=latlon_tol_deg)
      - return {'index': row_index, 'Temp': ..., 'Sal': ..., 'msg': '...'}
    """
    temp_nc, psal_nc = nc_paths_for_year_month(year, month)
    n = len(row_index)
    result = {"index": row_index, "Temp": [math.nan] * n, "Sal": [math.nan] * n, "msg": ""}

    if not (os.path.exists(temp_nc) and os.path.exists(psal_nc)):
        # Missing monthly nc: silently return NaNs (no per-group printing)
        return result

    ds_t, ds_s = None, None
    try:
        ds_t, ds_s = open_nc_pair(temp_nc, psal_nc)
        t2d, actual_depth_t = pick_depth_da(ds_t["TEMP"], depth_target, depth_tol)
        s2d, actual_depth_s = pick_depth_da(ds_s["PSAL"], actual_depth_t, depth_tol)

        lat_arr = np.asarray(lat_list, dtype=float)
        lon_arr = np.asarray(lon_list, dtype=float)
        temp_vals = extract_values_at_points(t2d, lat_arr, lon_arr, latlon_tol_deg)
        sal_vals  = extract_values_at_points(s2d, lat_arr, lon_arr, latlon_tol_deg)

        result["Temp"] = temp_vals.tolist()
        result["Sal"] = sal_vals.tolist()
        # Do not print per-group OK messages
        return result
    except Exception as e:
        # Carry error message in result, no per-group printing
        result["msg"] = f"[Exception] {year}-{month}: {e}"
        return result
    finally:
        if ds_t is not None: ds_t.close()
        if ds_s is not None: ds_s.close()

# -------------------- Per-depth processing (parallel by Year/Month) -------------------- #

def process_depth(depth_target: int):
    depth_dir = os.path.join(ALLOXY_BASE, f"{int(depth_target)}dbar")
    print(f"\n===== Processing depth {depth_target} -> dir {depth_dir} =====")

    if not os.path.isdir(depth_dir):
        print(f"[Info] Directory does not exist, skipping depth: {depth_dir}")
        return

    out_name = f"depth{int(depth_target)}_TRAIN.csv"
    out_path = os.path.join(depth_dir, out_name)

    # ——— Read input ———
    if INPLACE_UPDATE:
        if not os.path.exists(out_path):
            print(f"[Info] In-place update requested, but input file not found: {out_path}. Skipping this depth.")
            return
        try:
            df_all = pd.read_csv(out_path)
        except Exception as e:
            print(f"[Error] Failed to read: {out_path} -> {e}")
            return
        src_files = [out_name]
    else:
        src_files = sorted(
            p for p in glob.glob(os.path.join(depth_dir, "*_TRAIN.csv"))
            if os.path.basename(p) != out_name
        )
        if not src_files:
            print(f"[Info] No *_TRAIN.csv found in {depth_dir}, skipping.")
            return
        frames = []
        for p in src_files:
            try:
                df = pd.read_csv(p)
            except Exception as e:
                print(f"[Warning] Failed to read, skipping: {p} -> {e}")
                continue
            frames.append(df)
        if not frames:
            print(f"[Info] No usable *_TRAIN.csv in {depth_dir}, skipping.")
            return
        df_all = pd.concat(frames, ignore_index=True, sort=False)

    # Required columns
    required = {"Year", "Month", "Latitude", "Longitude"}
    if not required.issubset(df_all.columns):
        print(f"[Error] Missing required columns {required}, skipping this depth.")
        return

    # Ensure output columns exist
    if "Temp" not in df_all.columns: df_all["Temp"] = np.nan
    if "Sal"  not in df_all.columns: df_all["Sal"]  = np.nan

    # ——— Parallel processing by (Year, Month) ———
    groups = df_all.groupby(["Year", "Month"], dropna=False)
    print(f"Source files: {[os.path.basename(p) for p in src_files]}")
    print(f"Number of groups: {len(groups)} (by Year, Month) -> parallel, MAX_WORKERS={MAX_WORKERS}")

    tasks = []
    # Choose a usable multiprocessing start method (prefer fork)
    try:
        mp_ctx = multiprocessing.get_context("fork")
    except ValueError:
        mp_ctx = multiprocessing.get_context("spawn")

    with ProcessPoolExecutor(max_workers=MAX_WORKERS, mp_context=mp_ctx) as ex:
        for (y_raw, m_raw), g in groups:
            try:
                year, month = get_year_month_strs(y_raw, m_raw)
            except Exception as e:
                print(f"[Warning] Failed to parse Year/Month: ({y_raw}, {m_raw}) -> {e}. Skipping this group.")
                continue
            fut = ex.submit(
                worker_group,
                year, month, int(depth_target),
                g["Latitude"].tolist(),
                g["Longitude"].tolist(),
                g.index.astype(int).tolist(),
                LATLON_TOL_DEG, DEPTH_TOL
            )
            tasks.append(fut)

        total = len(tasks)
        if total == 0:
            print("[Info] No tasks to execute.")
        else:
            # Progress handling: tqdm if available, else percentage; do not print per-group OK/errors
            if HAS_TQDM:
                with tqdm(total=total, unit="grp", desc=f"Depth {int(depth_target)}", dynamic_ncols=True) as pbar:
                    for fut in as_completed(tasks):
                        res = fut.result()
                        idx = res["index"]
                        new_temp = pd.Series(res["Temp"], index=idx)
                        new_sal  = pd.Series(res["Sal"],  index=idx)
                        if OVERWRITE_TRAIN:
                            df_all.loc[idx, "Temp"] = new_temp
                            df_all.loc[idx, "Sal"]  = new_sal
                        else:
                            m1 = df_all.loc[idx, "Temp"].isna()
                            m2 = df_all.loc[idx, "Sal"].isna()
                            df_all.loc[idx[m1], "Temp"] = new_temp[m1]
                            df_all.loc[idx[m2], "Sal"]  = new_sal[m2]
                        pbar.update(1)
            else:
                done = 0
                print(f"Depth {int(depth_target)} progress: 0/{total} (0%)", end="", flush=True)
                for fut in as_completed(tasks):
                    res = fut.result()
                    idx = res["index"]
                    new_temp = pd.Series(res["Temp"], index=idx)
                    new_sal  = pd.Series(res["Sal"],  index=idx)
                    if OVERWRITE_TRAIN:
                        df_all.loc[idx, "Temp"] = new_temp
                        df_all.loc[idx, "Sal"]  = new_sal
                    else:
                        m1 = df_all.loc[idx, "Temp"].isna()
                        m2 = df_all.loc[idx, "Sal"].isna()
                        df_all.loc[idx[m1], "Temp"] = new_temp[m1]
                        df_all.loc[idx[m2], "Sal"]  = new_sal[m2]
                    done += 1
                    pct = int(done * 100 / total) if total else 100
                    print(f"\rDepth {int(depth_target)} progress: {done}/{total} ({pct}%)", end="", flush=True)
                print()  # newline

    # ——— Optional unified fallback fill for remaining NaNs (Temperature->Temp, Salinity->Sal) ———
    if FALLBACK_COPY_FROM_ORIGINAL:
        # Convert to numeric; non-numeric -> NaN (prevents incorrect fills)
        df_all["Temp"] = pd.to_numeric(df_all["Temp"], errors="coerce")
        df_all["Sal"]  = pd.to_numeric(df_all["Sal"],  errors="coerce")

        # Temp fallback
        if "Temperature" in df_all.columns:
            src_temp = pd.to_numeric(df_all["Temperature"], errors="coerce")
            mask_temp = df_all["Temp"].isna() & src_temp.notna()
            n_temp = int(mask_temp.sum())
            if n_temp > 0:
                df_all.loc[mask_temp, "Temp"] = src_temp[mask_temp]
            print(f"[Fallback] Temp filled for {n_temp} rows (from Temperature).")
        else:
            print("[Info] No 'Temperature' column found; cannot fallback-fill Temp.")

        # Sal fallback
        if "Salinity" in df_all.columns:
            src_sal = pd.to_numeric(df_all["Salinity"], errors="coerce")
            mask_sal = df_all["Sal"].isna() & src_sal.notna()
            n_sal = int(mask_sal.sum())
            if n_sal > 0:
                df_all.loc[mask_sal, "Sal"] = src_sal[mask_sal]
            print(f"[Fallback] Sal filled for {n_sal} rows (from Salinity).")
        else:
            print("[Info] No 'Salinity' column found; cannot fallback-fill Sal.")

    # ——— Safe write-out (atomic replace for in-place update) ———
    if INPLACE_UPDATE:
        tmp_path = out_path + ".__tmp__"
        df_all.to_csv(tmp_path, index=False)
        os.replace(tmp_path, out_path)   # atomic replace
        print(f"[Write] In-place update completed: {out_path} (atomic replace)")
    else:
        if OVERWRITE_TRAIN or not os.path.exists(out_path):
            df_all.to_csv(out_path, index=False)
            print(f"[Write] Overwrote {out_path}, total rows={len(df_all)}.")
        else:
            old = pd.read_csv(out_path)

            def make_key(df: pd.DataFrame) -> pd.Series:
                return (
                    df["Year"].astype(str).str.zfill(4)
                    + "-"
                    + df["Month"].astype(str).str.zfill(2)
                    + "-"
                    + df["Latitude"].round(5).astype(str)
                    + "-"
                    + df["Longitude"].round(5).astype(str)
                )

            old["_key_"] = make_key(old).values
            df_all["_key_"] = make_key(df_all).values
            old_idxed = old.set_index("_key_", drop=False)
            new_idxed = df_all.set_index("_key_", drop=False)

            for col in ["Temp", "Sal"]:
                if col not in old_idxed.columns:
                    old_idxed[col] = np.nan
                src = new_idxed[col]
                is_na = old_idxed[col].isna()
                idx_inter = old_idxed.index.intersection(src.index)
                old_idxed.loc[is_na & old_idxed.index.isin(idx_inter), col] = src[is_na & src.index.isin(idx_inter)]

            only_new = new_idxed.index.difference(old_idxed.index)
            combined = pd.concat([old_idxed, new_idxed.loc[only_new]], axis=0, ignore_index=False)
            combined = combined.drop(columns=["_key_"], errors="ignore")
            combined.to_csv(out_path, index=False)
            print(f"[Write] Incremental update completed: {out_path}, current rows={len(combined)}.")

def main():
    if not DEPTHS:
        print("[Error] DEPTHS is empty; no depth levels specified.")
        return
    print(f"Depth levels to process (in order): {DEPTHS}")
    for d in DEPTHS:
        process_depth(d)
    print("\nAll processing completed.")

if __name__ == "__main__":
    main()

In [None]:
# -*- coding: utf-8 -*-
"""
Batch-add for multiple depth-level CSV files:
- O2_sat (μmol/kg): oxygen solubility at surface equilibrium (0 dbar), column name configurable (COL_O2_SAT)

Configurable temperature/salinity column names:
- TempName = 'Temperature'
- SalName  = 'Salinity'

Priority:
- Use TEOS-10 (gsw) when available: SP + pt0 -> O2sol_SP_pt (μmol/kg)
Fallback:
- If gsw is missing or required fields are unavailable, use Weiss (1970) (ml/L)
  + EOS-80 density to convert to μmol/kg

The script overwrites the original file. If MAKE_BACKUP=True, it writes a *.bak backup in the same directory.
"""

import os
import shutil
import numpy as np
import pandas as pd
from multiprocessing import Pool, cpu_count
from pathlib import Path

# ---------------- Configuration ----------------
DOXY_BASE = Path("/data/wang/Result_Data/alldoxy/")  # data root path
DEPTHS = TARGET_DEPTHS

MAKE_BACKUP = False  # whether to create backup files
N_PROCESSES = 24     # moderate parallelism to avoid I/O contention

# ---------- New: output column name (configurable) ----------
COL_O2_SAT = "O2_sat"  # oxygen solubility column name

# ---------- New: temperature/salinity column names used for computation ----------
TempName = "Temp"  # temperature column name
SalName  = "Sal"   # salinity column name

# Optional: original dissolved oxygen column name (kept as 'Oxygen' here)
OXYGEN_COL = "Oxygen"

# ---------- Fallback path: Weiss (1970) + EOS-80 density ----------
def _rho_eos80_kg_m3(S, T):
    """EOS-80 density (kg/m^3), approximate at 0 dbar; S=PSU, T=°C (ITS-90)."""
    S = np.asarray(S, float)
    T = np.asarray(T, float)
    rho_w = (999.842594 + 6.793952e-2*T - 9.095290e-3*T**2
             + 1.001685e-4*T**3 - 1.120083e-6*T**4 + 6.536332e-9*T**5)
    A = (0.824493 - 4.0899e-3*T + 7.6438e-5*T**2 - 8.2467e-7*T**3 + 5.3875e-9*T**4)
    B = (-5.72466e-3 + 1.0227e-4*T - 1.6546e-6*T**2)
    C = 4.8314e-4
    return rho_w + A*S + B*(S**1.5) + C*(S**2)

def _o2sol_weiss_ml_per_L(T, S):
    """Weiss (1970) O2 solubility (ml/L); T=°C (ITS-90). Internally converts to IPTS-68 and Kelvin."""
    T = np.asarray(T, float)
    S = np.asarray(S, float)
    Tk = T*1.00024 + 273.15  # ITS-90 -> IPTS-68, then to Kelvin
    A1, A2, A3, A4 = -173.4292, 249.6339, 143.3483, -21.8492
    B1, B2, B3 = -0.033096, 0.014259, -0.0017000
    lnC = (A1 + A2*(100.0/Tk) + A3*np.log(Tk/100.0) + A4*(Tk/100.0)
           + S*(B1 + B2*(Tk/100.0) + B3*(Tk/100.0)**2))
    return np.exp(lnC)

def o2_sat_umolkg_weiss(T, S):
    """Convert Weiss O2 solubility to μmol/kg."""
    mlL = _o2sol_weiss_ml_per_L(T, S)
    rho = _rho_eos80_kg_m3(S, T)  # kg/m^3
    return mlL * 44.659 * (1000.0 / rho)  # μmol/kg (1 mL(STP) O2 = 44.659 μmol)

# ---------- Preferred path: TEOS-10 (gsw) ----------
def o2_sat_umolkg_teos10(SP, t, p, lon, lat):
    """
    TEOS-10: compute O2 solubility using potential temperature pt0 (μmol/kg, referenced to 0 dbar).
    Requires: Practical Salinity (SP), in-situ T(°C, ITS-90), Pressure(dbar), lon, lat.
    """
    import gsw
    SP = np.asarray(SP, float)
    t  = np.asarray(t, float)
    p  = np.asarray(p, float)
    lon = np.asarray(lon, float)
    lat = np.asarray(lat, float)
    SA  = gsw.SA_from_SP(SP, p, lon, lat)
    pt0 = gsw.pt0_from_t(SA, t, p)
    return gsw.O2sol_SP_pt(SP, pt0)  # μmol/kg

# ---------- Process a single depth ----------
def process_single_depth(depth):
    """
    Process one depth level:
      Input:  /data/wang/Result_Data/alldoxy/{depth}dbar/depth{depth}_TRAIN.csv
      Output: overwrite the same file with a new COL_O2_SAT column; optionally create a .bak file

    Returns: (depth, n_rows, n_teos, n_weiss, msg)
    """
    dir_path = DOXY_BASE / f"{depth}dbar"
    file_path = dir_path / f"depth{depth}_TRAIN.csv"
    if not file_path.exists():
        return (depth, 0, 0, 0, f"File not found: {file_path}")

    try:
        df = pd.read_csv(file_path)
    except Exception as e:
        return (depth, 0, 0, 0, f"Read failed: {e}")

    n0 = len(df)
    if n0 == 0:
        return (depth, 0, 0, 0, "Empty file, skipped")

    # Validate required columns for both paths (minimum shared set)
    if not {TempName, SalName}.issubset(df.columns):
        missing = {TempName, SalName} - set(df.columns)
        return (depth, n0, 0, 0, f"Missing columns: {missing}")

    O2_sat_arr = np.full(n0, np.nan, dtype=float)
    used_teos = 0
    used_weiss = 0

    # --- Prefer TEOS-10 when possible (only rows with all required fields present) ---
    try:
        import gsw  # noqa: F401
        pres_series = df.get("Pressure", pd.Series(np.nan, index=df.index))
        lon_series  = df.get("Longitude", pd.Series(np.nan, index=df.index))
        lat_series  = df.get("Latitude", pd.Series(np.nan, index=df.index))

        mask_teos = (
            df[SalName].notna() &
            df[TempName].notna() &
            pres_series.notna() &
            lon_series.notna() &
            lat_series.notna()
        ).to_numpy()

        if mask_teos.any():
            O2_sat_arr[mask_teos] = o2_sat_umolkg_teos10(
                df.loc[mask_teos, SalName].to_numpy(dtype=float),
                df.loc[mask_teos, TempName].to_numpy(dtype=float),
                df.loc[mask_teos, "Pressure"].to_numpy(dtype=float),
                df.loc[mask_teos, "Longitude"].to_numpy(dtype=float),
                df.loc[mask_teos, "Latitude"].to_numpy(dtype=float),
            )
            used_teos = int(mask_teos.sum())
    except Exception:
        # If TEOS-10 fails, continue with Weiss fallback
        pass

    # --- Weiss fallback (rows still NaN but have T/S) ---
    mask_weiss = np.isnan(O2_sat_arr) & df[TempName].notna() & df[SalName].notna()
    if mask_weiss.any():
        O2_sat_arr[mask_weiss] = o2_sat_umolkg_weiss(
            df.loc[mask_weiss, TempName].to_numpy(dtype=float),
            df.loc[mask_weiss, SalName].to_numpy(dtype=float),
        )
        used_weiss = int(mask_weiss.sum())

    # Write back to DataFrame (only the new O2_sat column)
    df[COL_O2_SAT] = O2_sat_arr

    # Diagnostics
    finite = np.isfinite(O2_sat_arr)
    mean_sat = float(np.nanmean(O2_sat_arr)) if finite.any() else np.nan
    std_sat  = float(np.nanstd(O2_sat_arr)) if finite.any() else np.nan
    msg_stats = f"{COL_O2_SAT} finite={int(finite.sum())}/{n0}, mean={mean_sat:.2f}, std={std_sat:.2f} μmol/kg"

    # Backup + overwrite
    try:
        if MAKE_BACKUP:
            bak_path = str(file_path) + ".bak"
            shutil.copy2(file_path, bak_path)
        df.to_csv(file_path, index=False)
    except Exception as e:
        return (depth, n0, used_teos, used_weiss, f"Write failed: {e}")

    # Summary message
    route_note = []
    if used_teos > 0:
        route_note.append(f"TEOS-10={used_teos}")
    if used_weiss > 0:
        route_note.append(f"Weiss={used_weiss}")
    if not route_note:
        route_note.append("No valid rows computed")

    return (depth, n0, used_teos, used_weiss, f"{', '.join(route_note)} | {msg_stats}")

# ---------- Main ----------
def main():
    tasks = list(DEPTHS)
    if not tasks:
        print("DEPTHS is not specified; exiting.")
        return

    print(f"Start processing depth levels: {tasks}")

    # Parallelize across depths (I/O-bound; moderate parallelism is recommended)
    if len(tasks) == 1 or N_PROCESSES == 1:
        results = [process_single_depth(d) for d in tasks]
    else:
        with Pool(processes=N_PROCESSES) as pool:
            results = pool.map(process_single_depth, tasks)

    # Summary
    print("\n--- Batch results ---")
    for depth, n0, n_teos, n_weiss, msg in results:
        print(f"[{depth} dbar] rows={n0} | {msg}")

    print("\n✅ All depths finished.")

if __name__ == "__main__":
    main()

In [None]:
# 2026 MLD-only
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Only extract MLD (mlotst) from CMEMS ARMOR3D monthly files and merge into DOXY TRAIN CSVs.

Optimizations:
  - Vectorized nearest-grid matching using np.searchsorted (O(N log M), no per-row argmin)
  - Vectorized extraction of mlotst via NumPy advanced indexing
  - Safer engine fallback for xr.open_dataset
  - Keep multiprocessing but use a moderate process count to avoid I/O thrashing
"""

import os
import re
import logging
from pathlib import Path
from multiprocessing import Pool

import xarray as xr
import pandas as pd
import numpy as np
from tqdm import tqdm

# ---------------------- Configuration ----------------------
CMEMS_DIR = Path("/data/wang/CMEMS/MLDuv/")
DOXY_BASE = Path("/data/wang/Result_Data/alldoxy/")
LATLON_TOLERANCE = 0.4  # degree
# dbar folders / filenames remain as you use
NPROC = min(16, (os.cpu_count() or 16))          # I/O-bound; too many processes are often slower
IMAP_CHUNKSIZE = 6                                # reduce overhead; tune for your data

# ---------------------- Logging ----------------------
def setup_logger(output_dir: Path) -> logging.Logger:
    """Multiprocess-friendly logger setup with directory creation."""
    output_dir.mkdir(parents=True, exist_ok=True)

    logger = logging.getLogger(f"merge_mld_{os.getpid()}")
    logger.setLevel(logging.INFO)
    logger.propagate = False

    # Remove existing handlers to avoid duplicate logs
    if logger.handlers:
        for h in list(logger.handlers):
            logger.removeHandler(h)

    pid = os.getpid()
    log_file = output_dir / f"merge_mld_{pid}.log"

    fh = logging.FileHandler(log_file)
    fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s: %(message)s"))

    ch = logging.StreamHandler()
    ch.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))

    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger

# ---------------------- Utilities: open dataset with engine fallback ----------------------
def open_dataset_safe(nc_path: Path):
    """
    ARMOR3D files may fail with the default engine due to compression/format edge cases.
    Try multiple engines as a fallback strategy.
    """
    last_err = None
    for eng in ["netcdf4", "h5netcdf", None]:
        try:
            if eng is None:
                return xr.open_dataset(nc_path, decode_times=False)
            return xr.open_dataset(nc_path, decode_times=False, engine=eng)
        except Exception as e:
            last_err = e
            continue
    raise last_err

# ---------------------- Vectorized nearest index on monotonic 1D axis ----------------------
def nearest_index_1d(sorted_axis: np.ndarray, values: np.ndarray) -> np.ndarray:
    """
    For a strictly increasing coordinate sorted_axis, return the nearest grid index for each value.
    Uses searchsorted and compares neighbors; much faster than per-point argmin.
    """
    a = np.asarray(sorted_axis, dtype=np.float64)
    v = np.asarray(values, dtype=np.float64)

    idx = np.searchsorted(a, v, side="left")
    idx = np.clip(idx, 0, len(a) - 1)

    idx0 = np.clip(idx - 1, 0, len(a) - 1)
    idx1 = idx

    d0 = np.abs(a[idx0] - v)
    d1 = np.abs(a[idx1] - v)

    out = np.where(d0 <= d1, idx0, idx1).astype(np.int64)
    return out

def normalize_lon_0_360(lon_deg: np.ndarray) -> np.ndarray:
    """Normalize longitude to [0, 360)."""
    lon = np.asarray(lon_deg, dtype=np.float64)
    return np.mod(lon, 360.0)

def lon_0_360_to_pm180(lon_0_360: np.ndarray) -> np.ndarray:
    """Convert longitude from [0, 360) to (-180, 180]."""
    lon = np.asarray(lon_0_360, dtype=np.float64)
    return np.where(lon <= 180.0, lon, lon - 360.0)

# ---------------------- Core: process one (Year, Month) group (vectorized) ----------------------
def process_group(year: int, month: int, group_df: pd.DataFrame, target_depth: int, logger: logging.Logger) -> pd.DataFrame:
    """
    Extract MLD only:
      - Vectorized nearest-grid matching for all points in group_df
      - Apply lat/lon tolerance filtering
      - Extract mlotst values and write to a new 'MLD' column
    """
    pattern = f"dataset-armor-3d-rep-monthly_{year}{month:02d}15T1200Z_*.nc"
    nc_files = list(CMEMS_DIR.glob(pattern))
    if not nc_files:
        # If file does not exist, return NaNs
        out = group_df.copy()
        out["MLD"] = np.nan
        return out

    nc_path = nc_files[0]
    try:
        with open_dataset_safe(nc_path) as ds:
            # Coordinate name compatibility
            lat_name = "latitude" if "latitude" in ds.coords else ("lat" if "lat" in ds.coords else None)
            lon_name = "longitude" if "longitude" in ds.coords else ("lon" if "lon" in ds.coords else None)
            if lat_name is None or lon_name is None:
                raise KeyError(f"Cannot find lat/lon coords in {nc_path.name}. coords={list(ds.coords)}")

            if "mlotst" not in ds.variables:
                raise KeyError(f"'mlotst' not found in {nc_path.name}. vars={list(ds.variables)}")

            lat_axis = ds[lat_name].values.astype(np.float64)
            lon_axis = ds[lon_name].values.astype(np.float64)

            # Ensure coordinates are increasing; if not, sort and reorder data accordingly
            lat_inc = np.all(np.diff(lat_axis) > 0)
            lon_inc = np.all(np.diff(lon_axis) > 0)
            da_mld = ds["mlotst"].isel(time=0) if "time" in ds["mlotst"].dims else ds["mlotst"]

            if not lat_inc:
                lat_order = np.argsort(lat_axis)
                lat_axis = lat_axis[lat_order]
                da_mld = da_mld.isel({lat_name: lat_order})

            if not lon_inc:
                lon_order = np.argsort(lon_axis)
                lon_axis = lon_axis[lon_order]
                da_mld = da_mld.isel({lon_name: lon_order})

            # Extract lat/lon arrays from the group (vectorized)
            lats = group_df["Latitude"].to_numpy(dtype=np.float64)
            lons_raw = group_df["Longitude"].to_numpy(dtype=np.float64)

            # Dataset longitude is usually 0..360; convert observations to 0..360 for indexing
            lons_nc = normalize_lon_0_360(lons_raw)

            # Nearest indices (vectorized)
            lat_idx = nearest_index_1d(lat_axis, lats)
            lon_idx = nearest_index_1d(lon_axis, lons_nc)

            # Compute tolerance using lon in (-180..180] to avoid artificial wrap-around differences
            actual_lat = lat_axis[lat_idx]
            actual_lon_nc = lon_axis[lon_idx]
            actual_lon = lon_0_360_to_pm180(actual_lon_nc)

            lat_gap = np.abs(actual_lat - lats)
            lon_gap = np.abs(actual_lon - lons_raw)

            ok = (lat_gap <= LATLON_TOLERANCE) & (lon_gap <= LATLON_TOLERANCE)

            # Vectorized MLD extraction (invalid points stay NaN)
            # da_mld dims: (latitude, longitude) or (lat, lon)
            mld_arr = da_mld.values  # 2D numpy
            mld_out = np.full(len(group_df), np.nan, dtype=np.float64)
            if np.any(ok):
                mld_vals = mld_arr[lat_idx[ok], lon_idx[ok]].astype(np.float64)
                mld_out[ok] = mld_vals

            out = group_df.copy()
            out["MLD"] = np.round(mld_out, 3)
            return out

    except Exception as e:
        logger.error(f"Failed {year}-{month:02d} file={nc_path.name}: {e}", exc_info=True)
        out = group_df.copy()
        out["MLD"] = np.nan
        return out

# ---------------------- Multiprocessing wrapper ----------------------
def process_group_wrapper(args):
    year, month, group_df, target_depth = args
    logger = setup_logger(DOXY_BASE / "logs")
    return process_group(year, month, group_df, target_depth, logger)

# ---------------------- Per-depth processing ----------------------
def process_single_depth(target_depth: int):
    log_dir = DOXY_BASE / "logs"
    logger = setup_logger(log_dir)

    input_dir = DOXY_BASE / f"{target_depth}dbar"
    output_dir = input_dir

    try:
        output_dir.mkdir(parents=True, exist_ok=True)

        input_csv = input_dir / f"depth{target_depth}_TRAIN.csv"
        output_csv = output_dir / f"depth{target_depth}_TRAIN.csv"

        if not input_csv.exists():
            logger.warning(f"Input file missing: {input_csv}")
            return

        df = pd.read_csv(input_csv)
        if df.empty:
            logger.warning(f"Empty file: {input_csv}")
            return

        # Basic required column check
        need_cols = {"Year", "Month", "Latitude", "Longitude"}
        miss = need_cols - set(df.columns)
        if miss:
            raise ValueError(f"Missing required columns: {miss}")

        # Group tasks by (Year, Month)
        groups = df.groupby(["Year", "Month"], group_keys=False, sort=False)
        tasks = [(int(year), int(month), group.copy(), target_depth) for (year, month), group in groups]

        logger.info(f"Depth {target_depth}dbar: groups={len(tasks)}, NPROC={NPROC}")

        processed = []
        # For I/O-bound workloads, too many processes can degrade performance
        with Pool(processes=NPROC) as pool:
            with tqdm(total=len(tasks), desc=f"Processing {target_depth}dbar (MLD)", unit="group") as pbar:
                for res in pool.imap_unordered(process_group_wrapper, tasks, chunksize=IMAP_CHUNKSIZE):
                    processed.append(res)
                    pbar.update(1)

        final_df = pd.concat(processed, ignore_index=True)
        final_df.to_csv(output_csv, index=False)
        logger.info(f"Done depth={target_depth}dbar -> {output_csv}")

    except Exception as e:
        logger.error(f"Depth {target_depth}dbar failed: {e}", exc_info=True)

# ---------------------- Main loop ----------------------
def main_process():
    log_dir = DOXY_BASE / "logs"
    log_dir.mkdir(parents=True, exist_ok=True)

    for target_depth in TARGET_DEPTHS:
        process_single_depth(int(target_depth))

if __name__ == "__main__":
    main_process()
    print("Batch processing completed. Please check the output directories for each depth level.")

In [None]:
# alldoxy
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import logging
from tqdm import tqdm
from multiprocessing import Pool
import os

# ---------------------- Configuration ----------------------
CCMP_DIR = Path("/data/wang/NASA/Wind/Monthly/")
DOXY_BASE = Path("/data/wang/Result_Data/alldoxy/")
OUTPUT_SUFFIX = "_TRAIN"
LAT_TOLERANCE = 0.4  # latitude tolerance
LON_TOLERANCE = 0.2  # longitude tolerance
DEPTHS = TARGET_DEPTHS

POOL_SIZE = 48  # number of parallel processes

# ---------------------- Logging ----------------------
def setup_logger(name, log_file):
    """Create an independent logger for each process."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    if not logger.handlers:
        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s: %(message)s'))

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

    return logger

# ---------------------- Core functions ----------------------
def build_ccmp_index():
    """Build an index dict for CCMP wind files keyed by (year, month)."""
    index = {}
    for nc_file in CCMP_DIR.glob("CCMP_Wind_Analysis_*_monthly_mean_V03.1_L4.nc"):
        try:
            parts = nc_file.stem.split("_")
            yyyymm = parts[3]
            year = int(yyyymm[:4])
            month = int(yyyymm[4:6])
            index[(year, month)] = nc_file
        except Exception as e:
            print(f"Filename parse error {nc_file}: {str(e)}")
    return index

def find_nearest_idx(array, value):
    """Fast nearest-index lookup."""
    return np.abs(array - value).argmin()

def process_group(args):
    """Process a single (year, month) group."""
    (year, month), group, nc_file, depth = args

    logger = setup_logger(
        f"process_{os.getpid()}",
        DOXY_BASE / "logs" / f"depth{depth}_{year}_{month}.log"
    )

    result_df = group.copy()
    result_df['U'] = np.nan
    result_df['V'] = np.nan
    result_df['W'] = np.nan

    try:
        if nc_file is None or not nc_file.exists():
            # logger.warning(f"Wind data missing: {year}-{month}")
            return result_df

        with xr.open_dataset(nc_file) as ds:
            lon_array = ds.longitude.values
            lat_array = ds.latitude.values
            u_data = ds['u'].isel(time=0).values
            v_data = ds['v'].isel(time=0).values
            w_data = ds['w'].isel(time=0).values

            for idx, row in group.iterrows():
                try:
                    orig_lon = row['Longitude']
                    target_lon = orig_lon % 360
                    target_lat = row['Latitude']

                    lon_idx = find_nearest_idx(lon_array, target_lon)
                    lat_idx = find_nearest_idx(lat_array, target_lat)

                    actual_lon = lon_array[lon_idx]
                    actual_lat = lat_array[lat_idx]

                    if (abs(actual_lon - target_lon) > LON_TOLERANCE or
                        abs(actual_lat - target_lat) > LAT_TOLERANCE):
                        continue

                    result_df.at[idx, 'U'] = round(float(u_data[lat_idx, lon_idx]), 3)
                    result_df.at[idx, 'V'] = round(float(v_data[lat_idx, lon_idx]), 3)
                    result_df.at[idx, 'W'] = round(float(w_data[lat_idx, lon_idx]), 3)

                except Exception as e:
                    logger.debug(f"Row {idx} processing error: {str(e)}")

        return result_df

    except Exception as e:
        logger.error(f"Group processing failed {year}-{month}: {str(e)}")
        return result_df

# ---------------------- Main workflow ----------------------
def process_single_depth(target_depth, ccmp_index):
    """Process a single depth level."""
    logger = setup_logger(
        f"depth{target_depth}",
        DOXY_BASE / "logs" / f"depth{target_depth}.log"
    )

    try:
        input_csv = DOXY_BASE / f"{target_depth}dbar" / f"depth{target_depth}_TRAIN.csv"
        output_csv = DOXY_BASE / f"{target_depth}dbar" / f"depth{target_depth}{OUTPUT_SUFFIX}.csv"

        if not input_csv.exists():
            logger.warning(f"Input file not found: {input_csv}")
            return
        if os.path.getsize(input_csv) == 0:
            logger.warning(f"Empty file: {input_csv}")
            return

        df = pd.read_csv(input_csv)

        # Prepare multiprocessing task arguments
        task_args = []
        for (year, month), group in df.groupby(['Year', 'Month']):
            nc_file = ccmp_index.get((year, month))
            task_args.append(((year, month), group, nc_file, target_depth))

        # Parallel processing over groups
        with Pool(processes=POOL_SIZE) as pool:
            chunks = [chunk for chunk in tqdm(
                pool.imap(process_group, task_args),
                total=len(task_args),
                desc=f"Depth {target_depth}dbar progress"
            )]

        # Merge outputs
        final_df = pd.concat(chunks, ignore_index=True)

        # Save result
        final_df.to_csv(output_csv, index=False)
        logger.info(f"Saved: {output_csv} ({len(final_df)} rows)")

    except Exception as e:
        logger.error(f"Processing error: {str(e)}", exc_info=True)

def main():
    # Initialize log directory
    log_dir = DOXY_BASE / "logs"
    log_dir.mkdir(exist_ok=True)

    main_logger = setup_logger("main", log_dir / "main.log")
    main_logger.info("Program started")

    # Build CCMP wind index
    ccmp_index = build_ccmp_index()
    main_logger.info(f"Indexed {len(ccmp_index)} wind files")

    # Process depth levels sequentially
    for depth in DEPTHS:
        main_logger.info(f"Start processing {depth}dbar")
        process_single_depth(depth, ccmp_index)

    main_logger.info("All processing completed")

if __name__ == "__main__":
    main()
    print("Done! Please check the output directory.")

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import logging
from tqdm import tqdm
from multiprocessing import Pool
import os

# ---------------------- Configuration ----------------------
AVISO_DIR = Path("/data/wang/AVISO/madt_h/")
DOXY_BASE = Path("/data/wang/Result_Data/alldoxy/")
LAT_TOLERANCE = 0.4  # latitude tolerance
LON_TOLERANCE = 0.2  # longitude tolerance
DEPTHS = TARGET_DEPTHS

# ---------------------- Logging ----------------------
def setup_logger(name, log_file):
    """Create an independent logger for each process."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    # Avoid adding duplicate handlers
    if not logger.handlers:
        formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s")

        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(formatter)

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

    return logger

# ---------------------- Core utilities ----------------------
def build_aviso_index():
    """Build an index dict for AVISO files keyed by (year, month)."""
    index = {}
    for nc_file in AVISO_DIR.glob("dt_global_allsat_madt_h_*.nc"):
        try:
            # Parse filename to extract year/month
            filename = nc_file.stem
            parts = filename.split("_")
            year = int(parts[-2][1:5])
            month = int(parts[-1][1:3])
            index[(year, month)] = nc_file
        except Exception as e:
            print(f"Filename parse error {nc_file}: {str(e)}")
    return index

def find_nearest_idx(array, value):
    """Fast nearest-index lookup."""
    return np.abs(array - value).argmin()

def process_group(args):
    """Process a single (year, month) group (multiprocessing worker)."""
    (year, month), group, nc_file, log_file = args
    logger = setup_logger(f"process_{os.getpid()}", log_file)

    result_df = group.copy()
    result_df["SSH"] = np.nan

    try:
        if nc_file is None or not nc_file.exists():
            # logger.warning(f"SSH data missing: {year}-{month}")
            return result_df

        with xr.open_dataset(nc_file) as ds:
            lon_array = ds.longitude.values
            lat_array = ds.latitude.values
            ssh_data = ds["adt"].isel(time=0).values

            for idx, row in group.iterrows():
                try:
                    orig_lon = row["Longitude"]
                    target_lon = orig_lon % 360
                    target_lat = row["Latitude"]

                    lon_idx = find_nearest_idx(lon_array, target_lon)
                    lat_idx = find_nearest_idx(lat_array, target_lat)

                    actual_lon = lon_array[lon_idx]
                    actual_lat = lat_array[lat_idx]

                    if (abs(actual_lon - target_lon) > LON_TOLERANCE or
                        abs(actual_lat - target_lat) > LAT_TOLERANCE):
                        continue

                    result_df.at[idx, "SSH"] = float(ssh_data[lat_idx, lon_idx])

                except Exception as e:
                    logger.debug(f"Row {idx} processing error: {str(e)}")

        return result_df

    except Exception as e:
        logger.error(f"Group processing failed {year}-{month}: {str(e)}")
        return result_df

# ---------------------- Main processing ----------------------
def process_single_depth(target_depth, aviso_index):
    """Process one depth level."""
    log_dir = DOXY_BASE / "logs"
    log_dir.mkdir(exist_ok=True)
    logger = setup_logger(f"depth{target_depth}", log_dir / f"depth{target_depth}.log")

    try:
        input_dir = DOXY_BASE / f"{target_depth}dbar"
        input_csv = input_dir / f"depth{target_depth}_TRAIN.csv"

        # Validate input file
        if not input_csv.exists():
            logger.warning(f"Input file not found: {input_csv}")
            return
        if os.path.getsize(input_csv) == 0:
            logger.warning(f"Empty file: {input_csv}")
            return

        # Read original data
        df = pd.read_csv(input_csv)

        # Prepare multiprocessing task arguments
        task_args = []
        for (year, month), group in df.groupby(["Year", "Month"]):
            nc_file = aviso_index.get((year, month))
            task_args.append((
                (year, month),
                group,
                nc_file,
                log_dir / f"process_{year}_{month}.log"
            ))

        # Parallel processing by groups
        results = []
        if task_args:
            with Pool(processes=48) as pool:
                with tqdm(total=len(task_args), desc=f"Depth {target_depth} dbar") as pbar:
                    for result in pool.imap(process_group, task_args):
                        results.append(result)
                        pbar.update()

            # Merge outputs and set precision
            final_df = pd.concat(results, ignore_index=True)
            final_df["SSH"] = final_df["SSH"].round(3)  # keep 3 decimals for the new variable

            # Overwrite the original file
            final_df.to_csv(input_csv, index=False)
            logger.info(f"Updated file: {input_csv} ({len(final_df)} rows)")
        else:
            logger.warning("No valid tasks to process")

    except Exception as e:
        logger.error(f"Processing error: {str(e)}", exc_info=True)

# ---------------------- Driver ----------------------
def main():
    # Initialize global logger
    main_logger = setup_logger("main", DOXY_BASE / "logs/main.log")

    # Build file index
    aviso_index = build_aviso_index()
    main_logger.info(f"Indexed {len(aviso_index)} SSH files")

    # Process in depth order
    for depth in DEPTHS:
        main_logger.info(f"Start processing depth: {depth} dbar")
        process_single_depth(depth, aviso_index)

    main_logger.info("All processing finished")

if __name__ == "__main__":
    main()
    print("Done! Please verify that the original files have been updated.")

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import logging
from tqdm import tqdm
from multiprocessing import Pool
import os

# ---------------------- Configuration ----------------------
aviso_DIR = Path("/data/wang/AVISO/eke/")
DOXY_BASE = Path("/data/wang/Result_Data/alldoxy/")
LAT_TOLERANCE = 0.4  # latitude tolerance (degrees)
LON_TOLERANCE = 0.2  # longitude tolerance (degrees)
DEPTHS = TARGET_DEPTHS

# ---------------------- Logging ----------------------
def setup_logger(name, log_file):
    """Create a file-only logger."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    if not logger.handlers:
        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s: %(message)s"))
        logger.addHandler(file_handler)

    return logger

# ---------------------- Core utilities ----------------------
def build_aviso_index():
    """Build an index dict mapping (year, month) -> AVISO NetCDF file path."""
    index = {}
    for nc_file in aviso_DIR.glob("dt_global_allsat_eke_*.nc"):
        try:
            filename = nc_file.stem
            parts = filename.split("_")
            year = int(parts[-2][1:5])
            month = int(parts[-1][1:3])
            index[(year, month)] = nc_file
        except Exception:
            continue  # silently skip filename parsing errors
    return index

def find_nearest_idx(array, value):
    """Find the nearest index (simple argmin distance)."""
    return np.abs(array - value).argmin()

def process_group(args):
    """Process a single (year, month) group (silent mode)."""
    (year, month), group, nc_file = args
    result_df = group.copy()
    result_df["EKE"] = np.nan  # initialize as NaN

    if nc_file is None or not nc_file.exists():
        return result_df

    try:
        with xr.open_dataset(nc_file) as ds:
            lon_array = ds.longitude.values
            lat_array = ds.latitude.values
            eke_data = ds["eke"].isel(time=0).values  # assume time dimension is correct

            for idx, row in group.iterrows():
                # Normalize longitude to [0, 360)
                target_lon = row["Longitude"] % 360
                target_lat = row["Latitude"]

                # Find nearest grid point
                lon_idx = find_nearest_idx(lon_array, target_lon)
                lat_idx = find_nearest_idx(lat_array, target_lat)

                # Tolerance check
                if (abs(lon_array[lon_idx] - target_lon) <= LON_TOLERANCE and
                    abs(lat_array[lat_idx] - target_lat) <= LAT_TOLERANCE):
                    result_df.at[idx, "EKE"] = eke_data[lat_idx, lon_idx]

    except Exception:
        pass  # silently ignore all exceptions

    return result_df

# ---------------------- Main workflow ----------------------
def process_single_depth(target_depth, aviso_index):
    """Process one depth level."""
    # Configure logging
    log_file = DOXY_BASE / "logs" / f"depth{target_depth}.log"
    logger = setup_logger(f"depth{target_depth}", log_file)

    try:
        # Build file paths
        depth_dir = DOXY_BASE / f"{target_depth}dbar"
        input_csv = depth_dir / f"depth{target_depth}_TRAIN.csv"  # adjust if naming differs

        # Validate input file
        if not input_csv.exists():
            logger.warning(f"Input file not found: {input_csv}")
            return
        if os.path.getsize(input_csv) == 0:
            logger.warning(f"Empty file: {input_csv}")
            return

        # Load data
        df = pd.read_csv(input_csv)
        if "EKE" in df.columns:
            df = df.drop(columns=["EKE"])  # remove old EKE column if present

        # Prepare parallel tasks
        task_args = []
        for (year, month), group in df.groupby(["Year", "Month"]):
            nc_path = aviso_index.get((year, month))
            task_args.append(((year, month), group, nc_path))

        # Parallel processing
        with Pool(processes=12) as pool:
            results = []
            with tqdm(total=len(task_args), desc=f"Depth {target_depth}dbar", leave=False) as pbar:
                for res in pool.imap(process_group, task_args):
                    results.append(res)
                    pbar.update()

        # Merge results and set precision
        final_df = pd.concat(results).sort_index()
        final_df["EKE"] = final_df["EKE"].round(3)  # keep 3 decimals

        # Overwrite the original file
        final_df.to_csv(input_csv, index=False)
        logger.info(f"Wrote {len(final_df)} rows to {input_csv}")

    except Exception as e:
        logger.error(f"Processing error: {str(e)}", exc_info=True)

def main():
    # Build global index
    aviso_index = build_aviso_index()

    # Create log directory
    (DOXY_BASE / "logs").mkdir(exist_ok=True)

    # Process each depth sequentially
    for depth in DEPTHS:
        process_single_depth(depth, aviso_index)

if __name__ == "__main__":
    main()
    print("Processing completed. Please check the output files.")

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import logging
from tqdm import tqdm
from multiprocessing import Pool
import os
import re

# ---------------------- Configuration ----------------------
PAR_DIR = Path("/data/wang/NASA/PAR/")
DOXY_BASE = Path("/data/wang/Result_Data/alldoxy/")
LAT_TOLERANCE = 0.4
LON_TOLERANCE = 0.2
DEPTHS = TARGET_DEPTHS

# ---------------------- Logging ----------------------
def setup_logger(name, log_file):
    """Create a logger that writes to both file and console; ensure parent dirs exist."""
    log_file.parent.mkdir(parents=True, exist_ok=True)
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    if not logger.handlers:
        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s: %(message)s'))
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(logging.Formatter('%(levelname)s: %(levelname)s: %(message)s'))
        logger.addHandler(file_handler)
        logger.addHandler(console_handler)
    return logger

# ---------------------- Core processing ----------------------
def build_par_index():
    """Build two index dicts for PAR netCDF files keyed by (year, month)."""
    seawifs_index = {}
    modis_index = {}

    # Build SEAWIFS index (1997–2002)
    for nc_file in PAR_DIR.glob("SEASTAR_SEAWIFS_GAC.*.L3m.MO.PAR.par.9km.nc"):
        try:
            match = re.search(r"\.(\d{6})\d{2}_\d{8}", nc_file.name)
            if match:
                yyyymm = match.group(1)
                year = int(yyyymm[:4])
                month = int(yyyymm[4:6])
                if 1997 <= year <= 2002:
                    seawifs_index[(year, month)] = nc_file
        except Exception as e:
            print(f"SEAWIFS index build error: {str(e)}")

    # Build MODIS index (2003–2024)
    for nc_file in PAR_DIR.glob("AQUA_MODIS.*.L3m.MO.PAR.par.4km.nc"):
        try:
            match = re.search(r"\.(\d{6})\d{2}_\d{8}", nc_file.name)
            if match:
                yyyymm = match.group(1)
                year = int(yyyymm[:4])
                month = int(yyyymm[4:6])
                if 2003 <= year <= 2024:
                    modis_index[(year, month)] = nc_file
        except Exception as e:
            print(f"MODIS index build error: {str(e)}")

    return seawifs_index, modis_index

def find_nearest_idx(array, value):
    """Return the index of the nearest value in a 1D array."""
    return np.abs(array - value).argmin()

def process_group_par(args):
    """Multiprocessing worker for one (year, month) group."""
    (year, month), group, seawifs_idx, modis_idx, log_file = args
    logger = setup_logger(f"process_{os.getpid()}", log_file)

    result_df = group.copy()
    result_df['PAR'] = np.nan

    try:
        # Dynamically choose sensor by year
        if year < 2003:
            nc_file = seawifs_idx.get((year, month))
            sensor = "SEAWIFS"
        else:
            nc_file = modis_idx.get((year, month))
            sensor = "MODIS"

        if nc_file is None or not nc_file.exists():
            return result_df

        with xr.open_dataset(nc_file) as ds:
            lons = ds.lon.values.astype(float)
            lats = ds.lat.values.astype(float)
            par_data = ds.par.values

            # Ensure shape is (time, lat, lon)
            if par_data.ndim == 2:
                par_data = par_data[np.newaxis, :, :]

            for idx, row in group.iterrows():
                target_lat = row['Latitude']
                target_lon = row['Longitude']

                lon_idx = find_nearest_idx(lons, target_lon)
                lat_idx = find_nearest_idx(lats, target_lat)

                # Apply nearest-neighbor tolerance filter
                if (abs(lons[lon_idx] - target_lon) > LON_TOLERANCE or
                    abs(lats[lat_idx] - target_lat) > LAT_TOLERANCE):
                    continue

                par_value = par_data[0, lat_idx, lon_idx]
                if not np.isnan(par_value):
                    result_df.at[idx, 'PAR'] = round(par_value, 3)  # keep 3 decimals

        return result_df

    except Exception as e:
        logger.error(f"Group processing failed {year}-{month}: {str(e)}")
        return result_df

# ---------------------- Main workflow ----------------------
def process_single_depth(target_depth, seawifs_index, modis_index):
    """Process one depth layer."""
    log_dir = DOXY_BASE / "logs_par"
    logger = setup_logger(f"depth{target_depth}", log_dir / f"depth{target_depth}.log")

    try:
        # Input/output paths
        input_dir = DOXY_BASE / f"{target_depth}dbar"
        input_csv = input_dir / f"depth{target_depth}_TRAIN.csv"

        if not input_csv.exists():
            logger.warning(f"Input file not found: {input_csv}")
            return

        # Load data
        df = pd.read_csv(input_csv)
        if df.empty:
            logger.warning(f"Empty dataset: {input_csv}")
            return

        # Build multiprocessing tasks
        task_args = []
        for (year, month), group in df.groupby(['Year', 'Month']):
            task_args.append((
                (year, month),
                group,
                seawifs_index,
                modis_index,
                log_dir / f"par_{year}_{month}.log"
            ))

        # Parallel execution
        with Pool(processes=12) as pool:
            results = []
            with tqdm(total=len(task_args), desc=f"Depth {target_depth}dbar") as pbar:
                for result in pool.imap(process_group_par, task_args):
                    results.append(result)
                    pbar.update(1)

            # Merge outputs and sort by original index
            final_df = pd.concat(results).sort_index()

            # Overwrite the original file
            final_df.to_csv(input_csv, index=False)
            logger.info(f"File updated: {input_csv}")

    except Exception as e:
        logger.error(f"Depth-layer processing error: {str(e)}")

def main():
    """Entry point."""
    main_log_dir = DOXY_BASE / "logs_par"
    main_logger = setup_logger("main_par", main_log_dir / "main.log")

    # Build indices
    seawifs_idx, modis_idx = build_par_index()
    main_logger.info(f"Index loaded: SEAWIFS({len(seawifs_idx)}), MODIS({len(modis_idx)})")

    # Process depths sequentially
    for depth in DEPTHS:
        main_logger.info(f"Start processing depth: {depth}dbar")
        process_single_depth(depth, seawifs_idx, modis_idx)

    main_logger.info("All tasks completed")

if __name__ == "__main__":
    main()
    print("Done! Please check log files.")

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import logging
import re
from tqdm import tqdm
from multiprocessing import Pool
import os

# ---------------------- Configuration ----------------------
CMEMS_DIR = Path("/data/wang/CMEMS/pco2/")
DOXY_BASE = Path("/data/wang/Result_Data/alldoxy/")
LAT_TOLERANCE = 0.4
LON_TOLERANCE = 0.2
DEPTHS = TARGET_DEPTHS

CO2_VARS = ["CO2_flux", "pH", "pCO2", "DIC", "Alkalinity"]
NC_VARS = ["fgco2", "ph", "spco2", "tco2", "talk"]

# ---------------------- Logging ----------------------
def setup_logger(name, log_dir):
    """Robust logger initialization."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    if logger.handlers:
        return logger

    try:
        log_dir.mkdir(parents=True, exist_ok=True)
        file_handler = logging.FileHandler(log_dir / f"{name}.log")
        file_handler.setFormatter(
            logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
        )
        logger.addHandler(file_handler)
    except PermissionError:
        print(f"Permission denied: {log_dir}. Falling back to console logging.")
    except Exception as e:
        print(f"Logger initialization error: {str(e)}")

    if not logger.handlers:
        console_handler = logging.StreamHandler()
        console_handler.setFormatter(
            logging.Formatter("%(levelname)s - %(message)s")
        )
        logger.addHandler(console_handler)

    return logger

# ---------------------- Core processor ----------------------
class CO2Processor:
    def __init__(self):
        self.logger = setup_logger("CO2Processor", DOXY_BASE / "logs")
        try:
            self.file_index = self._build_file_index()
            self.logger.info("CO2 processor initialized successfully")
        except Exception as e:
            self.logger.error(f"Initialization failed: {str(e)}", exc_info=True)
            raise

    def _build_file_index(self):
        """Build a file index dict keyed by (year, month) -> nc file."""
        index = {}
        pattern = re.compile(r"cmems_obs-mob_glo_bgc-car_my_irr-i_(\d{6})\.nc$")
        try:
            for nc_file in CMEMS_DIR.glob("*.nc"):
                if not nc_file.is_file():
                    continue
                match = pattern.match(nc_file.name)
                if match:
                    yyyymm = match.group(1)
                    year = int(yyyymm[:4])
                    month = int(yyyymm[4:6])
                    index[(year, month)] = nc_file
                    self.logger.debug(f"Indexed file: {year}-{month:02d}")
        except Exception as e:
            self.logger.error(f"Failed to build file index: {str(e)}")
            raise
        return index

    def _lon_360(self, lon):
        """Convert longitude to [0, 360)."""
        try:
            return lon % 360
        except TypeError:
            return np.nan

    def _find_nearest_idx(self, array, value):
        """Return the nearest index for value in array, robust to NaNs/errors."""
        try:
            return np.nanargmin(np.abs(array - value))
        except Exception as e:
            self.logger.warning(f"Coordinate lookup failed: {str(e)}")
            return -1

    def process_depth(self, target_depth):
        """Process a single depth level."""
        input_dir = DOXY_BASE / f"{target_depth}dbar"
        input_csv = input_dir / f"depth{target_depth}_TRAIN.csv"

        if not self._validate_input(input_csv):
            return

        try:
            df = pd.read_csv(input_csv)
            for col in CO2_VARS:
                df[col] = np.nan  # initialize new columns
        except Exception as e:
            self.logger.error(f"Failed to read CSV: {str(e)}")
            return

        # Parallel processing by (year, month) groups
        groups = list(df.groupby(["Year", "Month"]))
        with Pool(processes=48) as pool:
            processed_groups = list(
                tqdm(
                    pool.imap(self._process_group_wrapper, [(g[0], g[1]) for g in groups]),
                    total=len(groups),
                    desc=f"Processing {target_depth} dbar",
                    unit="group",
                )
            )

        # Merge results
        final_df = pd.concat([pg for pg in processed_groups if pg is not None])

        try:
            # Apply rounding only to the newly added variables
            final_df[CO2_VARS] = final_df[CO2_VARS].round(3)

            # Overwrite the original file (write all columns once)
            final_df.to_csv(input_csv, index=False)
            self.logger.info(f"Successfully overwrote file: {input_csv}")
        except Exception as e:
            self.logger.error(f"Failed to write file: {str(e)}")

    def _validate_input(self, input_csv):
        """Check that the input CSV exists and is non-empty."""
        if not input_csv.exists():
            self.logger.error(f"Input file not found: {input_csv}")
            return False
        if input_csv.stat().st_size == 0:
            self.logger.warning(f"Empty file: {input_csv}")
            return False
        return True

    def _process_group_wrapper(self, args):
        """Wrapper for multiprocessing to isolate exceptions."""
        try:
            return self._process_group(*args)
        except Exception as e:
            self.logger.error(f"Group processing exception: {str(e)}")
            return None

    def _process_group(self, year_month, group_df):
        """Process one (year, month) group and fill CO2-related variables."""
        year, month = year_month
        nc_file = self.file_index.get((year, month))
        if not nc_file or not nc_file.exists():
            self.logger.warning(f"Missing NC file: {year}-{month:02d}")
            return group_df

        try:
            with xr.open_dataset(nc_file) as ds:
                lats = ds.latitude.values.astype("float32")
                lons = ds.longitude.values.astype("float32")

                target_lons = group_df["Longitude"].apply(self._lon_360).values
                target_lats = group_df["Latitude"].values

                lat_indices = np.array([self._find_nearest_idx(lats, lat) for lat in target_lats])
                lon_indices = np.array([self._find_nearest_idx(lons, lon) for lon in target_lons])

                valid_mask = (
                    (np.abs(lats[lat_indices] - target_lats) <= LAT_TOLERANCE) &
                    (np.abs(lons[lon_indices] - target_lons) <= LON_TOLERANCE)
                )
                valid_idx = np.where(valid_mask)[0]

                results = np.full((len(group_df), len(CO2_VARS)), np.nan, dtype=np.float32)

                if len(valid_idx) > 0:
                    for var_idx, var in enumerate(NC_VARS):
                        try:
                            var_data = ds[var][0].values
                            valid_values = var_data[lat_indices[valid_idx], lon_indices[valid_idx]]
                            results[valid_idx, var_idx] = valid_values
                        except Exception as e:
                            self.logger.error(f"Failed to extract variable {var}: {str(e)}")

                group_df[CO2_VARS] = results
                return group_df

        except Exception as e:
            self.logger.error(f"Processing failed {year}-{month:02d}: {str(e)}")
            return group_df

# ---------------------- Entry point ----------------------
if __name__ == "__main__":
    try:
        processor = CO2Processor()
        # Process depths in order; internally parallelize by (year, month) groups
        for depth in tqdm(DEPTHS, desc="Depth processing progress"):
            processor.process_depth(depth)
        print("Batch processing finished!")
    except Exception as e:
        logging.getLogger("Main").critical(f"Batch processing failed: {str(e)}", exc_info=True)

In [None]:
# Chla
import xarray as xr
import pandas as pd
import numpy as np
from pathlib import Path
import logging
import re
from tqdm import tqdm
from multiprocessing import Pool
import os

# ---------------------- Configuration ----------------------
CHLA_DIR = Path("/data/wang/NASA/Chla/")
DOXY_BASE = Path("/data/wang/Result_Data/alldoxy/")
LAT_TOLERANCE = 0.4
LON_TOLERANCE = 0.2
DEPTHS = TARGET_DEPTHS
N_PROCESSES = 48  # number of worker processes

# ---------------------- Logging ----------------------
def setup_logger(name, log_dir):
    """Process-safe logger configuration."""
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)

    if not logger.handlers:
        log_dir.mkdir(parents=True, exist_ok=True)
        log_file = log_dir / f"{name}.log"

        file_handler = logging.FileHandler(log_file)
        file_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s: %(message)s"))

        console_handler = logging.StreamHandler()
        console_handler.setFormatter(logging.Formatter("%(levelname)s: %(message)s"))

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

    return logger

# ---------------------- File index builder ----------------------
class ChlaIndexer:
    def __init__(self):
        self.seawifs_index = {}  # 1997-2002
        self.modis_index = {}    # 2003-2024
        self._build_index()

    def _parse_filename(self, filename):
        """Generic filename parser."""
        patterns = [
            r"SEASTAR_SEAWIFS_GAC\.(\d{6})\d{2}_\d+\.L3m\.MO\.CHL\.chlor_a\.par\.9km\.nc",
            r"AQUA_MODIS\.(\d{6})\d{2}_\d+\.L3m\.MO\.CHL\.chlor_a\.4km\.nc"
        ]

        for pattern in patterns:
            match = re.search(pattern, filename)
            if match:
                yyyymm = match.group(1)
                year = int(yyyymm[:4])
                month = int(yyyymm[4:6])
                return year, month
        return None, None

    def _build_index(self):
        """Build a two-era file index."""
        for nc_file in CHLA_DIR.glob("*.nc"):
            if not nc_file.is_file():
                continue

            year, month = self._parse_filename(nc_file.name)
            if year is None:
                continue

            # Store by era
            if 1997 <= year <= 2002:
                self.seawifs_index[(year, month)] = nc_file
            elif 2003 <= year <= 2024:
                self.modis_index[(year, month)] = nc_file

# ---------------------- Core processing class ----------------------
class ChlaProcessor:
    def __init__(self):
        self.logger = setup_logger("ChlaProcessor", DOXY_BASE / "logs_chla")
        self.indexer = ChlaIndexer()
        self.logger.info(
            f"Index summary: SEAWIFS({len(self.indexer.seawifs_index)}) | MODIS({len(self.indexer.modis_index)})"
        )

    def process_depth(self, target_depth):
        """Process a single depth level."""
        processor_id = f"{target_depth}dbar_{os.getpid()}"
        logger = setup_logger(processor_id, DOXY_BASE / "logs_chla")

        try:
            input_csv = DOXY_BASE / f"{target_depth}dbar" / f"depth{target_depth}_TRAIN.csv"
            output_csv = input_csv  # overwrite in place

            if not self._validate_input(input_csv, logger):
                return

            df = pd.read_csv(input_csv)
            if "Chla" not in df.columns:
                df["Chla"] = np.nan

            # Extract all (Year, Month) groups
            year_month_groups = df.groupby(["Year", "Month"]).groups
            task_args = [(target_depth, ym, df.loc[idx]) for ym, idx in year_month_groups.items()]

            # Parallel processing by (Year, Month)
            with Pool(processes=N_PROCESSES) as pool:
                results = list(tqdm(
                    pool.imap(self._process_year_month, task_args),
                    total=len(task_args),
                    desc=f"Depth {target_depth}dbar",
                    unit="group"
                ))

            # Merge results and save
            final_df = pd.concat([r for r in results if r is not None])
            final_df["Chla"] = final_df["Chla"].round(4)  # keep 4 decimals
            final_df.to_csv(output_csv, index=False)
            logger.info(f"Overwritten successfully: {output_csv}")

        except Exception as e:
            logger.error(f"Processing failed: {str(e)}", exc_info=True)

    def _process_year_month(self, args):
        """Process a single (Year, Month) group (parallelized)."""
        target_depth, (year, month), group_df = args
        try:
            # Select NetCDF file by era
            if year <= 2002:
                nc_file = self.indexer.seawifs_index.get((year, month))
            else:
                nc_file = self.indexer.modis_index.get((year, month))

            if not nc_file or not nc_file.exists():
                return group_df

            with xr.open_dataset(nc_file) as ds:
                # Validate variable presence
                if "chlor_a" not in ds.variables:
                    return group_df

                # Read coordinate arrays
                lats = ds.lat.values.astype(float)
                lons = ds.lon.values.astype(float)
                chla_data = ds["chlor_a"].values

                # Ensure data has a time dimension
                if chla_data.ndim == 2:
                    chla_data = chla_data[np.newaxis, :, :]

                # Compute nearest indices (vectorized via list comprehension)
                lat_indices = np.array([self._find_nearest(lats, lat) for lat in group_df["Latitude"]])
                lon_indices = np.array([self._find_nearest(lons, lon) for lon in group_df["Longitude"]])

                # Tolerance filtering
                valid_mask = (
                    (np.abs(lats[lat_indices] - group_df["Latitude"]) <= LAT_TOLERANCE) &
                    (np.abs(lons[lon_indices] - group_df["Longitude"]) <= LON_TOLERANCE)
                )
                valid_idx = np.where(valid_mask)[0]

                # Extract values
                chla_values = np.full(len(group_df), np.nan, dtype=np.float32)
                if len(valid_idx) > 0:
                    try:
                        chla_values[valid_idx] = chla_data[0, lat_indices[valid_idx], lon_indices[valid_idx]]
                    except IndexError:
                        pass

                group_df["Chla"] = chla_values
                return group_df

        except Exception:
            return group_df

    def _find_nearest(self, array, value):
        """Find nearest index (NaN-safe)."""
        return np.nanargmin(np.abs(array - value))

    def _validate_input(self, input_csv, logger):
        """Validate input CSV file."""
        if not input_csv.exists():
            logger.warning(f"Input file not found: {input_csv}")
            return False
        if input_csv.stat().st_size == 0:
            logger.warning(f"Empty file: {input_csv}")
            return False
        return True

# ---------------------- Main program ----------------------
def main():
    # Initialize main logger
    main_logger = setup_logger("ChlaMain", DOXY_BASE / "logs_chla")

    # Process depths sequentially
    processor = ChlaProcessor()
    for depth in DEPTHS:
        main_logger.info(f"Start processing depth: {depth}dbar")
        processor.process_depth(depth)

    print("✅ Chlorophyll-a processing completed!")

if __name__ == "__main__":
    main()