In [4]:
import time
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import mplhep as hep

import pickle as pkl
import sys
sys.path.insert(0,'..')
from models import ParticleNet
import torch
import torch.nn as nn
from torch_geometric.loader import DataListLoader, DataLoader
from torch_geometric.data import Data, Batch

import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle

from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import roc_auc_score
%matplotlib inline

import mplhep as hep
plt.style.use(hep.style.CMS)
plt.rcParams.update({'font.size': 20})

import awkward as ak
import fastjet
import vector

In [8]:
datapath = '/xai4hepvol/toptagging/test/raw/test.h5'

In [10]:

N_SUBJETS = 3

LABEL = {}
LABEL[fastjet.kt_algorithm] = "$k_{{\mathrm{{T}}}}$"
LABEL[fastjet.antikt_algorithm] = "anti-$k_{{\mathrm{{T}}}}$"
LABEL[fastjet.cambridge_algorithm] = "CA"

JET_ALGO = fastjet.kt_algorithm

df = pd.read_hdf(f"{datapath}", key="table", start=1000, stop=1002)
print(f"{df['is_signal_new'].values=}")


def _col_list(prefix, max_particles=200):
    return ["%s_%d" % (prefix, i) for i in range(max_particles)]


vector.register_awkward()
px = ak.from_regular(ak.from_numpy(df[_col_list("PX")].values), axis=-1)
py = ak.from_regular(ak.from_numpy(df[_col_list("PY")].values), axis=-1)
pz = ak.from_regular(ak.from_numpy(df[_col_list("PZ")].values), axis=-1)
e = ak.from_regular(ak.from_numpy(df[_col_list("E")].values), axis=-1)
mask = e > 0


df['is_signal_new'].values=array([1, 1])


In [None]:

array = ak.zip(
    {
        "px": px[mask],
        "py": py[mask],
        "pz": pz[mask],
        "E": e[mask],
        "mask": mask[mask],
        "particle_idx": ak.local_index(px[mask]),
        "subjet_idx": ak.zeros_like(px[mask], dtype=int) - 1,
    },
    with_name="Momentum4D",
)

pseudojets = []
for jet in array:
    pseudojets.append(
        [
            fastjet.PseudoJet(particle.px, particle.py, particle.pz, particle.E)
            for particle in jet
        ]
    )
print(f"{len(pseudojets)=}")
print(f"{len(pseudojets[0])=}")


jetdef = fastjet.JetDefinition(JET_ALGO, 0.8)

subjet_indices = []
mapping = array.subjet_idx.to_list()
for ijet, pseudojet in enumerate(pseudojets):
    subjet_indices.append([])
    cluster = fastjet.ClusterSequence(pseudojet, jetdef)

    # cluster jets
    jets = cluster.inclusive_jets()
    print(f"{len(jets)=}")
    assert len(jets) == 1

    # get the 3 exclusive jets
    subjets = cluster.exclusive_subjets(jets[0], N_SUBJETS)
    print(f"{len(subjets)=}")
    assert len(subjets) == N_SUBJETS

    # sort by pt
    subjets = sorted(subjets, key=lambda x: x.pt(), reverse=True)

    for subjet_idx, subjet in enumerate(subjets):
        subjet_indices[-1].append([])
        for subjet_const in subjet.constituents():
            for idx, jet_const in enumerate(pseudojet):
                if (
                    subjet_const.px() == jet_const.px()
                    and subjet_const.py() == jet_const.py()
                    and subjet_const.pz() == jet_const.pz()
                    and subjet_const.E() == jet_const.E()
                ):
                    subjet_indices[-1][-1].append(idx)

    for subjet_idx, subjet in enumerate(subjets):
        print(subjet_indices[ijet][subjet_idx])
        local_mapping = np.array(mapping[ijet])
        local_mapping[subjet_indices[ijet][subjet_idx]] = subjet_idx
        mapping[ijet] = local_mapping

# update array
# array.subjet_idx = ak.Array(mapping)
array = ak.zip(
    {
        "px": px[mask],
        "py": py[mask],
        "pz": pz[mask],
        "E": e[mask],
        "mask": mask[mask],
        "particle_idx": ak.local_index(px[mask]),
        "subjet_idx": ak.Array(mapping),
    },
    with_name="Momentum4D",
)

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as colors

for i, arr in enumerate(array):
    fig = plt.figure()
    jet_vector = vector.obj(
        px=ak.sum(arr.px, axis=-1),
        py=ak.sum(arr.py, axis=-1),
        pz=ak.sum(arr.pz, axis=-1),
        E=ak.sum(arr.E, axis=-1),
    )
    subjet_vectors = [
        vector.obj(
            px=ak.sum(arr.px[arr.subjet_idx == j], axis=-1),
            py=ak.sum(arr.py[arr.subjet_idx == j], axis=-1),
            pz=ak.sum(arr.pz[arr.subjet_idx == j], axis=-1),
            E=ak.sum(arr.E[arr.subjet_idx == j], axis=-1),
        )
        for j in range(0, N_SUBJETS)
    ]

    deta = arr.deltaeta(jet_vector)
    dphi = arr.deltaphi(jet_vector)
    dpt = arr.pt / jet_vector.pt

    for j, cmap in zip(range(0, N_SUBJETS), ["Blues", "Reds", "Greens", "Purples"]):
        plt.scatter(
            deta[arr.subjet_idx == j],
            dphi[arr.subjet_idx == j],
            c=dpt[arr.subjet_idx == j],
            s=10,
            norm=colors.LogNorm(vmin=0.001, vmax=0.1),
            cmap=cmap,
            label=f"{LABEL[JET_ALGO]} Subjet {j} $p_{{\mathrm{{T}}}}$={subjet_vectors[j].pt:.0f}",
        )
    plt.scatter([], [], c=[], cmap="Greys", norm=colors.LogNorm(vmin=0.001, vmax=0.1))
    cbar = plt.colorbar()
    cbar.set_label("$p_{\mathrm{T}} / p_{\mathrm{T}}^{\mathrm{jet}}$")
    plt.xlabel(r"$\Delta\eta$")
    plt.ylabel(r"$\Delta\phi$")
    plt.ylim(-1.2, 1.2)
    plt.xlim(-1.2, 1.2)
    plt.legend(title="Top jet" if df["is_signal_new"].values[i] == 1 else "QCD jet")
    plt.savefig(f"etaphi_{i}.pdf")
    plt.savefig(f"etaphi_{i}.png")