# Data processing for CLD

## Imports

In [None]:
from pathlib import Path
import uproot
import numpy as np
import math
import pandas
import awkward
import h5py
from tqdm import tqdm
import pickle
import os
import lmdb

import plotly.graph_objects as go
import vector

import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix

## Function definitions
These are defined in `data_processing/cld_processing` but are reproduced here clarity.

In [None]:
# Constants
pion_mass = 0.13957
B = -2.0  # magnetic field in T (-2 for CLD FCC-ee)
c = 3e8  # speed of light in m/s
scale = 1000


def get_sitrack_links(ev):
    return ev.arrays(
        [
            "SiTracksMCTruthLink.weight",
            "_SiTracksMCTruthLink_to/_SiTracksMCTruthLink_to.collectionID",
            "_SiTracksMCTruthLink_to/_SiTracksMCTruthLink_to.index",
            "_SiTracksMCTruthLink_from/_SiTracksMCTruthLink_from.collectionID",
            "_SiTracksMCTruthLink_from/_SiTracksMCTruthLink_from.index",
        ]
    )


def get_tracker_hits_begin_end(ev):
    return ev.arrays(
        [
            "SiTracks_Refitted/SiTracks_Refitted.trackerHits_begin",
            "SiTracks_Refitted/SiTracks_Refitted.trackerHits_end",
            "_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.index",
            "_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.collectionID",
        ]
    )


def get_cluster_hits_begin_end(ev):
    return ev.arrays(
        [
            "PandoraClusters/PandoraClusters.hits_begin",
            "PandoraClusters/PandoraClusters.hits_end",
            "_PandoraClusters_hits/_PandoraClusters_hits.index",
            "_PandoraClusters_hits/_PandoraClusters_hits.collectionID",
        ]
    )


def get_calohit_links(ev):
    return ev.arrays(
        [
            "CalohitMCTruthLink.weight",
            "_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.collectionID",
            "_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.index",
            "_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.collectionID",
            "_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.index",
        ]
    )


def get_calo_hit_data(ev):
    return ev.arrays(
        [
            "ECALBarrel",
            "ECALEndcap",
            "HCALBarrel",
            "HCALEndcap",
            "HCALOther",
            "MUON",
        ]
    )


def get_tracker_hit_data(ev):
    return ev.arrays(
        [
            "VXDTrackerHits",
            "VXDEndcapTrackerHits",
            "ITrackerHits",
            "OTrackerHits",
            # These need to be added to the keep statements of the next generation
            # "ITrackerEndcapHits",
            # "OTrackerEndcapHits",
        ]
    )


def get_track_data(ev):
    track_data = ev.arrays(
        [
            "SiTracks_Refitted",
            "SiTracks_Refitted_dQdx",
            "_SiTracks_Refitted_trackStates",
        ]
    )

    return track_data


def get_cluster_data(ev):
    cluster_data = ev.arrays(
        [
            "PandoraClusters",
            "_PandoraClusters_hits",
        ]
    )

    return cluster_data


def get_gen_data(ev):
    gen_data = ev.arrays(["MCParticles"])
    return gen_data


def get_event_data(ev):
    """
    Retrieves all data entries returned by the existing functions using ev.arrays().

    Args:
        ev: The event data object.

    Returns:
        dict: A dictionary containing all data entries.
    """
    return ev.arrays(
        [
            # SiTrack links
            "SiTracksMCTruthLink.weight",
            "_SiTracksMCTruthLink_to/_SiTracksMCTruthLink_to.collectionID",
            "_SiTracksMCTruthLink_to/_SiTracksMCTruthLink_to.index",
            "_SiTracksMCTruthLink_from/_SiTracksMCTruthLink_from.collectionID",
            "_SiTracksMCTruthLink_from/_SiTracksMCTruthLink_from.index",
            # Tracker hits begin/end
            "SiTracks_Refitted/SiTracks_Refitted.trackerHits_begin",
            "SiTracks_Refitted/SiTracks_Refitted.trackerHits_end",
            "_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.index",
            "_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.collectionID",
            # Cluster hits begin/end
            "PandoraClusters/PandoraClusters.hits_begin",
            "PandoraClusters/PandoraClusters.hits_end",
            "_PandoraClusters_hits/_PandoraClusters_hits.index",
            "_PandoraClusters_hits/_PandoraClusters_hits.collectionID",
            # Calo hit links
            "CalohitMCTruthLink.weight",
            "_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.collectionID",
            "_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.index",
            "_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.collectionID",
            "_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.index",
            # Calo hit data
            "ECALBarrel",
            "ECALEndcap",
            "HCALBarrel",
            "HCALEndcap",
            "HCALOther",
            "MUON",
            # Tracker hit data
            "VXDTrackerHits",
            "VXDEndcapTrackerHits",
            "ITrackerHits",
            "OTrackerHits",
            # Track data
            "SiTracks_Refitted",
            "SiTracks_Refitted_dQdx",
            "_SiTracks_Refitted_trackStates",
            # Cluster data
            "PandoraClusters",
            "_PandoraClusters_hits",
            # Gen data
            "MCParticles",
            "MCParticles.parents_begin",
            "MCParticles.parents_end",
            "_MCParticles_parents/_MCParticles_parents.index",
            "MCParticles.daughters_begin",
            "MCParticles.daughters_end",
            "_MCParticles_daughters/_MCParticles_daughters.index",
        ]
    )


def hits_to_features(hit_data, iev, coll, feats):
    """
    Converts hit data into a structured feature array for a specific event and collection.

    Args:
        hit_data (dict): A dictionary containing hit data, where keys are strings representing
            collection and feature names, and values are arrays of feature data.
        iev (int): The index of the event to extract data for.
        coll (str): The name of the hit collection (e.g., "VXDTrackerHits", "VXDEndcapTrackerHits",
            "ECALBarrel", "ECALEndcap", etc.).
        feats (list of str): A list of feature names to extract from the hit data
            (e.g., "position.x", "position.y", "position.z", "energy", "type", etc.).

    Returns:
        awkward.Array: An Awkward Array containing the extracted features for the specified event
            and collection. The array includes an additional "subdetector" feature, which encodes
            the subdetector type:
            - 0 for ECAL
            - 1 for HCAL
            - 2 for MUON
            - 3 for other collections.
    """
    # tracker hits store eDep instead of energy
    if "TrackerHit" in coll or "TrackerEndcapHits" in coll:
        new_feats = []
        for feat in feats:
            feat_to_get = feat
            if feat == "energy":
                feat_to_get = "eDep"
            new_feats.append((feat, feat_to_get))
    else:
        new_feats = [(f, f) for f in feats]

    feat_arr = {f1: hit_data[coll + "." + f2][iev] for f1, f2 in new_feats}

    sdcoll = "subdetector"
    feat_arr[sdcoll] = np.zeros(len(feat_arr["type"]), dtype=np.int32)
    if coll.startswith("ECAL"):
        feat_arr[sdcoll][:] = 0
    elif coll.startswith("HCAL"):
        feat_arr[sdcoll][:] = 1
    elif coll.startswith("MUON"):
        feat_arr[sdcoll][:] = 2
    else:
        feat_arr[sdcoll][:] = 3
    return awkward.Array(feat_arr)


def genparticle_track_adj(event_data, iev):
    trk_to_gen_trkidx = event_data["_SiTracksMCTruthLink_from/_SiTracksMCTruthLink_from.index"][iev]
    trk_to_gen_genidx = event_data["_SiTracksMCTruthLink_to/_SiTracksMCTruthLink_to.index"][iev]
    trk_to_gen_w = event_data["SiTracksMCTruthLink.weight"][iev]

    genparticle_to_track_matrix_coo0 = awkward.to_numpy(trk_to_gen_genidx)
    genparticle_to_track_matrix_coo1 = awkward.to_numpy(trk_to_gen_trkidx)
    genparticle_to_track_matrix_w = awkward.to_numpy(trk_to_gen_w)

    return genparticle_to_track_matrix_coo0, genparticle_to_track_matrix_coo1, genparticle_to_track_matrix_w


def produce_gp_to_track(ev, iev, num_genparticles, num_tracks):
    """
    Produces the genparticle-to-track adjacency matrix (gp_to_track) in dense format.

    Args:
        iev: The event index.
        num_genparticles: Total number of genparticles.
        num_tracks: Total number of tracks.

    Returns:
        gp_to_track: A dense adjacency matrix where each entry represents the weight
                     between a genparticle and a track.
    """
    # Get the COO format adjacency data
    genparticle_to_track_matrix_coo0, genparticle_to_track_matrix_coo1, genparticle_to_track_matrix_w = (
        genparticle_track_adj(ev, iev)
    )

    # Create a sparse matrix for the association between gen particles and tracks
    if len(genparticle_to_track_matrix_coo0) > 0:
        gp_to_track = coo_matrix(
            (genparticle_to_track_matrix_w, (genparticle_to_track_matrix_coo0, genparticle_to_track_matrix_coo1)),
            shape=(num_genparticles, num_tracks),
        ).todense()
    else:
        gp_to_track = np.zeros((num_genparticles, 1))

    return gp_to_track


def create_global_to_local_mapping(hit_data, iev, collectionIDs):
    """Create a mapping from global hit indices to local (collection, index) pairs."""
    hit_idx_global_to_local = {}
    hit_idx_global = 0

    for col in sorted(hit_data.fields):
        icol = collectionIDs[col]
        for ihit in range(len(hit_data[col][col + ".position.x"][iev])):
            hit_idx_global_to_local[hit_idx_global] = (icol, ihit)
            hit_idx_global += 1

    hit_idx_local_to_global = {v: k for k, v in hit_idx_global_to_local.items()}

    return hit_idx_global_to_local, hit_idx_local_to_global


def create_hit_feature_matrix(hit_data, iev, feats):
    """Extract features from hit data and create a feature matrix."""
    hit_feature_matrix = []

    for col in sorted(hit_data.fields):
        hit_features = hits_to_features(hit_data[col], iev, col, feats)
        hit_feature_matrix.append(hit_features)

    # Combine all hit features into a single Record
    hit_feature_matrix = {
        k: awkward.concatenate([hit_feature_matrix[i][k] for i in range(len(hit_feature_matrix))])
        for k in hit_feature_matrix[0].fields
    }

    return hit_feature_matrix


# Combine the two functions above into one to reduce the number of loops over hit data.
def create_hit_feature_matrix_and_mapping(hit_data, iev, collectionIDs, feats):
    """Combines the creation of hit feature matrix and global-local index mapping in one loop over hit data."""

    # Initialize global hit index mapping
    hit_idx_global = 0
    hit_idx_global_to_local = {}
    hit_feature_matrix = []

    # Process hit data to create feature matrix and global-local mappings
    for col in sorted(hit_data.fields):
        icol = collectionIDs[col]
        hit_features = hits_to_features(hit_data[col], iev, col, feats)
        hit_feature_matrix.append(hit_features)
        for ihit in range(len(hit_data[col][col + ".position.x"][iev])):
            hit_idx_global_to_local[hit_idx_global] = (icol, ihit)
            hit_idx_global += 1
    hit_idx_local_to_global = {v: k for k, v in hit_idx_global_to_local.items()}
    hit_feature_matrix = {
        k: np.concatenate([hit_feature_matrix[i][k].to_numpy() for i in range(len(hit_feature_matrix))])
        for k in hit_feature_matrix[0].fields
    }

    return hit_feature_matrix, hit_idx_global_to_local, hit_idx_local_to_global


# TODO: check correctnes of this function
def create_track_to_hit_coo_matrix(event_data, iev, collectionIDs):
    """
    Creates the COO matrix indices and weights for the relationship between tracks and tracker hits.

    Args:
        ev: The event data containing track property data.
        iev: The index of the event to extract data for.
        collectionIDs: A dictionary mapping collection names to their IDs.

    Returns:
        tuple: A tuple containing three arrays:
            - Row indices (track indices).
            - Column indices (global hit indices).
            - Weights (association weights between tracks and hits).
    """
    # Extract tracker hit data
    tracker_hit_data = event_data[
        [
            "VXDTrackerHits",
            "VXDEndcapTrackerHits",
            "ITrackerHits",
            "OTrackerHits",
            # "ITrackerEndcapHits",
            # "OTrackerEndcapHits",
        ]
    ]

    # Extract tracker hit to track associations
    hit_beg = event_data["SiTracks_Refitted/SiTracks_Refitted.trackerHits_begin"][
        iev
    ]  # hit_beg[i] gives the first hit index for track i
    hit_end = event_data["SiTracks_Refitted/SiTracks_Refitted.trackerHits_end"][
        iev
    ]  # hit_end[i] gives the last hit index for track i
    trk_hit_idx = event_data["_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.index"][iev]
    trk_hit_coll = event_data["_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.collectionID"][iev]

    # Create a mapping from global hit indices to local (collection, index) pairs
    _, hit_idx_local_to_global = create_global_to_local_mapping(tracker_hit_data, iev, collectionIDs)

    # Initialize lists for COO matrix
    track_to_hit_matrix_coo0 = []
    track_to_hit_matrix_coo1 = []
    track_to_hit_matrix_w = []

    # Iterate over tracks and their associated hits
    for track_idx, (beg, end) in enumerate(zip(hit_beg, hit_end)):
        for ihit in range(beg, end):
            local_hit_idx = trk_hit_idx[ihit]
            collid = trk_hit_coll[ihit]

            if (collid, local_hit_idx) not in hit_idx_local_to_global:
                continue
            global_hit_idx = hit_idx_local_to_global[(collid, local_hit_idx)]

            # Append to COO matrix
            track_to_hit_matrix_coo0.append(track_idx)
            track_to_hit_matrix_coo1.append(global_hit_idx)
            track_to_hit_matrix_w.append(1.0)  # Assuming weight is 1.0 for all associations

    return (
        (
            np.array(track_to_hit_matrix_coo0),
            np.array(track_to_hit_matrix_coo1),
            np.array(track_to_hit_matrix_w),
        ),
        hit_idx_local_to_global,
    )


# TODO: check correctness of this function
def create_cluster_to_hit_coo_matrix(event_data, iev, collectionIDs):
    """
    Creates the COO matrix indices and weights for the relationship between clusters and calorimeter hits.

    Args:
        ev: The event data containing cluster property data.
        iev: The index of the event to extract data for.
        collectionIDs: A dictionary mapping collection names to their IDs.

    Returns:
        tuple: A tuple containing three arrays:
            - Row indices (cluster indices).
            - Column indices (global hit indices).
            - Weights (association weights between clusters and hits).
    """
    # Extract calorimeter hit data
    calo_hit_data = event_data[
        [
            "ECALBarrel",
            "ECALEndcap",
            "HCALBarrel",
            "HCALEndcap",
            "HCALOther",
            "MUON",
        ]
    ]

    # Extract cluster-to-hit associations
    cluster_hit_begin = event_data["PandoraClusters/PandoraClusters.hits_begin"][iev]
    cluster_hit_end = event_data["PandoraClusters/PandoraClusters.hits_end"][iev]
    cluster_hit_idx = event_data["_PandoraClusters_hits/_PandoraClusters_hits.index"][iev]
    cluster_hit_coll = event_data["_PandoraClusters_hits/_PandoraClusters_hits.collectionID"][iev]

    _, calo_hit_idx_local_to_global = create_global_to_local_mapping(calo_hit_data, iev, collectionIDs)

    # Initialize lists for COO matrix
    cluster_to_hit_matrix_coo0 = []
    cluster_to_hit_matrix_coo1 = []
    cluster_to_hit_matrix_w = []

    # Iterate over clusters and their associated hits
    for cluster_idx, (beg, end) in enumerate(zip(cluster_hit_begin, cluster_hit_end)):
        for ihit in range(beg, end):
            local_hit_idx = cluster_hit_idx[ihit]
            collid = cluster_hit_coll[ihit]

            if (collid, local_hit_idx) not in calo_hit_idx_local_to_global:
                continue
            global_hit_idx = calo_hit_idx_local_to_global[(collid, local_hit_idx)]

            # Append to COO matrix
            cluster_to_hit_matrix_coo0.append(cluster_idx)
            cluster_to_hit_matrix_coo1.append(global_hit_idx)
            cluster_to_hit_matrix_w.append(1.0)  # Assuming weight is 1.0 for all associations

    return (
        (
            np.array(cluster_to_hit_matrix_coo0),
            np.array(cluster_to_hit_matrix_coo1),
            np.array(cluster_to_hit_matrix_w),
        ),
        calo_hit_idx_local_to_global,
    )


def process_calo_hit_data(event_data, iev, collectionIDs):
    feats = ["type", "energy", "position.x", "position.y", "position.z"]

    calo_hit_data = event_data[
        [
            "ECALBarrel",
            "ECALEndcap",
            "HCALBarrel",
            "HCALEndcap",
            "HCALOther",
            "MUON",
        ]
    ]
    calohit_links = event_data[
        [
            "CalohitMCTruthLink.weight",
            "_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.collectionID",
            "_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.index",
            "_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.collectionID",
            "_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.index",
        ]
    ]

    # Create a mapping from global hit indices to local (collection, index) pairs and hit feature matrix
    hit_features, _, hit_idx_local_to_global = create_hit_feature_matrix_and_mapping(
        calo_hit_data, iev, collectionIDs, feats
    )

    # Add all edges from genparticle to calohit
    calohit_to_gen_weight = calohit_links["CalohitMCTruthLink.weight"][iev]
    calohit_to_gen_calo_colid = calohit_links["_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.collectionID"][iev]
    calohit_to_gen_gen_colid = calohit_links["_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.collectionID"][iev]
    calohit_to_gen_calo_idx = calohit_links["_CalohitMCTruthLink_from/_CalohitMCTruthLink_from.index"][iev]
    calohit_to_gen_gen_idx = calohit_links["_CalohitMCTruthLink_to/_CalohitMCTruthLink_to.index"][iev]

    genparticle_to_hit_matrix_coo0 = []
    genparticle_to_hit_matrix_coo1 = []
    genparticle_to_hit_matrix_w = []
    for calo_colid, calo_idx, gen_colid, gen_idx, weight in zip(
        calohit_to_gen_calo_colid,
        calohit_to_gen_calo_idx,
        calohit_to_gen_gen_colid,
        calohit_to_gen_gen_idx,
        calohit_to_gen_weight,
    ):
        genparticle_to_hit_matrix_coo0.append(gen_idx)
        genparticle_to_hit_matrix_coo1.append(hit_idx_local_to_global[(calo_colid, calo_idx)])
        genparticle_to_hit_matrix_w.append(weight)

    return (
        hit_features,
        (
            np.array(genparticle_to_hit_matrix_coo0),
            np.array(genparticle_to_hit_matrix_coo1),
            np.array(genparticle_to_hit_matrix_w),
        ),
        hit_idx_local_to_global,
    )


def process_tracker_hit_data(event_data, iev, collectionIDs):

    feats = ["type", "energy", "position.x", "position.y", "position.z"]

    tracker_hit_data = event_data[
        [
            "VXDTrackerHits",
            "VXDEndcapTrackerHits",
            "ITrackerHits",
            "OTrackerHits",
            # "ITrackerEndcapHits",
            # "OTrackerEndcapHits",
        ]
    ]

    # Create a mapping from global hit indices to local (collection, index) pairs and hit feature matrix
    hit_feature_matrix, _, hit_idx_local_to_global = create_hit_feature_matrix_and_mapping(
        tracker_hit_data, iev, collectionIDs, feats
    )

    # Extract tracker hit to track associations
    hit_beg = event_data["SiTracks_Refitted/SiTracks_Refitted.trackerHits_begin"][
        iev
    ]  # hit_beg[i] gives the first hit index for track i
    hit_end = event_data["SiTracks_Refitted/SiTracks_Refitted.trackerHits_end"][
        iev
    ]  # hit_end[i] gives the last hit index for track i
    trk_hit_idx = event_data["_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.index"][iev]
    trk_hit_coll = event_data["_SiTracks_Refitted_trackerHits/_SiTracks_Refitted_trackerHits.collectionID"][iev]

    # Get the COO format adjacency data
    genparticle_to_track_matrix_coo0, genparticle_to_track_matrix_coo1, genparticle_to_track_matrix_w = (
        genparticle_track_adj(event_data, iev)
    )
    gen_indices = genparticle_to_track_matrix_coo0
    track_indices = genparticle_to_track_matrix_coo1

    # Initialize lists for COO matrix
    genparticle_to_hit_matrix_coo0 = []
    genparticle_to_hit_matrix_coo1 = []
    genparticle_to_hit_matrix_w = []

    # Iterate over non-zero elements to find links between genparticles and tracks
    for gen_idx, track_idx, weight in zip(gen_indices, track_indices, genparticle_to_track_matrix_w):
        if weight > 0:  # Only consider non-zero weights
            # Find tracker hits associated with the track
            for ihit in range(hit_beg[track_idx], hit_end[track_idx]):  # for all hits in this track
                # Translate local hit index to global hit index
                local_hit_idx = trk_hit_idx[ihit]
                collid = trk_hit_coll[ihit]

                if (
                    collid,
                    local_hit_idx,
                ) not in hit_idx_local_to_global:  # Check if the hit is in the local-to-global mapping
                    continue
                global_hit_idx = hit_idx_local_to_global[(collid, local_hit_idx)]

                # Append the gp to hit association and weight to the COO matrix
                genparticle_to_hit_matrix_coo0.append(gen_idx)
                genparticle_to_hit_matrix_coo1.append(global_hit_idx)
                genparticle_to_hit_matrix_w.append(weight)

    return (
        hit_feature_matrix,  # Tracker hit feature matrix
        (
            np.array(genparticle_to_hit_matrix_coo0),
            np.array(genparticle_to_hit_matrix_coo1),
            np.array(genparticle_to_hit_matrix_w),
        ),
        hit_idx_local_to_global,
    )


def gen_to_features(event_data, iev):

    gen_data = event_data["MCParticles"]

    mc_coll = "MCParticles"
    gen_arr = gen_data[iev]

    gen_arr = {k.replace(mc_coll + ".", ""): gen_arr[k] for k in gen_arr.fields}

    MCParticles_p4 = vector.awk(
        awkward.zip(
            {
                "mass": gen_arr["mass"],
                "x": gen_arr["momentum.x"],
                "y": gen_arr["momentum.y"],
                "z": gen_arr["momentum.z"],
            }
        )
    )
    gen_arr["pt"] = MCParticles_p4.pt
    gen_arr["eta"] = MCParticles_p4.eta
    gen_arr["phi"] = MCParticles_p4.phi
    gen_arr["energy"] = MCParticles_p4.energy
    gen_arr["sin_phi"] = np.sin(gen_arr["phi"])
    gen_arr["cos_phi"] = np.cos(gen_arr["phi"])

    ret = {
        "PDG": gen_arr["PDG"],
        "generatorStatus": gen_arr["generatorStatus"],
        "charge": gen_arr["charge"],
        "pt": gen_arr["pt"],
        "eta": gen_arr["eta"],
        "phi": gen_arr["phi"],
        "sin_phi": gen_arr["sin_phi"],
        "cos_phi": gen_arr["cos_phi"],
        "energy": gen_arr["energy"],
        # "ispu": gen_arr["ispu"],
        "simulatorStatus": gen_arr["simulatorStatus"],
        # "gp_to_track": np.zeros(len(gen_arr["PDG"]), dtype=np.float64),
        # "gp_to_cluster": np.zeros(len(gen_arr["PDG"]), dtype=np.float64),
        # "jet_idx": np.zeros(len(gen_arr["PDG"]), dtype=np.int64),
        # "daughters_begin": gen_arr["daughters_begin"],
        # "daughters_end": gen_arr["daughters_end"],
        "px": gen_arr["momentum.x"],
        "py": gen_arr["momentum.y"],
        "pz": gen_arr["momentum.z"],
        "mass": gen_arr["mass"],
    }

    # ret["index"] = prop_data["_MCParticles_daughters/_MCParticles_daughters.index"][iev]

    # make all values numpy arrays
    ret = {k: awkward.to_numpy(v) for k, v in ret.items()}

    return ret


# From https://bib-pubdb1.desy.de/record/81214/files/LC-DET-2006-004[1].pdf, eq12
# pT​(in GeV/c) ≈ a [mm/s] * |Bz(in T) / omega(1/mm)|
# a = c * 10^(-15) = 3*10^(-4)
def track_pt(omega, bfield=B):
    a = 3 * 10**-4
    return a * np.abs(bfield / omega)


def track_to_features(event_data, iev):
    """
    Extracts track features from the provided property data for a specific event and track collection.

    Args:
        event_data (awdward.Array) The event data containing track property data.
        iev (int): The index of the event to extract data for.

    Returns:
        awkward.Record: A record containing the extracted track features, including:
            - "type", "chi2", "ndf": Basic track properties.
            - "dEdx", "dEdxError": Energy deposition and its error.
            - "radiusOfInnermostHit": Radius of the innermost hit for each track.
            - "tanLambda", "D0", "phi", "omega", "Z0", "time": Track state properties.
            - "pt", "px", "py", "pz", "p": Momentum components and magnitude.
            - "eta": Pseudorapidity.
            - "sin_phi", "cos_phi": Sine and cosine of the azimuthal angle.
            - "elemtype": Element type (always 1 for tracks).
            - "q": Charge of the track (+1 or -1).

    Notes:
        - The function calculates additional derived features such as momentum components, pseudorapidity,
          and radius of the innermost hit.
        - The "AtFirstHit" state is used to determine the innermost hit radius.
        - The charge is set to +1 or -1 based on the sign of the "omega" parameter.
        - The input `ev` is expected to be an uproot TTree object containing the necessary branches.
    """
    track_coll = "SiTracks_Refitted"
    track_arr = event_data[track_coll][iev]
    track_arr_dQdx = event_data["SiTracks_Refitted_dQdx"][iev]
    track_arr_trackStates = event_data["_SiTracks_Refitted_trackStates"][iev]

    feats_from_track = ["type", "chi2", "ndf"]
    ret = {feat: track_arr[track_coll + "." + feat] for feat in feats_from_track}

    ret["dEdx"] = track_arr_dQdx["SiTracks_Refitted_dQdx.dQdx.value"]
    ret["dEdxError"] = track_arr_dQdx["SiTracks_Refitted_dQdx.dQdx.error"]

    # build the radiusOfInnermostHit variable
    num_tracks = len(ret["dEdx"])
    innermost_radius = []
    for itrack in range(num_tracks):

        # select the track states corresponding to itrack
        # pick the state AtFirstHit
        # https://github.com/key4hep/EDM4hep/blob/fe5a54046a91a7e648d0b588960db7841aebc670/edm4hep.yaml#L220
        ibegin = track_arr[track_coll + "." + "trackStates_begin"][itrack]
        iend = track_arr[track_coll + "." + "trackStates_end"][itrack]

        refX = track_arr_trackStates["_SiTracks_Refitted_trackStates" + "." + "referencePoint.x"][ibegin:iend]
        refY = track_arr_trackStates["_SiTracks_Refitted_trackStates" + "." + "referencePoint.y"][ibegin:iend]
        location = track_arr_trackStates["_SiTracks_Refitted_trackStates" + "." + "location"][ibegin:iend]

        istate = np.argmax(location == 2)  # 2 refers to AtFirstHit

        innermost_radius.append(math.sqrt(refX[istate] ** 2 + refY[istate] ** 2))

    ret["radiusOfInnermostHit"] = np.array(innermost_radius)

    # get the index of the first track state
    trackstate_idx = event_data[track_coll][track_coll + ".trackStates_begin"][iev]
    # get the properties of the track at the first track state (at the origin)
    for k in ["tanLambda", "D0", "phi", "omega", "Z0", "time"]:
        ret[k] = awkward.to_numpy(
            event_data["_SiTracks_Refitted_trackStates"]["_SiTracks_Refitted_trackStates." + k][iev][trackstate_idx]
        )

    ret["pt"] = awkward.to_numpy(track_pt(ret["omega"]))
    ret["px"] = awkward.to_numpy(np.cos(ret["phi"])) * ret["pt"]
    ret["py"] = awkward.to_numpy(np.sin(ret["phi"])) * ret["pt"]
    ret["pz"] = ret["pt"] * ret["tanLambda"]
    ret["p"] = np.sqrt(ret["px"] ** 2 + ret["py"] ** 2 + ret["pz"] ** 2)
    cos_theta = np.divide(ret["pz"], ret["p"], where=ret["p"] > 0)
    theta = np.arccos(cos_theta)
    tt = np.tan(theta / 2.0)
    eta = awkward.to_numpy(-np.log(tt, where=tt > 0))
    eta[tt <= 0] = 0.0
    ret["eta"] = eta

    ret["sin_phi"] = np.sin(ret["phi"])
    ret["cos_phi"] = np.cos(ret["phi"])

    # track is always type 1
    ret["elemtype"] = 1 * np.ones(num_tracks, dtype=np.float32)

    ret["q"] = ret["omega"].copy()
    ret["q"][ret["q"] > 0] = 1
    ret["q"][ret["q"] < 0] = -1

    return ret


def cluster_to_features(event_data, iev, cluster_features=["position.x", "position.y", "position.z", "energy"]):
    """
    Extracts cluster features for a specific event.

    Args:
        ev: The event data containing cluster property data.
        iev: The index of the event to extract data for.
        cluster_features (list of str): A list of cluster feature names to extract.
            Default is ["position.x", "position.y", "position.z", "energy"].
    Returns:
        dict: A dictionary containing cluster features for the specified event.
    Raises:
        ValueError: If a specified feature is not found in the event data.
    """
    cluster_data = event_data["PandoraClusters"]
    for feat in cluster_features:
        if f"PandoraClusters.{feat}" not in cluster_data.fields:
            raise ValueError(f"Feature {feat} not found in PandoraClusters.")
        # Extract cluster features

    return {f"{feat}": awkward.to_numpy(cluster_data[f"PandoraClusters.{feat}"][iev]) for feat in cluster_features}


def create_genparticle_to_genparticle_coo_matrix(event_data, iev):
    """
    Creates the COO matrix indices and weights for the relationship between genparticles
    based on parent-daughter associations.

    Args:
        ev: The event data containing genparticle property data.
        iev: The index of the event to extract data for.

    Returns:
        tuple: A tuple containing three arrays:
            - Row indices (parent genparticle indices).
            - Column indices (daughter genparticle indices).
            - Weights (association weights between parent and daughter genparticles, set to 1.0).
    """
    # Extract parent-daughter associations
    daughters_begin = event_data["MCParticles.daughters_begin"][iev]
    daughters_end = event_data["MCParticles.daughters_end"][iev]
    daughter_indices = event_data["_MCParticles_daughters/_MCParticles_daughters.index"][iev]

    # Initialize lists for COO matrix
    coo_rows = []
    coo_cols = []
    coo_weights = []

    # Iterate over genparticles and their associated daughters
    for parent_idx, (beg, end) in enumerate(zip(daughters_begin, daughters_end)):
        for idaughter in range(beg, end):
            daughter_idx = daughter_indices[idaughter]
            coo_rows.append(parent_idx)
            coo_cols.append(daughter_idx)
            coo_weights.append(1.0)  # Assuming weight is 1.0 for all associations

    return (
        np.array(coo_rows),
        np.array(coo_cols),
        np.array(coo_weights),
    )


def create_genparticle_to_genparticle_coo_matrix2(event_data, iev):
    """
    Creates the COO matrix indices and weights for the relationship between genparticles
    based on parent-daughter associations.

    Args:
        ev: The event data containing genparticle property data.
        iev: The index of the event to extract data for.

    Returns:
        tuple: A tuple containing three arrays:
            - Row indices (parent genparticle indices).
            - Column indices (daughter genparticle indices).
            - Weights (association weights between parent and daughter genparticles, set to 1.0).
    """
    parents_begin = event_data["MCParticles.parents_begin"][iev]
    parents_end = event_data["MCParticles.parents_end"][iev]
    parent_indices = event_data["_MCParticles_parents/_MCParticles_parents.index"][iev]

    # Initialize lists for COO matrix
    coo_rows = []
    coo_cols = []
    coo_weights = []

    # Iterate over genparticles and their associated parents
    for daughter_idx, (beg, end) in enumerate(zip(parents_begin, parents_end)):
        for iparent in range(beg, end):
            parent_idx = parent_indices[iparent]
            coo_rows.append(parent_idx)
            coo_cols.append(daughter_idx)
            coo_weights.append(1.0)  # Assuming weight is 1.0 for all associations

    return (
        np.array(coo_rows),
        np.array(coo_cols),
        np.array(coo_weights),
    )

## Main

Data to extract:
- [x] hits (all tracker and calo hit features in the event)
- [x] genparticles (all genparticle features in the event, noting that the genparticles are in the form of a decay tree)
- [x] genparticle_to_genparticle (sparse association matrix for the genparticle decay tree)
- [x] hits_to_genparticles (sparse association matrix)
- [x] tracks (all track features in the event)
- [x] tracks_to_hits (sparse association matrix between tracks and tracker hits)
- [x] clusters (all cluster features in the event)
- [x] clusters_to_hits (sparse association matrix between clusters and calo hits)

Additional data to extract:
- [x] genparticle_to_track (sparse association matrix for the genparticle decay tree)

In [None]:
root_files_dir = Path("/mnt/ceph/users/ewulff/data/cld/Dec3/subfolder_0/")
root_file = root_files_dir / "reco_p8_ee_tt_ecm365_10000.root"
fi = uproot.open(root_file)
ev = fi["events"]

# which event to pick from the file
iev = 2

collectionIDs = {
    k: v
    for k, v in zip(
        fi.get("podio_metadata").arrays("events___idTable/m_names")["events___idTable/m_names"][0],
        fi.get("podio_metadata").arrays("events___idTable/m_collectionIDs")["events___idTable/m_collectionIDs"][0],
    )
}

event_data = get_event_data(ev)

# Extract gparticle, track, and cluster features
gen_features = gen_to_features(event_data, iev)
track_features = track_to_features(event_data, iev)
cluster_features = cluster_to_features(
    event_data, iev, cluster_features=["position.x", "position.y", "position.z", "energy", "type"]
)

# Process calorimeter hit data
calo_hit_features, genparticle_to_calo_hit_matrix, calo_hit_idx_local_to_global = process_calo_hit_data(
    event_data, iev, collectionIDs
)

# Process tracker hit data
tracker_hit_features, genparticle_to_tracker_hit_matrix, tracker_hit_idx_local_to_global = process_tracker_hit_data(
    event_data, iev, collectionIDs
)

# Create the track-to-trackerhit adjacency matrix
track_to_tracker_hit_matrix, tracker_hit_idx_local_to_global_2 = create_track_to_hit_coo_matrix(
    event_data, iev, collectionIDs
)
assert (
    tracker_hit_idx_local_to_global == tracker_hit_idx_local_to_global_2
), "Local to global tracker hit index mapping mismatch!"

# Ceate the cluster-to-clusterhit adjacency matrix
cluster_to_cluster_hit_matrix, calo_hit_idx_local_to_global_2 = create_cluster_to_hit_coo_matrix(
    event_data, iev, collectionIDs
)
assert (
    calo_hit_idx_local_to_global == calo_hit_idx_local_to_global_2
), "Local to global calorimeter hit index mapping mismatch!"

# Create the genparticle-to-track adjacency matrix
gp_to_track_matrix = genparticle_track_adj(event_data, iev)

# Create the genparticle-to-genparticle adjacency matrix
gp_to_gp = create_genparticle_to_genparticle_coo_matrix(event_data, iev)

# Check consistency between the two methods gp-to-gp methods
gp_to_gp2 = create_genparticle_to_genparticle_coo_matrix2(event_data, iev)

# Get the number of genparticles in the event
n_gp = len(ev["MCParticles.momentum.x"].array()[iev])

# create coo matrix through the COOs in gp_to_gp and gp_to_gp2
coo_matrix_gp_to_gp = coo_matrix((gp_to_gp[2], (gp_to_gp[0], gp_to_gp[1])), shape=(n_gp, n_gp))
coo_matrix_gp_to_gp2 = coo_matrix((gp_to_gp2[2], (gp_to_gp2[0], gp_to_gp2[1])), shape=(n_gp, n_gp))
# Check if the two dense gp_to_gp matrices are equal
assert (coo_matrix_gp_to_gp.todense() == coo_matrix_gp_to_gp2.todense()).all()
# Define the output file path
output_file = Path("extracted_features.hdf5")

In [None]:
type(gen_features), type(track_features), type(cluster_features)

In [None]:
type(gen_features["px"]), type(track_features["px"]), type(cluster_features["position.x"])

In [None]:
type(calo_hit_features), type(tracker_hit_features)

In [None]:
type(calo_hit_features["energy"]), type(tracker_hit_features["energy"])

In [None]:
(
    type(genparticle_to_calo_hit_matrix),
    type(genparticle_to_tracker_hit_matrix),
    type(track_to_tracker_hit_matrix),
    type(cluster_to_cluster_hit_matrix),
    type(gp_to_gp),
)

In [None]:
(
    type(genparticle_to_calo_hit_matrix[0]),
    type(genparticle_to_tracker_hit_matrix[0]),
    type(track_to_tracker_hit_matrix[0]),
    type(cluster_to_cluster_hit_matrix[0]),
    type(gp_to_gp[0]),
)

In [None]:
(
    type(genparticle_to_calo_hit_matrix[0][0]),
    type(genparticle_to_tracker_hit_matrix[0][0]),
    type(track_to_tracker_hit_matrix[0][0]),
    type(cluster_to_cluster_hit_matrix[0][0]),
    type(gp_to_gp[0][0]),
)

## Sanity checks

In [None]:
# Combine data from calo_hit_features and tracker_hit_features into a single dataframe
df = pandas.DataFrame()

# Extract data from calo_hit_features and tracker_hit_features using numpy.concatenate
df["px"] = np.concatenate([calo_hit_features["position.x"], tracker_hit_features["position.x"]])
df["py"] = np.concatenate([calo_hit_features["position.y"], tracker_hit_features["position.y"]])
df["pz"] = np.concatenate([calo_hit_features["position.z"], tracker_hit_features["position.z"]])
df["energy"] = np.concatenate([1000 * calo_hit_features["energy"], 1000 * tracker_hit_features["energy"]])
df["plotsize"] = 0.0
df["subdetector"] = np.concatenate([calo_hit_features["subdetector"], tracker_hit_features["subdetector"]])

# Calculate plotsize based on subdetector values
df.loc[df["subdetector"] == 0, "plotsize"] = df.loc[df["subdetector"] == 0, "energy"] / 5.0
df.loc[df["subdetector"] == 1, "plotsize"] = df.loc[df["subdetector"] == 1, "energy"] / 10.0
df.loc[df["subdetector"] == 2, "plotsize"] = df.loc[df["subdetector"] == 2, "energy"] * 100.0
df.loc[df["subdetector"] == 3, "plotsize"] = df.loc[df["subdetector"] == 3, "energy"] * 100.0

In [None]:
# tracks
def helix_eq(charge, bfield, v):
    """Calculate the 3D helical trajectory of a charged particle in a magnetic field.

    This function computes the x, y, and z coordinates of a charged particle's
    helical trajectory in a uniform magnetic field over a range of time values.

    Args:
        charge (float): The charge of the particle in elementary charge units.
        bfield (float): The magnetic field strength in Tesla.
        v (vector): A vector object representing the particle's velocity and properties.

    Returns:
        tuple: A tuple of three lists (x, y, z) representing the particle's
        trajectory in 3D space:
            - x (list): x-coordinates of the trajectory.
            - y (list): y-coordinates of the trajectory.
            - z (list): z-coordinates of the trajectory.
    """
    R = v.pt / (charge * 0.3 * bfield)
    omega = charge * 0.3 * bfield / (v.gamma * v.mass)
    t_values = np.linspace(0, 2 / (c * v.beta), 10)
    x = list(scale * R * np.cos(omega * c * t_values + v.phi - np.pi / 2) - scale * R * np.cos(v.phi - np.pi / 2))
    y = list(scale * R * np.sin(omega * c * t_values + v.phi - np.pi / 2) - scale * R * np.sin(v.phi - np.pi / 2))
    z = list(scale * v.pz * c * t_values / (v.gamma * v.mass))
    return x, y, z


track_px, track_py, track_pz, track_charge = (
    track_features["px"],
    track_features["py"],
    track_features["pz"],
    track_features["q"],
)

track_mass = np.zeros_like(track_px) + pion_mass

track_x, track_y, track_z = [], [], []
for irow in range(len(track_px)):

    # convert to vector
    v = vector.obj(px=track_px[irow], py=track_py[irow], pz=track_pz[irow], mass=track_mass[irow])

    x, y, z = helix_eq(track_charge[irow], B, v)
    track_x += x
    track_y += y
    track_z += z

    track_x += [None]  # Add None to separate tracks in the plot
    track_y += [None]
    track_z += [None]

In [None]:
labels = {0: "Raw ECAL hit", 1: "Raw HCAL hit", 2: "Raw Muon chamber hit", 3: "Raw tracker hit"}

subdetector_color = {0: "steelblue", 1: "green", 2: "orange", 3: "red"}  # ECAL  # HCAL  # MUON  # Tracker

In [None]:
traces = []

# raw hits
for subdetector in [0, 1, 2, 3]:

    trace = go.Scatter3d(
        x=np.clip(df["px"][df["subdetector"] == subdetector], -4000, 4000),
        y=np.clip(df["py"][df["subdetector"] == subdetector], -4000, 4000),
        z=np.clip(df["pz"][df["subdetector"] == subdetector], -4000, 4000),
        mode="markers",
        marker=dict(
            size=np.clip(2 + 2 * np.log(df["plotsize"]), 1, 15),
            color=subdetector_color[subdetector],
            colorscale="Viridis",
            opacity=0.8,
        ),
        name=labels[subdetector],
    )
    traces.append(trace)

# add the tracks
trace = go.Scatter3d(
    x=np.array(track_x),
    y=np.array(track_y),
    z=np.array(track_z),
    mode="lines",
    name="Track",
    line=dict(color="red"),
)
traces.append(trace)

# # add the clusters
trace = go.Scatter3d(
    x=cluster_features["position.x"],
    y=cluster_features["position.y"],
    z=cluster_features["position.z"],
    mode="markers",
    marker=dict(
        size=cluster_features["energy"],
        color="blue",
        opacity=0.8,
    ),
    name="ECAL/HCAL clusters",
)
traces.append(trace)

# Customize the axis names
layout = go.Layout(
    scene=dict(
        xaxis=dict(title="", showticklabels=False),
        yaxis=dict(title="", showticklabels=False),
        zaxis=dict(title="", showticklabels=False),
        camera=dict(
            up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
            center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
            eye=dict(x=0, y=0, z=2.0),  # Sets the position of the camera
        ),
    ),
    legend=dict(x=0.8, y=0.5, font=dict(size=16)),  # https://plotly.com/python/legend/
    showlegend=True,
    width=700,
    height=700,
)

# Create the figure and display the plot
fig = go.Figure(data=traces, layout=layout)
fig.update_traces(
    marker_line_width=0, selector=dict(type="scatter3d")
)  # for plotly to avoid plotting white spots when things overlap
# fig.write_image("pic_tracks_rawcalohits.pdf", width=1000, height=1000, scale=2)

In [None]:
pdg_dict = {
    22: "photon",  # photon
    11: "electron",  # electron
    13: "muon",  # muon
    130: "n. hadron",  # neutral hadron
    211: "ch. hadron",  # charged hadron
}

color_dict = {
    "photon": "red",  # photon
    "electron": "green",  # electron
    "muon": "purple",  # muon
    "n. hadron": "orange",  # neutral hadron
    "ch. hadron": "blue",  # charged hadron
    None: "black",  # placeholder for skipped element
}

# Extract relevant data from gen_features
gen_px = gen_features["px"]
gen_py = gen_features["py"]
gen_pz = gen_features["pz"]
gen_mass = gen_features["mass"]
gen_charge = gen_features["charge"]
gen_pdg = awkward.to_numpy(np.absolute(gen_features["PDG"]))

# Set all other particles to ch.had or n.had
gen_pdg[(gen_pdg != 13) & (gen_pdg != 11) & (gen_pdg != 22) & awkward.to_numpy(np.abs(gen_charge) > 0)] = (
    211  # when not filtering genstatus==1, charge can be between 0 and 1
)
gen_pdg[(gen_pdg != 13) & (gen_pdg != 11) & (gen_pdg != 22) & awkward.to_numpy(np.abs(gen_charge) == 0)] = 130

# Extrapolate MC particle trajectories
mc_x = []
mc_y = []
mc_z = []
pdg_list = []
for irow in range(len(gen_px)):

    # Convert to vector
    v = vector.obj(px=gen_px[irow], py=gen_py[irow], pz=gen_pz[irow], mass=gen_mass[irow])
    if gen_charge[irow] == 0:
        this_mc_x = [0, np.clip(scale * v.px / v.mag, -4000, 4000)]
        this_mc_y = [0, np.clip(scale * v.py / v.mag, -4000, 4000)]
        this_mc_z = [0, np.clip(scale * v.pz / v.mag, -4000, 4000)]
    else:
        x, y, z = helix_eq(gen_charge[irow], B, v)
        this_mc_x = x
        this_mc_y = y
        this_mc_z = z

    pdg_list += len(this_mc_x) * [pdg_dict[gen_pdg[irow]]]

    mc_x += this_mc_x
    mc_y += this_mc_y
    mc_z += this_mc_z

    mc_x += [None]
    mc_y += [None]
    mc_z += [None]
    pdg_list += [None]

In [None]:
# Create 3D scatter plot with one trace per particle
traces = []
unique_particles = set(pdg_list)  # Get unique particle types

for particle in unique_particles:

    # Get indices for the current particle, including None at the end of each track
    indices = []
    for i, p in enumerate(pdg_list):
        if p is None:  # we don't need to add non-particles
            continue
        if p == particle:
            indices.append(i)
            if pdg_list[i + 1] is None:
                indices.append(i + 1)  # Add None at the end of the track to separate tracks in plot

    # Create a separate trace for each particle
    traces.append(
        go.Scatter3d(
            x=np.array(mc_x)[indices],
            y=np.array(mc_y)[indices],
            z=np.array(mc_z)[indices],
            mode="lines",
            line=dict(color=color_dict[particle]),  # Assign color for the particle
            name=f"{particle}",  # Add particle name to the legend
            showlegend=True,
        )
    )

# Customize the axis names
layout = go.Layout(
    scene=dict(
        xaxis=dict(title="", showticklabels=False),
        yaxis=dict(title="", showticklabels=False),
        zaxis=dict(title="", showticklabels=False),
        camera=dict(
            up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
            center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
            eye=dict(x=0, y=0, z=2.0),  # Sets the position of the camera
        ),
    ),
    showlegend=True,
    width=700,
    height=700,
)

# Create the figure and display the plot
fig = go.Figure(data=traces, layout=layout)
fig.update_traces(
    marker_line_width=0, selector=dict(type="scatter3d")
)  # Avoid plotting white spots when things overlap
# fig.write_image("pic_particles_legend.pdf", width=1000, height=1000, scale=2)

In [None]:
# check the alignment of extrapolated tracks and their associated hits
fig, axes = plt.subplots(3, 3, figsize=(6, 6))  # Adjust figsize as needed
axes = axes.ravel()

for itrk in range(9):
    plt.sca(axes[itrk])
    v = vector.obj(px=track_px[itrk], py=track_py[itrk], pz=track_pz[itrk], mass=track_mass[itrk])

    # Get the global hit indices associated with the current track
    # track_to_tracker_hit_matrix is a tuple where the first element contains the track indices and the second element contains the hit indices
    hit_indices = track_to_tracker_hit_matrix[1][track_to_tracker_hit_matrix[0] == itrk]

    # Extract the corresponding hit positions
    hs_x = tracker_hit_features["position.x"][hit_indices]
    hs_z = tracker_hit_features["position.z"][hit_indices]

    # Calculate the helix trajectory
    x, y, z = helix_eq(track_charge[itrk], B, v)

    # Plot the track and associated hits
    plt.plot(x, z, label="Track")
    plt.scatter(hs_x, hs_z, color="red", marker=".", label="Tracker hit")
    plt.xlim(-2000, 2000)
    plt.ylim(-2000, 2000)
    plt.xticks([])
    plt.yticks([])
    axes[itrk].set_box_aspect(1)

axes[2].legend(loc="upper right")
plt.tight_layout()

In [None]:
# Check the alignment of clusters and their associated hits
fig, axes = plt.subplots(3, 3, figsize=(6, 6))  # Adjust figsize as needed
axes = axes.ravel()

for icls in range(9):
    plt.sca(axes[icls])

    # Get the global hit indices associated with the current cluster
    # cluster_to_cluster_hit_matrix is a tuple where the first element contains the cluster indices and the second element contains the hit indices
    hit_indices = cluster_to_cluster_hit_matrix[1][cluster_to_cluster_hit_matrix[0] == icls]

    # Extract the corresponding hit positions
    hs_x = calo_hit_features["position.x"][hit_indices]
    hs_y = calo_hit_features["position.y"][hit_indices]

    # Plot the cluster and associated hits
    plt.scatter(
        cluster_features["position.x"][icls],
        cluster_features["position.y"][icls],
        s=100 * cluster_features["energy"][icls],
        alpha=0.5,
        label="cluster",
    )
    plt.scatter(hs_x, hs_y, color="red", marker=".", label="hit")
    plt.xlim(-4000, 4000)
    plt.ylim(-4000, 4000)
    plt.xticks([])
    plt.yticks([])
    axes[icls].set_box_aspect(1)

axes[2].legend(loc="upper right")
plt.tight_layout()

## Save processed data

### HDF5

In [None]:
def save_event_to_hdf5(
    output_file,
    iev,
    gen_features,
    track_features,
    cluster_features,
    calo_hit_features,
    tracker_hit_features,
    genparticle_to_calo_hit_matrix,
    genparticle_to_tracker_hit_matrix,
    track_to_tracker_hit_matrix,
    cluster_to_cluster_hit_matrix,
    gp_to_track_matrix,
    gp_to_gp,
):
    with h5py.File(output_file, "w") as h5f:
        # Create a group for the event
        event_group = h5f.create_group(f"event_{iev}")

        # Save genparticle features
        for key, value in gen_features.items():
            event_group.create_dataset(f"gen_features/{key}", data=value)

        # Save track features
        for key, value in track_features.items():
            event_group.create_dataset(f"track_features/{key}", data=value)

        # Save cluster features
        for key, value in cluster_features.items():
            event_group.create_dataset(f"cluster_features/{key}", data=value)

        # Save calorimeter hit features
        for key, value in calo_hit_features.items():
            event_group.create_dataset(f"calo_hit_features/{key}", data=awkward.to_numpy(value))

        # Save tracker hit features
        for key, value in tracker_hit_features.items():
            event_group.create_dataset(f"tracker_hit_features/{key}", data=awkward.to_numpy(value))

        # Save adjacency matrices
        event_group.create_dataset("genparticle_to_calo_hit_matrix/rows", data=genparticle_to_calo_hit_matrix[0])
        event_group.create_dataset("genparticle_to_calo_hit_matrix/cols", data=genparticle_to_calo_hit_matrix[1])
        event_group.create_dataset("genparticle_to_calo_hit_matrix/weights", data=genparticle_to_calo_hit_matrix[2])

        event_group.create_dataset("genparticle_to_tracker_hit_matrix/rows", data=genparticle_to_tracker_hit_matrix[0])
        event_group.create_dataset("genparticle_to_tracker_hit_matrix/cols", data=genparticle_to_tracker_hit_matrix[1])
        event_group.create_dataset(
            "genparticle_to_tracker_hit_matrix/weights", data=genparticle_to_tracker_hit_matrix[2]
        )

        event_group.create_dataset("track_to_tracker_hit_matrix/rows", data=track_to_tracker_hit_matrix[0])
        event_group.create_dataset("track_to_tracker_hit_matrix/cols", data=track_to_tracker_hit_matrix[1])
        event_group.create_dataset("track_to_tracker_hit_matrix/weights", data=track_to_tracker_hit_matrix[2])

        event_group.create_dataset("cluster_to_cluster_hit_matrix/rows", data=cluster_to_cluster_hit_matrix[0])
        event_group.create_dataset("cluster_to_cluster_hit_matrix/cols", data=cluster_to_cluster_hit_matrix[1])
        event_group.create_dataset("cluster_to_cluster_hit_matrix/weights", data=cluster_to_cluster_hit_matrix[2])

        event_group.create_dataset("gp_to_track_matrix/rows", data=gp_to_track_matrix[0])
        event_group.create_dataset("gp_to_track_matrix/cols", data=gp_to_track_matrix[1])
        event_group.create_dataset("gp_to_track_matrix/weights", data=gp_to_track_matrix[2])

        event_group.create_dataset("gp_to_gp/rows", data=gp_to_gp[0])
        event_group.create_dataset("gp_to_gp/cols", data=gp_to_gp[1])
        event_group.create_dataset("gp_to_gp/weights", data=gp_to_gp[2])

In [None]:
output_dir = Path("/mnt/ceph/users/ewulff/data/cld/hdf5")


def process_root_files(data_dir, output_dir, events_per_hdf5=100, max_events=None, max_root_files=None):
    event_counter = 0
    root_counter = 0

    output_file = output_dir / Path(f"events_{event_counter}_to_{event_counter + events_per_hdf5 - 1}.hdf5")
    h5f = h5py.File(output_file, "w")
    print(f"Created new HDF5 file: {output_file}")

    for root_file in tqdm(
        Path(data_dir).rglob("*.root"),
        desc="Processing ROOT files",
        total=max_root_files or len(list(Path(data_dir).rglob("*.root"))),
    ):
        try:
            fi = uproot.open(root_file)
            ev = fi["events"]
            event_data = get_event_data(ev)
            for iev in range(len(ev["MCParticles.momentum.x"].array())):
                # Create a group for the event
                event_group = h5f.create_group(f"event_{event_counter}")

                # Extract features and adjacency matrices for the current event
                gen_features = gen_to_features(event_data, iev)
                track_features = track_to_features(event_data, iev)
                cluster_features = cluster_to_features(
                    event_data, iev, cluster_features=["position.x", "position.y", "position.z", "energy", "type"]
                )
                calo_hit_features, genparticle_to_calo_hit_matrix, _ = process_calo_hit_data(
                    event_data, iev, collectionIDs
                )
                tracker_hit_features, genparticle_to_tracker_hit_matrix, _ = process_tracker_hit_data(
                    event_data, iev, collectionIDs
                )
                track_to_tracker_hit_matrix, _ = create_track_to_hit_coo_matrix(event_data, iev, collectionIDs)
                cluster_to_cluster_hit_matrix, _ = create_cluster_to_hit_coo_matrix(event_data, iev, collectionIDs)
                gp_to_track_matrix = genparticle_track_adj(event_data, iev)
                gp_to_gp = create_genparticle_to_genparticle_coo_matrix(event_data, iev)

                # Save genparticle features
                for key, value in gen_features.items():
                    event_group.create_dataset(f"gen_features/{key}", data=value)

                # Save track features
                for key, value in track_features.items():
                    event_group.create_dataset(f"track_features/{key}", data=value)

                # Save cluster features
                for key, value in cluster_features.items():
                    event_group.create_dataset(f"cluster_features/{key}", data=value)

                # Save calorimeter hit features
                for key, value in calo_hit_features.items():
                    event_group.create_dataset(f"calo_hit_features/{key}", data=awkward.to_numpy(value))

                # Save tracker hit features
                for key, value in tracker_hit_features.items():
                    event_group.create_dataset(f"tracker_hit_features/{key}", data=awkward.to_numpy(value))

                # Save adjacency matrices
                event_group.create_dataset(
                    "genparticle_to_calo_hit_matrix/rows", data=genparticle_to_calo_hit_matrix[0]
                )
                event_group.create_dataset(
                    "genparticle_to_calo_hit_matrix/cols", data=genparticle_to_calo_hit_matrix[1]
                )
                event_group.create_dataset(
                    "genparticle_to_calo_hit_matrix/weights", data=genparticle_to_calo_hit_matrix[2]
                )

                event_group.create_dataset(
                    "genparticle_to_tracker_hit_matrix/rows", data=genparticle_to_tracker_hit_matrix[0]
                )
                event_group.create_dataset(
                    "genparticle_to_tracker_hit_matrix/cols", data=genparticle_to_tracker_hit_matrix[1]
                )
                event_group.create_dataset(
                    "genparticle_to_tracker_hit_matrix/weights", data=genparticle_to_tracker_hit_matrix[2]
                )

                event_group.create_dataset("track_to_tracker_hit_matrix/rows", data=track_to_tracker_hit_matrix[0])
                event_group.create_dataset("track_to_tracker_hit_matrix/cols", data=track_to_tracker_hit_matrix[1])
                event_group.create_dataset("track_to_tracker_hit_matrix/weights", data=track_to_tracker_hit_matrix[2])

                event_group.create_dataset("cluster_to_cluster_hit_matrix/rows", data=cluster_to_cluster_hit_matrix[0])
                event_group.create_dataset("cluster_to_cluster_hit_matrix/cols", data=cluster_to_cluster_hit_matrix[1])
                event_group.create_dataset(
                    "cluster_to_cluster_hit_matrix/weights", data=cluster_to_cluster_hit_matrix[2]
                )

                event_group.create_dataset("gp_to_track_matrix/rows", data=gp_to_track_matrix[0])
                event_group.create_dataset("gp_to_track_matrix/cols", data=gp_to_track_matrix[1])
                event_group.create_dataset("gp_to_track_matrix/weights", data=gp_to_track_matrix[2])

                event_group.create_dataset("gp_to_gp/rows", data=gp_to_gp[0])
                event_group.create_dataset("gp_to_gp/cols", data=gp_to_gp[1])
                event_group.create_dataset("gp_to_gp/weights", data=gp_to_gp[2])

                event_counter += 1

                if max_events is not None and event_counter >= max_events:
                    print(f"Reached max_events limit: {max_events}. Stopping processing.")
                    h5f.close()
                    return

                # Check if we need to start a new HDF5 file
                if event_counter % events_per_hdf5 == 0:
                    h5f.close()
                    output_file = output_dir / Path(
                        f"events_{event_counter}_to_{event_counter + events_per_hdf5 - 1}.hdf5"
                    )
                    h5f = h5py.File(output_file, "w")
                    print(f"Created new HDF5 file: {output_file}")

            print(f"Processed: {root_file}")
            root_counter += 1

            if max_root_files is not None and root_counter >= max_root_files:
                print(f"Reached max_root_files limit: {max_root_files}. Stopping processing.")
                h5f.close()
                return

        except Exception as e:
            print(f"Error processing {root_file}: {e}")

    # Close the last HDF5 file
    h5f.close()

In [None]:
# process_root_files(data_dir="/mnt/ceph/users/ewulff/data/cld/", output_dir=output_dir, events_per_hdf5=50, max_events=None, max_root_files=2)

In [None]:
output_dir = Path("/mnt/ceph/users/ewulff/data/cld/hdf5")


def merge_root_files_to_hdf5(data_dir, output_dir, files_per_hdf5=10):
    file_counter = 0
    hdf5_counter = 0

    output_file = output_dir / Path(f"merged_features_{hdf5_counter}.hdf5")
    h5f = h5py.File(output_file, "w")

    for root_file in tqdm(data_dir.rglob("*.root"), desc="Processing ROOT files"):
        try:
            fi = uproot.open(root_file)
            ev = fi["events"]
            for iev in range(len(ev["MCParticles.momentum.x"].array())):
                # Create a group for the event
                event_group = h5f.create_group(f"event_{file_counter}_{iev}")

                # Extract features and adjacency matrices for the current event
                gen_features = gen_to_features(ev, iev)
                track_features = track_to_features(ev, iev)
                cluster_features = cluster_to_features(
                    ev, iev, cluster_features=["position.x", "position.y", "position.z", "energy", "type"]
                )
                calo_hit_features, genparticle_to_calo_hit_matrix, calo_hit_idx_local_to_global = process_calo_hit_data(
                    ev, iev, collectionIDs
                )
                tracker_hit_features, genparticle_to_tracker_hit_matrix, tracker_hit_idx_local_to_global = (
                    process_tracker_hit_data(ev, iev, collectionIDs)
                )
                track_to_tracker_hit_matrix, tracker_hit_idx_local_to_global_2 = create_track_to_hit_coo_matrix(
                    ev, iev, collectionIDs
                )
                cluster_to_cluster_hit_matrix, calo_hit_idx_local_to_global_2 = create_cluster_to_hit_coo_matrix(
                    ev, iev, collectionIDs
                )
                gp_to_track_matrix = genparticle_track_adj(ev, iev)
                gp_to_gp = create_genparticle_to_genparticle_coo_matrix(ev, iev)

                # Save genparticle features
                for key, value in gen_features.items():
                    event_group.create_dataset(f"gen_features/{key}", data=value)

                # Save track features
                for key, value in track_features.items():
                    event_group.create_dataset(f"track_features/{key}", data=value)

                # Save cluster features
                for key, value in cluster_features.items():
                    event_group.create_dataset(f"cluster_features/{key}", data=value)

                # Save calorimeter hit features
                for key, value in calo_hit_features.items():
                    event_group.create_dataset(f"calo_hit_features/{key}", data=awkward.to_numpy(value))

                # Save tracker hit features
                for key, value in tracker_hit_features.items():
                    event_group.create_dataset(f"tracker_hit_features/{key}", data=awkward.to_numpy(value))

                # Save adjacency matrices
                event_group.create_dataset(
                    "genparticle_to_calo_hit_matrix/rows", data=genparticle_to_calo_hit_matrix[0]
                )
                event_group.create_dataset(
                    "genparticle_to_calo_hit_matrix/cols", data=genparticle_to_calo_hit_matrix[1]
                )
                event_group.create_dataset(
                    "genparticle_to_calo_hit_matrix/weights", data=genparticle_to_calo_hit_matrix[2]
                )

                event_group.create_dataset(
                    "genparticle_to_tracker_hit_matrix/rows", data=genparticle_to_tracker_hit_matrix[0]
                )
                event_group.create_dataset(
                    "genparticle_to_tracker_hit_matrix/cols", data=genparticle_to_tracker_hit_matrix[1]
                )
                event_group.create_dataset(
                    "genparticle_to_tracker_hit_matrix/weights", data=genparticle_to_tracker_hit_matrix[2]
                )

                event_group.create_dataset("track_to_tracker_hit_matrix/rows", data=track_to_tracker_hit_matrix[0])
                event_group.create_dataset("track_to_tracker_hit_matrix/cols", data=track_to_tracker_hit_matrix[1])
                event_group.create_dataset("track_to_tracker_hit_matrix/weights", data=track_to_tracker_hit_matrix[2])

                event_group.create_dataset("cluster_to_cluster_hit_matrix/rows", data=cluster_to_cluster_hit_matrix[0])
                event_group.create_dataset("cluster_to_cluster_hit_matrix/cols", data=cluster_to_cluster_hit_matrix[1])
                event_group.create_dataset(
                    "cluster_to_cluster_hit_matrix/weights", data=cluster_to_cluster_hit_matrix[2]
                )

                event_group.create_dataset("gp_to_track_matrix/rows", data=gp_to_track_matrix[0])
                event_group.create_dataset("gp_to_track_matrix/cols", data=gp_to_track_matrix[1])
                event_group.create_dataset("gp_to_track_matrix/weights", data=gp_to_track_matrix[2])

                event_group.create_dataset("gp_to_gp/rows", data=gp_to_gp[0])
                event_group.create_dataset("gp_to_gp/cols", data=gp_to_gp[1])
                event_group.create_dataset("gp_to_gp/weights", data=gp_to_gp[2])

            file_counter += 1
            print(f"Processed: {root_file}")

            # Check if we need to start a new HDF5 file
            if file_counter % files_per_hdf5 == 0:
                h5f.close()
                hdf5_counter += 1
                output_file = output_dir / Path(f"merged_features_{hdf5_counter}.hdf5")
                h5f = h5py.File(output_file, "w")
                print(f"Created new HDF5 file: {output_file}")

        except Exception as e:
            print(f"Error processing {root_file}: {e}")

    # Close the last HDF5 file
    h5f.close()

In [None]:
# Read the data back into Python variables
def read_event_from_hdf5(file, iev):
    with h5py.File(file, "r") as h5f:
        # Access the event group
        event_group = h5f[f"event_{iev}"]

        # Load genparticle features
        gen_features_loaded = {key: event_group[f"gen_features/{key}"][:] for key in event_group["gen_features"].keys()}

        # Load track features
        track_features_loaded = {
            key: event_group[f"track_features/{key}"][:] for key in event_group["track_features"].keys()
        }

        # Load cluster features
        cluster_features_loaded = {
            key: event_group[f"cluster_features/{key}"][:] for key in event_group["cluster_features"].keys()
        }

        # Load calorimeter hit features
        calo_hit_features_loaded = {
            key: event_group[f"calo_hit_features/{key}"][:] for key in event_group["calo_hit_features"].keys()
        }

        # Load tracker hit features
        tracker_hit_features_loaded = {
            key: event_group[f"tracker_hit_features/{key}"][:] for key in event_group["tracker_hit_features"].keys()
        }

        # Load adjacency matrices
        genparticle_to_calo_hit_matrix_loaded = (
            event_group["genparticle_to_calo_hit_matrix/rows"][:],
            event_group["genparticle_to_calo_hit_matrix/cols"][:],
            event_group["genparticle_to_calo_hit_matrix/weights"][:],
        )

        genparticle_to_tracker_hit_matrix_loaded = (
            event_group["genparticle_to_tracker_hit_matrix/rows"][:],
            event_group["genparticle_to_tracker_hit_matrix/cols"][:],
            event_group["genparticle_to_tracker_hit_matrix/weights"][:],
        )

        track_to_tracker_hit_matrix_loaded = (
            event_group["track_to_tracker_hit_matrix/rows"][:],
            event_group["track_to_tracker_hit_matrix/cols"][:],
            event_group["track_to_tracker_hit_matrix/weights"][:],
        )

        cluster_to_cluster_hit_matrix_loaded = (
            event_group["cluster_to_cluster_hit_matrix/rows"][:],
            event_group["cluster_to_cluster_hit_matrix/cols"][:],
            event_group["cluster_to_cluster_hit_matrix/weights"][:],
        )

        gp_to_track_matrix_loaded = (
            event_group["gp_to_track_matrix/rows"][:],
            event_group["gp_to_track_matrix/cols"][:],
            event_group["gp_to_track_matrix/weights"][:],
        )

        gp_to_gp_loaded = (
            event_group["gp_to_gp/rows"][:],
            event_group["gp_to_gp/cols"][:],
            event_group["gp_to_gp/weights"][:],
        )

    return (
        gen_features_loaded,
        track_features_loaded,
        cluster_features_loaded,
        calo_hit_features_loaded,
        tracker_hit_features_loaded,
        genparticle_to_calo_hit_matrix_loaded,
        genparticle_to_tracker_hit_matrix_loaded,
        track_to_tracker_hit_matrix_loaded,
        cluster_to_cluster_hit_matrix_loaded,
        gp_to_track_matrix_loaded,
        gp_to_gp_loaded,
    )

In [None]:
# output_file = Path("extracted_features.hdf5")
# save_event_to_hdf5(
#     output_file,
#     iev,
#     gen_features,
#     track_features,
#     cluster_features,
#     calo_hit_features,
#     tracker_hit_features,
#     genparticle_to_calo_hit_matrix,
#     genparticle_to_tracker_hit_matrix,
#     track_to_tracker_hit_matrix,
#     cluster_to_cluster_hit_matrix,
#     gp_to_track_matrix,
#     gp_to_gp
# )

In [None]:
# Read the data back from the HDF5 file
# (
#     gen_features,
#     track_features,
#     cluster_features,
#     calo_hit_features,
#     tracker_hit_features,
#     genparticle_to_calo_hit_matrix,
#     genparticle_to_tracker_hit_matrix,
#     track_to_tracker_hit_matrix,
#     cluster_to_cluster_hit_matrix,
#     gp_to_track_matrix,
#     gp_to_gp,
# ) = read_event_from_hdf5(output_file, iev)

In [None]:
# data_dir = "/mnt/ceph/users/ewulff/particlemind/data/zenodo/calohit_challenge"
# file_path = Path(data_dir) / "dataset_1_photons_1.hdf5"

data_dir = "/mnt/ceph/users/ewulff/particlemind/notebooks"
file_path = Path(data_dir) / "extracted_features.hdf5"


def print_h5_contents(file_path):
    """
    Prints the contents of an HDF5 file, including groups and datasets.

    Args:
        file_path (str or Path): Path to the HDF5 file.
    """
    if not Path(file_path).exists():
        raise FileNotFoundError(f"The file {file_path} does not exist.")

    with h5py.File(file_path, "r") as h5f:
        # List all groups and datasets in the file
        def list_datasets(name, obj):
            if isinstance(obj, h5py.Dataset):
                print(f"Dataset: {name}, Shape: {obj.shape}, Type: {obj.dtype}")
            elif isinstance(obj, h5py.Group):
                print(f"Group: {name}")

        print("Contents of the HDF5 file:")
        h5f.visititems(list_datasets)


# print_h5_contents(file_path)

### LMDB

In [None]:
def dumps(obj):
    """
    Serialize an object.

    Returns:
        Implementation-dependent bytes-like object
    """
    return pickle.dumps(obj, protocol=5)


def process_root_files_to_lmdb(input_dir, output, max_root_files=None):
    """
    Process ROOT files and save extracted features into an LMDB database.

    Args:
        input_dir (str or Path): Directory containing ROOT files.
        output (str or Path): Path to the LMDB database file.
        collectionIDs (dict): Mapping of collection names to their IDs.
        max_root_files (int, optional): Maximum number of ROOT files to process. Defaults to None.

    Returns:
        None
    """
    lmdb_path = os.path.expanduser(output)
    isdir = os.path.isdir(lmdb_path)

    print("Generate LMDB to %s" % lmdb_path)
    db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, meminit=False, map_async=True)

    txn = db.begin(write=True)
    event_counter = 0
    root_counter = 0

    for root_file in tqdm(Path(input_dir).rglob("*.root"), desc="Processing ROOT files"):
        if max_root_files is not None and root_counter >= max_root_files:
            print(f"Reached max_root_files limit: {max_root_files}. Stopping processing.")
            break
        try:
            fi = uproot.open(root_file)
            collectionIDs = {
                k: v
                for k, v in zip(
                    fi.get("podio_metadata").arrays("events___idTable/m_names")["events___idTable/m_names"][0],
                    fi.get("podio_metadata").arrays("events___idTable/m_collectionIDs")[
                        "events___idTable/m_collectionIDs"
                    ][0],
                )
            }
            ev = fi["events"]
            event_data = get_event_data(ev)
            for iev in range(len(ev["MCParticles.momentum.x"].array())):
                # Extract features and adjacency matrices for the current event
                gen_features = gen_to_features(event_data, iev)
                track_features = track_to_features(event_data, iev)
                cluster_features = cluster_to_features(
                    event_data, iev, cluster_features=["position.x", "position.y", "position.z", "energy", "type"]
                )
                calo_hit_features, genparticle_to_calo_hit_matrix, _ = process_calo_hit_data(
                    event_data, iev, collectionIDs
                )
                tracker_hit_features, genparticle_to_tracker_hit_matrix, _ = process_tracker_hit_data(
                    event_data, iev, collectionIDs
                )
                track_to_tracker_hit_matrix, _ = create_track_to_hit_coo_matrix(event_data, iev, collectionIDs)
                cluster_to_cluster_hit_matrix, _ = create_cluster_to_hit_coo_matrix(event_data, iev, collectionIDs)
                gp_to_track_matrix = genparticle_track_adj(event_data, iev)
                gp_to_gp = create_genparticle_to_genparticle_coo_matrix(event_data, iev)

                # Save the event data into LMDB
                txn.put(
                    "{}".format(event_counter).encode("ascii"),
                    dumps(
                        {
                            "gen_features": gen_features,
                            "track_features": track_features,
                            "cluster_features": cluster_features,
                            "calo_hit_features": calo_hit_features,
                            "tracker_hit_features": tracker_hit_features,
                            "genparticle_to_calo_hit_matrix": genparticle_to_calo_hit_matrix,
                            "genparticle_to_tracker_hit_matrix": genparticle_to_tracker_hit_matrix,
                            "track_to_tracker_hit_matrix": track_to_tracker_hit_matrix,
                            "cluster_to_cluster_hit_matrix": cluster_to_cluster_hit_matrix,
                            "gp_to_track_matrix": gp_to_track_matrix,
                            "gp_to_gp": gp_to_gp,
                        }
                    ),
                )

                if event_counter % 100 == 0:
                    print(f"[{event_counter}] events processed")
                    txn.commit()
                    txn = db.begin(write=True)

                event_counter += 1

            root_counter += 1

        except Exception as e:
            print(f"Error processing {root_file}: {e}")

    # Finish iterating through all events
    txn.commit()
    keys = ["{}".format(k).encode("ascii") for k in range(event_counter)]
    with db.begin(write=True) as txn:
        txn.put(b"__keys__", dumps(keys))
        txn.put(b"__len__", dumps(len(keys)))

    print("Flushing database ...")
    db.sync()
    db.close()

In [None]:
# process_root_files_to_lmdb(
#     input_dir="/mnt/ceph/users/ewulff/data/cld/", output="/mnt/ceph/users/ewulff/data/cld/lmdb", max_root_files=2
# )

In [None]:
# read the lmdb database
def read_full_lmdb_database(lmdb_path):
    lmdb_path = os.path.expanduser(lmdb_path)
    db = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), readonly=True, lock=False)

    with db.begin() as txn:
        keys = pickle.loads(txn.get(b"__keys__"))
        data = {ii: pickle.loads(txn.get(key)) for ii, key in enumerate(keys)}

    db.close()
    return data


lmdb_data = read_full_lmdb_database("/mnt/ceph/users/ewulff/data/cld/lmdb")

In [None]:
for k in lmdb_data[0].keys():
    print(k)