In [None]:
import os
import sys
import math
import time
import pickle
import subprocess
from functools import partial

import numpy as np
import pandas as pd

from skopt import gp_minimize
from skopt.space import Real
import matplotlib.pyplot as plt

try:
    import sixtracklib as stl
except Exception as e:
    print("Error importing sixtracklib:", e)
    print("Please install sixtracklib (pip install sixtracklib) and re-run.")
    raise

try:
    from pymadx import Madx
except Exception:
    Madx = None
    
try:
    import pysixtrack as ps
except Exception:
    ps = None

try:
    import xtrack as xt
    import xobjects as xo
except Exception:
    xt = None
    xo = None


MADX_SEQ_PATH = "/mnt/data/SIS18RING.SEQ"   
SEQUENCE_NAME = "sis18ring"               
OUT_DIR = "bayes_runs_full_lattice"
os.makedirs(OUT_DIR, exist_ok=True)

SIM_CONFIG = {
    "n_particles": 15000,        
    "rev_rf_ramp": 1000,
    "revmax": 10000,
    "bit_len": 10,               
    "len_sep": 1.5,
    "aperx_sep_en": -0.055,
    "aperx_sep_ex": -0.0685,
    "sep_wire_thick": 100.0 / 1.0e6,
    "x_offset_sep": 0.1,
    "circ": 216.72,
    "beta_fin": 0.9,
    "epsx_rms_fin": 150.0e-6 / 4.0,
    "epsy_rms_fin": 50.0e-6 / 4.0,
    "dpmax": 5.0e-4,
    "betx0": 12.0, "bety0": 8.0,
    "alfx0": 0.0, "alfy0": 0.0,
    "xco0": 0.0, "pxco0": 0.0, "yco0": 0.0, "pyco0": 0.0,
    "ptco0": 0.0,
    "verbose": True
}

t_rev = SIM_CONFIG["circ"] / (299792458.0 * SIM_CONFIG["beta_fin"])
revtot = SIM_CONFIG["revmax"] + SIM_CONFIG["rev_rf_ramp"]

space = [
    Real(4.29, 4.35, name="initial_tune_x"),
    Real(0.0, 5.0, name="central_tune_x"),
    Real(0.0, 0.03, name="essep_bump"),
    Real(0.0, 0.03, name="mgsep_bump"),
    Real(0.0, 90.0, name="phi_deg"),
    Real(0.0, 0.1, name="k2la")
]
OPT_CONFIG = {"n_calls": 40, "n_initial_points": 8, "random_state": 12345}

def convert_seq_to_sixtracklib_via_pysixtrack(madx_seq_path, sequence_name):

    if ps is None:
        raise RuntimeError("pysixtrack not installed (pip install pysixtrack)")

    if Madx is None:
        raise RuntimeError("pymadx not available; install with pip install pymadx (or use system madx)")

    madx = Madx(stdout=False)
    print("Reading MAD-X sequence via pymadx:", madx_seq_path)
    madx.call(madx_seq_path)

    try:
        line = ps.Line.from_madx_sequence(madx.sequence[sequence_name])
    except Exception as err:
        raise RuntimeError(f"pysixtrack conversion failed: {err}")

    try:
        elems = stl.Elements()
        idx_map = {}
        for i, el in enumerate(line._internal_list):  
            if el.__class__.__name__.lower().startswith("drift"):
                elems.append_drift(el.length if hasattr(el, "length") else 0.0)
            elif el.__class__.__name__.lower().startswith("quad"):
                elems.append_multipole(el.length if hasattr(el, "length") else 0.0,
                                       knl=[0.0, getattr(el, "k1", 0.0)])
            elif el.__class__.__name__.lower().startswith("sext"):
                elems.append_multipole(el.length if hasattr(el, "length") else 0.0,
                                       knl=[0.0, 0.0, getattr(el, "k2", 0.0)])
            else:
                if hasattr(el, "length"):
                    elems.append_drift(el.length)
                else:
                    elems.append_multipole(0.0, knl=[0.0])
            idx_map[i] = i
        return elems, {"type": "pysixtrack", "idx_map": idx_map}
    except Exception as err:
        raise RuntimeError(f"Failed to map pysixtrack Line to sixtracklib elements: {err}")

def convert_seq_to_sixtracklib_via_xtrack(madx_seq_path, sequence_name):
    if xt is None:
        raise RuntimeError("xtrack (xsuite) is not installed (pip install xtrack xsuite)")

    if Madx is None:
        raise RuntimeError("pymadx not installed; needed for MAD-X parsing")

    madx = Madx(stdout=False)
    madx.call(madx_seq_path)

    try:
        xline = xt.Line.from_madx_sequence(madx.sequence[sequence_name])
    except Exception as e:
        try:
            xline = xt.Line.from_madx(madx, sequence_name)
        except Exception as e2:
            raise RuntimeError(f"xtrack conversion failed: {e} / {e2}")
            
    try:
        elems = stl.Elements()
        idx_map = {}
        for i, el in enumerate(xline._context._elements): 
            typ = el.__class__.__name__.lower()
            if 'drift' in typ:
                elems.append_drift(getattr(el, 'length', 0.0))
            elif 'quadrupole' in typ or 'quad' in typ:
                k1 = getattr(el, 'k1', getattr(el, 'K', 0.0))
                elems.append_multipole(getattr(el, 'length', 0.0), knl=[0.0, k1])
            elif 'sext' in typ or 'sextupole' in typ:
                k2 = getattr(el, 'k2', 0.0)
                elems.append_multipole(getattr(el, 'length', 0.0), knl=[0.0, 0.0, k2])
            else:
                elems.append_drift(getattr(el, 'length', 0.0))
            idx_map[i] = i
        return elems, {"type": "xtrack", "idx_map": idx_map}
    except Exception as err:
        raise RuntimeError(f"Failed mapping xtrack Line to sixtracklib elements: {err}")

def convert_seq_to_sixtracklib_fallback(madx_seq_path):
    elems = stl.Elements()
    elems.append_drift(1.0)
    try:
        kicker_idx = elems.append_kick(0.0, 0.0)
        using_kick = True
    except Exception:
        kicker_idx = elems.append_multipole(0.0, knl=[0.0])
        using_kick = False
    mp1 = elems.append_multipole(0.0, knl=[0.0, 0.0, 0.0, 0.0])
    mp2 = elems.append_multipole(0.0, knl=[0.0, 0.0, 0.0, 0.0])
    mp3 = elems.append_multipole(0.0, knl=[0.0, 0.0, 0.0, 0.0])
    elems.append_drift(1.0)
    return elems, {"type": "fallback", "kicker": 0, "mp1": 1, "mp2": 2, "mp3": 3, "using_append_kick": using_kick}

def convert_seq_to_sixtracklib(madx_seq_path, sequence_name):
    errors = []
    try:
        print("Attempting conversion via pysixtrack...")
        elems, meta = convert_seq_to_sixtracklib_via_pysixtrack(madx_seq_path, sequence_name)
        print("Conversion via pysixtrack succeeded.")
        return elems, meta
    except Exception as exc:
        errors.append(("pysixtrack", str(exc)))
        print("pysixtrack conversion failed:", exc)

    try:
        print("Attempting conversion via xtrack/xsuite...")
        elems, meta = convert_seq_to_sixtracklib_via_xtrack(madx_seq_path, sequence_name)
        print("Conversion via xtrack succeeded.")
        return elems, meta
    except Exception as exc:
        errors.append(("xtrack", str(exc)))
        print("xtrack conversion failed:", exc)

    print("Falling back to minimal skeleton lattice (non-element-accurate).")
    elems, meta = convert_seq_to_sixtracklib_fallback(madx_seq_path)
    errors.append(("fallback", "used minimal skeleton"))
    return elems, meta

def sample_initial_particles(n_particles, sim_cfg):
    betx0 = sim_cfg["betx0"]
    bety0 = sim_cfg["bety0"]
    alfx0 = sim_cfg["alfx0"]
    alfy0 = sim_cfg["alfy0"]
    gamx0 = (1.0 + alfx0**2) / betx0
    gamy0 = (1.0 + alfy0**2) / bety0

    epsx_rms = sim_cfg["epsx_rms_fin"]
    epsy_rms = sim_cfg["epsy_rms_fin"]
    dpmax = sim_cfg["dpmax"]

    xrms = math.sqrt(epsx_rms * betx0)
    pxrms = math.sqrt(epsx_rms / betx0)
    yrms = math.sqrt(epsy_rms * bety0)
    pyrms = math.sqrt(epsy_rms / bety0)
    ptrms = 0.5 * dpmax * sim_cfg["beta_fin"]

    rng = np.random.default_rng()

    xs = xrms * rng.normal(size=n_particles)
    pxs = pxrms * rng.normal(size=n_particles) - (alfx0/betx0)*xs
    ys = yrms * rng.normal(size=n_particles)
    pys = pyrms * rng.normal(size=n_particles) - (alfy0/bety0)*ys

    ct = sim_cfg["circ"] * (0.5 - rng.random(size=n_particles)) / sim_cfg["beta_fin"] / 5.0
    delta = ptrms * rng.normal(size=n_particles)

    xs += sim_cfg["xco0"]
    pxs += sim_cfg["pxco0"]
    ys += sim_cfg["yco0"]
    pys += sim_cfg["pyco0"]
    delta += sim_cfg["ptco0"]

    arr = np.vstack([xs, pxs, ys, pys, ct, delta]).T
    return arr

def run_simulation_on_converted_lattice(elems, meta, params, sim_cfg):

    tracker = stl.Tracker(elements=elems)

    n_particles = sim_cfg["n_particles"]
    particles_np = sample_initial_particles(n_particles, sim_cfg)

    parts = stl.ParticlesSet(n_particles)
    parts.x = particles_np[:, 0].astype(np.float64)
    parts.px = particles_np[:, 1].astype(np.float64)
    parts.y = particles_np[:, 2].astype(np.float64)
    parts.py = particles_np[:, 3].astype(np.float64)

    try:
        parts.ct = particles_np[:, 4].astype(np.float64)
        parts.delta = particles_np[:, 5].astype(np.float64)
    except Exception:
        parts.t = particles_np[:, 4].astype(np.float64)
        parts.ptau = particles_np[:, 5].astype(np.float64)

    try:
        parts.update_from_host()
    except Exception:
        try:
            parts.sync()
        except Exception:
            pass

    sep_en_idx = None
    sep_ex_idx = None

    if meta.get("type") in ("pysixtrack", "xtrack"):
        idx_map = meta.get("idx_map", {})
        pass

    rev_rf_ramp = sim_cfg["rev_rf_ramp"]
    revmax = sim_cfg["revmax"]
    revtot = revmax + rev_rf_ramp

    rng = np.random.default_rng(123)
    bit_len = sim_cfg["bit_len"]
    n_bit_max = int(np.floor(1.5 * revmax / max(1, bit_len)) + 1)
    phi_ko_table = np.where(rng.random(n_bit_max) < 0.5, 0.0, math.pi)
    randnum_av = rng.standard_normal(revmax + rev_rf_ramp + 10)

    ploss_sep = 0
    ploss_ring = 0
    pnum_ex = 0
    pnum_tot = 0

    alive = np.ones(n_particles, dtype=bool)
    kokick_a = 2.0e-5
    central_tune_x = params["central_tune_x"]
    phi_rad = math.radians(params["phi_deg"])

    tt = 0
    n_bit = 0
    phi0 = phi_ko_table[0]
    kophase = 0.0

    element_by_element_tracking_supported = hasattr(tracker, "track_element_by_element") or hasattr(tracker, "track_elems")

    use_element_by_element = False
    if element_by_element_tracking_supported:
        use_element_by_element = True

    start_time = time.time()
 
    history = []

    for turn in range(1, revtot + 1):
        rev = turn - rev_rf_ramp

        if rev > 0:
            tt += 1
            if tt >= bit_len:
                tt = 0
                n_bit += 1
                if n_bit >= len(phi_ko_table):
                    n_bit = len(phi_ko_table) - 1
                phi0 = phi_ko_table[n_bit]
            kophase += central_tune_x * 2.0 * math.pi
            kokick_t = kokick_a * math.sin(kophase + phi0)
        else:
            kokick_t = 0.0

        idx_rev = max(0, min(len(randnum_av) - 1, rev))
        signal_rand = randnum_av[idx_rev]

        try:
            if meta.get("type") == "pysixtrack" and "idx_map" in meta:
                pass
        except Exception:
            pass

    
        tracker.track(parts, 1)
        try:
            parts.update_to_host()
        except Exception:
            try:
                parts.sync_from_device()
            except Exception:
                pass

        try:
            xs = parts.x.copy()
            pxs = parts.px.copy()
            ys = parts.y.copy()
            pys = parts.py.copy()

        alive_idxs = np.nonzero(alive)[0]
        if alive_idxs.size == 0:
            break

        x_alive = xs[alive_idxs]
        px_alive = pxs[alive_idxs]
        x_sep_ex = x_alive + sim_cfg["len_sep"] * px_alive

        xloss1 = (x_alive - sim_cfg["aperx_sep_en"]) * (x_alive - sim_cfg["aperx_sep_en"] + sim_cfg["sep_wire_thick"])
        cond1_hit_wire = (xloss1 <= 0.0)
        xloss2 = x_sep_ex - sim_cfg["aperx_sep_ex"] + sim_cfg["sep_wire_thick"]
        cond2_hit_exit = (xloss2 >= 0.0)
        sep_loss_flag = cond1_hit_wire | cond2_hit_exit

        extracted_mask = (~sep_loss_flag) & (x_alive >= sim_cfg["aperx_sep_ex"]) & (x_alive <= abs(sim_cfg["aperx_sep_en"]))
        ring_loss_mask = (~extracted_mask) & (sep_loss_flag == False) & ((np.abs(x_alive) > 1.0) | np.isnan(x_alive))

        n_sep_losses = int(np.count_nonzero(sep_loss_flag))
        n_extracted = int(np.count_nonzero(extracted_mask))
        n_ring_loss = int(np.count_nonzero(ring_loss_mask))
        n_alive_total = alive_idxs.size

        ploss_sep += n_sep_losses
        pnum_ex += n_extracted
        ploss_ring += n_ring_loss
        pnum_tot += n_alive_total

        to_kill = np.zeros_like(alive, dtype=bool)
        to_kill[alive_idxs[sep_loss_flag]] = True
        to_kill[alive_idxs[extracted_mask]] = True
        to_kill[alive_idxs[ring_loss_mask]] = True
        alive = alive & (~to_kill)

        if sim_cfg["verbose"] and (turn % max(1, sim_cfg["revmax"] // 50) == 0):
            print(f"turn {turn:6d} | alive {alive.sum():6d} | extracted this turn {n_extracted} | sep_loss {n_sep_losses} | ring_loss {n_ring_loss}")

    elapsed = time.time() - start_time
    summary = {
        "pnum_tot": int(pnum_tot),
        "pnum_ex": int(pnum_ex),
        "ploss_sep": int(ploss_sep),
        "ploss_ring": int(ploss_ring),
        "elapsed_s": elapsed
    }
    return summary

def objective_vector(x, elems, meta, sim_cfg):
    params = {
        "initial_tune_x": float(x[0]),
        "central_tune_x": float(x[1]),
        "essep_bump": float(x[2]),
        "mgsep_bump": float(x[3]),
        "phi_deg": float(x[4]),
        "k2la": float(x[5])
    }
    summary = run_simulation_on_converted_lattice(elems, meta, params, sim_cfg)
    pnum_ex = summary["pnum_ex"]
    pnum_tot = summary["pnum_tot"]
    if pnum_tot <= 0:
        return 1e6
    obj = 100.0 * (1.0 - float(pnum_ex) / float(pnum_tot))
    row = dict(params)
    row.update(summary)
    dfrow = pd.DataFrame([row])
    outcsv = os.path.join(OUT_DIR, "bayes_iter_history.csv")
    if not os.path.exists(outcsv):
        dfrow.to_csv(outcsv, index=False, mode="w")
    else:
        dfrow.to_csv(outcsv, index=False, mode="a", header=False)
    print(f"Eval params {params} => obj {obj:.6f}")
    return obj

def main():
    print("Starting conversion of MAD-X sequence to SixTrackLib elements...")
    elems, meta = convert_seq_to_sixtracklib(MADX_SEQ_PATH, SEQUENCE_NAME)
    converted_path = os.path.join(OUT_DIR, "converted_elements.pkl")
    with open(converted_path, "wb") as fh:
        pickle.dump((elems, meta), fh)
    print("Converted elements saved to:", converted_path)
    print("Converter meta:", meta)

    print("Sanity-check tracking (small test)...")
    try:
        test_summary = run_simulation_on_converted_lattice(elems, meta,
                                                          {"initial_tune_x": (4.29+4.35)/2.0,
                                                           "central_tune_x": 0.5,
                                                           "essep_bump": 0.0, "mgsep_bump": 0.0,
                                                           "phi_deg": 0.0, "k2la": 0.0},
                                                          SIM_CONFIG)
        print("Sanity test returned:", test_summary)
    except Exception as e:
        print("Sanity check failed. This often means the converter produced elements in a layout incompatible with the simple run loop.")
        print("Exception:", e)
        
    print("Starting Bayesian optimization (this will run many tracking simulations...")
    func = partial(objective_vector, elems=elems, meta=meta, sim_cfg=SIM_CONFIG)
    res = gp_minimize(func=func, dimensions=space,
                      acq_func="EI", n_calls=OPT_CONFIG["n_calls"],
                      n_initial_points=OPT_CONFIG["n_initial_points"],
                      random_state=OPT_CONFIG["random_state"], verbose=True)

    with open(os.path.join(OUT_DIR, "gp_result.pkl"), "wb") as fh:
        pickle.dump(res, fh)

    best_params = dict(zip([s.name for s in space], res.x))
    print("Best params:", best_params)
    print("Best objective:", res.fun)

    best_so_far = np.minimum.accumulate(res.func_vals)
    import matplotlib.pyplot as plt
    plt.figure(figsize=(6,4))
    plt.plot(best_so_far, marker="o")
    plt.xlabel("Iteration")
    plt.ylabel("Best objective so far")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, "convergence.png"))
    plt.close()
    print("Saved optimization outputs in", OUT_DIR)

if __name__ == "__main__":
    main()