In [None]:
import tensorflow_datasets as tfds
import vector
import awkward
import numpy as np
import fastjet
import tqdm

import matplotlib.pyplot as plt

In [None]:
#from mlpf/heptfds/cms_pf/cms_utils.py
CLASS_LABELS_CMS = [0, 211, 130, 1, 2, 22, 11, 13]
Y_FEATURES = [
    "typ_idx",
    "charge",
    "pt",
    "eta",
    "sin_phi",
    "cos_phi",
    "e",
    "ispu",
]

In [None]:
ds_string = "cms_pf_ttbar:2.2.0"
builder = tfds.builder(ds_string, data_dir="/scratch/persistent/joosep/tensorflow_datasets/")
ds_train = builder.as_data_source(split="train")

In [None]:
all_genjets = []
all_genparticles_awk = []
all_Xs = []

#loop over some events in the dataset
for iev in tqdm.tqdm(list(range(5000))):
    el = ds_train[iev]
    # print(len(el["X"]), el.keys())
    
    genjets = vector.awk(awkward.zip({"pt": el["genjets"][:, 0], "eta": el["genjets"][:, 1], "phi": el["genjets"][:, 2], "e": el["genjets"][:, 3]}))
    genparticles = el["ygen"]
    
    gp_phi = np.arctan2(genparticles[:, 4], genparticles[:, 5]) #sphi,cphi -> phi
    genparticles_p4 = vector.awk(awkward.zip({"pt": genparticles[:, 2], "eta": genparticles[:, 3], "phi": gp_phi, "e": genparticles[:, 6]}))
    gp_ispu = genparticles[:, 7]
    gp_pid = np.array(CLASS_LABELS_CMS)[genparticles[:, 0].astype(np.int64)]
    genparticles_awk = awkward.Array({
        "pid": gp_pid,
        "p4": genparticles_p4,
        "ispu": genparticles[:, 7],
    })

    all_Xs.append(el["X"])
    all_genjets.append(genjets)
    all_genparticles_awk.append(genparticles_awk)

In [None]:
all_Xs = awkward.unflatten(awkward.from_numpy(np.concatenate(all_Xs, axis=0)), counts=[len(x) for x in all_Xs])
all_genjets = awkward.unflatten(awkward.concatenate(all_genjets), counts=[len(x) for x in all_genjets])
all_genparticles_awk = awkward.unflatten(awkward.concatenate(all_genparticles_awk), counts=[len(x) for x in all_genparticles_awk])
all_genparticles_no0 = all_genparticles_awk[all_genparticles_awk["pid"]!=0]

In [None]:
p4 = vector.awk(
    awkward.zip(
        {
            "pt": all_genparticles_no0.p4.rho,
            "eta": all_genparticles_no0.p4.eta,
            "phi": all_genparticles_no0.p4.phi,
            "e": all_genparticles_no0.p4.t,
        }
    )
)

In [None]:
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4)
cluster = fastjet.ClusterSequence(p4.to_xyzt(), jetdef)
jets = cluster.inclusive_jets(min_pt=5)

cluster = fastjet.ClusterSequence(p4.to_xyzt()[all_genparticles_no0.ispu==0], jetdef)
jets_nopu = cluster.inclusive_jets(min_pt=5)

In [None]:
b = np.logspace(0,3,100)
plt.hist(awkward.flatten(all_genjets.rho), bins=b, histtype="step", label="genjets");
plt.hist(awkward.flatten(jets.pt), bins=b, histtype="step", label="all gp jets");
plt.hist(awkward.flatten(jets_nopu.pt), bins=b, histtype="step", label="ispu=0 gp jets");
plt.legend()
plt.yscale("log")
plt.xscale("log")
plt.xlabel("jet pt")
plt.ylabel("number of jets")
plt.title(ds_string)

In [None]:
gen_pid = awkward.flatten(all_genparticles_awk[all_Xs[:, :, 0]==1]["pid"])
gen_pt = awkward.flatten(all_genparticles_awk[all_Xs[:, :, 0]==1]["p4"].rho)
track_pt = awkward.flatten(all_Xs[all_Xs[:, :, 0]==1][:, :, 1])

In [None]:
b = np.logspace(-1,2,100)

plt.figure(figsize=(6,6))
plt.hist2d(
    awkward.to_numpy(track_pt[gen_pid!=0]),
    awkward.to_numpy(gen_pt[gen_pid!=0]),
    bins=b
)
plt.plot([0.1, 100], [0.1, 100], color="black", ls="--", lw=1.0)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("track pt")
plt.ylabel("gen pt")
plt.title(ds_string)

In [None]:
bins = np.logspace(-1, 3, 20)
fracs_gen = []

for ibin in range(len(bins)-1):
    b0 = bins[ibin]
    b1 = bins[ibin+1]
    msk = (track_pt >= b0) & (track_pt < b1)
    frac_gen = np.sum(gen_pid[msk]!=0) / np.sum(msk)
    fracs_gen.append(frac_gen)

plt.figure()
plt.plot(bins[:-1], fracs_gen, marker="o")
plt.xscale("log")
plt.xlabel("track pT")
plt.ylabel("fraction of tracks matched to gen")
plt.ylim(0,1)
plt.title(ds_string)