In [1]:
! ls ../data/FCC-CLD/gun/*.root    # must confirm which CLD version?

../data/FCC-CLD/gun/out_reco_edm4hep_REC.edm4hep.root
../data/FCC-CLD/gun/out_sim_edm4hep.root


In [2]:
import math

import numpy as np
import uproot
import awkward as ak

import networkx as nx
import matplotlib.pyplot as plt

import plotly.graph_objects as go

import vector

vector.register_awkward()

In [3]:
PATH = "../data/FCC-CLD/gun/out_reco_edm4hep_REC.edm4hep.root"
f = uproot.open(PATH)

f.keys()

['events;4',
 'events;3',
 'configuration_metadata;1',
 'metadata;1',
 'podio_metadata;1']

In [4]:
ev = f["events"]
metadata = f["metadata"]

print(f"EDM4HEP version: {f['podio_metadata/edm4hep___Version'].array()}")
print(f"PodioBuild version: {f['podio_metadata/PodioBuildVersion'].array()}")
print(f"# of events: {ev.num_entries}")

EDM4HEP version: [{major: 0, minor: 99, patch: 1}]
PodioBuild version: [{major: 1, minor: 1, patch: 0}]
# of events: 100


In [5]:
f["podio_metadata"].keys()

['metadata___idTable',
 'metadata___idTable/m_collectionIDs',
 'metadata___idTable/m_names',
 'metadata___CollectionTypeInfo',
 'metadata___CollectionTypeInfo/metadata___CollectionTypeInfo._3',
 'metadata___CollectionTypeInfo/metadata___CollectionTypeInfo._2',
 'metadata___CollectionTypeInfo/metadata___CollectionTypeInfo._1',
 'metadata___CollectionTypeInfo/metadata___CollectionTypeInfo._0',
 'configuration_metadata___idTable',
 'configuration_metadata___idTable/m_collectionIDs',
 'configuration_metadata___idTable/m_names',
 'configuration_metadata___CollectionTypeInfo',
 'configuration_metadata___CollectionTypeInfo/configuration_metadata___CollectionTypeInfo._3',
 'configuration_metadata___CollectionTypeInfo/configuration_metadata___CollectionTypeInfo._2',
 'configuration_metadata___CollectionTypeInfo/configuration_metadata___CollectionTypeInfo._1',
 'configuration_metadata___CollectionTypeInfo/configuration_metadata___CollectionTypeInfo._0',
 'events___idTable',
 'events___idTable/m_

In [6]:
ev["MCParticles"].keys()

['MCParticles.PDG',
 'MCParticles.generatorStatus',
 'MCParticles.simulatorStatus',
 'MCParticles.charge',
 'MCParticles.time',
 'MCParticles.mass',
 'MCParticles.vertex.x',
 'MCParticles.vertex.y',
 'MCParticles.vertex.z',
 'MCParticles.endpoint.x',
 'MCParticles.endpoint.y',
 'MCParticles.endpoint.z',
 'MCParticles.momentum.x',
 'MCParticles.momentum.y',
 'MCParticles.momentum.z',
 'MCParticles.momentumAtEndpoint.x',
 'MCParticles.momentumAtEndpoint.y',
 'MCParticles.momentumAtEndpoint.z',
 'MCParticles.spin.x',
 'MCParticles.spin.y',
 'MCParticles.spin.z',
 'MCParticles.colorFlow.a',
 'MCParticles.colorFlow.b',
 'MCParticles.parents_begin',
 'MCParticles.parents_end',
 'MCParticles.daughters_begin',
 'MCParticles.daughters_end']

# Event by event

In [8]:
iev = 0
import awkward as ak
import vector

vector.register_awkward()

import awkward as ak
import numpy as np

def build_relation_lists_from_offsets(
    begin,
    end,
    relation_index,
    n_items,
    name="relation",
):
    """
    Build per-item relation lists from EDM-style begin/end offsets.

    Parameters
    ----------
    begin : array-like (n_items)
        Begin offsets into the flat relation index buffer.
    end : array-like (n_items)
        End offsets into the flat relation index buffer.
    relation_index : array-like (n_links)
        Flat buffer containing indices of related items.
    n_items : int
        Number of items the relations refer to (e.g. number of MCParticles).
    name : str
        Name of the relation (used for error messages).

    Returns
    -------
    relations : awkward.Array
        Jagged array of length n_items, where relations[i] is
        a list of indices related to item i.
    """

    # Convert to plain numpy for predictable slicing / checks
    begin = np.asarray(begin)
    end = np.asarray(end)

    if begin.shape != end.shape:
        raise ValueError(
            f"{name}: begin/end shape mismatch: {begin.shape} vs {end.shape}"
        )

    if len(begin) != n_items:
        raise ValueError(
            f"{name}: expected begin/end of length {n_items}, got {len(begin)}"
        )

    if np.any(end < begin):
        raise ValueError(f"{name}: found end < begin (corrupt offsets)")

    # Number of relations per item
    counts = end - begin

    # Extract each relation slice explicitly (very clear semantics)
    slices = [
        relation_index[b:e]
        for b, e in zip(begin.tolist(), end.tolist())
    ]

    # Flatten then unflatten to guarantee correct jagged structure
    flat = ak.concatenate(slices, axis=0) if slices else relation_index[:0]

    if len(flat) != int(counts.sum()):
        raise ValueError(
            f"{name}: inconsistent offsets: "
            f"sum(end-begin)={int(counts.sum())}, "
            f"but flattened relation buffer has len={len(flat)}"
        )

    relations = ak.unflatten(flat, counts)

    # Final safety: indices must point to valid items
    if len(flat) > 0:
        min_idx = int(ak.min(flat))
        max_idx = int(ak.max(flat))
        if min_idx < 0 or max_idx >= n_items:
            raise ValueError(
                f"{name}: relation index out of range "
                f"(min={min_idx}, max={max_idx}, n_items={n_items})"
            )

    return relations

def get_mc_particles(ev, iev):
    """
    Build MC particle record for a single event.
    Adds:
      - mc['parents']   : jagged list of parent indices per particle
      - mc['daughters'] : jagged list of daughter indices per particle
    and derived pt/eta/phi/energy.
    """

    # --- base per-particle fields (for ONE event) ---
    mc = ak.zip({
        "pdg": ev["MCParticles.PDG"].array()[iev],
        "genstatus": ev["MCParticles.generatorStatus"].array()[iev],
        "simstatus": ev["MCParticles.simulatorStatus"].array()[iev],
        "charge": ev["MCParticles.charge"].array()[iev],
        "time": ev["MCParticles.time"].array()[iev],

        "px": ev["MCParticles.momentum.x"].array()[iev],
        "py": ev["MCParticles.momentum.y"].array()[iev],
        "pz": ev["MCParticles.momentum.z"].array()[iev],
        "mass": ev["MCParticles.mass"].array()[iev],

        "vx": ev["MCParticles.vertex.x"].array()[iev],
        "vy": ev["MCParticles.vertex.y"].array()[iev],
        "vz": ev["MCParticles.vertex.z"].array()[iev],

        "endx": ev["MCParticles.endpoint.x"].array()[iev],
        "endy": ev["MCParticles.endpoint.y"].array()[iev],
        "endz": ev["MCParticles.endpoint.z"].array()[iev],

        "endpx": ev["MCParticles.momentumAtEndpoint.x"].array()[iev],
        "endpy": ev["MCParticles.momentumAtEndpoint.y"].array()[iev],
        "endpz": ev["MCParticles.momentumAtEndpoint.z"].array()[iev],

        "spinx": ev["MCParticles.spin.x"].array()[iev],
        "spiny": ev["MCParticles.spin.y"].array()[iev],
        "spinz": ev["MCParticles.spin.z"].array()[iev],

        "color_a": ev["MCParticles.colorFlow.a"].array()[iev],
        "color_b": ev["MCParticles.colorFlow.b"].array()[iev],
    })

    nMC = len(mc)

    # --- relation pointers (ONE event) ---
    pb = ev["MCParticles.parents_begin"].array()[iev]
    pe = ev["MCParticles.parents_end"].array()[iev]
    db = ev["MCParticles.daughters_begin"].array()[iev]
    de = ev["MCParticles.daughters_end"].array()[iev]

    parents_index = ev["_MCParticles_parents/_MCParticles_parents.index"].array()[iev]
    daughters_index = ev["_MCParticles_daughters/_MCParticles_daughters.index"].array()[iev]

    parents = build_relation_lists_from_offsets(
        pb, pe, parents_index, nMC, name="parents"
    )

    daughters = build_relation_lists_from_offsets(
        db, de, daughters_index, nMC, name="daughters"
    )

    # Attach (this avoids ak.zip broadcasting issues)
    mc = ak.with_field(mc, parents, "parents")
    mc = ak.with_field(mc, daughters, "daughters")

    # --- derived kinematics via vector ---
    p4 = ak.zip({"px": mc.px, "py": mc.py, "pz": mc.pz, "mass": mc.mass}, with_name="Momentum4D")
    mc = ak.with_field(mc, p4.pt, "pt")
    mc = ak.with_field(mc, p4.eta, "eta")
    mc = ak.with_field(mc, p4.phi, "phi")
    mc = ak.with_field(mc, p4.energy, "energy")

    return mc

mc = get_mc_particles(ev, iev)

nMC = len(mc["pdg"])
nMC

50

In [9]:
for iMC in range(nMC):
    print(f"{iMC}, GenStatus: {mc['genstatus'][iMC]}, PID: {abs(mc['pdg'][iMC])}, pt: {mc['pt'][iMC].round(3)} - (SimStatus: {mc['simstatus'][iMC]})")

0, GenStatus: 3, PID: 11, pt: 1.875 - (SimStatus: 0)
1, GenStatus: 3, PID: 11, pt: 1.875 - (SimStatus: 0)
2, GenStatus: 1, PID: 321, pt: 0.713 - (SimStatus: 83886080)
3, GenStatus: 1, PID: 11, pt: 0.59 - (SimStatus: 83886080)
4, GenStatus: 1, PID: 11, pt: 24.444 - (SimStatus: 83886080)
5, GenStatus: 1, PID: 22, pt: 2.34 - (SimStatus: 83886080)
6, GenStatus: 1, PID: 11, pt: 2.058 - (SimStatus: 150994944)
7, GenStatus: 1, PID: 13, pt: 17.762 - (SimStatus: 33554432)
8, GenStatus: 1, PID: 13, pt: 14.862 - (SimStatus: 33554432)
9, GenStatus: 1, PID: 11, pt: 0.902 - (SimStatus: 83886080)
10, GenStatus: 1, PID: 2212, pt: 28.706 - (SimStatus: 83886080)
11, GenStatus: 1, PID: 22, pt: 11.092 - (SimStatus: 83886080)
12, GenStatus: 1, PID: 321, pt: 0.748 - (SimStatus: 83886080)
13, GenStatus: 0, PID: 11, pt: 0.001 - (SimStatus: 1493172224)
14, GenStatus: 0, PID: 11, pt: 0.002 - (SimStatus: 1493172224)
15, GenStatus: 0, PID: 11, pt: 0.002 - (SimStatus: 1426063360)
16, GenStatus: 0, PID: 11, pt: 0.0

In [10]:
for iMC in range(nMC):
    print(f"{iMC}, parents: {mc['parents'][iMC]}, daughters: {mc['daughters'][iMC]}")

0, parents: [], daughters: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
1, parents: [], daughters: [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
2, parents: [0, 1], daughters: [46, 47, 48, 49]
3, parents: [0, 1], daughters: [43, 44, 45]
4, parents: [0, 1], daughters: [40, 41, 42]
5, parents: [0, 1], daughters: []
6, parents: [0, 1], daughters: [28, 29, 30, 31, 32, 33, 34, 35, 36]
7, parents: [0, 1], daughters: [27]
8, parents: [0, 1], daughters: []
9, parents: [0, 1], daughters: [24, 25, 26]
10, parents: [0, 1], daughters: [21, 22]
11, parents: [0, 1], daughters: [16, 17, 18, 19, 20]
12, parents: [0, 1], daughters: [13, 14, 15]
13, parents: [12], daughters: []
14, parents: [12], daughters: []
15, parents: [12], daughters: []
16, parents: [11], daughters: []
17, parents: [11], daughters: []
18, parents: [11], daughters: []
19, parents: [11], daughters: []
20, parents: [11], daughters: []
21, parents: [10], daughters: [23]
22, parents: [10], daughters: []
23, parents: [21], daughters: []
24, parents: [9]

# Visualize

In [18]:
def mc_to_graph(mc):
    """
    Build a directed graph parent -> child using mc['daughters'].
    Assumes mc is for a single event and mc['daughters'][i] is a list of child indices.
    """
    g = nx.DiGraph()
    n = len(mc)

    g.add_nodes_from(range(n))

    for parent in range(n):
        for child in mc["daughters"][parent]:
            child = int(child)
            if child != parent:
                g.add_edge(parent, child)

    return g

def color_node(genstatus):
    if genstatus == 0: return "red"
    if genstatus == 1: return "blue"
    if genstatus == 2: return "green"
    return "gray"

In [168]:
def draw_mc_graph_interactive(mc, min_energy=None):
    g = mc_to_graph(mc)

    # --- prune graph: if a node fails cut, drop it and its whole subtree ---
    if min_energy is not None:
        E = np.array(mc["energy"], dtype=float)

        roots = [n for n in g.nodes() if g.in_degree(n) == 0]
        keep = set()

        stack = list(roots)
        while stack:
            u = stack.pop()

            # only keep/traverse if this node passes the cut
            if E[u] < min_energy:
                continue

            keep.add(u)
            # only traverse daughters if parent is kept
            for v in g.successors(u):
                # also require that ALL parents of v are kept
                # (avoids keeping v if it has some missing parent)
                if all((p in keep) for p in g.predecessors(v)):
                    stack.append(v)

        g = g.subgraph(keep).copy()

    # Layout: DOT (hierarchical)
    pos = nx.nx_agraph.graphviz_layout(g, prog="dot")

    # --- Build edge traces (as line segments) ---
    edge_x = []
    edge_y = []
    for u, v in g.edges():
        x0, y0 = pos[u]
        x1, y1 = pos[v]
        edge_x += [x0, x1, None]
        edge_y += [y0, y1, None]

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        mode="lines",
        line=dict(width=1),
        hoverinfo="skip",
        opacity=0.25,
    )

    # --- Node positions ---
    nodes = list(g.nodes())
    x = np.array([pos[n][0] for n in nodes], dtype=float)
    y = np.array([pos[n][1] for n in nodes], dtype=float)

    # --- Size / color / alpha (reuse your choices) ---
    pt = np.array(mc["pt"], dtype=float)
    energy = np.array(mc["energy"], dtype=float)

    node_size = np.clip(30 * (pt + 1e-3) ** 0.6, 8, 70)  # plotly uses marker size (not area)

    node_color = [color_node(int(mc["genstatus"][i])) for i in nodes]

    # --- Hover text ---
    hovertext = []
    for i in nodes:
        # base info
        s = (
            f"idx: {i}<br>"
            f"PDG: {int(mc['pdg'][i])}<br>"
            f"genstatus: {int(mc['genstatus'][i])}<br>"
            f"charge: {int(mc['charge'][i])}<br>"            
            f"E: {float(mc['energy'][i]):.4g}<br>"            
            f"pT: {float(mc['pt'][i]):.4g}<br>"
            f"eta: {float(mc['eta'][i]):.4g}<br>"
            f"phi: {float(mc['phi'][i]):.4g}"
        )

        # ΔR to parent(s)
        parents = list(mc["parents"][i])  # jagged -> python list
        if len(parents) == 0:
            s += "<br>ΔR(parent): —"
        else:
            eta_i = float(mc["eta"][i])
            phi_i = float(mc["phi"][i])

            drs = []
            lines = []
            for p in parents:
                p = int(p)
                eta_p = float(mc["eta"][p])
                phi_p = float(mc["phi"][p])

                d_eta = eta_i - eta_p
                d_phi = phi_i - phi_p
                # wrap Δphi into [-pi, pi]
                d_phi = (d_phi + np.pi) % (2 * np.pi) - np.pi

                dr = float(np.hypot(d_eta, d_phi))
                drs.append(dr)

                lines.append(
                    f"p={p} (PDG {int(mc['pdg'][p])}): ΔR={dr:.3g}"
                )

            s += f"<br>ΔR(parent) min: {min(drs):.3g}"
            s += "<br>" + "<br>".join(lines)

        hovertext.append(s)

    # PDG label shown on plot
    text = [f"{int(mc['pdg'][i])}" for i in nodes]

    node_trace = go.Scatter(
        x=x, y=y,
        mode="markers+text",
        text=text,
        textposition="middle center",
        hovertext=hovertext,
        hoverinfo="text",
        textfont=dict(
            size=20,          # <-- increase this (try 14–18)
            color="black",
            family="Arial",
        ),        
        marker=dict(
            size=node_size,
            color=node_color,
            line=dict(width=1, color="white"),
            opacity=0.65,
        ),
    )

    fig = go.Figure(data=[edge_trace, node_trace])
    fig.update_layout(
        title="MC graph (hover for details)",
        showlegend=False,
        hovermode="closest",
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        margin=dict(l=20, r=20, t=40, b=20),
    )
    fig.show()

In [169]:
import plotly.io as pio
# pio.renderers.default = "iframe"
pio.renderers.default = "browser"

draw_mc_graph_interactive(mc, min_energy=0.5)

In [136]:
# import numpy as np
# import plotly.graph_objects as go

# def plot_mc_xyz_3d(mc, which="vertex", min_energy=None, min_pt=None):
#     """
#     Plot MC particles as points in 3D (x,y,z).

#     which: "vertex" or "endpoint"
#     min_energy/min_pt: optional cuts
#     """

#     if which == "vertex":
#         x = np.array(mc["vx"], dtype=float)
#         y = np.array(mc["vy"], dtype=float)
#         z = np.array(mc["vz"], dtype=float)
#         title = "MC particles: production vertex (x,y,z)"
#     elif which == "endpoint":
#         x = np.array(mc["endx"], dtype=float)
#         y = np.array(mc["endy"], dtype=float)
#         z = np.array(mc["endz"], dtype=float)
#         title = "MC particles: endpoint (x,y,z)"
#     else:
#         raise ValueError("which must be 'vertex' or 'endpoint'")

#     pt = np.array(mc["pt"], dtype=float)
#     E  = np.array(mc["energy"], dtype=float)
#     pdg = np.array(mc["pdg"], dtype=int)
#     gen = np.array(mc["genstatus"], dtype=int)

#     # mask
#     mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(z)
#     if min_energy is not None:
#         mask &= (E >= float(min_energy))
#     if min_pt is not None:
#         mask &= (pt >= float(min_pt))

#     idx = np.where(mask)[0]
#     if len(idx) == 0:
#         print("No particles pass the selection.")
#         return

#     # marker size ~ pT (compressed)
#     size = np.clip(4 + 10 * (pt[idx] + 1e-3) ** 0.45, 4, 18)

#     # color by genstatus (simple & quick)
#     # plotly wants numeric colors if using colorscale; we can just pass genstatus as color
#     color = gen[idx]

#     hovertext = [
#         f"idx: {i}"
#         f"<br>PDG: {int(pdg[i])}"
#         f"<br>genstatus: {int(gen[i])}"
#         f"<br>pT: {float(pt[i]):.4g}"
#         f"<br>E: {float(E[i]):.4g}"
#         f"<br>x,y,z: ({float(x[i]):.4g}, {float(y[i]):.4g}, {float(z[i]):.4g})"
#         for i in idx
#     ]

#     fig = go.Figure(
#         data=go.Scatter3d(
#             x=x[idx], y=y[idx], z=z[idx],
#             mode="markers",
#             marker=dict(
#                 size=size,
#                 color=color,
#                 opacity=0.75,
#             ),
#             hovertext=hovertext,
#             hoverinfo="text",
#         )
#     )

#     fig.update_layout(
#         title=title,
#         scene=dict(
#             xaxis_title="x",
#             yaxis_title="y",
#             zaxis_title="z",
#             camera=dict(
#                 eye=dict(x=0.0, y=0.0, z=2.5),
#                 up=dict(x=0.0, y=1.0, z=0.0),
#             ),
#             aspectmode="data",  # preserve geometry
#         ),
#         margin=dict(l=0, r=0, t=40, b=0),
#         showlegend=False,       
#     )

#     fig.show()
# plot_mc_xyz_3d(mc, which="vertex", min_energy=0.5)

In [137]:
# def plot_calo_hits_3d(ev, iev, clip_mm=8000, log_size=True):
#     # Pick the collections you want
#     calo_cols = {
#         "ECALBarrel": "blue",
#         "ECALEndcap": "blue",
#         "HCALBarrel": "orange",
#         "HCALEndcap": "orange",
#         "HCALOther": "orange",
#         "MUON": "purple",
#     }

#     traces = []

#     for name, color in calo_cols.items():
#         if name not in ev.keys():
#             continue

#         # EDM4hep calorimeter hits usually have fields like:
#         # position.x, position.y, position.z, energy
#         # (Awkward record arrays)
#         x = np.asarray(ev[f"{name}.position.x"].array()[iev])
#         y = np.asarray(ev[f"{name}.position.y"].array()[iev])
#         z = np.asarray(ev[f"{name}.position.z"].array()[iev])
#         e = np.asarray(ev[f"{name}.energy"].array()[iev])
        
#         # basic cleaning
#         m = np.isfinite(x) & np.isfinite(y) & np.isfinite(z) & np.isfinite(e) & (e > 0)
#         x, y, z, e = x[m], y[m], z[m], e[m]

#         if len(e) == 0:
#             continue

#         # marker size ~ energy (log or linear)
#         if log_size:
#             size = np.clip(2 + 4 * np.log10(e / np.max(e) + 1e-6), 1, 10)
#         else:
#             size = np.clip(1 + 30 * (e / np.max(e)), 1, 12)

#         size += 5
        
#         traces.append(
#             go.Scatter3d(
#                 x=np.clip(x, -clip_mm, clip_mm),
#                 y=np.clip(y, -clip_mm, clip_mm),
#                 z=np.clip(z, -clip_mm, clip_mm),
#                 mode="markers",
#                 marker=dict(size=size, color=color, opacity=1),
#                 name=name,
#                 hovertext=[f"{name}<br>E={val:.3g}" for val in e],
#                 hoverinfo="text",
#             )
#         )

#     fig = go.Figure(data=traces)
#     fig.update_layout(
#         title=f"Calorimeter hits (event {iev})",
#         scene=dict(
#             xaxis_title="x",
#             yaxis_title="y",
#             zaxis_title="z",
#             aspectmode="data",
#             camera=dict(
#                 eye=dict(x=0.0, y=0.0, z=2.5),
#                 up=dict(x=0.0, y=1.0, z=0.0),
#             ),
#         ),
#         margin=dict(l=0, r=0, t=40, b=0),
#         legend=dict(itemsizing="constant"),
#     )
#     fig.show()
    
# plot_calo_hits_3d(ev, iev, clip_mm=1000000)    

In [173]:
import numpy as np
import plotly.graph_objects as go

def plot_mc_xyz_3d(
    mc,
    which="vertex",
    min_MCenergy=None,
    min_MCpt=None,
    msk_MCpid=None,
    msk_GenStatus=None,
    ev=None,
    iev=None,
    plot_hits=False,
    clip_mm=8000,
    log_hit_size=True,
):
    """
    Plot MC particles in 3D (x,y,z). Optionally overlay calorimeter hits.

    Parameters
    ----------
    mc : awkward.Array (single event record array you built)
    which : {"vertex","endpoint"}
    min_energy, min_pt : optional cuts on MC particles
    ev, iev : needed if plot_hits=True
    plot_hits : bool
        Overlay ECAL/HCAL/MUON hits on same display.
    clip_mm : float
        Coordinate clipping for hits.
    log_hit_size : bool
        Use log scaling for hit marker sizes.
    """

    if which == "vertex":
        x = np.array(mc["vx"], dtype=float)
        y = np.array(mc["vy"], dtype=float)
        z = np.array(mc["vz"], dtype=float)
        title = "MC particles: production vertex (x,y,z)"
    elif which == "endpoint":
        x = np.array(mc["endx"], dtype=float)
        y = np.array(mc["endy"], dtype=float)
        z = np.array(mc["endz"], dtype=float)
        title = "MC particles: endpoint (x,y,z)"
    else:
        raise ValueError("which must be 'vertex' or 'endpoint'")

    pt  = np.array(mc["pt"], dtype=float)
    E   = np.array(mc["energy"], dtype=float)
    pdg = np.array(mc["pdg"], dtype=int)
    gen = np.array(mc["genstatus"], dtype=int)
    
    # mask MC
    mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(z)
    if min_MCenergy is not None:
        mask &= (E >= float(min_MCenergy))
    if min_MCpt is not None:
        mask &= (pt >= float(min_MCpt))
    if msk_MCpid is not None:
        mask &= (pdg == float(msk_MCpid))
    if msk_GenStatus is not None:
        mask &= (gen == 1)        

    idx = np.where(mask)[0]
    if len(idx) == 0:
        print("No particles pass the selection.")
        return

    # MC marker size ~ pT (compressed)
    mc_size = np.clip(4 + 10 * (pt[idx] + 1e-3) ** 0.45, 4, 18)
    
    mc_hovertext = [
        f"idx: {i}"
        f"<br>PDG: {int(pdg[i])}"
        f"<br>genstatus: {int(gen[i])}"
        f"<br>pT: {float(pt[i]):.4g}"
        f"<br>E: {float(E[i]):.4g}"
        f"<br>x,y,z: ({float(x[i]):.4g}, {float(y[i]):.4g}, {float(z[i]):.4g})"
        for i in idx
    ]

    traces = []

    # --- optional hits overlay (your convention) ---
    if plot_hits:
        if ev is None or iev is None:
            raise ValueError("To plot hits, pass ev=<events> and iev=<event index>.")
        calo_cols = {
            "ECALBarrel": "blue",
            "ECALEndcap": "blue",
            "HCALBarrel": "orange",
            "HCALEndcap": "orange",
            "HCALOther": "orange",
            "MUON": "purple",
        }

        for name, color in calo_cols.items():
            if name not in ev.keys():
                continue

            hx = np.asarray(ev[f"{name}.position.x"].array()[iev])
            hy = np.asarray(ev[f"{name}.position.y"].array()[iev])
            hz = np.asarray(ev[f"{name}.position.z"].array()[iev])
            he = np.asarray(ev[f"{name}.energy"].array()[iev])

            m = np.isfinite(hx) & np.isfinite(hy) & np.isfinite(hz) & np.isfinite(he) & (he > 0.0)
            hx, hy, hz, he = hx[m], hy[m], hz[m], he[m]

            if len(he) == 0:
                continue

            if log_hit_size:
                hsize = np.clip(2 + 4 * np.log10(he / np.max(he) + 1e-6), 1, 10)
            else:
                hsize = np.clip(1 + 30 * (he / np.max(he)), 1, 12)

            hsize = hsize + 5  # keep your offset

            traces.append(
                go.Scatter3d(
                    x=np.clip(hx, -clip_mm, clip_mm),
                    y=np.clip(hy, -clip_mm, clip_mm),
                    z=np.clip(hz, -clip_mm, clip_mm),
                    mode="markers",
                    marker=dict(size=hsize, color=color, opacity=1.0),
                    name=name,
                    hovertext=[f"{name}<br>E={val:.3g}" for val in he],
                    hoverinfo="text",
                )
            )

        title = title + f" + calo hits (event {iev})"

    # --- MC particles as X markers ---
    traces.append(
        go.Scatter3d(
            x=x[idx], y=y[idx], z=z[idx],
            mode="markers",
            marker=dict(
                size=mc_size*0.8,
                color="red",       # keep your simple coloring
                opacity=0.8,
                symbol="x",           # <-- key: avoid overlap confusion
                line=dict(width=2),   # makes the X readable
            ),
            name="MC particles",
            hovertext=mc_hovertext,
            hoverinfo="text",
        )
    )

    fig = go.Figure(data=traces)

    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title="x",
            yaxis_title="y",
            zaxis_title="z",
            aspectmode="data",
            camera=dict(
                eye=dict(x=0.0, y=0.0, z=2.5),
                up=dict(x=0.0, y=1.0, z=0.0),
            ),
        ),
        margin=dict(l=0, r=0, t=40, b=0),
        showlegend=True,
        legend=dict(itemsizing="constant"),
    )

    fig.show()

In [174]:
plot_mc_xyz_3d(mc, which="endpoint", ev=ev, iev=iev, plot_hits=True, msk_GenStatus=1, min_MCenergy=0.5)#, msk_MCpid=13)

In [175]:
plot_mc_xyz_3d(mc, which="endpoint", ev=ev, iev=iev, plot_hits=True)