In [None]:
%matplotlib inline

In [None]:
import pandas as pd
import json
import glob
import matplotlib.pyplot as plt
import numpy as np

import sklearn
import sklearn.metrics
import matplotlib
import scipy
import mplhep
import os
import awkward

import vector
import fastjet
import awkward as ak

import pandas
import boost_histogram as bh
import itertools
import mplhep

mplhep.set_style(mplhep.styles.CMS)

In [None]:
import sys

sys.path += ["../mlpf/plotting/"]

import plot_utils
from plot_utils import pid_to_text

from plot_utils import cms_label, sample_label
from plot_utils import ELEM_LABELS_CMS, ELEM_NAMES_CMS
from plot_utils import CLASS_LABELS_CMS, CLASS_NAMES_CMS

In [None]:
def to_bh(data, bins, cumulative=False):
    h1 = bh.Histogram(bh.axis.Variable(bins))
    h1.fill(data)
    if cumulative:
        h1[:] = np.sum(h1.values()) - np.cumsum(h1)
    return h1


def loss_plot(train, test, margin=0.05, smoothing=False):
    fig = plt.figure()
    ax = plt.axes()

    alpha = 0.2 if smoothing else 1.0
    l0 = None if smoothing else "train"
    l1 = None if smoothing else "test"
    p0 = plt.plot(train, alpha=alpha, label=l0)
    p1 = plt.plot(test, alpha=alpha, label=l1)

    if smoothing:
        train_smooth = np.convolve(train, np.ones(5) / 5, mode="valid")
        plt.plot(train_smooth, color=p0[0].get_color(), lw=2, label="train")
        test_smooth = np.convolve(test, np.ones(5) / 5, mode="valid")
        plt.plot(test_smooth, color=p1[0].get_color(), lw=2, label="test")

    plt.ylim(test[-1] * (1.0 - margin), test[-1] * (1.0 + margin))
    plt.legend(loc=3, frameon=False)
    plt.xlabel("epoch")
    cms_label(ax)


def med_iqr(arr):
    p25 = np.percentile(arr, 25)
    p50 = np.percentile(arr, 50)
    p75 = np.percentile(arr, 75)
    return p50, p75 - p25


def flatten(arr):
    return arr.reshape(-1, arr.shape[-1])


def get_distribution(prefix, bins, var):

    hists = []
    for pid in [13, 11, 22, 1, 2, 130, 211]:
        icls = CLASS_LABELS_CMS.index(pid)
        msk_pid = yvals_f[prefix + "_cls_id"] == icls
        h = bh.Histogram(bh.axis.Variable(bins))
        d = yvals_f[prefix + "_" + var][msk_pid]
        h.fill(d.flatten())
        hists.append(h)
    return hists


def binom_error(n_sig, n_tot):
    """
    for an efficiency = nSig/nTrueSig or purity = nSig / (nSig + nBckgrd), this function calculates the
    standard deviation according to http://arxiv.org/abs/physics/0701199 .
    """
    variance = np.where(
        n_tot > 0, (n_sig + 1) * (n_sig + 2) / ((n_tot + 2) * (n_tot + 3)) - (n_sig + 1) ** 2 / ((n_tot + 2) ** 2), 0
    )
    return np.sqrt(variance)


def reso_plot(pid, var, bins, ptcl_name):

    fig = plt.figure()
    ax = plt.axes()

    msk = (yvals["gen_cls_id"] == pid) & (yvals["cand_cls_id"] != 0) & (yvals["pred_cls_id"] != 0)
    vals_gen = awkward.flatten(yvals["gen_" + var][msk])
    vals_cand = awkward.flatten(yvals["cand_" + var][msk])
    vals_mlpf = awkward.flatten(yvals["pred_" + var][msk])

    reso_1 = vals_cand / vals_gen
    reso_2 = vals_mlpf / vals_gen
    plt.hist(reso_1, bins=bins, histtype="step", lw=2, label="PF, M={:.2f}, IQR={:.2f}".format(*med_iqr(reso_1)))
    plt.hist(reso_2, bins=bins, histtype="step", lw=2, label="MLPF, M={:.2f}, IQR={:.2f}".format(*med_iqr(reso_2)))
    plt.yscale("log")
    if var == "pt":
        plt.xlabel(r"$p_\mathrm{T,reco} / p_\mathrm{T,gen}$")
    elif var == "eta":
        plt.xlabel(r"$\eta_\mathrm{reco} / \eta_\mathrm{gen}$")
    plt.ylabel("Number of particles / bin")
    cms_label(ax)
    sample_label(ax, physics_process, ptcl_name)
    plt.xlim(min(bins), max(bins))
    plt.legend(loc=(0.4, 0.7))
    # plt.ylim(1, 1e9)
    # plt.savefig("{}/pt_res_ch_had.pdf".format(outpath), bbox_inches="tight")


def plot_eff_and_fake_rate(icls=1, ivar=4, ielem=1, bins=np.linspace(-3, 6, 100), xlabel="PFElement log[E/GeV]", log=True):

    values = X[:, :, ivar]

    hist_X = bh.Histogram(bh.axis.Variable(bins))
    hist_gen = bh.Histogram(bh.axis.Variable(bins))
    hist_gen_pred = bh.Histogram(bh.axis.Variable(bins))
    hist_gen_cand = bh.Histogram(bh.axis.Variable(bins))
    hist_pred = bh.Histogram(bh.axis.Variable(bins))
    hist_cand = bh.Histogram(bh.axis.Variable(bins))
    hist_pred_fake = bh.Histogram(bh.axis.Variable(bins))
    hist_cand_fake = bh.Histogram(bh.axis.Variable(bins))

    eff_mlpf = bh.Histogram(bh.axis.Variable(bins), storage=bh.storage.Weight())
    eff_pf = bh.Histogram(bh.axis.Variable(bins), storage=bh.storage.Weight())
    fake_pf = bh.Histogram(bh.axis.Variable(bins), storage=bh.storage.Weight())
    fake_mlpf = bh.Histogram(bh.axis.Variable(bins), storage=bh.storage.Weight())

    hist_X.fill(awkward.flatten(values[X[:, :, 0] == ielem]))
    hist_gen.fill(awkward.flatten(values[(yvals["gen_cls_id"] == icls) & (X[:, :, 0] == ielem)]))

    hist_gen_pred.fill(
        awkward.flatten(values[(yvals["gen_cls_id"] == icls) & (yvals["pred_cls_id"] == icls) & (X[:, :, 0] == ielem)])
    )
    hist_gen_cand.fill(
        awkward.flatten(values[(yvals["gen_cls_id"] == icls) & (yvals["cand_cls_id"] == icls) & (X[:, :, 0] == ielem)])
    )

    hist_pred.fill(awkward.flatten(values[(yvals["pred_cls_id"] == icls) & (X[:, :, 0] == ielem)]))
    hist_cand.fill(awkward.flatten(values[(yvals["cand_cls_id"] == icls) & (X[:, :, 0] == ielem)]))
    hist_pred_fake.fill(
        awkward.flatten(values[(yvals["gen_cls_id"] != icls) & (yvals["pred_cls_id"] == icls) & (X[:, :, 0] == ielem)])
    )
    hist_cand_fake.fill(
        awkward.flatten(values[(yvals["gen_cls_id"] != icls) & (yvals["cand_cls_id"] == icls) & (X[:, :, 0] == ielem)])
    )

    eff_mlpf.values()[:] = hist_gen_pred.values() / hist_gen.values()
    eff_mlpf.variances()[:] = binom_error(hist_gen_pred.values(), hist_gen.values()) ** 2

    eff_pf.values()[:] = hist_gen_cand.values() / hist_gen.values()
    eff_pf.variances()[:] = binom_error(hist_gen_cand.values(), hist_gen.values()) ** 2

    fake_pf.values()[:] = hist_cand_fake.values() / hist_cand.values()
    fake_pf.variances()[:] = binom_error(hist_cand_fake.values(), hist_cand.values()) ** 2

    fake_mlpf.values()[:] = hist_pred_fake.values() / hist_pred.values()
    fake_mlpf.variances()[:] = binom_error(hist_pred_fake.values(), hist_pred.values()) ** 2

    plt.figure()
    ax = plt.axes()
    mplhep.histplot(hist_X, label="all PFElements", color="black")
    mplhep.histplot(hist_cand, label="with PF")
    mplhep.histplot(hist_pred, label="with MLPF reco")
    mplhep.histplot(hist_gen, label="with MLPF truth")
    plt.ylabel("Number of PFElements / bin")
    plt.xlabel(xlabel)
    cms_label(ax)
    plt.yscale("log")
    sample_label(ax, physics_process, ", " + CLASS_NAMES_CMS[icls])
    if log:
        plt.xscale("log")
    plt.legend(loc=(0.6, 0.65))
    plt.ylim(10, 20 * np.max(hist_X.values()))
    plt.xlim(min(bins), max(bins))
    plt.savefig("{}/distr_icls{}_ivar{}.pdf".format(outpath, icls, ivar), bbox_inches="tight")

    plt.figure()
    ax = plt.axes(sharex=ax)
    mplhep.histplot(eff_pf, label="PF")
    mplhep.histplot(eff_mlpf, label="MLPF")
    plt.ylim(0, 1.5)
    plt.ylabel("Efficiency")
    plt.xlabel(xlabel)
    cms_label(ax)
    sample_label(ax, physics_process, ", " + CLASS_NAMES_CMS[icls])
    if log:
        plt.xscale("log")
    plt.legend(loc=(0.75, 0.7))
    plt.xlim(min(bins), max(bins))
    plt.savefig("{}/eff_icls{}_ivar{}.pdf".format(outpath, icls, ivar), bbox_inches="tight")

    plt.figure()
    ax = plt.axes(sharex=ax)
    mplhep.histplot(fake_pf, label="PF")
    mplhep.histplot(fake_mlpf, label="MLPF")
    plt.ylim(0, 1.5)
    plt.ylabel("Fake rate")
    plt.xlabel(xlabel)
    cms_label(ax)
    sample_label(ax, physics_process, ", " + CLASS_NAMES_CMS[icls])
    if log:
        plt.xscale("log")
    plt.legend(loc=(0.75, 0.7))
    plt.xlim(min(bins), max(bins))
    plt.savefig("{}/fake_icls{}_ivar{}.pdf".format(outpath, icls, ivar), bbox_inches="tight")

    # mplhep.histplot(fake, bins=hist_gen[1], label="fake rate", color="red")


#     plt.legend(frameon=False)
#     plt.ylim(0,1.4)
#     plt.xlabel(xlabel)
#     plt.ylabel("Fraction of particles / bin")

In [None]:
# These can be overriden from the command line using `papermill cms-mlpf.ipynb -p path new/path/...`
backend = "tf"
sample = "cms_pf_qcd_high_pt"

if sample == "cms_pf_ttbar":
    physics_process = "TTbar_14TeV_TuneCUETP8M1_cfi"
if sample == "cms_pf_ztt":
    physics_process = "ZTT_All_hadronic_14TeV_TuneCUETP8M1_cfi"
if sample == "cms_pf_qcd":
    physics_process = "QCDForPF_14TeV_TuneCUETP8M1_cfi"
if sample == "cms_pf_qcd_high_pt":
    physics_process = "QCD_Pt_3000_7000_14TeV_TuneCUETP8M1_cfi"

path = "../experiments/cms-transformer_20221114_182159_902630.gpu0.local/evaluation/epoch_14/{}/".format(sample)
PAPERMILL_OUTPUT_PATH = path

In [None]:
outpath = PAPERMILL_OUTPUT_PATH
if os.path.isfile(outpath):
    outpath = os.path.dirname(outpath)
print("params", path, outpath)

# Load the predictions

In [None]:
if backend == "tf":

    # Load all parquet files from the eval_model step
    data = []
    for fi in list(glob.glob(path + "/pred_batch*.parquet")):
        dd = awkward.from_parquet(fi)
        data.append(dd)
    data = awkward.concatenate(data)

    # Get the inputs, flatten across the file dimension
    X = awkward.flatten(data["inputs"], axis=1)

    yvals = {}
    for typ in ["gen", "cand", "pred"]:
        for k in data["particles"][typ].fields:
            yvals["{}_{}".format(typ, k)] = awkward.flatten(data["particles"][typ][k], axis=1)

    # Get the classification output as a class ID
    yvals["gen_cls_id"] = np.argmax(yvals["gen_cls"], axis=-1)
    yvals["cand_cls_id"] = np.argmax(yvals["cand_cls"], axis=-1)
    yvals["pred_cls_id"] = np.argmax(yvals["pred_cls"], axis=-1)

    for typ in ["gen", "cand", "pred"]:

        # Compute phi, px, py
        yvals[typ + "_phi"] = np.arctan2(yvals[typ + "_sin_phi"], yvals[typ + "_cos_phi"])
        yvals[typ + "_px"] = yvals[typ + "_pt"] * yvals[typ + "_cos_phi"]
        yvals[typ + "_py"] = yvals[typ + "_pt"] * yvals[typ + "_sin_phi"]

        # Get the jet vectors
        jetvec = vector.arr(data["jets"][typ])
        for k in ["pt", "eta", "phi", "energy"]:
            yvals["jets_{}_{}".format(typ, k)] = awkward.flatten(awkward.flatten(getattr(jetvec, k)))

    # Matched jets
    yvals["jets_pt_gen_to_cand"] = np.stack(
        [
            awkward.flatten(awkward.flatten(vector.arr(data["matched_jets"]["gen_to_cand"]["gen_jet"]).pt)),
            awkward.flatten(awkward.flatten(vector.arr(data["matched_jets"]["gen_to_cand"]["cand_jet"]).pt)),
        ],
        axis=-1,
    )

    yvals["jets_pt_gen_to_pred"] = np.stack(
        [
            awkward.flatten(awkward.flatten(vector.arr(data["matched_jets"]["gen_to_pred"]["gen_jet"]).pt)),
            awkward.flatten(awkward.flatten(vector.arr(data["matched_jets"]["gen_to_pred"]["pred_jet"]).pt)),
        ],
        axis=-1,
    )

In [None]:
if backend == "pyg":
    import torch

    path = "./preds/"
    X = torch.load(f"{path}/post_processed_Xs.pt")
    X_f = torch.load(f"{path}/post_processed_X_f.pt")
    msk_X_f = torch.load(f"{path}/post_processed_msk_X_f.pt")
    yvals = torch.load(f"{path}/post_processed_yvals.pt")
    yvals_f = torch.load(f"{path}/post_processed_yvals_f.pt")

# Make plots

### Full distribution plots for each class

In [None]:
for icls in range(0, 8):
    fig, axs = plt.subplots(
        2, 2, figsize=(2 * mplhep.styles.CMS["figure.figsize"][0], 2 * mplhep.styles.CMS["figure.figsize"][1])
    )

    for ax, ivar in zip(axs.flatten(), ["pt", "energy", "eta", "phi"]):

        plt.sca(ax)

        if icls == 0:
            vals_true = awkward.flatten(yvals["gen_" + ivar][yvals["gen_cls_id"] != 0])
            vals_pf = awkward.flatten(yvals["cand_" + ivar][yvals["cand_cls_id"] != 0])
            vals_pred = awkward.flatten(yvals["pred_" + ivar][yvals["pred_cls_id"] != 0])
        else:
            vals_true = awkward.flatten(yvals["gen_" + ivar][yvals["gen_cls_id"] == icls])
            vals_pf = awkward.flatten(yvals["cand_" + ivar][yvals["cand_cls_id"] == icls])
            vals_pred = awkward.flatten(yvals["pred_" + ivar][yvals["pred_cls_id"] == icls])

        if ivar == "pt" or ivar == "energy":
            b = np.logspace(-3, 4, 61)
            log = True
        else:
            b = np.linspace(np.min(vals_true), np.max(vals_true), 41)
            log = False

        plt.hist(vals_true, bins=b, histtype="step", lw=2, label="gen", color="black")
        plt.hist(vals_pf, bins=b, histtype="step", lw=2, label="PF")
        plt.hist(vals_pred, bins=b, histtype="step", lw=2, label="MLPF")
        plt.legend(loc=(0.75, 0.75))

        ylim = ax.get_ylim()

        cls_name = CLASS_NAMES_CMS[icls] if icls > 0 else "all"
        plt.xlabel("{} {}".format(cls_name, ivar))

        plt.yscale("log")
        plt.ylim(10, 10 * ylim[1])

        if log:
            plt.xscale("log")
        cms_label(ax)

    plt.tight_layout()
    plt.savefig("{}/distribution_icls{}.pdf".format(outpath, icls), bbox_inches="tight")

### Plot of the neutral cluster classification output

In [None]:
df = pandas.DataFrame()
msk = X[:, :, 0] == 5
df["X_energy"] = awkward.to_numpy(awkward.flatten(X[msk, 4]))
df["X_eta"] = awkward.to_numpy(awkward.flatten(X[msk, 2]))

df["cand_energy"] = awkward.to_numpy(awkward.flatten(yvals["cand_energy"][msk]))
df["cand_cls_id"] = awkward.to_numpy(awkward.flatten(yvals["cand_cls_id"][msk]))

df["gen_energy"] = awkward.to_numpy(awkward.flatten(yvals["gen_energy"][msk]))
df["gen_cls_id"] = awkward.to_numpy(awkward.flatten(yvals["gen_cls_id"][msk]))

df["pred_energy"] = awkward.to_numpy(awkward.flatten(yvals["pred_energy"][msk]))
df["pred_cls_id"] = awkward.to_numpy(awkward.flatten(yvals["pred_cls_id"][msk]))

df["pred_cls0"] = awkward.to_numpy(awkward.flatten(yvals["pred_cls"][msk, 0]))
df["pred_cls1"] = awkward.to_numpy(awkward.flatten(yvals["pred_cls"][msk, 1]))
df["pred_cls2"] = awkward.to_numpy(awkward.flatten(yvals["pred_cls"][msk, 2]))

In [None]:
b = np.linspace(0, 1, 100)
plt.figure(figsize=(15, 15))

ax = plt.subplot(3, 1, 1)
plt.xlim(0, 1)
msk = df["X_energy"] < 1
plt.hist(
    df["pred_cls2"][(df["gen_cls_id"] == 0) & msk], bins=b, histtype="step", lw=2, color="red", label="no true particle"
)
plt.hist(df["pred_cls2"][(df["gen_cls_id"] == 2) & msk], bins=b, histtype="step", lw=2, color="blue", label="true n.had.")
plt.yscale("log")
plt.legend(loc=4)
ax.text(0.01, 0.7, "PFElement E < 1 GeV", transform=ax.transAxes)
plt.ylabel("PFElements / bin")
plt.xlabel("Classification output for neutral hadron")
cms_label(ax, y=0.9)
sample_label(ax, physics_process, y=0.8)
plt.ylim(1, 1e7)

ax = plt.subplot(3, 1, 2)
plt.xlim(0, 1)
msk = (df["X_energy"] > 1) & (df["X_energy"] < 10)
plt.hist(
    df["pred_cls2"][(df["gen_cls_id"] == 0) & msk], bins=b, histtype="step", lw=2, color="red", label="no true particle"
)
plt.hist(df["pred_cls2"][(df["gen_cls_id"] == 2) & msk], bins=b, histtype="step", lw=2, color="blue", label="true n.had.")
plt.yscale("log")
plt.ylabel("PFElements / bin")
ax.text(0.01, 0.7, "1 < PFElement E < 10 GeV", transform=ax.transAxes)
plt.ylim(1, 1e7)
plt.xlabel("Classification output for neutral hadron")
cms_label(ax, y=0.9)
sample_label(ax, physics_process, y=0.8)

ax = plt.subplot(3, 1, 3)
plt.xlim(0, 1)
msk = (df["X_energy"] > 10) & (df["X_energy"] < 100)
plt.hist(
    df["pred_cls2"][(df["gen_cls_id"] == 0) & msk], bins=b, histtype="step", lw=2, color="red", label="no true particle"
)
plt.hist(df["pred_cls2"][(df["gen_cls_id"] == 2) & msk], bins=b, histtype="step", lw=2, color="blue", label="true n.had.")
plt.yscale("log")
plt.ylabel("PFElements / bin")
ax.text(0.01, 0.7, "10 < PFElement E < 100 GeV", transform=ax.transAxes)
plt.xlabel("Classification output for neutral hadron")
plt.ylim(1, 1e7)
cms_label(ax, y=0.9)
sample_label(ax, physics_process, y=0.8)

plt.tight_layout()

plt.savefig("{}/clsout_ielem5_icls2.pdf".format(outpath), bbox_inches="tight")

In [None]:
gen_cls_id = yvals["gen_cls_id"]
gen_pt = yvals["gen_pt"][gen_cls_id != 0]
gen_eta = yvals["gen_eta"][gen_cls_id != 0]
gen_phi = yvals["gen_phi"][gen_cls_id != 0]
gen_e = yvals["gen_energy"][gen_cls_id != 0]
gen_cls_id = gen_cls_id[gen_cls_id != 0]

cand_cls_id = yvals["cand_cls_id"]
cand_pt = yvals["cand_pt"][cand_cls_id != 0]
cand_eta = yvals["cand_eta"][cand_cls_id != 0]
cand_phi = yvals["cand_phi"][cand_cls_id != 0]
cand_e = yvals["cand_energy"][cand_cls_id != 0]
cand_cls_id = cand_cls_id[cand_cls_id != 0]

pred_cls_id = yvals["pred_cls_id"]
pred_pt = yvals["pred_pt"][pred_cls_id != 0]
pred_eta = yvals["pred_eta"][pred_cls_id != 0]
pred_phi = yvals["pred_phi"][pred_cls_id != 0]
pred_e = yvals["pred_energy"][pred_cls_id != 0]
pred_cls_id = pred_cls_id[pred_cls_id != 0]

In [None]:
b = np.logspace(-1, 4, 101)

f, (a0, a1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": [3, 1]}, sharex=True)

plt.sca(a0)

h0 = to_bh(ak.flatten(cand_pt[cand_cls_id != 0]), b)
h1 = to_bh(ak.flatten(pred_pt[pred_cls_id != 0]), b)
h2 = to_bh(ak.flatten(gen_pt[gen_cls_id != 0]), b)

mplhep.histplot(h0, histtype="step", lw=2, label="PF")
mplhep.histplot(h1, histtype="step", lw=2, label="MLPF")
mplhep.histplot(h2, histtype="step", lw=2, label="MLPF truth")
plt.xscale("log")
plt.yscale("log")
plt.legend(frameon=False)
plt.ylabel("number of particles / bin")

plt.sca(a1)
mplhep.histplot(h0 / h2, histtype="step", lw=2)
mplhep.histplot(h1 / h2, histtype="step", lw=2)
mplhep.histplot(h2 / h2, histtype="step", lw=2)
plt.ylabel("reco / truth")
plt.xlabel("particle $p_T$ [GeV]")

In [None]:
b = np.linspace(-6, 6, 41)

f, (a0, a1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": [3, 1]}, sharex=True)

plt.sca(a0)

h0 = to_bh(ak.flatten(cand_eta[cand_cls_id != 0]), b)
h1 = to_bh(ak.flatten(pred_eta[pred_cls_id != 0]), b)
h2 = to_bh(ak.flatten(gen_eta[gen_cls_id != 0]), b)

mplhep.histplot(h0, histtype="step", lw=2, label="PF")
mplhep.histplot(h1, histtype="step", lw=2, label="MLPF")
mplhep.histplot(h2, histtype="step", lw=2, label="MLPF truth")
plt.legend(frameon=False)

plt.sca(a1)
mplhep.histplot(h0 / h2, histtype="step", lw=2)
mplhep.histplot(h1 / h2, histtype="step", lw=2)
mplhep.histplot(h2 / h2, histtype="step", lw=2)
plt.ylabel("reco / truth")
plt.xlabel("particle $\eta$")
plt.ylim(0, 2)

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()
b = np.logspace(-2, 4, 101)
hs = []
pids = [1, 2, 11, 13, 22, 130, 211]

colors = plt.cm.get_cmap("tab20c", len(pids))
labels = []
for pid in pids[::-1]:
    pid_idx = CLASS_LABELS_CMS.index(pid)
    pt_pid = ak.flatten(pred_pt[pred_cls_id == pid_idx])
    hs.append(np.histogram(pt_pid, bins=b))
    labels.append(CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
mplhep.histplot(hs, stack=True, histtype="fill", label=labels, color=colors.colors)
# plt.yscale("log")
plt.xscale("log")

plt.ylim(0, 2e6)
plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
ax.yaxis.major.formatter._useMathText = True

plt.legend(ncol=1, loc=(0.7, 0.4))
plt.xlabel("$p_T$ [GeV]")
plt.ylabel("Number of particles / bin")
cms_label(ax)
sample_label(ax, physics_process, ", MLPF")
plt.xlim(10**-2, 10**4)
plt.savefig(outpath + "/mlpf_pt.pdf", bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = plt.axes()
b = np.linspace(-6, 6, 41)
hs = []

colors = plt.cm.get_cmap("tab20c", len(pids))
labels = []
for pid in pids[::-1]:
    pid_idx = CLASS_LABELS_CMS.index(pid)
    pt_pid = ak.flatten(pred_eta[pred_cls_id == pid_idx])
    hs.append(np.histogram(pt_pid, bins=b))
    labels.append(CLASS_NAMES_CMS[CLASS_LABELS_CMS.index(pid)])
mplhep.histplot(hs, stack=True, histtype="fill", label=labels, color=colors.colors)
# plt.yscale("log")
# plt.xscale("log")
plt.ylim(0, 1e6)
plt.ticklabel_format(style="sci", axis="y", scilimits=(0, 0))
ax.yaxis.major.formatter._useMathText = True

plt.legend(ncol=3, loc=(0.2, 0.65))
plt.xlabel("$\eta$")
plt.ylabel("Number of particles / bin")
cms_label(ax)
sample_label(ax, physics_process, ", MLPF")
plt.xlim(-6, 6)
plt.savefig(outpath + "/mlpf_eta.pdf", bbox_inches="tight")

In [None]:
b = np.logspace(0.5, 4, 100)

plt.figure()
ax = plt.axes()
cms_label(ax)
sample_label(ax, physics_process)

plt.hist(yvals["jets_gen_pt"], bins=b, histtype="step", lw=2, label="genjet")
plt.hist(yvals["jets_cand_pt"], bins=b, histtype="step", lw=2, label="PF jet")
plt.hist(yvals["jets_pred_pt"], bins=b, histtype="step", lw=2, label="MLPF jet")
plt.yscale("log")
plt.xscale("log")
plt.ylim(1, 1e6)
plt.legend(loc=(0.6, 0.7))
plt.xlabel("jet $p_T$ [GeV]")
plt.ylabel("Number of jets")
plt.savefig("{}/jets.pdf".format(outpath), bbox_inches="tight")

In [None]:
b = np.linspace(-7, 7, 41)

plt.figure()
ax = plt.axes()
cms_label(ax)
sample_label(ax, physics_process)

plt.hist(yvals["jets_gen_eta"], bins=b, histtype="step", lw=2, label="genjet")
plt.hist(yvals["jets_cand_eta"], bins=b, histtype="step", lw=2, label="PF jet")
plt.hist(yvals["jets_pred_eta"], bins=b, histtype="step", lw=2, label="MLPF jet")
plt.legend(loc=(0.6, 0.7))
plt.savefig("{}/jets_eta.pdf".format(outpath), bbox_inches="tight")

In [None]:
b = np.linspace(-2, 15, 101)

fig = plt.figure()
ax = plt.axes()
vals = yvals["jets_pt_gen_to_cand"][:, 1] / yvals["jets_pt_gen_to_cand"][:, 0]
p = med_iqr(vals)
plt.hist(vals, bins=b, histtype="step", lw=2, label=r"PF (M={:.2f}, IQR={:.2f})".format(p[0], p[1]))

vals = yvals["jets_pt_gen_to_pred"][:, 1] / yvals["jets_pt_gen_to_pred"][:, 0]
p = med_iqr(vals)
plt.hist(vals, bins=b, histtype="step", lw=2, label=r"MLPF (M={:.2f}, IQR={:.2f})".format(p[0], p[1]))

plt.yscale("log")
plt.ylim(1, 1e7)
cms_label(ax)
sample_label(ax, physics_process)
plt.legend(loc=(0.4, 0.7))
plt.xlabel(r"jet $\frac{p_{\mathrm{T,reco}}}{p_{T,\mathrm{gen}}}$")
plt.savefig("{}/jetres.pdf".format(outpath), bbox_inches="tight")

In [None]:
plt.figure()
ax = plt.axes()
plt.hist(np.sum(X[:, :, 0] != 0, axis=1), bins=100)
plt.axvline(6400, ls="--", color="black")
plt.xlabel("number of input PFElements")
plt.ylabel("number of events / bin")
cms_label(ax)
sample_label(ax, physics_process)

In [None]:
px = yvals["gen_px"][yvals["gen_cls_id"] != 0]
py = yvals["gen_py"][yvals["gen_cls_id"] != 0]
gen_met = np.sqrt(awkward.sum(px**2 + py**2, axis=1))

px = yvals["cand_px"][yvals["cand_cls_id"] != 0]
py = yvals["cand_py"][yvals["cand_cls_id"] != 0]
cand_met = np.sqrt(awkward.sum(px**2 + py**2, axis=1))

px = yvals["pred_px"][yvals["pred_cls_id"] != 0]
py = yvals["pred_py"][yvals["pred_cls_id"] != 0]
pred_met = np.sqrt(awkward.sum(px**2 + py**2, axis=1))

In [None]:
fig = plt.figure()
ax = plt.axes()

b = np.logspace(3, 4, 100)
plt.hist(cand_met, bins=b, histtype="step", lw=2, label="PF")
plt.hist(pred_met, bins=b, histtype="step", lw=2, label="MLPF")
plt.hist(gen_met, bins=b, histtype="step", lw=2, label="gen")
plt.yscale("log")
plt.xscale("log")
plt.legend(loc=(0.75, 0.7))
cms_label(ax)
sample_label(ax, physics_process)
plt.ylim(1, 1e3)
plt.xlabel("MET [GeV]")
plt.ylabel("Number of events")
plt.savefig("{}/met.pdf".format(outpath), bbox_inches="tight")

In [None]:
fig = plt.figure()
ax = plt.axes()
b = np.linspace(0, 2, 101)
vals_a = cand_met / gen_met
vals_b = pred_met / gen_met

# vals_a = vals_a[gen_met < 500]
# vals_b = vals_b[gen_met < 500]

p = med_iqr(vals_a)
plt.hist(vals_a, bins=b, histtype="step", lw=2, label="PF, $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]))

p = med_iqr(vals_b)
plt.hist(
    vals_b,
    bins=b,
    histtype="step",
    lw=2,
    label="MLPF, $(M={:.2f}, IQR={:.2f})$".format(p[0], p[1]),
)
# plt.yscale("log")
cms_label(ax)
sample_label(ax, physics_process)
plt.ylim(1, 1e3)
plt.legend(loc=(0.35, 0.7))
plt.xlabel(r"$\frac{\mathrm{MET}_{\mathrm{reco}}}{\mathrm{MET}_{\mathrm{gen}}}$")
plt.ylabel("Number of events / bin")
plt.savefig("{}/metres.pdf".format(outpath), bbox_inches="tight")

In [None]:
fig = plt.figure()
ax = plt.axes()

plt.scatter(awkward.sum(yvals["gen_pt"], axis=1), awkward.sum(yvals["cand_pt"], axis=1), alpha=0.5, label="PF")
plt.scatter(awkward.sum(yvals["gen_pt"], axis=1), awkward.sum(yvals["pred_pt"], axis=1), alpha=0.5, label="MLPF")
# plt.plot([1000, 6000], [1000, 6000], color="black")
# plt.xlim(1000, 6000)
# plt.ylim(1000, 6000)
plt.legend(loc=4)
cms_label(ax)
sample_label(ax, physics_process)
plt.ylabel("Reconstructed $\sum p_T$ [GeV]")

plt.savefig("{}/sum_pt.pdf".format(outpath), bbox_inches="tight")

### Resolution plots

In [None]:
reso_plot(1, "pt", np.linspace(0, 15, 100), ", ch.had.")
plt.ylim(1, 1e9)
plt.savefig("{}/pt_res_ch_had.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(2, "pt", np.linspace(0, 100, 100), ", n.had.")
plt.ylim(1, 1e9)
plt.savefig("{}/pt_res_n_had.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(3, "pt", np.linspace(0, 100, 100), ", HFHAD")
plt.ylim(1, 1e9)
plt.savefig("{}/pt_res_hfhad.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(4, "pt", np.linspace(0, 100, 100), ", HFEM")
plt.ylim(1, 1e9)
plt.savefig("{}/pt_res_hfem.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(5, "pt", np.linspace(0, 50, 100), ", $\gamma$")
plt.ylim(1, 1e9)
plt.savefig("{}/pt_res_gamma.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(6, "pt", np.linspace(0, 10, 100), ", $e^\pm$")
plt.ylim(1, 1e9)
plt.savefig("{}/pt_res_ele.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(7, "pt", np.linspace(0, 5, 100), ", $\mu^\pm$")
plt.ylim(1, 1e9)
plt.savefig("{}/pt_res_mu.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(1, "eta", np.linspace(-50, 50, 100), ", ch.had.")
plt.ylim(1, 1e10)
plt.savefig("{}/eta_res_ch_had.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(2, "eta", np.linspace(-50, 50, 100), ", n.had.")
plt.ylim(1, 1e10)
plt.savefig("{}/eta_res_n_had.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(3, "eta", np.linspace(-5, 5, 100), ", HFHAD")
plt.ylim(1, 1e10)
plt.savefig("{}/eta_res_hfhad.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(4, "eta", np.linspace(-5, 5, 100), ", HFEM")
plt.ylim(1, 1e10)
plt.savefig("{}/eta_res_hfem.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(5, "eta", np.linspace(-10, 10, 100), ", $\gamma$")
plt.ylim(1, 1e10)
plt.savefig("{}/eta_res_gamma.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(6, "eta", np.linspace(-10, 10, 100), ", $e^\pm$")
plt.ylim(1, 1e10)
plt.savefig("{}/eta_res_ele.pdf".format(outpath), bbox_inches="tight")

In [None]:
reso_plot(7, "eta", np.linspace(-10, 10, 100), ", $\mu^\pm$")
plt.ylim(1, 1e10)
plt.savefig("{}/eta_res_mu.pdf".format(outpath), bbox_inches="tight")

### Efficiencies and fake rates

In [None]:
plot_eff_and_fake_rate(icls=1, ivar=1, ielem=1, bins=np.logspace(-1, 2, 41), xlabel="track $p_T$ [GeV]", log=True)

In [None]:
plot_eff_and_fake_rate(icls=2, ivar=4, ielem=5, bins=np.logspace(0, 3, 41), xlabel="calorimeter cluster E [GeV]", log=True)

In [None]:
plot_eff_and_fake_rate(icls=3, ivar=4, ielem=9, bins=np.logspace(0, 3, 41), xlabel="PFElement E [GeV]", log=True)

In [None]:
plot_eff_and_fake_rate(icls=4, ivar=4, ielem=8, bins=np.logspace(0, 3, 41), xlabel="PFElement E [GeV]", log=True)

In [None]:
plot_eff_and_fake_rate(icls=5, ivar=4, ielem=4, bins=np.logspace(-1, 4, 41), xlabel="PFElement E [GeV]", log=True)

In [None]:
plot_eff_and_fake_rate(icls=6, ivar=1, ielem=6, bins=np.logspace(0, 2, 41), xlabel="PFElement E [GeV]", log=True)

In [None]:
plot_eff_and_fake_rate(icls=7, ivar=1, ielem=1, bins=np.logspace(0, 2, 41), xlabel="PFElement $p_T$ [GeV]", log=True)

### Training details

In [None]:
def load_history(path, min_epoch=None, max_epoch=None):
    ret = {}
    for fi in glob.glob(path):
        data = json.load(open(fi))
        epoch = int(fi.split("_")[-1].split(".")[0])
        ret[epoch] = data

    if not max_epoch:
        max_epoch = max(ret.keys())
    if not min_epoch:
        min_epoch = min(ret.keys())

    ret2 = []
    for i in range(min_epoch, max_epoch + 1):
        ret2.append(ret[i])
    return pandas.DataFrame(ret2)

In [None]:
history = load_history(path + "/../../../history/history_*.json")

In [None]:
p0 = loss_plot(history["loss"].values, history["val_loss"].values, margin=0.5)
plt.ylabel("Total loss")
plt.savefig("{}/loss.pdf".format(outpath), bbox_inches="tight")

In [None]:
p0 = loss_plot(history["cls_loss"].values, history["val_cls_loss"].values, margin=0.5)
plt.ylabel("Multiclassification loss")
plt.savefig("{}/cls_loss.pdf".format(outpath), bbox_inches="tight")

In [None]:
reg_loss = sum([history["{}_loss".format(l)].values for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]])
val_reg_loss = sum(
    [history["val_{}_loss".format(l)].values for l in ["energy", "pt", "eta", "sin_phi", "cos_phi", "charge"]]
)
p0 = loss_plot(reg_loss, val_reg_loss, margin=0.2)
plt.ylabel("Regression loss")
plt.savefig("{}/reg_loss.pdf".format(outpath), bbox_inches="tight")

In [None]:
if "pt_e_eta_phi_loss" in history.keys():
    reg_loss = sum([history["{}_loss".format(l)].values for l in ["pt_e_eta_phi"]])
    val_reg_loss = sum([history["val_{}_loss".format(l)].values for l in ["pt_e_eta_phi"]])
    p0 = loss_plot(reg_loss, val_reg_loss, margin=0.1)
    plt.ylabel("Event loss")
    plt.savefig("{}/event_loss.pdf".format(outpath), bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(12, 12))
ax = plt.axes()

cm_norm = sklearn.metrics.confusion_matrix(
    awkward.flatten(yvals["gen_cls_id"][X[:, :, 0] != 0]),
    awkward.flatten(yvals["pred_cls_id"][X[:, :, 0] != 0]),
    labels=range(0, len(CLASS_LABELS_CMS)),
    normalize="true",
)

plt.imshow(cm_norm, cmap="Blues", origin="lower")
plt.colorbar()


thresh = cm_norm.max() / 1.5
for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
    plt.text(
        j,
        i,
        "{:0.2f}".format(cm_norm[i, j]),
        horizontalalignment="center",
        color="white" if cm_norm[i, j] > thresh else "black",
        fontsize=12,
    )

cms_label(ax, y=1.01)
# cms_label_sample_label(x1=0.18, x2=0.52, y=0.82)
plt.xticks(range(len(CLASS_NAMES_CMS)), CLASS_NAMES_CMS, rotation=45)
plt.yticks(range(len(CLASS_NAMES_CMS)), CLASS_NAMES_CMS)
plt.xlabel("MLPF candidate ID")
plt.ylabel("Truth ID")
# plt.ylim(-0.5, 6.9)
# plt.title("MLPF trained on PF")
plt.savefig("{}/cm_normed.pdf".format(outpath), bbox_inches="tight")

In [None]:
fig = plt.figure(figsize=(12, 12))
ax = plt.axes()

cm_norm = sklearn.metrics.confusion_matrix(
    awkward.flatten(yvals["gen_cls_id"][X[:, :, 0] != 0]),
    awkward.flatten(yvals["cand_cls_id"][X[:, :, 0] != 0]),
    labels=range(0, len(CLASS_LABELS_CMS)),
    normalize="true",
)

plt.imshow(cm_norm, cmap="Blues", origin="lower")
plt.colorbar()


thresh = cm_norm.max() / 1.5
for i, j in itertools.product(range(cm_norm.shape[0]), range(cm_norm.shape[1])):
    plt.text(
        j,
        i,
        "{:0.2f}".format(cm_norm[i, j]),
        horizontalalignment="center",
        color="white" if cm_norm[i, j] > thresh else "black",
        fontsize=12,
    )

cms_label(ax, y=1.01)
# cms_label_sample_label(x1=0.18, x2=0.52, y=0.82)
plt.xticks(range(len(CLASS_NAMES_CMS)), CLASS_NAMES_CMS, rotation=45)
plt.yticks(range(len(CLASS_NAMES_CMS)), CLASS_NAMES_CMS)
plt.xlabel("PF candidate ID")
plt.ylabel("Truth ID")
# plt.ylim(-0.5, 6.9)
# plt.title("MLPF trained on PF")
plt.savefig("{}/cm_normed_pf.pdf".format(outpath), bbox_inches="tight")