"""
Baseline BEACON workflow with RLJ Prior for Bayesian optimisation of molecular PES minima.

Designed for reuse across molecules by changing only:
1. The input .xyz structures.
2. The prior configuration block (No Prior, RLJ, and MACE)

Run metadata (e.g., run ID, log file, evaluation count) is handled automatically
via command-line arguments and logged during execution.
"""


In [None]:
#!/usr/bin/env python3

from __future__ import annotations

import argparse, os, sys, logging, random
import numpy as np
from numpy.random import SeedSequence
from scipy.spatial.distance import pdist

from ase.io import read
from ase.calculators.gaussian import Gaussian
from ase.data import covalent_radii, atomic_numbers
from ase.parallel import world

import gpatom.beacon.beacon as beacon
from gpatom.beacon.str_gen import RandomBranch
from gpatom.gpfp.fingerprint import FingerPrint
from gpatom.gpfp.gp import GaussianProcess
from gpatom.gpfp.prior import CalculatorPrior, RepulsivePotential
from gpatom.gpfp.atoms_gp_interface import Model
from gpatom.gpfp.avishart_hpfitter import HpFitterConstantRatioParallel, calculate_all_distances
from hpfitter.pdistributions.normal import Normal_prior

# Command-line interface
parser = argparse.ArgumentParser(description="BEACON HF run for NH3 with clean GP/Prior logging.")
parser.add_argument("--run-id", type=int, default=0)
parser.add_argument("--log-file", type=str, default="beacon_run.log")
parser.add_argument("--ndft", type=int, default=100, help="number of outer HF evaluations")
parser.add_argument("--include-initial", action="store_true",
                    help="also log the initial ninit seed HF calls (else only the ndft outer steps)")
args = parser.parse_args()

# Logging setup
from logging import LoggerAdapter, getLogger, FileHandler, StreamHandler, Formatter
if world.rank == 0 and os.path.exists(args.log_file):
    os.remove(args.log_file)
world.barrier()

_base = getLogger("beacon_run")
_base.handlers.clear()
_base.setLevel(logging.INFO)

fh = FileHandler(args.log_file, mode="w")
fh.setFormatter(Formatter("%(asctime)s [run=%(run_id)s] %(message)s"))
_base.addHandler(fh)

if world.rank == 0:
    sh = StreamHandler(sys.stdout)
    sh.setFormatter(Formatter("%(asctime)s [run=%(run_id)s] %(message)s"))
    _base.addHandler(sh)

log = LoggerAdapter(_base, extra={"run_id": args.run_id})

# Random number generator setup
BASE_SEED = 77381
ss = SeedSequence(entropy=BASE_SEED, spawn_key=(int(args.run_id), int(world.rank)))
rng_seed = int(ss.generate_state(1, dtype=np.uint32)[0])

os.environ["PYTHONHASHSEED"] = str(rng_seed)
np.random.seed(rng_seed)
random.seed(rng_seed)
rng = np.random.RandomState(rng_seed)

if world.rank == 0:
    log.info(f"[INFO] Run-ID: {args.run_id} | MPI rank: {world.rank} | RNG seed: {rng_seed}")

# Load seed structures (NH3)
atoms0 = read("odd_ammonia_tshape.xyz")
atoms1 = read("NH3_triang.xyz")
for at in (atoms0, atoms1):
    at.center(vacuum=6.0)
    at.pbc = False
structure = [atoms0, atoms1]

# Hooked Gaussian to log each true HF call
OUTER_LOG_CB = None  # set in main()

class HookedGaussian(Gaussian):
    def calculate(self, atoms=None, properties=("energy",), system_changes=None):
        if atoms is not None and OUTER_LOG_CB is not None:
            OUTER_LOG_CB(atoms)
        return super().calculate(atoms=atoms, properties=properties, system_changes=system_changes)

# Gaussian calculator factory
def make_gaussian_calc() -> Gaussian:
    return HookedGaussian(method="HF", basis="3-21G", scf="xqc",
                          label=f"nh3_run{args.run_id}_{world.rank}")

calc_if = beacon.CalculatorInterface(calc=make_gaussian_calc)

# Structure generators
sgen = RandomBranch(atoms0, rng=rng)
rgen = RandomBranch(atoms1, rng=rng)
initgen = beacon.InitatomsGenerator(
    sgen=sgen, rgen=rgen,
    nrattle=3, rattlestrength=0.25, nbest=3, realfmax=0.05, rng=rng
)

# GP and prior setup
rp_rc = 0.90 * (covalent_radii[atomic_numbers['H']] + covalent_radii[atomic_numbers['N']])
prior = CalculatorPrior(RepulsivePotential('LJ', prefactor=1.0, rc=rp_rc, constant=0.0))

fp = FingerPrint(fp_args=dict(r_cutoff=8.0, a_cutoff=4.0, aweight=1.0), calc_strain=False)
gp = GaussianProcess(prior=prior, hp=dict(scale=50.0, weight=1.0, noise=1e-3), use_forces=True)
model = Model(gp=gp, fp=fp)

# Hyperparameter fitter
scale_prior = Normal_prior(mu=50.0, std=2.0)
def prior_method(E, _): return float(np.mean(E)) if len(E) else 0.0
def scale_bounds_method(fps):
    if len(fps) < 2: return [0.1, 15.0]
    dists, nn = calculate_all_distances(fps)
    return [float(np.median(nn)), 10.0 * float(np.max(dists))]

class SafeHP(HpFitterConstantRatioParallel):
    def __init__(self, *args, **kwargs):
        self._model_ref = None
        super().__init__(*args, **kwargs)
    def set_model(self, model_obj): self._model_ref = model_obj
    def fit(self, gp_obj):
        n_train = len(getattr(gp_obj, "fps", getattr(gp_obj, "X", [])))
        super().fit(gp_obj)
        log.info(f"[HP] Train pts: {n_train:3d} | scale={gp_obj.hp['scale']:5.1f} "
                 f"| weight={gp_obj.hp['weight']:5.1f} | noise={gp_obj.hp['noise']:.2e}")
        try:
            y = getattr(gp_obj, "y", None) or getattr(gp_obj, "Y", None)
            if y is not None and len(y):
                best = float(np.min(np.array(y).reshape(-1)))
                log.info(f"[BEST] Best-so-far energy: {best:.8f}")
        except Exception:
            pass

hp_optimizer = SafeHP(scale_prior=scale_prior,
                      prior_method=prior_method,
                      scale_bounds_method=scale_bounds_method)

# Covalent bond checker
class CovalentChecker:
    def __init__(self, factor: float = 0.90): self.factor = factor
    def _too_close(self, atoms):
        Z = atoms.get_atomic_numbers()
        d = pdist(atoms.get_positions())
        lim = [self.factor * (covalent_radii[Zi] + covalent_radii[Zj])
               for i, Zi in enumerate(Z[:-1]) for Zj in Z[i+1:]]
        return (d < np.array(lim)).any()
    def check(self, atoms, *_, **__):
        return (False, "short bond") if self._too_close(atoms) else (True, "accepted")
    accept = check
    def check_fingerprint_distances(self, *_, **__): return True

checker = CovalentChecker()

# BEACON setup
surropt = beacon.SurrogateOptimizer(fmax=0.05, relax_steps=100)
acq = beacon.LowerBound(kappa=2.0)
bo = beacon.BEACON(
    calculator=calc_if, initatomsgen=initgen, init_atoms=structure,
    model=model, ninit=2, ndft=args.ndft, nsur=3, surropt=surropt, acq=acq,
    checker=checker, hp_optimizer=hp_optimizer
)
hp_optimizer.set_model(bo.model)

# Main execution
def main():
    # Record once per HF calculator call (outer step)
    gp_vals, gp_vars, prior_vals, steps = [], [], [], []

    ninit = int(getattr(bo, "ninit", 0))
    hf_calls_seen = 0

    def _first_float(x, default=0.0):
        try:
            arr = np.asarray(x, dtype=float)
            return float(arr if arr.ndim == 0 else arr.flat[0])
        except Exception:
            return float(default)

    def record_true_hf(atoms, *a, **k):
        nonlocal hf_calls_seen
        
        fobj = fp.get(atoms)
        try:
            f, V = model.gp.predict(fobj, get_variance=True)
        except TypeError:
            f, V = model.gp.predict(fobj, True)

        E = _first_float(f, 0.0)
        varE = max(_first_float(V, 0.0), 0.0)

        # Prior energy
        pcalc = model.gp.prior.calculator
        pcalc.calculate(atoms)
        prior_e = float(pcalc.get_potential_energy(atoms))

        # Log one row per HF call; skip initial seeds unless specified
        include = args.include_initial or (hf_calls_seen >= ninit)
        if include and world.rank == 0:
            steps.append(len(steps))
            gp_vals.append(E)
            gp_vars.append(varE)
            prior_vals.append(prior_e)

        hf_calls_seen += 1
        

    # Attach callback for HookedGaussian
    global OUTER_LOG_CB
    OUTER_LOG_CB = record_true_hf

    # Run optimisation
    bo.run()
    if world.rank == 0:
        log.info("BEACON finished.")

    # Save CSV output
    if world.rank == 0:
        try:
            n = len(steps)
            arr = np.column_stack([
                np.asarray(steps, dtype=int),
                np.asarray(gp_vals, dtype=float),
                np.asarray(gp_vars, dtype=float),
                np.asarray(prior_vals, dtype=float),
            ])
            np.savetxt("gp_prior_vs_step.csv", arr, delimiter=",",
                       header="Step,GP,GP_var,Prior", comments="", fmt="%.10g")
            log.info(f"[GPPRIOR] Saved {n} rows to gp_prior_vs_step.csv "
                     f"(include_initial={bool(args.include_initial)}, ninit={ninit})")
        except Exception as e:
            log.info(f"[GPPRIOR] Save failed: {e}")

        # Summary
        try:
            y = getattr(model.gp, "y", None) or getattr(model.gp, "Y", None)
            if y is not None and len(y):
                best = float(np.min(np.array(y).reshape(-1)))
                num_train = len(getattr(model, "data", []))
                log.info(f"[SUMMARY] Training points: {num_train} | Final best: {best:.8f}")
        except Exception:
            pass

if __name__ == "__main__":
    main()


""" No prior Block for NH3, PF3Cl2, C3H6O) - # Section ---> "GP and prior setup"""

In [11]:
# No Prior
# Removed RLJ/Calculator prior; use a constant zero prior
prior = ConstantPrior(constant=0.0)

fp = FingerPrint(fp_args=dict(r_cutoff=8.0, a_cutoff=4.0, aweight=1.0), calc_strain=False)
gp = GaussianProcess(prior=prior, hp=dict(scale=50.0, weight=1.0, noise=1e-3), use_forces=True)
model = Model(gp=gp, fp=fp)



In [None]:
# Mace Prior
# Applied a MACE prior via from "mace.calculators import mace_off"

mace_model = os.environ.get("MACE_MODEL", "large") 
mace_calc = mace_off(model=mace_model, device="cpu")
prior = CalculatorPrior(mace_calc)

fp = FingerPrint(fp_args=dict(r_cutoff=8.0, a_cutoff=4.0, aweight=1.0),
                 calc_strain=False)
gp = GaussianProcess(prior=prior,
                     hp=dict(scale=50.0, weight=1.0, noise=1e-3),
                     use_forces=True)
model = Model(gp=gp, fp=fp)
