# Integrator Testbed

Experiment with the core integrator across all supported simulation types. Toggle the legacy
comparison when needed, tune particle parameters, and explore output artefacts for debugging or validation.

In [4]:
from __future__ import annotations

import json
import sys
from pathlib import Path
from typing import Dict, List

import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display

# Ensure repository paths are importable when running in VS Code/notebooks
PROJECT_ROOT = Path.cwd().resolve()
while PROJECT_ROOT.name and PROJECT_ROOT.name != "LW_windows":
    PROJECT_ROOT = PROJECT_ROOT.parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

LEGACY_ROOT = PROJECT_ROOT / "legacy"
if str(LEGACY_ROOT) not in sys.path:
    sys.path.insert(0, str(LEGACY_ROOT))

VALIDATION_ROOT = PROJECT_ROOT / "examples" / "validation"
if str(VALIDATION_ROOT) not in sys.path:
    sys.path.insert(0, str(VALIDATION_ROOT))

from core_vs_legacy_benchmark import (
    DEFAULT_DRIVER_PARAMS,
    DEFAULT_RIDER_PARAMS,
    PARTICLE_PARAM_FIELDS,
    SimulationType,
    compute_delta_energy_series,
    run_benchmark,
    summarise_metrics,
)

COLOR_RIDER = "#0072B2"
COLOR_DRIVER = "#D55E00"
COLOR_LEGACY_RIDER = "#56B4E9"
COLOR_LEGACY_DRIVER = "#E69F00"
SCATTER_STYLE = {"s": 140, "alpha": 0.78, "linewidth": 0, "edgecolors": "none"}
TITLE_FONTSIZE = 18
LABEL_FONTSIZE = 16
TICK_FONTSIZE = 13
LEGEND_FONTSIZE = 12
AVAILABLE_DPI_CHOICES = (150, 300, 450, 600)
DEFAULT_PLOT_DPI = 300

plt.style.use("seaborn-v0_8-whitegrid")
plt.rcParams.update(
    {
        "axes.titlesize": TITLE_FONTSIZE,
        "axes.labelsize": LABEL_FONTSIZE,
        "xtick.labelsize": TICK_FONTSIZE,
        "ytick.labelsize": TICK_FONTSIZE,
        "legend.fontsize": LEGEND_FONTSIZE,
    }
)


In [5]:
PARAM_LABELS = {
    "starting_distance": "Start z (mm)",
    "transv_mom": "Transverse mom (amu·mm/ns)",
    "starting_Pz": "Initial Pz (amu·mm/ns)",
    "stripped_ions": "Stripped ions",
    "m_particle": "Mass (amu)",
    "transv_dist": "Transverse spread (mm)",
    "pcount": "Particle count",
    "charge_sign": "Charge sign",
}

CORE_PARAM_LABELS = {
    "time_step": "Time step (ns)",
    "wall_z": "Wall z (mm)",
    "aperture_radius": "Aperture radius (mm)",
    "mean": "Mean separation (mm)",
    "cav_spacing": "Cavity spacing (mm)",
    "z_cutoff": "z cutoff (mm)",
}

PARTICLE_WIDGET_WIDTH = widgets.Layout(width="260px")
CORE_WIDGET_WIDTH = widgets.Layout(width="230px")

SIMULATION_TYPE_OPTIONS = {
    "Conducting wall": SimulationType.CONDUCTING_WALL,
    "Switching wall": SimulationType.SWITCHING_WALL,
    "Bunch to bunch": SimulationType.BUNCH_TO_BUNCH,
}

CORE_PARAM_DEFAULTS = {
    "time_step": 2.2e-7,
    "wall_z": 1e5,
    "aperture_radius": 1e5,
    "mean": 1e5,
    "cav_spacing": 1e5,
    "z_cutoff": 0.0,
}

CORE_REQUIRED_PARAMS = {
    SimulationType.CONDUCTING_WALL: {"time_step", "wall_z", "aperture_radius"},
    SimulationType.SWITCHING_WALL: {
        "time_step",
        "wall_z",
        "aperture_radius",
        "cav_spacing",
        "z_cutoff",
    },
    SimulationType.BUNCH_TO_BUNCH: {"time_step", "aperture_radius"},
}


def _make_particle_widgets(defaults: Dict[str, float | int]) -> Dict[str, widgets.Widget]:
    controls: Dict[str, widgets.Widget] = {}
    for name in PARTICLE_PARAM_FIELDS:
        default_value = defaults[name]
        description = PARAM_LABELS.get(name, name.replace("_", " ").title())
        if isinstance(default_value, int):
            if name == "pcount":
                control = widgets.IntSlider(
                    value=int(default_value),
                    min=1,
                    max=256,
                    step=1,
                    description=description,
                    continuous_update=False,
                    layout=PARTICLE_WIDGET_WIDTH,
                )
            else:
                control = widgets.IntText(
                    value=int(default_value),
                    description=description,
                    layout=PARTICLE_WIDGET_WIDTH,
                )
        else:
            control = widgets.FloatText(
                value=float(default_value),
                description=description,
                layout=PARTICLE_WIDGET_WIDTH,
            )
        controls[name] = control
    return controls


def _make_core_widgets() -> Dict[str, widgets.Widget]:
    controls: Dict[str, widgets.Widget] = {}
    for name, default_value in CORE_PARAM_DEFAULTS.items():
        description = CORE_PARAM_LABELS.get(name, name.replace("_", " ").title())
        controls[name] = widgets.FloatText(
            value=float(default_value),
            description=description,
            layout=CORE_WIDGET_WIDTH,
        )
    return controls


def _particle_rows(controls: Dict[str, widgets.Widget]) -> List[widgets.Widget]:
    rows: List[widgets.Widget] = []
    row: List[widgets.Widget] = []
    for name in PARTICLE_PARAM_FIELDS:
        row.append(controls[name])
        if len(row) == 2:
            rows.append(widgets.HBox(row))
            row = []
    if row:
        rows.append(widgets.HBox(row))
    return rows


def _core_rows(controls: Dict[str, widgets.Widget]) -> List[widgets.Widget]:
    rows: List[widgets.Widget] = []
    row: List[widgets.Widget] = []
    for name in CORE_PARAM_DEFAULTS:
        row.append(controls[name])
        if len(row) == 3:
            rows.append(widgets.HBox(row))
            row = []
    if row:
        rows.append(widgets.HBox(row))
    return rows


def _collect_particle_values(controls: Dict[str, widgets.Widget]) -> Dict[str, float | int]:
    values: Dict[str, float | int] = {}
    for name, control in controls.items():
        values[name] = control.value
    return values


def _collect_core_values(controls: Dict[str, widgets.Widget]) -> Dict[str, float]:
    values: Dict[str, float] = {}
    for name, control in controls.items():
        values[name] = float(control.value)
    return values


def _build_particle_section(title: str, controls: Dict[str, widgets.Widget]) -> widgets.Accordion:
    rows = _particle_rows(controls)
    accordion = widgets.Accordion(children=[widgets.VBox(rows)])
    accordion.set_title(0, title)
    return accordion


def _build_core_section(controls: Dict[str, widgets.Widget]) -> widgets.Accordion:
    rows = _core_rows(controls)
    accordion = widgets.Accordion(children=[widgets.VBox(rows)])
    accordion.set_title(0, "Core configuration")
    return accordion


def _required_params_for(sim_type: SimulationType) -> set[str]:
    return CORE_REQUIRED_PARAMS.get(sim_type, set())


def _apply_core_param_state(
    controls: Dict[str, widgets.Widget], required_params: set[str]
) -> None:
    for name, control in controls.items():
        control.disabled = name not in required_params
        control.layout.opacity = 1.0 if name in required_params else 0.45


In [6]:
steps_widget = widgets.IntSlider(
    value=1000,
    min=10,
    max=20000,
    step=10,
    description="Steps:",
    continuous_update=False,
)
seed_widget = widgets.IntText(value=12345, description="Seed:")
legacy_toggle = widgets.Checkbox(value=False, description="Include legacy comparison")
simulation_widget = widgets.Dropdown(
    options=[(label, value) for label, value in SIMULATION_TYPE_OPTIONS.items()],
    value=SimulationType.BUNCH_TO_BUNCH,
    description="Core mode:",
)
overlay_display_widget = widgets.Checkbox(
    value=False, description="Show overlay plot"
)
overlay_save_widget = widgets.Checkbox(
    value=False, description="Save overlay plot"
)
metrics_save_widget = widgets.Checkbox(
    value=False, description="Save metrics JSON"
)
energy_save_widget = widgets.Checkbox(
    value=False, description="Save ΔE plots"
)
dpi_widget = widgets.Dropdown(
    options=[(f"{value} dpi", value) for value in AVAILABLE_DPI_CHOICES],
    value=DEFAULT_PLOT_DPI,
    description="Plot DPI:",
    layout=widgets.Layout(width="200px"),
)
output_dir_widget = widgets.Text(
    value="test_outputs/testbed_runs",
    description="Output dir:",
    layout=widgets.Layout(width="350px"),
)
run_button = widgets.Button(
    description="Run integrator",
    icon="play",
    button_style="primary",
)
output_area = widgets.Output()

rider_controls = _make_particle_widgets(DEFAULT_RIDER_PARAMS)
driver_controls = _make_particle_widgets(DEFAULT_DRIVER_PARAMS)
core_controls = _make_core_widgets()
rider_section = _build_particle_section("Rider particle", rider_controls)
driver_section = _build_particle_section("Driver particle", driver_controls)
core_section = _build_core_section(core_controls)


def _resolved_output_dir() -> Path:
    raw_value = output_dir_widget.value.strip() or "test_outputs/testbed_runs"
    resolved = Path(raw_value).expanduser()
    return resolved


def _save_dir_if_needed(should_create: bool) -> Path:
    directory = _resolved_output_dir()
    if should_create:
        directory.mkdir(parents=True, exist_ok=True)
    return directory


def _update_legacy_controls(change) -> None:
    legacy_enabled = bool(change["new"] if isinstance(change, dict) else legacy_toggle.value)
    for checkbox in (overlay_display_widget, overlay_save_widget, metrics_save_widget):
        checkbox.disabled = not legacy_enabled
        if not legacy_enabled:
            checkbox.value = False


def _update_core_controls(change=None) -> None:
    sim_value = change["new"] if isinstance(change, dict) else simulation_widget.value
    required = _required_params_for(sim_value)
    _apply_core_param_state(core_controls, required)


legacy_toggle.observe(lambda change: _update_legacy_controls(change), names="value")
_update_legacy_controls({"new": legacy_toggle.value})

simulation_widget.observe(_update_core_controls, names="value")
_update_core_controls()


def handle_run(_):
    with output_area:
        output_area.clear_output()

        legacy_enabled = legacy_toggle.value
        rider_params = _collect_particle_values(rider_controls)
        driver_params = _collect_particle_values(driver_controls)
        core_params = _collect_core_values(core_controls)
        required_core = _required_params_for(simulation_widget.value)

        needs_dir = any(
            [overlay_save_widget.value, metrics_save_widget.value, energy_save_widget.value]
        )
        save_dir = _save_dir_if_needed(needs_dir)

        save_json_path = (
            save_dir / "benchmark_metrics.json"
            if legacy_enabled and metrics_save_widget.value
            else None
        )
        save_overlay_path = (
            save_dir / "core_legacy_overlay.png"
            if legacy_enabled and overlay_save_widget.value
            else None
        )
        save_energy_path = (
            save_dir / "delta_energy_scatter.png"
            if energy_save_widget.value
            else None
        )

        overlay_requested = legacy_enabled and (
            overlay_display_widget.value or overlay_save_widget.value
        )

        plot_dpi_value = dpi_widget.value
        time_step_value = core_params["time_step"]
        wall_z_value = core_params["wall_z"] if "wall_z" in required_core else None
        aperture_radius_value = (
            core_params["aperture_radius"] if "aperture_radius" in required_core else None
        )
        mean_value = core_params["mean"] if "mean" in required_core else None
        cav_spacing_value = core_params["cav_spacing"] if "cav_spacing" in required_core else None
        z_cutoff_value = core_params["z_cutoff"] if "z_cutoff" in required_core else None

        metrics, payload = run_benchmark(
            steps=steps_widget.value,
            seed=seed_widget.value,
            rider_params=rider_params,
            driver_params=driver_params,
            legacy_enabled=legacy_enabled,
            simulation_type=simulation_widget.value,
            time_step=time_step_value,
            wall_z=wall_z_value,
            aperture_radius=aperture_radius_value,
            mean=mean_value,
            cav_spacing=cav_spacing_value,
            z_cutoff=z_cutoff_value,
            save_json=save_json_path,
            save_fig=save_overlay_path,
            show=overlay_display_widget.value and legacy_enabled,
            plot=overlay_requested,
            return_trajectories=True,
            plot_dpi=plot_dpi_value,
        )

        rider_delta, rider_z = compute_delta_energy_series(
            payload["core"]["rider"],
            payload["initial_states"]["rider"],
            payload["rest_energy_mev"]["rider"],
        )
        driver_delta, driver_z = compute_delta_energy_series(
            payload["core"]["driver"],
            payload["initial_states"]["driver"],
            payload["rest_energy_mev"]["driver"],
        )

        fig, axes = plt.subplots(
            1,
            2,
            figsize=(13, 6),
            constrained_layout=True,
            dpi=plot_dpi_value,
        )

        fig.patch.set_facecolor("white")

        axes[0].scatter(
            rider_z,
            rider_delta,
            color=COLOR_RIDER,
            label="Rider ΔE",
            **SCATTER_STYLE,
        )
        axes[0].set_title("Rider ΔE vs z")
        axes[0].set_xlabel("z (mm)")
        axes[0].set_ylabel("ΔE (GeV)")
        axes[0].grid(True, alpha=0.3)
        axes[0].legend()

        axes[1].scatter(
            driver_z,
            driver_delta,
            color=COLOR_DRIVER,
            label="Driver ΔE",
            **SCATTER_STYLE,
        )
        axes[1].set_title("Driver ΔE vs z")
        axes[1].set_xlabel("z (mm)")
        axes[1].set_ylabel("ΔE (GeV)")
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()

        for axis in axes:
            axis.tick_params(axis="both", which="major", labelsize=TICK_FONTSIZE)

        if save_energy_path is not None:
            fig.savefig(save_energy_path, dpi=plot_dpi_value, bbox_inches="tight")
            print(f"ΔE scatter saved to {save_energy_path}")

        plt.show()

        if metrics:
            print(summarise_metrics(metrics))
        elif not legacy_enabled:
            print("Legacy comparison skipped; no metrics computed.")

        print(
            json.dumps(
                {
                    "required_core_params": {k: core_params[k] for k in required_core},
                    "optional_core_params": {
                        k: core_params[k] for k in core_params if k not in required_core
                    },
                    "plot_dpi": plot_dpi_value,
                },
                indent=2,
                sort_keys=True,
            )
        )


controls_layout = widgets.VBox(
    [
        widgets.HBox([steps_widget, seed_widget, simulation_widget]),
        widgets.HBox([legacy_toggle, overlay_display_widget, overlay_save_widget]),
        widgets.HBox([metrics_save_widget, energy_save_widget, dpi_widget]),
        output_dir_widget,
        core_section,
        rider_section,
        driver_section,
        run_button,
        output_area,
    ]
)

display(controls_layout)

run_button.on_click(handle_run)
handle_run(None)


VBox(children=(HBox(children=(IntSlider(value=1000, continuous_update=False, description='Steps:', max=20000, …