# Data processing for CLD

## Imports

In [None]:
from pathlib import Path
import uproot
import numpy as np
import pandas
import awkward as ak
from tqdm import tqdm
import pickle
import os
import lmdb

import plotly.graph_objects as go
import vector

import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix

## Function definitions
These are defined in `data_processing/cld_processing` but are reproduced here clarity.

In [None]:
# Constants
pion_mass = 0.13957
B = -2.0  # magnetic field in T (-2 for CLD FCC-ee)
c = 3e8  # speed of light in m/s
scale = 1000

# append path
import sys

sys.path.append(str(Path("/mnt/ceph/users/ewulff/particlemind")))

from data_processing.cld_processing import (
    get_event_data,
    gen_to_features,
    track_to_features,
    cluster_to_features,
    process_calo_hit_data,
    process_tracker_hit_data,
    create_track_to_hit_coo_matrix,
    create_cluster_to_hit_coo_matrix,
    genparticle_track_adj,
    create_genparticle_to_genparticle_coo_matrix,
    create_genparticle_to_genparticle_coo_matrix2,
)

## Main

Data to extract:
- [x] hits (all tracker and calo hit features in the event)
- [x] genparticles (all genparticle features in the event, noting that the genparticles are in the form of a decay tree)
- [x] genparticle_to_genparticle (sparse association matrix for the genparticle decay tree)
- [x] hits_to_genparticles (sparse association matrix)
- [x] tracks (all track features in the event)
- [x] tracks_to_hits (sparse association matrix between tracks and tracker hits)
- [x] clusters (all cluster features in the event)
- [x] clusters_to_hits (sparse association matrix between clusters and calo hits)

Additional data to extract:
- [x] genparticle_to_track (sparse association matrix for the genparticle to track associations)

In [None]:
root_files_dir = Path("/mnt/ceph/users/ewulff/data/cld/Dec3/subfolder_0/")
root_file = root_files_dir / "reco_p8_ee_tt_ecm365_10000.root"
fi = uproot.open(root_file)
ev = fi["events"]

# which event to pick from the file
iev = 2

collectionIDs = {
    k: v
    for k, v in zip(
        fi.get("podio_metadata").arrays("events___idTable/m_names")["events___idTable/m_names"][0],
        fi.get("podio_metadata").arrays("events___idTable/m_collectionIDs")["events___idTable/m_collectionIDs"][0],
    )
}

event_data = get_event_data(ev)

# Extract gparticle, track, and cluster features
gen_features = gen_to_features(event_data, iev)
track_features = track_to_features(event_data, iev)
cluster_features = cluster_to_features(
    event_data, iev, cluster_features=["position.x", "position.y", "position.z", "energy", "type"]
)

# Process calorimeter hit data
calo_hit_features, genparticle_to_calo_hit_matrix, calo_hit_idx_local_to_global = process_calo_hit_data(
    event_data, iev, collectionIDs
)

# Process tracker hit data
tracker_hit_features, genparticle_to_tracker_hit_matrix, tracker_hit_idx_local_to_global = process_tracker_hit_data(
    event_data, iev, collectionIDs
)

# Create the track-to-trackerhit adjacency matrix
track_to_tracker_hit_matrix, tracker_hit_idx_local_to_global_2 = create_track_to_hit_coo_matrix(
    event_data, iev, collectionIDs
)
assert (
    tracker_hit_idx_local_to_global == tracker_hit_idx_local_to_global_2
), "Local to global tracker hit index mapping mismatch!"

# Ceate the cluster-to-clusterhit adjacency matrix
cluster_to_cluster_hit_matrix, calo_hit_idx_local_to_global_2 = create_cluster_to_hit_coo_matrix(
    event_data, iev, collectionIDs
)
assert (
    calo_hit_idx_local_to_global == calo_hit_idx_local_to_global_2
), "Local to global calorimeter hit index mapping mismatch!"

# Create the genparticle-to-track adjacency matrix
gp_to_track_matrix = genparticle_track_adj(event_data, iev)

# Create the genparticle-to-genparticle adjacency matrix
gp_to_gp = create_genparticle_to_genparticle_coo_matrix(event_data, iev)

# Check consistency between the two methods gp-to-gp methods
gp_to_gp2 = create_genparticle_to_genparticle_coo_matrix2(event_data, iev)

# Get the number of genparticles in the event
n_gp = len(ev["MCParticles.momentum.x"].array()[iev])

# create coo matrix through the COOs in gp_to_gp and gp_to_gp2
coo_matrix_gp_to_gp = coo_matrix(
    (gp_to_gp["weight"], (gp_to_gp["parent_idx"], gp_to_gp["daughter_idx"])), shape=(n_gp, n_gp)
)
coo_matrix_gp_to_gp2 = coo_matrix(
    (gp_to_gp2["weight"], (gp_to_gp2["parent_idx"], gp_to_gp2["daughter_idx"])), shape=(n_gp, n_gp)
)
# Check if the two dense gp_to_gp matrices are equal
assert (coo_matrix_gp_to_gp.todense() == coo_matrix_gp_to_gp2.todense()).all()
# Define the output file path
output_file = Path("extracted_features.hdf5")

In [None]:
type(gen_features), type(track_features), type(cluster_features)

In [None]:
type(gen_features["px"]), type(track_features["px"]), type(cluster_features["position.x"])

In [None]:
type(calo_hit_features), type(tracker_hit_features)

In [None]:
type(calo_hit_features["energy"]), type(tracker_hit_features["energy"])

In [None]:
(
    type(genparticle_to_calo_hit_matrix),
    type(genparticle_to_tracker_hit_matrix),
    type(track_to_tracker_hit_matrix),
    type(cluster_to_cluster_hit_matrix),
    type(gp_to_gp),
)

In [None]:
(
    type(genparticle_to_calo_hit_matrix["hit_idx"]),
    type(genparticle_to_tracker_hit_matrix["gen_idx"]),
    type(track_to_tracker_hit_matrix["track_idx"]),
    type(cluster_to_cluster_hit_matrix["cluster_idx"]),
    type(gp_to_gp["parent_idx"]),
)

## Plots and sanity checks

In [None]:
# Combine data from calo_hit_features and tracker_hit_features into a single dataframe
df = pandas.DataFrame()

# Extract data from calo_hit_features and tracker_hit_features using numpy.concatenate
df["px"] = np.concatenate([calo_hit_features["position.x"], tracker_hit_features["position.x"]])
df["py"] = np.concatenate([calo_hit_features["position.y"], tracker_hit_features["position.y"]])
df["pz"] = np.concatenate([calo_hit_features["position.z"], tracker_hit_features["position.z"]])
df["energy"] = np.concatenate([1000 * calo_hit_features["energy"], 1000 * tracker_hit_features["energy"]])
df["plotsize"] = 0.0
df["subdetector"] = np.concatenate([calo_hit_features["subdetector"], tracker_hit_features["subdetector"]])

# Calculate plotsize based on subdetector values
df.loc[df["subdetector"] == 0, "plotsize"] = df.loc[df["subdetector"] == 0, "energy"] / 5.0
df.loc[df["subdetector"] == 1, "plotsize"] = df.loc[df["subdetector"] == 1, "energy"] / 10.0
df.loc[df["subdetector"] == 2, "plotsize"] = df.loc[df["subdetector"] == 2, "energy"] * 100.0
df.loc[df["subdetector"] == 3, "plotsize"] = df.loc[df["subdetector"] == 3, "energy"] * 100.0

In [None]:
# tracks
def helix_eq(charge, bfield, v):
    """Calculate the 3D helical trajectory of a charged particle in a magnetic field.

    This function computes the x, y, and z coordinates of a charged particle's
    helical trajectory in a uniform magnetic field over a range of time values.

    Args:
        charge (float): The charge of the particle in elementary charge units.
        bfield (float): The magnetic field strength in Tesla.
        v (vector): A vector object representing the particle's velocity and properties.

    Returns:
        tuple: A tuple of three lists (x, y, z) representing the particle's
        trajectory in 3D space:
            - x (list): x-coordinates of the trajectory.
            - y (list): y-coordinates of the trajectory.
            - z (list): z-coordinates of the trajectory.
    """
    R = v.pt / (charge * 0.3 * bfield)
    omega = charge * 0.3 * bfield / (v.gamma * v.mass)
    t_values = np.linspace(0, 2 / (c * v.beta), 10)
    x = list(scale * R * np.cos(omega * c * t_values + v.phi - np.pi / 2) - scale * R * np.cos(v.phi - np.pi / 2))
    y = list(scale * R * np.sin(omega * c * t_values + v.phi - np.pi / 2) - scale * R * np.sin(v.phi - np.pi / 2))
    z = list(scale * v.pz * c * t_values / (v.gamma * v.mass))
    return x, y, z


track_px, track_py, track_pz, track_charge = (
    track_features["px"],
    track_features["py"],
    track_features["pz"],
    track_features["q"],
)

track_mass = np.zeros_like(track_px) + pion_mass

track_x, track_y, track_z = [], [], []
for irow in range(len(track_px)):

    # convert to vector
    v = vector.obj(px=track_px[irow], py=track_py[irow], pz=track_pz[irow], mass=track_mass[irow])

    x, y, z = helix_eq(track_charge[irow], B, v)
    track_x += x
    track_y += y
    track_z += z

    track_x += [None]  # Add None to separate tracks in the plot
    track_y += [None]
    track_z += [None]

In [None]:
labels = {0: "Raw ECAL hit", 1: "Raw HCAL hit", 2: "Raw Muon chamber hit", 3: "Raw tracker hit"}

subdetector_color = {0: "steelblue", 1: "green", 2: "orange", 3: "red"}  # ECAL  # HCAL  # MUON  # Tracker

In [None]:
traces = []

# raw hits
for subdetector in [0, 1, 2, 3]:

    trace = go.Scatter3d(
        x=np.clip(df["px"][df["subdetector"] == subdetector], -4000, 4000),
        y=np.clip(df["py"][df["subdetector"] == subdetector], -4000, 4000),
        z=np.clip(df["pz"][df["subdetector"] == subdetector], -4000, 4000),
        mode="markers",
        marker=dict(
            size=np.clip(2 + 2 * np.log(df["plotsize"]), 1, 15),
            color=subdetector_color[subdetector],
            colorscale="Viridis",
            opacity=0.8,
        ),
        name=labels[subdetector],
    )
    traces.append(trace)

# add the tracks
trace = go.Scatter3d(
    x=np.array(track_x),
    y=np.array(track_y),
    z=np.array(track_z),
    mode="lines",
    name="Track",
    line=dict(color="red"),
)
traces.append(trace)

# # add the clusters
trace = go.Scatter3d(
    x=cluster_features["position.x"],
    y=cluster_features["position.y"],
    z=cluster_features["position.z"],
    mode="markers",
    marker=dict(
        size=cluster_features["energy"],
        color="blue",
        opacity=0.8,
    ),
    name="ECAL/HCAL clusters",
)
traces.append(trace)

# Customize the axis names
layout = go.Layout(
    scene=dict(
        xaxis=dict(title="", showticklabels=False),
        yaxis=dict(title="", showticklabels=False),
        zaxis=dict(title="", showticklabels=False),
        camera=dict(
            up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
            center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
            eye=dict(x=0, y=0, z=2.0),  # Sets the position of the camera
        ),
    ),
    legend=dict(x=0.8, y=0.5, font=dict(size=16)),  # https://plotly.com/python/legend/
    showlegend=True,
    width=700,
    height=700,
)

# Create the figure and display the plot
fig = go.Figure(data=traces, layout=layout)
fig.update_traces(
    marker_line_width=0, selector=dict(type="scatter3d")
)  # for plotly to avoid plotting white spots when things overlap
# fig.write_image("pic_tracks_rawcalohits.pdf", width=1000, height=1000, scale=2)

In [None]:
pdg_dict = {
    22: "photon",  # photon
    11: "electron",  # electron
    13: "muon",  # muon
    130: "n. hadron",  # neutral hadron
    211: "ch. hadron",  # charged hadron
}

color_dict = {
    "photon": "red",  # photon
    "electron": "green",  # electron
    "muon": "purple",  # muon
    "n. hadron": "orange",  # neutral hadron
    "ch. hadron": "blue",  # charged hadron
    None: "black",  # placeholder for skipped element
}

# Extract relevant data from gen_features
gen_px = gen_features["px"]
gen_py = gen_features["py"]
gen_pz = gen_features["pz"]
gen_mass = gen_features["mass"]
gen_charge = gen_features["charge"]
gen_pdg = ak.to_numpy(np.absolute(gen_features["PDG"]))

# Set all other particles to ch.had or n.had
gen_pdg[(gen_pdg != 13) & (gen_pdg != 11) & (gen_pdg != 22) & ak.to_numpy(np.abs(gen_charge) > 0)] = (
    211  # when not filtering genstatus==1, charge can be between 0 and 1
)
gen_pdg[(gen_pdg != 13) & (gen_pdg != 11) & (gen_pdg != 22) & ak.to_numpy(np.abs(gen_charge) == 0)] = 130

# Extrapolate MC particle trajectories
mc_x = []
mc_y = []
mc_z = []
pdg_list = []
for irow in range(len(gen_px)):

    # Convert to vector
    v = vector.obj(px=gen_px[irow], py=gen_py[irow], pz=gen_pz[irow], mass=gen_mass[irow])
    if gen_charge[irow] == 0:
        this_mc_x = [0, np.clip(scale * v.px / v.mag, -4000, 4000)]
        this_mc_y = [0, np.clip(scale * v.py / v.mag, -4000, 4000)]
        this_mc_z = [0, np.clip(scale * v.pz / v.mag, -4000, 4000)]
    else:
        x, y, z = helix_eq(gen_charge[irow], B, v)
        this_mc_x = x
        this_mc_y = y
        this_mc_z = z

    pdg_list += len(this_mc_x) * [pdg_dict[gen_pdg[irow]]]

    mc_x += this_mc_x
    mc_y += this_mc_y
    mc_z += this_mc_z

    mc_x += [None]
    mc_y += [None]
    mc_z += [None]
    pdg_list += [None]

In [None]:
# Create 3D scatter plot with one trace per particle
traces = []
unique_particles = set(pdg_list)  # Get unique particle types

for particle in unique_particles:

    # Get indices for the current particle, including None at the end of each track
    indices = []
    for i, p in enumerate(pdg_list):
        if p is None:  # we don't need to add non-particles
            continue
        if p == particle:
            indices.append(i)
            if pdg_list[i + 1] is None:
                indices.append(i + 1)  # Add None at the end of the track to separate tracks in plot

    # Create a separate trace for each particle
    traces.append(
        go.Scatter3d(
            x=np.array(mc_x)[indices],
            y=np.array(mc_y)[indices],
            z=np.array(mc_z)[indices],
            mode="lines",
            line=dict(color=color_dict[particle]),  # Assign color for the particle
            name=f"{particle}",  # Add particle name to the legend
            showlegend=True,
        )
    )

# Customize the axis names
layout = go.Layout(
    scene=dict(
        xaxis=dict(title="", showticklabels=False),
        yaxis=dict(title="", showticklabels=False),
        zaxis=dict(title="", showticklabels=False),
        camera=dict(
            up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
            center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
            eye=dict(x=0, y=0, z=2.0),  # Sets the position of the camera
        ),
    ),
    showlegend=True,
    width=700,
    height=700,
)

# Create the figure and display the plot
fig = go.Figure(data=traces, layout=layout)
fig.update_traces(
    marker_line_width=0, selector=dict(type="scatter3d")
)  # Avoid plotting white spots when things overlap
# fig.write_image("pic_particles_legend.pdf", width=1000, height=1000, scale=2)

In [None]:
# check the alignment of extrapolated tracks and their associated hits
fig, axes = plt.subplots(3, 3, figsize=(6, 6))  # Adjust figsize as needed
axes = axes.ravel()

for itrk in range(9):
    plt.sca(axes[itrk])
    v = vector.obj(px=track_px[itrk], py=track_py[itrk], pz=track_pz[itrk], mass=track_mass[itrk])

    # Get the global hit indices associated with the current track
    # track_to_tracker_hit_matrix is a tuple where the first element contains the track indices and the second element contains the hit indices
    hit_indices = track_to_tracker_hit_matrix["hit_idx"][track_to_tracker_hit_matrix["track_idx"] == itrk]

    # Extract the corresponding hit positions
    hs_x = tracker_hit_features["position.x"][hit_indices]
    hs_z = tracker_hit_features["position.z"][hit_indices]

    # Calculate the helix trajectory
    x, y, z = helix_eq(track_charge[itrk], B, v)

    # Plot the track and associated hits
    plt.plot(x, z, label="Track")
    plt.scatter(hs_x, hs_z, color="red", marker=".", label="Tracker hit")
    plt.xlim(-2000, 2000)
    plt.ylim(-2000, 2000)
    plt.xticks([])
    plt.yticks([])
    axes[itrk].set_box_aspect(1)

axes[2].legend(loc="upper right")
plt.tight_layout()

In [None]:
# Check the alignment of clusters and their associated hits
fig, axes = plt.subplots(3, 3, figsize=(6, 6))  # Adjust figsize as needed
axes = axes.ravel()

for icls in range(9):
    plt.sca(axes[icls])

    # Get the global hit indices associated with the current cluster
    # cluster_to_cluster_hit_matrix is a tuple where the first element contains the cluster indices and the second element contains the hit indices
    hit_indices = cluster_to_cluster_hit_matrix["hit_idx"][cluster_to_cluster_hit_matrix["cluster_idx"] == icls]

    # Extract the corresponding hit positions
    hs_x = calo_hit_features["position.x"][hit_indices]
    hs_y = calo_hit_features["position.y"][hit_indices]

    # Plot the cluster and associated hits
    plt.scatter(
        cluster_features["position.x"][icls],
        cluster_features["position.y"][icls],
        s=100 * cluster_features["energy"][icls],
        alpha=0.5,
        label="cluster",
    )
    plt.scatter(hs_x, hs_y, color="red", marker=".", label="hit")
    plt.xlim(-4000, 4000)
    plt.ylim(-4000, 4000)
    plt.xticks([])
    plt.yticks([])
    axes[icls].set_box_aspect(1)

axes[2].legend(loc="upper right")
plt.tight_layout()

## Genparticle to calorimeter hit associations

In [None]:
# Extract genparticle_to_calo_hit_matrix from event_data1
# This matrix contains the mapping of genparticles to calorimeter hits in a COO format
rows = genparticle_to_calo_hit_matrix["gen_idx"]
cols = genparticle_to_calo_hit_matrix["hit_idx"]
weights = genparticle_to_calo_hit_matrix["weight"]

# create dense coo amtrix
gp_to_calo_hit_matrix = coo_matrix((weights, (rows, cols)), shape=(np.max(rows) + 1, np.max(cols) + 1)).todense()

In [None]:
len(weights), np.sum(gp_to_calo_hit_matrix > 0)

### How many genparticles leave at least one hit? (we haven't filtered on generatorStatus yet)

In [None]:
# Extract rows that sum to greater than 0
non_zero_rows = np.where(np.sum(gp_to_calo_hit_matrix, axis=1) > 0)[0]

# Create a new matrix with the extracted rows
gp_to_calo_hit_matrix_non_zero = gp_to_calo_hit_matrix[non_zero_rows, :]
gp_to_calo_hit_matrix.shape, gp_to_calo_hit_matrix_non_zero.shape

In [None]:
print(f"Number of genparticles: {gp_to_calo_hit_matrix.shape[0]}")
print(f"Number of genparticles with > 0 hits: {len(non_zero_rows)}")

### How many genparticles leave more than 1 hit?

In [None]:
# Count rows with more than 1 element with weight > 0
rows_with_multiple_elements = np.sum(np.sum(gp_to_calo_hit_matrix > 0, axis=1) > 1)
print(f"Number of genparticles leaving more than 1 hit: {rows_with_multiple_elements}")

### How many hits are assoicated to more than one genparticle?

In [None]:
# Count cols with more than one element > 0
cols_with_multiple_elements = np.sum(np.sum(gp_to_calo_hit_matrix > 0, axis=0) > 1)

# Extract rows that sum to greater than 0
multiple_gp_per_hit_mask = np.where(np.sum(gp_to_calo_hit_matrix > 0, axis=0) > 1)

# Create a new matrix with the extracted rows
multiple_gp_per_hit = gp_to_calo_hit_matrix[:, multiple_gp_per_hit_mask[1]]
multiple_gp_per_hit.shape

In [None]:
print(f"Number of hits with multiple genparticles: {multiple_gp_per_hit.shape[1]}")
print(f"Number of hits with multiple genparticles: {cols_with_multiple_elements}")

### Of the hits associated to multiple genparticles, how many genparticles are each hit associated to?

In [None]:
np.sum(multiple_gp_per_hit > 0, axis=0)

In [None]:
for i_hit in range(multiple_gp_per_hit.shape[1]):
    gp_links = multiple_gp_per_hit[:, i_hit]
    gp_link_mask = gp_links > 0
    gp_links = gp_links[gp_link_mask]
    print(f"Hit {i_hit} is linked to genparticles: {np.where(gp_link_mask)[0]}")

## Save processed data

### Save to parquet

In [None]:
def process_root_files_to_parquet(input_dir, output_dir, max_root_files=None):
    """
    Process ROOT files and save extracted features and matrices into Parquet files using ak arrays.

    Args:
        input_dir (str or Path): Directory containing ROOT files.
        output_dir (str or Path): Directory to save Parquet files.
        max_root_files (int, optional): Maximum number of ROOT files to process. Defaults to None.

    Returns:
        None
    """
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    root_counter = 0
    root_file_list = list(Path(input_dir).rglob("*.root"))
    total_files_to_porcess = max_root_files or len(root_file_list)

    for root_file in tqdm(sorted(root_file_list), desc="Processing ROOT files", total=total_files_to_porcess):
        if max_root_files is not None and root_counter >= max_root_files:
            print(f"Reached max_root_files limit: {max_root_files}. Stopping processing.")
            break
        try:
            output_file = output_dir / f"{root_file.stem}.parquet"
            if output_file.exists():
                print(f"Output file {output_file} already exists. Skipping processing.")
                return

            fi = uproot.open(root_file)
            collectionIDs = {
                k: v
                for k, v in zip(
                    fi.get("podio_metadata").arrays("events___idTable/m_names")["events___idTable/m_names"][0],
                    fi.get("podio_metadata").arrays("events___idTable/m_collectionIDs")[
                        "events___idTable/m_collectionIDs"
                    ][0],
                )
            }
            ev = fi["events"]
            event_data = get_event_data(ev)

            # Combine all events in the current ROOT file
            combined_data_dict = {
                "gen_features": [],
                "track_features": [],
                "cluster_features": [],
                "calo_hit_features": [],
                "tracker_hit_features": [],
                "genparticle_to_calo_hit_matrix": [],
                "genparticle_to_tracker_hit_matrix": [],
                "track_to_tracker_hit_matrix": [],
                "cluster_to_cluster_hit_matrix": [],
                "gp_to_track_matrix": [],
                "gp_to_gp": [],
            }

            for iev in range(len(ev["MCParticles.momentum.x"].array())):
                # Extract features and adjacency matrices for the current event
                gen_features = gen_to_features(event_data, iev)
                track_features = track_to_features(event_data, iev)
                cluster_features = cluster_to_features(
                    event_data, iev, cluster_features=["position.x", "position.y", "position.z", "energy", "type"]
                )
                calo_hit_features, genparticle_to_calo_hit_matrix, _ = process_calo_hit_data(
                    event_data, iev, collectionIDs
                )
                tracker_hit_features, genparticle_to_tracker_hit_matrix, _ = process_tracker_hit_data(
                    event_data, iev, collectionIDs
                )
                track_to_tracker_hit_matrix, _ = create_track_to_hit_coo_matrix(event_data, iev, collectionIDs)
                cluster_to_cluster_hit_matrix, _ = create_cluster_to_hit_coo_matrix(event_data, iev, collectionIDs)
                gp_to_track_matrix = genparticle_track_adj(event_data, iev)
                gp_to_gp = create_genparticle_to_genparticle_coo_matrix(event_data, iev)

                # Append the event data to the combined data dictionary
                combined_data_dict["gen_features"].append(gen_features)
                combined_data_dict["track_features"].append(track_features)
                combined_data_dict["cluster_features"].append(cluster_features)
                combined_data_dict["calo_hit_features"].append(calo_hit_features)
                combined_data_dict["tracker_hit_features"].append(tracker_hit_features)
                combined_data_dict["genparticle_to_calo_hit_matrix"].append(genparticle_to_calo_hit_matrix)
                combined_data_dict["genparticle_to_tracker_hit_matrix"].append(genparticle_to_tracker_hit_matrix)
                combined_data_dict["track_to_tracker_hit_matrix"].append(track_to_tracker_hit_matrix)
                combined_data_dict["cluster_to_cluster_hit_matrix"].append(cluster_to_cluster_hit_matrix)
                combined_data_dict["gp_to_track_matrix"].append(gp_to_track_matrix)
                combined_data_dict["gp_to_gp"].append(gp_to_gp)

            # Convert lists to ak arrays
            for key in combined_data_dict.keys():
                combined_data_dict[key] = ak.Array(combined_data_dict[key])

            # Save the combined data into a single Parquet file
            ak.to_parquet(combined_data_dict, output_file)

            print(f"Saved combined data for {root_file} to {output_file}")
            root_counter += 1

        except Exception as e:
            print(f"Error processing {root_file}: {e}")

    print(f"Finished processing {root_counter} ROOT files.")

In [None]:
# process_root_files_to_parquet("/mnt/ceph/users/ewulff/data/cld/", "/mnt/ceph/users/ewulff/data/cld/processed/parquet", max_root_files=2)

In [None]:
event_data1 = ak.from_parquet(next(Path("/mnt/ceph/users/ewulff/data/cld/processed/parquet/").glob("*.parquet")))
event_data1.fields

### LMDB

In [None]:
def dumps(obj):
    """
    Serialize an object.

    Returns:
        Implementation-dependent bytes-like object
    """
    return pickle.dumps(obj, protocol=5)


def process_root_files_to_lmdb(input_dir, output, max_root_files=None):
    """
    Process ROOT files and save extracted features into an LMDB database.

    Args:
        input_dir (str or Path): Directory containing ROOT files.
        output (str or Path): Path to the LMDB database file.
        collectionIDs (dict): Mapping of collection names to their IDs.
        max_root_files (int, optional): Maximum number of ROOT files to process. Defaults to None.

    Returns:
        None
    """
    lmdb_path = os.path.expanduser(output)
    isdir = os.path.isdir(lmdb_path)

    print("Generate LMDB to %s" % lmdb_path)
    db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, meminit=False, map_async=True)

    txn = db.begin(write=True)
    event_counter = 0
    root_counter = 0

    for root_file in tqdm(Path(input_dir).rglob("*.root"), desc="Processing ROOT files"):
        if max_root_files is not None and root_counter >= max_root_files:
            print(f"Reached max_root_files limit: {max_root_files}. Stopping processing.")
            break
        try:
            fi = uproot.open(root_file)
            collectionIDs = {
                k: v
                for k, v in zip(
                    fi.get("podio_metadata").arrays("events___idTable/m_names")["events___idTable/m_names"][0],
                    fi.get("podio_metadata").arrays("events___idTable/m_collectionIDs")[
                        "events___idTable/m_collectionIDs"
                    ][0],
                )
            }
            ev = fi["events"]
            event_data = get_event_data(ev)
            for iev in range(len(ev["MCParticles.momentum.x"].array())):
                # Extract features and adjacency matrices for the current event
                gen_features = gen_to_features(event_data, iev)
                track_features = track_to_features(event_data, iev)
                cluster_features = cluster_to_features(
                    event_data, iev, cluster_features=["position.x", "position.y", "position.z", "energy", "type"]
                )
                calo_hit_features, genparticle_to_calo_hit_matrix, _ = process_calo_hit_data(
                    event_data, iev, collectionIDs
                )
                tracker_hit_features, genparticle_to_tracker_hit_matrix, _ = process_tracker_hit_data(
                    event_data, iev, collectionIDs
                )
                track_to_tracker_hit_matrix, _ = create_track_to_hit_coo_matrix(event_data, iev, collectionIDs)
                cluster_to_cluster_hit_matrix, _ = create_cluster_to_hit_coo_matrix(event_data, iev, collectionIDs)
                gp_to_track_matrix = genparticle_track_adj(event_data, iev)
                gp_to_gp = create_genparticle_to_genparticle_coo_matrix(event_data, iev)

                # Save the event data into LMDB
                txn.put(
                    "{}".format(event_counter).encode("ascii"),
                    dumps(
                        {
                            "gen_features": gen_features,
                            "track_features": track_features,
                            "cluster_features": cluster_features,
                            "calo_hit_features": calo_hit_features,
                            "tracker_hit_features": tracker_hit_features,
                            "genparticle_to_calo_hit_matrix": genparticle_to_calo_hit_matrix,
                            "genparticle_to_tracker_hit_matrix": genparticle_to_tracker_hit_matrix,
                            "track_to_tracker_hit_matrix": track_to_tracker_hit_matrix,
                            "cluster_to_cluster_hit_matrix": cluster_to_cluster_hit_matrix,
                            "gp_to_track_matrix": gp_to_track_matrix,
                            "gp_to_gp": gp_to_gp,
                        }
                    ),
                )

                if event_counter % 100 == 0:
                    print(f"[{event_counter}] events processed")
                    txn.commit()
                    txn = db.begin(write=True)

                event_counter += 1

            root_counter += 1

        except Exception as e:
            print(f"Error processing {root_file}: {e}")

    # Finish iterating through all events
    txn.commit()
    keys = ["{}".format(k).encode("ascii") for k in range(event_counter)]
    with db.begin(write=True) as txn:
        txn.put(b"__keys__", dumps(keys))
        txn.put(b"__len__", dumps(len(keys)))

    print("Flushing database ...")
    db.sync()
    db.close()

In [None]:
# process_root_files_to_lmdb(
#     input_dir="/mnt/ceph/users/ewulff/data/cld/", output="/mnt/ceph/users/ewulff/data/cld/processed/lmdb", max_root_files=2
# )

In [None]:
# read the lmdb database
def read_full_lmdb_database(lmdb_path):
    lmdb_path = os.path.expanduser(lmdb_path)
    db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False)

    with db.begin() as txn:
        keys = pickle.loads(txn.get(b"__keys__"))
        data = {ii: pickle.loads(txn.get(key)) for ii, key in enumerate(keys)}

    db.close()
    return data


# lmdb_data = read_full_lmdb_database("/mnt/ceph/users/ewulff/data/cld/processed/lmdb")

# ML task
- Set up clustering (and tracking?) as instance segmentation problems
- Instance segmentation labels are based on hits-to-genparticle associations in 
    - genparticle_to_calo_hit_matrix (and genparticle_to_tracker_hit_matrix)

In [None]:
event_data1 = ak.from_parquet(next(Path("/mnt/ceph/users/ewulff/data/cld/processed/parquet/").glob("*.parquet")))

# Extract genparticle_to_calo_hit_matrix from event_data1
# This matrix contains the mapping of genparticles to calorimeter hits in a COO format
event_i = 2
gen_idx = event_data1["genparticle_to_calo_hit_matrix"][event_i]["gen_idx"].to_numpy()
hit_idx = event_data1["genparticle_to_calo_hit_matrix"][event_i]["hit_idx"].to_numpy()
weights = event_data1["genparticle_to_calo_hit_matrix"][event_i][
    "weight"
].to_numpy()  # only contains the non-zero weights

# create dense coo amtrix
# gp_to_calo_hit_matrix = coo_matrix((weights, (rows, cols)), shape=(np.max(rows) + 1, np.max(cols) + 1)).todense()

# Create instance segmentation-like labels for each hit
# Each hit is classified as belonging to one genparticle based on the highest weight


def get_hit_labels(hit_idx, gen_idx, weights):
    """
    Assign labels to hits based on the genparticle index with the highest weight.

    Parameters:
        hit_idx (np.ndarray): Array of hit indices.
        gen_idx (np.ndarray): Array of genparticle indices corresponding to each hit.
        weights (np.ndarray): Array of weights corresponding to each hit.

    Returns:
        np.ndarray: Array of labels for each hit, where each label corresponds to the genparticle index.
    """
    # Initialize an array to store labels for each hit
    hit_labels = np.full(np.max(hit_idx) + 1, -1, dtype=int)  # Default label is -1 (unclassified)

    # Iterate through the sparse COO matrix data
    for ii, (h_idx, g_idx, weight) in enumerate(zip(hit_idx, gen_idx, weights)):
        if hit_labels[h_idx] == -1 or weight > weights[ii]:
            hit_labels[h_idx] = g_idx

    # hit_labels now contains the genparticle index for each hit

    return hit_labels


hit_labels = get_hit_labels(hit_idx, gen_idx, weights)
hit_labels

In [None]:
# Mock data for testing
def create_mock_event_data():
    return ak.Array(
        {
            "genparticle_to_calo_hit_matrix": [
                {
                    "gen_idx": np.array([0, 1, 0, 2]),
                    "hit_idx": np.array([0, 1, 2, 3]),
                    "weight": np.array([0.5, 0.8, 0.3, 0.9]),
                },
                {
                    "gen_idx": np.array([1, 2, 1, 0]),
                    "hit_idx": np.array([0, 1, 2, 3]),
                    "weight": np.array([0.6, 0.7, 0.4, 0.2]),
                },
                {
                    "gen_idx": np.array([0, 2, 1, 2]),
                    "hit_idx": np.array([0, 1, 2, 3]),
                    "weight": np.array([0.9, 0.5, 0.7, 0.8]),
                },
            ]
        }
    )


# Test function
def test_hit_labels():
    event_data1 = create_mock_event_data()
    event_i = 2
    gen_idx = event_data1["genparticle_to_calo_hit_matrix"][event_i]["gen_idx"].to_numpy()
    hit_idx = event_data1["genparticle_to_calo_hit_matrix"][event_i]["hit_idx"].to_numpy()
    weights = event_data1["genparticle_to_calo_hit_matrix"][event_i]["weight"].to_numpy()

    hit_labels = get_hit_labels(hit_idx, gen_idx, weights)

    # Expected labels based on mock data
    expected_labels = np.array([0, 2, 1, 2])

    assert np.array_equal(hit_labels, expected_labels), f"Expected {expected_labels}, but got {hit_labels}"


# Run the test
test_hit_labels()
print("Test passed!")

In [None]:
gen_idx = event_data1["genparticle_to_calo_hit_matrix"][event_i]["gen_idx"].to_numpy()
hit_idx = event_data1["genparticle_to_calo_hit_matrix"][event_i]["hit_idx"].to_numpy()
weights = event_data1["genparticle_to_calo_hit_matrix"][event_i][
    "weight"
].to_numpy()  # only contains the non-zero weights
calo_hit_features = event_data1["calo_hit_features"][event_i]

# Assign labels to hits based on the genparticle index with the highest weight
hit_labels = get_hit_labels(hit_idx, gen_idx, weights)

# Extract calorimeter hit positions (x, y, z)
calo_hit_positions = np.column_stack(
    (
        calo_hit_features["position.x"].to_numpy(),
        calo_hit_features["position.y"].to_numpy(),
        calo_hit_features["position.z"].to_numpy(),
    )
)

# Assign unique colors to each genparticle ID
unique_ids = np.unique(hit_labels)
colors = plt.cm.tab10(np.linspace(0, 1, len(unique_ids)))
color_map = {
    gen_id: f"rgba({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)}, {color[3]})"
    for gen_id, color in zip(unique_ids, colors)
}

# Create traces for each genparticle ID
traces = []
for gen_id in unique_ids:
    mask = hit_labels == gen_id  # Create a mask for hits belonging to the current genparticle ID
    traces.append(
        go.Scatter3d(
            x=calo_hit_positions[mask, 0],
            y=calo_hit_positions[mask, 1],
            z=calo_hit_positions[mask, 2],
            mode="markers",
            marker=dict(size=3, color=color_map[gen_id]),
            name=f"gp {gen_id}",
        )
    )

# Customize the axis names
layout = go.Layout(
    scene=dict(
        xaxis=dict(title="X"),
        yaxis=dict(title="Y"),
        zaxis=dict(title="Z"),
        camera=dict(
            up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
            center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
            eye=dict(x=0, y=0, z=2.1),  # Sets the position of the camera
        ),
    ),
    showlegend=False,
    width=700,
    height=700,
    title="Calorimeter hits colored by genparticle",
)

# Create the figure and display the plot
fig = go.Figure(data=traces, layout=layout)
fig.show()

In [None]:
def count_genparticles(event_data, iev):
    """
    Identify genparticles that leave more than a specified number of hits in the calorimeter.

    Args:
        gp_to_calo_hit_matrix (numpy.ndarray): The matrix representing the mapping of genparticles to calorimeter hits.
        threshold (int): The minimum number of hits a genparticle must leave to be included in the result.

    Returns:
        numpy.ndarray: Indices of genparticles that leave more than the specified number of hits.
    """

    gen_idx = event_data["genparticle_to_calo_hit_matrix"][iev]["gen_idx"].to_numpy()
    hit_idx = event_data["genparticle_to_calo_hit_matrix"][iev]["hit_idx"].to_numpy()
    weights = event_data["genparticle_to_calo_hit_matrix"][iev][
        "weight"
    ].to_numpy()  # only contains the non-zero weights

    # create dense coo amtrix
    gp_to_calo_hit_matrix = coo_matrix(
        (weights, (gen_idx, hit_idx)), shape=(np.max(gen_idx) + 1, np.max(hit_idx) + 1)
    ).todense()

    # Extract rows that sum to greater than 0
    non_zero_rows = np.where(np.sum(gp_to_calo_hit_matrix, axis=1) > 0)[0]

    print(f"Number of genparticles: {gp_to_calo_hit_matrix.shape[0]}")
    print(f"Number of genparticles with > 0 hits: {len(non_zero_rows)}")
    print(f"Number of genparticles with > 1 hits: {np.sum(np.sum(gp_to_calo_hit_matrix > 0, axis=1) > 1)}")


count_genparticles(event_data1, 2)