In [None]:
import networkx as nx
import pickle
import numpy as np
import uproot
import matplotlib.pyplot as plt

In [None]:
g = pickle.load(open("../../graphs/cleanup_g_0.pkl", "rb"))

In [None]:
networkx.write_graphml(g, "cleanup.graphml")

In [None]:
sub_graphs = nx.weakly_connected_components(g)
for sg in sub_graphs:
    if len(sg)>1:
        print(sg)

In [None]:
g.nodes[("cp", 58)]

In [None]:
sucs = list(g.successors(("cp", 47)))
sucs

In [None]:
g.nodes[("elem", 329)]["typ"], g.nodes[("elem", 329)]["pt"], g.nodes[("elem", 329)]["eta"], g.nodes[("elem", 329)]["phi"]

In [None]:
tt = uproot.open("../../graphs/pfntuple_singleele.root")["pfana/pftree"]

In [None]:
arr = tt.arrays([
    "caloparticle_tid", "caloparticle_pid", "caloparticle_energy", "caloparticle_pt", "caloparticle_eta",
    "caloparticle_to_element", "caloparticle_to_element_cmp", "caloparticle_idx_trackingparticle",
    "trackingparticle_to_element", "trackingparticle_to_element_cmp",
    "trackingparticle_tid", "trackingparticle_pid", "trackingparticle_energy", "trackingparticle_pt", "trackingparticle_eta",
    "simtrack_tid", "simtrack_parent_tid", "simtrack_pdgid", "simtrack_energy", "simtrack_gpidx",
    "calohit_tid", "calohit_energy",
    "trkhit_tid", "trkhit_energy", "gen_pdgid", "gen_energy",
    "element_type", "element_energy",
])

In [None]:
iev = 13
simtrack_g = nx.DiGraph()

for tid, pid, energy, gpidx in zip(arr[iev]["simtrack_tid"], arr[iev]["simtrack_pdgid"], arr[iev]["simtrack_energy"], arr[iev]["simtrack_gpidx"]):
    simtrack_g.add_node(("st", tid), pid=abs(pid), typ="st", energy=energy, ecalo=0, etrk=0, is_cp=0, is_tp=0)
    if gpidx!=-1:
        gpidx = gpidx - 1
        energy = arr[iev]["gen_energy"][gpidx]
                    
        simtrack_g.add_node(("gp", gpidx), pid=abs(arr[iev]["gen_pdgid"][gpidx]), energy=energy, ecalo=energy, etrk=energy)
        simtrack_g.add_edge(("gp", gpidx), ("st", tid))

for tid in arr[iev]["caloparticle_tid"]:
    simtrack_g.nodes[("st", tid)]["is_cp"] = 1

for tid in arr[iev]["trackingparticle_tid"]:
    simtrack_g.nodes[("st", tid)]["is_tp"] = 1
    
for tid, parent_tid in zip(arr[iev]["simtrack_tid"], arr[iev]["simtrack_parent_tid"]):
    if parent_tid != 2**32-1:
        simtrack_g.add_edge(("st", parent_tid), ("st", tid))
        
for tid, energy in zip(arr["calohit_tid"][iev], arr["calohit_energy"][iev]):
    if ("st", tid) in simtrack_g.nodes:
        simtrack_g.nodes[("st", tid)]["ecalo"] += energy

for tid, energy in zip(arr["trkhit_tid"][iev], arr["trkhit_energy"][iev]):
    if ("st", tid) in simtrack_g.nodes:
        simtrack_g.nodes[("st", tid)]["etrk"] += energy

In [None]:
sub_graphs = list(nx.weakly_connected_components(simtrack_g))
for isg, sg in enumerate(sub_graphs):
    if len(sg)>1:
        print(isg, [n for n in sg if n[0]=="cp"], [n for n in sg if n[0]=="st"], len(sg))

In [None]:
def color_node(g, n):
    if n[0] == "gp":
        return "blue"
    elif n[0] == "st":
        is_cp = g.nodes[n]["is_cp"]
        is_tp = g.nodes[n]["is_tp"]
        if is_cp and is_tp:
            return "gray"
        elif is_cp:
            return "red"
        elif is_tp:
            return "cyan"
        else:
            return "green"
    else:
        raise Exception(n)
        
def label_node(g, n):
    return "{}".format(g[n]["pid"])

In [None]:
ssg = simtrack_g.subgraph(sub_graphs[0])
node_color = [color_node(ssg, nc) for nc in ssg.nodes]
# node_size = [np.clip(100**ssg.nodes[n]["energy"]*2, 1, 200) for n in ssg.nodes]
node_size = [5+ssg.nodes[n]["energy"] for n in ssg.nodes]
alpha1 = [np.clip(ssg.nodes[n]["ecalo"]/ssg.nodes[n]["energy"], 0.2, 1.0) for n in ssg.nodes]
alpha2 = [np.clip(ssg.nodes[n]["etrk"]/ssg.nodes[n]["energy"], 0.2, 1.0) for n in ssg.nodes]
labels = {n: label_node(ssg.nodes, n) for n in ssg.nodes}
pos = nx.nx_agraph.graphviz_layout(ssg, prog="dot")
fig = plt.figure(figsize=(15,10))
nx.draw_networkx_nodes(
    ssg, pos,
    node_color=node_color,
    node_size=node_size,
    alpha = alpha1,
);
nx.draw_networkx_edges(
    ssg, pos,
);
nx.draw_networkx_labels(
    ssg, pos,
    labels=labels,
    font_size=8
);

In [None]:
cp_g = nx.DiGraph()

ielem = 0
for typ, energy in zip(arr[iev]["element_type"], arr[iev]["element_energy"]):
    if typ not in [2,3,7]:
        cp_g.add_node(
            ("elem", ielem),
            typ=typ,
            energy=energy,
        )
    ielem += 1

icp = 0
for tid, pid, energy, itp in zip(arr[iev]["caloparticle_tid"], arr[iev]["caloparticle_pid"], arr[iev]["caloparticle_energy"], arr[iev]["caloparticle_idx_trackingparticle"]):
    cp_g.add_node(
        ("cp", icp),
        pid=abs(pid),
        typ="cp",
        energy=energy,
    )
    msk = arr[iev]["caloparticle_to_element"]["caloparticle_to_element.first"] == icp
    elem_indices = arr[iev]["caloparticle_to_element"]["caloparticle_to_element.second"][msk]
    elem_cmp = arr[iev]["caloparticle_to_element_cmp"][msk]
    for ielem, cmp in zip(elem_indices, elem_cmp):
        if ("elem", ielem) in cp_g.nodes:
            cp_g.add_edge(("cp", icp), ("elem", ielem), w=cmp)

    if itp != -1:
        msk = arr[iev]["trackingparticle_to_element"]["trackingparticle_to_element.first"] == itp
        elem_indices = arr[iev]["trackingparticle_to_element"]["trackingparticle_to_element.second"][msk]
        elem_cmp = arr[iev]["trackingparticle_to_element_cmp"][msk]
        for ielem, cmp in zip(elem_indices, elem_cmp):
            if ("elem", ielem) in cp_g.nodes:
                cp_g.add_edge(("cp", icp), ("elem", ielem), w=cmp*energy)
            
    icp += 1

In [None]:
sub_graphs = list(nx.weakly_connected_components(cp_g))
subgraph_indices = []
for isg, sg in enumerate(sub_graphs):
    if len(sg)>1:
        print(isg, [n for n in sg if n[0]=="cp"], len(sg))
        subgraph_indices.append(isg)

In [None]:
def label_node_cp(nodes, n):
    if n[0] == "elem":
        return "{}".format(nodes[n]["typ"])
    if n[0] == "cp":
        return "{}".format(nodes[n]["pid"])
    return ""

In [None]:
ssg = cp_g.subgraph(sub_graphs[subgraph_indices[0]])

edge_widths = [np.clip(ssg.edges[e]["w"]/10, 0.01, 10) for e in ssg.edges]
arrowsize = [np.clip(ssg.edges[e]["w"]/5, 0.01, 100) for e in ssg.edges]
labels = {n: label_node_cp(ssg.nodes, n) for n in ssg.nodes}
node_size = [5+ssg.nodes[n]["energy"] for n in ssg.nodes]

pos = nx.nx_agraph.graphviz_layout(ssg, prog="dot")
fig = plt.figure(figsize=(15,10))
nx.draw_networkx_nodes(
    ssg, pos,
    node_size=node_size,
);

nx.draw_networkx_edges(
    ssg, pos,
    width=edge_widths,
    arrows=False
);

nx.draw_networkx_labels(
    ssg, pos,
    labels=labels,
    font_size=12
);