# Evaluation of CNN energy reconstruction performance

In [None]:
import hist
import keras
import matplotlib.gridspec as grid_spec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa: F401
import scipy
import sympy
import tensorflow as tf
import tensorflow.keras.backend as K
import uproot
from iminuit import Minuit, cost
from iminuit.cost import LeastSquares
from scipy.stats import linregress
from tensorflow.keras.models import load_model

from plotting import watermark
from preprocessing import reshape_data

In [None]:
particle = "photon"
particle_pretty = r"$\gamma$"

In [None]:
particle = "pi_zero"
particle_pretty = r"$\pi^0$"

In [None]:
particle = "pi_plus"
particle_pretty = r"$\pi^\pm$"
filename_test = f"df_{particle}_test.root:df"
model_file = "CNN_kanchenjunga_energy_hcal_pi_plus_n100000_e25.keras"

In [None]:
particle = "electron"
particle_pretty = r"$e$"
filename_test = f"df_{particle}_fixz_sat7_test.root:df"
model_file = "CNN_lyskamm_energy_electron_n160000_e100.keras"

In [None]:
target = "nu_energy"

target_pretty = "flavour"
target_LaTeX = "flavour"

In [None]:
plt.style.use('science')

In [None]:
plt.rcParams["font.size"] = 18
plt.rcParams["axes.formatter.limits"] = -5, 4
plt.rcParams["figure.figsize"] = 6, 4
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
events_test = uproot.open(filename_test)

In [None]:
def event_generator():
    events = events_test
    for batch, report in events.iterate(step_size=1, report=True, library="np"):
        ys = np.abs(batch[target])
        for i in range(batch["X"].shape[0]):
            yield (
                batch["X"].astype(np.float16)[i],
                batch["X_mufilter"].astype(np.float16)[i],
                ys[i],
            )

In [None]:
gen = event_generator()

In [None]:
sample = gen.__next__()

In [None]:
generator_spec_0 = tf.type_spec_from_value(gen.__next__()[0])
generator_spec_1 = tf.type_spec_from_value(gen.__next__()[1])
generator_spec_2 = tf.type_spec_from_value(gen.__next__()[2])

In [None]:
ds_test = (
    tf.data.Dataset.from_generator(
        event_generator,
        output_signature=(
            generator_spec_0,
            generator_spec_1,
            generator_spec_2,
        ),
    )
    .map(reshape_data)
    .apply(tf.data.experimental.assert_cardinality(events_test.num_entries))
)

In [None]:
#y_test = np.log(events_test[target].array(library="np"))
y_test = events_test[target].array(library="np")

In [None]:
batched_ds_test = ds_test.batch(20)

In [None]:
K.set_image_data_format("channels_last")

In [None]:
#model = load_model(f"CNN_jannu_energy_{particle}_n80000_e100.keras")

In [None]:
#model = load_model(f"CNN_jannu_energy_photon_n80000_e100.keras")

In [None]:
keras.config.enable_unsafe_deserialization()

In [None]:
#model = load_model()

In [None]:
#model = load_model(f"CNN_jannu_energy_combined_sat7_n160000_e100.keras")

In [None]:
model = load_model(model_file)

In [None]:
model_name = model.name

In [None]:
y_pred = model.predict(batched_ds_test)

In [None]:
df = pd.DataFrame({"E_true": y_test, "E_pred": y_pred.ravel()})

In [None]:
scale = df.E_true.mean() / df.E_pred.mean()

In [None]:
shift = ((scale * df.E_pred) - df.E_true).mean()

In [None]:
df["E_corrected"] = (scale * df.E_pred) - shift

In [None]:
df.E_corrected = df.E_pred

In [None]:
res = linregress(df.E_true[df.E_true>150], df.E_pred[df.E_true>150])

In [None]:
res

In [None]:
def energy_correct(E_raw):
    return (E_raw - res.intercept) / res.slope

In [None]:
#df.E_corrected = energy_correct(df.E_pred)

In [None]:
plt.scatter(df.E_true, df.E_corrected)

In [None]:
scale, shift

In [None]:
h_dE = hist.Hist.new.Regular(200, -100, 100, name=r"$dE$").Double()

In [None]:
h_dE.fill(df.E_corrected - df.E_true)

In [None]:
df.E_corrected.max()

In [None]:
df["d_corrected_energy"] = df.E_corrected - df.E_true

### Fit energy resolution

In [None]:
bins_E_reco = 15

In [None]:
h_dE_rel_test_vs_E_rel_pred = (
    hist.Hist.new.Regular(100, (df.d_corrected_energy.min() // 10 * 10), (df.d_corrected_energy.max() // 10 + 1) * 10, name=r"d_corrected_energy")
    .Regular(
        bins_E_reco, 150, 500, name=r"E_true"
    )  # , transform=hist.axis.transform.log)
    .Double()
)

In [None]:
h_dE_rel_test_vs_E_rel_pred.fill(df.d_corrected_energy, df.E_true)

In [None]:
plt.scatter(df.E_true, df.E_pred, marker='.', s=0.1)
plt.xlim([150,500])
plt.xlabel(r" $E_\mathrm{true}\;[\mathrm{GeV}]$")
plt.ylabel(r"$E_\mathrm{reco}\;[\mathrm{GeV}]$")
plt.savefig("plots/scatter_E_rel_test_vs_E_rel_pred.pdf")
plt.savefig("plots/scatter_E_rel_test_vs_E_rel_pred.png")

In [None]:
plt.scatter(df.E_true, df.E_corrected, marker='.', s=0.1)
plt.xlim([150,500])
plt.xlabel(r" $E_\mathrm{true}\;[\mathrm{GeV}]$")
plt.ylabel(r"$E_\mathrm{reco, corrected}\;[\mathrm{GeV}]$")
#plt.savefig("plots/scatter_E_rel_test_vs_E_rel_pred.pdf")
#plt.savefig("plots/scatter_E_rel_test_vs_E_rel_pred.png")

In [None]:
plt.scatter(df.E_pred-df.E_true, df.E_true, marker='.', s=0.1)
plt.ylim([150,500])
plt.xlabel(r" $\Delta E\;[\mathrm{GeV}]$")
plt.ylabel(r"$E_\mathrm{true}\;[\mathrm{GeV}]$")
plt.savefig("plots/scatter_dE_rel_test_vs_E_rel_pred.pdf")
plt.savefig("plots/scatter_dE_rel_test_vs_E_rel_pred.png")

In [None]:
"""
hits_per_strip = []
energies = []"""

In [None]:
wstrips = []
energies = []

In [None]:
"""
plt.scatter(energies, hits_per_strip)
"""

In [None]:
h_dE_rel_test_vs_E_rel_pred.plot()
plt.xlabel(r" $\Delta E\;[\mathrm{GeV}]$")
plt.ylabel(r"$E_\mathrm{true}\;[\mathrm{GeV}]$")
watermark()
plt.savefig("plots/h_dE_rel_test_vs_E_rel_pred.pdf")
plt.savefig("plots/h_dE_rel_test_vs_E_rel_pred.png")

In [None]:
def model(x, mu, sigma):
    return scipy.stats.norm.cdf(x, loc=mu, scale=sigma)

In [None]:
gs = grid_spec.GridSpec(bins_E_reco, 1)
fig = plt.figure(figsize=(16, 9))

i = 0
mus = []
sigmas = []
bins = []

ax_objs = []
for bin in range(bins_E_reco):
    # creating new axes object
    ax_objs.append(fig.add_subplot(gs[i : i + 1, 0:]))

    # plotting the distribution
    h = h_dE_rel_test_vs_E_rel_pred[:, bin]
    h.plot(yerr=False, ax=ax_objs[-1], color=colors[bin % len(colors)], histtype="fill")
    entries, edges = h.to_numpy()
    n_bins = len(entries)
    average = np.average(edges[:-1], weights=entries)
    variance = np.average((edges[:-1] - average) ** 2, weights=entries)
    bnll = cost.BinnedNLL(entries, edges, model)
    m = Minuit(bnll, average, np.sqrt(variance))
    res = m.migrad()
    res = m.hesse()
    if res.valid:
        plot_range = ax_objs[-1].get_xlim()
        x = np.linspace(*plot_range, n_bins)
        best_fit = scipy.stats.norm(res.params[0].value, res.params[1].value)
        binsize = (plot_range[1] - plot_range[0]) / n_bins
        scale = (
            h.sum()
            / (best_fit.cdf(plot_range[1]) - best_fit.cdf(plot_range[0]))
            * binsize
        )
        #scale = 1
        #ax_objs[-1].plot(
        #    x, scale * best_fit.pdf(x), color=colors[(bin + 3) % len(colors)]
        #)
        ax_objs[-1].plot(
            x, bnll.prediction(res.values), color=colors[(bin + 3) % len(colors)]
        )
        bins.append(bin)
        mus.append(res.params[0])
        sigmas.append(res.params[1])
    else:
        print(res)

    # make background transparent
    rect = ax_objs[-1].patch
    rect.set_alpha(0)

    # remove borders, axis ticks, and labels
    ax_objs[-1].set_yticklabels([])

    if i == bins_E_reco - 1:
        ax_objs[-1].set_xlabel(r"$\Delta E$", fontsize=16, fontweight="bold")
    else:
        ax_objs[-1].set_xticklabels([])
        ax_objs[-1].set_xlabel("")

    ax_objs[-1].set_ylabel(str(bin), rotation=45)
    ax_objs[-1].set_yticks([])
    ax_objs[-1].set_xticks([])

    spines = ["top", "right", "left", "bottom"]
    for s in spines:
        ax_objs[-1].spines[s].set_visible(False)

    i += 1

gs.update(hspace=-0.7)
# gs.update()

plt.tight_layout()
#plt.show()

plt.savefig(f"plots/ridge_{particle}_{model_name}.pdf")
plt.savefig(f"plots/ridge_{particle}_{model_name}.png")

In [None]:
bin_edges = h_dE_rel_test_vs_E_rel_pred.axes[1].edges
bin_centres = (bin_edges[1:] + bin_edges[:-1]) / 2
bin_half_widths = (bin_edges[1:] - bin_edges[:-1]) / 2

In [None]:
def line(x, m, b):
    return b + x * m

In [None]:
mu_E_over_E = np.array([mu.value for mu in mus]) / bin_centres[bins]
d_mu_E_over_E = abs((
    [mu.value for mu in mus]
    / bin_centres[bins]
    * np.sqrt(
        (
            np.array([mu.error for mu in mus])
            / np.array([mu.value for mu in mus])
        )
        ** 2
        + (bin_half_widths[bins] / bin_centres[bins]) ** 2
    )
))

In [None]:
least_squares = LeastSquares(
    bin_centres[bins], mu_E_over_E, d_mu_E_over_E, line
)

In [None]:
m = Minuit(least_squares, b=-0.35, m=-0.01)  # starting values for m and b

m.migrad()  # finds minimum of least_squares function
res = m.hesse()  # accurately computes uncertainties

In [None]:
res

In [None]:
plt.errorbar(
    bin_centres[bins],
    mu_E_over_E,
    xerr=bin_half_widths[bins],
    yerr=d_mu_E_over_E,
    linestyle="",
    label=r"$\left<\Delta E\right>$",
    #color=colors[0],
)
plt.plot(bin_centres[bins], line(bin_centres[bins], *res.values))
#plt.hlines(0, *plt.xlim(), color="red")
plt.ylabel(r"$\frac{\left<\Delta E\right>}{E_\mathrm{true}}$")
plt.xlabel(r"$E_\mathrm{true}\;[\mathrm{GeV}]$")
watermark()
plt.savefig("plots/energy_bias.pdf")
plt.savefig("plots/energy_bias.png")

In [None]:
plt.errorbar(
    bin_centres[bins],
    [mu.value for mu in mus] + bin_centres[bins],
    xerr=bin_half_widths[bins],
    yerr=[mu.error for mu in mus],
    linestyle="",
    label=r"$\left<\Delta E\right>$",
    #color=colors[0],
)

In [None]:
least_squares = LeastSquares(
    bin_centres[bins], [mu.value for mu in mus] + bin_centres[bins], [mu.error for mu in mus], line
)

In [None]:
m = Minuit(least_squares, b=0, m=1)  # starting values for m and b

m.migrad()  # finds minimum of least_squares function
res = m.hesse()  # accurately computes uncertainties

In [None]:
res.params

In [None]:
A, b, c, E = sympy.symbols("A b c E")

In [None]:
f = A + b / sympy.sqrt(E) + c / E

In [None]:
f_lambda = sympy.lambdify((A, b, c, E), f)

In [None]:
def E_model(E, A, b, c):
    return f_lambda(A, b, c, E)

In [None]:
denom = [mu.value for mu in mus] + bin_centres[bins]
denom = bin_centres[bins]
sigma_E_over_E = np.array([sigma.value for sigma in sigmas]) / denom
#error_denom = np.sqrt(bin_half_widths[bins] ** 2 + np.array([mu.error for mu in mus]) ** 2)
#error_denom = np.array([mu.error for mu in mus])
error_denom = bin_half_widths[bins]
d_sigma_E_over_E = (
    sigma_E_over_E
    * np.sqrt(
        (
            np.array([sigma.error for sigma in sigmas])
            / np.array([sigma.value for sigma in sigmas])
        )
        ** 2
        + ( error_denom / denom) ** 2
    )
)

In [None]:
least_squares = LeastSquares(
    bin_centres[bins], sigma_E_over_E, d_sigma_E_over_E, E_model
)

In [None]:
m = Minuit(least_squares, A=0.1, b=1, c=0)  # starting values for α and β
#m.limits["A"] = (0, None)
#m.limits["b"] = (0, None)
#m.limits["c"] = (0, None)
#m.fixed["c"] = False


m.migrad()  # finds minimum of least_squares function
res = m.hesse()  # accurately computes uncertainties

In [None]:
res

In [None]:
f_pretty = sympy.latex(
    f.subs(
        [
            (A, sympy.Float(res.params[0].value, 1)),
            (b, sympy.Float(res.params[1].value, 2)),
            (c, sympy.Float(res.params[2].value, 2)),

        ]
    )
)

In [None]:
plt.errorbar(
    bin_centres[bins],
    [sigma.value for sigma in sigmas] / bin_centres[bins],
    xerr=bin_half_widths[bins],
    yerr=d_sigma_E_over_E,
    linestyle="",
    label=r"$\sigma\left(\Delta E\right)$",
    color=colors[1],
    fmt="o",
    capsize=3,
)
plt.plot(
    bin_centres[bins],
    E_model(bin_centres[bins], res.params[0].value, res.params[1].value, res.params[2].value),
)
plt.ylabel(r"$\frac{\sigma\left(E\right)}{E_\mathrm{reco}}$")
plt.xlabel(r"$E_\mathrm{true}\;[\mathrm{GeV}]$")
ax = plt.gca()
watermark()
plt.text(
    0.6,
    0.8,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"{particle_pretty} energy res.",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.7,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"${f_pretty}$",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.6,
    rf"$A = {res.params[0].value:.2f} \pm {res.params[0].error:.2f}$",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.5,
    rf"$b = {res.params[1].value:.1f} \pm {res.params[1].error:.1f}$",
    fontsize=14,
    transform=ax.transAxes,
)
if not m.fixed["c"]:
    plt.text(
        0.6,
        0.4,
        rf"$c = {res.params[2].value:.1f} \pm {res.params[2].error:.1f}$",
        fontsize=14,
        transform=ax.transAxes,
    )
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_free_abc.pdf")
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_free_abc.png")

In [None]:
m = Minuit(least_squares, A=0.1, b=1, c=0)  # starting values for α and β
#m.limits["A"] = (0, None)
#m.limits["b"] = (0, None)
#m.limits["c"] = (0, None)
m.fixed["c"] = True


m.migrad()  # finds minimum of least_squares function
res = m.hesse()  # accurately computes uncertainties

In [None]:
res

In [None]:
f_pretty = sympy.latex(
    f.subs(
        [
            (A, sympy.Float(res.params[0].value, 1)),
            (b, sympy.Float(res.params[1].value, 2)),
            (c, sympy.Float(res.params[2].value, 2)),

        ]
    )
)

In [None]:
plt.errorbar(
    bin_centres[bins],
    sigma_E_over_E,
    xerr=bin_half_widths[bins],
    yerr=d_sigma_E_over_E,
    linestyle="",
    label=r"$\sigma\left(\Delta E\right)$",
    color=colors[1],
    fmt="o",
    capsize=3,
)
plt.plot(
    bin_centres[bins],
    E_model(bin_centres[bins], res.params[0].value, res.params[1].value, res.params[2].value),
)
plt.ylabel(r"$\frac{\sigma\left(E\right)}{E_\mathrm{reco}}$")
plt.xlabel(r"$E_\mathrm{true}\;[\mathrm{GeV}]$")
ax = plt.gca()
watermark()
plt.text(
    0.6,
    0.8,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"{particle_pretty} energy res.",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.7,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"${f_pretty}$",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.6,
    rf"$A = {res.params[0].value:.2f} \pm {res.params[0].error:.2f}$",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.5,
    rf"$b = {res.params[1].value:.1f} \pm {res.params[1].error:.1f}$",
    fontsize=14,
    transform=ax.transAxes,
)
if not m.fixed["c"]:
    plt.text(
        0.6,
        0.4,
        rf"$c = {res.params[2].value:.1f} \pm {res.params[2].error:.1f}$",
        fontsize=14,
        transform=ax.transAxes,
    )
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_free_ab.pdf")
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_free_ab.png")

In [None]:
m = Minuit(least_squares, A=0.1, b=1, c=0)  # starting values for α and β
m.limits["A"] = (0, None)
m.limits["b"] = (0, None)
m.limits["c"] = (0, None)
m.fixed["c"] = False


m.migrad()  # finds minimum of least_squares function
res = m.hesse()  # accurately computes uncertainties

In [None]:
f_pretty = sympy.latex(
    f.subs(
        [
            (A, sympy.Float(res.params[0].value, 1)),
            (b, sympy.Float(res.params[1].value, 2)),
            (c, sympy.Float(res.params[2].value, 2)),

        ]
    )
)

In [None]:
plt.errorbar(
    bin_centres[bins],
    sigma_E_over_E,
    xerr=bin_half_widths[bins],
    yerr=d_sigma_E_over_E,
    linestyle="",
    label=r"$\sigma\left(\Delta E\right)$",
    color=colors[1],
    fmt="o",
    capsize=3,
)
plt.plot(
    bin_centres[bins],
    E_model(bin_centres[bins], res.params[0].value, res.params[1].value, res.params[2].value),
)
plt.ylabel(r"$\frac{\sigma\left(E\right)}{E_\mathrm{reco}}$")
plt.xlabel(r"$E_\mathrm{true}\;[\mathrm{GeV}]$")
ax = plt.gca()
watermark()
plt.text(
    0.6,
    0.8,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"{particle_pretty} energy res.",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.7,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"${f_pretty}$",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.6,
    rf"$A = {res.params[0].value:.2f} \pm {res.params[0].error:.2f}$",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.5,
    rf"$b = {res.params[1].value:.1f} \pm {res.params[1].error:.1f}$",
    fontsize=14,
    transform=ax.transAxes,
)
if not m.fixed["c"]:
    plt.text(
        0.6,
        0.4,
        rf"$c = {res.params[2].value:.1f} \pm {res.params[2].error:.1f}$",
        fontsize=14,
        transform=ax.transAxes,
    )
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_pos_abc.pdf")
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_pos_abc.png")

In [None]:
m = Minuit(least_squares, A=0.1, b=1, c=0)  # starting values for α and β
m.limits["A"] = (0, None)
m.limits["b"] = (0, None)
m.limits["c"] = (0, None)
m.fixed["c"] = True


m.migrad()  # finds minimum of least_squares function
res = m.hesse()  # accurately computes uncertainties

In [None]:
f_pretty = sympy.latex(
    f.subs(
        [
            (A, sympy.Float(res.params[0].value, 1)),
            (b, sympy.Float(res.params[1].value, 2)),
            (c, sympy.Float(res.params[2].value, 2)),

        ]
    )
)

In [None]:
plt.errorbar(
    bin_centres[bins],
    sigma_E_over_E,
    xerr=bin_half_widths[bins],
    yerr=d_sigma_E_over_E,
    linestyle="",
    label=r"$\sigma\left(\Delta E\right)$",
    color=colors[1],
    fmt="o",
    capsize=3,
)
plt.plot(
    bin_centres[bins],
    E_model(bin_centres[bins], res.params[0].value, res.params[1].value, res.params[2].value),
)
plt.ylabel(r"$\frac{\sigma\left(E\right)}{E_\mathrm{reco}}$")
plt.xlabel(r"$E_\mathrm{true}\;[\mathrm{GeV}]$")
ax = plt.gca()
watermark()
plt.text(
    0.6,
    0.8,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"{particle_pretty} energy res.",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.7,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"${f_pretty}$",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.6,
    rf"$A = {res.params[0].value:.2f} \pm {res.params[0].error:.2f}$",
    fontsize=14,
    transform=ax.transAxes,
)
plt.text(
    0.6,
    0.5,
    rf"$b = {res.params[1].value:.1f} \pm {res.params[1].error:.1f}$",
    fontsize=14,
    transform=ax.transAxes,
)
if not m.fixed["c"]:
    plt.text(
        0.6,
        0.4,
        rf"$c = {res.params[2].value:.1f} \pm {res.params[2].error:.1f}$",
        fontsize=14,
        transform=ax.transAxes,
    )
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_pos_ab.pdf")
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_pos_ab.png")

In [None]:
plt.errorbar(
    bin_centres[bins],
    sigma_E_over_E,
    xerr=bin_half_widths[bins],
    yerr=d_sigma_E_over_E,
    linestyle="",
    label=r"$\sigma\left(\Delta E\right)$",
    color=colors[1],
    fmt="o",
    capsize=3,
)
plt.ylabel(r"$\frac{\sigma\left(E\right)}{E_\mathrm{reco}}$")
plt.xlabel(r"$E_\mathrm{true}\;[\mathrm{GeV}]$")
ax = plt.gca()
watermark()
plt.text(
    0.6,
    0.8,
    # rf"$\sqrt{{{res.params[0].value:.3f}^2 + \left(\frac{{{res.params[1].value:.1f}}}{{\sqrt{{E}}}}\right)^2}}$",
    rf"{particle_pretty} energy res.",
    fontsize=14,
    transform=ax.transAxes,
)
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_no_fit.pdf")
plt.savefig(f"plots/energy_resolution_{particle}_{model_name}_no_fit.png")