## Current ARCSIX Campaign Size Distribution Merge R1 Version

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from __future__ import annotations

import sys
from pathlib import Path
import traceback
from datetime import datetime, timezone

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from netCDF4 import Dataset  # kept because your pipeline imports it elsewhere

BASE_DIR = Path.cwd().parent
sys.path.append(str(BASE_DIR / "src"))

from merge_production import (  # noqa: E402
    load_aufi_oneday_v2,
    split_frames,
    plot_period_totals,
    make_filtered_specs_v2,
    run_joint_optimization_v2,
    plot_history,
    make_consensus_merged_spec_v2,
    plot_sizedist_all,
    write_day_netcdf_v2,
    chunk_is_incloud,
    filter_chunk_by_inlet_flag,
)
from sizedist_utils import remap_dndlog_by_edges_any  # noqa: E402
from ict_utils import read_inlet_flag, read_microphysical  # noqa: E402


# =============================================================================
# SETTINGS
# =============================================================================
dates = [
    "2024-05-28", "2024-05-30", "2024-05-31", "2024-06-03",
    "2024-06-05", "2024-06-06", "2024-06-07", "2024-06-10",
    "2024-06-11", "2024-06-13", "2024-07-25", "2024-07-29",
    "2024-07-30", "2024-08-01", "2024-08-02", "2024-08-07",
    "2024-08-08", "2024-08-09", "2024-08-15",
]

DATA_DIR  = Path("/Volumes/Hailstone Data/Research Data/ARCSIX_P3B")
aps_dir   = DATA_DIR / "LARGE-APS"
uhsas_dir = DATA_DIR / "PUTLS-UHSAS"
fims_dir  = DATA_DIR / "FIMS"
pops_dir  = DATA_DIR / "PUTLS-POPS"
inlet_dir = DATA_DIR / "LARGE-InletFlag"
micro_dir = DATA_DIR / "LARGE-MICROPHYSICAL"

OUTPUT_DIR = Path("/Users/C832577250/Output/arcsix_sizedist_merge_batch_v3")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

LOG_FILE   = OUTPUT_DIR / "output_log.txt"
ERROR_LOG  = OUTPUT_DIR / "error_log.txt"

with LOG_FILE.open("a") as f:
    ts = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
    f.write("ARCSIX Aerosol Size Distribution Merge Log - VERSION 3 (4-Instrument)\n")
    f.write(f"Generated: {ts} by Bo Chen\n\n")
    f.write(f"DATA_DIR:      {DATA_DIR}\n")
    f.write(f"OUTPUT_DIR:    {OUTPUT_DIR}\n")
    f.write(f"DATES:         {dates}\n\n")

MERGE_PER = 5                 # minutes
FIMS_LAG  = 10                # seconds (shift FIMS earlier)
INCLOUD_PAD_S = 10            # seconds around inlet_flag to mark a chunk in-cloud
MIN_SAMPLES_PER_INST = 50

# --- NEW: overlap requirement ---
MIN_OVERLAP_S = 120           # minimum overlap among all instruments (seconds)
OVERLAP_FREQ  = "1S"          # compute overlap at 1-second resolution

# ---------- OPTIMIZATION SETTINGS ----------------------------------------------------
MOMENT = "V"
SPACE  = "linear"
PAIR_W = 1.0

BOUNDS_UHSAS = [(1.3, 1.8)]
BOUNDS_APS   = [(950.0, 2000.0)]

UHSAS_XMIN = None
UHSAS_XMAX = None

FIMS_XMIN = 10
FIMS_XMAX = 500

POPS_XMIN = None
POPS_XMAX = None

LUT_DIR   = BASE_DIR / "lut"

# ---------------- GLOBAL BINNING (IDENTICAL ACROSS ALL DAYS) ------------------------
# 100 bins => 101 edges
FINE_BIN = 100

GLOBAL_XMIN_NM = 10.0
GLOBAL_XMAX_NM = 5000.0

GLOBAL_EDGES = np.logspace(
    np.log10(GLOBAL_XMIN_NM),
    np.log10(GLOBAL_XMAX_NM),
    FINE_BIN + 1
).astype(float)

# ---------- CONSENSUS SETTINGS ---------------------------------------------------
ALPHA_FIMS  = 1.0
ALPHA_UHSAS = 1.5
ALPHA_POPS  = 0.2
ALPHA_APS   = 2.0
LAMBDA_TIK  = 1e-5
C_PUNISH    = 0.5
COMBINATION_SPACE = "log10"


# =============================================================================
# OVERLAP QC
# =============================================================================
def overlap_seconds_all_instruments(
    chunk: dict[str, pd.DataFrame],
    t_start: pd.Timestamp,
    t_end: pd.Timestamp,
    instruments: tuple[str, ...] = ("APS", "UHSAS", "FIMS", "POPS"),
    freq: str = "1S",
) -> int:
    """
    Count how many seconds in [t_start, t_end) have >=1 record from ALL instruments.

    This is strict: same-second presence for all instruments.
    """
    if t_start is None or t_end is None:
        return 0
    if t_end <= t_start:
        return 0

    # Fixed second grid, end-exclusive
    grid = pd.date_range(t_start, t_end, freq=freq, inclusive="left")
    if len(grid) == 0:
        return 0

    all_mask = np.ones(len(grid), dtype=bool)

    for name in instruments:
        df = chunk.get(name, None)
        if df is None or df.empty:
            return 0

        dfi = df.loc[t_start:t_end]
        if dfi.empty:
            return 0

        pres = (dfi.resample(freq).size() > 0)
        pres = pres.reindex(grid, fill_value=False)

        all_mask &= pres.to_numpy()
        if not all_mask.any():
            return 0

    return int(all_mask.sum())


# =============================================================================
# MAIN LOOP
# =============================================================================
for a_date in dates:
    with LOG_FILE.open("a") as f:
        f.write("----------------------------------------------------------\n")
        f.write(f"Merging {a_date} every {MERGE_PER} minutes\n")
        f.write("------------------------ SETTINGS ------------------------\n")
        # --- Time & Sampling ---
        f.write(f"FIMS_LAG: {FIMS_LAG}s  # shift FIMS earlier\n")
        f.write(f"INCLOUD_PAD_S: {INCLOUD_PAD_S}s  # cloud-flag buffer\n")
        f.write(f"MIN_SAMPLES_PER_INST: {MIN_SAMPLES_PER_INST}  # min avg count to process chunk\n")
        f.write(f"MERGE_PER: {MERGE_PER} min  # chunk duration\n")

        # --- Overlap QC ---
        f.write(f"\n# OVERLAP QC\n")
        f.write(f"MIN_OVERLAP_S: {MIN_OVERLAP_S}  # required overlap seconds among all instruments\n")
        f.write(f"OVERLAP_FREQ: {OVERLAP_FREQ}  # resampling frequency to compute overlap\n")

        # --- Optimization Physics ---
        f.write(f"\n# OPTIMIZATION (Joint Fit)\n")
        f.write(f"MOMENT: {MOMENT}  # (N=0, S=2, V=3)\n")
        f.write(f"SPACE: {SPACE}  # linear or log cost\n")
        f.write(f"PAIR_W: {PAIR_W}  # cross-instrument consistency weight\n")
        f.write(f"BOUNDS_UHSAS (m): {BOUNDS_UHSAS}  # real refractive index range\n")
        f.write(f"BOUNDS_APS (rho): {BOUNDS_APS}  # density range kg/m3\n")

        # --- Instrument Range Filters ---
        f.write(f"\n# INSTRUMENT X-LIMITS (nm)\n")
        f.write(f"UHSAS_RANGE: {UHSAS_XMIN} to {UHSAS_XMAX}\n")
        f.write(f"FIMS_RANGE:  {FIMS_XMIN} to {FIMS_XMAX}\n")
        f.write(f"POPS_RANGE:  {POPS_XMIN} to {POPS_XMAX}\n")

        # --- Global Output Bins ---
        f.write(f"\n# GLOBAL OUTPUT BINS (IDENTICAL ACROSS ALL DAYS)\n")
        f.write(f"FINE_BIN: {FINE_BIN}\n")
        f.write(f"GLOBAL_XMIN_NM: {GLOBAL_XMIN_NM}\n")
        f.write(f"GLOBAL_XMAX_NM: {GLOBAL_XMAX_NM}\n")

        # --- Consensus & Tikhonov ---
        f.write(f"\n# MERGING & CONSENSUS\n")
        f.write(f"LAMBDA_TIK: {LAMBDA_TIK}  # Tikhonov smoothing parameter\n")
        f.write(f"C_PUNISH: {C_PUNISH}  # consensus penalty parameter\n")
        f.write(f"COMBINATION_SPACE: {COMBINATION_SPACE}  # linear or log combination\n")
        f.write(f"ALPHA_FIMS:  {ALPHA_FIMS}  # weight for FIMS in consensus\n")
        f.write(f"ALPHA_UHSAS: {ALPHA_UHSAS}  # weight for UHSAS in consensus\n")
        f.write(f"ALPHA_POPS:  {ALPHA_POPS}  # weight for POPS in consensus\n")
        f.write(f"ALPHA_APS:   {ALPHA_APS}  # weight for APS in consensus\n")
        f.write("----------------------------------------------------------\n\n")

    day_dir = OUTPUT_DIR / a_date
    day_dir.mkdir(parents=True, exist_ok=True)
    totals_dir = day_dir / "time_series"
    opt_dir    = day_dir / "loss_curve"
    plots_dir  = day_dir / "merge_plots"
    totals_dir.mkdir(exist_ok=True)
    opt_dir.mkdir(exist_ok=True)
    plots_dir.mkdir(exist_ok=True)

    # IMPORTANT: use global edges for ALL days (not per-day)
    common_edges = GLOBAL_EDGES

    day_fims_algn    = []
    day_uhsas_algn   = []
    day_pops_algn    = []
    day_aps_algn     = []
    day_merged       = []
    day_times_start  = []
    day_times_end    = []
    day_incloud      = []
    orig_APS_edges   = None
    orig_UHSAS_edges = None
    orig_POPS_edges  = None
    orig_FIMS_edges  = None
    day_n_fit        = []
    day_n_pops_fit   = []
    day_rho_fit      = []
    day_best_cost    = []

    filtered_frames = load_aufi_oneday_v2(a_date, aps_dir, uhsas_dir, fims_dir, pops_dir)

    # shift FIMS earlier by FIMS_LAG seconds
    if "FIMS" in filtered_frames and not filtered_frames["FIMS"].empty:
        filtered_frames["FIMS"] = filtered_frames["FIMS"].copy()
        filtered_frames["FIMS"].index = (
            filtered_frames["FIMS"].index - pd.Timedelta(seconds=FIMS_LAG)
        )

    inlet_flag = read_inlet_flag(inlet_dir, start=a_date, end=None, prefix="ARCSIX")
    micro      = read_microphysical(micro_dir, start=a_date, end=None, prefix="ARCSIX")
    cpc_total  = pd.to_numeric(micro.get("CNgt10nm"), errors="coerce")

    split_filtered_frames = split_frames(filtered_frames, MERGE_PER * 60)

    for i, a_chunk in enumerate(split_filtered_frames):
        try:
            times   = [t for df in a_chunk.values() if len(df) for t in (df.index[0], df.index[-1])]
            t_start = min(times) if times else None
            t_end   = max(times) if times else None

            with LOG_FILE.open("a") as f:
                f.write(f"\tsizedist {i:03d}: {t_start} -> {t_end}\n")

            inlet_chunk     = inlet_flag.loc[t_start:t_end] if (t_start is not None and t_end is not None) else inlet_flag.iloc[0:0]
            cpc_total_chunk = cpc_total.loc[t_start:t_end]  if (t_start is not None and t_end is not None) else cpc_total.iloc[0:0]

            inc_flag = chunk_is_incloud(inlet_flag, t_start, t_end, tol_s=INCLOUD_PAD_S)

            if t_start is None or t_end is None:
                with LOG_FILE.open("a") as f:
                    f.write(f"\t[SKIP] chunk {i:03d} empty window\n")
                continue

            # ---------------------------------------------------
            # FILTER OUT inlet-flagged time (±INCLOUD_PAD_S seconds)
            # ---------------------------------------------------
            a_chunk = filter_chunk_by_inlet_flag(
                chunk=a_chunk,
                inlet_flag=inlet_flag,
                t_start=t_start,
                t_end=t_end,
                pad_s=INCLOUD_PAD_S,
            )

            if all((df is None or df.empty) for df in a_chunk.values()):
                with LOG_FILE.open("a") as f:
                    f.write(f"\t[SKIP] chunk {i:03d} all data removed by inlet_flag filter\n")
                continue

            # ---------------------------------------------------
            # NEW: enforce >=2 minutes overlap among ALL instruments
            # ---------------------------------------------------
            ov_s = overlap_seconds_all_instruments(
                chunk=a_chunk,
                t_start=t_start,
                t_end=t_end,
                instruments=("APS", "UHSAS", "FIMS", "POPS"),
                freq=OVERLAP_FREQ,
            )
            if ov_s < MIN_OVERLAP_S:
                with LOG_FILE.open("a") as f:
                    f.write(
                        f"\t[SKIP] chunk {i:03d} insufficient overlap: "
                        f"{ov_s:d}s < {MIN_OVERLAP_S:d}s (freq={OVERLAP_FREQ})\n"
                    )
                continue

            # 1) time series plot
            fig1, _ = plot_period_totals(
                a_chunk,
                title=f"{a_date} sizedist {i:03d}",
                inlet_flag=inlet_flag,
                gauss_win=10,
                gauss_std=2,
                cpc_total=cpc_total_chunk,
                t_start=t_start,
                t_end=t_end,
            )
            fig1.savefig(totals_dir / f"sizedist_{i:03d}_totals.png", dpi=150)
            plt.close(fig1)

            # 2) mean specs
            specs, line_kwargs, fill_kwargs, bin_counts = make_filtered_specs_v2(
                a_chunk,
                a_chunk.get("APS",   pd.DataFrame()),
                a_chunk.get("UHSAS", pd.DataFrame()),
                a_chunk.get("FIMS",  pd.DataFrame()),
                a_chunk.get("POPS",  pd.DataFrame()),
                LOG_FILE,
            )

            if ("APS" not in specs) or ("UHSAS" not in specs) or ("FIMS" not in specs) or ("POPS" not in specs):
                with LOG_FILE.open("a") as f:
                    f.write(f"\t[SKIP] chunk {i:03d} missing instrument(s)\n")
                continue

            # gate on non-zero average counts per instrument (your existing logic)
            low_data_reason = None
            for _name in ("APS", "UHSAS", "FIMS", "POPS"):
                if _name in bin_counts and len(bin_counts[_name]) > 0:
                    arr = np.asarray(bin_counts[_name], int)
                    nz = arr[arr > 0]
                    nz_avg = float(nz.mean()) if nz.size > 0 else 0.0
                    if nz_avg < MIN_SAMPLES_PER_INST:
                        low_data_reason = f"{_name} nonzero_avg={nz_avg:.1f} < {MIN_SAMPLES_PER_INST}"
                        break
            if low_data_reason is not None:
                with LOG_FILE.open("a") as f:
                    f.write(f"\t[SKIP] chunk {i:03d} low data: {low_data_reason}\n")
                continue

            # save original instrument edges once
            if orig_APS_edges is None:
                orig_APS_edges = specs["APS"][1]
            if orig_UHSAS_edges is None:
                orig_UHSAS_edges = specs["UHSAS"][1]
            if orig_FIMS_edges is None:
                orig_FIMS_edges = specs["FIMS"][1]
            if orig_POPS_edges is None:
                orig_POPS_edges = specs["POPS"][1]

            # 3) optimization
            specs_opt, line_kwargs_opt, fill_kwargs_opt, opt_res = run_joint_optimization_v2(
                specs,
                line_kwargs,
                fill_kwargs,
                moment=MOMENT,
                space=SPACE,
                pair_w=PAIR_W,
                uhsas_bounds=BOUNDS_UHSAS,
                aps_bounds=BOUNDS_APS,
                uhsas_xmin=UHSAS_XMIN,
                uhsas_xmax=UHSAS_XMAX,
                fims_xmin=FIMS_XMIN,
                fims_xmax=FIMS_XMAX,
                pops_xmin=POPS_XMIN,
                pops_xmax=POPS_XMAX,
                lut_dir=LUT_DIR,
            )

            # 4) loss curve
            fig_h, ax_h = plot_history(opt_res["hist"])
            ax_h.set_title(f"opt hist {a_date} {i:03d}")
            fig_h.savefig(opt_dir / f"sizedist_{i:03d}_opt_hist.png", dpi=150)
            plt.close(fig_h)

            # 5) log
            with LOG_FILE.open("a") as f:
                f.write(
                    f"\t\tUHSAS n_fit = {opt_res['n_fit']:.4f}, "
                    f"POPS n_fit = {opt_res['n_pops_fit']:.4f}, "
                    f"APS rho_fit = {opt_res['rho_fit']:.1f} kg/m^3, "
                    f"cost = {opt_res['best_cost']:.6g}\n\n"
                )

            # 6) Consensus merged (native edges; then remap to GLOBAL_EDGES below)
            uh_label = f"UHSAS fit (n={opt_res['n_fit']:.3f})"
            po_label = f"POPS fit (n={opt_res['n_pops_fit']:.3f})"
            ap_label = f"APS fit (ρ={opt_res['rho_fit']*0.001:.3f} g/cm$^3$)"

            tik_specs, tik_lines, tik_fills, tik_diag = make_consensus_merged_spec_v2(
                e_fims_sel=specs_opt["FIMS_applied"][1],
                y_fims_sel=specs_opt["FIMS_applied"][2],
                e_uhsas_fit=specs_opt[uh_label][1],
                y_uhsas_fit=specs_opt[uh_label][2],
                e_pops_fit=specs_opt[po_label][1],
                y_pops_fit=specs_opt[po_label][2],
                e_aps_fit=specs_opt[ap_label][1],
                y_aps_fit=specs_opt[ap_label][2],
                lam=LAMBDA_TIK,
                n_points=FINE_BIN,  # 100 bins native output of consensus step
                alpha_fims=ALPHA_FIMS,
                alpha_uhsas=ALPHA_UHSAS,
                alpha_pops=ALPHA_POPS,
                alpha_aps=ALPHA_APS,
                c_punish=C_PUNISH,
                data_space=COMBINATION_SPACE,
            )

            specs_opt.update(tik_specs)
            line_kwargs_opt.update(tik_lines)
            fill_kwargs_opt.update(tik_fills)

            # 7) plots
            (figN, axN), (figV, axV), _ = plot_sizedist_all(
                specs=specs_opt,
                merged_spec=tik_specs,
                line_kwargs=line_kwargs_opt,
                merged_line_kwargs=tik_lines,
                fill_kwargs=fill_kwargs_opt,
                merged_fill_kwargs=tik_fills,
                inlet_flag=inlet_chunk,
                d_str=a_date,
            )
            figN.savefig(plots_dir / f"{a_date}_chunk{i:03d}_dNdlogDp.png", dpi=200)
            plt.close(figN)
            figV.savefig(plots_dir / f"{a_date}_chunk{i:03d}_dVdlogDp.png", dpi=200)
            plt.close(figV)

            # 8) collect + rebin onto GLOBAL_EDGES (IDENTICAL ACROSS ALL DAYS)
            tik_name  = next(iter(tik_specs.keys()))
            tik_edges = tik_specs[tik_name][1]
            tik_vals  = tik_specs[tik_name][2]

            fims_on_common   = remap_dndlog_by_edges_any(
                specs_opt["FIMS_applied"][1], common_edges, specs_opt["FIMS_applied"][2]
            )
            uhsas_on_common  = remap_dndlog_by_edges_any(
                specs_opt[uh_label][1], common_edges, specs_opt[uh_label][2]
            )
            pops_on_common   = remap_dndlog_by_edges_any(
                specs_opt[po_label][1], common_edges, specs_opt[po_label][2]
            )
            aps_on_common    = remap_dndlog_by_edges_any(
                specs_opt[ap_label][1], common_edges, specs_opt[ap_label][2]
            )
            merged_on_common = remap_dndlog_by_edges_any(
                tik_edges, common_edges, tik_vals
            )

            day_fims_algn.append(fims_on_common)
            day_uhsas_algn.append(uhsas_on_common)
            day_pops_algn.append(pops_on_common)
            day_aps_algn.append(aps_on_common)
            day_merged.append(merged_on_common)
            day_times_start.append(t_start)
            day_times_end.append(t_end)
            day_incloud.append(inc_flag)
            day_n_fit.append(opt_res["n_fit"])
            day_n_pops_fit.append(opt_res["n_pops_fit"])
            day_rho_fit.append(opt_res["rho_fit"])
            day_best_cost.append(opt_res["best_cost"])

        except Exception as e:
            err_ts = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S %Z")
            with ERROR_LOG.open("a") as ef:
                ef.write(f"[{err_ts}] ERROR on date {a_date} chunk {i:03d}\n")
                ef.write(f"t_start={t_start}, t_end={t_end}\n")
                ef.write(f"{type(e).__name__}: {e}\n")
                ef.write(traceback.format_exc())
                ef.write("\n")
            with LOG_FILE.open("a") as f:
                f.write(f"\t[ERROR] chunk {i:03d} failed, see error_log.txt\n")

    # write per-day (edges identical across ALL days)
    if len(day_merged) > 0:
        write_day_netcdf_v2(
            day_dir,
            a_date,
            day_fine_edges=np.asarray(common_edges, float),  # GLOBAL_EDGES
            day_fims_algn=np.asarray(day_fims_algn, float),
            day_uhsas_algn=np.asarray(day_uhsas_algn, float),
            day_pops_algn=np.asarray(day_pops_algn, float),
            day_aps_algn=np.asarray(day_aps_algn, float),
            day_fine_vals=np.asarray(day_merged, float),
            day_times_start=day_times_start,
            day_times_end=day_times_end,
            day_incloud_flag=np.asarray(day_incloud, int),
            day_n_fit=np.asarray(day_n_fit, float),
            day_n_pops_fit=np.asarray(day_n_pops_fit, float),
            day_rho_fit=np.asarray(day_rho_fit, float),
            day_best_cost=np.asarray(day_best_cost, float),
            orig_APS_edges=np.asarray(orig_APS_edges, float),
            orig_UHSAS_edges=np.asarray(orig_UHSAS_edges, float),
            orig_POPS_edges=np.asarray(orig_POPS_edges, float),
            orig_FIMS_edges=np.asarray(orig_FIMS_edges, float),
        )
    else:
        with LOG_FILE.open("a") as f:
            f.write(f"[WARN] {a_date}: no valid chunks to write\n")

## Current ARCSIX Campaign Size Distribution Merge R1 QC

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
QC plots for ARCSIX size-distribution merge output NetCDF files (NO flag filtering)
+ write QC-flagged NetCDF copies with added warning flags.
"""

import sys
from pathlib import Path
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from netCDF4 import Dataset

# -------------------------------------------------------------------
# PATH SETTINGS
# -------------------------------------------------------------------
BASE_DIR = Path("/Users/C832577250/Output/arcsix_sizedist_merge_batch_v3")

DATA_DIR = Path("/Volumes/Hailstone Data/Research Data/ARCSIX_P3B")
micro_dir = DATA_DIR / "LARGE-MICROPHYSICAL"

QC_DIR = BASE_DIR / "qc_plots"
QC_DIR.mkdir(parents=True, exist_ok=True)

QC_NC_DIR = BASE_DIR / "qc_flagged_nc"
QC_NC_DIR.mkdir(parents=True, exist_ok=True)

REPO_BASE = Path.cwd().parent
sys.path.append(str(REPO_BASE / "src"))
from ict_utils import read_microphysical  # noqa: E402


# -------------------------------------------------------------------
# CONSTANTS / THRESHOLDS
# -------------------------------------------------------------------
HIGH_COST_THRESH = 0.2

# cutoff for totals comparison (nm)
DP_CUTOFF_NM = 10.0

# robust outlier K for WARNING flag
K_SIGMA_WARN = 10.0

# robust outlier K for DELETING extreme chunks in qc_flagged_nc output
K_SIGMA_DROP = 20.0

MIN_POINTS_FOR_ROBUST = 50  # hard fail if too few points to define bounds

# scatter plot axes
PLOT_LO = 0.0
PLOT_HI = 8000.0

# threshold line styling (for cost hist only)
THRESH_COLOR = "orange"
THRESH_LW = 2.0
THRESH_LS = "-"


# -------------------------------------------------------------------
# SMALL HELPERS
# -------------------------------------------------------------------
def _read_var_as_array(ds, name: str) -> np.ndarray:
    """Read a variable from a Dataset and convert _FillValue to NaN."""
    var = ds.variables[name]
    arr = np.array(var[:], dtype=float)

    fv = getattr(var, "_FillValue", None)
    if fv is not None:
        arr[arr == fv] = np.nan
    arr[~np.isfinite(arr)] = np.nan
    return arr


def _parse_base_time(ds: Dataset) -> datetime:
    base_iso = ds.getncattr("base_time_iso")
    if base_iso.endswith("Z"):
        base_iso = base_iso[:-1]
    if base_iso.endswith("+00:00") or base_iso.endswith("-00:00"):
        base_iso = base_iso[:-6]
    base_dt = datetime.fromisoformat(base_iso)
    if base_dt.tzinfo is not None:
        base_dt = base_dt.replace(tzinfo=None)
    return base_dt


def _ensure_naive_datetime_index(idx: pd.Index) -> pd.DatetimeIndex:
    if not isinstance(idx, pd.DatetimeIndex):
        raise TypeError("Expected a DatetimeIndex for CPC time axis.")
    if idx.tz is not None:
        idx = idx.tz_convert("UTC").tz_localize(None)
    return idx


def _add_flag_var(dst_ds: Dataset, name: str, dims: tuple, data: np.ndarray, long_name: str, comment: str):
    """Add int8 flag variable with 0/1."""
    if name in dst_ds.variables:
        raise RuntimeError(f"Destination file already has variable {name} (refusing to overwrite).")

    v = dst_ds.createVariable(name, "i1", dims)
    v[:] = data.astype(np.int8)
    v.setncattr("long_name", long_name)
    v.setncattr("units", "1")
    v.setncattr("flag_values", np.array([0, 1], dtype=np.int8))
    v.setncattr("flag_meanings", "ok warning")
    v.setncattr("comment", comment)


def _copy_netcdf_subset_by_dim(
    src: Dataset,
    dst_path: Path,
    *,
    subset_dim: str,
    keep_idx: np.ndarray,
):
    """
    Create dst NETCDF4 file and copy everything from src, but SUBSET along subset_dim for any variable
    that contains that dimension. All other dimensions/variables copied as-is.

    Returns: dst Dataset (open, caller must close).
    """
    if dst_path.exists():
        dst_path.unlink()

    dst = Dataset(dst_path, mode="w", format="NETCDF4")

    # global attrs
    for attr in src.ncattrs():
        dst.setncattr(attr, src.getncattr(attr))

    # dims
    keep_n = int(keep_idx.size)
    for dname, dim in src.dimensions.items():
        if dname == subset_dim:
            dst.createDimension(dname, keep_n)
        else:
            dst.createDimension(dname, (len(dim) if not dim.isunlimited() else None))

    # vars
    for vname, svar in src.variables.items():
        fill_value = getattr(svar, "_FillValue", None)
        if fill_value is not None:
            dvar = dst.createVariable(vname, svar.dtype, svar.dimensions, fill_value=fill_value)
        else:
            dvar = dst.createVariable(vname, svar.dtype, svar.dimensions)

        for attr in svar.ncattrs():
            if attr == "_FillValue":
                continue
            dvar.setncattr(attr, svar.getncattr(attr))

        data = svar[:]
        if subset_dim in svar.dimensions:
            axis = svar.dimensions.index(subset_dim)
            data = np.take(data, keep_idx, axis=axis)
        dvar[:] = data

    return dst


# -------------------------------------------------------------------
# INTEGRATION FOR >DP_CUTOFF_NM
# -------------------------------------------------------------------
def _gt_cutoff_weights_from_edges(fine_edges_nm: np.ndarray, cutoff_nm: float) -> np.ndarray:
    edges = np.asarray(fine_edges_nm, dtype=float)
    if edges.ndim != 1 or edges.size < 2:
        raise ValueError("fine_edges_nm must be a 1D array with length >= 2.")
    if not np.all(np.isfinite(edges)):
        raise ValueError("fine_edges_nm contains non-finite values.")
    if not np.all(np.diff(edges) > 0):
        raise ValueError("fine_edges_nm must be strictly increasing.")
    if not np.isfinite(cutoff_nm) or cutoff_nm <= 0:
        raise ValueError("cutoff_nm must be a positive finite number.")

    nb = edges.size - 1
    w = np.zeros(nb, dtype=float)

    if cutoff_nm <= edges[0]:
        w[:] = 1.0
        return w
    if cutoff_nm >= edges[-1]:
        return w

    loge = np.log10(edges)
    i = np.searchsorted(edges, cutoff_nm, side="right") - 1
    if i < 0:
        w[:] = 1.0
        return w
    if i >= nb:
        return w

    if i + 1 < nb:
        w[i + 1:] = 1.0

    lo = loge[i]
    hi = loge[i + 1]
    lc = np.log10(cutoff_nm)
    frac = (hi - lc) / (hi - lo)
    w[i] = float(np.clip(frac, 0.0, 1.0))
    return w


def integrate_dNdlogDp_gt_cutoff(dNdlogDp: np.ndarray, fine_edges_nm: np.ndarray, cutoff_nm: float) -> np.ndarray:
    """
    Integrate dN/dlog10Dp over Dp>cutoff:
      N_gt = Σ dNdlogDp * dlog10Dp * weight
    """
    A = np.asarray(dNdlogDp, dtype=float)
    if A.ndim != 2:
        raise ValueError("dNdlogDp must be 2D (chunk, fine_bin).")

    edges = np.asarray(fine_edges_nm, dtype=float)
    dlog10 = np.diff(np.log10(edges))
    if A.shape[1] != dlog10.size:
        raise RuntimeError(
            f"Shape mismatch: dNdlogDp has {A.shape[1]} bins but fine_edges_nm implies {dlog10.size} bins."
        )

    w = _gt_cutoff_weights_from_edges(edges, cutoff_nm)
    contrib = A * (dlog10[None, :] * w[None, :])
    return np.nansum(contrib, axis=1)


# -------------------------------------------------------------------
# ROBUST BOUNDS (LINEAR RESIDUAL metric)
# -------------------------------------------------------------------
def compute_robust_linear_bounds(
    merged: np.ndarray,
    cpc: np.ndarray,
    *,
    k_sigma: float,
    min_points: int,
):
    """
    r = (merged - cpc)
    bounds: median(r) ± k_sigma*(1.4826*MAD(r))
    """
    y = np.asarray(merged, float)
    x = np.asarray(cpc, float)

    ok = np.isfinite(x) & np.isfinite(y)
    if not np.any(ok):
        raise RuntimeError("No valid merged/cpc points for linear bounds.")

    r = (y[ok] - x[ok])

    if r.size < int(min_points):
        raise RuntimeError(f"Not enough valid points for robust bounds: {r.size} < {min_points}")

    r_med = float(np.median(r))
    mad = float(np.median(np.abs(r - r_med)))
    sigma = 1.4826 * mad
    if not np.isfinite(sigma) or sigma <= 0:
        raise RuntimeError(f"Invalid robust sigma: sigma={sigma} (mad={mad})")

    r_low = r_med - k_sigma * sigma
    r_high = r_med + k_sigma * sigma
    return float(r_low), float(r_high), r_med, float(sigma)


def linear_resid_and_flag(
    merged: np.ndarray,
    cpc: np.ndarray,
    *,
    r_low: float,
    r_high: float,
):
    """
    r = merged - cpc
    flag if r outside [r_low, r_high]
    """
    y = np.asarray(merged, float)
    x = np.asarray(cpc, float)

    r_full = np.full_like(x, np.nan, dtype=float)
    ok = np.isfinite(x) & np.isfinite(y)
    r_full[ok] = y[ok] - x[ok]

    flag = ok & ((r_full < r_low) | (r_full > r_high))
    return r_full, flag


# -------------------------------------------------------------------
# GATHER EVERYTHING IN ONE PASS (FOR PLOTS/CSV/BOUNDS)
# -------------------------------------------------------------------
def gather_all_chunks(base_dir: Path):
    nc_files = sorted(base_dir.glob("**/*_sizedist_merged_v2.nc"))
    if not nc_files:
        raise FileNotFoundError(f"No *_sizedist_merged_v2.nc files found under {base_dir}")

    print(f"Found {len(nc_files)} NetCDF files.")

    all_cost = []
    all_uhsas_n = []
    all_pops_n = []
    all_rho = []

    all_cpc_median = []
    all_merged_total_gt10 = []

    micro_cache = {}

    for nc_path in nc_files:
        date_str = nc_path.stem.split("_")[0]
        print(f"\nProcessing {nc_path} (date {date_str})")

        if date_str not in micro_cache:
            micro = read_microphysical(micro_dir, start=date_str, end=None, prefix="ARCSIX")
            cpc_series = pd.to_numeric(micro.get("CNgt10nm"), errors="coerce")
            cpc_series.index = _ensure_naive_datetime_index(cpc_series.index)
            micro_cache[date_str] = cpc_series
        else:
            cpc_series = micro_cache[date_str]

        with Dataset(nc_path, mode="r") as ds:
            for v in (
                "optimization_best_cost",
                "retrieved_uhsas_n_fit",
                "retrieved_pops_n_fit",
                "retrieved_aps_density",
                "merged_dNdlogDp",
                "fine_edges_nm",
                "time_start_since_base_s",
                "time_end_since_base_s",
            ):
                if v not in ds.variables:
                    raise RuntimeError(f"{nc_path} missing {v}")

            cost = _read_var_as_array(ds, "optimization_best_cost")
            uhsas_n = _read_var_as_array(ds, "retrieved_uhsas_n_fit")
            pops_n = _read_var_as_array(ds, "retrieved_pops_n_fit")
            rho = _read_var_as_array(ds, "retrieved_aps_density")

            merged_dNdlogDp = _read_var_as_array(ds, "merged_dNdlogDp")
            fine_edges_nm = _read_var_as_array(ds, "fine_edges_nm")

            time_start_s = _read_var_as_array(ds, "time_start_since_base_s")
            time_end_s = _read_var_as_array(ds, "time_end_since_base_s")

            base_dt = _parse_base_time(ds)

            n_chunks = uhsas_n.size
            if not (pops_n.size == n_chunks and rho.size == n_chunks and cost.size == n_chunks):
                raise RuntimeError(f"{nc_path}: chunk length mismatch among cost/uhsas/pops/rho")
            if merged_dNdlogDp.shape[0] != n_chunks:
                raise RuntimeError(f"{nc_path}: merged_dNdlogDp chunk mismatch")
            if time_start_s.size != n_chunks or time_end_s.size != n_chunks:
                raise RuntimeError(f"{nc_path}: time_start/end chunk mismatch")

            merged_total_gt10 = integrate_dNdlogDp_gt_cutoff(merged_dNdlogDp, fine_edges_nm, DP_CUTOFF_NM)

            cpc_median = np.full(n_chunks, np.nan, dtype=float)
            for j in range(n_chunks):
                if not (np.isfinite(time_start_s[j]) and np.isfinite(time_end_s[j])):
                    continue
                t0 = base_dt + timedelta(seconds=float(time_start_s[j]))
                t1 = base_dt + timedelta(seconds=float(time_end_s[j]))
                cpc_chunk = cpc_series.loc[t0:t1]
                if cpc_chunk.size:
                    cpc_median[j] = float(cpc_chunk.median())

        for j in range(n_chunks):
            all_cost.append(float(cost[j]) if np.isfinite(cost[j]) else np.nan)
            all_uhsas_n.append(float(uhsas_n[j]) if np.isfinite(uhsas_n[j]) else np.nan)
            all_pops_n.append(float(pops_n[j]) if np.isfinite(pops_n[j]) else np.nan)
            all_rho.append(float(rho[j]) if np.isfinite(rho[j]) else np.nan)

            all_cpc_median.append(float(cpc_median[j]) if np.isfinite(cpc_median[j]) else np.nan)
            all_merged_total_gt10.append(float(merged_total_gt10[j]) if np.isfinite(merged_total_gt10[j]) else np.nan)

    return (
        np.asarray(all_cost, float),
        np.asarray(all_uhsas_n, float),
        np.asarray(all_pops_n, float),
        np.asarray(all_rho, float),
        np.asarray(all_cpc_median, float),
        np.asarray(all_merged_total_gt10, float),
    )


# -------------------------------------------------------------------
# PLOTS
# -------------------------------------------------------------------
def plot_cost_hist(cost: np.ndarray, out_png: Path):
    c = np.asarray(cost, float)
    allv = c[np.isfinite(c)]
    allv = allv[(allv >= 0.0) & (allv <= 1.0)]

    bins = np.linspace(0.0, 1.0, 51)
    plt.figure(figsize=(6, 4))
    if allv.size:
        plt.hist(allv, bins=bins, edgecolor="black", alpha=0.7)

    plt.axvline(HIGH_COST_THRESH, linestyle=THRESH_LS, linewidth=THRESH_LW, color=THRESH_COLOR)
    plt.xlabel("optimization_best_cost")
    plt.ylabel("Count")
    plt.title(f"warning_high_cost: cost > {HIGH_COST_THRESH}")
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()


def plot_simple_hist(data: np.ndarray, out_png: Path, xlabel: str, title: str, bins=60):
    x = np.asarray(data, float)
    v = x[np.isfinite(x)]

    plt.figure(figsize=(6, 4))
    if v.size:
        plt.hist(v, bins=bins, edgecolor="black", alpha=0.7)
    plt.xlabel(xlabel)
    plt.ylabel("Count")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()


def plot_scatter_merged_vs_cpc_flagged_linear(
    cpc: np.ndarray,
    merged: np.ndarray,
    flag: np.ndarray,
    *,
    r_low: float,
    r_high: float,
    out_png: Path,
):
    """
    Linear residual bounds:
      y = x + r_low
      y = x + r_high
    """
    x = np.asarray(cpc, float)
    y = np.asarray(merged, float)
    f = np.asarray(flag, bool)

    ok = np.isfinite(x) & np.isfinite(y) & (~f)
    bad = np.isfinite(x) & np.isfinite(y) & f

    plt.figure(figsize=(6.1, 5.5))
    if np.any(ok):
        plt.scatter(x[ok], y[ok], s=10, alpha=0.6, label="OK")
    if np.any(bad):
        plt.scatter(x[bad], y[bad], s=28, alpha=0.9, marker="x", label="warning_merged_gt10_diff_from_cpc=1")

    lo, hi = PLOT_LO, PLOT_HI
    xx = np.linspace(lo, hi, 600)

    plt.plot([lo, hi], [lo, hi], linestyle="dashed", linewidth=1.2, color="k", label="1:1")
    plt.plot(xx, xx + r_low, linestyle="dotted", linewidth=1.2, color="k", label="bound_low")
    plt.plot(xx, xx + r_high, linestyle="dotted", linewidth=1.2, color="k", label="bound_high")

    plt.xlabel("CPC10nm median (CNgt10nm) (#/cm$^3$)")
    plt.ylabel(f"Merged total(>{DP_CUTOFF_NM:g} nm) (#/cm$^3$)")
    plt.xlim(lo, hi)
    plt.ylim(lo, hi)
    plt.title(f"warning_merged_gt10_diff_from_cpc: robust linear residual outlier (K={K_SIGMA_WARN:g})")
    plt.legend(loc="best")
    plt.tight_layout()
    plt.savefig(out_png, dpi=200)
    plt.close()


# -------------------------------------------------------------------
# QC NETCDF WRITER (FILTER EXTREME CPC MISMATCH CHUNKS, THEN ADD FLAGS)
# -------------------------------------------------------------------
def write_qc_flagged_nc_files(
    base_dir: Path,
    out_dir: Path,
    *,
    # WARNING bounds
    r_low_warn: float,
    r_high_warn: float,
    # robust center/scale (same used to derive DROP bounds)
    r_med: float,
    sigma_r: float,
):
    nc_files = sorted(base_dir.glob("**/*_sizedist_merged_v2.nc"))
    if not nc_files:
        raise FileNotFoundError(f"No *_sizedist_merged_v2.nc files found under {base_dir}")

    r_low_drop = r_med - K_SIGMA_DROP * sigma_r
    r_high_drop = r_med + K_SIGMA_DROP * sigma_r

    print(f"\nWill write filtered+flagged copies for {len(nc_files)} NetCDF files into:\n  {out_dir}\n")
    print(
        f"[CPC WARNING BOUNDS]  K_WARN={K_SIGMA_WARN:g}:  r in [{r_low_warn:.6g}, {r_high_warn:.6g}]"
    )
    print(
        f"[CPC DROP BOUNDS]     K_DROP={K_SIGMA_DROP:g}: r in [{r_low_drop:.6g}, {r_high_drop:.6g}]"
    )

    micro_cache = {}
    total_dropped = 0
    total_kept = 0

    for src_path in nc_files:
        dst_path = out_dir / src_path.name
        print(f"\n[QC-FLAG+FILTER] {src_path} -> {dst_path}")

        with Dataset(src_path, mode="r") as src:
            for v in (
                "optimization_best_cost",
                "retrieved_uhsas_n_fit",
                "merged_dNdlogDp",
                "fine_edges_nm",
                "time_start_since_base_s",
                "time_end_since_base_s",
            ):
                if v not in src.variables:
                    raise RuntimeError(f"{src_path} missing variable {v}")

            cost = _read_var_as_array(src, "optimization_best_cost")

            uhsas_var = src.variables["retrieved_uhsas_n_fit"]
            dims = uhsas_var.dimensions
            if len(dims) != 1:
                raise RuntimeError(f"{src_path}: expected retrieved_uhsas_n_fit to be 1D, got dims={dims}")
            chunk_dim = dims[0]
            n_chunks = int(cost.size)

            # warning_high_cost (on original)
            high_cost_mask = np.isfinite(cost) & (cost > HIGH_COST_THRESH)

            # merged/CPC residuals (on original)
            merged_dNdlogDp = _read_var_as_array(src, "merged_dNdlogDp")
            fine_edges_nm = _read_var_as_array(src, "fine_edges_nm")
            time_start_s = _read_var_as_array(src, "time_start_since_base_s")
            time_end_s = _read_var_as_array(src, "time_end_since_base_s")
            base_dt = _parse_base_time(src)

            if merged_dNdlogDp.shape[0] != n_chunks:
                raise RuntimeError(f"{src_path}: merged_dNdlogDp chunk dimension mismatch")

            merged_total_gt10 = integrate_dNdlogDp_gt_cutoff(merged_dNdlogDp, fine_edges_nm, DP_CUTOFF_NM)

            date_str = src_path.stem.split("_")[0]
            if date_str not in micro_cache:
                micro = read_microphysical(micro_dir, start=date_str, end=None, prefix="ARCSIX")
                cpc_series = pd.to_numeric(micro.get("CNgt10nm"), errors="coerce")
                cpc_series.index = _ensure_naive_datetime_index(cpc_series.index)
                micro_cache[date_str] = cpc_series
            else:
                cpc_series = micro_cache[date_str]

            cpc_median = np.full(n_chunks, np.nan, dtype=float)
            for j in range(n_chunks):
                if not (np.isfinite(time_start_s[j]) and np.isfinite(time_end_s[j])):
                    continue
                t0 = base_dt + timedelta(seconds=float(time_start_s[j]))
                t1 = base_dt + timedelta(seconds=float(time_end_s[j]))
                cpc_chunk = cpc_series.loc[t0:t1]
                if cpc_chunk.size:
                    cpc_median[j] = float(cpc_chunk.median())

            # residuals and WARNING flag (K_WARN=10) on original
            r_full, warn_flag = linear_resid_and_flag(
                merged_total_gt10,
                cpc_median,
                r_low=r_low_warn,
                r_high=r_high_warn,
            )

            # DROP mask (K_DROP=20) on original: only where residual is finite
            ok = np.isfinite(r_full)
            drop_mask = ok & ((r_full < r_med - K_SIGMA_DROP * sigma_r) | (r_full > r_med + K_SIGMA_DROP * sigma_r))

            keep_mask = ~drop_mask
            keep_idx = np.flatnonzero(keep_mask).astype(int)

            n_drop = int(np.sum(drop_mask))
            n_keep = int(keep_idx.size)
            total_dropped += n_drop
            total_kept += n_keep

            print(f"    chunks: total={n_chunks}  drop_extreme(K={K_SIGMA_DROP:g})={n_drop}  keep={n_keep}")
            if n_keep <= 0:
                raise RuntimeError(f"{src_path}: all chunks were dropped by extreme CPC filter; refusing to write empty file.")

            # create subsetted copy
            dst = _copy_netcdf_subset_by_dim(src, dst_path, subset_dim=chunk_dim, keep_idx=keep_idx)

            # subset flags to kept
            high_cost_keep = high_cost_mask[keep_idx]
            warn_keep = warn_flag[keep_idx]

            # add ONLY kept flags to dst
            _add_flag_var(
                dst,
                "warning_high_cost",
                (chunk_dim,),
                high_cost_keep.astype(np.int8),
                long_name="QC warning: optimization cost exceeds threshold",
                comment=f"1 if optimization_best_cost > {HIGH_COST_THRESH}, else 0",
            )
            _add_flag_var(
                dst,
                "warning_merged_gt10_diff_from_cpc",
                (chunk_dim,),
                warn_keep.astype(np.int8),
                long_name=f"QC warning: merged total(>{DP_CUTOFF_NM:g} nm) vs CPC is an outlier (linear residual)",
                comment=(
                    f"r = merged_total_gt{DP_CUTOFF_NM:g}nm - CPC_median. "
                    f"Robust bounds: r within median±{K_SIGMA_WARN:g}*(1.4826*MAD(r)). "
                    f"Computed over ALL chunks: r_med={r_med}, sigma_r={sigma_r}, "
                    f"r_low_warn={r_low_warn}, r_high_warn={r_high_warn}. "
                    f"Flag=1 if r outside [r_low_warn,r_high_warn]. "
                    f"NOTE: extreme chunks were removed using K_DROP={K_SIGMA_DROP:g} before writing this file."
                ),
            )

            print(
                f"    warnings (on kept): high_cost={int(np.nansum(high_cost_keep))}, "
                f"merged_linear_outlier={int(np.nansum(warn_keep))}"
            )

            dst.close()

    print(f"\nDone writing QC-flagged NetCDF copies.")
    print(f"[TOTAL] dropped_extreme={total_dropped}  kept={total_kept}")


# -------------------------------------------------------------------
# SUMMARY STATS (only kept warnings; computed over ALL chunks)
# -------------------------------------------------------------------
def print_warning_summary(flag_dict: dict[str, np.ndarray]):
    keys = list(flag_dict.keys())
    if not keys:
        raise ValueError("flag_dict is empty")

    n = int(flag_dict[keys[0]].size)
    for k in keys[1:]:
        if int(flag_dict[k].size) != n:
            raise RuntimeError("Flag masks have inconsistent lengths")

    any_warn = np.zeros(n, dtype=bool)
    print("\n=== WARNING SUMMARY (per-chunk; ALL chunks, K_WARN applied) ===")
    print(f"Total chunks: {n}")

    for k in keys:
        m = np.asarray(flag_dict[k], bool)
        c = int(np.nansum(m))
        any_warn |= m
        pct = (100.0 * c / n) if n > 0 else np.nan
        print(f"  {k}: {c} ({pct:.2f}%)")

    c_any = int(np.nansum(any_warn))
    c_free = int(n - c_any)
    pct_any = (100.0 * c_any / n) if n > 0 else np.nan
    pct_free = (100.0 * c_free / n) if n > 0 else np.nan
    print(f"  any_warning: {c_any} ({pct_any:.2f}%)")
    print(f"  warning_free: {c_free} ({pct_free:.2f}%)")


# -------------------------------------------------------------------
# MAIN
# -------------------------------------------------------------------
(
    all_cost,
    all_uhsas_n,
    all_pops_n,
    all_rho,
    all_cpc_median,
    all_merged_total_gt10,
) = gather_all_chunks(BASE_DIR)

# warning_high_cost mask
mask_cost = np.isfinite(all_cost) & (all_cost > HIGH_COST_THRESH)

# robust bounds for merged/CPC using LINEAR residual (WARNING K=10)
r_low_warn, r_high_warn, r_med, sigma_r = compute_robust_linear_bounds(
    merged=all_merged_total_gt10,
    cpc=all_cpc_median,
    k_sigma=K_SIGMA_WARN,
    min_points=MIN_POINTS_FOR_ROBUST,
)

r_metric, mask_merged_vs_cpc_warn = linear_resid_and_flag(
    all_merged_total_gt10,
    all_cpc_median,
    r_low=r_low_warn,
    r_high=r_high_warn,
)

# drop bounds (for info)
r_low_drop = r_med - K_SIGMA_DROP * sigma_r
r_high_drop = r_med + K_SIGMA_DROP * sigma_r

print(
    f"\n[MERGED/CPC LINEAR] r_med={r_med:.6g}, sigma_r={sigma_r:.6g} (sigma=1.4826*MAD), "
    f"K_WARN={K_SIGMA_WARN:g}, K_DROP={K_SIGMA_DROP:g}"
)
print(
    f"[WARNING BOUNDS] r_low_warn={r_low_warn:.6g}, r_high_warn={r_high_warn:.6g}"
)
print(
    f"[DROP BOUNDS]    r_low_drop={r_low_drop:.6g}, r_high_drop={r_high_drop:.6g}"
)

valid = np.isfinite(r_metric)
print(f"[MERGED/CPC WARNING FLAG COUNT] {int(np.nansum(mask_merged_vs_cpc_warn))} flagged / {int(np.sum(valid))} valid")

# PLOTS
plot_cost_hist(all_cost, QC_DIR / "hist_optimization_best_cost_with_thresh.png")

# Keep histograms but NO threshold lines
plot_simple_hist(all_uhsas_n, QC_DIR / "hist_uhsas_n.png", xlabel="retrieved_uhsas_n_fit", title="UHSAS n distribution")
plot_simple_hist(all_pops_n, QC_DIR / "hist_pops_n.png", xlabel="retrieved_pops_n_fit", title="POPS n distribution")
plot_simple_hist(all_rho, QC_DIR / "hist_aps_rho.png", xlabel="retrieved_aps_density (kg m$^{-3}$)", title="APS density distribution")

plot_scatter_merged_vs_cpc_flagged_linear(
    cpc=all_cpc_median,
    merged=all_merged_total_gt10,
    flag=mask_merged_vs_cpc_warn,
    r_low=r_low_warn,
    r_high=r_high_warn,
    out_png=QC_DIR / "scatter_merged_gt10nm_vs_cpc_flagged.png",
)

# SUMMARY STATS (only kept warnings; WARNING masks with K_WARN)
print_warning_summary(
    {
        "warning_high_cost": mask_cost,
        "warning_merged_gt10_diff_from_cpc": mask_merged_vs_cpc_warn,
    }
)

# CSV (keep distributions + WARNING flags)
out_csv = QC_DIR / "per_chunk_qc_with_flags.csv"
df = pd.DataFrame(
    {
        "optimization_best_cost": all_cost,
        "retrieved_uhsas_n_fit": all_uhsas_n,
        "retrieved_pops_n_fit": all_pops_n,
        "retrieved_aps_density": all_rho,
        "CPC_median_CNgt10nm": all_cpc_median,
        f"MERGED_total_gt{int(DP_CUTOFF_NM)}nm": all_merged_total_gt10,
        "linear_residual_merged_minus_cpc": r_metric,
        "warning_high_cost": mask_cost.astype(int),
        "warning_merged_gt10_diff_from_cpc": mask_merged_vs_cpc_warn.astype(int),
    }
)
df.to_csv(out_csv, index=False)
print(f"\nWrote per-chunk QC table to {out_csv}")

# Write QC-flagged NetCDF copies:
# - delete extreme CPC mismatch chunks using K_DROP=20
# - then add warning_high_cost + warning_merged_gt10_diff_from_cpc (K_WARN=10) on remaining chunks
write_qc_flagged_nc_files(
    BASE_DIR,
    QC_NC_DIR,
    r_low_warn=r_low_warn,
    r_high_warn=r_high_warn,
    r_med=r_med,
    sigma_r=sigma_r,
)