In [None]:
import pandas as pd
import json
import glob
import tqdm
import matplotlib.pyplot as plt
import numpy as np

import sklearn
import sklearn.metrics
import matplotlib
import scipy
import mplhep
import os
import awkward

import vector
import fastjet
import awkward as ak

import pandas
import boost_histogram as bh
import itertools
import mplhep

mplhep.set_style(mplhep.styles.CMS)

In [None]:
import sys

sys.path += ["../../mlpf/plotting/"]

import plot_utils
from plot_utils import pid_to_text, load_eval_data, compute_jet_ratio

In [None]:
outpath = "./"

In [None]:
def sum_overflow_into_last_bin(all_values):
    values = all_values[1:-1]
    values[-1] = values[-1] + all_values[-1]
    values[0] = values[0] + all_values[0]
    return values


def to_bh(data, bins, cumulative=False):
    h1 = bh.Histogram(bh.axis.Variable(bins))
    h1.fill(data)
    if cumulative:
        h1[:] = np.sum(h1.values()) - np.cumsum(h1)
    h1[:] = sum_overflow_into_last_bin(h1.values(flow=True)[:])
    return h1


def loss_plot(train, test, margin=0.05, smoothing=False):
    fig = plt.figure()
    ax = plt.axes()

    alpha = 0.2 if smoothing else 1.0
    l0 = None if smoothing else "train"
    l1 = None if smoothing else "test"
    p0 = plt.plot(train, alpha=alpha, label=l0)
    p1 = plt.plot(test, alpha=alpha, label=l1)

    if smoothing:
        train_smooth = np.convolve(train, np.ones(5) / 5, mode="valid")
        plt.plot(train_smooth, color=p0[0].get_color(), lw=2, label="train")
        test_smooth = np.convolve(test, np.ones(5) / 5, mode="valid")
        plt.plot(test_smooth, color=p1[0].get_color(), lw=2, label="test")

    plt.ylim(test[-1] * (1.0 - margin), test[-1] * (1.0 + margin))
    plt.legend(loc=3, frameon=False)
    plt.xlabel("epoch")
    cms_label(ax)


def med_iqr(arr):
    p25 = np.percentile(arr, 25)
    p50 = np.percentile(arr, 50)
    p75 = np.percentile(arr, 75)
    return p50, p75 - p25


def flatten(arr):
    return arr.reshape(-1, arr.shape[-1])


def get_distribution(prefix, bins, var):

    hists = []
    for pid in [13, 11, 22, 1, 2, 130, 211]:
        icls = CLASS_LABELS_CMS.index(pid)
        msk_pid = yvals_f[prefix + "_cls_id"] == icls
        h = bh.Histogram(bh.axis.Variable(bins))
        d = yvals_f[prefix + "_" + var][msk_pid]
        h.fill(d.flatten())
        hists.append(h)
    return hists


def binom_error(n_sig, n_tot):
    """
    for an efficiency = nSig/nTrueSig or purity = nSig / (nSig + nBckgrd), this function calculates the
    standard deviation according to http://arxiv.org/abs/physics/0701199 .
    """
    variance = np.where(
        n_tot > 0, (n_sig + 1) * (n_sig + 2) / ((n_tot + 2) * (n_tot + 3)) - (n_sig + 1) ** 2 / ((n_tot + 2) ** 2), 0
    )
    return np.sqrt(variance)


def reso_plot(pid, var, bins, legtitle, xlabel):

    fig = plt.figure()
    ax = plt.axes()

    #choose only events with a single track or cluster
    #this makes the definition of the single particle reconstruction efficiency straightforward
    if pid == 2 or pid == 3:
        msk_ev = ak.sum(X[:, :, 0]==2, axis=1)==1
    elif pid == 1:
        msk_ev = (ak.sum(X[:, :, 0]==1, axis=1)==1)
    
    msk = (yvals["gen_cls_id"][msk_ev] == pid) & (yvals["cand_cls_id"][msk_ev] != 0) & (yvals["pred_cls_id"][msk_ev] != 0)
    vals_gen = awkward.flatten(yvals["gen_" + var][msk_ev][msk])
    vals_cand = awkward.flatten(yvals["cand_" + var][msk_ev][msk])
    vals_mlpf = awkward.flatten(yvals["pred_" + var][msk_ev][msk])

    reso_1 = vals_cand / vals_gen
    reso_2 = vals_mlpf / vals_gen
    plt.hist(reso_1, bins=bins, histtype="step", lw=2, label="PF (M={:.2f}, IQR={:.2f})".format(*med_iqr(reso_1)))
    plt.hist(reso_2, bins=bins, histtype="step", lw=2, label="MLPF (M={:.2f}, IQR={:.2f})".format(*med_iqr(reso_2)))
    plt.yscale("log")
    if var == "pt":
        plt.xlabel(r"$p_\mathrm{T,reco} / p_\mathrm{T,gen}$")
    elif var == "eta":
        plt.xlabel(r"$\eta_\mathrm{reco} / \eta_\mathrm{gen}$")
    plt.ylabel("Number of particles / bin")
    # cms_label(ax)
    # sample_label(ax, physics_process, ptcl_name)
    plt.xlim(min(bins), max(bins))
    plt.legend(loc="best", title=legtitle)
    plt.xlabel(xlabel)
    plt.ylim(1, 1e9)
    if var == "pt":
        ivar = 1
    if var == "energy":
        ivar = 5
    plt.savefig("{}/res_icls{}_ivar{}.pdf".format(outpath, pid, ivar), bbox_inches="tight")


def plot_eff_and_fake_rate(icls=1, ivar=4, ielem=1, bins=np.linspace(-3, 6, 100), xlabel="PFElement log[E/GeV]", log=True, ylim_eff=(0.0, 1.5), ylim_fake=(0.0, 1.5), legtitle=""):

    #choose only events with a cluster
    #this makes the definition of the single particle reconstruction efficiency straightforward
    if ielem == 2:
        msk_ev = ak.sum(X[:, :, 0]==2, axis=1)==1
    else:
        msk_ev = (ak.sum(X[:, :, 0]==1, axis=1)==1)
    print(np.sum(msk_ev))
    values = X[msk_ev][:, :, ivar]

    hist_X = bh.Histogram(bh.axis.Variable(bins))
    hist_gen = bh.Histogram(bh.axis.Variable(bins))
    hist_gen_pred = bh.Histogram(bh.axis.Variable(bins))
    hist_gen_cand = bh.Histogram(bh.axis.Variable(bins))
    hist_pred = bh.Histogram(bh.axis.Variable(bins))
    hist_cand = bh.Histogram(bh.axis.Variable(bins))
    hist_pred_fake = bh.Histogram(bh.axis.Variable(bins))
    hist_cand_fake = bh.Histogram(bh.axis.Variable(bins))

    eff_mlpf = bh.Histogram(bh.axis.Variable(bins), storage=bh.storage.Weight())
    eff_pf = bh.Histogram(bh.axis.Variable(bins), storage=bh.storage.Weight())
    fake_pf = bh.Histogram(bh.axis.Variable(bins), storage=bh.storage.Weight())
    fake_mlpf = bh.Histogram(bh.axis.Variable(bins), storage=bh.storage.Weight())

    #filter out tracks that are badly reconstructed
    if ielem == 1:
        msk_X = (X[msk_ev][:, :, 0] == ielem) & (np.abs(X[msk_ev][:, :, 1]/yvals["gen_pt"][msk_ev] - 1) < 1.0)
    else:
        msk_X = (X[msk_ev][:, :, 0] == ielem)

    msk_gen = yvals["gen_cls_id"][msk_ev] == icls
    msk_nogen = yvals["gen_cls_id"][msk_ev] != icls

    msk_pred = yvals["pred_cls_id"][msk_ev] == icls
    msk_nopred = yvals["pred_cls_id"][msk_ev] != icls

    msk_cand = yvals["cand_cls_id"][msk_ev] == icls
    msk_nocand = yvals["cand_cls_id"][msk_ev] != icls

    hist_X.fill(awkward.flatten(values[msk_X]))
    hist_gen.fill(awkward.flatten(values[msk_gen & msk_X]))
    hist_pred.fill(awkward.flatten(values[msk_pred & msk_X]))
    hist_cand.fill(awkward.flatten(values[msk_cand & msk_X]))

    # Genparticle exists, reco particle exists
    hist_gen_pred.fill(awkward.flatten(values[msk_gen & msk_pred & msk_X]))
    hist_gen_cand.fill(awkward.flatten(values[msk_gen & msk_cand & msk_X]))

    # Genparticle does not exist, reco particle exists
    hist_pred_fake.fill(awkward.flatten(values[msk_nogen & msk_pred & msk_X]))
    hist_cand_fake.fill(awkward.flatten(values[msk_nogen & msk_cand & msk_X]))

    eff_mlpf.values()[:] = hist_gen_pred.values() / hist_gen.values()
    eff_mlpf.variances()[:] = binom_error(hist_gen_pred.values(), hist_gen.values()) ** 2

    eff_pf.values()[:] = hist_gen_cand.values() / hist_gen.values()
    eff_pf.variances()[:] = binom_error(hist_gen_cand.values(), hist_gen.values()) ** 2

    fake_pf.values()[:] = hist_cand_fake.values() / hist_cand.values()
    fake_pf.variances()[:] = binom_error(hist_cand_fake.values(), hist_cand.values()) ** 2

    fake_mlpf.values()[:] = hist_pred_fake.values() / hist_pred.values()
    fake_mlpf.variances()[:] = binom_error(hist_pred_fake.values(), hist_pred.values()) ** 2

    plt.figure()
    ax = plt.axes()
    mplhep.histplot(hist_X, label="all PFElements", color="black")
    mplhep.histplot(hist_cand, label="with PF")
    mplhep.histplot(hist_pred, label="with MLPF reco")
    mplhep.histplot(hist_gen, label="with MLPF truth")
    plt.ylabel("Number of PFElements / bin")
    plt.xlabel(xlabel)
    #cms_label(ax)
    plt.yscale("log")
    # sample_label(ax, physics_process, ", " + CLASS_NAMES_CMS[icls])
    if log:
        plt.xscale("log")
    plt.legend(loc=(0.6, 0.65))
    plt.ylim(10, 20 * np.max(hist_X.values()))
    plt.xlim(min(bins), max(bins))
    plt.savefig("{}/distr_icls{}_ivar{}.pdf".format(outpath, icls, ivar), bbox_inches="tight")

    plt.figure()
    ax = plt.axes(sharex=ax)
    mplhep.histplot(eff_pf, label="PF")
    mplhep.histplot(eff_mlpf, label="MLPF")
    plt.ylim(ylim_eff[0], ylim_eff[1])
    plt.ylabel("Efficiency")
    plt.xlabel(xlabel)
    #cms_label(ax)
    # sample_label(ax, physics_process, ", " + CLASS_NAMES_CMS[icls])
    if log:
        plt.xscale("log")
    plt.legend(loc="best", title=legtitle)
    plt.xlim(min(bins), max(bins))
    plt.savefig("{}/eff_icls{}_ivar{}.pdf".format(outpath, icls, ivar), bbox_inches="tight")

    plt.figure()
    ax = plt.axes(sharex=ax)
    mplhep.histplot(fake_pf, label="PF")
    mplhep.histplot(fake_mlpf, label="MLPF")
    plt.ylim(ylim_fake[0], ylim_fake[1])
    plt.ylabel("Fake rate")
    plt.xlabel(xlabel)
    #cms_label(ax)
    # sample_label(ax, physics_process, ", " + CLASS_NAMES_CMS[icls])
    if log:
        plt.xscale("log")
    plt.legend(loc="best", title=legtitle)
    plt.xlim(min(bins), max(bins))
    plt.savefig("{}/fake_icls{}_ivar{}.pdf".format(outpath, icls, ivar), bbox_inches="tight")

    # mplhep.histplot(fake, bins=hist_gen[1], label="fake rate", color="red")


#     plt.legend(frameon=False)
#     plt.ylim(0,1.4)
#     plt.xlabel(xlabel)
#     plt.ylabel("Fraction of particles / bin")

In [None]:
!ls *.pdf

In [None]:
nfiles = -1

In [None]:
yvals, X, _ = load_eval_data("../../mlpf-clic-2023-results/clusters_best_tuned_gnn_clic_v130/evaluation/epoch_96/clic_edm_single_pi_pf/*/*.parquet", nfiles)

In [None]:
plot_eff_and_fake_rate(icls=1, ivar=1, ielem=1, bins=np.logspace(0, 2, 21), xlabel="track $p_T$ [GeV]", log=True, ylim_eff=(0.6, 1.1), ylim_fake=(0, 0.01), legtitle="charged hadrons")

In [None]:
reso_plot(1, "pt", np.linspace(0, 2, 201), "charged hadrons", "$p_{T,\mathrm{reco}}/p_{T,\mathrm{gen}}$")
plt.ylim(1, 1e8)

In [None]:
yvals, X, _ = load_eval_data("../../mlpf-clic-2023-results/clusters_best_tuned_gnn_clic_v130/evaluation/epoch_96/clic_edm_single_gamma_pf/*/*.parquet", nfiles)

In [None]:
plot_eff_and_fake_rate(icls=3, ivar=5, ielem=2, bins=np.logspace(0, 2, 21), xlabel="cluster $E$ [GeV]", log=True, ylim_eff=(0.8, 1.1), ylim_fake=(0, 0.02), legtitle="$\gamma$")

In [None]:
reso_plot(3, "energy", np.linspace(0, 2, 201), "$\gamma$", "$E_{\mathrm{reco}}/E_{\mathrm{gen}}$")
plt.ylim(1, 1e8)

In [None]:
yvals, X, _ = load_eval_data("../../mlpf-clic-2023-results/clusters_best_tuned_gnn_clic_v130/evaluation/epoch_96/clic_edm_single_kaon0l_pf/*/*.parquet", nfiles)

In [None]:
plot_eff_and_fake_rate(icls=2, ivar=5, ielem=2, bins=np.logspace(0, 2, 21), xlabel="cluster $E$ [GeV]", log=True, ylim_eff=(0.5, 1.1), ylim_fake=(0, 0.05), legtitle="neutral hadrons")

In [None]:
reso_plot(2, "energy", np.linspace(0, 6, 201), "neutral hadrons", "$E_{\mathrm{reco}}/E_{\mathrm{gen}}$")
plt.ylim(1, 1e8)