In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
import uproot
import logging

import acts

from itertools import cycle

import awkward as ak

from gnn4itk_tools.detector_plotter import DetectorPlotter



In [3]:
match_df = pd.read_csv("../tmp/no_threshold_2/performance_gnn_plus_ckf.csv", dtype={"particle_id": np.uint64})
match_df = match_df[ match_df.event == 0 ].copy()

In [4]:
particles = ak.to_dataframe(uproot.open("../tmp/simdata/particles_initial.root:particles").arrays(), how="inner").reset_index(drop=True)
particles = particles[ particles.event_id == 0 ].copy()
particles = particles[ particles.particle_id.isin(match_df.particle_id) ].copy()
particles.shape

(887, 22)

In [5]:
particles["matched"] = particles.particle_id.map(dict(zip(match_df.particle_id, match_df.matched)))
assert not any(pd.isna(particles.matched))
particles.head(2)

Unnamed: 0,event_id,particle_id,particle_type,process,vx,vy,vz,vt,px,py,...,eta,phi,pt,p,vertex_primary,vertex_secondary,particle,generation,sub_particle,matched
2,0,4503599677702144,-211,0,-0.006878,0.006443,26.299278,0.0,1.618853,-0.872867,...,-1.899954,-0.494505,1.83918,6.28555,1,0,3,0,0,1
3,0,4503599694479360,321,0,-0.006878,0.006443,26.299278,0.0,-0.340527,-1.401685,...,-3.243134,-1.80912,1.442456,18.5016,1,0,4,0,0,1


In [6]:
hits = uproot.open("../tmp/simdata/hits.root:hits").arrays(library="pd")
hits = hits[ (hits.event_id == 0) & (hits.tt < 25.0) ].copy()
hits["hit_id"] = np.arange(len(hits))
hits.head(2)

simhit_map = pd.read_csv("../tmp/no_threshold_2/digi/event000000000-measurement-simhit-map.csv")
measId_to_hitID = dict(zip(simhit_map.measurement_id, simhit_map.hit_id))
hitId_to_particleId = dict(zip(hits.hit_id, hits.particle_id))
measId_to_particleId = { m: hitId_to_particleId[ measId_to_hitID[ m ] ] for m in simhit_map.measurement_id }

def process_prototracks(tracks):
    tracks["hit_id"] = tracks["measurementId"].map(measId_to_hitID)
    tracks["tx"] = tracks.hit_id.map(dict(zip(hits.hit_id, hits.tx)))
    tracks["ty"] = tracks.hit_id.map(dict(zip(hits.hit_id, hits.ty)))
    tracks["tz"] = tracks.hit_id.map(dict(zip(hits.hit_id, hits.tz)))
    tracks["geometry_id"] = tracks.hit_id.map(dict(zip(hits.hit_id, hits.geometry_id)))
    tracks["particle_id"] = tracks.hit_id.map(hitId_to_particleId)
    return tracks

In [7]:
prototracks = pd.read_csv("../tmp/no_threshold_2/gnn_plus_ckf/event000000000-prototracks.csv")
prototracks = process_prototracks(prototracks)
gnn_prototracks = [ t for _, t in prototracks.groupby("trackId") ]
print("GNN prototracks:",len(gnn_prototracks))

GNN prototracks: 2819


In [28]:
gnn_prototracks = list(reversed(sorted(gnn_prototracks, key=lambda t: len(t))))

In [9]:
graph = pd.read_csv("../tmp/no_threshold_2/gnn_plus_ckf/event000000000-exatrkx-graph.csv")
spacepoints = pd.read_csv("../tmp/no_threshold_2/digi/event000000000-spacepoint.csv")

for edge, poscols in [("edge0", ["x0","y0","z0"]), ("edge1", ["x1","y1","z1"])]:
    for c in poscols:
        graph[c] = graph[edge].map(dict(zip(spacepoints.measurement_id, spacepoints[c[:1]])))
        
graph["particle0"] = graph["edge0"].map(measId_to_particleId)
graph["particle1"] = graph["edge1"].map(measId_to_particleId)

graph["r0"] = np.hypot(graph.x0, graph.y0)
graph["r1"] = np.hypot(graph.x1, graph.y1)

graph["good"] = (graph["particle0"] == graph["particle1"])

graph.head(3)

Unnamed: 0,edge0,edge1,weight,x0,y0,z0,x1,y1,z1,particle0,particle1,r0,r1,good
0,5,2340,0.999589,98.324997,5.225,-1515.59998,85.474998,5.275,-1315.59998,774619136327155712,774619136327155712,98.463728,85.637614,True
1,13,2348,0.999716,63.224998,-3.725,-1515.59998,54.775002,-3.575001,-1315.59998,607985950970085376,607985950970085376,63.334636,54.891542,True
2,14,83,0.999962,63.825001,7.225,-1515.59998,63.886929,7.20568,-1516.80005,607985949997006848,607985949997006848,64.232635,64.292002,True


In [10]:
sum(graph["good"])/len(graph)

0.993388967161591

In [1]:
import networkx as nx

In [21]:
g = nx.DiGraph()

for _, row in graph.iterrows():
    color="green" if row.good else "red"
    if row.r0 < row.r1:
        g.add_edge(row.edge0, row.edge1, color=color, weight=row.weight)
    else:
        g.add_edge(row.edge1, row.edge0, color=color, weight=row.weight)

In [22]:
tracks = [ g.subgraph(c) for c in nx.weakly_connected_components(g) ]
tracks = [ t for t in tracks if len(t) >= 3 ]
len(tracks)

2972

In [25]:
tracks = list(reversed(sorted(tracks, key=lambda t: len(t))))

In [30]:
len(tracks[1])

19

In [31]:
len(gnn_prototracks[1])

19