In [None]:
%matplotlib inline
import sklearn
import sklearn.metrics
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas
import mplhep
import pickle
import awkward
import glob
import bz2
import os
import tqdm
import fastjet
import vector
import uproot
from pathlib import Path

import pickle
from functools import reduce
import mplhep
import boost_histogram as bh
import bz2

mplhep.style.use("CMS")

import sys
sys.path += ["../../mlpf/"]
sys.path += ["../../mlpf/plotting/"]

import plot_utils
from plot_utils import ELEM_LABELS_CMS, ELEM_NAMES_CMS
from plot_utils import CLASS_LABELS_CMS, CLASS_NAMES_CMS, EVALUATION_DATASET_NAMES
from plot_utils import cms_label, sample_label
from plot_utils import pid_to_text


In [None]:
jet_label_coords = 0.01, 0.84
jet_label_coords_single = 0.01, 0.88
sample_label_coords = 0.01, 0.97
jet_label_ak4 = "AK4 jets, $|\eta|<2.5$"

In [None]:
def save_img(outfile, epoch=None, cp_dir=None, comet_experiment=None):
    if cp_dir:
        image_path = str(cp_dir / outfile)
        plt.savefig(image_path, dpi=100, bbox_inches="tight")
        plt.savefig(image_path.replace(".png", ".pdf"), bbox_inches="tight")
        if comet_experiment:
            comet_experiment.log_image(image_path, step=epoch - 1)

In [None]:
def add_results(d0, d1):
    d_ret = {}
    k0 = set(d0.keys())
    k1 = set(d1.keys())

    for k in k0.intersection(k1):
        d_ret[k] = d0[k] + d1[k]

    for k in k0.difference(k1):
        d_ret[k] = d0[k]

    for k in k1.difference(k0):
        d_ret[k] = d1[k]

    return d_ret

In [None]:
#files = [pickle.load(open(fn, "rb")) for fn in glob.glob("/local/joosep/mlpf/cms/20250508_cmssw_15_0_5_d3c6d1/validation_plots/out*.pkl")]
files = [pickle.load(open(fn, "rb")) for fn in glob.glob("../../out*.pkl")]
ret = reduce(add_results, files, {})

sample_keys = sorted(set(["/".join(k.split("/")[0:2]) for k in ret.keys() if not k.startswith("combined")]))
sample_keys_combined = sorted(set(["/".join(k.split("/")[0:3]) for k in ret.keys() if k.startswith("combined")]))

In [None]:
sample_keys

In [None]:
sample_keys_combined

In [None]:
# for k in sorted(ret.keys()):
#     print(k)

In [None]:
sample_labels = {
    "nopu/TTbar_14TeV_TuneCUETP8M1_cfi": "cms_pf_ttbar_nopu",
    "nopu/QCDForPF_14TeV_TuneCUETP8M1_cfi": "cms_pf_qcd_nopu",
    "nopu/ZTT_All_hadronic_14TeV_TuneCUETP8M1_cfi": "cms_pf_ztt_nopu",
    "pu55to75/TTbar_14TeV_TuneCUETP8M1_cfi": "cms_pf_ttbar",
    "pu55to75/QCDForPF_14TeV_TuneCUETP8M1_cfi": "cms_pf_qcd",
    "pu55to75/ZTT_All_hadronic_14TeV_TuneCUETP8M1_cfi": "cms_pf_ztt"
}

def sample_label(ax, sample, additional_text=""):
    plt.text(sample_label_coords[0], sample_label_coords[1], EVALUATION_DATASET_NAMES[sample_labels[sample.replace("combined/", "")]] + "\n" + additional_text, ha="left", va="top", transform=ax.transAxes)

In [None]:
sample_keys_combined

In [None]:
#for sample in sample_keys_combined:
for sample in ['combined/nopu/TTbar_14TeV_TuneCUETP8M1_cfi']:
    plt.figure()
    ax = plt.axes()
    mplhep.histplot(ret[f"{sample}/particles_pt_pythia"], label="Pythia")
    mplhep.histplot(ret[f"{sample}/particles_pt_cand"], label="PF")
    #mplhep.histplot(ret[f"{sample}/particles_pt_caloparticle"], label="CaloParticle")
    #mplhep.histplot(ret[f"{sample}/particles_pt_target"], label="Target")
    #mplhep.histplot(ret[f"{sample}/particles_pt_target_pumask"], label="Target, PU mask", ls="--")
    plt.xscale("log")
    plt.yscale("log")
    plt.legend(loc=1)
    #plt.ylim(1, 1e8)
    cms_label(ax)
    sample_label(ax, sample)
    plt.xlabel("particle " + plot_utils.labels["pt"])
    plt.ylabel("Count")
    save_img("{}_particles_pt.png".format(sample.replace("/", "_")), cp_dir=Path("./"))
    plt.show()

In [None]:
for pid in [11, 22, 211, 130]:
    for sample in sample_keys_combined:
        plt.figure()
        ax = plt.axes()
        mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_pythia"], label="Pythia")
        mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_cand"], label="PF")
        # mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_caloparticle"], label="CaloParticle")
        mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_target"], label="Target")
        mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_target_pumask"], label="Target, PU mask", ls="--")
        plt.xscale("log")
        plt.yscale("log")
        plt.legend(loc=1)
        plt.ylim(1, 1e7)
        cms_label(ax)
        sample_label(ax, sample, str(pid))
        plt.xlabel("particle " + plot_utils.labels["pt"])
        plt.ylabel("Count")
        save_img("{}_particle_{}_pt.png".format(sample.replace("/", "_"), pid), cp_dir=Path("./"))
        plt.show()

In [None]:
for sample in sample_keys:
    plt.figure()
    ax = plt.axes()
    mplhep.histplot(ret[f"{sample}/jets_pt_genjet"], label="Pythia")
    mplhep.histplot(ret[f"{sample}/jets_pt_cand"], label="PF")
    mplhep.histplot(ret[f"{sample}/jets_pt_target"], label="Target")
    mplhep.histplot(ret[f"{sample}/jets_pt_target_pumask"], label="Target, PU mask", ls="--")
    plt.xscale("log")
    plt.legend()
    cms_label(ax)
    sample_label(ax, sample, "AK4 jets")
    plt.yscale("log")
    plt.ylim(1,1e8)
    plt.xlabel("jet " + plot_utils.labels["pt"])
    plt.ylabel("Count")
    save_img("{}_jet_pt.png".format(sample.replace("/", "_")), cp_dir=Path("./"))
    plt.show()

In [None]:
rebin = 5
for sample in sample_keys:
    plt.figure()
    ax = plt.axes()
    mplhep.histplot(0.0*ret[f"{sample}/jets_pt_ratio_cand"][bh.rebin(rebin)], yerr=False)
    mplhep.histplot(ret[f"{sample}/jets_pt_ratio_cand"][bh.rebin(rebin)], yerr=False, label="PF")
    # mplhep.histplot(ret[f"{sample}/jets_pt_ratio_caloparticle"][bh.rebin(rebin)], yerr=False, label="CaloParticle")
    mplhep.histplot(ret[f"{sample}/jets_pt_ratio_target"][bh.rebin(rebin)], yerr=False, label="Target")
    mplhep.histplot(ret[f"{sample}/jets_pt_ratio_target_pumask"][bh.rebin(rebin)], yerr=False, label="Target, PU mask", ls="--")
    plt.legend()
    cms_label(ax)
    sample_label(ax, sample, "AK4 jets")
    plt.yscale("log")
    plt.ylim(1,1e8)
    plt.xlabel(plot_utils.labels["pt_response"])
    plt.ylabel("Count")
    save_img("{}_jet_response.png".format(sample.replace("/", "_")), cp_dir=Path("./"))
    plt.show()

In [None]:
rebin = 1
for sample in sample_keys:
    plt.figure()
    ax = plt.axes()
    mplhep.histplot(ret[f"{sample}/met_pythia"][bh.rebin(rebin)], yerr=False, label="Pythia")
    mplhep.histplot(ret[f"{sample}/met_cand"][bh.rebin(rebin)], yerr=False, label="PF")
    mplhep.histplot(ret[f"{sample}/met_target"][bh.rebin(rebin)], yerr=False, label="Target")
    mplhep.histplot(ret[f"{sample}/met_target_pumask"][bh.rebin(rebin)], yerr=False, label="Target, PU mask", ls="--")
    plt.legend(loc=1)
    plt.yscale("log")
    plt.xscale("log")
    cms_label(ax)
    sample_label(ax, sample)
    plt.ylim(1,1e8)
    plt.xlabel(plot_utils.labels["met"])
    plt.ylabel("Count")
    save_img("{}_met.png".format(sample.replace("/", "_")), cp_dir=Path("./"))
    plt.show()

In [None]:
!ls *.pdf