In [None]:
import tensorflow_datasets as tfds
import vector
import awkward
import numpy as np
import fastjet
import tqdm

import matplotlib.pyplot as plt

In [None]:
#from mlpf/heptfds/cms_pf/cms_utils.py
CLASS_LABELS_CMS = [0, 211, 130, 1, 2, 22, 11, 13]
Y_FEATURES = [
    "typ_idx",
    "charge",
    "pt",
    "eta",
    "sin_phi",
    "cos_phi",
    "e",
    "ispu",
]

In [None]:
ds_string = "cms_pf_qcd/10:2.6.0"
builder = tfds.builder(ds_string, data_dir="/scratch/persistent/joosep/tensorflow_datasets/")
ds_train = builder.as_data_source(split="train")
ds_test = builder.as_data_source(split="test")

In [None]:
def ds_to_awk(ds):
    all_genjets = []
    all_genparticles_awk = []
    all_Xs = []
    
    #loop over some events in the dataset
    #nev = len(ds)
    nev = 1000
    for iev in tqdm.tqdm(list(range(nev))):
        el = ds[iev]
        # print(len(el["X"]), el.keys())
        
        genjets = vector.awk(awkward.zip({"pt": el["genjets"][:, 0], "eta": el["genjets"][:, 1], "phi": el["genjets"][:, 2], "e": el["genjets"][:, 3]}))
        genparticles = el["ytarget"]
        
        gp_phi = np.arctan2(genparticles[:, 4], genparticles[:, 5]) #sphi,cphi -> phi
        genparticles_p4 = vector.awk(awkward.zip({"pt": genparticles[:, 2], "eta": genparticles[:, 3], "phi": gp_phi, "e": genparticles[:, 6]}))
        gp_ispu = genparticles[:, 7]
        gp_pid = np.array(CLASS_LABELS_CMS)[genparticles[:, 0].astype(np.int64)]
        genparticles_awk = awkward.Array({
            "pid": gp_pid,
            "p4": genparticles_p4,
            "ispu": genparticles[:, 7],
        })
    
        all_Xs.append(el["X"])
        all_genjets.append(genjets)
        all_genparticles_awk.append(genparticles_awk)

    all_Xs = awkward.unflatten(awkward.from_numpy(np.concatenate(all_Xs, axis=0)), counts=[len(x) for x in all_Xs])
    all_genjets = awkward.unflatten(awkward.concatenate(all_genjets), counts=[len(x) for x in all_genjets])
    all_genparticles_awk = awkward.unflatten(awkward.concatenate(all_genparticles_awk), counts=[len(x) for x in all_genparticles_awk])
    all_genparticles_no0 = all_genparticles_awk[all_genparticles_awk["pid"]!=0]

    return all_genparticles_no0

In [None]:
genparticles_train = ds_to_awk(ds_train)
genparticles_test = ds_to_awk(ds_test)

In [None]:
plt.hist(awkward.flatten(genparticles_train["ispu"]), bins=np.linspace(0,1,100), density=1, label="train", histtype="step");
plt.hist(awkward.flatten(genparticles_test["ispu"]), bins=np.linspace(0,1,100), density=1, label="test", histtype="step");
plt.yscale("log")
plt.xlabel("PU frac")
plt.title(ds_string)
plt.legend(loc="best")

In [None]:
p4 = vector.awk(
    awkward.zip(
        {
            "pt": genparticles_train.p4.rho,
            "eta": genparticles_train.p4.eta,
            "phi": genparticles_train.p4.phi,
            "e": genparticles_train.p4.t,
        }
    )
)

In [None]:
jetdef = fastjet.JetDefinition(fastjet.antikt_algorithm, 0.4)
cluster = fastjet.ClusterSequence(p4.to_xyzt(), jetdef)
jets = cluster.inclusive_jets(min_pt=5)

cluster = fastjet.ClusterSequence(p4.to_xyzt()[genparticles_train.ispu==0], jetdef)
jets_nopu = cluster.inclusive_jets(min_pt=5)

In [None]:
b = np.logspace(0,3,100)
#plt.hist(awkward.flatten(all_genjets.rho), bins=b, histtype="step", label="genjets");
plt.hist(awkward.flatten(jets.pt), bins=b, histtype="step", label="all gp jets");
plt.hist(awkward.flatten(jets_nopu.pt), bins=b, histtype="step", label="ispu=0 gp jets");
plt.legend()
plt.yscale("log")
plt.xscale("log")
plt.xlabel("jet pt")
plt.ylabel("number of jets")
plt.title(ds_string)