In [None]:
import pickle
import random

import matplotlib
import matplotlib.pyplot as plt
import mpl_toolkits.axisartist as axisartist
import numpy as np
import pandas as pd
import scienceplots
from icecream import ic
from tqdm import tqdm

from dataset import EuXFELCurrentDataset, EuXFELLPSDataset
from legacy import SupervisedCurrentProfileInference, SupervisedLPSInference
from nils.reconstruction_module_after_diss import master_recon
from nils.simulate_spectrometer_signal import get_crisp_signal
from utils import current2formfactor

In [2]:
plt.style.use(["science", "ieee"])

In [None]:
matplotlib.rcParams["text.latex.preamble"] += r" \usepackage{siunitx}"
matplotlib.rcParams["text.latex.preamble"] += r" \usepackage{bm}"
matplotlib.rcParams["text.latex.preamble"]

In [None]:
current_model = SupervisedCurrentProfileInference.load_from_checkpoint(
    "virtual-diagnostics-euxfel-current-legacy/og6sdbm0/checkpoints/epoch=76-step=29876.ckpt"
)
current_model.eval()

lps_model = SupervisedLPSInference.load_from_checkpoint(
    "virtual-diagnostics-euxfel-lps-legacy/mptc9vmu/checkpoints/epoch=95-step=35136.ckpt"
)
lps_model.eval()

print("Loaded models!")

In [5]:
with open("data/zihan/train_scalers_current.pkl", "rb") as f:
    current_scalers = pickle.load(f)
current_dataset_test = EuXFELCurrentDataset(
    stage="test",
    normalize=True,
    rf_scaler=current_scalers["rf"],
    formfactor_scaler=current_scalers["formfactor"],
    current_scaler=current_scalers["current"],
    bunch_length_scaler=current_scalers["bunch_length"],
)

with open("data/zihan/train_scalers_lps.pkl", "rb") as f:
    lps_scalers = pickle.load(f)
lps_dataset_test = EuXFELLPSDataset(
    stage="test",
    normalize=True,
    rf_scaler=lps_scalers["rf"],
    formfactor_scaler=lps_scalers["formfactor"],
    lps_image_scaler=lps_scalers["lps_image"],
    lps_range_scaler=lps_scalers["lps_range"],
)

In [None]:
random.seed("IPAC")
# random.seed(24)
num_examples = 4
correction_scale = 1.023  # Corrects for Matplotlib making figure slightly too small
cmap = "viridis"
cmap_cutoff = 0
samples = random.choices(range(len(lps_dataset_test)), k=num_examples)

# Swap and filter plot columns around
samples[0], samples[3] = samples[3], samples[0]
samples[1], samples[3] = samples[3], samples[1]
samples = samples[:3]
samples[1], samples[2] = samples[2], samples[1]

ic(
    plt.rcParams["xtick.labelsize"],
    plt.rcParams["font.size"],
    plt.rcParams["axes.linewidth"],
)

fig, axs = plt.subplots(
    3,
    len(samples),
    figsize=(
        426.79 / 72.27 * correction_scale,
        426.79 / 72.27 * 0.60 * correction_scale,
    ),
)

for i, sample in enumerate(samples):
    (_, _), (current_profile, bunch_length) = current_dataset_test[sample]
    (rf_settings, formfactor), (lps_image, lps_range) = lps_dataset_test[sample]

    x_rf = rf_settings.reshape((1, 5))
    x_formfactor = formfactor.reshape((1, 240))

    y_hat_current_profile, y_hat_bunch_length = current_model(x_rf, x_formfactor)
    y_hat_lps_image, y_hat_lps_range = lps_model(x_rf, x_formfactor)

    predicted_current_profile = current_scalers["current"].inverse_transform(
        y_hat_current_profile.detach().numpy()
    )
    predicted_bunch_length = current_scalers["bunch_length"].inverse_transform(
        y_hat_bunch_length.detach().numpy()
    )
    predicted_lps_image = (
        lps_scalers["lps_image"]
        .inverse_transform(y_hat_lps_image.detach().numpy().reshape(1, 300 * 300))
        .reshape(1, 300, 300)
    )
    predicted_lps_range = lps_scalers["lps_range"].inverse_transform(
        y_hat_lps_range.detach().numpy()
    )

    real_ss = np.linspace(
        -current_dataset_test.bunch_lengths[sample][0] / 2,
        current_dataset_test.bunch_lengths[sample][0] / 2,
        num=300,
    )
    predicted_ss = np.linspace(
        -predicted_bunch_length[0][0] / 2, predicted_bunch_length[0][0] / 2, num=300
    )

    # nils_frequencies, nils_formfactors, nils_formfactor_noise, nils_detlim = (
    #     get_crisp_signal(
    #         real_ss,
    #         current_dataset_test.current_profiles[sample],
    #         n_shots=10,
    #         which_set="both",
    #     )
    # )
    # recon_time, recon_current, _, _, _, _ = master_recon(
    #     nils_frequencies,
    #     nils_formfactors,
    #     nils_formfactor_noise,
    #     nils_detlim,
    #     1e-10,
    #     method="KKstart",
    #     channels_to_remove=[],
    #     show_plots=False,
    # )

    predicted_lps_image[predicted_lps_image < cmap_cutoff] = np.nan
    axs[0, i].imshow(
        np.flipud(predicted_lps_image[0].transpose()),
        vmin=0,
        aspect="auto",
        cmap=cmap,
        extent=(
            -predicted_lps_range[0, 0] / 2 * 1e6,
            predicted_lps_range[0, 0] / 2 * 1e6,
            -predicted_lps_range[0, 1] / 2 * 1e2,
            predicted_lps_range[0, 1] / 2 * 1e2,
        ),
    )

    real_lps_image = lps_dataset_test.lps_images[sample].copy()
    real_lps_image[real_lps_image < cmap_cutoff] = np.nan
    axs[1, i].imshow(
        np.flipud(real_lps_image.transpose()),
        vmin=0,
        aspect="auto",
        cmap=cmap,
        extent=(
            -lps_dataset_test.lps_ranges[sample, 0] / 2 * 1e6,
            lps_dataset_test.lps_ranges[sample, 0] / 2 * 1e6,
            -lps_dataset_test.lps_ranges[sample, 1] / 2 * 1e2,
            lps_dataset_test.lps_ranges[sample, 1] / 2 * 1e2,
        ),
    )

    axs[2, i].plot(
        real_ss * 1e6,
        current_dataset_test.current_profiles[sample] * 1e-3,
        label="Ground truth",
    )
    axs[2, i].plot(
        predicted_ss * 1e6, predicted_current_profile[0] * 1e-3, label="Reconstruction"
    )
    # plt.plot(recon_time, recon_current)
    axs[2, i].set_xlabel(r"$s$ (\unit{\micro\meter})")

axs[0, 0].set_ylabel(r"$\delta_E$ (\unit{\percent})")
axs[1, 0].set_ylabel(r"$\delta_E$ (\unit{\percent})")
axs[2, 0].set_ylabel(r"$I$ (\unit{\kilo\ampere})")
axs[0, len(samples) - 1].text(
    0.95,
    0.1,
    "Reconstruction",
    color="white",
    horizontalalignment="right",
    transform=axs[0, i].transAxes,
)
axs[1, len(samples) - 1].text(
    0.95,
    0.1,
    "Ground truth",
    color="white",
    horizontalalignment="right",
    transform=axs[1, i].transAxes,
)
axs[2, len(samples) - 1].legend()

axs[2, 1].sharey(axs[2, 0])
axs[2, 2].sharey(axs[2, 0])

for ax in axs[0]:
    ax.set_xticklabels([])
for ax in axs[1]:
    ax.set_xticklabels([])

plt.tight_layout()
# plt.subplots_adjust(wspace=0.4, hspace=0.3)
fig.savefig("figures/dissertation_prediction_examples.pdf")
plt.show()

In [7]:
# from pathlib import Path

# import ocelot as oc
# from ocelot.gui.accelerator import show_e_beam
# import pandas as pd

# plt.style.use("default")

# df = pd.read_pickle("data/zihan/data_20220905_test.pkl")

# filename = Path(df.iloc[sample]["file"])
# particles = oc.load_particle_array("data/zihan/particles" / filename)
# show_e_beam(particles)

In [8]:
# plt.style.use(["science", "no-latex"])

In [None]:
# For neural network flowchart in dissertation

plt.figure(figsize=(90 / 72 * 2, 59 / 72 * 2))
plt.plot(lps_dataset_test.formfactors[samples[-1]], color="#3700CC")
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)
plt.xticks([])
# plt.xlabel(r"Frequency (\unit{\tera\hertz})")
plt.yticks([])
# plt.ylabel(r"$\bm{x}_{\unit{\tera\hertz}} \mathrm{(|F|)}$")
plt.tight_layout()
plt.savefig("figures/dissertation_flowchart_formfactor.svg", dpi=300)
plt.show()

plt.imshow(
    np.flipud(lps_dataset_test.lps_images[samples[-1]].transpose()),
    vmin=0,
    interpolation="nearest",
)
plt.axis("off")
plt.tight_layout()
plt.savefig("figures/dissertation_flowchart_lps_image.svg")
plt.show()

In [None]:
# For EuXFEL VD overview in dissertation

ic(plt.rcParams["axes.linewidth"], plt.rcParams["figure.dpi"])

fix_scale = 1.28
plt.figure(figsize=(59.8 / 72 * fix_scale, 59.8 * 0.73 / 72 * fix_scale))
ic(
    "After figure is created",
    plt.gcf().get_size_inches(),
    plt.gcf().get_size_inches() * 72,
    plt.gcf().get_size_inches() * 2.54,
)
plt.plot(lps_dataset_test.formfactors[samples[-1]], color="#3700CC", linewidth=0.5)
# plt.gca().spines["right"].set_visible(False)
# plt.gca().spines["top"].set_visible(False)
plt.xticks([])
# plt.xlabel(r"Frequency (\unit{\tera\hertz})")
plt.yticks([])
# plt.ylabel(r"$\bm{x}_{\unit{\tera\hertz}} \mathrm{(|F|)}$")

plt.tight_layout()
ic(
    "Just before saving",
    plt.gcf().get_size_inches(),
    plt.gcf().get_size_inches() * 72,
    plt.gcf().get_size_inches() * 2.54,
)
plt.savefig(
    "figures/dissertation_euxfel_vd_overview_formfactor.pdf",
    bbox_inches=None,
    pad_inches=0.25 / 72,
)
plt.show()

## Quantitive analysis


In [None]:
losses = []
predicted_bunch_lengths = []
predicted_lps_images = []
for (rf_settings, formfactor), (
    true_lps_image,
    true_lps_range,
) in tqdm(lps_dataset_test):
    x_rf = rf_settings.reshape((1, 5))
    x_formfactor = formfactor.reshape((1, 240))

    y_lps_image = true_lps_image.reshape((1, 300, 300))
    y_lps_range = true_lps_range.reshape((1, 2))
    y_hat_lps_image, y_hat_lps_range = lps_model(x_rf, x_formfactor)

    predicted_lps_image = (
        lps_scalers["lps_image"]
        .inverse_transform(y_hat_lps_image.detach().numpy().reshape(1, 300 * 300))
        .reshape(300, 300)
    )
    predicted_lps_range = (
        lps_scalers["lps_range"]
        .inverse_transform(y_hat_lps_range.detach().numpy())
        .reshape(2)
    )

    lps_image_loss = lps_model.lps_image_criterion(
        y_hat_lps_image, y_lps_image
    ).detach()
    lps_range_loss = lps_model.lps_range_criterion(
        y_hat_lps_range, y_lps_range
    ).detach()
    loss = lps_image_loss + lps_range_loss

    losses.append(loss.item())
    predicted_bunch_lengths.append(predicted_lps_range[0])
    predicted_lps_images.append(predicted_lps_image)

losses = np.array(losses)
predicted_bunch_lengths = np.array(predicted_bunch_lengths)
predicted_lps_images = np.array(predicted_lps_images)

In [12]:
bunch_length_errors = np.abs(
    predicted_bunch_lengths - lps_dataset_test.lps_ranges[:, 0]
)

true_currents = lps_dataset_test.lps_images.sum(axis=(1, 2))
predicted_currents = predicted_lps_images.sum(axis=(1, 2))
current_errors = np.abs(predicted_currents - true_currents)

true_peak_currents = lps_dataset_test.lps_images.sum(axis=2).max(axis=1)
predicted_peak_currents = predicted_lps_images.sum(axis=2).max(axis=1)
peak_current_errors = np.abs(predicted_peak_currents - true_peak_currents)

In [None]:
_ = ic(
    np.median(true_currents),
    true_currents.mean(),
    true_currents.std(),
    np.median(predicted_currents),
    predicted_currents.mean(),
    predicted_currents.std(),
)

In [None]:
i = 324
plt.imshow(lps_dataset_test.lps_images[i])
plt.show()

lps_dataset_test.lps_images[i].sum()

In [None]:
lps_dataset_test.lps_images.shape

In [None]:
true_currents.shape

In [None]:
true_currents

In [None]:
scatter_alpha = 0.025

fig, axs = plt.subplots(
    4,
    7,
    sharex="col",
    sharey="row",
    figsize=(426.79 / 72.27 * 5, 426.79 / 72.27 * 5 * 0.48),
)

axs[0, 0].scatter(
    lps_dataset_test.lps_ranges[:, 0], losses, alpha=scatter_alpha, color="C1"
)
# axs[0, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[0, 0].set_ylabel("Loss")
axs[0, 0].set_yscale("log")

axs[0, 1].scatter(
    lps_dataset_test.lps_ranges[:, 1], losses, alpha=scatter_alpha, color="C1"
)
# axs[0, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[0, 2].scatter(
    lps_dataset_test.rf_settings[:, 0], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[0, 3].scatter(
    lps_dataset_test.rf_settings[:, 1], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[0, 4].scatter(
    lps_dataset_test.rf_settings[:, 2], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[0, 5].scatter(
    lps_dataset_test.rf_settings[:, 3], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[0, 6].scatter(
    lps_dataset_test.rf_settings[:, 4], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[1, 0].scatter(
    lps_dataset_test.lps_ranges[:, 0],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C1",
)
# axs[1, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[1, 0].set_ylabel(r"Bunch length error (\unit{\meter})")
# axs[1, 0].set_yscale("log")

axs[1, 1].scatter(
    lps_dataset_test.lps_ranges[:, 1],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C1",
)
# axs[1, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[1, 2].scatter(
    lps_dataset_test.rf_settings[:, 0],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[1, 3].scatter(
    lps_dataset_test.rf_settings[:, 1],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[1, 4].scatter(
    lps_dataset_test.rf_settings[:, 2],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[1, 5].scatter(
    lps_dataset_test.rf_settings[:, 3],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[1, 6].scatter(
    lps_dataset_test.rf_settings[:, 4],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[2, 0].scatter(
    lps_dataset_test.lps_ranges[:, 0], current_errors, alpha=scatter_alpha, color="C1"
)
axs[2, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[2, 0].set_ylabel(r"Current error (\unit{\ampere})")
# axs[2, 0].set_yscale("log")

axs[2, 1].scatter(
    lps_dataset_test.lps_ranges[:, 1], current_errors, alpha=scatter_alpha, color="C1"
)
# axs[2, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[2, 2].scatter(
    lps_dataset_test.rf_settings[:, 0], current_errors, alpha=scatter_alpha, color="C2"
)
# axs[2, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[2, 3].scatter(
    lps_dataset_test.rf_settings[:, 1], current_errors, alpha=scatter_alpha, color="C2"
)
# axs[2, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[2, 4].scatter(
    lps_dataset_test.rf_settings[:, 2], current_errors, alpha=scatter_alpha, color="C2"
)
# axs[2, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[2, 5].scatter(
    lps_dataset_test.rf_settings[:, 3], current_errors, alpha=scatter_alpha, color="C2"
)
# axs[2, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[2, 6].scatter(
    lps_dataset_test.rf_settings[:, 4], current_errors, alpha=scatter_alpha, color="C2"
)
# axs[2, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[3, 0].scatter(
    lps_dataset_test.lps_ranges[:, 0],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C1",
)
axs[3, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[3, 0].set_ylabel(r"Peak current error (\unit{\ampere})")
axs[3, 0].set_xscale("log")
# axs[3, 0].set_yscale("log")

axs[3, 1].scatter(
    lps_dataset_test.lps_ranges[:, 1],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C1",
)
axs[3, 1].set_xlabel(r"Energy spread (\unit{\percent})")
axs[3, 1].set_xscale("log")

axs[3, 2].scatter(
    lps_dataset_test.rf_settings[:, 0],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[3, 3].scatter(
    lps_dataset_test.rf_settings[:, 1],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[3, 4].scatter(
    lps_dataset_test.rf_settings[:, 2],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[3, 5].scatter(
    lps_dataset_test.rf_settings[:, 3],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[3, 6].scatter(
    lps_dataset_test.rf_settings[:, 4],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

plt.tight_layout()
plt.show()

In [None]:
plt.scatter(true_currents, predicted_currents, alpha=0.05, s=1)
plt.plot(true_currents, true_currents, color="C1")
plt.loglog()
plt.xlabel(r"True current (\unit{\ampere})")
plt.ylabel(r"Predicted current (\unit{\ampere})")
plt.show()

In [20]:
normalized_bunch_length_errors = bunch_length_errors / lps_dataset_test.lps_ranges[:, 0]
normalized_current_errors = current_errors / true_currents
normalized_peak_current_errors = peak_current_errors / true_peak_currents

In [None]:
scatter_alpha = 0.025

fig, axs = plt.subplots(
    4,
    7,
    sharex="col",
    sharey="row",
    figsize=(426.79 / 72.27 * 5, 426.79 / 72.27 * 5 * 0.48),
)

axs[0, 0].scatter(
    lps_dataset_test.lps_ranges[:, 0], losses, alpha=scatter_alpha, color="C1"
)
# axs[0, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[0, 0].set_ylabel("Loss")
axs[0, 0].set_yscale("log")

axs[0, 1].scatter(
    lps_dataset_test.lps_ranges[:, 1], losses, alpha=scatter_alpha, color="C1"
)
# axs[0, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[0, 2].scatter(
    lps_dataset_test.rf_settings[:, 0], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[0, 3].scatter(
    lps_dataset_test.rf_settings[:, 1], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[0, 4].scatter(
    lps_dataset_test.rf_settings[:, 2], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[0, 5].scatter(
    lps_dataset_test.rf_settings[:, 3], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[0, 6].scatter(
    lps_dataset_test.rf_settings[:, 4], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[1, 0].scatter(
    lps_dataset_test.lps_ranges[:, 0],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
# axs[1, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[1, 0].set_ylabel(r"Bunch length error (\unit{\percent})")
# axs[1, 0].set_yscale("log")

axs[1, 1].scatter(
    lps_dataset_test.lps_ranges[:, 1],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
# axs[1, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[1, 2].scatter(
    lps_dataset_test.rf_settings[:, 0],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[1, 3].scatter(
    lps_dataset_test.rf_settings[:, 1],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[1, 4].scatter(
    lps_dataset_test.rf_settings[:, 2],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[1, 5].scatter(
    lps_dataset_test.rf_settings[:, 3],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[1, 6].scatter(
    lps_dataset_test.rf_settings[:, 4],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[2, 0].scatter(
    lps_dataset_test.lps_ranges[:, 0],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
axs[2, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[2, 0].set_ylabel(r"Current error (\unit{\percent})")
# axs[2, 0].set_yscale("log")

axs[2, 1].scatter(
    lps_dataset_test.lps_ranges[:, 1],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
# axs[2, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[2, 2].scatter(
    lps_dataset_test.rf_settings[:, 0],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[2, 3].scatter(
    lps_dataset_test.rf_settings[:, 1],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[2, 4].scatter(
    lps_dataset_test.rf_settings[:, 2],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[2, 5].scatter(
    lps_dataset_test.rf_settings[:, 3],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[2, 6].scatter(
    lps_dataset_test.rf_settings[:, 4],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[3, 0].scatter(
    lps_dataset_test.lps_ranges[:, 0],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
axs[3, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[3, 0].set_ylabel(r"Peak current error (\unit{\ampere})")
axs[3, 0].set_xscale("log")
# axs[3, 0].set_yscale("log")

axs[3, 1].scatter(
    lps_dataset_test.lps_ranges[:, 1],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
axs[3, 1].set_xlabel(r"Energy spread (\unit{\percent})")
axs[3, 1].set_xscale("log")

axs[3, 2].scatter(
    lps_dataset_test.rf_settings[:, 0],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[3, 3].scatter(
    lps_dataset_test.rf_settings[:, 1],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[3, 4].scatter(
    lps_dataset_test.rf_settings[:, 2],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[3, 5].scatter(
    lps_dataset_test.rf_settings[:, 3],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[3, 6].scatter(
    lps_dataset_test.rf_settings[:, 4],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

plt.tight_layout()
plt.show()

In [None]:
losses = []
predicted_bunch_lengths = []
predicted_current_profiles = []
for (rf_settings, formfactor), (true_current_profile, true_bunch_length) in tqdm(
    current_dataset_test
):
    x_rf = rf_settings.reshape((1, 5))
    x_formfactor = formfactor.reshape((1, 240))

    y_current_profile = true_current_profile.reshape((1, 300))
    y_bunch_length = true_bunch_length.reshape((1, 1))
    y_hat_current_profile, y_hat_bunch_length = current_model(x_rf, x_formfactor)

    predicted_current_profile = (
        current_scalers["current"]
        .inverse_transform(y_hat_current_profile.detach().numpy())
        .reshape(300)
    )
    predicted_bunch_length = (
        current_scalers["bunch_length"]
        .inverse_transform(y_hat_bunch_length.detach().numpy())
        .reshape(1)
    )

    current_profile_loss = current_model.current_criterion(
        y_hat_current_profile, y_current_profile
    ).detach()
    bunch_length_loss = current_model.length_criterion(
        y_hat_bunch_length, y_bunch_length
    ).detach()
    loss = current_profile_loss + bunch_length_loss

    losses.append(loss.item())
    predicted_bunch_lengths.append(predicted_bunch_length)
    predicted_current_profiles.append(predicted_current_profile)

losses = np.array(losses)
predicted_bunch_lengths = np.array(predicted_bunch_lengths)
predicted_current_profiles = np.array(predicted_current_profiles)

In [23]:
bunch_length_errors = np.abs(
    predicted_bunch_lengths - current_dataset_test.bunch_lengths
)

true_currents = current_dataset_test.current_profiles.sum(axis=1)
predicted_currents = predicted_current_profiles.sum(axis=1)
current_errors = np.abs(predicted_currents - true_currents)

true_peak_currents = current_dataset_test.current_profiles.max(axis=1)
predicted_peak_currents = predicted_current_profiles.max(axis=1)
peak_current_errors = np.abs(predicted_peak_currents - true_peak_currents)

In [None]:
_ = ic(
    np.median(true_currents),
    true_currents.mean(),
    true_currents.std(),
    np.median(predicted_currents),
    predicted_currents.mean(),
    predicted_currents.std(),
)

In [None]:
scatter_alpha = 0.025

fig, axs = plt.subplots(
    4,
    7,
    sharex="col",
    sharey="row",
    figsize=(426.79 / 72.27 * 5, 426.79 / 72.27 * 5 * 0.48),
)

axs[0, 0].scatter(
    current_dataset_test.bunch_lengths, losses, alpha=scatter_alpha, color="C1"
)
# axs[0, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[0, 0].set_ylabel("Loss")
axs[0, 0].set_yscale("log")

# axs[0, 1].scatter(
#     lps_dataset_test.lps_ranges[:, 1], losses, alpha=scatter_alpha, color="C1"
# )
# axs[0, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[0, 2].scatter(
    current_dataset_test.rf_settings[:, 0], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[0, 3].scatter(
    current_dataset_test.rf_settings[:, 1], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[0, 4].scatter(
    current_dataset_test.rf_settings[:, 2], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[0, 5].scatter(
    current_dataset_test.rf_settings[:, 3], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[0, 6].scatter(
    current_dataset_test.rf_settings[:, 4], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[1, 0].scatter(
    current_dataset_test.bunch_lengths,
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C1",
)
# axs[1, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[1, 0].set_ylabel(r"Bunch length error (\unit{\meter})")
# axs[1, 0].set_yscale("log")

# axs[1, 1].scatter(
#     lps_dataset_test.lps_ranges[:, 1],
#     bunch_length_errors,
#     alpha=scatter_alpha,
#     color="C1",
# )
# axs[1, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[1, 2].scatter(
    current_dataset_test.rf_settings[:, 0],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[1, 3].scatter(
    current_dataset_test.rf_settings[:, 1],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[1, 4].scatter(
    current_dataset_test.rf_settings[:, 2],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[1, 5].scatter(
    current_dataset_test.rf_settings[:, 3],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[1, 6].scatter(
    current_dataset_test.rf_settings[:, 4],
    bunch_length_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[2, 0].scatter(
    current_dataset_test.bunch_lengths, current_errors, alpha=scatter_alpha, color="C1"
)
axs[2, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[2, 0].set_ylabel(r"Current error (\unit{\ampere})")
# axs[2, 0].set_yscale("log")

# axs[2, 1].scatter(
#     lps_dataset_test.lps_ranges[:, 1], current_errors, alpha=scatter_alpha, color="C1"
# )
# axs[2, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[2, 2].scatter(
    current_dataset_test.rf_settings[:, 0],
    current_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[2, 3].scatter(
    current_dataset_test.rf_settings[:, 1],
    current_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[2, 4].scatter(
    current_dataset_test.rf_settings[:, 2],
    current_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[2, 5].scatter(
    current_dataset_test.rf_settings[:, 3],
    current_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[2, 6].scatter(
    current_dataset_test.rf_settings[:, 4],
    current_errors,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[3, 0].scatter(
    current_dataset_test.bunch_lengths,
    peak_current_errors,
    alpha=scatter_alpha,
    color="C1",
)
axs[3, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[3, 0].set_ylabel(r"Peak current error (\unit{\ampere})")
axs[3, 0].set_xscale("log")
# axs[3, 0].set_yscale("log")

# axs[3, 1].scatter(
#     lps_dataset_test.lps_ranges[:, 1],
#     peak_current_errors,
#     alpha=scatter_alpha,
#     color="C1",
# )
axs[3, 1].set_xlabel(r"Energy spread (\unit{\percent})")
axs[3, 1].set_xscale("log")

axs[3, 2].scatter(
    current_dataset_test.rf_settings[:, 0],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[3, 3].scatter(
    current_dataset_test.rf_settings[:, 1],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[3, 4].scatter(
    current_dataset_test.rf_settings[:, 2],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[3, 5].scatter(
    current_dataset_test.rf_settings[:, 3],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[3, 6].scatter(
    current_dataset_test.rf_settings[:, 4],
    peak_current_errors,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

plt.tight_layout()
plt.show()

In [None]:
plt.scatter(true_currents, predicted_currents, alpha=0.05, s=1)
plt.plot(true_currents, true_currents, color="C1")
plt.loglog()
plt.xlabel(r"True current (\unit{\ampere})")
plt.ylabel(r"Predicted current (\unit{\ampere})")
plt.show()

In [27]:
normalized_bunch_length_errors = (
    bunch_length_errors / current_dataset_test.bunch_lengths
)
normalized_current_errors = current_errors / true_currents
normalized_peak_current_errors = peak_current_errors / true_peak_currents

In [None]:
scatter_alpha = 0.025

fig, axs = plt.subplots(
    4,
    7,
    sharex="col",
    sharey="row",
    figsize=(426.79 / 72.27 * 5, 426.79 / 72.27 * 5 * 0.48),
)

axs[0, 0].scatter(
    current_dataset_test.bunch_lengths, losses, alpha=scatter_alpha, color="C1"
)
# axs[0, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[0, 0].set_ylabel("Loss")
axs[0, 0].set_yscale("log")

# axs[0, 1].scatter(
#     lps_dataset_test.lps_ranges[:, 1], losses, alpha=scatter_alpha, color="C1"
# )
# axs[0, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[0, 2].scatter(
    current_dataset_test.rf_settings[:, 0], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[0, 3].scatter(
    current_dataset_test.rf_settings[:, 1], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[0, 4].scatter(
    current_dataset_test.rf_settings[:, 2], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[0, 5].scatter(
    current_dataset_test.rf_settings[:, 3], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[0, 6].scatter(
    current_dataset_test.rf_settings[:, 4], losses, alpha=scatter_alpha, color="C2"
)
# axs[0, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[1, 0].scatter(
    current_dataset_test.bunch_lengths,
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
# axs[1, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[1, 0].set_ylabel(r"Bunch length error (\unit{\percent})")
# axs[1, 0].set_yscale("log")

# axs[1, 1].scatter(
#     lps_dataset_test.lps_ranges[:, 1],
#     normalized_bunch_length_errors * 100,
#     alpha=scatter_alpha,
#     color="C1",
# )
# axs[1, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[1, 2].scatter(
    current_dataset_test.rf_settings[:, 0],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[1, 3].scatter(
    current_dataset_test.rf_settings[:, 1],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[1, 4].scatter(
    current_dataset_test.rf_settings[:, 2],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[1, 5].scatter(
    current_dataset_test.rf_settings[:, 3],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[1, 6].scatter(
    current_dataset_test.rf_settings[:, 4],
    normalized_bunch_length_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[1, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[2, 0].scatter(
    current_dataset_test.bunch_lengths,
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
axs[2, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[2, 0].set_ylabel(r"Current error (\unit{\percent})")
# axs[2, 0].set_yscale("log")

# axs[2, 1].scatter(
#     lps_dataset_test.lps_ranges[:, 1],
#     normalized_current_errors * 100,
#     alpha=scatter_alpha,
#     color="C1",
# )
# axs[2, 1].set_xlabel(r"Energy spread (\unit{\percent})")

axs[2, 2].scatter(
    current_dataset_test.rf_settings[:, 0],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[2, 3].scatter(
    current_dataset_test.rf_settings[:, 1],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[2, 4].scatter(
    current_dataset_test.rf_settings[:, 2],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[2, 5].scatter(
    current_dataset_test.rf_settings[:, 3],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[2, 6].scatter(
    current_dataset_test.rf_settings[:, 4],
    normalized_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
# axs[2, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

axs[3, 0].scatter(
    current_dataset_test.bunch_lengths,
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C1",
)
axs[3, 0].set_xlabel(r"Bunch length (\unit{\meter})")
axs[3, 0].set_ylabel(r"Peak current error (\unit{\ampere})")
axs[3, 0].set_xscale("log")
# axs[3, 0].set_yscale("log")

# axs[3, 1].scatter(
#     lps_dataset_test.lps_ranges[:, 1],
#     normalized_peak_current_errors * 100,
#     alpha=scatter_alpha,
#     color="C1",
# )
axs[3, 1].set_xlabel(r"Energy spread (\unit{\percent})")
axs[3, 1].set_xscale("log")

axs[3, 2].scatter(
    current_dataset_test.rf_settings[:, 0],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 2].set_xlabel(r"Chirp (\unit{\hertz\per\second})")

axs[3, 3].scatter(
    current_dataset_test.rf_settings[:, 1],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 3].set_xlabel(r"Curv (\unit{\per\second})")

axs[3, 4].scatter(
    current_dataset_test.rf_settings[:, 2],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 4].set_xlabel(r"Skew (\unit{\per\second})")

axs[3, 5].scatter(
    current_dataset_test.rf_settings[:, 3],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 5].set_xlabel(r"Chirp L1 (\unit{\hertz})")

axs[3, 6].scatter(
    current_dataset_test.rf_settings[:, 4],
    normalized_peak_current_errors * 100,
    alpha=scatter_alpha,
    color="C2",
)
axs[3, 6].set_xlabel(r"Chirp L2 (\unit{\hertz})")

plt.tight_layout()
plt.show()