# Core vs Legacy Two-Particle Benchmark

This notebook mirrors `core_vs_legacy_two_particle_benchmark.py` so you can run the legacy
vs core comparison interactively inside VS Code. Use the helper cells below to configure
paths, run the benchmark, and inspect the resulting metrics or plots.

In [1]:
from __future__ import annotations

import copy

import json

import sys

from pathlib import Path

from typing import Dict, Iterable, List, Tuple



import ipywidgets as widgets

import matplotlib.pyplot as plt

import numpy as np

from IPython.display import display



# Ensure repository paths are importable when running in VS Code/notebooks

PROJECT_ROOT = Path.cwd().resolve()

while PROJECT_ROOT.name and PROJECT_ROOT.name != "LW_windows":

    PROJECT_ROOT = PROJECT_ROOT.parent

if str(PROJECT_ROOT) not in sys.path:

    sys.path.insert(0, str(PROJECT_ROOT))



LEGACY_ROOT = PROJECT_ROOT / "legacy"

if str(LEGACY_ROOT) not in sys.path:

    sys.path.insert(0, str(LEGACY_ROOT))



from core.trajectory_integrator import SimulationType, retarded_integrator

from legacy.bunch_inits import init_bunch  # type: ignore

from legacy.covariant_integrator_library import (  # type: ignore

    retarded_integrator as legacy_retarded_integrator,

)

In [2]:
FIELDS_TO_TRACK: Tuple[str, ...] = ("x", "y", "z", "Px", "Py", "Pz", "Pt", "gamma")

def _normalize_state(state: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:

    normalized: Dict[str, np.ndarray] = {}

    for key, value in state.items():

        if isinstance(value, np.ndarray):

            normalized[key] = value

        elif np.isscalar(value):

            normalized[key] = np.asarray([value], dtype=float)

        else:

            normalized[key] = np.asarray(value, dtype=float)

    return normalized





def _convert_legacy_state(state: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:

    converted: Dict[str, np.ndarray] = {}

    length = len(state["x"])

    for key, value in state.items():

        if isinstance(value, np.ndarray):

            converted[key] = value

        elif key in {"q", "m", "char_time"}:

            converted[key] = np.full(length, value, dtype=float)

        else:

            converted[key] = np.asarray(value, dtype=float)

    return converted





def _extract_series(states: Iterable[Dict[str, np.ndarray]], field: str) -> np.ndarray:

    return np.asarray([state[field][0] for state in states], dtype=float)





def prepare_two_particle_demo(

    seed: int,

) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], float, float]:

    np.random.seed(seed)

    rider_state, rider_rest_mev = init_bunch(

        starting_distance=1e-6,

        transv_mom=0.0,

        starting_Pz=1.01e6,

        stripped_ions=1.0,

        m_particle=1.007319468,

        transv_dist=2e-4,

        pcount=5,

        charge_sign=-1.0,

    )

    driver_state, driver_rest_mev = init_bunch(

        starting_distance=1000.0,

        transv_mom=0.0,

        starting_Pz=-1.01e6 / 207.2 * 1.007319468,

        stripped_ions=54.0,

        m_particle=207.2,

        transv_dist=2e-4 - 8e-2,

        pcount=5,

        charge_sign=1.0,

    )

    return (

        rider_state,

        driver_state,

        float(rider_rest_mev),

        float(driver_rest_mev),

    )





def run_legacy_integrator(

    rider_state: Dict[str, np.ndarray],

    driver_state: Dict[str, np.ndarray],

    steps: int,

) -> Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]]:

    legacy_traj, legacy_drv = legacy_retarded_integrator(

        steps,

        2.2e-7,

        1e5,

        1e5,

        2,

        rider_state,

        driver_state,

        1e5,

        1e5,

        0.0,

    )

    return (

        [_normalize_state(state) for state in legacy_traj],

        [_normalize_state(state) for state in legacy_drv],

    )





def run_core_integrator(

    rider_state: Dict[str, np.ndarray],

    driver_state: Dict[str, np.ndarray],

    steps: int,

) -> Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]]:

    core_traj, core_drv = retarded_integrator(

        steps=steps,

        h_step=2.2e-7,

        wall_z=1e5,

        aperture_radius=1e5,

        sim_type=SimulationType.BUNCH_TO_BUNCH,

        init_rider=_convert_legacy_state(copy.deepcopy(rider_state)),

        init_driver=_convert_legacy_state(copy.deepcopy(driver_state)),

        mean=1e5,

        cav_spacing=1e5,

        z_cutoff=0.0,

    )

    return (

        [_normalize_state(state) for state in core_traj],

        [_normalize_state(state) for state in core_drv],

    )





def compute_metrics(

    legacy: Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]],

    core: Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]],

) -> Dict[str, Dict[str, float]]:

    metrics: Dict[str, Dict[str, float]] = {}

    for label, legacy_states, core_states in (

        ("rider", legacy[0], core[0]),

        ("driver", legacy[1], core[1]),

    ):

        summary: Dict[str, float] = {}

        for field in FIELDS_TO_TRACK:

            legacy_series = _extract_series(legacy_states, field)

            core_series = _extract_series(core_states, field)

            diff = core_series - legacy_series

            max_abs = float(np.max(np.abs(diff)))

            rel = np.abs(diff) / np.maximum(np.abs(legacy_series), 1e-12)

            summary[f"{field}_max_abs"] = max_abs

            summary[f"{field}_max_rel_pct"] = float(np.max(rel) * 100.0)

            summary[f"{field}_final_abs"] = float(diff[-1])

        metrics[label] = summary

    return metrics





def summarise_metrics(metrics: Dict[str, Dict[str, float]]) -> str:

    def line(prefix: str, field: str, label: str) -> str:

        abs_key = f"{field}_max_abs"

        rel_key = f"{field}_max_rel_pct"

        return (

            f"  {prefix:<6s} {field:<3s} : max |Δ| = {metrics[label][abs_key]:.3e}, "

            f"max rel = {metrics[label][rel_key]:.3e}%"

        )



    lines = ["Benchmark summary (relative to legacy trajectories):"]

    for label in ("rider", "driver"):

        lines.append(f"- {label.capitalize()}")

        for field in ("z", "Pt", "gamma"):

            lines.append(line(label, field, label))

    return "\n".join(lines)





def export_metrics(metrics: Dict[str, Dict[str, float]], destination: Path) -> None:

    destination.parent.mkdir(parents=True, exist_ok=True)

    with destination.open("w", encoding="utf-8") as fh:

        json.dump(metrics, fh, indent=2, sort_keys=True)





def plot_results(

    legacy: Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]],

    core: Tuple[List[Dict[str, np.ndarray]], List[Dict[str, np.ndarray]]],

    *,

    save_path: Path | None = None,

    show: bool = False,

) -> None:

    steps_axis = np.arange(len(legacy[0]))

    legacy_rider_z = _extract_series(legacy[0], "z")

    core_rider_z = _extract_series(core[0], "z")

    legacy_driver_z = _extract_series(legacy[1], "z")

    core_driver_z = _extract_series(core[1], "z")



    rider_gamma_diff = _extract_series(core[0], "gamma") - _extract_series(legacy[0], "gamma")

    driver_gamma_diff = _extract_series(core[1], "gamma") - _extract_series(legacy[1], "gamma")



    fig, axes = plt.subplots(2, 1, figsize=(10, 8), constrained_layout=True)



    axes[0].plot(steps_axis, legacy_rider_z, "--", label="Legacy rider")

    axes[0].plot(steps_axis, core_rider_z, label="Core rider")

    axes[0].plot(steps_axis, legacy_driver_z, "--", label="Legacy driver")

    axes[0].plot(steps_axis, core_driver_z, label="Core driver")

    axes[0].set_title("Trajectory overlap (z position)")

    axes[0].set_xlabel("Step")

    axes[0].set_ylabel("z (mm)")

    axes[0].grid(True, alpha=0.3)

    axes[0].legend()



    axes[1].plot(steps_axis, rider_gamma_diff, label="Δγ rider")

    axes[1].plot(steps_axis, driver_gamma_diff, label="Δγ driver")

    axes[1].set_title("Gamma difference (core − legacy)")

    axes[1].set_xlabel("Step")

    axes[1].set_ylabel("Δγ")

    axes[1].grid(True, alpha=0.3)

    axes[1].legend()



    if save_path is not None:

        fig.savefig(save_path, dpi=300, bbox_inches="tight")

    if show:

        plt.show()

    plt.close(fig)





def compute_delta_energy_series(

    states: List[Dict[str, np.ndarray]],

    initial_state: Dict[str, np.ndarray],

    rest_energy_mev: float,

) -> Tuple[np.ndarray, np.ndarray]:

    gamma_series = _extract_series(states, "gamma")

    initial_gamma = float(initial_state["gamma"][0])

    rest_energy_gev = rest_energy_mev * 1e-3

    delta_energy_gev = (gamma_series - initial_gamma) * rest_energy_gev

    z_series = _extract_series(states, "z")

    return delta_energy_gev, z_series





def run_benchmark(

    *,

    steps: int = 40,

    seed: int = 12345,

    save_json: Path | None = None,

    save_fig: Path | None = None,

    show: bool = False,

    plot: bool = True,

    return_trajectories: bool = False,

):

    rider_state, driver_state, rider_rest_mev, driver_rest_mev = prepare_two_particle_demo(seed)

    rider_initial = _normalize_state(copy.deepcopy(rider_state))

    driver_initial = _normalize_state(copy.deepcopy(driver_state))



    legacy_results = run_legacy_integrator(

        copy.deepcopy(rider_state), copy.deepcopy(driver_state), steps

    )

    core_results = run_core_integrator(

        copy.deepcopy(rider_state), copy.deepcopy(driver_state), steps

    )



    metrics = compute_metrics(legacy_results, core_results)



    if save_json is not None:

        export_metrics(metrics, save_json)

        print(f"Metrics written to {save_json}")



    if plot:

        plot_results(

            legacy_results,

            core_results,

            save_path=save_fig,

            show=show,

        )

        if save_fig is not None:

            print(f"Plot written to {save_fig}")



    if return_trajectories:

        payload = {

            "legacy": {

                "rider": legacy_results[0],

                "driver": legacy_results[1],

            },

            "core": {

                "rider": core_results[0],

                "driver": core_results[1],

            },

            "initial_states": {

                "rider": rider_initial,

                "driver": driver_initial,

            },

            "rest_energy_mev": {

                "rider": rider_rest_mev,

                "driver": driver_rest_mev,

            },

        }

        return metrics, payload



    return metrics


In [3]:
steps_widget = widgets.IntSlider(
    value=1000,
    min=10,
    max=10000,
    step=10,
    description="Steps:",
    continuous_update=False,
)
seed_widget = widgets.IntText(value=12345, description="Seed:")
overlay_display_widget = widgets.Checkbox(
    value=False, description="Show overlay plot"
)
overlay_save_widget = widgets.Checkbox(
    value=False, description="Save overlay plot"
)
metrics_save_widget = widgets.Checkbox(
    value=False, description="Save metrics JSON"
)
energy_save_widget = widgets.Checkbox(
    value=False, description="Save ΔE plots"
)
output_dir_widget = widgets.Text(
    value="test_outputs/notebook_runs",
    description="Output dir:",
    layout=widgets.Layout(width="350px"),
)
run_button = widgets.Button(
    description="Run benchmark",
    icon="play",
    button_style="success",
)
output_area = widgets.Output()


def _resolved_output_dir() -> Path:
    raw_value = output_dir_widget.value.strip() or "test_outputs/notebook_runs"
    resolved = Path(raw_value).expanduser()
    return resolved


def _save_dir_if_needed(should_create: bool) -> Path:
    directory = _resolved_output_dir()
    if should_create:
        directory.mkdir(parents=True, exist_ok=True)
    return directory


def handle_run(_):
    with output_area:
        output_area.clear_output()

        needs_dir = any(
            [overlay_save_widget.value, metrics_save_widget.value, energy_save_widget.value]
        )
        save_dir = _save_dir_if_needed(needs_dir)

        save_json_path = (
            save_dir / "benchmark_metrics.json" if metrics_save_widget.value else None
        )
        save_overlay_path = (
            save_dir / "core_legacy_overlay.png" if overlay_save_widget.value else None
        )
        save_energy_path = (
            save_dir / "delta_energy_scatter.png" if energy_save_widget.value else None
        )

        metrics, payload = run_benchmark(
            steps=steps_widget.value,
            seed=seed_widget.value,
            save_json=save_json_path,
            save_fig=save_overlay_path,
            show=overlay_display_widget.value,
            plot=overlay_display_widget.value or overlay_save_widget.value,
            return_trajectories=True,
        )

        rider_delta, rider_z = compute_delta_energy_series(
            payload["core"]["rider"],
            payload["initial_states"]["rider"],
            payload["rest_energy_mev"]["rider"],
        )
        driver_delta, driver_z = compute_delta_energy_series(
            payload["core"]["driver"],
            payload["initial_states"]["driver"],
            payload["rest_energy_mev"]["driver"],
        )

        fig, axes = plt.subplots(1, 2, figsize=(13, 6), constrained_layout=True)

        axes[0].scatter(rider_z, rider_delta, s=30, alpha=0.85, color="#1f77b4")
        axes[0].set_title("Rider ΔE vs z")
        axes[0].set_xlabel("z (mm)")
        axes[0].set_ylabel("ΔE (GeV)")
        axes[0].grid(True, alpha=0.3)

        axes[1].scatter(driver_z, driver_delta, s=30, alpha=0.85, color="#ff7f0e")
        axes[1].set_title("Driver ΔE vs z")
        axes[1].set_xlabel("z (mm)")
        axes[1].set_ylabel("ΔE (GeV)")
        axes[1].grid(True, alpha=0.3)

        if save_energy_path is not None:
            fig.savefig(save_energy_path, dpi=300, bbox_inches="tight")
            print(f"ΔE scatter saved to {save_energy_path}")

        plt.show()

        print(summarise_metrics(metrics))


display(
    widgets.VBox(
        [
            widgets.HBox([steps_widget, seed_widget]),
            widgets.HBox([overlay_display_widget, overlay_save_widget]),
            widgets.HBox([metrics_save_widget, energy_save_widget]),
            output_dir_widget,
            run_button,
            output_area,
        ]
    )
)

run_button.on_click(handle_run)
handle_run(None)

VBox(children=(HBox(children=(IntSlider(value=1000, continuous_update=False, description='Steps:', max=10000, …