In [None]:
"""
BEACON post-analysis and plotting template.

Applies to all molecular systems. For each molecule, update the reference minima
files in `minima.xyz/`.
All other logic—parsing, normalisation, success curves, and plotting—stays the same.
"""

from __future__ import annotations
import re
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib

try:
    from IPython import get_ipython
    if get_ipython() is None:
        matplotlib.use("Agg")
except Exception:
    matplotlib.use("Agg")

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.ticker import PercentFormatter, MultipleLocator

# Export settings
PANEL_W_IN, PANEL_H_IN, DPI = 6.5, 4.6, 600
SAVE_MOSAIC = False

import matplotlib as mpl
mpl.rcParams.update({
    "savefig.dpi": DPI,
    "figure.dpi": 150,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "svg.fonttype": "none",
    "font.size": 13,
    "axes.titlesize": 14,
    "axes.labelsize": 13,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 11,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.linewidth": 1.2,
    "xtick.major.size": 4,
    "ytick.major.size": 4,
})

MEAN_LW    = 3.0
MEDIAN_LW  = 2.6
RUN_LW     = 1.9
GRID_ALPHA = 0.28

# Paths and labels
PARENT  = Path(".")          # holds run ID (run_1, run_2,..)
LOGNAME = "beacon_run.log"

def infer_labels(base: Path):
    cwd = base.resolve()
    prior = cwd.name
    mol = None
    for p in [cwd] + list(cwd.parents):
        if p.name.startswith("project_root_"):
            mol = p.name.replace("project_root_", "")
            break
    system = f"{mol} — {prior}" if mol else prior
    prefix = f"{mol}_{prior}_beacon" if mol else f"{prior}_beacon"
    return system, prefix

SYSTEM_LABEL, OUT_PREFIX = infer_labels(PARENT)

# Config Parameters
INCLUDE_INIT_MIN = True
EARLY_STEPS      = 15
LEGEND_FONT_SIZE = 11

# Success target
BEACON_BAND_MEV = (10.0, 200.0)
TARGET_MEV      = 10.0
ADD_SUCCESS_CI  = True
EXPECTED        = 100

# Known minima (XYZ files)
SHOW_MINIMA_LINES = True
MINIMA_DIR  = PARENT / "minima.xyz"   #molecule-specific minima xyz
MINIMA_GLOB = "*.xyz"

# Legend order 
EXPECTED_MINIMA = ["NH3_min1"]
MINIMA_COLOR_MAP = {"NH3_min1": "#17becf"}
FALLBACK_MINIMA_COLORS = ["#e377c2", "#8c564b", "#2ca02c", "#d62728", "#9467bd", "#ff7f0e"]

# Draw minima dashed line at 0
Y_EPS_SYMLOG_BASE = 0.0
Y_EPS_LINEAR_BASE = 0.0

# Regex & constants
rx_hdr  = re.compile(r"\[INFO\]\s+Run-ID:\s+(?P<rid>\d+).+RNG seed:\s+(?P<seed>\d+)")
rx_init = re.compile(r"Init structure\s+\d+\s+Energy:\s+([+-]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)")
rx_best = re.compile(r"\[BEST\]\s+Best-so-far energy:\s+([+-]?\d+(?:\.\d+)?(?:[eE][+-]?\d+)?)")

# Hartree -> eV factor
HARTREE2EV = 27.211386024367243

# Accept any of:
#   Properties=... energy=-1520.360084833535 ...
#   Properties=... energy=-1520.36 eV ...
#   E = -55.87 Hartree
energy_prop_pat    = re.compile(r"\benergy\s*=\s*([-\d.+Ee]+)\s*(eV)?\b", re.IGNORECASE)
energy_hartree_pat = re.compile(r"\bE\s*=\s*([-\d.+Ee]+)\s*Hartree\b", re.IGNORECASE)

# Helpers
def parse_log(path: Path):
    run_id = seed = None
    init_E, best_series, it = [], [], 0
    if not path.exists():
        return None
    for line in path.read_text(encoding="utf-8", errors="ignore").splitlines():
        m = rx_hdr.search(line)
        if m:
            run_id = int(m.group("rid")); seed = int(m.group("seed"))
        m = rx_init.search(line)
        if m:
            init_E.append(float(m.group(1)))
        m = rx_best.search(line)
        if m:
            it += 1
            best_series.append((it, float(m.group(1))))
    if not best_series and not init_E:
        return None
    if INCLUDE_INIT_MIN and init_E:
        best_series = [(0, float(np.min(init_E)))] + best_series
    return dict(run=run_id, seed=seed, series=best_series)

def read_minima_energies_xyz(folder: Path, pattern: str = "*.xyz"):
    """
    Return list[(stem_label, energy_eV)].

    Handles single- or multi-frame XYZ files. Scans up to ~200 lines and collects
    all energy tags; uses the lowest energy per file.

    Supports:
      - 'Properties=... energy=<val>'   (unitless or 'eV' -> treated as eV)
      - 'E = <val> Hartree'             (converted to eV)
    """
    out = []
    if not folder.exists():
        return out
    for f in sorted(folder.glob(pattern)):
        E_list = []
        with open(f, "r", errors="replace") as fh:
            for i, line in enumerate(fh):
                if i > 200:
                    break
                mH = energy_hartree_pat.search(line)
                if mH:
                    try:
                        E_list.append(float(mH.group(1)) * HARTREE2EV)
                        continue
                    except Exception:
                        pass
                mE = energy_prop_pat.search(line)
                if mE:
                    try:
                        E_list.append(float(mE.group(1)))  # assume eV if unitless
                    except Exception:
                        pass
        if E_list:
            out.append((f.stem, float(np.min(E_list))))
        else:
            print(f"[WARN] No energy found in {f.name} (expected 'energy=...' or 'E = ... Hartree').")
    return out

def assign_minima_colors(minima_pairs):
    ordered, used = [], set()
    for lbl in EXPECTED_MINIMA:
        match = None
        for s, E in minima_pairs:
            if s.lower() == lbl.lower():
                match = (s, E)
                break
        if match is not None:
            ordered.append({"label": lbl, "E_eV": match[1], "color": MINIMA_COLOR_MAP.get(lbl, "#000000")})
            used.add(match[0])
    extras = [(s, E) for (s, E) in minima_pairs if s not in used]
    extras.sort(key=lambda t: t[0].lower())
    for i, (s, E) in enumerate(extras):
        ordered.append({"label": s, "E_eV": E, "color": FALLBACK_MINIMA_COLORS[i % len(FALLBACK_MINIMA_COLORS)]})
    return ordered

# Collect runs
runs_meta = []
for d in sorted(PARENT.glob("run_*")):
    rec = parse_log(d / LOGNAME)
    if rec:
        runs_meta.append(rec)
if not runs_meta:
    raise RuntimeError("No runs found (expected run_*/beacon_run.log).")

rows = []
for rec in runs_meta:
    for it, e in rec["series"]:
        rows.append({"run": rec["run"], "iter": it, "best_E": e})
df_long = pd.DataFrame(rows).sort_values(["run", "iter"])
piv = df_long.pivot(index="iter", columns="run", values="best_E").sort_index()

# Log-derived minimum 
log_min_eV = float(np.nanmin(piv.values))

# True global min 
energy_pat_fp = re.compile(r"(?<!free_)energy=([-\d.+Ee]+)\b")
def parse_global_min_info(parent: Path):
    path = parent / "GLOBAL_MIN.info"
    if not path.exists():
        return None
    run_name = None; step = None; energy = None
    for line in path.read_text(errors="ignore").splitlines():
        if line.startswith("run="):
            run_name = line.split("=",1)[1].strip()
        elif line.startswith("step="):
            try:
                step = int(line.split("=",1)[1].strip())
            except Exception:
                pass
        elif line.startswith("energy_eV="):
            try:
                energy = float(line.split("=",1)[1].strip())
            except Exception:
                energy = np.nan
    if run_name is None or step is None:
        return None
    m = re.match(r"run_(\d+)$", run_name)
    run_label = int(m.group(1)) if m else run_name
    return dict(run_label=run_label, step=step, energy=energy)

def scan_exact100_min(parent: Path):
    best = None
    for rd in sorted(parent.glob("run_*")):
        xyz = rd / "structures_dft.exact100.xyz"
        if not xyz.exists():
            continue
        E = []
        with open(xyz, "r", errors="replace") as f:
            for line in f:
                m = energy_pat_fp.search(line)
                if m:
                    try:
                        E.append(float(m.group(1)))
                    except Exception:
                        E.append(np.nan)
        if not E:
            continue
        E = np.array(E, float)
        if np.isfinite(E).any():
            j = int(np.nanargmin(E)); e = float(E[j])
            if best is None or e < best["energy"]:
                m = re.match(r".*run_(\d+)$", rd.name)
                run_label = int(m.group(1)) if m else rd.name
                best = dict(run_label=run_label, step=j, energy=e)
    return best

min_true = parse_global_min_info(PARENT) or scan_exact100_min(PARENT)
fp_min_eV = float(min_true["energy"]) if (min_true and min_true.get("energy") is not None) else np.nan

# minima.xyz 
raw_minima = read_minima_energies_xyz(MINIMA_DIR, MINIMA_GLOB)   
minima_ordered = assign_minima_colors(raw_minima)
minima_dir_min_eV = float(np.min([E for _, E in raw_minima])) if raw_minima else np.nan

# Choose a single ΔE reference (prefer minima.xyz, then FP, then log)
candidates = [x for x in [minima_dir_min_eV, fp_min_eV, log_min_eV] if np.isfinite(x)]
if not candidates:
    raise RuntimeError("No valid reference energy found (minima.xyz, FP pipeline, or logs).")
REF_E = float(np.min(candidates))  # the lowest available

# Diagnostics
print("[CHECK] Reference candidates (eV): "
      f"minima_dir={minima_dir_min_eV if np.isfinite(minima_dir_min_eV) else 'nan'}, "
      f"fp_min={fp_min_eV if np.isfinite(fp_min_eV) else 'nan'}, "
      f"log_min={log_min_eV:.6f}")
print(f"[INFO] ΔE reference (REF_E): {REF_E:.6f} eV "
      f"({'minima.xyz' if np.isfinite(minima_dir_min_eV) and REF_E==minima_dir_min_eV else ('FP min' if np.isfinite(fp_min_eV) and REF_E==fp_min_eV else 'log min')})")
if REF_E > log_min_eV + 1e-4:
    print(f"[WARN] REF_E ({REF_E:.6f}) is higher than log min ({log_min_eV:.6f}). Check sources.")

# ΔE(meV) vs REF_E
delta_mev = (piv - REF_E) * 1000.0
delta_mev = delta_mev.clip(lower=0)

# Stats
mean_mev    = delta_mev.mean(axis=1)
median_mev  = delta_mev.median(axis=1)
q25_mev     = delta_mev.quantile(0.25, axis=1)
q75_mev     = delta_mev.quantile(0.75, axis=1)

run_ids = sorted(list(delta_mev.columns))
n_runs  = len(run_ids)
run_list_str = ", ".join(map(str, run_ids))

# "Median trajectory" run
abs_dev = (delta_mev.sub(median_mev, axis=0)).abs()
median_run_id = int(abs_dev.sum(axis=0).idxmin())

# Known minima ΔE lines (vs REF_E)
minima_lines_mev = []
zero_count = 0
for m in minima_ordered:
    d_mev = max((m["E_eV"] - REF_E) * 1000.0, 0.0)
    if d_mev <= 1e-9:
        y_sym = 0.0
        y_lin = 0.0
        zero_count += 1
        minima_lines_mev.append({**m, "d_mev": d_mev, "y_sym": y_sym, "y_lin": y_lin, "is_zero": True})
    else:
        minima_lines_mev.append({**m, "d_mev": d_mev, "y_sym": d_mev, "y_lin": d_mev, "is_zero": False})

# Style helpers
colors = plt.cm.tab10(np.linspace(0, 1, max(10, n_runs)))
eps = 1e-12
def panel_header(ax, letter: str, title: str):
    ax.set_title("")
    y = 1.02
    ax.text(0.00, y, letter, transform=ax.transAxes,
            fontsize=14, fontweight="bold", va="bottom", ha="left")
    ax.text(0.055, y, title, transform=ax.transAxes,
            fontsize=14, va="bottom", ha="left")

# Success prerequisites
hit_target = (delta_mev <= TARGET_MEV)
succ_cum   = hit_target.cummax(axis=0)
succ_frac  = succ_cum.sum(axis=1) / n_runs

if ADD_SUCCESS_CI:
    z = 1.96
    k = succ_cum.sum(axis=1).values.astype(float)
    n = float(n_runs)
    phat = k / n
    den = 1.0 + (z**2)/n
    half = z * np.sqrt((phat*(1 - phat)/n) + (z**2)/(4*n**2)) / den
    center = (phat + (z**2)/(2*n)) / den
    lower = np.clip(center - half, 0, 1)
    upper = np.clip(center + half, 0, 1)
else:
    lower = succ_frac.values.copy()
    upper = succ_frac.values.copy()

# PANEL A (ΔE meV, symlog)
figA, axA = plt.subplots(figsize=(PANEL_W_IN, PANEL_H_IN), constrained_layout=True)
for j, rid in enumerate(run_ids):
    lw = MEDIAN_LW if rid == median_run_id else RUN_LW
    color = colors[j % len(colors)]
    axA.plot(delta_mev.index, delta_mev[rid], lw=lw, alpha=0.9, color=color, label=f"run {rid}")
axA.plot(mean_mev.index, mean_mev.values, color="black", lw=MEAN_LW, zorder=5, label="mean")
axA.fill_between(mean_mev.index, np.maximum(q25_mev, eps), np.maximum(q75_mev, eps),
                 color="#F58518", alpha=0.18, label="IQR (25–75%)")
axA.axvspan(delta_mev.index.min(), min(EARLY_STEPS, delta_mev.index.max()), color="gray", alpha=0.10, lw=0)
axA.set_yscale("symlog", linthresh=1e-3)
axA.axhline(0.0, ls=":", lw=1.4, color="0.6")

# Minima dashed lines + labels
minima_handles_A = []
for m in minima_lines_mev:
    y = m["y_sym"]; color = m["color"]
    axA.axhline(y, ls="--", lw=1.6, color=color, alpha=0.95)
    label_text = f"{m['label']}" if not m["is_zero"] else f"{m['label']} (≈0 meV)"
    axA.text(0.99, y, label_text, va="center", ha="right", fontsize=10, color=color,
             transform=axA.get_yaxis_transform())
    minima_handles_A.append(Line2D([0],[0], ls="--", lw=1.8, color=color, label=m["label"]))

axA.set_xlim(0, EXPECTED)
axA.xaxis.set_major_locator(MultipleLocator(10))
axA.xaxis.set_minor_locator(MultipleLocator(5))
axA.set_xlabel("BEACON step")
axA.set_ylabel("ΔE to global min [meV]")
axA.grid(True, which="both", alpha=GRID_ALPHA)
panel_header(axA, "A", "ΔE to global minimum [symlog]")

h_runs_A, _ = axA.get_legend_handles_labels()
axA.legend(handles=h_runs_A + minima_handles_A, loc="center left",
           bbox_to_anchor=(1.02, 0.5), frameon=False, ncol=1)

figA.savefig(f"{OUT_PREFIX}_panelA_delta_meV.pdf", bbox_inches="tight", pad_inches=0.02)
figA.savefig(f"{OUT_PREFIX}_panelA_delta_meV.png", dpi=DPI, bbox_inches="tight", pad_inches=0.02)

# PANEL B (ΔE meV, linear)
figB, axB = plt.subplots(figsize=(PANEL_W_IN, PANEL_H_IN), constrained_layout=True)
axB.plot(mean_mev.index, mean_mev.values, color="black", lw=MEAN_LW, label="mean")
axB.fill_between(mean_mev.index, np.maximum(q25_mev, eps), np.maximum(q75_mev, eps),
                 color="#F58518", alpha=0.18, label="IQR (25–75%)")
axB.axhline(0.0, ls=":", lw=1.4, color="0.6")

minima_handles_B = []
for m in minima_lines_mev:
    y = m["y_lin"]; color = m["color"]
    axB.axhline(y, ls="--", lw=1.6, color=color, alpha=0.95)
    label_text = f"{m['label']}" if not m["is_zero"] else f"{m['label']} (≈0 meV)"
    axB.text(0.99, y, label_text, va="center", ha="right", fontsize=10, color=color,
             transform=axB.get_yaxis_transform())
    minima_handles_B.append(Line2D([0],[0], ls="--", lw=1.8, color=color, label=m["label"]))

axB.set_xlabel("BEACON step")
axB.set_ylabel("ΔE to global min [meV]")
axB.grid(True, alpha=GRID_ALPHA)
panel_header(axB, "B", "ΔE to global minimum [linear]")

h_runs_B, _ = axB.get_legend_handles_labels()
axB.legend(handles=h_runs_B + minima_handles_B, loc="center left",
           bbox_to_anchor=(1.02, 0.5), frameon=False)

figB.savefig(f"{OUT_PREFIX}_panelB_delta_meV_linear.pdf", bbox_inches="tight", pad_inches=0.02)
figB.savefig(f"{OUT_PREFIX}_panelB_delta_meV_linear.png", dpi=DPI, bbox_inches="tight", pad_inches=0.02)

# PANEL C (success curve)
figC, axC = plt.subplots(figsize=(PANEL_W_IN, PANEL_H_IN), constrained_layout=True)
target_text = f"ΔE ≤ {TARGET_MEV:g} meV"
if ADD_SUCCESS_CI:
    axC.fill_between(delta_mev.index, lower, upper, alpha=0.18, label="95% confidence band")
axC.plot(delta_mev.index, succ_frac.values, lw=MEAN_LW-0.4, label=target_text)
axC.yaxis.set_major_formatter(PercentFormatter(1.0))
axC.set_ylabel("runs meeting target (cumulative, %)")
axC.set_xlabel("BEACON step")
axC.axhline(0.5, ls=":", color="0.35")

ge_half_idx = np.where(succ_frac.values >= 0.5)[0]
median_step = int(succ_frac.index[ge_half_idx[0]]) if ge_half_idx.size > 0 else None
idx_all     = np.where(succ_frac.values >= 1.0)[0]
step_all    = int(succ_frac.index[idx_all[0]]) if idx_all.size > 0 else None
if median_step is not None:
    axC.axvline(median_step, ls=":", color="0.35")
if step_all is not None:
    axC.axvline(step_all, ls="--", color="0.25")
    axC.text(step_all, 0.02, f"100% by step {step_all}", rotation=90,
             va="bottom", ha="right", fontsize=11, color="0.25")

axC.set_ylim(0, 1.02)
axC.grid(True, alpha=0.3)
panel_header(axC, "C", f"Success rate (target: {target_text})")
axC.legend(loc="lower right", frameon=False)

figC.savefig(f"{OUT_PREFIX}_panelC_success.pdf", bbox_inches="tight", pad_inches=0.02)
figC.savefig(f"{OUT_PREFIX}_panelC_success.png", dpi=DPI, bbox_inches="tight", pad_inches=0.02)

# Console summary
print(f"System: {SYSTEM_LABEL}")
print(f"Runs: {n_runs} (IDs: {run_list_str})")
loaded_labels = [m["label"] for m in minima_ordered]
print(f"[INFO] Loaded minima files: {len(minima_ordered)} -> {loaded_labels}")
minima_mev = [(m["label"], round(max((m["E_eV"] - REF_E) * 1000.0, 0.0), 6)) for m in minima_ordered]
print("[INFO] Minima ΔE (meV):", minima_mev)
print("Saved single panels: "
      f"{OUT_PREFIX}_panelA_delta_meV.(pdf/png), "
      f"{OUT_PREFIX}_panelB_delta_meV_linear.(pdf/png), "
      f"{OUT_PREFIX}_panelC_success.(pdf/png)")
