In [None]:
import tensorflow_datasets as tfds
import vector
import awkward
import numpy as np
import fastjet

import matplotlib.pyplot as plt

In [None]:
#from mlpf/heptfds/cms_pf/cms_utils.py
CLASS_LABELS_CMS = [0, 211, 130, 1, 2, 22, 11, 13]
Y_FEATURES = [
    "typ_idx",
    "charge",
    "pt",
    "eta",
    "sin_phi",
    "cos_phi",
    "e",
    "ispu",
]

In [None]:
builder = tfds.builder("cms_pf_ttbar", data_dir="/scratch/persistent/joosep/tensorflow_datasets/")
ds_train = builder.as_data_source(split="train")

In [None]:
all_genjets = []
all_genparticles = []

#loop over some events in the dataset
for iev in range(100):
    el = ds_train[iev]
    print(len(el["X"]), el.keys())
    
    genjets = vector.awk(awkward.zip({"pt": el["genjets"][:, 0], "eta": el["genjets"][:, 1], "phi": el["genjets"][:, 2], "e": el["genjets"][:, 3]}))
    mask_genparticles = el["ygen"][:, 0]!=0
    genparticles = el["ygen"][mask_genparticles]
    
    gp_phi = np.arctan2(genparticles[:, 4], genparticles[:, 5]) #sphi,cphi -> phi
    genparticles_p4 = vector.awk(awkward.zip({"pt": genparticles[:, 2], "eta": genparticles[:, 3], "phi": gp_phi, "e": genparticles[:, 6]}))
    gp_ispu = genparticles[:, 7]
    gp_pid = np.array(CLASS_LABELS_CMS)[genparticles[:, 0].astype(np.int64)]
    genparticles = awkward.Record({
        "pid": gp_pid,
        "p4": genparticles_p4,
        "ispu": genparticles[:, 7],
    })

    all_genjets.append(genjets)
    all_genparticles.append(genparticles)

all_genjets = awkward.from_iter(all_genjets)
all_genparticles = awkward.from_iter(all_genparticles)

In [None]:
p4 = vector.awk(
    awkward.zip(
        {
            "pt": all_genparticles.p4.rho,
            "eta": all_genparticles.p4.eta,
            "phi": all_genparticles.p4.phi,
            "e": all_genparticles.p4.t,
        }
    )
)

In [None]:
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4)
cluster = fastjet.ClusterSequence(p4.to_xyzt(), jetdef)
jets = cluster.inclusive_jets(min_pt=10)

cluster = fastjet.ClusterSequence(p4.to_xyzt()[all_genparticles.ispu==0], jetdef)
jets_nopu = cluster.inclusive_jets(min_pt=10)

In [None]:
b = np.linspace(10,100,100)
plt.hist(awkward.flatten(all_genjets.rho), bins=b, histtype="step", label="genjets");
plt.hist(awkward.flatten(jets.pt), bins=b, histtype="step", label="all gp jets");
plt.hist(awkward.flatten(jets_nopu.pt), bins=b, histtype="step", label="ispu=0 gp jets");
plt.legend()