# Core vs Legacy Benchmark

Interactively configure the two-particle benchmark that powers `core_vs_legacy_benchmark.py`. Toggle the legacy
comparison, adjust particle parameters, and capture plots or metrics directly from this notebook.

In [1]:
from __future__ import annotations


import sys

from pathlib import Path

from typing import Dict, List



import ipywidgets as widgets

import matplotlib.pyplot as plt


from IPython.display import display



# Ensure repository paths are importable when running in VS Code/notebooks

PROJECT_ROOT = Path.cwd().resolve()

while PROJECT_ROOT.name and PROJECT_ROOT.name != "LW_windows":

    PROJECT_ROOT = PROJECT_ROOT.parent

if str(PROJECT_ROOT) not in sys.path:

    sys.path.insert(0, str(PROJECT_ROOT))



LEGACY_ROOT = PROJECT_ROOT / "legacy"

if str(LEGACY_ROOT) not in sys.path:

    sys.path.insert(0, str(LEGACY_ROOT))



VALIDATION_ROOT = PROJECT_ROOT / "examples" / "validation"

if str(VALIDATION_ROOT) not in sys.path:

    sys.path.insert(0, str(VALIDATION_ROOT))



from core_vs_legacy_benchmark import (

    DEFAULT_DRIVER_PARAMS,

    DEFAULT_RIDER_PARAMS,

    PARTICLE_PARAM_FIELDS,

    SimulationType,

    compute_delta_energy_series,

    run_benchmark,

    summarise_metrics,

)

In [2]:
PARAM_LABELS = {
    "starting_distance": "Start z (mm)",
    "transv_mom": "Transverse mom (amu·mm/ns)",
    "starting_Pz": "Initial Pz (amu·mm/ns)",
    "stripped_ions": "Stripped ions",
    "m_particle": "Mass (amu)",
    "transv_dist": "Transverse spread (mm)",
    "pcount": "Particle count",
    "charge_sign": "Charge sign",
}

PARTICLE_WIDGET_WIDTH = widgets.Layout(width="260px")

SIMULATION_TYPE_OPTIONS = {
    "Conducting wall": SimulationType.CONDUCTING_WALL,
    "Switching wall": SimulationType.SWITCHING_WALL,
    "Bunch to bunch": SimulationType.BUNCH_TO_BUNCH,
}


def _make_particle_widgets(defaults: Dict[str, float | int]) -> Dict[str, widgets.Widget]:
    controls: Dict[str, widgets.Widget] = {}
    for name in PARTICLE_PARAM_FIELDS:
        default_value = defaults[name]
        description = PARAM_LABELS.get(name, name.replace("_", " ").title())
        if isinstance(default_value, int):
            if name == "pcount":
                control = widgets.IntSlider(
                    value=int(default_value),
                    min=1,
                    max=128,
                    step=1,
                    description=description,
                    continuous_update=False,
                    layout=PARTICLE_WIDGET_WIDTH,
                )
            else:
                control = widgets.IntText(
                    value=int(default_value),
                    description=description,
                    layout=PARTICLE_WIDGET_WIDTH,
                )
        else:
            control = widgets.FloatText(
                value=float(default_value),
                description=description,
                layout=PARTICLE_WIDGET_WIDTH,
            )
        controls[name] = control
    return controls


def _particle_rows(controls: Dict[str, widgets.Widget]) -> List[widgets.Widget]:
    rows: List[widgets.Widget] = []
    row: List[widgets.Widget] = []
    for name in PARTICLE_PARAM_FIELDS:
        row.append(controls[name])
        if len(row) == 2:
            rows.append(widgets.HBox(row))
            row = []
    if row:
        rows.append(widgets.HBox(row))
    return rows


def _collect_particle_values(controls: Dict[str, widgets.Widget]) -> Dict[str, float | int]:
    values: Dict[str, float | int] = {}
    for name, control in controls.items():
        values[name] = control.value
    return values


def _build_particle_section(title: str, controls: Dict[str, widgets.Widget]) -> widgets.Accordion:
    rows = _particle_rows(controls)
    accordion = widgets.Accordion(children=[widgets.VBox(rows)])
    accordion.set_title(0, title)
    return accordion

In [None]:
steps_widget = widgets.IntSlider(
    value=1000,
    min=10,
    max=10000,
    step=10,
    description="Steps:",
    continuous_update=False,
)
seed_widget = widgets.IntText(value=12345, description="Seed:")
legacy_toggle = widgets.Checkbox(value=True, description="Include legacy comparison")
simulation_widget = widgets.Dropdown(
    options=[(label, value) for label, value in SIMULATION_TYPE_OPTIONS.items()],
    value=SimulationType.BUNCH_TO_BUNCH,
    description="Core mode:",
)
overlay_display_widget = widgets.Checkbox(
    value=False, description="Show overlay plot"
)
overlay_save_widget = widgets.Checkbox(
    value=False, description="Save overlay plot"
)
metrics_save_widget = widgets.Checkbox(
    value=False, description="Save metrics JSON"
)
energy_save_widget = widgets.Checkbox(
    value=False, description="Save ΔE plots"
)
output_dir_widget = widgets.Text(
    value="test_outputs/notebook_runs",
    description="Output dir:",
    layout=widgets.Layout(width="350px"),
)
run_button = widgets.Button(
    description="Run benchmark",
    icon="play",
    button_style="success",
)
output_area = widgets.Output()

rider_controls = _make_particle_widgets(DEFAULT_RIDER_PARAMS)
driver_controls = _make_particle_widgets(DEFAULT_DRIVER_PARAMS)
rider_section = _build_particle_section("Rider particle", rider_controls)
driver_section = _build_particle_section("Driver particle", driver_controls)


def _resolved_output_dir() -> Path:
    raw_value = output_dir_widget.value.strip() or "test_outputs/notebook_runs"
    resolved = Path(raw_value).expanduser()
    return resolved


def _save_dir_if_needed(should_create: bool) -> Path:
    directory = _resolved_output_dir()
    if should_create:
        directory.mkdir(parents=True, exist_ok=True)
    return directory


def _update_legacy_controls(change) -> None:
    legacy_enabled = bool(change["new"] if isinstance(change, dict) else legacy_toggle.value)
    for checkbox in (overlay_display_widget, overlay_save_widget, metrics_save_widget):
        checkbox.disabled = not legacy_enabled
        if not legacy_enabled:
            checkbox.value = False


legacy_toggle.observe(lambda change: _update_legacy_controls(change), names="value")
_update_legacy_controls({"new": legacy_toggle.value})


def handle_run(_):
    with output_area:
        output_area.clear_output(wait=True)

        legacy_enabled = legacy_toggle.value
        rider_params = _collect_particle_values(rider_controls)
        driver_params = _collect_particle_values(driver_controls)

        needs_dir = any(
            [overlay_save_widget.value, metrics_save_widget.value, energy_save_widget.value]
        )
        save_dir = _save_dir_if_needed(needs_dir)

        save_json_path = (
            save_dir / "benchmark_metrics.json"
            if legacy_enabled and metrics_save_widget.value
            else None
        )
        save_overlay_path = (
            save_dir / "core_legacy_overlay.png"
            if legacy_enabled and overlay_save_widget.value
            else None
        )
        save_energy_path = (
            save_dir / "delta_energy_scatter.png"
            if energy_save_widget.value
            else None
        )

        overlay_requested = legacy_enabled and (
            overlay_display_widget.value or overlay_save_widget.value
        )

        metrics, payload = run_benchmark(
            steps=steps_widget.value,
            seed=seed_widget.value,
            rider_params=rider_params,
            driver_params=driver_params,
            legacy_enabled=legacy_enabled,
            simulation_type=simulation_widget.value,
            save_json=save_json_path,
            save_fig=save_overlay_path,
            show=overlay_display_widget.value and legacy_enabled,
            plot=overlay_requested,
            return_trajectories=True,
        )

        rider_delta, rider_z = compute_delta_energy_series(
            payload["core"]["rider"],
            payload["initial_states"]["rider"],
            payload["rest_energy_mev"]["rider"],
        )
        driver_delta, driver_z = compute_delta_energy_series(
            payload["core"]["driver"],
            payload["initial_states"]["driver"],
            payload["rest_energy_mev"]["driver"],
        )

        fig, axes = plt.subplots(1, 2, figsize=(13, 6), constrained_layout=True)

        axes[0].scatter(rider_z, rider_delta, s=30, alpha=0.85, color="#1f77b4")
        axes[0].set_title("Rider ΔE vs z")
        axes[0].set_xlabel("z (mm)")
        axes[0].set_ylabel("ΔE (GeV)")
        axes[0].grid(True, alpha=0.3)

        axes[1].scatter(driver_z, driver_delta, s=30, alpha=0.85, color="#ff7f0e")
        axes[1].set_title("Driver ΔE vs z")
        axes[1].set_xlabel("z (mm)")
        axes[1].set_ylabel("ΔE (GeV)")
        axes[1].grid(True, alpha=0.3)

        if save_energy_path is not None:
            fig.savefig(save_energy_path, dpi=300, bbox_inches="tight")
            print(f"ΔE scatter saved to {save_energy_path}")

        plt.show()

        if metrics:
            print(summarise_metrics(metrics))
        elif not legacy_enabled:
            print("Legacy comparison skipped; no metrics computed.")


controls_layout = widgets.VBox(
    [
        widgets.HBox([steps_widget, seed_widget, simulation_widget]),
        widgets.HBox([legacy_toggle, overlay_display_widget, overlay_save_widget]),
        widgets.HBox([metrics_save_widget, energy_save_widget]),
        output_dir_widget,
        rider_section,
        driver_section,
        run_button,
        output_area,
    ]
)

run_button.on_click(handle_run)
display(controls_layout)

VBox(children=(HBox(children=(IntSlider(value=1000, continuous_update=False, description='Steps:', max=10000, …