In [None]:
%matplotlib inline
import sklearn
import sklearn.metrics
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas
import mplhep
import pickle
import awkward
import glob
import bz2
import os
import tqdm
import fastjet
import vector
import uproot
from pathlib import Path

import pickle
from functools import reduce
import mplhep
import boost_histogram as bh
import bz2

mplhep.style.use("CMS")

import sys
sys.path += ["../../mlpf/"]
sys.path += ["../../mlpf/plotting/"]

import plot_utils
from plot_utils import ELEM_LABELS_CMS, ELEM_NAMES_CMS
from plot_utils import CLASS_LABELS_CMS, CLASS_NAMES_CMS, EVALUATION_DATASET_NAMES
from plot_utils import sample_label
from plot_utils import pid_to_text


In [None]:
matplotlib.rcParams['axes.labelsize'] = 35
legend_fontsize = 30
sample_label_fontsize = 30
addtext_fontsize = 25

jet_label_coords = 0.02, 0.82
jet_label_coords_single = 0.02, 0.86
sample_label_coords = 0.02, 0.96
jet_label_ak4 = "AK4 ref. jets, $p_T$ > 3 GeV"
particle_label = "$p_T$ > 0.5 GeV, $|\eta|$ < 5"

default_cycler = plt.rcParams['axes.prop_cycle']
pythia_color = list(default_cycler)[0]["color"]
target_color = list(default_cycler)[1]["color"]

pythia_linestyle = "--"
target_linestyle = "-"

In [None]:
def add_results(d0, d1):
    d_ret = {}
    k0 = set(d0.keys())
    k1 = set(d1.keys())

    for k in k0.intersection(k1):
        d_ret[k] = d0[k] + d1[k]

    for k in k0.difference(k1):
        d_ret[k] = d0[k]

    for k in k1.difference(k0):
        d_ret[k] = d1[k]

    return d_ret

In [None]:
#files = [pickle.load(open(fn, "rb")) for fn in glob.glob("/local/joosep/mlpf/cms/20250508_cmssw_15_0_5_d3c6d1/validation_plots/out*.pkl")]
files = [pickle.load(open(fn, "rb")) for fn in glob.glob("../../out*.pkl")]
ret = reduce(add_results, files, {})

sample_keys = sorted(set(["/".join(k.split("/")[0:2]) for k in ret.keys() if not k.startswith("combined")]))
sample_keys_combined = sorted(set(["/".join(k.split("/")[0:3]) for k in ret.keys() if k.startswith("combined")]))

In [None]:
# for k in sorted(ret.keys()):
#     print(k)

In [None]:
sample_labels = {
    "nopu/TTbar_14TeV_TuneCUETP8M1_cfi": "cms_pf_ttbar_nopu",
    "nopu/QCDForPF_14TeV_TuneCUETP8M1_cfi": "cms_pf_qcd_nopu",
    "nopu/ZTT_All_hadronic_14TeV_TuneCUETP8M1_cfi": "cms_pf_ztt_nopu",
    "pu55to75/TTbar_14TeV_TuneCUETP8M1_cfi": "cms_pf_ttbar",
    "pu55to75/QCDForPF_14TeV_TuneCUETP8M1_cfi": "cms_pf_qcd",
    "pu55to75/ZTT_All_hadronic_14TeV_TuneCUETP8M1_cfi": "cms_pf_ztt"
}

def sample_label(ax, sample, additional_text=""):
    plt.text(sample_label_coords[0], sample_label_coords[1], EVALUATION_DATASET_NAMES[sample_labels[sample.replace("combined/", "")]] + "\n" + additional_text, ha="left", va="top", transform=ax.transAxes)

In [None]:
for sample in sample_keys_combined:
    plt.figure(figsize=(8,8))
    ax = plt.axes()
    r = mplhep.hist2dplot(ret[f"{sample}/met_pythia_vs_target_pumask"][bh.rebin(2), bh.rebin(2)], cmap="jet", norm=matplotlib.colors.LogNorm())
    plt.xscale("log")
    plt.yscale("log")
    #cms_label(ax)
    mplhep.cms.label("", data=False, com=14, year='Run 3')
    sample_label(ax, sample)
    plt.xlim(1, 1e3)
    plt.ylim(1, 1e3)
    plt.plot([1, 1e3], [1, 1e3], color="black", ls="--")
    plt.xlabel("Pythia " + plot_utils.labels["met"])
    plt.ylabel("Target " + plot_utils.labels["met"])
    r.cbar.set_label("Count")
    plt.savefig("{}_particle_met_2d.pdf".format(sample.replace("/", "_")), bbox_inches="tight")

In [None]:
for sample in sample_keys_combined:
#for sample in ['combined/nopu/TTbar_14TeV_TuneCUETP8M1_cfi']:
    f, (a0, a1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": [3, 1]}, sharex=True)
    plt.sca(a0)
    mplhep.histplot(ret[f"{sample}/particles_pt_pythia"], label="Pythia", lw=2, color=pythia_color, ls=pythia_linestyle)
    mplhep.histplot(ret[f"{sample}/particles_pt_target_pumask"], label="Target", lw=2, color=target_color, ls=target_linestyle)
    plt.xscale("log")
    plt.yscale("log")
    plt.legend(loc=(0.65, 0.7), fontsize=legend_fontsize)
    plt.ylim(1, 1e8)
    mplhep.cms.label("", data=False, com=14, year='Run 3')
    sample_label(a0, sample)
    a0.text(jet_label_coords[0], jet_label_coords[1], particle_label, transform=a0.transAxes, fontsize=addtext_fontsize)
    plt.ylabel("Count")

    plt.sca(a1)
    mplhep.histplot(ret[f"{sample}/particles_pt_target_pumask"]/ret[f"{sample}/particles_pt_pythia"], lw=2, color=target_color, ls=target_linestyle)
    plt.ylim(0,2)
    plt.xlim(0.5,1000)
    plt.ylabel("Tgt. / Pythia")
    plt.axhline(1.0, color="black", ls="--")
    plt.xlabel("Particle " + plot_utils.labels["pt"])
    plt.savefig("{}_particles_pt.pdf".format(sample.replace("/", "_")), bbox_inches="tight")

In [None]:
# for pid in [11, 22, 211, 130]:
#     for sample in sample_keys_combined:
#         plt.figure()
#         ax = plt.axes()
#         mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_pythia"], label="Pythia")
#         #mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_cand"], label="PF")
#         # mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_caloparticle"], label="CaloParticle")
#         #mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_target"], label="Target")
#         mplhep.histplot(ret[f"{sample}/particle_{pid}_pt_target_pumask"], label="Target", ls="--")
#         plt.xscale("log")
#         plt.yscale("log")
#         plt.legend(loc=(0.65, 0.7), fontsize=legend_fontsize)
#         plt.ylim(1, 1e7)
#         cms_label(ax)
#         sample_label(ax, sample, str(pid))
#         plt.xlabel("Particle " + plot_utils.labels["pt"])
#         plt.ylabel("Count")
#         save_img("{}_particle_{}_pt.png".format(sample.replace("/", "_"), pid), cp_dir=Path("./"))
#         plt.show()

In [None]:
for sample in sample_keys:
    f, (a0, a1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": [3, 1]}, sharex=True)
    plt.sca(a0)
    mplhep.histplot(ret[f"{sample}/jets_pt_genjet"], label="Pythia", lw=2, color=pythia_color, ls=pythia_linestyle)
    # mplhep.histplot(ret[f"{sample}/jets_pt_cand"], label="PF")
    # mplhep.histplot(ret[f"{sample}/jets_pt_target"], label="Target")
    mplhep.histplot(ret[f"{sample}/jets_pt_target_pumask"], label="Target", lw=2, color=target_color, ls=target_linestyle)
    plt.xscale("log")
    plt.legend(fontsize=legend_fontsize)
    mplhep.cms.label("", data=False, com=14, year='Run 3')
    sample_label(a0, sample)
    a0.text(jet_label_coords[0], jet_label_coords[1], jet_label_ak4, transform=a0.transAxes, fontsize=addtext_fontsize)
    plt.yscale("log")
    plt.ylabel("Count")
    plt.ylim(1,1e8)

    plt.sca(a1)
    mplhep.histplot(ret[f"{sample}/jets_pt_target_pumask"]/ret[f"{sample}/jets_pt_genjet"], lw=2, color=target_color, ls=target_linestyle)
    plt.ylim(0,2)
    plt.xlim(3,2000)
    plt.axhline(1.0, color="black", ls="--")
    plt.xlabel("Jet " + plot_utils.labels["pt"])
    plt.savefig("{}_jet_pt.pdf".format(sample.replace("/", "_")), bbox_inches="tight")
    plt.show()

In [None]:
rebin = 1
for sample in sample_keys:
    plt.figure()
    ax = plt.axes()
    mplhep.histplot(ret[f"{sample}/jets_pt_ratio_target_pumask"][bh.rebin(rebin)], yerr=False, label="Target", lw=2, color=target_color, ls=target_linestyle)
    plt.legend(fontsize=legend_fontsize)
    mplhep.cms.label("", data=False, com=14, year='Run 3')
    sample_label(ax, sample, jet_label_ak4)
    a0.text(jet_label_coords_single[0], jet_label_coords_single[1], jet_label_ak4, transform=a0.transAxes, fontsize=addtext_fontsize)
    plt.yscale("log")
    plt.ylim(1,1e8)
    plt.xlabel("Jet " + plot_utils.labels["pt_response"])
    plt.ylabel("Count")
    plt.savefig("{}_jet_response.pdf".format(sample.replace("/", "_")), bbox_inches="tight")
    plt.show()

In [None]:
rebin = 1
for sample in sample_keys:
    f, (a0, a1) = plt.subplots(2, 1, gridspec_kw={"height_ratios": [3, 1]}, sharex=True)
    plt.sca(a0)
    mplhep.histplot(ret[f"{sample}/met_pythia"][bh.rebin(rebin)], yerr=False, label="Pythia", lw=2, color=pythia_color, ls=pythia_linestyle)
    mplhep.histplot(ret[f"{sample}/met_target_pumask"][bh.rebin(rebin)], yerr=False, label="Target", lw=2, color=target_color, ls=target_linestyle)
    plt.legend(loc=(0.65, 0.7), fontsize=legend_fontsize)
    plt.yscale("log")
    plt.xscale("log")
    mplhep.cms.label("", data=False, com=14, year='Run 3')
    sample_label(a0, sample)
    plt.ylim(1,1e8)
    plt.ylabel("Count")
    
    plt.sca(a1)
    mplhep.histplot(ret[f"{sample}/met_target_pumask"][bh.rebin(rebin)]/ret[f"{sample}/met_pythia"][bh.rebin(rebin)], lw=2, color=target_color, ls=target_linestyle)
    plt.ylim(0,2)
    plt.axhline(1.0, color="black", ls="--")
    plt.xlabel(plot_utils.labels["met"])

    plt.savefig("{}_met.pdf".format(sample.replace("/", "_")), bbox_inches="tight")
    plt.show()