In [None]:
import warnings
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.metrics as metrics
from scipy.stats import binned_statistic

# Note: These imports may need to be updated based on your hepattn package structure
# from hepformer.tracking.eval.evaluate import eval_events as eval_events_tracking
# from hepformer.tracking.eval.hit_eval import load_events as eval_events_filtering
# from hepformer.tracking.eval.plots import binned, profile_plot

# Placeholder imports - update these based on your actual hepattn structure
from hepattn.tracking.eval.evaluate import eval_events as eval_events_tracking
from hepattn.tracking.eval.hit_eval import load_events as eval_events_filtering
from hepattn.tracking.eval.plots import binned, profile_plot

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 400
plt.rcParams["text.usetex"] = True
plt.rcParams["font.family"] = "serif"
plt.rcParams["figure.constrained_layout.use"] = True

In [None]:
# Setup
out_dir = Path("/share/rcifdata/maxhart/hepformer-paper-plots/hepformer/hepformer/tracking/eval/plots")
out_dir = Path("/share/rcifdata/svanstroud/hepformer/hepformer/tracking/eval/plots/new/")
out_dir.mkdir(exist_ok=True)

training_colours = {
    "600 MeV": "mediumvioletred",
    "750 MeV": "cornflowerblue",
    # "1 GeV": "mediumseagreen", # |eta| < 2.5
    "1 GeV": "mediumseagreen",  # |eta| < 4.0
}

qty_bins = {
    "pt": np.array([0.6, 0.75, 1.0, 1.5, 2, 3, 4, 6, 10]),
    # "eta": np.array([-2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5]),
    "eta": np.array([-5, -4.5, -4, -3.5, -3, -2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4, 4.5, 5]),
    "phi": np.array([-3.14, -2.36, -1.57, -0.79, 0, 0.79, 1.57, 2.36, 3.14]),
    "vz": np.array([-100, -50, -20, -10, 0, 10, 20, 50, 100]),
}

qty_symbols = {"pt": "p_\\mathrm{T}", "eta": "\\eta", "phi": "\\phi", "vz": "v_z"}
qty_units = {"pt": "[GeV]", "eta": "", "phi": "", "vz": "[mm]"}

In [None]:
# Hit filter plots
filtering_fnames = {
    "600 MeV": "/share/rcifdata/svanstroud/hepformer/hepformer/tracking/logs/HC-final-0.6GeV_20241007-T092114/ckpts/epoch=029-val_loss=0.09947__test_test.h5",
    "750 MeV": "/share/rcifdata/svanstroud/hepformer/hepformer/tracking/logs/HC-final-0.75GeV_20241007-T092015/ckpts/epoch=029-val_loss=0.09307__test_train.h5",
    # "1 GeV": "/share/rcifdata/svanstroud/hepformer/hepformer/tracking/logs/HC-final-1GeV_20241007-T092359/ckpts/epoch=029-val_loss=0.09906__test_test.h5",
    "1 GeV": "/share/rcifdata/svanstroud/hepformer/hepformer/tracking/logs/HC-final-1GeV-eta5_20250303-T193944/ckpts/epoch=028-val_loss=0.13227__test_test.h5",
}

num_events = 1
filtering_results = {}
# Comment out until the correct imports are fixed
# for name, fname in filtering_fnames.items():
#     hits, parts = eval_events_filtering(fname, num_events=num_events, hit_cut=0.1)
#     filtering_results[name] = (hits, parts)

print("Cell ready to run once eval_events_filtering is properly imported")

In [None]:
filtering_table = []

rows = {
    "600 MeV": {"Layers": 12, "Param. Count": "5.4M", "Inference Time [ms]": 37.8},
    "750 MeV": {"Layers": 12, "Param. Count": "5.4M", "Inference Time [ms]": 37.5},
    # "1 GeV": {"Layers": 8, "Param. Count": "5.4M", "Inference Time [ms]": 25.2},
    "1 GeV": {"Layers": 12, "Param. Count": "9999M", "Inference Time [ms]": 28.7},  # this is with compile callback but not sure the others are
}

# Comment out until filtering_results is available
# for name, (hits, parts) in filtering_results.items():
#     row = {"Model": name} | rows[name]
#
#     hit_eff = (hits.pred & hits.tgt).sum() / hits.tgt.sum()
#     hit_pur_post = (hits.pred & hits.tgt).sum() / hits.pred.sum()
#     hit_pur_pre = hits["tgt"].sum() / len(hits)
#
#     # se_recall = (recall * (1 - recall) / hits.tgt.sum()) ** 0.5
#     # se_precision = (precision * (1 - precision) / hits.pred.sum()) ** 0.5
#     # pre_count, _,  _ = binned_statistic(parts["pt"], parts["reconstructable_pre"], statistic="count", bins=bins)
#     # post_count, _, _ = binned_statistic(parts["pt"], parts["reconstructable_post"], statistic="count", bins=bins)
#
#     eff_perf = parts["reconstructable_post"].sum() / parts["reconstructable_pre"].sum()
#     eff_perf_hipt = parts["reconstructable_post"][parts["pt"] > 0.9].sum() / parts["reconstructable_pre"][parts["pt"] > 0.9].sum()
#
#     row["Hit Efficiency"] = 100 * hit_eff
#     row["Hit Purity (Pre)"] = 100 * hit_pur_pre
#     row["Hit Purity (Post)"] = 100 * hit_pur_post
#     row["\\varepsilon^\\mathrm{perfect}"] = 100 * eff_perf
#     row["\\varepsilon^\\mathrm{perfect}_{p_\\mathrm{T} \\geq 0.9}"] = 100 * eff_perf_hipt
#
#     fpr, tpr, thresholds = metrics.roc_curve(hits["tgt"], hits["prob"])
#     row["AUC"] = metrics.auc(fpr, tpr)
#
#     filtering_table.append(row)
#
# filtering_table = pd.DataFrame.from_dict(filtering_table)
# filtering_table

print("Cell ready to run once filtering_results is available")

In [None]:
# main hit filter performance plots
from matplotlib.lines import Line2D

# Comment out until filtering_results is available
# fig, ax = plt.subplots(nrows=1, ncols=2, constrained_layout=True)
# fig.set_size_inches(10, 3)
# bins = np.linspace(0, 1, 24)

# for training, (hits, parts) in filtering_results.items():
#     colour = training_colours[training]
#     prec, recall, threshold = metrics.precision_recall_curve(hits["tgt"], hits["prob"])
#     idx = np.argmin(np.abs(threshold - 0.1))

#     ax[0].plot(recall, prec, color=colour, label=training)
#     ax[0].scatter(recall[idx], prec[idx], color=colour)

#     bins = qty_bins["pt"]
#     b = (bins[:-1] + bins[1:]) / 2
#     b_err = (bins[1:] - bins[:-1]) / 2

#     pt_bin_count, _, _ = binned_statistic(parts["pt"], parts["reconstructable_post"], statistic="count", bins=bins)
#     post_count, _, _ = binned_statistic(parts["pt"], parts["reconstructable_post"], statistic="sum", bins=bins)
#     pre_count, _, _ = binned_statistic(parts["pt"], parts["reconstructable_pre"], statistic="sum", bins=bins)

#     ys = post_count / pre_count
#     ys_err = np.sqrt(ys * (1 - ys) / pt_bin_count)
#     label = training
#     ax[1].errorbar(b, ys, yerr=None, xerr=b_err, color=colour, fmt=".", label=label, marker="s", markersize=2.0)
#     ax[1].errorbar(b, ys, yerr=ys_err, xerr=None, color=colour, fmt=".", marker="none", capsize=5.0, markeredgewidth=1.0)
#     pt_value = {"600 MeV": 0.6, "750 MeV": 0.75, "1 GeV": 1.0, "1 GeV eta 4": 1.0}[training]

# ax[0].set_xlim(0.9, 1.01)
# ax[0].set_ylim(0.3, 1.01)
# ax[1].set_ylim(0.97, 1.0)
# ax[0].set_ylabel("Hit purity")
# ax[0].set_xlabel("Hit efficiency")
# ax[1].set_ylabel("Reconstructable particles")
# ax[1].set_xlabel(r"Particle $p_\mathrm{T}$ [GeV]")
# ax[0].grid(zorder=0, alpha=0.25, linestyle="--")
# ax[1].grid(zorder=0, alpha=0.25, linestyle="--")
# ax[0].legend(frameon=False)
# legend_elements = [Line2D([0], [0], color=training_colours[training], label=training) for training in filtering_results]
# ax[1].legend(handles=legend_elements, frameon=False, loc="lower left")

# fig.savefig(out_dir / "filter_response.pdf")
# fig.show()

print("Plotting cell ready to run once filtering_results is available")

In [None]:
# Test cell to verify environment is working
print("Jupyter environment is working!")
print(f"NumPy version: {np.__version__}")
print(f"Pandas version: {pd.__version__}")
print(f"Matplotlib version: {plt.__version__}")

# Simple plot test
fig, ax = plt.subplots()
x = np.linspace(0, 10, 100)
y = np.sin(x)
ax.plot(x, y)
ax.set_title("Test Plot")
plt.show()