In [None]:
import awkward
import numpy as np
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
import particle

In [None]:
from particle import Particle

In [None]:
Particle.from_pdgid(311).name

In [None]:
particles = awkward.from_parquet("../data/particles.parquet")

In [None]:
calohits = awkward.from_parquet("../data/calohits.parquet")

In [None]:
particles[0]

In [None]:
calohits[0]

In [None]:
import networkx as nx
def gen_to_graph(gen_features):
    g = nx.DiGraph()
    for igp in gen_features["particle_id"]:
        g.add_node(igp)

    for igen, parent in zip(gen_features["particle_id"], gen_features["parent_id"]):
        if igen in g.nodes and parent in g.nodes:
            g.add_edge(parent, igen)
    return g


def color_node(st):
    if st:
        return "blue"
    else:
        return "red"

In [None]:
iev = 2
gen_features = particles[iev]
calo_features = calohits[iev]
g = gen_to_graph(gen_features)

In [None]:
map_to_parent = {}
for k, v in zip(gen_features["particle_id"], gen_features["parent_id"]):
    map_to_parent[k] = v

map_to_primary = {}
for k, v in zip(gen_features["particle_id"], gen_features["primary"]):
    map_to_primary[k] = v

map_to_idx = {}
for v, k in enumerate(gen_features["particle_id"]):
    map_to_idx[k] = v

In [None]:
#hit idx
genparticle_to_hit_matrix_coo0 = np.repeat(
    np.arange(len(calohits[iev]["contrib_particle_ids"])),
    awkward.count(calohits[iev]["contrib_particle_ids"], axis=1)
)

#gen idx
genparticle_to_hit_matrix_coo1 = awkward.flatten(calohits[iev]["contrib_particle_ids"])
genparticle_to_hit_matrix_w = awkward.flatten(calohits[iev]["contrib_energies"])

In [None]:
len(g.nodes), len(gen_features["pdg_id"])

In [None]:
node_color = [color_node(st) for st in gen_features["primary"]]
node_size = [np.clip(10 * e, 1, 100) for e in gen_features["energy"]]
alpha = [1.0 if n in genparticle_to_hit_matrix_coo1 else 0.2 for n in g.nodes]

labels = {n: "{}".format(Particle.from_pdgid(pid).name) for n, pid in zip(g.nodes, gen_features["pdg_id"])}
pos = nx.nx_agraph.graphviz_layout(g, prog="circo")
fig = plt.figure(figsize=(20, 20))
nx.draw_networkx_nodes(g, pos, node_color=node_color, node_size=node_size, alpha=alpha)
nx.draw_networkx_edges(
    g,
    pos,
    arrowsize=1,
    width=0.5,
    alpha=0.2,
    node_size=node_size,
)
nx.draw_networkx_labels(g, pos, labels=labels, font_size=2)
plt.savefig("graph.pdf")

In [None]:
def get_hit_labels(hit_idx, gen_idx, weights, max_hits=None):
    # Initialize an array to store labels for each hit
    if not max_hits:
        max_hits = int(np.max(hit_idx)) + 1
    hit_labels = np.full(max_hits, -1, dtype=int)  # Default label is -1 (unassigned)
    hit_label_weights = dict()  # To keep track of the highest weight for each hit

    # Iterate through the sparse COO matrix data
    for h_idx, g_idx, weight in zip(hit_idx, gen_idx, weights):
        if hit_labels[h_idx] == -1 or weight > hit_label_weights[h_idx]:
            hit_labels[h_idx] = g_idx
            hit_label_weights[h_idx] = weight

    # hit_labels now contains the genparticle index for each hit

    return hit_labels

def get_hit_labels_p(hit_idx, gen_idx, weights, map_to_parent, map_to_primary, max_hits=None):
    # Initialize an array to store labels for each hit
    if not max_hits:
        max_hits = int(np.max(hit_idx)) + 1
    hit_labels = np.full(max_hits, -1, dtype=int)  # Default label is -1 (unassigned)
    hit_label_weights = dict()  # To keep track of the highest weight for each hit
    
    for h_idx, g_idx, weight in zip(hit_idx, gen_idx, weights):
        if hit_labels[h_idx] == -1 or weight > hit_label_weights[h_idx]:
            hit_labels[h_idx] = g_idx
            hit_label_weights[h_idx] = weight

    #recurse up the gen particle tree to the first primary particle
    hit_labels_p = hit_labels.copy()
    for h_idx, g_idx in enumerate(hit_labels):
        while not map_to_primary[g_idx]:
            g_idx = map_to_parent[g_idx]
        hit_labels_p[h_idx] = g_idx

    return hit_labels_p

In [None]:
max_hits = len(calo_features["x"])
hit_labels = get_hit_labels(genparticle_to_hit_matrix_coo0, genparticle_to_hit_matrix_coo1, genparticle_to_hit_matrix_w, max_hits=max_hits)
hit_labels_p = get_hit_labels_p(genparticle_to_hit_matrix_coo0, genparticle_to_hit_matrix_coo1, genparticle_to_hit_matrix_w, map_to_parent, map_to_primary, max_hits=max_hits)

In [None]:
calo_hit_positions = np.array(awkward.to_numpy(np.stack([
    calo_features["x"],
    calo_features["y"],
    calo_features["z"],
]))).T

In [None]:
len(calo_features["x"])

In [None]:
import plotly.graph_objects as go
import random

def plot_calo_hits_colored_by_genparticle(
    hit_energies, hit_labels, calo_hit_positions, title="Calorimeter hits colored by genparticle"
):
    # Assign unique colors to each genparticle ID
    unique_ids = np.unique(hit_labels)
    colors = plt.cm.tab10(np.linspace(0, 1, len(unique_ids)))
    color_map = {
        gen_id: f"rgba({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)}, {color[3]})"
        for gen_id, color in zip(unique_ids, colors)
    }

    # random color map
    def random_color():
        """Generate a random color in RGBA format."""
        return f"rgba({random.randint(0, 255)}, {random.randint(0, 255)}, {random.randint(0, 255)}, 1)"

    random_color_map = {gen_id: random_color() for gen_id in unique_ids}
    random_color_map[-1] = "rgba(0,0,0)"

    # Create traces for each genparticle ID
    traces = []
    for gen_id in unique_ids:
        if gen_id != -1:
            mask = hit_labels == gen_id  # Create a mask for hits belonging to the current genparticle ID
            traces.append(
                go.Scatter3d(
                    x=calo_hit_positions[mask, 0],
                    y=calo_hit_positions[mask, 1],
                    z=calo_hit_positions[mask, 2],
                    mode="markers",
                    marker=dict(
                        #size=2*np.exp(0.1*np.log(hit_energies)+1)+1,
                        size=2,
                        color=random_color_map[gen_id]
                    ),
                    name=f"gp {gen_id}",
                )
            )

    # Customize the axis names
    layout = go.Layout(
        scene=dict(
            xaxis=dict(title="X"),
            yaxis=dict(title="Y"),
            zaxis=dict(title="Z"),
            camera=dict(
                up=dict(x=1, y=0, z=0),  # Sets the orientation of the camera
                center=dict(x=0, y=0, z=0),  # Sets the center point of the plot
                eye=dict(x=0, y=0, z=2.1),  # Sets the position of the camera
            ),
        ),
        showlegend=False,
        width=700,
        height=700,
        title=title,
    )

    # Create the figure and display the plot
    fig = go.Figure(data=traces, layout=layout)
    fig.show()

In [None]:
plot_calo_hits_colored_by_genparticle(calo_features["total_energy"], hit_labels_p, calo_hit_positions)

In [None]:
ehits = []
egen = []
for h in np.unique(hit_labels_p):
    msk = hit_labels_p == h
    energy_hits = np.sum(calo_features["total_energy"][msk])
    energy_gen = gen_features["energy"][map_to_idx[h]]
    ehits.append(energy_hits)
    egen.append(energy_gen)
ehits = np.array(ehits)
egen = np.array(egen)

In [None]:
plt.figure(figsize=(5,5))
plt.scatter(ehits, egen);

plt.xscale("log")
plt.yscale("log")
plt.xlim(1e-3, 1e3)
plt.ylim(1e-3, 1e3)
plt.xlabel("$\sum E_{hits}$")
plt.ylabel("$E_{gen}^{primary}$")