In [None]:
%load_ext autoreload
%autoreload 2

import logging
from pathlib import Path

import matplotlib.pyplot as plt

import awkward as ak
import numpy as np
import vector
from omegaconf import OmegaConf

import torch
import gabbro
import gabbro.plotting.utils as plot_utils
from gabbro.plotting.feature_plotting import plot_features
from gabbro.utils.arrays import ak_select_and_preprocess
import tqdm

# hacky way to setup logging in jupyter
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig(level=logging.INFO)
logger.info("Setup complete")

vector.register_awkward()

In [None]:
device = torch.device("cpu")

In [None]:
def reinitialize_p4(p4_obj: ak.Array):
    if "tau" in p4_obj.fields:
        p4 = vector.awk(
            ak.zip(
                {
                    "mass": p4_obj.tau,
                    "x": p4_obj.x,
                    "y": p4_obj.y,
                    "z": p4_obj.z,
                }
            )
        )
    else:
        p4 = vector.awk(
            ak.zip(
                {
                    "energy": p4_obj.t,
                    "x": p4_obj.x,
                    "y": p4_obj.y,
                    "z": p4_obj.z,
                }
            )
        )
    return p4

def deltaPhi(phi1, phi2):
    diff = phi1 - phi2
    return np.arctan2(np.sin(diff), np.cos(diff))

def deltaR_etaPhi(eta1, phi1, eta2, phi2):
    deta = np.abs(eta1 - eta2)
    dphi = deltaPhi(phi1, phi2)
    return np.sqrt(deta**2 + dphi**2)

def deltaEta(eta1, eta2):
    return np.abs(eta1 - eta2)

def stack_and_pad_features(cand_features, max_cands):
    cand_features_tensors = np.stack([ak.pad_none(cand_features[feat], max_cands, clip=True) for feat in cand_features.fields], axis=-1)
    cand_features_tensors = ak.to_numpy(ak.fill_none(cand_features_tensors, 0))
    # Swapping the axes such that it has the shape of (nJets, nFeatures, nParticles)
    cand_features_tensors = np.swapaxes(cand_features_tensors, 1, 2)

    cand_features_tensors[np.isnan(cand_features_tensors)] = 0
    cand_features_tensors[np.isinf(cand_features_tensors)] = 0
    return cand_features_tensors

# Tokenization with the VQ-VAE

This notebook provides a short example on how to

- load a trained VQ-VAE model that was trained with this repo
- use the model to encode/tokenize jets from the JetClass dataset
- reconstruct/decode the jets from the tokens

In [None]:
# --- Load the tokenizer model from checkpoint, and also get the feature_dict from the config ---

from gabbro.models.vqvae import VQVAELightning

# this checkpoint is the checkpoint from a tokenization training
ckpt_path = "../../checkpoints/vqvae_8192_tokens/model_ckpt.ckpt"
vqvae_model = VQVAELightning.load_from_checkpoint(ckpt_path).to(device)
vqvae_model.eval()

cfg = OmegaConf.load(Path(ckpt_path).parent / "config.yaml")
pp_dict = OmegaConf.to_container(cfg.data.dataset_kwargs_common.feature_dict)
print("\npp_dict:")
for item in pp_dict:
    print(item, pp_dict[item])

# get the cuts from the pp_dict (since this leads to particles being removed during
# preprocessing/tokenization), thus we also have to remove them from the original jets
# when we compare the tokenized+reconstructed particles to the original ones)
pp_dict_cuts = {
    feat_name: {
        criterion: pp_dict[feat_name].get(criterion)
        for criterion in ["larger_than", "smaller_than"]
    }
    for feat_name in pp_dict
}

print("\npp_dict_cuts:")
for item in pp_dict_cuts:
    print(item, pp_dict_cuts[item])

print("\nModel:")
print(vqvae_model)

In [None]:
x = torch.randn((8,32,3))
mask = torch.randn((8,32)).bool()

In [None]:
x_particle_reco, vq_out = vqvae_model.model.forward(x, mask)

In [None]:
vq_out["q"].shape

### Tokenize and reconstruct the jets

In [None]:
tau_data = ak.from_parquet("/local/joosep/ml-tau-en-reg/ntuples/20240701_lowered_ptcut_merged/z_train.parquet")
tau_data = tau_data[:10000]

In [None]:
part_p4 = reinitialize_p4(tau_data["reco_cand_p4s"])
jet_p4 = reinitialize_p4(tau_data["reco_jet_p4s"])
tau_data_transf = ak.Array({
    "part_pt": part_p4.pt,
    "part_etarel": part_p4.eta - jet_p4.eta,
    "part_phirel": deltaPhi(part_p4.phi, jet_p4.phi),
})

In [None]:
model_params = filter(lambda p: p.requires_grad, vqvae_model.parameters())
num_trainable_weights = sum([np.prod(p.size()) for p in model_params])
num_trainable_weights

In [None]:
# tokenization and reconstruction

part_features_ak_tokenized = vqvae_model.tokenize_ak_array(
    ak_arr=tau_data_transf,
    pp_dict=pp_dict,
    batch_size=512,
    pad_length=128,
)
# note that if you want to reconstruct tokens from the generative model, you'll have
# to remove the start token from the tokenized array, and subtract 1 from the tokens
# (since we chose the convention to use 0 as the start token, so the tokens from the
# generative model are shifted by 1 compared to the ones from the VQ-VAE)
part_features_ak_reco = vqvae_model.reconstruct_ak_tokens(
    tokens_ak=part_features_ak_tokenized,
    pp_dict=pp_dict,
    batch_size=512,
    pad_length=128,
)

In [None]:
# inspect the tokenized and reconstructed jets
print("First 5 tokenized jets:")
for i in range(5):
    print(part_features_ak_tokenized[i])

print("\nFirst 5 reconstructed jets:")
for i in range(5):
    print(part_features_ak_reco[i])

In [None]:
plt.hist(ak.num(tau_data_transf.part_pt)-ak.num(part_features_ak_tokenized), bins=np.linspace(0,10,11));
plt.xticks(np.linspace(0,10,11))
plt.yscale("log")
plt.xlabel("Nptcls - Ntokens")
plt.ylabel("jets / bin")

### Calculate the four-momentum of the reconstructed jets and make comparison plots

In [None]:
def get_p4s_from_part_features(part_features):
    """Small helper function to get the 4-momentum from part_features."""
    return ak.zip(
        {
            "pt": part_features.part_pt,
            "eta": part_features.part_etarel,
            "phi": part_features.part_phirel,
            "mass": ak.zeros_like(part_features.part_pt),  # massless particles
        },
        with_name="Momentum4D",
    )

p4s_original = get_p4s_from_part_features(tau_data_transf)
p4s_reco = get_p4s_from_part_features(part_features_ak_reco)

In [None]:
b = np.linspace(-3,3,100)
plt.hist(ak.flatten(tau_data_transf.part_phirel), bins=b, histtype="step");
plt.hist(ak.flatten(part_features_ak_reco.part_phirel), bins=b, histtype="step");
#plt.yscale("log")

In [None]:
p4s_original[1]

In [None]:
p4s_reco[1]

In [None]:
# plot inclusive jet-level distributions
p4s_jets_original = ak.sum(p4s_original, axis=1)
p4s_jets_reco = ak.sum(p4s_reco, axis=1)

fig, axarr = plot_features(
    ak_array_dict={
        "Original jets": p4s_jets_original,
        "Reconstructed jets": p4s_jets_reco,
    },
    names={
        "pt": plot_utils.DEFAULT_LABELS["jet_pt"],
        "eta": plot_utils.DEFAULT_LABELS["jet_eta"],
        "phi": plot_utils.DEFAULT_LABELS["jet_phi"],
        "mass": plot_utils.DEFAULT_LABELS["jet_mass"],
    },
    flatten=False,
    decorate_ax_kwargs={"yscale": 1.7},
    bins_dict={
        "pt": np.linspace(0, 200, 100),
        "eta": np.linspace(-0.1, 0.1, 100),
        "phi": np.linspace(-0.1, 0.1, 100),
        "mass": np.linspace(0, 10, 100),
    },
)
# plot the resolution (i.e. jet features of the reconstructed jets - jet features of the original jets)
fig, axarr = plot_features(
    ak_array_dict={
        "Difference": ak.Array(
            {
                "pt": p4s_jets_reco.pt - p4s_jets_original.pt,
                "eta": p4s_jets_reco.eta - p4s_jets_original.eta,
                "phi": deltaPhi(p4s_jets_reco.phi, p4s_jets_original.phi),
                "mass": p4s_jets_reco.mass - p4s_jets_original.mass,
            }
        )
    },
    names={
        "pt": "Jet $p_T^{\\text{reco}} - p_T^{\\text{orig}}$",
        "eta": "Jet $\\eta^{\\text{reco}} - \\eta^{\\text{orig}}$",
        "phi": "Jet $\\phi^{\\text{reco}} - \\phi^{\\text{orig}}$",
        "mass": "Jet $m^{\\text{reco}} - m^{\\text{orig}}$",
    },
    flatten=False,
    decorate_ax_kwargs={"yscale": 1.7},
    bins_dict={
        "pt": np.linspace(-15, 15, 100),
        "eta": np.linspace(-0.05, 0.05, 100),
        "phi": np.linspace(-0.05, 0.05, 100),
        "mass": np.linspace(-5, 5, 100),
    },
    colors=["C2"],
)

In [None]:
b = np.linspace(-0.1, 0.1, 100)
plt.hist(p4s_jets_original.phi, bins=b, histtype="step", lw=1);
plt.hist(p4s_jets_reco.phi, bins=b, histtype="step", lw=1);
plt.yscale("log")

## Pretrained backbone

In [None]:
import copy
import math
import torch
from functools import partial
import torch.nn as nn
loaded_model = torch.load("../../checkpoints/generative_8192_tokens/OmniJet_generative_model_UnintentionalPinscher_59.ckpt", map_location=torch.device('cpu'))
from gabbro.models.gpt_model import BackboneModel

In [None]:
bb_model = BackboneModel(256, 0.0, 8194, 128, 32, 3)

In [None]:
bb_model

In [None]:
gpt_state = {k.replace("module.", ""): v for k, v in loaded_model["state_dict"].items() if k.startswith("module.")}

In [None]:
bb_model.load_state_dict(gpt_state)

In [None]:
part_features_ak_tokenized[0:8]

In [None]:
jets_padded_tokenized = ak.fill_none(ak.pad_none(part_features_ak_tokenized[0:8], 32), 0)

In [None]:
jets_batch = ak.to_regular(jets_padded_tokenized)
jets_batch = torch.tensor(jets_batch).long()
padding_mask = jets_batch==0

In [None]:
encoded_jets = bb_model(jets_batch, padding_mask=padding_mask)

In [None]:
encoded_jets.shape