In [None]:
"""
BEACON median three-panel analysis (GP, length scale, and FP distance).

Applies to all molecular systems. Only the fingerprint minima references
(minima_fps.csv) need to be updated per molecule to match its minima
fingerprints. All other components—data parsing, GP processing, length-scale
extraction, and plotting—remain identical for reproducibility.
"""

from __future__ import annotations
import re
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

# Configuration
PRIOR_DIR = Path.cwd()
LOGNAME = "beacon_run.log"
INCLUDE_INIT_MIN = True

# Run discovery and log parsing
def immediate_run_dirs(base: Path):
    dirs = []
    for p in sorted(base.iterdir()):
        if p.is_dir() and (p / LOGNAME).exists():
            dirs.append(p)
    if not dirs:
        for pat in ("run_*", "seed_*"):
            dirs = [p for p in sorted(base.glob(pat)) if (p / LOGNAME).exists()]
            if dirs:
                break
    return dirs

run_dirs = immediate_run_dirs(PRIOR_DIR)
if not run_dirs:
    kids = [p.name for p in sorted(PRIOR_DIR.iterdir())]
    raise RuntimeError(
        f"No runs detected under: {PRIOR_DIR}\n"
        f"Expected subfolders with '{LOGNAME}'.\n"
        f"Found: {kids}"
    )
print(f"[FOUND] {len(run_dirs)} run folder(s): {[d.name for d in run_dirs][:6]}{' ...' if len(run_dirs)>6 else ''}")

rx_hdr  = re.compile(r"\[INFO\]\s+Run-ID:\s+(?P<rid>\d+).+RNG seed:\s+(?P<seed>\d+)")
rx_init = re.compile(r"Init structure\s+\d+\s+Energy:\s+([+-]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)")
rx_best = re.compile(r"\[BEST\]\s+Best-so-far energy:\s+([+-]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)")

def parse_energy_log(log_path: Path):
    run_id = seed = None
    init_E, best_series, it = [], [], 0
    for line in log_path.read_text(encoding="utf-8", errors="ignore").splitlines():
        m = rx_hdr.search(line)
        if m:
            run_id = int(m.group("rid")); seed = int(m.group("seed"))
        m = rx_init.search(line)
        if m:
            init_E.append(float(m.group(1)))
        m = rx_best.search(line)
        if m:
            it += 1
            best_series.append((it, float(m.group(1))))
    if not best_series and not init_E:
        return None
    if INCLUDE_INIT_MIN and init_E:
        best_series = [(0, float(np.min(init_E)))] + best_series
    return dict(run=run_id, seed=seed, series=best_series)

runs_meta = []
for d in run_dirs:
    rec = parse_energy_log(d / LOGNAME)
    if rec:
        if rec["run"] is None:
            suf = d.name.split("_")[-1]
            rec["run"] = int(suf) if suf.isdigit() else d.name
        runs_meta.append(rec)
if not runs_meta:
    raise RuntimeError(f"Logs found but unparsable: ensure '{LOGNAME}' has [INFO]/[BEST] lines.")

rows = []
for rec in runs_meta:
    for it, e in rec["series"]:
        rows.append({"run": rec["run"], "iter": it, "best_E": e})
df_long = pd.DataFrame(rows).sort_values(["run", "iter"])
piv = df_long.pivot(index="iter", columns="run", values="best_E").sort_index()

global_min_eV = float(np.nanmin(piv.values))
delta_mev = (piv - global_min_eV) * 1000.0
delta_mev = delta_mev.clip(lower=0)
median_mev = delta_mev.median(axis=1)
abs_dev = (delta_mev.sub(median_mev, axis=0)).abs()
median_run_id = int(abs_dev.sum(axis=0).idxmin())

prior_tag = PRIOR_DIR.name
parent_name = PRIOR_DIR.parent.name
mol = parent_name.replace("project_root_", "") if parent_name.startswith("project_root_") else parent_name
label = f"{mol} - {prior_tag} - run {median_run_id}"

run_dir = PRIOR_DIR / f"run_{median_run_id}" if (PRIOR_DIR / f"run_{median_run_id}").exists() else None
if run_dir is None:
    candidates = [d for d in run_dirs if d.name.endswith(f"_{median_run_id}")]
    run_dir = candidates[0] if candidates else run_dirs[0]
print(f"[Energy] median trajectory = {run_dir.name} | global min = {global_min_eV:.6f} eV")

# Clean GP CSV
CSV_IN, CSV_OUT = "gp_prior_vs_step.csv", "gp_prior_vs_step.cleaned.csv"
STEP_EPS, VAL_ATOL = 1e-9, 1e-12
RENAME = {"GP":"E_gp_pred_eV","GP_var":"Var_E_gp_pred_eV2","Prior":"E_prior_eV","Step":"Step"}

def _load_structured(csv_path: Path):
    A = np.genfromtxt(csv_path, delimiter=",", names=True, dtype=float, encoding=None)
    if A.size == 0:
        raise SystemExit(f"[ERR] empty CSV: {csv_path}")
    if A.ndim == 0:
        A = np.array([tuple(A.tolist())], dtype=A.dtype)
    return A

def _relabel(A):
    old = list(A.dtype.names or []); new = [RENAME.get(n, n) for n in old]
    if old == new:
        return A
    B = np.zeros(A.shape, dtype=[(n, float) for n in new])
    for o, n in zip(old, new):
        B[n] = A[o].astype(float)
    return B

def _ensure_cols(A):
    names = list(A.dtype.names or []); need=[]
    if "Var_E_gp_pred_eV2" not in names:
        need.append(("Var_E_gp_pred_eV2", float, np.nan))
    if "E_prior_eV" not in names:
        need.append(("E_prior_eV", float, np.nan))
    if not need:
        return A
    B = np.zeros(A.shape, dtype=A.dtype.descr + [(n,t) for n,t,_ in need])
    for n in A.dtype.names:
        B[n] = A[n]
    for n,_,val in need:
        B[n][:] = val
    return B

def _add_gp_value_column(A):
    if "GP_value_eV" in (A.dtype.names or []):
        return A
    if "E_gp_pred_eV" not in A.dtype.names or "E_prior_eV" not in A.dtype.names:
        return A
    gp_value = A["E_gp_pred_eV"].astype(float) + A["E_prior_eV"].astype(float)
    B = np.zeros(A.shape, dtype=A.dtype.descr + [("GP_value_eV", float)])
    for n in A.dtype.names:
        B[n] = A[n]
    B["GP_value_eV"] = gp_value
    return B

def _dedup_by_step(A, keep="last"):
    step_f = A["Step"].astype(float)
    step_ix = np.rint(step_f).astype(np.int64)
    off = np.abs(step_f - step_ix)
    if np.any(off > STEP_EPS):
        print(f"[WARN] non-integer Step values (max off={off.max():.3g}); rounding applied.")
    index_by = {}
    if keep == "last":
        for i, s in enumerate(step_ix):
            index_by[s] = i
    else:
        for i, s in enumerate(step_ix):
            if s not in index_by:
                index_by[s] = i
    keep_idx = np.array(sorted(index_by.values(), key=lambda i: (step_ix[i], i)), dtype=int)
    A2 = A[keep_idx]
    order = np.argsort(np.rint(A2["Step"]).astype(int))
    A2 = A2[order]
    keys = np.rint(A2["Step"]).astype(np.int64)
    if np.unique(keys).size != keys.size:
        raise SystemExit("[ERR] de-dup by Step failed.")
    return A2, (len(A) - len(A2))

def _drop_consecutive_value_dups(A, atol=VAL_ATOL):
    Egp = A["E_gp_pred_eV"].astype(float)
    Var = A["Var_E_gp_pred_eV2"].astype(float)
    Epri = A["E_prior_eV"].astype(float)
    keep = [0]
    for i in range(1, len(A)):
        same = (
            np.isclose(Egp[i], Egp[i-1], atol=atol) and
            np.isclose(Var[i], Var[i-1], atol=atol) and
            ((np.isnan(Epri[i]) and np.isnan(Epri[i-1])) or np.isclose(Epri[i], Epri[i-1], atol=atol))
        )
        if not same:
            keep.append(i)
    keep = np.array(keep, dtype=int)
    return A[keep], (len(A) - len(keep))

def _reindex_step(A):
    B = A.copy()
    B["Step"] = np.arange(len(B), dtype=float)
    return B

def _save_structured(path: Path, A):
    header = ",".join(A.dtype.names)
    np.savetxt(path, A, delimiter=",", header=header, comments="", fmt="%.10g")

gp_clean_csv = None
src = run_dir / CSV_IN
if src.exists():
    A = _load_structured(src)
    A = _relabel(A)
    A = _ensure_cols(A)
    A = _add_gp_value_column(A)
    n0 = len(A)
    uniq0 = np.unique(np.rint(A["Step"]).astype(np.int64)).size
    if uniq0 < n0:
        print(f"[INFO] {run_dir.name}: {n0-uniq0} duplicate Step row(s) detected.")
    A1, rem1 = _dedup_by_step(A, keep="last")
    if rem1:
        print(f"[DE-DUP step] {run_dir.name}: removed {rem1} rows (kept last).")
    A2, rem2 = _drop_consecutive_value_dups(A1, atol=VAL_ATOL)
    if rem2:
        print(f"[DE-DUP vals]  {run_dir.name}: removed {rem2} consecutive duplicates.")
    A3 = _reindex_step(A2)
    A3 = _add_gp_value_column(A3)
    out_path = run_dir / CSV_OUT
    _save_structured(out_path, A3)
    gp_clean_csv = out_path
    print(f"[CLEANED] {run_dir.name}: {n0} -> {len(A3)} rows written to {out_path.name}")
else:
    print(f"[GP] WARNING: {CSV_IN} missing in {run_dir}")

gp_df = None
if gp_clean_csv and gp_clean_csv.exists():
    A = _load_structured(gp_clean_csv)
    gp_df = pd.DataFrame({name: A[name] for name in A.dtype.names})

# GP length scale
rx_hp_A = re.compile(r"\[HP\]\s*Training points:\s*(\d+),\s*GP length scale:\s*([+\-]?\d+(?:\.\d+)?(?:[eE][+\-]?\d+)?)")
rx_hp_B = re.compile(r"\[HP\]\s*Train\s*pts:\s*(\d+)\s*\|\s*scale=\s*([+\-]?\d+(?:\.\d+)?(?:[eE][+\-]?\d+)?)")
rx_hp_C = re.compile(r"scale\s*=\s*([+\-]?\d+(?:\.\d+)?(?:[eE][+\-]?\d+)?)")

hp_rows = []
for d in run_dirs:
    rid = d.name.split("_")[-1]
    try:
        rid = int(rid)
    except Exception:
        pass
    lp = d / LOGNAME
    for line in lp.read_text(encoding="utf-8", errors="ignore").splitlines():
        m = rx_hp_A.search(line) or rx_hp_B.search(line)
        if m:
            hp_rows.append({"run": rid, "iter": int(m.group(1)), "scale": float(m.group(2))})
        else:
            m2 = rx_hp_C.search(line)
            if m2:
                hp_rows.append({
                    "run": rid,
                    "iter": len([r for r in hp_rows if r["run"] == rid]) + 1,
                    "scale": float(m2.group(1))
                })

ls_df = None
if hp_rows:
    hp_long = pd.DataFrame(hp_rows).sort_values(["run", "iter"])
    ls_df = hp_long[hp_long["run"].astype(str) == str(median_run_id)][["iter", "scale"]].rename(columns={"iter": "Step"})
else:
    print("[HP] WARNING: no [HP] lines matched; Panel B will display a notice.")

# Fingerprint data
fp_df = None
fp_csv = run_dir / "fp_per_run.csv"
if fp_csv.exists():
    fp_df = pd.read_csv(fp_csv)
else:
    print(f"[FP] WARNING: {fp_csv.name} missing; Panel C may show a notice.")

# Plotting
plt.rcParams.update({
    "figure.dpi": 200,
    "savefig.dpi": 600,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "svg.fonttype": "none",
    "font.size": 18,
    "axes.titlesize": 22,
    "axes.labelsize": 20,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 18,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.linewidth": 1.8,
    "lines.linewidth": 2.6,
})

COLOR_GP, COLOR_PRIOR = "tab:blue", "tab:green"
fig, axs = plt.subplots(3, 1, figsize=(16, 18), constrained_layout=True)
fig.set_constrained_layout_pads(w_pad=0.3, wspace=0.3, h_pad=0.35, hspace=0.35)

def set_beacon_axis(ax):
    ax.set_xlim(0, 100)
    ax.xaxis.set_major_locator(MultipleLocator(10))
    ax.xaxis.set_minor_locator(MultipleLocator(5))
    ax.set_xlabel("BEACON step")
    ax.tick_params(axis="both", which="major", length=9, width=1.8)
    ax.tick_params(axis="both", which="minor", length=5, width=1.4)
    ax.grid(True, alpha=0.25)

# A) GP value + Prior (meV)
axA = axs[0]
if gp_df is not None and {"Step"}.issubset(gp_df.columns):
    axA2 = axA.twinx()
    step = np.rint(gp_df["Step"].values).astype(int)
    if "GP_value_eV" in gp_df.columns:
        Emod = gp_df["GP_value_eV"].astype(float).values
    else:
        Emod = gp_df["E_gp_pred_eV"].astype(float).values + (
            gp_df["E_prior_eV"].astype(float).values if "E_prior_eV" in gp_df.columns else 0.0
        )
    Var = gp_df["Var_E_gp_pred_eV2"].astype(float).values if "Var_E_gp_pred_eV2" in gp_df.columns else np.full_like(Emod, np.nan)
    Epri = gp_df["E_prior_eV"].astype(float).values if "E_prior_eV" in gp_df.columns else np.full_like(Emod, np.nan)
    sigma = np.sqrt(np.clip(Var, 0.0, None))
    axA.plot(step, Emod, "-", color=COLOR_GP, label="GP value")
    if np.isfinite(sigma).any():
        axA.fill_between(step, Emod - sigma, Emod + sigma, alpha=0.25, color=COLOR_GP, label="GP ±1σ")
    if np.isfinite(Epri).any():
        axA2.plot(step, 1000.0 * Epri, "--", color=COLOR_PRIOR, label="Prior (meV)")
    axA.set_ylabel("Energy (eV)", color=COLOR_GP)
    axA.tick_params(axis='y', labelcolor=COLOR_GP)
    axA2.set_ylabel("Prior (meV)", color=COLOR_PRIOR)
    axA2.tick_params(axis='y', labelcolor=COLOR_PRIOR)
    axA.set_title("A  GP value and prior vs optimisation step (median run)")
    set_beacon_axis(axA)
    lines, labels = [], []
    for ax in (axA, axA2):
        h, lab = ax.get_legend_handles_labels()
        lines.extend(h); labels.extend(lab)
    if lines:
        leg = axA.legend(lines, labels, loc="upper right", framealpha=0.92, frameon=True)
        leg.get_frame().set_linewidth(1.4)
else:
    axA.text(0.5, 0.5, "GP CSV not found", ha="center", va="center", transform=axA.transAxes)
    axA.set_title("A  GP value and prior vs optimisation step (median run)")
    set_beacon_axis(axA)

# B) GP length scale
axB = axs[1]
if ls_df is not None and not ls_df.empty:
    axB.plot(ls_df["Step"], ls_df["scale"])
    axB.set_ylabel("GP length scale")
else:
    axB.text(0.5, 0.5, "No [HP] length-scale data parsed", ha="center", va="center", transform=axB.transAxes)
axB.set_title("B  GP length scale (median run)")
set_beacon_axis(axB)

# C) FP distance vs step (median run)
axC = axs[2]
plotted_any = False

def _find_minima_csv():
    for p in [PRIOR_DIR/"minima_fps.csv", Path.cwd()/"minima_fps.csv", Path("/mnt/data/minima_fps.csv")]:
        if p.exists():
            return p
    return None

if fp_df is not None:
    fp_cols = [c for c in fp_df.columns if c.startswith("FP_")]
    step_col = None
    for cand in ("Step", "step", "STEP"):
        if cand in fp_df.columns:
            step_col = cand
            break
    if fp_cols and step_col is not None:
        steps  = np.rint(fp_df[step_col].to_numpy()).astype(int)
        FP_run = fp_df[fp_cols].to_numpy(dtype=float)

        mcsv = _find_minima_csv()
        if mcsv is not None:
            mins = pd.read_csv(mcsv)
            mcols = [c for c in mins.columns if c.startswith("FP_")]
            if mcols and FP_run.shape[1] == mins[mcols].shape[1]:
                M = mins[mcols].to_numpy(dtype=float)
                names = mins["min_name"].astype(str).tolist()

                D_list = [np.linalg.norm(FP_run - M[j][None, :], axis=1) for j in range(len(names))]
                D = np.vstack(D_list)
                d_nearest = D.min(axis=0)

                for j, nm in enumerate(names):
                    axC.plot(steps, D[j], alpha=0.9, label=f"Distance to {nm}")
                axC.plot(steps, d_nearest, lw=3.0, ls="--", color="crimson", alpha=0.95,
                         label="Closest to any minimum")
                plotted_any = True

if not plotted_any:
    axC.text(0.5, 0.5, "No minima refs or FP distances available", ha="center", va="center", transform=axC.transAxes)

axC.set_ylabel("FP distance (L2)")
axC.set_title("C  FP distance vs step (median run)")
set_beacon_axis(axC)

# Legend for Panel C
legC = axC.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), framealpha=0.92, frameon=True)
legC.get_frame().set_linewidth(1.4)

# Save
fig.suptitle(label, y=0.995, fontsize=24)
out_base = f"{mol}_{prior_tag}_median_three_panel_gp_fp"
fig.savefig(PRIOR_DIR / f"{out_base}.pdf", bbox_inches="tight", facecolor="white")
fig.savefig(PRIOR_DIR / f"{out_base}.png", dpi=600, bbox_inches="tight", facecolor="white")
plt.show()

print(f"[DONE] median run: {median_run_id}\nSaved:\n  {PRIOR_DIR / (out_base + '.pdf')}\n  {PRIOR_DIR / (out_base + '.png')}")
