In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Gaia distance pipeline with QC/adoption:
- Resolves coords (from input RA/Dec if present, else SIMBAD with Bayer/Flamsteed variants)
- Pulls Bailer-Jones (I/352) distances + quantiles
- Pulls DR3 (I/355) astrometry; optional EDR3 (I/350) parallax
- Applies ZPT (piecewise by G & BP-RP), RUWE inflation (with floor & micro-sys)
- Smartly picks EDR3 vs DR3 parallax (min sigma_total with penalties)
- EDSD posterior (adaptive grid) → p16/p50/p84 and ± error
- Computes parallax SNR and QC flags; AUTO-ADOPTS final distance using thresholds
- Outputs:
    gaia_full_with_ruwe.csv
    gaia_distances_new.csv
    gaia_distances_new_with_ci.csv
    gaia_distances_summary.csv  (includes adopted_* and quality_note)
"""

import time
import re
import numpy as np
import pandas as pd

from astroquery.simbad import Simbad
from astroquery.vizier import Vizier
from astropy.coordinates import SkyCoord, BarycentricTrueEcliptic, Angle
import astropy.units as u

try:
    from tqdm.auto import tqdm
except Exception:
    def tqdm(x, **kwargs): return x

# =================== CONFIG ===================

FILE = "obj_list.xlsx"
SHEET = None

CAT_BJ   = "I/352/gedr3dis"   # Bailer-Jones 2021 (Gaia EDR3 distances)
CAT_EDR3 = "I/350/gaiaedr3"   # EDR3 parallax by Source
CAT_DR3  = "I/355/gaiadr3"    # DR3 main

SEARCH_RADIUS_ARCSEC = 8.0     # BJ cone
DR3_FALLBACK_ARCSEC  = 5.0     # DR3 cone fallback

# EDSD prior scale (can be replaced by a map)
L_EDSD_PC_DEFAULT = 1350.0

# RUWE inflation (heuristics akin to El-Badry 2025)
E_ALPHA = 2.77
E_F0    = 3.73
E_BETA  = 0.065
E_GAMMA = -0.056
SIGMA_FLOOR_MAS = 0.02   # minimal random floor (mas)
SIGMA_SYS_MAS   = 0.015  # micro systematic (mas), added in quadrature

# Zero-point (piecewise; fallback constants)
USE_ZPT = True
ZPT_CONST_DR3  = -0.017
ZPT_CONST_EDR3 = -0.017

# Fallback for missing parallax error
USE_ERR_FALLBACK     = True
DEFAULT_PLX_ERR_MAS  = 0.5

# Network retries
MAX_TRIES = 3
SLEEP_BASE = 0.6

# Verbosity
VERBOSE_FIRST_N = 3

# Adoption thresholds (tune if needed)
SNR_MIN = 2.5              # minimal SNR for adopting EDSD_new
MAX_REL_ERR_NEW = 0.40     # max relative error for EDSD_new
MAX_REL_ERR_OLD = 0.60     # max relative error for BJ_old
MAX_DR3_SEP_ARCSEC = 2.0   # max acceptable DR3 match separation
REQUIRE_GOOD_QC = True     # require good Gaia QC (no dup, good Nper, etc.)

# ================== INPUT ==================
df = pd.read_excel(FILE, sheet_name=SHEET) if SHEET else pd.read_excel(FILE)
df.columns = [str(c).strip() for c in df.columns]
print("Columns:", df.columns.tolist())

def pick_col(cands):
    lc = [x.lower() for x in cands]
    for c in df.columns:
        if c.lower() in lc: return c
    return None

NAME_COL  = pick_col(["name","object","obj","star","id"])
HD_COL    = pick_col(["HD","hd"])
BS_COL    = pick_col(["BS","HR","hr","bs"])
RA_COL    = pick_col(["ra","ra_deg","RA","RA_deg"])
DEC_COL   = pick_col(["dec","dec_deg","DE","DE_deg","decl","decl_deg"])
print("Detected:", {"name": NAME_COL, "HD": HD_COL, "BS/HR": BS_COL, "RA": RA_COL, "DEC": DEC_COL})

# =============== NAME NORMALIZATION ===============
GREEK_MAP = {
    "Alpha":"alpha","Alp":"alpha","Beta":"beta","Bet":"beta","Gamma":"gamma","Gam":"gamma",
    "Delta":"delta","Del":"delta","Epsilon":"epsilon","Eps":"epsilon","Zeta":"zeta","Zet":"zeta",
    "Eta":"eta","Theta":"theta","The":"theta","Iota":"iota","Iot":"iota","Kappa":"kappa","Kap":"kappa",
    "Lambda":"lambda","Lam":"lambda","Mu":"mu","Nu":"nu","Xi":"xi","Omicron":"omicron","Omi":"omicron",
    "Pi":"pi","Rho":"rho","Sigma":"sigma","Sig":"sigma","Tau":"tau","Upsilon":"upsilon","Ups":"upsilon",
    "Phi":"phi","Chi":"chi","Psi":"psi","Omega":"omega","Ome":"omega",
}
GREEK_ABBR = {
    "alpha":"alf","beta":"bet","gamma":"gam","delta":"del","epsilon":"eps","zeta":"zet",
    "eta":"eta","theta":"the","iota":"iot","kappa":"kap","lambda":"lam","mu":"mu","nu":"nu",
    "xi":"xi","omicron":"omi","pi":"pi","rho":"rho","sigma":"sig","tau":"tau","upsilon":"ups",
    "phi":"phi","chi":"chi","psi":"psi","omega":"ome"
}
IAU3 = {
    "And":"And","Ant":"Ant","Aps":"Aps","Aql":"Aql","Aqr":"Aqr","Ara":"Ara","Ari":"Ari",
    "Aur":"Aur","Boo":"Boo","CMa":"CMa","CMi":"CMi","Cas":"Cas","Cen":"Cen","Cep":"Cep",
    "Cyg":"Cyg","Gem":"Gem","Ori":"Ori","Per":"Per","Tri":"Tri","UMa":"UMa","UMi":"UMi",
    # ... extend if needed
}

def normalize_bd_like(s: str) -> str:
    s = s.strip()
    s = re.sub(r'^(BD|CD|CPD)\s*([+-])\s*', r'\1\2', s, flags=re.IGNORECASE)
    s = re.sub(r'^(BD[+-]\d+)\s+(\d+)$', r'\1 \2', s, flags=re.IGNORECASE)
    s = re.sub(r'\s+', ' ', s)
    return s

def bayer_variants(s: str):
    out = set()
    t = s.strip()
    m = re.match(r'^([A-Za-z]{3,7})(\d{0,2})([A-Z][A-Za-z]{2})$', t) \
        or re.match(r'^([A-Za-z]{3,7})\s*(\d{0,2})\s+([A-Z][A-Za-z]{2})$', t)
    if not m: return out
    greek_raw, digit, const = m.groups()
    key = greek_raw.capitalize()
    greek_full = GREEK_MAP.get(key, greek_raw.lower())
    abbr = GREEK_ABBR.get(greek_full, greek_full[:3])
    const_std = IAU3.get(const, const)
    out.update({f"{greek_full}{digit} {const_std}".strip(),
                f"{abbr}{digit} {const_std}".strip()})
    return out

def flamsteed_variants(s: str):
    out = set()
    m = re.match(r'^(\d{1,3})\s+([A-Z][A-Za-z]{2})$', s.strip())
    if m:
        out.add(s.strip())
        out.add(f"{m.group(1)}{m.group(2)}")
    return out

def clean_name(name):
    if pd.isna(name): return None
    s = str(name).strip()
    if re.match(r'^(BD|CD|CPD)\b', s, flags=re.IGNORECASE):
        return normalize_bd_like(s)
    return s

def is_int_like(x):
    try: return int(float(str(x).strip()))
    except Exception: return None

def build_candidates(row):
    cands = []
    sset = set()

    if NAME_COL and not pd.isna(row.get(NAME_COL)):
        raw = str(row[NAME_COL]).strip()
        base = clean_name(raw)
        sset.add(raw)
        if base: sset.add(base)
        for x in list(sset):
            sset |= bayer_variants(x)
            sset |= flamsteed_variants(x)
        tmp = set(sset)
        for c in tmp:
            if "BD " in c: sset.add(c.replace("BD ", "BD"))
            if "BD+" in c: sset.add(c.replace("BD+", "BD +"))
        if base and re.match(r'^V\d{1,4}\s+[A-Z][a-z]{2}$', base):
            sset.add("V* " + base)
        cands.extend(list(sset))

    if HD_COL and not pd.isna(row.get(HD_COL)):
        num = is_int_like(row[HD_COL])
        if num is not None: cands.append(f"HD {num}")
    if BS_COL and not pd.isna(row.get(BS_COL)):
        num = is_int_like(row[BS_COL])
        if num is not None: cands.append(f"HR {num}")

    return [c for c in dict.fromkeys(cands) if c and len(c) >= 3]

# =============== PARSERS ===============
def parse_ra(val):
    if pd.isna(val): return None
    s = str(val).strip()
    try:
        if ":" in s or " " in s:
            return Angle(s, unit=u.hourangle).to(u.deg).value
        return float(s)
    except Exception:
        return None

def parse_dec(val):
    if pd.isna(val): return None
    s = str(val).strip()
    try:
        if ":" in s or " " in s:
            return Angle(s, unit=u.deg).to(u.deg).value
        return float(s)
    except Exception:
        return None

# ===================== NET HELPERS (retries) =====================
def with_retries(fn, *args, **kwargs):
    for i in range(MAX_TRIES):
        try:
            return fn(*args, **kwargs)
        except Exception:
            if i == MAX_TRIES-1:
                raise
            time.sleep(SLEEP_BASE * (2**i))

# ===================== SIMBAD =====================
Simbad.TIMEOUT = 60
Simbad.add_votable_fields("ra","dec")

def _coords_from_simbad_table(res):
    cmap = {c.lower(): c for c in res.colnames}
    if 'ra' in cmap and 'dec' in cmap:
        ra_val = res[cmap['ra']][0]; dec_val = res[cmap['dec']][0]
        try:
            if hasattr(ra_val, 'unit') and hasattr(dec_val, 'unit'):
                coord = SkyCoord(ra_val, dec_val)
                return coord.ra.deg, coord.dec.deg
        except Exception:
            pass
        if isinstance(ra_val,(str,bytes)) or isinstance(dec_val,(str,bytes)):
            coord = SkyCoord(str(ra_val), str(dec_val), unit=(u.hourangle, u.deg))
            return coord.ra.deg, coord.dec.deg
        return float(ra_val), float(dec_val)
    raise KeyError(f"SIMBAD columns not found: {res.colnames}")

def find_coords(row, verbose=False):
    # prefer RA/Dec from input
    if RA_COL and DEC_COL:
        ra_d = parse_ra(row.get(RA_COL))
        dec_d = parse_dec(row.get(DEC_COL))
        if ra_d is not None and dec_d is not None:
            if verbose: print(f"RA/Dec from input → {ra_d:.6f}, {dec_d:.6f}")
            return ra_d, dec_d, None, "INPUT"
    # otherwise resolve by names
    for cand in build_candidates(row):
        try:
            res = with_retries(Simbad.query_object, cand)
            if res is None or len(res)==0:
                if verbose: print("SIMBAD miss:", cand); continue
            ra, dec = _coords_from_simbad_table(res)
            if verbose: print(f"SIMBAD hit: {cand} → {ra:.6f}, {dec:.6f}")
            return ra, dec, cand, "SIMBAD"
        except Exception as e:
            if verbose: print("SIMBAD error", cand, e)
    return None, None, None, None

# ==================== VizieR ====================
Vizier.ROW_LIMIT = 200

viz_bj   = Vizier(columns=[
    "Source","RA_ICRS","DE_ICRS",
    "rgeo","b_rgeo","B_rgeo",
    "rpgeo","b_rpgeo","B_rpgeo",
    "Flag"
])

viz_edr3 = Vizier(columns=["Source","Plx","e_Plx"])

viz_dr3  = Vizier(columns=[
    "Source","RUWE","Gmag","Plx","e_Plx",
    "BPmag","RPmag","phot_bp_mean_mag","phot_rp_mean_mag",
    "nu_eff_used_in_astrometry","pseudocolour",
    "amax","Nper","Solved","Dup",
    "RA_ICRS","DE_ICRS"
])

def _to_deg(col):
    try:
        q = u.Quantity(col)
        if q.unit == u.dimensionless_unscaled: q = q * u.deg
        return q.to(u.deg)
    except Exception:
        return np.array(col, dtype=float) * u.deg

def _f(x):
    try:
        if hasattr(x, "mask") and getattr(x, "mask", False): return np.nan
        return float(x)
    except Exception:
        return np.nan

# ============ BJ distances ============
def get_bj_distance(ra_deg, dec_deg, radius_arcsec=SEARCH_RADIUS_ARCSEC, verbose=False):
    try:
        pos = SkyCoord(ra=ra_deg*u.deg, dec=dec_deg*u.deg, frame='icrs')
        res = with_retries(viz_bj.query_region, pos, radius=radius_arcsec*u.arcsec, catalog=CAT_BJ)
        if not res or len(res[0])==0:
            if verbose: print("BJ: no matches")
            return None
        tab = res[0]
        bj_ra = bj_dec = np.nan
        sep_as = np.nan
        if "RA_ICRS" in tab.colnames and "DE_ICRS" in tab.colnames:
            tpos = SkyCoord(ra=_to_deg(tab["RA_ICRS"]), dec=_to_deg(tab["DE_ICRS"]), frame='icrs')
            idx = int(np.argmin(pos.separation(tpos).arcsec))
            row = tab[idx]
            sep_as = float(pos.separation(tpos[idx]).arcsec)
            bj_ra  = float(tpos[idx].ra.deg)
            bj_dec = float(tpos[idx].dec.deg)
        else:
            row = tab[0]

        def pack(prefix):
            val = _f(row[prefix]) if prefix in row.colnames else np.nan
            lo  = _f(row["b_"+prefix]) if ("b_"+prefix) in row.colnames else np.nan
            hi  = _f(row["B_"+prefix]) if ("B_"+prefix) in row.colnames else np.nan
            err_sym = np.nan
            if np.isfinite(lo) and np.isfinite(hi):
                err_sym = 0.5*(hi - lo)
            else:
                err_plus  = (hi - val) if np.isfinite(hi) and np.isfinite(val) else np.nan
                err_minus = (val - lo) if np.isfinite(lo) and np.isfinite(val) else np.nan
                if np.isfinite(err_plus) or np.isfinite(err_minus):
                    err_sym = np.nanmean([err_plus, err_minus])
            return val, lo, hi, err_sym

        have_rp = ("rpgeo" in row.colnames) and np.isfinite(_f(row["rpgeo"]))
        if have_rp:
            dist, p16, p84, err_sym = pack("rpgeo"); used = "rpgeo"
        else:
            dist, p16, p84, err_sym = pack("rgeo");  used = "rgeo"

        source_id = int(row["Source"]) if "Source" in row.colnames else None
        flag_str  = str(row["Flag"]) if "Flag" in row.colnames else ""
        if verbose:
            print(f"BJ {used}: D={dist} [{p16},{p84}] (±~{err_sym}) src={source_id} flag={flag_str} sep={sep_as:.2f}\"")
        return dict(distance=dist, p16=p16, p84=p84, err_sym=err_sym,
                    used_model=used, source_id=source_id, flag=flag_str,
                    bj_ra_deg=bj_ra, bj_dec_deg=bj_dec, bj_sep_arcsec=sep_as)
    except Exception as e:
        if verbose: print("BJ error:", e)
        return None

# ================= EDR3 parallax =================
def get_edr3_parallax(source_id, verbose=False):
    if source_id is None: return np.nan, np.nan
    try:
        res = with_retries(viz_edr3.query_constraints, catalog=CAT_EDR3, Source=str(source_id))
        if res and len(res[0])>0:
            row = res[0][0]
            return _f(row["Plx"]), _f(row["e_Plx"])
    except Exception as e:
        if verbose: print("EDR3 plx error:", e)
    return np.nan, np.nan

# ================= DR3 info =================
def _get(row, *names):
    for n in names:
        try:
            return _f(row[n])
        except Exception:
            continue
    return np.nan

def get_dr3_info(source_id, ra_deg=None, dec_deg=None, bj_ra_deg=None, bj_dec_deg=None, verbose=False):
    def pack(row, pos_ref=None):
        dup_val = _get(row, "Dup")
        ra_dr3  = _get(row, "RA_ICRS")
        de_dr3  = _get(row, "DE_ICRS")
        sep_as  = np.nan
        if np.isfinite(ra_dr3) and np.isfinite(de_dr3) and pos_ref is not None:
            sep_as = float(pos_ref.separation(SkyCoord(ra=ra_dr3*u.deg, dec=de_dr3*u.deg)).arcsec)
        return dict(
            ruwe=_get(row, "RUWE","ruwe"),
            gmag=_get(row, "Gmag","phot_g_mean_mag"),
            bp=_get(row, "BPmag","phot_bp_mean_mag"),
            rp=_get(row, "RPmag","phot_rp_mean_mag"),
            nu_eff=_get(row, "nu_eff_used_in_astrometry"),
            pseudocolour=_get(row, "pseudocolour"),
            dr3_parallax=_get(row, "Plx","parallax"),
            dr3_parallax_err=_get(row, "e_Plx","e_parallax"),
            astrometric_sigma5d_max=_get(row, "amax","astrometric_sigma5d_max"),
            visibility_periods_used=_get(row, "Nper","visibility_periods_used"),
            astrometric_params_solved=_get(row, "Solved","astrometric_params_solved"),
            duplicated_source=(bool(int(dup_val)) if np.isfinite(dup_val) else False),
            dr3_ra_deg=ra_dr3, dr3_dec_deg=de_dr3,
            dr3_sep_arcsec=sep_as
        )

    if source_id is not None:
        try:
            res = with_retries(viz_dr3.query_constraints, catalog=CAT_DR3, Source=str(source_id))
            if res and len(res[0])>0:
                row = res[0][0]
                pos_ref = None
                if np.isfinite(bj_ra_deg) and np.isfinite(bj_dec_deg):
                    pos_ref = SkyCoord(ra=bj_ra_deg*u.deg, dec=bj_dec_deg*u.deg)
                elif (ra_deg is not None) and (dec_deg is not None):
                    pos_ref = SkyCoord(ra=ra_deg*u.deg, dec=dec*u.deg)
                return pack(row, pos_ref=pos_ref)
        except Exception as e:
            if verbose: print("DR3 by Source error:", e)

    for (r0, d0, label) in [(bj_ra_deg, bj_dec_deg, "BJ"), (ra_deg, dec_deg, "SIMBAD")]:
        if (r0 is not None) and (d0 is not None) and np.isfinite(r0) and np.isfinite(d0):
            try:
                pos = SkyCoord(ra=r0*u.deg, dec=d0*u.deg, frame='icrs')
                res = with_retries(viz_dr3.query_region, pos, radius=DR3_FALLBACK_ARCSEC*u.arcsec, catalog=CAT_DR3)
                if res and len(res[0])>0:
                    tab = res[0]
                    if "RA_ICRS" in tab.colnames and "DE_ICRS" in tab.colnames:
                        tpos = SkyCoord(ra=_to_deg(tab["RA_ICRS"]), dec=_to_deg(tab["DE_ICRS"]), frame='icrs')
                        idx = int(np.argmin(pos.separation(tpos).arcsec))
                        row = tab[idx]
                    else:
                        row = tab[0]
                    out = pack(row, pos_ref=pos)
                    if verbose: print(f"DR3 cone@{label}: sep={out['dr3_sep_arcsec']:.2f}\"")
                    return out
            except Exception as e:
                if verbose: print("DR3 by pos error:", e)

    return dict(
        ruwe=np.nan, gmag=np.nan, bp=np.nan, rp=np.nan,
        nu_eff=np.nan, pseudocolour=np.nan,
        dr3_parallax=np.nan, dr3_parallax_err=np.nan,
        astrometric_sigma5d_max=np.nan, visibility_periods_used=np.nan,
        astrometric_params_solved=np.nan, duplicated_source=False,
        dr3_ra_deg=np.nan, dr3_dec_deg=np.nan, dr3_sep_arcsec=np.nan
    )

# ================= RUWE inflation =================
def sigma_eta_mas_from_G(g):
    if not np.isfinite(g): return 0.1
    xs = np.array([6, 10, 12, 13, 14, 16, 18, 19, 20], float)
    ys = np.array([0.02,0.03,0.04,0.05,0.06,0.10,0.20,0.30,0.50], float)
    if g <= xs[0]: return float(ys[0])
    if g >= xs[-1]: return float(ys[-1])
    return float(np.interp(g, xs, ys))

def ruwe_inflation_factor(ruwe, parallax_mas, gmag=np.nan):
    if not np.isfinite(ruwe) or ruwe <= 1.0:
        return 1.0
    if (not np.isfinite(parallax_mas)) or (parallax_mas == 0):
        parallax_mas = 0.1
    sigma_eta = sigma_eta_mas_from_G(gmag)
    if not np.isfinite(sigma_eta) or sigma_eta <= 0:
        sigma_eta = 0.1
    base1 = max(parallax_mas, 0.01) / 10.0
    base2 = sigma_eta / 0.1
    fmax = E_F0 * (base1 ** E_BETA) * (base2 ** E_GAMMA)
    if not np.isfinite(fmax) or fmax < 1.0:
        fmax = 1.0
    f = 1.0 + (fmax - 1.0) * (1.0 - np.exp(-E_ALPHA * (ruwe - 1.0)))
    if not np.isfinite(f): return 1.0
    return float(max(f, 1.0))

# ================= Zero-point =================
def apply_zpt(parallax_mas, catalog="DR3", gmag=None, bp=None, rp=None, nu_eff=None):
    if not USE_ZPT or not np.isfinite(parallax_mas):
        return parallax_mas
    zpt = ZPT_CONST_EDR3 if (catalog.upper()=="EDR3") else ZPT_CONST_DR3
    if np.isfinite(gmag) and np.isfinite(bp) and np.isfinite(rp):
        color = bp - rp
        try:
            if gmag < 13:      zpt = -0.010 + 0.002*(color-1.0)
            elif gmag < 17:    zpt = -0.017 + 0.001*(color-1.0)
            else:              zpt = -0.025 + 0.002*(color-1.0)
        except Exception:
            pass
    return float(parallax_mas - zpt)

# ================= Sky & prior =================
def ecliptic_lat_deg(ra_deg, dec_deg):
    try:
        sc = SkyCoord(ra=ra_deg*u.deg, dec=dec_deg*u.deg, frame='icrs')
        ecl = sc.transform_to(BarycentricTrueEcliptic())
        return float(ecl.lat.to(u.deg).value)
    except Exception:
        return np.nan

def L_of_sky(ra_deg, dec_deg, bj_p50=None):
    # mild adaptation by BJ to tame fat tails (optional)
    if np.isfinite(bj_p50):
        return float(np.clip(0.8*bj_p50, 300.0, 2500.0))
    return L_EDSD_PC_DEFAULT

# ================= Posterior (EDSD) =================
def _adaptive_r_bounds(parallax_mas, sigma_mas, L_pc):
    if np.isfinite(parallax_mas) and np.isfinite(sigma_mas) and sigma_mas>0 and parallax_mas>0:
        r0 = 1000.0 / parallax_mas
        k  = max(5.0, 8.0 * (sigma_mas / max(parallax_mas, 1e-6)))
        rmin = max(1.0, r0 / k)
        rmax = min(100000.0, r0 * k)
    else:
        rmin = 1.0
        rmax = min(100000.0, 5.0*L_pc)
    if rmax <= rmin: rmax = rmin * 1.5
    return rmin, rmax

def distance_posterior_quantiles(parallax_mas, sigma_mas, L_pc, qs=(0.16,0.50,0.84), N=5000):
    if not (np.isfinite(parallax_mas) and np.isfinite(sigma_mas) and sigma_mas>0 and np.isfinite(L_pc) and L_pc>0):
        return [np.nan for _ in qs]
    rmin, rmax = _adaptive_r_bounds(parallax_mas, sigma_mas, L_pc)
    r = np.linspace(rmin, rmax, int(N))
    pi_true = 1000.0 / r
    log_like  = -0.5*((parallax_mas - pi_true)/sigma_mas)**2
    log_prior = 2*np.log(r) - r/L_pc
    log_post  = log_like + log_prior
    m = np.nanmax(log_post)
    post = np.exp(log_post - m)
    cdf = np.cumsum(post)
    cdf /= cdf[-1]
    return [float(np.interp(q, cdf, r)) for q in qs]

# ================= SMART CATALOG CHOICE =================
def choose_parallax_candidate(edr3_plx, edr3_eplx, dr3_plx, dr3_eplx, ruwe, gmag, bp, rp,
                              dup, nper, solved, g_context):
    candidates = []
    for catalog, (plx, eplx) in {"EDR3": (edr3_plx, edr3_eplx), "DR3": (dr3_plx, dr3_eplx)}.items():
        if not (np.isfinite(plx) and (np.isfinite(eplx) or USE_ERR_FALLBACK)):
            continue
        sigma_raw = eplx if np.isfinite(eplx) else DEFAULT_PLX_ERR_MAS
        plx_corr  = apply_zpt(plx, catalog=catalog, gmag=gmag, bp=bp, rp=rp)
        f_ruwe    = ruwe_inflation_factor(ruwe, plx_corr, gmag)
        sigma_infl= max(f_ruwe * sigma_raw, SIGMA_FLOOR_MAS)
        sigma_tot = float(np.hypot(sigma_infl, SIGMA_SYS_MAS))
        penalty = 1.0
        if dup: penalty *= 1.5
        if np.isfinite(nper) and nper < 8: penalty *= 1.3
        if np.isfinite(solved) and solved < 31: penalty *= 1.3
        score = sigma_tot * penalty
        candidates.append((score, catalog, plx_corr, sigma_tot, f_ruwe, sigma_raw))
    if not candidates:
        return None
    best = sorted(candidates, key=lambda x: x[0])[0]
    _, use_catalog, plx_corr, sigma_tot, f_ruwe, sigma_raw = best
    g_context["ruwe_inflation_f"] = f_ruwe
    g_context["parallax_err_mas_raw"] = sigma_raw
    return use_catalog, plx_corr, sigma_tot

# ================= ADOPTION LOGIC =================
def decide_adopted(out):
    """
    Returns (adopted_value, adopted_err, origin, quality_note).
    Rules:
      - Prefer EDSD_new if SNR >= SNR_MIN, QC ok, and rel.err <= MAX_REL_ERR_NEW
      - Else fallback to BJ_old if rel.err <= MAX_REL_ERR_OLD
      - Else pick the one with smaller relative error
    """
    notes = []

    new_val  = out.get("distance_new_pc")
    new_err  = out.get("distance_new_err_pc")
    new_rel  = (new_err / new_val) if (np.isfinite(new_err) and np.isfinite(new_val) and new_val>0) else np.inf

    old_val  = out.get("distance_old_pc")
    old_err  = out.get("distance_old_err_pc")
    old_rel  = (old_err / old_val) if (np.isfinite(old_err) and np.isfinite(old_val) and old_val>0) else np.inf

    snr      = out.get("parallax_snr")
    ruwe_bad = bool(out.get("ruwe_flag", False))
    nper_bad = bool(out.get("low_nper_flag", False))
    dup_bad  = bool(out.get("dup_flag", False))
    sep      = out.get("dr3_sep_arcsec", np.nan)
    sep_bad  = (np.isfinite(sep) and sep > MAX_DR3_SEP_ARCSEC)

    if not np.isfinite(snr) or snr < SNR_MIN: notes.append("low_SNR")
    if ruwe_bad:   notes.append("high_RUWE")
    if nper_bad:   notes.append("low_Nper")
    if dup_bad:    notes.append("duplicated")
    if sep_bad:    notes.append("large_sep")
    if not np.isfinite(new_rel): notes.append("new_no_CI")
    if not np.isfinite(old_rel): notes.append("old_no_CI")

    qc_ok = (not REQUIRE_GOOD_QC) or (not ruwe_bad and not nper_bad and not dup_bad and not sep_bad)
    if np.isfinite(new_val) and np.isfinite(new_err) and np.isfinite(snr):
        if (snr >= SNR_MIN) and qc_ok and (new_rel <= MAX_REL_ERR_NEW):
            return new_val, new_err, "EDSD_new", ("ok" if not notes else ",".join(sorted(set(notes))))

    if np.isfinite(old_val) and np.isfinite(old_err) and (old_rel <= MAX_REL_ERR_OLD):
        notes.append("fallback_BJ")
        return old_val, old_err, "BJ_old", ",".join(sorted(set(notes)))

    cand = []
    if np.isfinite(new_val) and np.isfinite(new_err): cand.append(("EDSD_new", new_val, new_err, new_rel))
    if np.isfinite(old_val) and np.isfinite(old_err): cand.append(("BJ_old",   old_val, old_err, old_rel))
    if cand:
        origin, v, e, _ = sorted(cand, key=lambda t: t[3])[0]
        notes.append("min_rel_err")
        return v, e, origin, ",".join(sorted(set(notes)))

    notes.append("no_adoptable_value")
    return np.nan, np.nan, None, ",".join(sorted(set(notes)))

# ================= MAIN =================
rows = []
for idx, row in tqdm(list(df.iterrows()), total=len(df)):
    verbose = idx < VERBOSE_FIRST_N

    ra, dec, used_name, resolver = find_coords(row, verbose=verbose)
    out = {
        "name": row.get(NAME_COL) if NAME_COL and not pd.isna(row.get(NAME_COL)) else used_name,
        "resolved_as": used_name, "resolver": resolver,
        "ra_deg": ra, "dec_deg": dec,
    }

    if ra is None:
        out.update({"status":"no_coords"})
        rows.append(out); continue

    out["ecl_lat_deg"] = ecliptic_lat_deg(ra, dec)

    # BJ (old)
    bj = get_bj_distance(ra, dec, verbose=verbose)
    if bj is None:
        out.update({"status":"no_gaia_bj"})
        rows.append(out); continue

    bj_p50 = bj["distance"]
    bj_p16 = bj.get("p16", np.nan)
    bj_p84 = bj.get("p84", np.nan)
    if np.isfinite(bj_p16) and np.isfinite(bj_p84):
        bj_err = 0.5*(bj_p84 - bj_p16)
    else:
        bj_err = bj.get("err_sym", np.nan)

    out.update({
        "status":"ok",
        "distance_old_pc": bj_p50,
        "distance_old_p16_pc": bj_p16,
        "distance_old_p84_pc": bj_p84,
        "distance_old_err_pc": bj_err,
        "bj_used_model": bj["used_model"],
        "gaia_source_id": bj["source_id"],
        "bj_flag": bj["flag"],
        "bj_ra_deg": bj.get("bj_ra_deg", np.nan),
        "bj_dec_deg": bj.get("bj_dec_deg", np.nan),
        "bj_sep_arcsec": bj.get("bj_sep_arcsec", np.nan),
    })

    # EDR3 parallax (optional)
    plx_edr, eplx_edr = (np.nan, np.nan)
    if bj["source_id"] is not None:
        plx_edr, eplx_edr = get_edr3_parallax(bj["source_id"], verbose=verbose)
    out.update({"edr3_parallax": plx_edr, "edr3_parallax_err": eplx_edr})

    # DR3 info
    dr3 = get_dr3_info(
        bj["source_id"], ra_deg=ra, dec_deg=dec,
        bj_ra_deg=bj.get("bj_ra_deg"), bj_dec_deg=bj.get("bj_dec_deg"),
        verbose=verbose
    )
    out.update(dr3)

    # QC flags
    out["ruwe_flag"]     = (out["ruwe"] > 1.4) if np.isfinite(out["ruwe"]) else False
    out["low_nper_flag"] = (out["visibility_periods_used"] < 8) if np.isfinite(out["visibility_periods_used"]) else False
    out["solved_flag"]   = (out["astrometric_params_solved"] >= 31) if np.isfinite(out["astrometric_params_solved"]) else False
    out["dup_flag"]      = bool(out.get("duplicated_source", False))

    # choose parallax
    ctx = {}
    choice = choose_parallax_candidate(
        out.get("edr3_parallax"), out.get("edr3_parallax_err"),
        out.get("dr3_parallax"),  out.get("dr3_parallax_err"),
        out.get("ruwe"), out.get("gmag"), out.get("bp"), out.get("rp"),
        out.get("duplicated_source"), out.get("visibility_periods_used"), out.get("astrometric_params_solved"),
        ctx
    )

    if choice is None:
        out.update({
            "posterior_catalog": None,
            "parallax_mas_raw": np.nan,
            "parallax_mas_corr": np.nan,
            "parallax_err_mas_raw": np.nan,
            "parallax_sigma_total_mas": np.nan,
            "ruwe_inflation_f": np.nan,
            "L_pc_used": L_of_sky(ra, dec, bj_p50=out.get("distance_old_pc")),
            "distance_new_p16_pc": np.nan,
            "distance_new_pc": np.nan,
            "distance_new_p84_pc": np.nan,
            "distance_new_err_pc": np.nan,
            "parallax_snr": np.nan
        })
        # adoption
        adopt_v, adopt_e, origin, note = decide_adopted(out)
        out["adopted_distance_pc"] = adopt_v
        out["adopted_distance_err_pc"] = adopt_e
        out["adopted_origin"] = origin
        out["quality_note"] = note

        rows.append(out); continue

    use_catalog, parallax_corr, sigma_total = choice
    parallax_mas_raw = out["edr3_parallax"] if use_catalog=="EDR3" else out["dr3_parallax"]

    # Prior scale (mildly adapted by BJ)
    L_pc = L_of_sky(ra, dec, bj_p50=out.get("distance_old_pc"))

    # New distance (quantiles)
    p16, p50, p84 = distance_posterior_quantiles(parallax_corr, sigma_total, L_pc=L_pc, qs=(0.16,0.50,0.84))
    new_err = 0.5*(p84 - p16) if (np.isfinite(p16) and np.isfinite(p84)) else np.nan
    parallax_snr = (parallax_corr / sigma_total) if (np.isfinite(parallax_corr) and np.isfinite(sigma_total) and sigma_total>0) else np.nan

    out.update({
        "posterior_catalog": use_catalog,
        "parallax_mas_raw": parallax_mas_raw,
        "parallax_mas_corr": parallax_corr,
        "parallax_err_mas_raw": ctx.get("parallax_err_mas_raw"),
        "parallax_sigma_total_mas": sigma_total,
        "ruwe_inflation_f": ctx.get("ruwe_inflation_f"),
        "L_pc_used": L_pc,
        "distance_new_p16_pc": p16,
        "distance_new_pc": p50,
        "distance_new_p84_pc": p84,
        "distance_new_err_pc": new_err,
        "parallax_snr": parallax_snr
    })

    # Adoption
    adopt_v, adopt_e, origin, note = decide_adopted(out)
    out["adopted_distance_pc"] = adopt_v
    out["adopted_distance_err_pc"] = adopt_e
    out["adopted_origin"] = origin
    out["quality_note"] = note

    rows.append(out)

full = pd.DataFrame(rows)

# ---------- save full debug table ----------
cols_order = [
    "name","resolved_as","resolver","status",
    "ra_deg","dec_deg","ecl_lat_deg",
    "gaia_source_id",
    "bj_ra_deg","bj_dec_deg","bj_sep_arcsec",
    "distance_old_pc","distance_old_p16_pc","distance_old_p84_pc","distance_old_err_pc","bj_used_model","bj_flag",
    "posterior_catalog",
    "ruwe","ruwe_flag","gmag","bp","rp","nu_eff","pseudocolour",
    "dr3_parallax","dr3_parallax_err","dr3_ra_deg","dr3_dec_deg","dr3_sep_arcsec",
    "edr3_parallax","edr3_parallax_err",
    "parallax_mas_raw","parallax_mas_corr","parallax_err_mas_raw",
    "ruwe_inflation_f","parallax_sigma_total_mas","parallax_snr",
    "astrometric_sigma5d_max","visibility_periods_used","low_nper_flag",
    "astrometric_params_solved","solved_flag","duplicated_source","dup_flag",
    "L_pc_used",
    "distance_new_p16_pc","distance_new_pc","distance_new_p84_pc","distance_new_err_pc",
    "adopted_distance_pc","adopted_distance_err_pc","adopted_origin","quality_note"
]
cols_order = [c for c in cols_order if c in full.columns]
full = full[cols_order]

full.to_csv("gaia_full_with_ruwe.csv", index=False)
print("Saved → gaia_full_with_ruwe.csv")

# ---------- save compact 3-col view ----------
out3 = full[["name","distance_old_pc","distance_new_pc"]].copy()
out3.to_csv("gaia_distances_new.csv", index=False)
print("Saved → gaia_distances_new.csv")

# ---------- save compact with CIs ----------
out3ci = full[["name","distance_old_pc","distance_old_p16_pc","distance_old_p84_pc",
               "distance_new_p16_pc","distance_new_pc","distance_new_p84_pc"]].copy()
out3ci.to_csv("gaia_distances_new_with_ci.csv", index=False)
print("Saved → gaia_distances_new_with_ci.csv")

# ---------- summary with ± errors & adoption ----------
summary = full[[
    "name",
    "distance_old_pc","distance_old_err_pc",
    "distance_new_pc","distance_new_err_pc",
    "adopted_distance_pc","adopted_distance_err_pc","adopted_origin",
    "parallax_snr","quality_note"
]].copy()
summary.to_csv("gaia_distances_summary_1.csv", index=False)
print("Saved → gaia_distances_summary.csv")

print(summary.head(12))