In [None]:
import math
import os
import sys
import time
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

try:
    import xtrack as xt
    import xpart as xp
    import xobjects as xo
except Exception as e:
    print("ERROR: xtrack/xpart/xobjects import failed. Make sure xSuite is installed.")
    print("Import error:", e)
    sys.exit(1)

try:
    import optuna
except Exception as e:
    print("ERROR: optuna import failed. Install optuna to run optimization.")
    print("Import error:", e)
    sys.exit(1)

In [None]:
MAD_SEQUENCE_FILE = "SIS18RING.SEQ"
MAD_STRENGTHS_FILE = "SIS18_cryocatchers.str"

A = 197                    # mass number
Z = 65                     # charge state
energy_per_u = 400e6       # eV/u
mass_u = 931.494e6         # eV/c^2
N_PARTICLES = 10000        # number of macroparticles
N_TURNS = 10000            # number of turns to track
OPTUNA_LOG_CSV = "optuna_log.csv"
FINAL_COORDS_CSV = "final_coordinates.csv"
x_septum_mm = -55.0        # mm 
x_septum = x_septum_mm * 1e-3   # convert to meters
Qx_initial = 4.331
Qy_initial = 3.29
chrom_x_initial = -1.0
OCT_MIN = 0.0
OCT_MAX = 5.0
RFKO_MIN_KHZ = 10.0
RFKO_MAX_KHZ = 80.0
ELECTRO_BUMP_MIN = -25.0
ELECTRO_BUMP_MAX = -5.0
MAG_BUMP_MIN = 0.0
MAG_BUMP_MAX = 15.0
SEXT_AMP_MIN = 0.0
SEXT_AMP_MAX = 1.0
SEXT_PHASE_MIN_DEG = 0.0
SEXT_PHASE_MAX_DEG = 180.0
SEPTUM_DEF_MIN_DEG = -5.0
SEPTUM_DEF_MAX_DEG = 0.0
CHROM_X_MIN = -5.0
CHROM_X_MAX = 5.0
SEED = 12345
np.random.seed(SEED)
random.seed(SEED)

In [None]:
def load_line(seq_file=MAD_SEQUENCE_FILE, str_file=MAD_STRENGTHS_FILE):
    mad = xt.Madx()
    mad.call(seq_file)
    mad.call(str_file)

        seq_name = "SIS18RING"
    if not hasattr(mad.sequence, seq_name):
        seqs = [k for k in mad.sequence.__dict__.keys() if not k.startswith('_')]
        if len(seqs) == 0:
            raise RuntimeError("No sequence found in MAD-X after loading files.")
        seq_name = seqs[0]

    seq = mad.sequence.__getattr__(seq_name)
    line = xt.Line.from_madx_sequence(seq)
    line.build_tracker()
    print(f"Loaded lattice sequence '{seq_name}' with {len(line.elements)} elements.")
    return line

def create_particle_reference():
    mass0 = A * mass_u
    E_tot = mass0 + A * energy_per_u
    p0c = math.sqrt(E_tot * E_tot - mass0 * mass0)
    particle_ref = xp.Particles(mass0=mass0, q0=Z, p0c=p0c)
    return particle_ref


def generate_initial_beam(line, particle_ref, n_particles=N_PARTICLES):
    beam = xp.generate_matched_gaussian_bunch(
        num_particles=n_particles,
        total_intensity_particles=1e9,
        nemitt_x=5e-6,
        nemitt_y=5e-6,
        line=line,
        particle_ref=particle_ref
    )
    return beam

def resonance_factor(f_center_khz, target_khz=30.0, width_khz=8.0):
    delta = (f_center_khz - target_khz) / width_khz
    return math.exp(-0.5 * delta * delta)

def apply_rfko_kick(beam, kick_amp, f_center_khz, turn):
    envelope = 0.5 * (1.0 + math.sin(2.0 * math.pi * (turn / 500.0)))
    noise = np.random.normal(loc=0.0, scale=0.1, size=beam.x.size)
    active_mask = np.isfinite(beam.x)
    px = beam.px
    px_delta = np.zeros_like(px)
    px_delta[active_mask] = kick_amp * envelope * (1.0 + 0.1 * noise[active_mask])
    beam.px += px_delta

def evaluate_septum_extraction(beam, params_state):
    px = beam.px
    active = np.isfinite(x) & np.isfinite(px)
    candidate = active & (x < x_septum)
    angle_limit = 0.05 
    angle_ok = np.abs(px) < angle_limit
    eligible = candidate & angle_ok
    res = resonance_factor(params_state['f_center_khz'], target_khz=params_state['target_khz'], width_khz=params_state['width_khz'])
    base_prob = 0.30  

    kick_factor = math.tanh(params_state['kick_amp'] * 1e4)  
    chrom_factor = 1.0 + 0.05 * abs(params_state['chrom_x'])
    elec_bump_factor = 1.0 + 0.02 * abs(params_state['elec_bump'])
    mag_bump_factor = 1.0 + 0.02 * abs(params_state['mag_bump'])
    sext_factor = 1.0 + 0.1 * abs(params_state['sext_amp'])
    septum_defl_factor = 1.0 + 0.1 * abs(params_state['septum_deflection_rad'])

    prob = base_prob * res * (0.5 + 0.5 * kick_factor) * chrom_factor * elec_bump_factor * mag_bump_factor * sext_factor * septum_defl_factor
    if prob < 0.0:
        prob = 0.0
    if prob > 0.99:
        prob = 0.99

    n = beam.x.size
    rnd = np.random.rand(n)
    extracted_mask = (rnd < prob) & eligible

    return extracted_mask

In [None]:
def track_with_params(line, beam_in, params, n_turns=N_TURNS, verbose=False, add_octupoles=False, oct_strength=1.0):
    beam = xp.Particles.from_dict(beam_in.to_dict())
    oct_boost = 1.0
    if add_octupoles:
        oct_boost = 1.0 + 0.05 * oct_strength

    params_state = {
        'f_center_khz': params['f_center_khz'],
        'target_khz': params.get('target_khz', 30.0),
        'width_khz': params.get('width_khz', 8.0),
        'kick_amp': params['kick_amp'],
        'chrom_x': params['chrom_x'],
        'elec_bump': params['elec_bump'],
        'mag_bump': params['mag_bump'],
        'sext_amp': params['sext_amp'],
        'sext_phase_deg': params['sext_phase_deg'],
        'septum_deflection_rad': params['septum_deflection_rad']
    }

    n_particles = beam.x.size
    extracted_mask_global = np.zeros(n_particles, dtype=bool)
    lost_mask_global = np.zeros(n_particles, dtype=bool)

    tracker = line.get_tracker()

    for turn in range(n_turns):
        apply_rfko_kick(beam, params['kick_amp'], params['f_center_khz'], turn)
        tracker.track(beam, num_turns=1)
        coords_ok = np.isfinite(beam.x) & np.isfinite(beam.px) & np.isfinite(beam.y) & np.isfinite(beam.py)
        lost_now = ~coords_ok
        lost_now = lost_now & (~extracted_mask_global)
        if lost_now.any():
            idx_lost = np.where(lost_now)[0]
            lost_mask_global[idx_lost] = True
            beam.x[idx_lost] = np.nan
            beam.px[idx_lost] = np.nan
            beam.y[idx_lost] = np.nan
            beam.py[idx_lost] = np.nan
            beam.zeta[idx_lost] = np.nan
            beam.ptau[idx_lost] = np.nan

        extracted_now = evaluate_septum_extraction(beam, params_state)

        if add_octupoles and extracted_now.any():
            idx_ex = np.where(extracted_now)[0]
            rnd2 = np.random.rand(idx_ex.size)
            accept_idx_rel = rnd2 < (0.5 * oct_boost)
            mask_accept = np.zeros_like(extracted_now, dtype=bool)
            mask_accept[idx_ex[accept_idx_rel]] = True
            extracted_now = mask_accept

        extracted_now = extracted_now & (~extracted_mask_global) & (~lost_mask_global)

        if extracted_now.any():
            idx_extracted = np.where(extracted_now)[0]
            extracted_mask_global[idx_extracted] = True
            beam.x[idx_extracted] = np.nan
            beam.px[idx_extracted] = np.nan
            beam.y[idx_extracted] = np.nan
            beam.py[idx_extracted] = np.nan
            beam.zeta[idx_extracted] = np.nan
            beam.ptau[idx_extracted] = np.nan

        if turn % 500 == 0 and turn > 0:
            active = ~(extracted_mask_global | lost_mask_global)
            active_idx = np.where(active)[0]
            if active_idx.size > 0:
                n_bg = max(1, int(0.0002 * active_idx.size))
                sel = np.random.choice(active_idx, size=n_bg, replace=False)
                lost_mask_global[sel] = True
                beam.x[sel] = np.nan
                beam.px[sel] = np.nan
                beam.y[sel] = np.nan
                beam.py[sel] = np.nan
                beam.zeta[sel] = np.nan
                beam.ptau[sel] = np.nan

        if (extracted_mask_global | lost_mask_global).all():
            if verbose:
                print(f"All particles processed by turn {turn}")
            break

    results = {
        'n_extracted': int(extracted_mask_global.sum()),
        'n_lost': int(lost_mask_global.sum()),
        'n_remaining': int((~(extracted_mask_global | lost_mask_global)).sum()),
        'final_beam': beam,
        'extracted_mask': extracted_mask_global,
        'lost_mask': lost_mask_global,
        'turns_used': turn + 1
    }
    return results

In [None]:
def objective_optuna(trial, line, particle_ref, initial_beam):
    f_center_khz = trial.suggest_float("f_center_khz", RFKO_MIN_KHZ, RFKO_MAX_KHZ)
    chrom_x = trial.suggest_float("chrom_x", CHROM_X_MIN, CHROM_X_MAX)
    elec_bump = trial.suggest_float("elec_bump", ELECTRO_BUMP_MIN, ELECTRO_BUMP_MAX)
    mag_bump = trial.suggest_float("mag_bump", MAG_BUMP_MIN, MAG_BUMP_MAX)
    sext_amp = trial.suggest_float("sext_amp", SEXT_AMP_MIN, SEXT_AMP_MAX)
    sext_phase_deg = trial.suggest_float("sext_phase_deg", SEXT_PHASE_MIN_DEG, SEXT_PHASE_MAX_DEG)
    septum_deflection_deg = trial.suggest_float("septum_deflection_deg", SEPTUM_DEF_MIN_DEG, SEPTUM_DEF_MAX_DEG)
    kick_amp = trial.suggest_float("kick_amp", 1e-7, 5e-5, log=True)

    params = {
        'f_center_khz': f_center_khz,
        'kick_amp': kick_amp,
        'chrom_x': chrom_x,
        'elec_bump': elec_bump,
        'mag_bump': mag_bump,
        'sext_amp': sext_amp,
        'sext_phase_deg': sext_phase_deg,
        'septum_deflection_rad': math.radians(septum_deflection_deg)
    }

    res = track_with_params(line, initial_beam, params, n_turns=N_TURNS, verbose=False, add_octupoles=False)

    Ntot = initial_beam.x.size
    Next = res['n_extracted']
    objective_value = 100.0 * (1.0 - (float(Next) / float(Ntot)))

    trial.set_user_attr("n_extracted", res['n_extracted'])
    trial.set_user_attr("n_lost", res['n_lost'])
    trial.set_user_attr("n_remaining", res['n_remaining'])
    trial.set_user_attr("turns_used", res['turns_used'])

    log_row = {
        'trial_number': trial.number,
        'objective': objective_value,
        'n_extracted': res['n_extracted'],
        'n_lost': res['n_lost'],
        'n_remaining': res['n_remaining'],
        'f_center_khz': f_center_khz,
        'kick_amp': kick_amp,
        'chrom_x': chrom_x,
        'elec_bump': elec_bump,
        'mag_bump': mag_bump,
        'sext_amp': sext_amp,
        'sext_phase_deg': sext_phase_deg,
        'septum_deflection_deg': septum_deflection_deg,
        'turns_used': res['turns_used']
    }
    write_log_row(log_row)

    return objective_value

def write_log_row(row, filename=OPTUNA_LOG_CSV):
    df_row = pd.DataFrame([row])
    write_header = not os.path.exists(filename)
    if write_header:
        df_row.to_csv(filename, index=False, mode='w')
    else:
        df_row.to_csv(filename, index=False, mode='a', header=False)
def insert_octupoles(line, strength, positions=None):
    n_elems = len(line.elements)
    if positions is None:
        positions = [int(n_elems * 0.25), int(n_elems * 0.5), int(n_elems * 0.75)]

    for i, pos in enumerate(positions):
        name = f"OCT_{i}"
        knl = [0.0, 0.0, 0.0, float(strength)]
        oct_elem = xt.Multipole(knl=knl, ksl=0.0, name=name)
        line.insert_element(element=oct_elem, index=pos)

    line.build_tracker()
    return line

In [None]:
def main():
    line = load_line()

    particle_ref = create_particle_reference()
    initial_beam = generate_initial_beam(line, particle_ref, n_particles=N_PARTICLES)

    print(f"Initial beam created with {initial_beam.x.size} particles.")

    if os.path.exists(OPTUNA_LOG_CSV):
        os.remove(OPTUNA_LOG_CSV)

    study = optuna.create_study(direction="minimize", sampler=optuna.samplers.TPESampler(seed=SEED))

    N_TRIALS = 50

    print(f"Running optimization for {N_TRIALS} trials. Note: each trial tracks {N_TURNS} turns for {N_PARTICLES} particles.")
    start_time = time.time()
    study.optimize(lambda trial: objective_optuna(trial, line, particle_ref, initial_beam), n_trials=N_TRIALS)
    elapsed = time.time() - start_time
    print(f"Optimization completed in {elapsed/60.0:.2f} minutes.")

    best = study.best_trial
    print("Best trial number:", best.number)
    print("Best objective (100*(1 - Next/Ntot)):", best.value)
    print("Best parameters:")
    for k, v in best.params.items():
        print(f"  {k} : {v}")
        
    best_params = best.params.copy()
    params_final = {
        'f_center_khz': float(best_params['f_center_khz']),
        'kick_amp': float(best_params['kick_amp']),
        'chrom_x': float(best_params['chrom_x']),
        'elec_bump': float(best_params['elec_bump']),
        'mag_bump': float(best_params['mag_bump']),
        'sext_amp': float(best_params['sext_amp']),
        'sext_phase_deg': float(best_params['sext_phase_deg']),
        'septum_deflection_rad': math.radians(float(best_params['septum_deflection_deg']))
    }

    print("\nRunning final tracking with best parameters (no octupoles) to write final coordinate file...")
    final_beam_res = track_with_params(line, initial_beam, params_final, n_turns=N_TURNS, verbose=True, add_octupoles=False)
    final_beam = final_beam_res['final_beam']

    coords_df = pd.DataFrame({
        'x': np.array(final_beam.x).flatten(),
        'px': np.array(final_beam.px).flatten(),
        'y': np.array(final_beam.y).flatten(),
        'py': np.array(final_beam.py).flatten(),
        't': np.array(final_beam.zeta).flatten(),
        'pt': np.array(final_beam.ptau).flatten()
    })

    coords_df.to_csv(FINAL_COORDS_CSV, index=False)
    print(f"Final coordinates saved to {FINAL_COORDS_CSV}")

    print("\nSummary (no octupoles):")
    print("  N_total =", N_PARTICLES)
    print("  N_extracted =", final_beam_res['n_extracted'])
    print("  N_lost =", final_beam_res['n_lost'])
    print("  N_remaining =", final_beam_res['n_remaining'])
    objective_final = 100.0 * (1.0 - (float(final_beam_res['n_extracted']) / float(N_PARTICLES)))
    print("  Objective (100*(1 - Next/Ntot)) =", objective_final)

    line_with_oct = line.copy()
    oct_strength_to_use = 2.0  
    insert_octupoles(line_with_oct, oct_strength_to_use)

    initial_beam_oct = generate_initial_beam(line_with_oct, particle_ref, n_particles=N_PARTICLES)

    final_oct_res = track_with_params(line_with_oct, initial_beam_oct, params_final, n_turns=N_TURNS, verbose=True, add_octupoles=True, oct_strength=oct_strength_to_use)

    print("\nSummary (with octupoles):")
    print("  N_total =", N_PARTICLES)
    print("  N_extracted =", final_oct_res['n_extracted'])
    print("  N_lost =", final_oct_res['n_lost'])
    print("  N_remaining =", final_oct_res['n_remaining'])
    objective_oct = 100.0 * (1.0 - (float(final_oct_res['n_extracted']) / float(N_PARTICLES)))
    print("  Objective (100*(1 - Next/Ntot)) =", objective_oct)

if __name__ == "__main__":
    main()