# Visualize CLD events from Key4HEP full simulation + reconstruction

## Imports

In [None]:
%matplotlib inline

In [None]:
from pathlib import Path
import uproot
import numpy as np
import math
import pandas
import awkward

import plotly.graph_objects as go
import vector

import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix

In [None]:
import jupyterlab

jupyterlab.__version__

## Constants

In [None]:
pion_mass = 0.13957
B = -2.0  # magnetic field in T (-2 for CLD FCC-ee)
c = 3e8  # speed of light in m/s
scale = 1000

## Define functions

In [None]:
# From https://bib-pubdb1.desy.de/record/81214/files/LC-DET-2006-004[1].pdf, eq12
# pT​(in GeV/c) ≈ a [mm/s] * |Bz(in T) / omega(1/mm)|
# a = c * 10^(-15) = 3*10^(-4)
def track_pt(omega, bfield=B):
    a = 3 * 10**-4
    return a * np.abs(bfield / omega)


def hits_to_features(hit_data, iev, coll, feats):
    """
    Converts hit data into a structured feature array for a specific event and collection.

    Args:
        hit_data (dict): A dictionary containing hit data, where keys are strings representing
            collection and feature names, and values are arrays of feature data.
        iev (int): The index of the event to extract data for.
        coll (str): The name of the hit collection (e.g., "VXDTrackerHits", "VXDEndcapTrackerHits",
            "ECALBarrel", "ECALEndcap", etc.).
        feats (list of str): A list of feature names to extract from the hit data
            (e.g., "position.x", "position.y", "position.z", "energy", "type", etc.).

    Returns:
        awkward.Array: An Awkward Array containing the extracted features for the specified event
            and collection. The array includes an additional "subdetector" feature, which encodes
            the subdetector type:
            - 0 for ECAL
            - 1 for HCAL
            - 2 for MUON
            - 3 for other collections.
    """
    # tracker hits store eDep instead of energy
    if "TrackerHit" in coll or "TrackerEndcapHits" in coll:
        new_feats = []
        for feat in feats:
            feat_to_get = feat
            if feat == "energy":
                feat_to_get = "eDep"
            new_feats.append((feat, feat_to_get))
    else:
        new_feats = [(f, f) for f in feats]

    feat_arr = {f1: hit_data[coll + "." + f2][iev] for f1, f2 in new_feats}

    sdcoll = "subdetector"
    feat_arr[sdcoll] = np.zeros(len(feat_arr["type"]), dtype=np.int32)
    if coll.startswith("ECAL"):
        feat_arr[sdcoll][:] = 0
    elif coll.startswith("HCAL"):
        feat_arr[sdcoll][:] = 1
    elif coll.startswith("MUON"):
        feat_arr[sdcoll][:] = 2
    else:
        feat_arr[sdcoll][:] = 3
    return awkward.Array(feat_arr)


def track_to_features(prop_data, iev, track_coll):
    """
    Extracts track features from the provided property data for a specific event and track collection.

    Args:
        prop_data (awkward array): An awkward array containing track property data.
        iev (int): The index of the event to extract data for.
        track_coll (str): The name of the track collection (e.g., "SiTracks_Refitted").

    Returns:
        awkward.Record: A record containing the extracted track features, including:
            - "type", "chi2", "ndf": Basic track properties.
            - "dEdx", "dEdxError": Energy deposition and its error.
            - "radiusOfInnermostHit": Radius of the innermost hit for each track.
            - "tanLambda", "D0", "phi", "omega", "Z0", "time": Track state properties.
            - "pt", "px", "py", "pz", "p": Momentum components and magnitude.
            - "eta": Pseudorapidity.
            - "sin_phi", "cos_phi": Sine and cosine of the azimuthal angle.
            - "elemtype": Element type (always 1 for tracks).
            - "q": Charge of the track (+1 or -1).

    Notes:
        - The function calculates additional derived features such as momentum components, pseudorapidity,
          and radius of the innermost hit.
        - The "AtFirstHit" state is used to determine the innermost hit radius.
        - The charge is set to 1 or -1 based on the sign of the "omega" parameter.
    """
    track_arr = prop_data[track_coll][iev]
    # the following are needed since they are no longer defined under SiTracks_Refitted
    track_arr_dQdx = prop_data["SiTracks_Refitted_dQdx"][iev]
    track_arr_trackStates = prop_data["_SiTracks_Refitted_trackStates"][iev]

    feats_from_track = ["type", "chi2", "ndf"]
    ret = {feat: track_arr[track_coll + "." + feat] for feat in feats_from_track}

    ret["dEdx"] = track_arr_dQdx["SiTracks_Refitted_dQdx.dQdx.value"]
    ret["dEdxError"] = track_arr_dQdx["SiTracks_Refitted_dQdx.dQdx.error"]

    # build the radiusOfInnermostHit variable
    num_tracks = len(ret["dEdx"])
    innermost_radius = []
    for itrack in range(num_tracks):

        # select the track states corresponding to itrack
        # pick the state AtFirstHit
        # https://github.com/key4hep/EDM4hep/blob/fe5a54046a91a7e648d0b588960db7841aebc670/edm4hep.yaml#L220
        ibegin = track_arr[track_coll + "." + "trackStates_begin"][itrack]
        iend = track_arr[track_coll + "." + "trackStates_end"][itrack]

        refX = track_arr_trackStates["_SiTracks_Refitted_trackStates" + "." + "referencePoint.x"][ibegin:iend]
        refY = track_arr_trackStates["_SiTracks_Refitted_trackStates" + "." + "referencePoint.y"][ibegin:iend]
        location = track_arr_trackStates["_SiTracks_Refitted_trackStates" + "." + "location"][ibegin:iend]

        istate = np.argmax(location == 2)  # 2 refers to AtFirstHit

        innermost_radius.append(math.sqrt(refX[istate] ** 2 + refY[istate] ** 2))

    ret["radiusOfInnermostHit"] = np.array(innermost_radius)

    # get the index of the first track state
    trackstate_idx = prop_data[track_coll][track_coll + ".trackStates_begin"][iev]
    # get the properties of the track at the first track state (at the origin)
    for k in ["tanLambda", "D0", "phi", "omega", "Z0", "time"]:
        ret[k] = awkward.to_numpy(
            prop_data["_SiTracks_Refitted_trackStates"]["_SiTracks_Refitted_trackStates." + k][iev][trackstate_idx]
        )

    ret["pt"] = awkward.to_numpy(track_pt(ret["omega"]))
    ret["px"] = awkward.to_numpy(np.cos(ret["phi"])) * ret["pt"]
    ret["py"] = awkward.to_numpy(np.sin(ret["phi"])) * ret["pt"]
    ret["pz"] = ret["pt"] * ret["tanLambda"]
    ret["p"] = np.sqrt(ret["px"] ** 2 + ret["py"] ** 2 + ret["pz"] ** 2)
    cos_theta = np.divide(ret["pz"], ret["p"], where=ret["p"] > 0)
    theta = np.arccos(cos_theta)
    tt = np.tan(theta / 2.0)
    eta = awkward.to_numpy(-np.log(tt, where=tt > 0))
    eta[tt <= 0] = 0.0
    ret["eta"] = eta

    ret["sin_phi"] = np.sin(ret["phi"])
    ret["cos_phi"] = np.cos(ret["phi"])

    # track is always type 1
    ret["elemtype"] = 1 * np.ones(num_tracks, dtype=np.float32)

    ret["q"] = ret["omega"].copy()
    ret["q"][ret["q"] > 0] = 1
    ret["q"][ret["q"] < 0] = -1

    return awkward.Record(ret)


def helix_eq(charge, bfield, v):
    """Calculate the 3D helical trajectory of a charged particle in a magnetic field.

    This function computes the x, y, and z coordinates of a charged particle's
    helical trajectory in a uniform magnetic field over a range of time values.

    Args:
        charge (float): The charge of the particle in elementary charge units.
        bfield (float): The magnetic field strength in Tesla.
        v (vector): A vector object representing the particle's velocity and properties.

    Returns:
        tuple: A tuple of three lists (x, y, z) representing the particle's
        trajectory in 3D space:
            - x (list): x-coordinates of the trajectory.
            - y (list): y-coordinates of the trajectory.
            - z (list): z-coordinates of the trajectory.
    """
    R = v.pt / (charge * 0.3 * bfield)
    omega = charge * 0.3 * bfield / (v.gamma * v.mass)
    t_values = np.linspace(0, 2 / (c * v.beta), 10)
    x = list(scale * R * np.cos(omega * c * t_values + v.phi - np.pi / 2) - scale * R * np.cos(v.phi - np.pi / 2))
    y = list(scale * R * np.sin(omega * c * t_values + v.phi - np.pi / 2) - scale * R * np.sin(v.phi - np.pi / 2))
    z = list(scale * v.pz * c * t_values / (v.gamma * v.mass))
    return x, y, z

## Load events from ROOT file

In [None]:
root_files_dir = Path("/mnt/ceph/users/ewulff/data/cld/Dec3/subfolder_0/")
root_file = root_files_dir / "reco_p8_ee_tt_ecm365_10000.root"
fi = uproot.open(root_file)
ev = fi["events"]

# which event to pick from the file
iev = 2

## Particles IDs and color coding

In [None]:
pdg_dict = {
    22: "photon",  # photon
    11: "electron",  # electron
    13: "muon",  # muon
    130: "n. hadron",  # neutral hadron
    211: "ch. hadron",  # charged hadron
}

color_dict = {
    "photon": "red",  # photon
    "electron": "green",  # electron
    "muon": "purple",  # muon
    "n. hadron": "orange",  # neutral hadron
    "ch. hadron": "blue",  # charged hadron
    None: "black",  # placeholder for skipped element
}

# Raw hits, tracks and clusters

In [None]:
collectionIDs = {
    k: v
    for k, v in zip(
        fi.get("podio_metadata").arrays("events___idTable/m_names")["events___idTable/m_names"][0],
        fi.get("podio_metadata").arrays("events___idTable/m_collectionIDs")["events___idTable/m_collectionIDs"][0],
    )
}

collectionIDs_reverse = {v: k for k, v in collectionIDs.items()}

In [None]:
for k in ev.keys():
    if "VertexBarrelCollection" in k:
        print(k)

In [None]:
tracker_hit_data = {
    "VXDTrackerHits": ev["VXDTrackerHits"].array(),
    "VXDEndcapTrackerHits": ev["VXDEndcapTrackerHits"].array(),
    "ITrackerHits": ev["ITrackerHits"].array(),
    "OTrackerHits": ev["OTrackerHits"].array(),
    # these need to be added to the keep statements of the next generation
    # "ITrackerEndcapHits": ev["ITrackerEndcapHits"].array(),
    # "OTrackerEndcapHits": ev["OTrackerEndcapHits"].array(),
}

calo_hit_data = {
    "ECALBarrel": ev["ECALBarrel"].array(),
    "ECALEndcap": ev["ECALEndcap"].array(),
    "HCALBarrel": ev["HCALBarrel"].array(),
    "HCALEndcap": ev["HCALEndcap"].array(),
    "HCALOther": ev["HCALOther"].array(),
    "MUON": ev["MUON"].array(),
}

tracker_and_calo_hit_data = {**tracker_hit_data, **calo_hit_data}

In [None]:
# tracks
# borrowed from https://github.com/jpata/particleflow/blob/main/mlpf/data/key4hep/postprocessing.py
track_coll = "SiTracks_Refitted"
track_feature_order = [
    "elemtype",
    "pt",
    "eta",
    "sin_phi",
    "cos_phi",
    "p",
    "chi2",
    "ndf",
    "dEdx",
    "dEdxError",
    "radiusOfInnermostHit",
    "tanLambda",
    "D0",
    "omega",
    "Z0",
    "time",
]
track_prop_data = ev.arrays([track_coll, "_SiTracks_Refitted_trackStates", "SiTracks_Refitted_dQdx"])
track_features = track_to_features(track_prop_data, iev, track_coll)

# clusters
cluster_x = ev["PandoraClusters/PandoraClusters.position.x"].array()[iev]
cluster_y = ev["PandoraClusters/PandoraClusters.position.y"].array()[iev]
cluster_z = ev["PandoraClusters/PandoraClusters.position.z"].array()[iev]
cluster_energy = ev["PandoraClusters/PandoraClusters.energy"].array()[iev]

In [None]:
# hit-to-track associations
hit_beg = ev["SiTracks_Refitted/SiTracks_Refitted.trackerHits_begin"].array()[iev]
hit_end = ev["SiTracks_Refitted/SiTracks_Refitted.trackerHits_end"].array()[iev]
trk_hit_idx = ev["_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.index"].array()[iev]
trk_hit_coll = ev["_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.collectionID"].array()[iev]

# initialize the hit_to_track dict with -1 everywhere
hit_to_track = {
    k: -1 * np.ones(len(tracker_hit_data[k][tracker_hit_data[k].fields[0]][iev]), dtype=np.int32)
    for k in tracker_hit_data.keys()
}

# loop over the tracker hits and fill the hit_to_track dict
for itrk in range(len(track_features["pt"])):
    for ihit in range(hit_beg[itrk], hit_end[itrk]):
        idx = trk_hit_idx[ihit]
        coll = collectionIDs_reverse[trk_hit_coll[ihit]]
        if coll in hit_to_track:
            hit_to_track[coll][idx] = itrk

# hit-to-cluster associations
hit_beg = ev["PandoraClusters/PandoraClusters.hits_begin"].array()[iev]
hit_end = ev["PandoraClusters/PandoraClusters.hits_end"].array()[iev]
cls_hit_idx = ev["_PandoraClusters_hits/_PandoraClusters_hits.index"].array()[iev]
cls_hit_coll = ev["_PandoraClusters_hits/_PandoraClusters_hits.collectionID"].array()[iev]

# initialize the hit_to_cls dict with -1 everywhere
hit_to_cls = {
    k: -1 * np.ones(len(calo_hit_data[k][calo_hit_data[k].fields[0]][iev]), dtype=np.int32)
    for k in calo_hit_data.keys()
}

# loop over the clusters and fill the hit_to_cls dict
for icls in range(len(cluster_energy)):
    for ihit in range(hit_beg[icls], hit_end[icls]):
        idx = cls_hit_idx[ihit]
        coll = collectionIDs_reverse[cls_hit_coll[ihit]]
        hit_to_cls[coll][idx] = icls

In [None]:
# combine tracker and calo hits into one hit_feature_matrix

hit_feats = ["position.x", "position.y", "position.z", "energy", "type"]
hit_feature_matrix = []
for col in sorted(tracker_hit_data.keys()):
    hit_features = hits_to_features(tracker_hit_data[col], iev, col, hit_feats)
    hit_features["reco_idx"] = hit_to_track[col]  # which track the hit belongs to
    hit_feature_matrix.append(hit_features)

for col in sorted(calo_hit_data.keys()):
    hit_features = hits_to_features(calo_hit_data[col], iev, col, hit_feats)
    hit_features["reco_idx"] = hit_to_cls[col]  # which cluster the hit belongs to
    hit_feature_matrix.append(hit_features)

hit_feature_matrix = awkward.Array(
    {
        k: awkward.concatenate([hit_feature_matrix[i][k] for i in range(len(hit_feature_matrix))])
        for k in hit_feature_matrix[0].fields
    }
)

# put into dataframe
df = pandas.DataFrame()
df["px"] = hit_feature_matrix["position.x"].to_numpy()
df["py"] = hit_feature_matrix["position.y"].to_numpy()
df["pz"] = hit_feature_matrix["position.z"].to_numpy()
df["energy"] = 1000 * hit_feature_matrix["energy"].to_numpy()
df["plotsize"] = 0.0
df["subdetector"] = hit_feature_matrix["subdetector"].to_numpy()

df.loc[df["subdetector"] == 0, "plotsize"] = df.loc[df["subdetector"] == 0, "energy"] / 5.0
df.loc[df["subdetector"] == 1, "plotsize"] = df.loc[df["subdetector"] == 1, "energy"] / 10.0
df.loc[df["subdetector"] == 2, "plotsize"] = df.loc[df["subdetector"] == 2, "energy"] * 100.0
df.loc[df["subdetector"] == 3, "plotsize"] = df.loc[df["subdetector"] == 3, "energy"] * 100.0

In [None]:
labels = {
    0: "Raw ECAL hit",
    1: "Raw HCAL hit",
    2: "Raw Muon chamber hit",
    3: "Raw tracker hit",
}

subdetector_color = {
    0: "steelblue",  # ECAL
    1: "green",  # HCAL
    2: "orange",  # MUON
    3: "red",  # Tracker
}

In [None]:
# tracks
track_px, track_py, track_pz, track_charge = (
    track_features["px"],
    track_features["py"],
    track_features["pz"],
    track_features["q"],
)

track_mass = np.zeros_like(track_px) + pion_mass

track_x, track_y, track_z = [], [], []
for irow in range(len(track_px)):

    # convert to vector
    v = vector.obj(px=track_px[irow], py=track_py[irow], pz=track_pz[irow], mass=track_mass[irow])

    x, y, z = helix_eq(track_charge[irow], B, v)
    track_x += x
    track_y += y
    track_z += z

    track_x += [None]
    track_y += [None]
    track_z += [None]

In [None]:
traces = []

# raw hits
for subdetector in [0, 1, 2, 3]:

    trace = go.Scatter3d(
        x=np.clip(df["px"][df["subdetector"] == subdetector], -4000, 4000),
        y=np.clip(df["py"][df["subdetector"] == subdetector], -4000, 4000),
        z=np.clip(df["pz"][df["subdetector"] == subdetector], -4000, 4000),
        mode="markers",
        marker=dict(
            size=np.clip(2 + 2 * np.log(df["plotsize"]), 1, 15),
            color=subdetector_color[subdetector],
            colorscale="Viridis",
            opacity=0.8,
        ),
        name=labels[subdetector],
    )

    traces.append(trace)

# add the tracks
trace = go.Scatter3d(
    x=np.array(track_x),
    y=np.array(track_y),
    z=np.array(track_z),
    mode="lines",
    name="Track",
    line=dict(color="red"),
)
traces.append(trace)

# add the clusters
trace = go.Scatter3d(
    x=cluster_x,
    y=cluster_y,
    z=cluster_z,
    mode="markers",
    marker=dict(
        size=cluster_energy,
        color="blue",
        opacity=0.8,
    ),
    name="ECAL/HCAL clusters",
)
traces.append(trace)

# Customize the axis names
layout = go.Layout(
    scene=dict(
        xaxis=dict(title="", showticklabels=False),
        yaxis=dict(title="", showticklabels=False),
        zaxis=dict(title="", showticklabels=False),
        camera=dict(
            up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
            center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
            eye=dict(x=0, y=0, z=2.0),  # Sets the position of the camera
        ),
    ),
    legend=dict(x=0.75, y=0.5, font=dict(size=20)),  # https://plotly.com/python/legend/
    showlegend=False,
    width=700,
    height=700,
)

# Create the figure and display the plot
fig = go.Figure(data=traces, layout=layout)

fig.update_traces(
    marker_line_width=0, selector=dict(type="scatter3d")
)  # for plotly to avoid plotting white spots when things overlap
# fig.write_image("pic_tracks_rawcalohits.pdf", width=1000, height=1000, scale=2)

# Particles

In [None]:
# pick final particles from pythia
msk_gen = ev["MCParticles/MCParticles.generatorStatus"].array() == 1
gen_px = ev["MCParticles/MCParticles.momentum.x"].array()[msk_gen][iev]
gen_py = ev["MCParticles/MCParticles.momentum.y"].array()[msk_gen][iev]
gen_pz = ev["MCParticles/MCParticles.momentum.z"].array()[msk_gen][iev]
gen_mass = ev["MCParticles/MCParticles.mass"].array()[msk_gen][iev]
gen_charge = ev["MCParticles/MCParticles.charge"].array()[msk_gen][iev]
gen_pdg = awkward.to_numpy(np.absolute(ev["MCParticles/MCParticles.PDG"].array()[msk_gen][iev]))

# set all other particles to ch.had or n.had
gen_pdg[(gen_pdg != 13) & (gen_pdg != 11) & (gen_pdg != 22) & (np.abs(gen_charge) == 1)] = 211
gen_pdg[(gen_pdg != 13) & (gen_pdg != 11) & (gen_pdg != 22) & (np.abs(gen_charge) == 0)] = 130

# extrapolate MC particle trajectories
mc_x = []
mc_y = []
mc_z = []
pdg_list = []
for irow in range(len(gen_px)):

    # convert to vector
    v = vector.obj(px=gen_px[irow], py=gen_py[irow], pz=gen_pz[irow], mass=gen_mass[irow])
    if gen_charge[irow] == 0:
        this_mc_x = [0, np.clip(scale * v.px / v.mag, -4000, 4000)]
        this_mc_y = [0, np.clip(scale * v.py / v.mag, -4000, 4000)]
        this_mc_z = [0, np.clip(scale * v.pz / v.mag, -4000, 4000)]
    else:
        x, y, z = helix_eq(gen_charge[irow], B, v)
        this_mc_x = x
        this_mc_y = y
        this_mc_z = z

    pdg_list += len(this_mc_x) * [pdg_dict[gen_pdg[irow]]]

    mc_x += this_mc_x
    mc_y += this_mc_y
    mc_z += this_mc_z

    mc_x += [None]
    mc_y += [None]
    mc_z += [None]
    pdg_list += [None]

In [None]:
# create gen particle feature matrix
gen_feats = ["px", "py", "pz", "mass", "charge", "pdg"]
gen_feature_matrix = {
    feat: awkward.Array([gen_px, gen_py, gen_pz, gen_mass, gen_charge, gen_pdg][i]) for i, feat in enumerate(gen_feats)
}
gen_feature_matrix = awkward.Array(gen_feature_matrix)

In [None]:
# Create 3D scatter plot of the particles

trace = go.Scatter3d(
    x=np.array(mc_x),
    y=np.array(mc_y),
    z=np.array(mc_z),
    mode="lines",
    line=dict(color=[color_dict[p] for p in pdg_list]),
    text=[f"Particle: {p}" for p in pdg_list],  # Add particle labels
    hoverinfo="text",  # Display text on hover
)

# Customize the axis names
layout = go.Layout(
    scene=dict(
        xaxis=dict(title="", showticklabels=False),
        yaxis=dict(title="", showticklabels=False),
        zaxis=dict(title="", showticklabels=False),
        camera=dict(
            up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
            center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
            eye=dict(x=0, y=0, z=2.0),  # Sets the position of the camera
        ),
    ),
    showlegend=False,
    width=700,
    height=700,
)

# Create the figure and display the plot
fig = go.Figure(data=trace, layout=layout)
fig.update_traces(
    marker_line_width=0, selector=dict(type="scatter3d")
)  # for plotly to avoid plotting white spots when things overlap
# fig.write_image("pic_particles.pdf", width=1000, height=1000, scale=2)

## Particles with legend

In [None]:
# Create 3D scatter plot with one trace per particle
traces = []
unique_particles = set(pdg_list)  # Get unique particle types

for particle in unique_particles:

    # Get indices for the current particle, including None at the end of each track
    indices = []
    for i, p in enumerate(pdg_list):
        if p is None:  # we don't need to add non-particles
            continue
        if p == particle:
            indices.append(i)
            if pdg_list[i + 1] is None:
                indices.append(i + 1)  # Add None at the end of the track to separate tracks in plot

    # Create a separate trace for each particle
    traces.append(
        go.Scatter3d(
            x=np.array(mc_x)[indices],
            y=np.array(mc_y)[indices],
            z=np.array(mc_z)[indices],
            mode="lines",
            line=dict(color=color_dict[particle]),  # Assign color for the particle
            name=f"{particle}",  # Add particle name to the legend
            showlegend=True,
        )
    )

# Customize the axis names
layout = go.Layout(
    scene=dict(
        xaxis=dict(title="", showticklabels=False),
        yaxis=dict(title="", showticklabels=False),
        zaxis=dict(title="", showticklabels=False),
        camera=dict(
            up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
            center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
            eye=dict(x=0, y=0, z=2.0),  # Sets the position of the camera
        ),
    ),
    showlegend=True,
    width=700,
    height=700,
)

# Create the figure and display the plot
fig = go.Figure(data=traces, layout=layout)
fig.update_traces(
    marker_line_width=0, selector=dict(type="scatter3d")
)  # Avoid plotting white spots when things overlap
# fig.write_image("pic_particles_legend.pdf", width=1000, height=1000, scale=2)

## Sanity checks

In [None]:
# check the alignment of extrapolated tracks and their associated hits
fig, axes = plt.subplots(3, 3, figsize=(6, 6))  # Adjust figsize as needed
axes = axes.ravel()
for itrk in range(9):
    plt.sca(axes[itrk])
    v = vector.obj(px=track_px[itrk], py=track_py[itrk], pz=track_pz[itrk], mass=track_mass[itrk])
    hs = hit_feature_matrix[(hit_feature_matrix["subdetector"] == 3) & (hit_feature_matrix["reco_idx"] == itrk)]
    x, y, z = helix_eq(track_charge[itrk], B, v)
    plt.plot(x, z, label="Track")
    plt.scatter(hs["position.x"], hs["position.z"], color="red", marker=".", label="Tracker hit")
    plt.xlim(-2000, 2000)
    plt.ylim(-2000, 2000)
    plt.xticks([])
    plt.yticks([])
    axes[itrk].set_box_aspect(1)
axes[2].legend(loc="upper right")
plt.tight_layout()

In [None]:
# check the alignment of clusters and their associated hits
fig, axes = plt.subplots(3, 3, figsize=(6, 6))  # Adjust figsize as needed
axes = axes.ravel()
for icls in range(9):
    plt.sca(axes[icls])
    hs = hit_feature_matrix[(hit_feature_matrix["subdetector"] != 3) & (hit_feature_matrix["reco_idx"] == icls)]
    plt.scatter(
        cluster_x[icls],
        cluster_y[icls],
        s=100 * cluster_energy[icls],
        alpha=0.5,
        label="cluster",
    )
    plt.scatter(hs["position.x"], hs["position.y"], color="red", marker=".", label="hit")
    plt.xlim(-4000, 4000)
    plt.ylim(-4000, 4000)
    plt.xticks([])
    plt.yticks([])
    axes[icls].set_box_aspect(1)
axes[2].legend(loc="upper right")
plt.tight_layout()