# Integrator Testbed

Experiment with the core integrator across all supported simulation types. Toggle the legacy
comparison when needed, tune particle parameters, and explore output artefacts for debugging or validation.

In [8]:
from __future__ import annotations



import html

import io

import json

import sys

from pathlib import Path

from typing import Dict, List, Optional



import ipywidgets as widgets

import matplotlib.pyplot as plt

from IPython.display import display



# Ensure repository paths are importable when running in VS Code/notebooks

PROJECT_ROOT = Path.cwd().resolve()

while PROJECT_ROOT.name and PROJECT_ROOT.name != "LW_windows":

    PROJECT_ROOT = PROJECT_ROOT.parent

if str(PROJECT_ROOT) not in sys.path:

    sys.path.insert(0, str(PROJECT_ROOT))



LEGACY_ROOT = PROJECT_ROOT / "legacy"

if str(LEGACY_ROOT) not in sys.path:

    sys.path.insert(0, str(LEGACY_ROOT))



VALIDATION_ROOT = PROJECT_ROOT / "examples" / "validation"

if str(VALIDATION_ROOT) not in sys.path:

    sys.path.insert(0, str(VALIDATION_ROOT))



from core_vs_legacy_benchmark import (

    DEFAULT_DRIVER_PARAMS,

    DEFAULT_RIDER_PARAMS,

    PARTICLE_PARAM_FIELDS,

    SimulationType,

    compute_delta_energy_series,

    run_benchmark,

    summarise_metrics,

)



COLOR_RIDER = "#0072B2"

COLOR_DRIVER = "#D55E00"

COLOR_LEGACY_RIDER = "#56B4E9"

COLOR_LEGACY_DRIVER = "#E69F00"

SCATTER_STYLE = {"s": 140, "alpha": 0.78, "linewidth": 0, "edgecolors": "none"}

TITLE_FONTSIZE = 18

LABEL_FONTSIZE = 16

TICK_FONTSIZE = 13

LEGEND_FONTSIZE = 12

AVAILABLE_DPI_CHOICES = (150, 300, 450, 600)

DEFAULT_PLOT_DPI = 300



plt.style.use("seaborn-v0_8-whitegrid")

plt.rcParams.update(

    {

        "axes.titlesize": TITLE_FONTSIZE,

        "axes.labelsize": LABEL_FONTSIZE,

        "xtick.labelsize": TICK_FONTSIZE,

        "ytick.labelsize": TICK_FONTSIZE,

        "legend.fontsize": LEGEND_FONTSIZE,

    }

)


In [9]:
PARAM_LABELS = {
    "starting_distance": "Start z (mm)",
    "transv_mom": "Transverse momentum (amu·mm/ns)",
    "starting_Pz": "Initial Pz (amu·mm/ns)",
    "stripped_ions": "Stripped ions",
    "m_particle": "Mass (amu)",
    "transv_dist": "Transverse spread (mm)",
    "pcount": "Particle count",
    "charge_sign": "Charge sign",
}

CORE_PARAM_LABELS = {
    "time_step": "Time step (ns)",
    "wall_z": "Wall z (mm)",
    "aperture_radius": "Aperture radius (mm)",
    "mean": "Mean separation (mm)",
    "cav_spacing": "Cavity spacing (mm)",
    "z_cutoff": "z cutoff (mm)",
}

PARTICLE_WIDGET_WIDTH = widgets.Layout(width="320px")
CORE_WIDGET_WIDTH = widgets.Layout(width="280px")
PARTICLE_WIDGET_STYLE = {"description_width": "175px"}
CORE_WIDGET_STYLE = {"description_width": "175px"}
ROW_LAYOUT = widgets.Layout(gap="14px")

SIMULATION_TYPE_OPTIONS = {
    "Conducting wall": SimulationType.CONDUCTING_WALL,
    "Switching wall": SimulationType.SWITCHING_WALL,
    "Bunch to bunch": SimulationType.BUNCH_TO_BUNCH,
}

SPECIES_PRESETS: Dict[str, Optional[Dict[str, float]]] = {
    "custom": None,
    "electron": {
        "m_particle": 5.48579909070e-4,
        "charge_sign": -1.0,
        "stripped_ions": 1.0,
    },
    "positron": {
        "m_particle": 5.48579909070e-4,
        "charge_sign": 1.0,
        "stripped_ions": 1.0,
    },
    "proton": {
        "m_particle": 1.007276466621,
        "charge_sign": 1.0,
        "stripped_ions": 1.0,
    },
    "antiproton": {
        "m_particle": 1.007276466621,
        "charge_sign": -1.0,
        "stripped_ions": 1.0,
    },
    "lead": {
        "m_particle": 207.9766521,
        "charge_sign": 1.0,
        "stripped_ions": 82.0,
    },
    "gold": {
        "m_particle": 196.9665687,
        "charge_sign": 1.0,
        "stripped_ions": 79.0,
    },
}

SPECIES_DROPDOWN_OPTIONS = [
    ("Custom / manual", "custom"),
    ("Electron (e⁻)", "electron"),
    ("Positron (e⁺)", "positron"),
    ("Proton (p⁺)", "proton"),
    ("Antiproton (p̄⁻)", "antiproton"),
    ("Lead ion (Pb⁸²⁺)", "lead"),
    ("Gold ion (Au⁷⁹⁺)", "gold"),
]

CORE_PARAM_DEFAULTS = {
    "time_step": 2.2e-7,
    "wall_z": 1e5,
    "aperture_radius": 1e5,
    "mean": 1e5,
    "cav_spacing": 1e5,
    "z_cutoff": 0.0,
}

CORE_REQUIRED_PARAMS = {
    SimulationType.CONDUCTING_WALL: {"time_step", "wall_z", "aperture_radius"},
    SimulationType.SWITCHING_WALL: {
        "time_step",
        "wall_z",
        "aperture_radius",
        "cav_spacing",
        "z_cutoff",
    },
    SimulationType.BUNCH_TO_BUNCH: {"time_step", "aperture_radius"},
}


def _make_particle_widgets(defaults: Dict[str, float | int]) -> Dict[str, widgets.Widget]:
    controls: Dict[str, widgets.Widget] = {}
    for name in PARTICLE_PARAM_FIELDS:
        default_value = defaults[name]
        description = PARAM_LABELS.get(name, name.replace("_", " ").title())
        if isinstance(default_value, int):
            if name == "pcount":
                control = widgets.IntSlider(
                    value=int(default_value),
                    min=1,
                    max=256,
                    step=1,
                    description=description,
                    continuous_update=False,
                    layout=PARTICLE_WIDGET_WIDTH,
                    style=PARTICLE_WIDGET_STYLE,
                )
            else:
                control = widgets.IntText(
                    value=int(default_value),
                    description=description,
                    layout=PARTICLE_WIDGET_WIDTH,
                    style=PARTICLE_WIDGET_STYLE,
                )
        else:
            control = widgets.FloatText(
                value=float(default_value),
                description=description,
                layout=PARTICLE_WIDGET_WIDTH,
                style=PARTICLE_WIDGET_STYLE,
            )
        controls[name] = control
    return controls


def _make_core_widgets() -> Dict[str, widgets.Widget]:
    controls: Dict[str, widgets.Widget] = {}
    for name, default_value in CORE_PARAM_DEFAULTS.items():
        description = CORE_PARAM_LABELS.get(name, name.replace("_", " ").title())
        controls[name] = widgets.FloatText(
            value=float(default_value),
            description=description,
            layout=CORE_WIDGET_WIDTH,
            style=CORE_WIDGET_STYLE,
        )
    return controls


def _apply_species_preset(
    controls: Dict[str, widgets.Widget], preset_key: str
) -> None:
    preset = SPECIES_PRESETS.get(preset_key)
    if not preset:
        return
    for field, value in preset.items():
        control = controls.get(field)
        if control is None:
            continue
        current_value = control.value
        if isinstance(current_value, int):
            control.value = int(round(value))
        else:
            control.value = float(value)


def _particle_rows(controls: Dict[str, widgets.Widget]) -> List[widgets.Widget]:
    rows: List[widgets.Widget] = []
    row: List[widgets.Widget] = []
    for name in PARTICLE_PARAM_FIELDS:
        row.append(controls[name])
        if len(row) == 2:
            rows.append(widgets.HBox(row, layout=ROW_LAYOUT))
            row = []
    if row:
        rows.append(widgets.HBox(row, layout=ROW_LAYOUT))
    return rows


def _core_rows(controls: Dict[str, widgets.Widget]) -> List[widgets.Widget]:
    rows: List[widgets.Widget] = []
    row: List[widgets.Widget] = []
    for name in CORE_PARAM_DEFAULTS:
        row.append(controls[name])
        if len(row) == 3:
            rows.append(widgets.HBox(row, layout=ROW_LAYOUT))
            row = []
    if row:
        rows.append(widgets.HBox(row, layout=ROW_LAYOUT))
    return rows


def _collect_particle_values(controls: Dict[str, widgets.Widget]) -> Dict[str, float | int]:
    values: Dict[str, float | int] = {}
    for name, control in controls.items():
        values[name] = control.value
    return values


def _collect_core_values(controls: Dict[str, widgets.Widget]) -> Dict[str, float]:
    values: Dict[str, float] = {}
    for name, control in controls.items():
        values[name] = float(control.value)
    return values


def _build_particle_section(
    title: str, controls: Dict[str, widgets.Widget], preset_widget: widgets.Widget | None
) -> widgets.Accordion:
    rows = []
    if preset_widget is not None:
        rows.append(preset_widget)
    rows.extend(_particle_rows(controls))
    accordion = widgets.Accordion(children=[widgets.VBox(rows)])
    accordion.set_title(0, title)
    return accordion


def _build_core_section(controls: Dict[str, widgets.Widget]) -> widgets.Accordion:
    rows = _core_rows(controls)
    accordion = widgets.Accordion(children=[widgets.VBox(rows)])
    accordion.set_title(0, "Core configuration")
    return accordion


def _strip_driver_summary(summary: str) -> str:
    lines: List[str] = []
    skipping = False
    for line in summary.splitlines():
        if line.startswith("- Driver"):
            skipping = True
            continue
        if skipping:
            if line.startswith("  driver"):
                continue
            skipping = False
        lines.append(line)
    return "\n".join(lines)


def _required_params_for(sim_type: SimulationType) -> set[str]:
    return CORE_REQUIRED_PARAMS.get(sim_type, set())


def _apply_core_param_state(
    controls: Dict[str, widgets.Widget], required_params: set[str]
) -> None:
    for name, control in controls.items():
        control.disabled = name not in required_params
        control.layout.opacity = 1.0 if name in required_params else 0.45


In [10]:
steps_widget = widgets.IntSlider(

    value=1000,

    min=10,

    max=20000,

    step=10,

    description="Steps:",

    continuous_update=False,

)

seed_widget = widgets.IntText(value=12345, description="Seed:")

legacy_toggle = widgets.Checkbox(value=False, description="Include legacy comparison")

simulation_widget = widgets.Dropdown(

    options=[(label, value) for label, value in SIMULATION_TYPE_OPTIONS.items()],

    value=SimulationType.BUNCH_TO_BUNCH,

    description="Core mode:",

)

overlay_display_widget = widgets.Checkbox(

    value=False, description="Show overlay plot"

)

overlay_save_widget = widgets.Checkbox(

    value=False, description="Save overlay plot"

)

metrics_save_widget = widgets.Checkbox(

    value=False, description="Save metrics JSON"

)

energy_save_widget = widgets.Checkbox(

    value=False, description="Save ΔE plots"

)

dpi_widget = widgets.Dropdown(

    options=[(f"{value} dpi", value) for value in AVAILABLE_DPI_CHOICES],

    value=DEFAULT_PLOT_DPI,

    description="Plot DPI:",

    layout=widgets.Layout(width="220px"),

    style=CORE_WIDGET_STYLE,

)

output_dir_widget = widgets.Text(

    value="test_outputs/testbed_runs",

    description="Output dir:",

    layout=widgets.Layout(width="360px"),

    style=CORE_WIDGET_STYLE,

)

run_button = widgets.Button(

    description="Run integrator",

    icon="play",

    button_style="primary",

)

output_area = widgets.Output()



rider_controls = _make_particle_widgets(DEFAULT_RIDER_PARAMS)

driver_controls = _make_particle_widgets(DEFAULT_DRIVER_PARAMS)

core_controls = _make_core_widgets()



rider_species_widget = widgets.Dropdown(

    options=SPECIES_DROPDOWN_OPTIONS,

    value="custom",

    description="Rider species:",

    layout=PARTICLE_WIDGET_WIDTH,

    style=PARTICLE_WIDGET_STYLE,

)

driver_species_widget = widgets.Dropdown(

    options=SPECIES_DROPDOWN_OPTIONS,

    value="custom",

    description="Driver species:",

    layout=PARTICLE_WIDGET_WIDTH,

    style=PARTICLE_WIDGET_STYLE,

)



rider_section = _build_particle_section("Rider particle", rider_controls, rider_species_widget)

driver_section = _build_particle_section("Driver particle", driver_controls, driver_species_widget)

core_section = _build_core_section(core_controls)



def _resolved_output_dir() -> Path:

    raw_value = output_dir_widget.value.strip() or "test_outputs/testbed_runs"

    return Path(raw_value).expanduser()



def _save_dir_if_needed(should_create: bool) -> Path:

    directory = _resolved_output_dir()

    if should_create:

        directory.mkdir(parents=True, exist_ok=True)

    return directory



def _update_legacy_controls(change) -> None:

    legacy_enabled = bool(change["new"] if isinstance(change, dict) else legacy_toggle.value)

    for checkbox in (overlay_display_widget, overlay_save_widget, metrics_save_widget):

        checkbox.disabled = not legacy_enabled

        if not legacy_enabled:

            checkbox.value = False



def _supports_driver(sim_type: SimulationType) -> bool:

    return sim_type == SimulationType.BUNCH_TO_BUNCH



def _apply_driver_visibility(sim_value: SimulationType) -> None:

    supports_driver = _supports_driver(sim_value)

    driver_section.layout.display = "" if supports_driver else "none"

    if not supports_driver:

        driver_section.selected_index = None



def _on_rider_species(change) -> None:

    new_value = change.get("new") if isinstance(change, dict) else rider_species_widget.value

    if not new_value or new_value == change.get("old"):

        return

    _apply_species_preset(rider_controls, str(new_value))



def _on_driver_species(change) -> None:

    new_value = change.get("new") if isinstance(change, dict) else driver_species_widget.value

    if not new_value or new_value == change.get("old"):

        return

    _apply_species_preset(driver_controls, str(new_value))



rider_species_widget.observe(_on_rider_species, names="value")

driver_species_widget.observe(_on_driver_species, names="value")



legacy_toggle.observe(lambda change: _update_legacy_controls(change), names="value")

_update_legacy_controls({"new": legacy_toggle.value})



def _update_core_controls(change=None) -> None:

    sim_value = change["new"] if isinstance(change, dict) else simulation_widget.value

    required = _required_params_for(sim_value)

    _apply_core_param_state(core_controls, required)

    _apply_driver_visibility(sim_value)



simulation_widget.observe(_update_core_controls, names="value")

_update_core_controls()



def handle_run(_):

    with output_area:

        output_area.clear_output(wait=True)



        legacy_enabled = legacy_toggle.value

        sim_type_value = simulation_widget.value

        supports_driver = _supports_driver(sim_type_value)



        rider_params = _collect_particle_values(rider_controls)

        driver_params = _collect_particle_values(driver_controls) if supports_driver else None

        core_params = _collect_core_values(core_controls)

        required_core = _required_params_for(sim_type_value)



        needs_dir = any([

            overlay_save_widget.value,

            metrics_save_widget.value,

            energy_save_widget.value,

        ])

        save_dir = _save_dir_if_needed(needs_dir)



        save_json_path = (

            save_dir / "benchmark_metrics.json"

            if legacy_enabled and metrics_save_widget.value

            else None

        )

        save_overlay_path = (

            save_dir / "core_legacy_overlay.png"

            if legacy_enabled and overlay_save_widget.value

            else None

        )

        save_energy_path = (

            save_dir / "delta_energy_scatter.png"

            if energy_save_widget.value

            else None

        )



        overlay_requested = legacy_enabled and (

            overlay_display_widget.value or overlay_save_widget.value

        )



        plot_dpi_value = dpi_widget.value

        time_step_value = core_params["time_step"]

        wall_z_value = core_params["wall_z"] if "wall_z" in required_core else None

        aperture_radius_value = (

            core_params["aperture_radius"] if "aperture_radius" in required_core else None

        )

        mean_value = core_params["mean"] if "mean" in required_core else None

        cav_spacing_value = core_params["cav_spacing"] if "cav_spacing" in required_core else None

        z_cutoff_value = core_params["z_cutoff"] if "z_cutoff" in required_core else None



        log_messages: List[str] = []



        metrics, payload = run_benchmark(

            steps=steps_widget.value,

            seed=seed_widget.value,

            rider_params=rider_params,

            driver_params=driver_params,

            legacy_enabled=legacy_enabled,

            simulation_type=sim_type_value,

            time_step=time_step_value,

            wall_z=wall_z_value,

            aperture_radius=aperture_radius_value,

            mean=mean_value,

            cav_spacing=cav_spacing_value,

            z_cutoff=z_cutoff_value,

            save_json=save_json_path,

            save_fig=save_overlay_path,

            show=overlay_display_widget.value and legacy_enabled,

            plot=overlay_requested,

            return_trajectories=True,

            plot_dpi=plot_dpi_value,

            log_messages=log_messages,

)



        rider_delta, rider_z = compute_delta_energy_series(

            payload["core"]["rider"],

            payload["initial_states"]["rider"],

            payload["rest_energy_mev"]["rider"],

        )



        axis_count = 2 if supports_driver else 1

        fig_width = 13 if supports_driver else 7.5



        fig, axes = plt.subplots(

            1,

            axis_count,

            figsize=(fig_width, 6),

            constrained_layout=True,

            dpi=plot_dpi_value,

)

        if axis_count == 1:

            axes = [axes]



        fig.patch.set_facecolor("white")



        axes[0].scatter(

            rider_z,

            rider_delta,

            color=COLOR_RIDER,

            label="Rider ΔE",

            **SCATTER_STYLE,

)

        axes[0].set_title("Rider ΔE vs z")

        axes[0].set_xlabel("z (mm)")

        axes[0].set_ylabel("ΔE (GeV)")

        axes[0].grid(True, alpha=0.3)

        axes[0].legend()



        if supports_driver:

            driver_delta, driver_z = compute_delta_energy_series(

                payload["core"]["driver"],

                payload["initial_states"]["driver"],

                payload["rest_energy_mev"]["driver"],

)

            axes[1].scatter(

                driver_z,

                driver_delta,

                color=COLOR_DRIVER,

                label="Driver ΔE",

                **SCATTER_STYLE,

)

            axes[1].set_title("Driver ΔE vs z")

            axes[1].set_xlabel("z (mm)")

            axes[1].set_ylabel("ΔE (GeV)")

            axes[1].grid(True, alpha=0.3)

            axes[1].legend()

        else:

            log_messages.append("Driver outputs suppressed for wall-mode simulations.")



        for axis in axes:

            axis.tick_params(axis="both", which="major", labelsize=TICK_FONTSIZE)



        if save_energy_path is not None:

            fig.savefig(save_energy_path, dpi=plot_dpi_value, bbox_inches="tight")

            log_messages.append(f"ΔE scatter saved to {save_energy_path}")



        figure_buffer = io.BytesIO()

        fig.savefig(figure_buffer, format="png", dpi=plot_dpi_value, bbox_inches="tight")

        figure_buffer.seek(0)

        figure_widget = widgets.Image(

            value=figure_buffer.getvalue(),

            format="png",

            layout=widgets.Layout(max_width="100%"),

)

        plt.close(fig)



        info_sections: List[widgets.Widget] = []

        section_titles: List[str] = []



        if metrics:

            summary_text = summarise_metrics(metrics)

            if not supports_driver:

                summary_text = _strip_driver_summary(summary_text)

            info_sections.append(

                widgets.HTML(value=f"<pre style='margin:0'>{html.escape(summary_text)}</pre>")

            )

            section_titles.append("Metrics summary")

        else:

            info_sections.append(

                widgets.HTML(value="<p style='margin:0'>Legacy comparison skipped; no metrics computed.</p>")

            )

            section_titles.append("Metrics summary")



        config_payload = {

            "simulation_type": sim_type_value.name,

            "required_core_params": {k: core_params[k] for k in required_core},

            "optional_core_params": {

                k: core_params[k] for k in core_params if k not in required_core

            },

            "plot_dpi": plot_dpi_value,

            "steps": steps_widget.value,

            "seed": seed_widget.value,

            "legacy_enabled": legacy_enabled,

        }

        info_sections.append(

            widgets.HTML(

                value=(

                    "<pre style='margin:0'>"

                    f"{html.escape(json.dumps(config_payload, indent=2, sort_keys=True))}"

                    "</pre>"

                )

            )

        )

        section_titles.append("Run configuration")



        if log_messages:

            messages_html = "<br>".join(html.escape(message) for message in log_messages)

            info_sections.append(

                widgets.HTML(value=f"<p style='margin:0'>{messages_html}</p>")

            )

            section_titles.append("Run messages")



        info_accordion = widgets.Accordion(children=info_sections)

        with info_accordion.hold_trait_notifications():

            for index, title in enumerate(section_titles):

                info_accordion.set_title(index, title)

            info_accordion.selected_index = 0 if info_sections else None



        display(widgets.VBox([figure_widget, info_accordion]))



controls_layout = widgets.VBox(

    [

        widgets.HBox([steps_widget, seed_widget, simulation_widget], layout=ROW_LAYOUT),

        widgets.HBox([legacy_toggle, overlay_display_widget, overlay_save_widget], layout=ROW_LAYOUT),

        widgets.HBox([metrics_save_widget, energy_save_widget, dpi_widget], layout=ROW_LAYOUT),

        output_dir_widget,

        core_section,

        rider_section,

        driver_section,

        run_button,

        output_area,

    ]

)



display(controls_layout)



run_button.on_click(handle_run)

handle_run(None)




VBox(children=(HBox(children=(IntSlider(value=1000, continuous_update=False, description='Steps:', max=20000, …