In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
import awkward as ak
import energyflow as ef
import fastjet as fj
import h5py
import vector

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
# plots and metrics
import matplotlib.pyplot as plt

# from src.data.components import calculate_all_wasserstein_metrics
from src.utils.data_generation import generate_data
from src.utils.plot.lhco_plotting import plot_unprocessed_data_lhco
from src.utils.plotting import apply_mpl_styles, plot_data, prepare_data_for_plotting

apply_mpl_styles()

In [None]:
experiment = "lhco.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    # print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
datamodule.setup()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
test_cond = np.array(datamodule.tensor_conditioning_test)
val_data = np.array(datamodule.tensor_val)
val_mask = np.array(datamodule.mask_val)
val_cond = np.array(datamodule.tensor_conditioning_val)
train_data = np.array(datamodule.tensor_train)
train_mask = np.array(datamodule.mask_train)
train_cond = np.array(datamodule.tensor_conditioning_train)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

In [None]:
print(test_data.shape)
print(test_mask.shape)
print(test_cond.shape)
print(val_data.shape)
print(val_mask.shape)
print(val_cond.shape)
print(train_data.shape)
print(train_mask.shape)
print(train_cond.shape)
print(means)
print(stds)

In [None]:
ckpt = (
    "/beegfs/desy/user/ewencedr/deep-learning/logs/lhco unprocessed data"
    " norm/runs/2023-08-01_15-23-47/checkpoints/last-EMA.ckpt"
)
model = model.load_from_checkpoint(ckpt)

In [None]:
samples = 10000

In [None]:
torch.manual_seed(9999)
data, generation_time = generate_data(
    model,
    num_jet_samples=samples,
    batch_size=2048,
    cond=None,
    variable_set_sizes=True,
    mask=torch.tensor(test_mask[:samples]),
    normalized_data=True,
    means=means,
    stds=stds,
    ode_solver="midpoint",
    ode_steps=50,
)

In [None]:
data_full = data[..., [2, 0, 1]]
background_data = test_data[..., [2, 0, 1]]

In [None]:
fig = plot_unprocessed_data_lhco(
    particle_data=np.array([data_full]),
    sim_data=background_data[:samples],
    plottype="",
    save_fig=False,
)

In [None]:
# to awkard array
zrs = np.zeros((data_full.shape[0], data_full.shape[1], 1))
data_with_mass = np.concatenate((data_full, zrs), axis=2)
awkward_data = ak.from_numpy(data_with_mass)

In [None]:
# tell awkward that the data is in eta, phi, pt, mass format
vector.register_awkward()
unmasked_data = ak.zip(
    {
        "pt": awkward_data[:, :, 0],
        "eta": awkward_data[:, :, 1],
        "phi": awkward_data[:, :, 2],
        "mass": awkward_data[:, :, 3],
    },
    with_name="Momentum4D",
)
print(unmasked_data.type)

In [None]:
# remove the padded data points
data = ak.drop_none(ak.mask(unmasked_data, unmasked_data.pt != 0))
print(data.type)

In [None]:
jetdef = fj.JetDefinition(fj.antikt_algorithm, 1.0)

In [None]:
cluster = fj.ClusterSequence(data, jetdef)

In [None]:
# get jets and constituents
jets_out = cluster.inclusive_jets()
consts_out = cluster.constituents()

In [None]:
# define a function to sort ak.Array by pt
def sort_by_pt(data: ak.Array, ascending: bool = False, return_indices: bool = False):
    """Sort ak.Array by pt

    Args:
        data (ak.Array): array that should be sorted by pt. It should have a pt attribute.
        ascending (bool, optional): If True, the first value in each sorted group will be smallest; if False, the order is from largest to smallest. Defaults to False.
        return_indices (bool, optional): If True, the indices of the sorted array are returned. Defaults to False.

    Returns:
        ak.Array: sorted array
        ak.Array (optional): indices of the sorted array
    """
    if isinstance(data, ak.Array):
        try:
            temppt = data.pt
        except AttributeError:
            raise AttributeError(
                "Needs either correct coordinates or embedded vector backend"
            ) from None
    tmpsort = ak.argsort(temppt, axis=-1, ascending=ascending)
    if return_indices:
        return data[tmpsort], tmpsort
    else:
        return data[tmpsort]

In [None]:
# sort jets and constituents by pt
jets_sorted, idxs = sort_by_pt(jets_out, return_indices=True)
consts_sorted_jets = consts_out[idxs]
consts_sorted = sort_by_pt(consts_sorted_jets)

In [None]:
# only take the first 2 highest pt jets
jets_awk = jets_sorted[:, :2]
consts_awk = consts_sorted[:, :2]

In [None]:
# get max. number of constituents in an event
max_consts = int(ak.max(ak.num(consts_awk, axis=-1)))
print(max_consts)
max_consts = 279

In [None]:
# pad the data with zeros to make them all the same length
zero_padding = ak.zip({"pt": 0.0, "eta": 0.0, "phi": 0.0, "mass": 0.0}, with_name="Momentum4D")
padded_consts = ak.fill_none(
    ak.pad_none(consts_awk, max_consts, clip=True, axis=-1), zero_padding, axis=-1
)
print(padded_consts.type)

In [None]:
# go back to numpy arrays
pt, eta, phi, mass = ak.unzip(padded_consts)
pt_np = ak.to_numpy(pt)
eta_np = ak.to_numpy(eta)
phi_np = ak.to_numpy(phi)
consts = np.stack((pt_np, eta_np, phi_np), axis=-1)
print(consts.shape)

In [None]:
# calculate mask for jet constituents
mask = np.expand_dims((consts[..., 0] > 0).astype(int), axis=-1)
print(mask.shape)

In [None]:
# get numpy arrays for jet data
jets_pt_np = ak.to_numpy(jets_awk.pt)
jets_eta_np = ak.to_numpy(jets_awk.eta)
jets_phi_np = ak.to_numpy(jets_awk.phi)
jets_m_np = ak.to_numpy(jets_awk.m)
jets = np.stack((jets_pt_np, jets_eta_np, jets_phi_np, jets_m_np), axis=-1)
print(jets.shape)

## calculate jet features

In [None]:
def get_jet_data(consts: np.ndarray) -> np.ndarray:
    """Calculate jet data from constituent data. (pt, y, phi)->(pt, y, phi, m)

    Args:
        consts (np.ndarray): constituent data. (pt, y, phi)

    Returns:
        np.ndarray: jet data. (pt, y, phi, m)
    """
    p4s = ef.p4s_from_ptyphims(consts[..., :3])
    sum_p4 = np.sum(p4s, axis=-2)
    jet_data = ef.ptyphims_from_p4s(sum_p4, phi_ref=0)
    return jet_data

In [None]:
x_consts = consts[:, 0]
y_consts = consts[:, 1]
x_jets = get_jet_data(x_consts)
y_jets = get_jet_data(y_consts)
print(x_jets.shape)
print(x_consts.shape)

# Compare to originally clustered data

## load jet features

In [None]:
path = "/beegfs/desy/user/ewencedr/data/lhco/final_data/processed_data_background.h5"

In [None]:
with h5py.File(path, "r") as f:
    jet_data = f["jet_data"][:]
    particle_data = f["constituents"][:]
    mask_ref = f["mask"][:]

In [None]:
print(jet_data.shape)
print(particle_data.shape)

In [None]:
n_samp = 10000

In [None]:
x_jets_ref = jet_data[:n_samp, 0]
y_jets_ref = jet_data[:n_samp, 1]
x_consts_ref = particle_data[:n_samp, 0]
y_consts_ref = particle_data[:n_samp, 1]

## X-Jets

In [None]:
data = y_consts[..., [1, 2, 0]]
background_data = y_consts_ref[..., [1, 2, 0]]

In [None]:
plot_config = {
    "num_samples": -1,
    "plot_jet_features": True,
    "plot_w_dists": False,
    "plot_efps": False,
    "plot_selected_multiplicities": False,
    "selected_multiplicities": [10, 20, 30, 40, 50, 100],
    "selected_particles": [1, 3, 10],
    "plottype": "",
    "save_fig": False,
    "variable_jet_sizes_plotting": False,
    "bins": 100,
    "close_fig": False,
    "labels": ["test"],
}
plot_prep_config = {
    "calculate_efps" if key == "plot_efps" else key: value
    for key, value in plot_config.items()
    if key in ["plot_efps", "selected_particles", "selected_multiplicities"]
}

In [None]:
(
    jet_data,
    efps_values,
    pt_selected_particles,
    pt_selected_multiplicities,
) = prepare_data_for_plotting(
    np.array([data]),
    **plot_prep_config,
)

In [None]:
(
    jet_data_sim,
    efps_sim,
    pt_selected_particles_sim,
    pt_selected_multiplicities_sim,
) = prepare_data_for_plotting(
    [background_data],
    **plot_prep_config,
)
jet_data_sim, efps_sim, pt_selected_particles_sim = (
    jet_data_sim[0],
    efps_sim[0],
    pt_selected_particles_sim[0],
)

In [None]:
fig = plot_data(
    particle_data=np.array([data]),
    sim_data=background_data,
    jet_data_sim=jet_data_sim,
    jet_data=jet_data,
    efps_sim=efps_sim,
    efps_values=efps_values,
    pt_selected_particles=pt_selected_particles,
    pt_selected_multiplicities=pt_selected_multiplicities,
    pt_selected_particles_sim=pt_selected_particles_sim,
    pt_selected_multiplicities_sim=pt_selected_multiplicities_sim,
    **plot_config,
)

## Y-Jets