# Truth matching

In [None]:
import ROOT

In [None]:
import graphviz

In [None]:
from particle import Particle, ParticleNotFound

In [None]:
def make_graph_MC(tracks):
    parents_map = {i: mother_id for i, track in enumerate(tracks) if (mother_id:=track.GetMotherId(), mother_id != -2)}
    #print(parents_map)
    level1 = {}
    for track, mother in parents_map.items():
        if mother not in level1:
            level1[mother] = [track]
        else:
            level1[mother].append(track)
    #print(level1)
    dot = graphviz.Digraph(comment='Vertex')
    dot.attr(rankdir="LR")
    dot.node("-1")
    for mother, track_ids in level1.items():
        for track in track_ids:
            pdgid = tracks[track].GetPdgCode()
            try:
                particle = Particle.from_pdgid(pdgid)
            except ParticleNotFound as e:
                print(e)
                #print(pdgid)
                assert len(str(pdgid)) == 10
                #I = pdgid % 10
                A = pdgid // 10 % 1000
                Z = pdgid // 10000 % 1000
                particle = Particle.from_nucleus_info(a=A, z=Z)
                # TODO how to deal with 1000390981?
                
            style = 'solid' if particle.charge else 'dashed'
            #print(particle.name, particle.charge)
            dot.node(str(track), '', shape='point')
            dot.edge(str(mother), str(track), style=style, label=f"{track}: {particle.name}")
    #print(dot)
    dot=dot.unflatten(stagger=10)
    return dot

In [None]:
def make_graph_reco(vertices):
    top = graphviz.Digraph(comment='Vertex')
    for vertex in vertices:
        #help(vertex)
        id = vertex.getId()
        dot = graphviz.Digraph(comment='Vertex')
        dot.attr(rankdir="LR")
        dot.node(str(id), str(id))
        for i in range(vertex.getNTracks()):
            track_params = vertex.getParameters(i)
            #dot.node(f"{id}_track_{i}", f"track_{i}")
            #dot.edge(str(id), f"{id}_track_{i}")
            #help(track_params)
            track = track_params.getTrack()
            mc_id = track.getMcTrackId()
            dot.node(f"track_{mc_id}", f"track_{mc_id}")
            dot.edge(str(id), f"track_{mc_id}")
            #help(track)
        top.subgraph(dot)
    top.attr(rankdir="LR")
    return top

In [None]:
from vertex_analysis import find_MC_track

"""
def find_MC_track(track, event):
    link = event.Digi_TargetClusterHits2MCPoints[0]
    points = track.getPoints()
    track_ids = []
    for p in points:
        digi_hit = event.Digi_advTargetClusters[p.getRawMeasurement().getHitId()]
        wlist = link.wList(p.getRawMeasurement().getDetId())
        for index, weight in wlist:
            point = event.AdvTargetPoint[index]
            track_id = point.GetTrackID()
            if track_id == -2:
                continue
            track_ids.append(track_id)
    most_common_track, count = Counter(track_ids).most_common(1)[0]
    if count >= len(points) * 0.7:
        # truth match if ≥ 70 % of hits are related to a single MCTrack
        return most_common_track
    return -1
    # TODO check for ghosts/clones?
    # add to track_fit.py or separate script?
    # LHCb truth match if ≥ 70 % of hits are related to a single MCTrack
    # Ghost rate: fraction of tracks not truth matched
"""

In [None]:
from vertex_analysis import match_vertex


In [None]:
from vertex_analysis import find_true_vertex

"""
def find_true_vertex(track, event):
    id = track.getMcTrackId()
    if id >= 0:
        print(id)
        mc_track = event.MCTrack[id]
        true_vertex = ROOT.TVector3()
        mc_track.GetStartVertex(true_vertex)
        return true_vertex
    return None
"""

In [None]:
# need to define track in acceptance

In [None]:
f = ROOT.TFile.Open("numu_dig_selected_PR_tracked_vertexed.root", "read")

In [None]:
tree = f.cbmsim

In [None]:
i=0
for event in tree:
    if i > 5:
        break
    dot = make_graph_MC(event.MCTrack)

    if dot:
        pass
        #display(dot)
    for track in event.genfit_tracks:
        track.setMcTrackId(find_MC_track(track, event))
        true_vertex = find_true_vertex(track, event)
        if true_vertex:
            print(true_vertex.X(), true_vertex.Y(), true_vertex.Z())
    for vertex in event.RAVE_vertices:
        print(match_vertex(vertex, event))
    dot = make_graph_reco(event.RAVE_vertices)
    if dot:
        pass
        #display(dot)
    i+=1