In [None]:
%matplotlib inline

In [None]:
import sklearn
import sklearn.metrics
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas
import mplhep
import pickle
import awkward
import glob
import bz2
import os
import tqdm
import fastjet
import vector
import networkx as nx

mplhep.style.use("CMS")

In [None]:
def map_pdgid_to_candid(pdgid):
    pdgid = abs(pdgid)
    if pdgid in [22, 11, 13]:
        return pdgid

    # charged hadron
    if pdgid in [211, 321, 2212]:
        return 211

    # neutral hadron
    return 130

In [None]:
import sys
sys.path += ["../../mlpf/"]

import jet_utils

sys.path += ["../../mlpf/plotting/"]

from plot_utils import ELEM_LABELS_CMS, ELEM_NAMES_CMS
from plot_utils import CLASS_LABELS_CMS, CLASS_NAMES_CMS
from plot_utils import cms_label, sample_label
from plot_utils import pid_to_text

In [None]:
!ls /media/joosep/data/20240823_simcluster/nopu/

In [None]:
sample = "SingleGammaFlatPt1To1000_pythia8_cfi/"

maxfiles = 1

plot_outpath = "cms-simvalidation/{}/".format(sample)
if not os.path.isdir(plot_outpath):
    os.makedirs(plot_outpath)

In [None]:
pickle_data = sum(
    [
        pickle.load(bz2.BZ2File(f, "r"))
        for f in tqdm.tqdm(list(glob.glob("/media/joosep/data/20240823_simcluster/nopu/{}/raw/*.pkl.bz2".format(sample)))[:maxfiles])
    ],
    [],
)

# compute phi from sin_phi, cos_phi for the ygen and ycand
for i in range(len(pickle_data)):
    for coll in ["ytarget", "ycand"]:
        pickle_data[i][coll] = pandas.DataFrame(pickle_data[i][coll])
        pickle_data[i][coll]["phi"] = np.arctan2(pickle_data[i][coll]["sin_phi"], pickle_data[i][coll]["cos_phi"])

In [None]:
arrs_awk = {}
arrs_flat = {}
for coll in ["Xelem"]:
    arrs_awk[coll] = {}
    arrs_flat[coll] = {}
    for feat in ["typ", "pt", "eta", "phi", "energy"]:
        arrs_awk[coll][feat] = awkward.from_regular(
            [np.array(p[coll][feat][p[coll]["typ"] != 0].tolist()) for p in pickle_data]
        )
        arrs_flat[coll][feat] = awkward.from_regular([np.array(p[coll][feat].tolist()) for p in pickle_data])

for coll in ["ytarget", "ycand"]:
    arrs_awk[coll] = {}
    arrs_flat[coll] = {}
    for feat in ["pid", "pt", "eta", "phi", "energy"]:
        arrs_awk[coll][feat] = awkward.from_regular(
            [np.array(p[coll][feat][p[coll]["pid"] != 0].tolist()) for p in pickle_data]
        )
        arrs_flat[coll][feat] = awkward.from_regular([np.array(p[coll][feat].tolist()) for p in pickle_data])

arrs_awk["ytarget"]["ispu"] = awkward.from_regular([np.array(p["ytarget"]["ispu"][p["ytarget"]["pid"] != 0].tolist()) for p in pickle_data])
arrs_flat["ytarget"]["ispu"] = awkward.from_regular([np.array(p["ytarget"]["ispu"].tolist()) for p in pickle_data])

if "pythia" in pickle_data[0].keys():
    arrs_flat["pythia"] = {}
    for ifeat, feat in enumerate(["pid", "pt", "eta", "phi", "energy"]):
        #         arrs_awk["pythia"][feat] = awkward.from_regular(
        #             [np.array(p["pythia"][:, ifeat][p[coll][:, 0]!=0].tolist()) for p in pickle_data]
        #         )
        arr = []
        for p in pickle_data:
            mask_invis = (
                (np.abs(p["pythia"][:, 0]) != 12) & (np.abs(p["pythia"][:, 0]) != 14) & (np.abs(p["pythia"][:, 0]) != 16)
            )
            arr.append(np.array(p["pythia"][:, ifeat][mask_invis].tolist()))

        arrs_flat["pythia"][feat] = awkward.from_regular(arr)


genmet_cmssw = np.array([pickle_data[i]["genmet"][0, 0] for i in range(len(pickle_data))])
genjet_cmssw = awkward.from_iter([pickle_data[i]["genjet"] for i in range(len(pickle_data))])
genjet_cmssw = vector.awk(
    awkward.zip(
        {   
            "pt": genjet_cmssw[:, :, 0],
            "eta": genjet_cmssw[:, :, 1],
            "phi": genjet_cmssw[:, :, 2],
            "energy": genjet_cmssw[:, :, 3],
        }
    )
)

ytarget_met = np.sqrt(awkward.sum(
    (arrs_awk["ytarget"]["pt"] * np.sin(arrs_awk["ytarget"]["phi"]))**2 + (arrs_awk["ytarget"]["pt"] * np.cos(arrs_awk["ytarget"]["phi"]))**2,
    axis=1
))

ycand_met = np.sqrt(awkward.sum(
    (arrs_awk["ycand"]["pt"] * np.sin(arrs_awk["ycand"]["phi"]))**2 + (arrs_awk["ycand"]["pt"] * np.cos(arrs_awk["ycand"]["phi"]))**2,
    axis=1
))

In [None]:
iev = 15

plt.figure(figsize=(5, 5))
plt.scatter(
    arrs_awk["Xelem"]["eta"][iev],
    arrs_awk["Xelem"]["phi"][iev],
    s=arrs_awk["Xelem"]["pt"][iev],
    alpha=0.4,
    c=arrs_awk["Xelem"]["typ"][iev]
)

plt.scatter(
    arrs_awk["ytarget"]["eta"][iev],
    arrs_awk["ytarget"]["phi"][iev],
    s=arrs_awk["ytarget"]["pt"][iev],
    alpha=0.4,
    c=arrs_awk["ytarget"]["pid"][iev],
    marker="x"
)

plt.xlim(-5,5)
plt.ylim(-4,4)

In [None]:
df = pandas.DataFrame()
df["Xelem_energy"] = arrs_flat["Xelem"]["energy"][iev]
df["Xelem_eta"] = arrs_flat["Xelem"]["eta"][iev]
df["Xelem_phi"] = arrs_flat["Xelem"]["phi"][iev]
df["Xelem_typ"] = [int(x) for x in arrs_flat["Xelem"]["typ"][iev]]

df["ytarget_energy"] = arrs_flat["ytarget"]["energy"][iev]
df["ytarget_eta"] = arrs_flat["ytarget"]["eta"][iev]
df["ytarget_eta"] = arrs_flat["ytarget"]["eta"][iev]
df["ytarget_typ"] = [int(x) for x in arrs_flat["ytarget"]["pid"][iev]]

df = df[df["Xelem_energy"]>2]
df = df.sort_values("Xelem_energy", ascending=False).head(10)

In [None]:
df2 = pandas.DataFrame()
df2["energy"] = arrs_flat["pythia"]["energy"][iev]
df2["eta"] = arrs_flat["pythia"]["eta"][iev]
df2["phi"] = arrs_flat["pythia"]["phi"][iev]
df2["pid"] = arrs_flat["pythia"]["pid"][iev]
df2

In [None]:
df

In [None]:
graph = nx.DiGraph()
for igen in range(len(df2)):
    graph.add_node(("gen", igen), energy=df2["energy"].iloc[igen], pid=df2["pid"].iloc[igen])

for itgt in range(len(df)):
    graph.add_node(("elem", itgt), energy=df["Xelem_energy"].iloc[itgt], pid=df["Xelem_typ"].iloc[itgt])
    if df["ytarget_typ"].iloc[itgt] != 0:
        graph.add_node(("tgt", itgt), energy=df["ytarget_energy"].iloc[itgt], pid=df["ytarget_typ"].iloc[itgt])
        graph.add_edge(("gen", 0), ("tgt", itgt))
        graph.add_edge(("tgt", itgt), ("elem", itgt))

elem_nodes = [n for n in graph.nodes if n[0] == "elem"]

# graph.graph['graph'] = {'rankdir': 'LR'}  # Optional: set layout direction

# Create a subgraph for nodes at the same rank
# In NetworkX, you do this by defining a dictionary of subgraphs
subgraphs = []
same_rank_subgraph = {'rank': 'same', 'nodes': elem_nodes}
subgraphs.append(same_rank_subgraph)

# Store the subgraph information in the graph
graph.graph['subgraphs'] = subgraphs

In [None]:
def color_node(n):
    if n[0] == "gen":
        return "red"
    elif n[0] == "tgt":
        return "blue"
    elif n[0] == "elem":
        return "green"

In [None]:
node_color = [color_node(n) for n in graph.nodes]
node_size = [graph.nodes[n]["energy"] for n in graph.nodes]

# node_size = [5+ssg.nodes[n]["energy"] for n in ssg.nodes]
# labels = {n: "{}".format(abs(pid))+("*" if vis else "") for n, pid, vis in zip(g.nodes, gen_features["PDG"], mask_visible)}

pos = nx.nx_agraph.graphviz_layout(graph, prog="dot")
fig = plt.figure(figsize=(20,10))
nx.draw_networkx_nodes(
    graph, pos,
    node_color=node_color,
    node_size=node_size,
);
nx.draw_networkx_edges(
    graph, pos, arrowsize=1, width=0.5, alpha=0.2,
    node_size=node_size,
);

# nx.draw_networkx_labels(
#     g, pos,
#     labels=labels,
#     font_size=2
# );
# plt.savefig("graph.pdf")

In [None]:
jets_coll = {}
jets_coll["cmssw"] = genjet_cmssw

for coll in ["ytarget", "ycand"]:
    vec = vector.awk(
        awkward.zip(
            {   
                "pt": arrs_awk[coll]["pt"],
                "eta": arrs_awk[coll]["eta"],
                "phi": arrs_awk[coll]["phi"],
                "energy": arrs_awk[coll]["energy"],
            }
        )
    )
    jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4)
    cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
    jets_coll[coll] = cluster.inclusive_jets(min_pt=3)


vec = vector.awk(
    awkward.zip(
        {   
            "pt": arrs_awk["ytarget"]["pt"][arrs_awk["ytarget"]["ispu"]<1],
            "eta": arrs_awk["ytarget"]["eta"][arrs_awk["ytarget"]["ispu"]<1],
            "phi": arrs_awk["ytarget"]["phi"][arrs_awk["ytarget"]["ispu"]<1],
            "energy": arrs_awk["ytarget"]["energy"][arrs_awk["ytarget"]["ispu"]<1],
        }
    )
)
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4)
cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
jets_coll["ytarget_nopu"] = cluster.inclusive_jets(min_pt=3)

In [None]:
cmssw_to_ytarget = jet_utils.match_two_jet_collections(jets_coll, "cmssw", "ytarget", 0.1)
cmssw_to_ytarget_nopu = jet_utils.match_two_jet_collections(jets_coll, "cmssw", "ytarget_nopu", 0.1)
cmssw_to_ycand = jet_utils.match_two_jet_collections(jets_coll, "cmssw", "ycand", 0.1)

In [None]:
print(len(awkward.flatten(cmssw_to_ytarget["cmssw"])))
print(len(awkward.flatten(cmssw_to_ytarget_nopu["cmssw"])))
print(len(awkward.flatten(cmssw_to_ycand["cmssw"])))

In [None]:
plt.figure()

msk = (arrs_flat["Xelem"]["typ"]==1) & (arrs_flat["ytarget"]["pid"]!=0)
plt.hist(awkward.flatten(arrs_flat["ytarget"]["energy"][msk])/awkward.flatten(arrs_flat["Xelem"]["energy"][msk]), bins=np.logspace(-4,4,500), histtype="step");

msk = (arrs_flat["Xelem"]["typ"]==4) & (arrs_flat["ytarget"]["pid"]!=0)
plt.hist(awkward.flatten(arrs_flat["ytarget"]["energy"][msk])/awkward.flatten(arrs_flat["Xelem"]["energy"][msk]), bins=np.logspace(-4,4,500), histtype="step");

msk = (arrs_flat["Xelem"]["typ"]==5) & (arrs_flat["ytarget"]["pid"]!=0)
plt.hist(awkward.flatten(arrs_flat["ytarget"]["energy"][msk])/awkward.flatten(arrs_flat["Xelem"]["energy"][msk]), bins=np.logspace(-4,4,500), histtype="step");

msk = (arrs_flat["Xelem"]["typ"]==6) & (arrs_flat["ytarget"]["pid"]!=0)
plt.hist(awkward.flatten(arrs_flat["ytarget"]["energy"][msk])/awkward.flatten(arrs_flat["Xelem"]["energy"][msk]), bins=np.logspace(-4,4,500), histtype="step");

plt.xscale("log")
plt.yscale("log")
plt.axvline(1.0, color="black", ls="--", lw=0.5)
plt.show()

In [None]:
plt.figure()
ax = plt.axes()
b = np.logspace(-2,2,400)

plt.hist(
    awkward.flatten(
        jets_coll["ytarget"][cmssw_to_ytarget["ytarget"]].pt / jets_coll["cmssw"][cmssw_to_ytarget["cmssw"]].pt
    ), bins=b, histtype="step", lw=1, label="MLPF target"
)

# plt.hist(
#     awkward.flatten(
#         jets_coll["ytarget_nopu"][cmssw_to_ytarget_nopu["ytarget_nopu"]].pt / jets_coll["cmssw"][cmssw_to_ytarget_nopu["cmssw"]].pt
#     ), bins=b, histtype="step", lw=1, label="MLPF target, no PU"
# )

plt.hist(
    awkward.flatten(
        jets_coll["ycand"][cmssw_to_ycand["ycand"]].pt / jets_coll["cmssw"][cmssw_to_ycand["cmssw"]].pt
    ), bins=b, histtype="step", lw=1, label="PF"
)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("jet $p_T$ / genjet $p_T$")
cms_label(ax)
sample_label(ax, "cms_pf_ttbar_nopu")
plt.ylim(top=ax.get_ylim()[1]*1.5)
plt.legend(loc=3)
plt.axvline(1.0, color="black", ls="--", lw=0.5)
plt.show()

In [None]:
plt.figure(figsize=(10, 10))
b = np.logspace(1, 3, 101)
plt.hist(genmet_cmssw, bins=b, histtype="step", lw=2)
plt.hist(ytarget_met, bins=b, histtype="step", lw=2)
plt.hist(ycand_met, bins=b, histtype="step", lw=2)
plt.xscale("log")
plt.yscale("log")
plt.show()

In [None]:
plt.figure(figsize=(10, 10))
b = np.logspace(1, 4, 61)
plt.hist((ytarget_met/genmet_cmssw)[genmet_cmssw<1], bins=b, histtype="step", lw=2)
plt.hist((ycand_met/genmet_cmssw)[genmet_cmssw<1], bins=b, histtype="step", lw=2)
plt.xscale("log")
plt.show()

In [None]:
plt.figure(figsize=(10, 10))
b = np.logspace(-1,2,101)
plt.hist((ytarget_met/genmet_cmssw)[genmet_cmssw>1], bins=b, histtype="step", lw=2)
plt.hist((ycand_met/genmet_cmssw)[genmet_cmssw>1], bins=b, histtype="step", lw=2)
plt.xscale("log")
plt.show()

In [None]:
plt.figure()
plt.hist([len(x) for x in arrs_awk["Xelem"]["typ"]], bins=100)
plt.show()

In [None]:
if "pythia" in arrs_flat.keys():
    fig = plt.figure()
    ax = plt.axes()
    b = np.logspace(-1, 4, 101)
    plt.hist(awkward.flatten(arrs_flat["pythia"]["pt"]), bins=b, histtype="step", lw=2, label="Pythia")
    plt.hist(awkward.flatten(arrs_awk["ytarget"]["pt"]), bins=b, histtype="step", lw=2, label="MLPF truth")
    plt.hist(awkward.flatten(arrs_awk["ycand"]["pt"]), bins=b, histtype="step", lw=2, label="PF")
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("particle $p_T$ [GeV]")
    plt.ylabel("Number of particles")
    plt.legend(loc=6)
    cms_label(ax)
    plt.ylim(1, 1e5)
    #sample_label(ax, sample)
    plt.show()
    plt.savefig(plot_outpath + "all_pt.pdf", bbox_inches="tight")

In [None]:
if "pythia" in arrs_flat.keys():
    fig = plt.figure()
    ax = plt.axes()
    b = np.logspace(1, 5, 101)
    plt.hist(awkward.sum(arrs_flat["pythia"]["energy"], axis=1), bins=b, histtype="step", lw=2, label="Pythia")
    plt.hist(awkward.sum(arrs_awk["ytarget"]["energy"], axis=1), bins=b, histtype="step", lw=2, label="MLPF truth")
    plt.hist(awkward.sum(arrs_awk["ycand"]["energy"], axis=1), bins=b, histtype="step", lw=2, label="PF")
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("event $\sum E$ [GeV]")
    plt.ylabel("Number of events")
    plt.legend(loc=6)
    cms_label(ax)
    #sample_label(ax, sample)
    plt.ylim(1, 1e3)
    plt.show()
    plt.savefig(plot_outpath + "all_sume.pdf", bbox_inches="tight")

In [None]:
plt.figure()
b = np.linspace(0, 10000, 101)
plt.hist(awkward.sum(arrs_awk["Xelem"]["pt"], axis=1), bins=b, histtype="step", lw=1)
plt.hist(awkward.sum(arrs_awk["ytarget"]["pt"], axis=1), bins=b, histtype="step", lw=1)
plt.hist(awkward.sum(arrs_awk["ycand"]["pt"], axis=1), bins=b, histtype="step", lw=1)
plt.yscale("log")
plt.show()

In [None]:
plt.figure()
b = np.linspace(0, 1e5, 100)
plt.hist(awkward.sum(arrs_awk["Xelem"]["energy"], axis=1), bins=b, histtype="step", lw=2)
plt.hist(awkward.sum(arrs_awk["ytarget"]["energy"], axis=1), bins=b, histtype="step", lw=2)
plt.hist(awkward.sum(arrs_awk["ycand"]["energy"], axis=1), bins=b, histtype="step", lw=2)
plt.yscale("log")
plt.show()

In [None]:
plt.figure(figsize=(12, 10))
ax = plt.axes()
b = np.logspace(3, 5, 101)
plt.hist2d(
    awkward.to_numpy(awkward.sum(arrs_awk["ytarget"]["energy"], axis=1)),
    awkward.to_numpy(awkward.sum(arrs_awk["ycand"]["energy"], axis=1)),
    bins=(b, b),
    cmap="hot_r",
    norm=matplotlib.colors.Normalize(vmin=0),
)
plt.plot([1e3, 1e5], [1e3, 1e5], color="black", ls="--")
plt.colorbar()
plt.xscale("log")
plt.yscale("log")
plt.xlabel("MLPF truth event $\sum E$ [GeV]")
plt.ylabel("PF event $\sum E$ [GeV]")

#cms_label(ax)
#sample_label(ax, sample)
plt.show()
plt.savefig(plot_outpath + "pf_vs_truth_sume.pdf", bbox_inches="tight")

In [None]:
def met(pt, phi):
    px = pt * np.cos(phi)
    py = pt * np.sin(phi)
    pt = np.sqrt(awkward.sum(px**2 + py**2, axis=1))
    return pt

In [None]:
plt.figure(figsize=(12, 10))
ax = plt.axes()
b = np.logspace(1, 3, 100)
plt.hist2d(
    awkward.to_numpy(met(arrs_awk["ytarget"]["pt"], arrs_awk["ytarget"]["phi"])),
    awkward.to_numpy(met(arrs_awk["ycand"]["pt"], arrs_awk["ycand"]["phi"])),
    bins=(b, b),
    cmap="hot_r",
    norm=matplotlib.colors.Normalize(vmin=0),
)
plt.plot([1e1, 1e6], [1e1, 1e6], color="black", ls="--")
plt.colorbar()
plt.xscale("log")
plt.yscale("log")
plt.xlabel("MLPF truth MET [GeV]")
plt.ylabel("PF MET [GeV]")

#cms_label(ax)
#sample_label(ax, sample)
plt.show()
plt.savefig(plot_outpath + "pf_vs_truth_met.pdf", bbox_inches="tight")

In [None]:
for pid in [
    0, 211, 130, 22
]:
    if pid == 0:
        msk = arrs_flat["ytarget"]["pid"] != pid
    else:
        msk = arrs_flat["ytarget"]["pid"] == pid
    print(np.sum(msk))
    data1 = awkward.to_numpy(awkward.flatten(arrs_flat["Xelem"]["eta"][msk]))
    data2 = awkward.to_numpy(awkward.flatten(arrs_flat["ytarget"]["eta"][msk]))

    plt.figure(figsize=(12, 10))
    ax = plt.axes()
    plt.hist2d(
        data2,
        data1,
        bins=(np.linspace(-7, 7, 100), np.linspace(-7, 7, 100)),
        cmap="hot_r",
        norm=matplotlib.colors.Normalize(vmin=0),
    )
    plt.plot([-7, 7], [-7, 7], ls="--", color="black")
    plt.xlim(-7, 7)
    plt.ylim(-7, 7)
    cbar = plt.colorbar(label="number of particles / bin")
    cbar.formatter.set_powerlimits((0, 0))
    cbar.formatter.set_useMathText(True)

    cms_label(ax)
    # if pid == 0:
    #     sample_label(ax, sample)
    # else:
    #     sample_label(ax, sample, ", " + CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
    plt.xlabel("Truth $\eta$")
    plt.ylabel("PFElement $\eta$")
    plt.savefig(plot_outpath + "truth_vs_pfelement_eta_{}.pdf".format(pid), bbox_inches="tight")

    data1 = awkward.to_numpy(awkward.flatten(arrs_flat["Xelem"]["phi"][msk]))
    data2 = awkward.to_numpy(awkward.flatten(arrs_flat["ytarget"]["phi"][msk]))
    plt.figure(figsize=(12, 10))
    ax = plt.axes()
    plt.hist2d(
        data2,
        data1,
        bins=(np.linspace(-4, 4, 100), np.linspace(-4, 4, 100)),
        cmap="hot_r",
        norm=matplotlib.colors.Normalize(vmin=0),
    )
    plt.plot([-4, 4], [-4, 4], ls="--", color="black")
    plt.xlim(-4, 4)
    plt.ylim(-4, 4)
    cbar = plt.colorbar(label="number of particles / bin")
    cbar.formatter.set_powerlimits((0, 0))
    cbar.formatter.set_useMathText(True)

    cms_label(ax)
    # if pid == 0:
    #     sample_label(ax, sample)
    # else:
    #     sample_label(ax, sample, ", " + CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
    plt.xlabel("MLPF truth $\phi$")
    plt.ylabel("PFElement $\phi$")
    plt.show()
    plt.savefig(plot_outpath + "truth_vs_pfelement_phi_{}.pdf".format(pid), bbox_inches="tight")

    data1 = awkward.to_numpy(awkward.flatten(arrs_flat["Xelem"]["energy"][msk]))
    data2 = awkward.to_numpy(awkward.flatten(arrs_flat["ytarget"]["energy"][msk]))
    plt.figure(figsize=(12, 10))
    ax = plt.axes()
    plt.hist2d(
        data2,
        data1,
        bins=(np.logspace(0, 3, 100), np.logspace(0, 3, 100)),
        cmap="hot_r",
        norm=matplotlib.colors.Normalize(vmin=0),
    )
    plt.plot([1, 1e3], [1, 1e3], ls="--", color="black")
    plt.xscale("log")
    plt.yscale("log")
    cbar = plt.colorbar(label="number of particles / bin")
    cbar.formatter.set_powerlimits((0, 0))
    cbar.formatter.set_useMathText(True)
    cms_label(ax)
    # if pid == 0:
    #     sample_label(ax, sample)
    # else:
    #     sample_label(ax, sample, ", " + CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
    plt.xlabel("Truth $E$")
    plt.ylabel("PFElement $E$ [GeV]")
    plt.show()
    plt.savefig(plot_outpath + "truth_vs_pf_e_{}.pdf".format(pid), bbox_inches="tight")

In [None]:
Xelem_typ_f = np.array(awkward.flatten(arrs_flat["Xelem"]["typ"]))

ygen_typ_f = np.array(awkward.flatten(arrs_flat["ytarget"]["pid"]))
ygen_typ_id = np.zeros(len(ygen_typ_f), dtype=np.int32)
for i in range(len(CLASS_LABELS_CMS)):
    ygen_typ_id[ygen_typ_f == CLASS_LABELS_CMS[i]] = i

ycand_typ_f = np.array(awkward.flatten(arrs_flat["ycand"]["pid"]))
ycand_typ_id = np.zeros(len(ycand_typ_f), dtype=np.int32)
for i in range(len(CLASS_LABELS_CMS)):
    ycand_typ_id[ycand_typ_f == CLASS_LABELS_CMS[i]] = i

In [None]:
np.unique(Xelem_typ_f, return_counts=True)

In [None]:
np.unique(ygen_typ_id[Xelem_typ_f == 9], return_counts=True)

In [None]:
plt.figure(figsize=(15, 10))
plt.subplot(1, 2, 1)
cm = sklearn.metrics.confusion_matrix(
    Xelem_typ_f,
    ygen_typ_id,
    labels=range(0, 13),
)
plt.imshow(cm, cmap="Blues", norm=matplotlib.colors.LogNorm(), origin="lower")
plt.colorbar()
plt.xticks(range(len(CLASS_NAMES_CMS)), CLASS_NAMES_CMS, rotation=45)
plt.yticks(range(len(ELEM_NAMES_CMS)), ELEM_NAMES_CMS)
plt.xlim(-0.5, len(CLASS_NAMES_CMS) - 0.5)
plt.title("MLPF truth")

plt.subplot(1, 2, 2)
cm = sklearn.metrics.confusion_matrix(
    Xelem_typ_f,
    ycand_typ_id,
    labels=range(0, 13),
)
plt.imshow(cm, cmap="Blues", norm=matplotlib.colors.LogNorm(), origin="lower")
plt.colorbar()
plt.xticks(range(len(CLASS_NAMES_CMS)), CLASS_NAMES_CMS, rotation=45)
plt.yticks(range(len(ELEM_NAMES_CMS)), ELEM_NAMES_CMS)
plt.xlim(-0.5, len(CLASS_NAMES_CMS) - 0.5)
plt.title("PF")

plt.tight_layout()
plt.show()
plt.savefig(plot_outpath + "primary_element.pdf", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()

b = np.logspace(-2, 4, 101)
h = np.histogram(awkward.to_numpy(awkward.flatten(arrs_awk["ycand"]["pt"])), bins=b)
mplhep.histplot(h, histtype="step", label="PF")

h = np.histogram(awkward.to_numpy(awkward.flatten(arrs_awk["ytarget"]["pt"])), bins=b)
mplhep.histplot(h, histtype="step", label="MLPF truth")

plt.xscale("log")
plt.legend(ncol=1, loc=(0.6, 0.5))

cms_label(ax)
#sample_label(ax, sample)

plt.xlabel("$p_T$ [GeV]")
plt.ylabel("Number of particles")
plt.show()
plt.savefig(plot_outpath + "pf_vs_truth_pt.pdf", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()

b = np.linspace(-6, 6, 101)

h = np.histogram(awkward.to_numpy(awkward.flatten(arrs_awk["ycand"]["eta"])), bins=b)
mplhep.histplot(h, histtype="step", label="PF")

h = np.histogram(awkward.to_numpy(awkward.flatten(arrs_awk["ytarget"]["eta"])), bins=b)
mplhep.histplot(h, histtype="step", label="MLPF truth")

plt.legend(ncol=1, loc=(0.68, 0.75))

cms_label(ax)
#sample_label(ax, sample)

plt.xlabel("particle $\eta$")
plt.ylabel("Number of particles")
plt.show()
plt.savefig(plot_outpath + "pf_vs_truth_eta.pdf", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()
b = np.logspace(-2, 4, 101)
hs = []
pids = sorted(np.unique(awkward.flatten(arrs_awk["ytarget"]["pid"])).tolist())
colors = plt.cm.get_cmap("tab20c", len(pids))
labels = []
for pid in pids[::-1]:
    pt_pid = awkward.to_numpy(awkward.to_numpy(awkward.flatten(arrs_awk["ytarget"]["pt"][arrs_awk["ytarget"]["pid"] == pid])))
    hs.append(np.histogram(pt_pid, bins=b))
    labels.append(CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
mplhep.histplot(hs, stack=True, histtype="fill", label=labels, color=colors.colors)
# plt.yscale("log")
plt.xscale("log")
plt.ylim(0, 1.2 * np.sum([h[0] for h in hs], axis=0).max())
if sample == "TTbar_14TeV_TuneCUETP8M1_cfi":
    plt.ylim(0, 1e5)

plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
ax.yaxis.major.formatter._useMathText = True

plt.legend(ncol=1, loc=(0.1, 0.4))
plt.xlabel("particle $p_T$ [GeV]")
plt.ylabel("Number of particles / bin")
# plt.title("{}\nMLPF truth".format(sample))
cms_label(ax)
#sample_label(ax, sample, ", MLPF truth")
plt.xlim(10**-2, 10**4)
plt.show()
plt.savefig(plot_outpath + "truth_pt.pdf", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()
b = np.linspace(-6, 6, 41)
hs = []
pids = sorted(np.unique(awkward.flatten(arrs_awk["ytarget"]["pid"])).tolist())
colors = plt.cm.get_cmap("tab20c", len(pids))
labels = []
for pid in pids[::-1]:
    pt_pid = awkward.to_numpy(awkward.flatten(arrs_awk["ytarget"]["eta"][arrs_awk["ytarget"]["pid"] == pid]))
    hs.append(np.histogram(pt_pid, bins=b))
    labels.append(CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
mplhep.histplot(hs, stack=True, histtype="fill", label=labels, color=colors.colors)
plt.ylim(0, 1.5 * np.sum([h[0] for h in hs], axis=0).max())
if sample == "TTbar_14TeV_TuneCUETP8M1_cfi":
    plt.ylim(0, 1e5)
plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
ax.yaxis.major.formatter._useMathText = True

# plt.yscale("log")
# plt.xscale("log")
plt.legend(ncol=3, loc=(0.2, 0.65))
plt.xlabel("particle $\eta$")
plt.ylabel("Number of particles / bin")
# plt.title("{}\nMLPF truth".format(sample))
cms_label(ax)
#sample_label(ax, sample, ", MLPF truth")
plt.xlim(-6, 6)
plt.show()
plt.savefig(plot_outpath + "truth_eta.pdf", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()
b = np.logspace(-2, 4, 101)
hs = []
pids = sorted(np.unique(awkward.flatten(arrs_awk["ycand"]["pid"])).tolist())
colors = plt.cm.get_cmap("tab20c", len(pids))
labels = []
for pid in pids[::-1]:
    pt_pid = awkward.to_numpy(awkward.flatten(arrs_awk["ycand"]["pt"][arrs_awk["ycand"]["pid"] == pid]))
    hs.append(np.histogram(pt_pid, bins=b))
    labels.append(CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
mplhep.histplot(hs, stack=True, histtype="fill", label=labels, color=colors.colors)
# plt.yscale("log")
plt.xscale("log")
plt.ylim(0, 1.2 * np.sum([h[0] for h in hs], axis=0).max())
if sample == "TTbar_14TeV_TuneCUETP8M1_cfi":
    plt.ylim(0, 1e5)
plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
ax.yaxis.major.formatter._useMathText = True

plt.legend(ncol=1, loc=(0.7, 0.4))
plt.xlabel("particle $p_T$ [GeV]")
plt.ylabel("Number of particles / bin")
# plt.title("{}\nMLPF truth".format(sample))
cms_label(ax)
#sample_label(ax, sample, ", PF")
plt.xlim(10**-2, 10**4)
plt.show()
plt.savefig(plot_outpath + "pf_pt.pdf", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()
b = np.linspace(-6, 6, 41)
hs = []
pids = sorted(np.unique(awkward.flatten(arrs_awk["ycand"]["pid"])).tolist())
colors = plt.cm.get_cmap("tab20c", len(pids))
labels = []
for pid in pids[::-1]:
    pt_pid = awkward.to_numpy(awkward.flatten(arrs_awk["ycand"]["eta"][arrs_awk["ycand"]["pid"] == pid]))
    hs.append(np.histogram(pt_pid, bins=b))
    labels.append(CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
mplhep.histplot(hs, stack=True, histtype="fill", label=labels, color=colors.colors)
# plt.yscale("log")
# plt.xscale("log")
plt.ylim(0, 1.5 * np.sum([h[0] for h in hs], axis=0).max())
if sample == "TTbar_14TeV_TuneCUETP8M1_cfi":
    plt.ylim(0, 1e5)
plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
ax.yaxis.major.formatter._useMathText = True

plt.legend(ncol=3, loc=(0.2, 0.65))
plt.xlabel("particle $\eta$")
plt.ylabel("Number of particles / bin")
# plt.title("{}\nMLPF truth".format(sample))
cms_label(ax)
#sample_label(ax, sample, ", PF")
plt.xlim(-6, 6)
plt.show()
plt.savefig(plot_outpath + "pf_eta.pdf", bbox_inches="tight")

In [None]:
arrs_flat["pythia"]

In [None]:
if "pythia" in arrs_flat.keys():
    fig = plt.figure(figsize=(10, 10))
    ax = plt.axes()
    b = np.logspace(-2, 5, 101)
    hs = []
    pids = sorted(np.unique(awkward.flatten(arrs_flat["pythia"]["pid"])).tolist())
    colors = plt.cm.get_cmap("tab20c", len(pids))
    labels = []
    for pid in pids[::-1]:
        pt_pid = awkward.to_numpy(awkward.flatten(arrs_flat["pythia"]["pt"][arrs_flat["pythia"]["pid"] == pid]))
        hs.append(np.histogram(pt_pid, bins=b))
        labels.append(int(pid))
    mplhep.histplot(hs, stack=True, histtype="fill", label=labels, color=colors.colors)
    # plt.yscale("log")
    plt.xscale("log")
    plt.legend(ncol=1, loc=6)
    plt.xlabel("$p_T$ [GeV]")
    plt.ylabel("Number of particles")
    # plt.title("{}\nMLPF truth".format(sample))
    cms_label(ax)
    #sample_label(ax, sample, ", Pythia")
    plt.show()
    plt.savefig(plot_outpath + "pythia_pt.pdf", bbox_inches="tight")

In [None]:
if "pythia" in arrs_flat.keys():
    fig = plt.figure(figsize=(10, 10))
    ax = plt.axes()
    b = np.linspace(-6, 6, 101)
    hs = []
    pids = sorted(np.unique(awkward.flatten(arrs_flat["pythia"]["pid"])).tolist())
    colors = plt.cm.get_cmap("tab20c", len(pids))
    labels = []
    for pid in pids[::-1]:
        pt_pid = awkward.to_numpy(awkward.flatten(arrs_flat["pythia"]["eta"][arrs_flat["pythia"]["pid"] == pid]))
        hs.append(np.histogram(pt_pid, bins=b))
        labels.append(int(pid))
    mplhep.histplot(hs, stack=True, histtype="fill", label=labels, color=colors.colors)
    # plt.yscale("log")
    # plt.xscale("log")
    plt.legend(ncol=1, loc=6)
    plt.xlabel("$\eta$")
    plt.ylabel("Number of particles")
    # plt.title("{}\nMLPF truth".format(sample))
    cms_label(ax)
    #sample_label(ax, sample, ", Pythia")
    plt.show()
    plt.savefig(plot_outpath + "pythia_eta.pdf", bbox_inches="tight")

In [None]:
b = np.logspace(-2, 4, 100)
for pid in [1, 2, 11, 13, 22, 130, 211]:
    plt.figure()
    ax = plt.axes()
    plt.hist(
        awkward.to_numpy(awkward.flatten(arrs_awk["ycand"]["pt"][arrs_awk["ycand"]["pid"] == pid])),
        bins=b, histtype="step", lw=2, label="PF"
    )
    plt.hist(
        awkward.to_numpy(awkward.flatten(arrs_awk["ytarget"]["pt"][arrs_awk["ytarget"]["pid"] == pid])),
        bins=b,
        histtype="step",
        lw=2,
        label="MLPF truth",
    )
    plt.yscale("log")
    plt.xscale("log")
    plt.title(CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
    plt.legend(ncol=1, loc=(0.68, 0.8))
    plt.xlabel("$p_T$ [GeV]")
    cms_label(ax)
    #sample_label(ax, sample)
    plt.show()
    plt.savefig(plot_outpath + "pid{}_pt.pdf".format(pid), bbox_inches="tight")

In [None]:
b = np.linspace(-6, 6, 100)
for pid in [1, 2, 11, 13, 22, 130, 211]:
    plt.figure()
    ax = plt.axes()
    plt.hist(
        awkward.flatten(arrs_awk["ycand"]["eta"][arrs_awk["ycand"]["pid"] == pid]),
        weights=awkward.flatten(arrs_awk["ycand"]["energy"][arrs_awk["ycand"]["pid"] == pid]),
        bins=b, histtype="step", lw=2, label="PF"
    )
    plt.hist(
        awkward.flatten(arrs_awk["ytarget"]["eta"][arrs_awk["ytarget"]["pid"] == pid]),
        weights=awkward.flatten(arrs_awk["ytarget"]["energy"][arrs_awk["ytarget"]["pid"] == pid]),
        bins=b,
        histtype="step",
        lw=2,
        label="MLPF truth",
    )
    plt.title(CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
    plt.legend(ncol=1, loc=(0.68, 0.8))
    plt.xlabel("particle $\eta$")
    # cms_label(ax)
    # sample_label(ax, sample)
    plt.show()
    plt.savefig(plot_outpath + "pid{}_eta.pdf".format(pid), bbox_inches="tight")

In [None]:
if "pythia" in arrs_flat.keys():
    fig = plt.figure(figsize=(12, 10))
    ax = plt.axes()

    b = np.logspace(1, 6, 100)
    plt.hist2d(
        awkward.to_numpy(awkward.sum(arrs_flat["pythia"]["energy"], axis=1)),
        awkward.to_numpy(awkward.sum(arrs_flat["ytarget"]["energy"], axis=1)),
        bins=(b, b),
        cmap="hot_r",
        norm=matplotlib.colors.Normalize(vmin=0),
    )
    plt.plot([1e1, 1e6], [1e1, 1e6], color="black", ls="--")
    plt.colorbar(label="events / bin")
    cms_label(ax)
    #sample_label(ax, sample)
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Pythia $\sum E$ [GeV]")
    plt.ylabel("MLPF truth $\sum E$ [GeV]")
    plt.show()
    plt.savefig(plot_outpath + "pythia_vs_mlpf_sume.pdf", bbox_inches="tight")

In [None]:
if "pythia" in arrs_flat.keys():
    fig = plt.figure(figsize=(12, 10))
    ax = plt.axes()

    b = np.logspace(1, 6, 100)
    plt.hist2d(
        awkward.to_numpy(awkward.sum(arrs_flat["pythia"]["energy"], axis=1)),
        awkward.to_numpy(awkward.sum(arrs_flat["ycand"]["energy"], axis=1)),
        bins=(b, b),
        cmap="hot_r",
        norm=matplotlib.colors.Normalize(vmin=0),
    )
    plt.plot([1e1, 1e6], [1e1, 1e6], color="black", ls="--")
    plt.colorbar(label="events / bin")
    cms_label(ax)
    #sample_label(ax, sample)
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Pythia $\sum E$ [GeV]")
    plt.ylabel("PF $\sum E$ [GeV]")
    plt.show()
    plt.savefig(plot_outpath + "pythia_vs_pf_sume.pdf", bbox_inches="tight")

In [None]:
gen_pid = awkward.flatten(arrs_flat["ytarget"]["pid"][arrs_flat["Xelem"]["typ"]==1])
cand_pid = awkward.flatten(arrs_flat["ycand"]["pid"][arrs_flat["Xelem"]["typ"]==1])
track_pt = awkward.flatten(arrs_flat["Xelem"]["pt"][arrs_flat["Xelem"]["typ"]==1])
track_eta = awkward.flatten(arrs_flat["Xelem"]["eta"][arrs_flat["Xelem"]["typ"]==1])

In [None]:
plt.figure()
plt.hist(track_pt, bins=np.logspace(-2,3,100));
plt.xscale("log")
plt.show()

In [None]:
def midpoints(x):
    return x[:-1] + np.diff(x)/2

In [None]:
bins = np.logspace(-1, 3, 20)
fracs_gen = []
fracs_cand = []

for ibin in range(len(bins)-1):
    b0 = bins[ibin]
    b1 = bins[ibin+1]
    msk = (track_pt >= b0) & (track_pt < b1)
    frac_gen = np.sum(gen_pid[msk]!=0) / np.sum(msk)
    frac_cand = np.sum(cand_pid[msk]!=0) / np.sum(msk)
    fracs_gen.append(frac_gen)
    fracs_cand.append(frac_cand)

plt.figure()
plt.plot(midpoints(bins), fracs_gen, marker="o", label="target")
plt.plot(midpoints(bins), fracs_cand, marker="o", label="PF")
plt.xscale("log")
plt.legend(loc="best")
plt.show()

In [None]:
bins = np.linspace(-4, 4, 20)
fracs_gen = []
fracs_cand = []

for ibin in range(len(bins)-1):
    b0 = bins[ibin]
    b1 = bins[ibin+1]
    msk = (track_eta >= b0) & (track_eta < b1) & (track_pt>1)
    frac_gen = np.sum(gen_pid[msk]!=0) / np.sum(msk)
    frac_cand = np.sum(cand_pid[msk]!=0) / np.sum(msk)
    fracs_gen.append(frac_gen)
    fracs_cand.append(frac_cand)

plt.figure()
plt.plot(midpoints(bins), fracs_gen, marker="o", label="target")
plt.plot(midpoints(bins), fracs_cand, marker="o", label="PF")
plt.legend(loc="best")
plt.show()