In [584]:
import json
import math
import os
import pickle as pkl
import time
from typing import Dict, List

import awkward as ak
import fastjet
import matplotlib
import matplotlib.pyplot as plt
import mplhep as hep
import numpy as np
import sklearn
import sklearn.metrics
import torch
import tqdm
import vector
from torch_geometric.data import Batch, Data

plt.style.use(hep.style.CMS)
plt.rcParams.update({"font.size": 20})

In [2]:
%load_ext autoreload
%autoreload 2

In [383]:
# import relevant functions from mlpf.pyg
import sys
sys.path.append("/home/jovyan/particleflow/mlpf/")
import pyg
sys.path.append("/home/jovyan/particleflow/mlpf/pyg/")
import utils

from PFDataset import PFDataset, PFDataLoader, Collater

from pyg.mlpf import MLPF
from pyg.utils import X_FEATURES, Y_FEATURES, unpack_predictions, unpack_target
from jet_utils import match_two_jet_collections

In [4]:
# define the global base device
world_size = 1
if torch.cuda.device_count():
    rank = 0
    device = torch.device("cuda:0")
    print(f"Will use {torch.cuda.get_device_name(device)}")
else:
    rank = "cpu"
    device = "cpu"
    print("Will use cpu")

Will use NVIDIA A100-SXM4-80GB


# Load the pre-trained MLPF model

In [10]:
def load_checkpoint(checkpoint, model, optimizer=None):
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model.module.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        return model, optimizer
    else:
        return model
    
    
loaddir = "/pfvol/experiments/MLPF_clic_A100_1gpu_pyg-clic_20240322_233518_004447"

with open(f"{loaddir}/model_kwargs.pkl", "rb") as f:
    mlpf_kwargs = pkl.load(f)

mlpf_kwargs["attention_type"] = "flash"

mlpf = MLPF(**mlpf_kwargs).to(torch.device(rank))
checkpoint = torch.load(f"{loaddir}/best_weights.pth", map_location=torch.device(rank))

mlpf = load_checkpoint(checkpoint, mlpf)
mlpf.eval()

print(mlpf)    

# CLIC dataset

In [6]:
! ls /pfvol/tensorflow_datasets/

clic_edm_qq_pf		   cms_pf_qcd_high_pt	   cms_pf_single_proton
clic_edm_ttbar_pf	   cms_pf_single_electron  cms_pf_single_tau
clic_edm_ttbar_pu10_pf	   cms_pf_single_gamma	   cms_pf_sms_t1tttt
clic_edm_ww_fullhad_pf	   cms_pf_single_mu	   cms_pf_ttbar
clic_edm_zh_tautau_pf	   cms_pf_single_neutron   cms_pf_ztt
cms_pf_multi_particle_gun  cms_pf_single_pi	   delphes_qcd_pf
cms_pf_qcd		   cms_pf_single_pi0	   delphes_ttbar_pf


In [8]:
# we can see the 17th features here (recall type is 1 for tracks and 2 for clusters)
X_FEATURES["clic"]

['type',
 'pt | et',
 'eta',
 'sin_phi',
 'cos_phi',
 'p | energy',
 'chi2 | position.x',
 'ndf | position.y',
 'dEdx | position.z',
 'dEdxError | iTheta',
 'radiusOfInnermostHit | energy_ecal',
 'tanLambda | energy_hcal',
 'D0 | energy_other',
 'omega | num_hits',
 'Z0 | sigma_x',
 'time | sigma_y',
 'Null | sigma_z']

In [9]:
# we can see the 8 gen features per pf element here (notice the jet_index which may be useful)
Y_FEATURES

['cls_id', 'charge', 'pt', 'eta', 'sin_phi', 'cos_phi', 'energy']

# Get the dataset (Events)

In [384]:
data_dir = "/home/jovyan/particleflow/tensorflow_datasets/"
sample = "clic_edm_ttbar_pf"

dataset_train = PFDataset(data_dir, f"{sample}:1.5.0", "train", num_samples=10_000)

batch_size = 100
pad_3d = True
train_loader = PFDataLoader(dataset_train.ds,
                               batch_size=batch_size,
                               collate_fn=Collater(["X", "ygen", "ycand"], pad_3d=pad_3d),
                              )

In [385]:
for batch in train_loader:
    batch = batch.to(rank, non_blocking=True)
    break
print(batch.X.shape)    

torch.Size([100, 246, 17])


# Pre-processing (Events -> Jets)

In [388]:
############################### set up forward hooks to retrive the latent representations of MLPF
latent_reps = {}
def get_activations(name):
    def hook(mlpf, input, output):
        latent_reps[name] = output.detach()

    return hook

mlpf.conv_reg[0].dropout.register_forward_hook(get_activations("conv_reg0"))
mlpf.conv_reg[1].dropout.register_forward_hook(get_activations("conv_reg1"))
mlpf.conv_reg[2].dropout.register_forward_hook(get_activations("conv_reg2"))
mlpf.nn_id.register_forward_hook(get_activations("nn_id"))    
###############################

def get_latent_reps(batch, latent_reps):
    for layer in latent_reps:
        if "conv" in layer:
            latent_reps[layer] *= batch.mask.unsqueeze(-1)

    latentX = torch.cat(
        [
            batch.X.to(rank),
            latent_reps["conv_reg0"],
            latent_reps["conv_reg1"],
            latent_reps["conv_reg2"],
            latent_reps["nn_id"],
        ],
        axis=-1,
    )
    return latentX

In [542]:
sample_to_lab = {
    "clic_edm_ttbar_pf": 1,
    "clic_edm_qq_pf": 0,   
}

jetdef = fastjet.JetDefinition(fastjet.ee_genkt_algorithm, 0.7, -1.0)
jet_ptcut = 15.0
jet_match_dr = 0.1
    
for i, batch in enumerate(train_loader):
    # initilize - will save on disk at the end of the loop
    jet_dataset = []

    # run the MLPF model in inference mode to get the MLPF cands / latent representations    
    batch = batch.to(rank, non_blocking=True)
    with torch.no_grad():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
            ymlpf = mlpf(batch.X, batch.mask)
    ymlpf = unpack_predictions(ymlpf)
    
    # get the latent representations
    ymlpf["latentX"] = get_latent_reps(batch, latent_reps)

    for k, v in ymlpf.items():
        ymlpf[k] = v.detach().cpu()
    
    jets_coll = {}
    ####################### get the reco jet collection
    vec = vector.awk(
        ak.zip(
            {
                "pt": ymlpf["p4"][:, :, 0].to("cpu"),
                "eta": ymlpf["p4"][:, :, 1].to("cpu"),
                "phi": ymlpf["p4"][:, :, 2].to("cpu"),
                "e": ymlpf["p4"][:, :, 3].to("cpu"),
            }
        )
    )
    cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
    jets_coll["reco"] = cluster.inclusive_jets(min_pt=jet_ptcut)
    #######################    
    
    ####################### get the gen jet collection
    ygen = unpack_target(batch.ygen)    
    vec = vector.awk(
        ak.zip(
            {
                "pt": ygen["p4"][:, :, 0].to("cpu"),
                "eta": ygen["p4"][:, :, 1].to("cpu"),
                "phi": ygen["p4"][:, :, 2].to("cpu"),
                "e": ygen["p4"][:, :, 3].to("cpu"),
            }
        )
    )
    cluster = fastjet.ClusterSequence(vec.to_xyzt(), jetdef)
    jets_coll["gen"] = cluster.inclusive_jets(min_pt=jet_ptcut)
    #######################
    
    matched_jets = match_two_jet_collections(jets_coll, "reco", "gen", jet_match_dr)
    
    # get the constituents to mask the MLPF candidates and build the input for the downstream
    genptcl_to_genjet_index = cluster.constituent_index(min_pt=jet_ptcut)    
    
    # build the big jet list
    for iev in tqdm.tqdm(range(len(matched_jets["gen"]))):
        
        num_matched_jets = len(matched_jets["gen"][iev])   # number of gen jets matched to reco
        
        for ijet in range(num_matched_jets):   
            igenjet = matched_jets["gen"][iev][ijet]
            irecojet = matched_jets["reco"][iev][ijet]

            if len(genptcl_to_genjet_index[iev][igenjet])<3:   # don't save jets with very few particles
                continue

            # build a mask tensor that will select the particles that belong to the gen jet
            msk_indices = genptcl_to_genjet_index[iev][igenjet].to_numpy()
            PADDIM = 256            

            jet_dataset += [

                dict(
                    # Target for jet tagging
                    gen_jet_label=torch.tensor(sample_to_lab[sample]).unsqueeze(0),
                    
                    # Target for jet p4 regression                    
                    gen_jet_pt=torch.tensor(jets_coll["gen"][iev][igenjet].pt).unsqueeze(0),
                    gen_jet_eta=torch.tensor(jets_coll["gen"][iev][igenjet].eta).unsqueeze(0),
                    gen_jet_phi=torch.tensor(jets_coll["gen"][iev][igenjet].phi).unsqueeze(0),
                    gen_jet_energy=torch.tensor(jets_coll["gen"][iev][igenjet].energy).unsqueeze(0),

                    # could be part of the target
                    reco_jet_pt=torch.tensor(jets_coll["reco"][iev][irecojet].pt).unsqueeze(0),
                    reco_jet_eta=torch.tensor(jets_coll["reco"][iev][irecojet].eta).unsqueeze(0),
                    reco_jet_phi=torch.tensor(jets_coll["reco"][iev][irecojet].phi).unsqueeze(0),
                    reco_jet_energy=torch.tensor(jets_coll["reco"][iev][irecojet].energy).unsqueeze(0),

                    # Input
                    mlpfcands_momentum=torch.nn.functional.pad(
                        ymlpf["momentum"][iev][msk_indices],
                        (0,0,0,PADDIM-ymlpf["momentum"][iev][msk_indices].shape[0]),
                        value=0,
                    ),
                    mlpfcands_pid=torch.nn.functional.pad(
                        ymlpf["cls_id_onehot"][iev][msk_indices],
                        (0,0,0,PADDIM-ymlpf["cls_id_onehot"][iev][msk_indices].shape[0]),
                        value=0,
                    ),
                    mlpfcands_charge=torch.nn.functional.pad(
                        ymlpf["charge"][iev][msk_indices],
                        (0,0,0,PADDIM-ymlpf["charge"][iev][msk_indices].shape[0]),
                        value=0,
                    ),                    
                    mlpfcands_latentX=torch.nn.functional.pad(
                        ymlpf["latentX"][iev][msk_indices],
                        (0,0,0,PADDIM-ymlpf["latentX"][iev][msk_indices].shape[0]),
                        value=0,
                    )
                )
            ]    

#             break  # per jet
#         break   # per event

    torch.save(jet_dataset, f"/pfvol/jetdataset/{sample}/train/{i}.pt")

    break    # per batch

100%|████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:39<00:00,  2.50it/s]


In [543]:
jet_dataset[0].keys()

dict_keys(['gen_jet_label', 'gen_jet_pt', 'gen_jet_eta', 'gen_jet_phi', 'gen_jet_energy', 'reco_jet_pt', 'reco_jet_eta', 'reco_jet_phi', 'reco_jet_energy', 'mlpfcands_momentum', 'mlpfcands_pid', 'mlpfcands_charge', 'mlpfcands_latentX'])

# Load the dataset

In [552]:
! du -sh /pfvol/jetdataset/clic_edm_ttbar_pf/train/0.pt

271M	/pfvol/jetdataset/clic_edm_ttbar_pf/train/0.pt


In [546]:
# load one of the train files
jet_dataset = torch.load("/pfvol/jetdataset/clic_edm_ttbar_pf/train/0.pt")

In [547]:
jet_dataset[0].keys()

dict_keys(['gen_jet_label', 'gen_jet_pt', 'gen_jet_eta', 'gen_jet_phi', 'gen_jet_energy', 'reco_jet_pt', 'reco_jet_eta', 'reco_jet_phi', 'reco_jet_energy', 'mlpfcands_momentum', 'mlpfcands_pid', 'mlpfcands_charge', 'mlpfcands_latentX'])

In [548]:
jetloader = torch.utils.data.dataloader.DataLoader(jet_dataset, batch_size=10)

In [549]:
for batch in jetloader:
    break

In [550]:
batch["gen_jet_pt"]

tensor([[22.4181],
        [23.0416],
        [27.6135],
        [62.2044],
        [21.0416],
        [52.7170],
        [62.5815],
        [56.3452],
        [80.0007],
        [29.5156]], dtype=torch.float64)

In [551]:
batch["reco_jet_pt"]

tensor([[23.6038],
        [24.1894],
        [34.3942],
        [66.5854],
        [21.3737],
        [59.5791],
        [70.3020],
        [59.1744],
        [82.2939],
        [30.3395]], dtype=torch.float64)

# Setup the downstream task

In [573]:
import torch.nn as nn

def ffn(input_dim, output_dim, width, act, dropout):
    return nn.Sequential(
        nn.Linear(input_dim, width),
        act(),
        torch.nn.LayerNorm(width),
        nn.Dropout(dropout),
        nn.Linear(width, output_dim),
    )

class JetRegressor(nn.Module):
    def __init__(
        self,
        input_dim=14,
        embedding_dim=64,
        output_dim=1,
        width=256,
        dropout=0,
    ):
        super(JetRegressor, self).__init__()

        """
        Takes as input either (1) the MLPF candidates OR (2) the latent representations of the MLPF candidates,
        and runs an MLP to predict an output per jet: "ptcorr"; which will enter the loss as follows:
            pred_jetpt = ptcorr * reco_pt

            LOSS = Huber(true_jetpt, pred_jetpt)

        """

        self.act = nn.ELU
        self.nn1 = ffn(input_dim, embedding_dim, width, self.act, dropout)
        self.nn2 = ffn(embedding_dim, output_dim, width, self.act, dropout)

    # @torch.compile
    def forward(self, X):

        embeddings = self.nn1(X)
        
        pooled_embeddings = embeddings.sum(axis=1)   # recall ~ [Batch, Particles, Features]

        return self.nn2(pooled_embeddings)

In [574]:
run_with_latentX = True

if run_with_latentX:
    input_dim = 791
else:
    input_dim = 14    
    
model = JetRegressor(input_dim).to(rank)
model.train()

JetRegressor(
  (nn1): Sequential(
    (0): Linear(in_features=791, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (3): Dropout(p=0, inplace=False)
    (4): Linear(in_features=256, out_features=64, bias=True)
  )
  (nn2): Sequential(
    (0): Linear(in_features=64, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (3): Dropout(p=0, inplace=False)
    (4): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [585]:
for batch in jetloader:

    if run_with_latentX:
        X = batch["mlpfcands_latentX"].to(rank)
    else:
        X = torch.cat([batch["mlpfcands_momentum"], batch["mlpfcands_pid"], batch["mlpfcands_charge"]], axis=-1).to(rank)
    
    ptcorr = model(X).cpu()
    
    target = torch.log(batch["gen_jet_pt"] / batch["reco_jet_pt"])

    loss = torch.nn.functional.huber_loss(target, ptcorr)
    
    break
print(loss)

tensor(0.0240, dtype=torch.float64, grad_fn=<HuberLossBackward0>)
