# LHCO Cathode Generation Pipeline
After the particle level models and the jet feature models have been trained, the final step is to run the whole generation pipeline. This is the purpose of this notebook.

## Imports

In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from os.path import join

import energyflow as ef
import h5py
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from sklearn.neighbors import KernelDensity

In [None]:
# plots and metrics
import matplotlib.pyplot as plt

from src.data.components import calculate_all_wasserstein_metrics, normalize_tensor
from src.utils.data_generation import generate_data
from src.utils.plotting import apply_mpl_styles, plot_data, prepare_data_for_plotting

apply_mpl_styles()

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

# Generate mjj samples
We fit a KDE to the mjj distribution of the signal and background samples. We then sample from the KDE to generate new mjj samples in the signal region.

In [None]:
path = "/beegfs/desy/user/ewencedr/data/lhco/final_data/processed_data_background_rel.h5"
with h5py.File(path, "r") as f:
    jets = f["jet_data"][:]

In [None]:
p4_jets = ef.p4s_from_ptyphims(jets)

In [None]:
sum_p4 = p4_jets[:, 0] + p4_jets[:, 1]
mjj = ef.ms_from_p4s(sum_p4)

In [None]:
args_to_keep = ((mjj < 3300) & (mjj > 2300)) | ((mjj > 3700) & (mjj < 5000))
args_to_keep_sr = (mjj > 3300) & (mjj < 3700)

In [None]:
mjj_sb = mjj[args_to_keep]
mjj_sr = mjj[args_to_keep_sr]
args_to_keep_sb_sr = args_to_keep | args_to_keep_sr
mjj_sb_sr = mjj[args_to_keep_sb_sr]

In [None]:
hist = plt.hist(
    mjj, bins=np.arange(1e3, 9.5e3, 0.1e3), histtype="stepfilled", label="mjj", alpha=0.5
)
plt.hist(mjj_sb, bins=hist[1], histtype="step", label="mjj SB")
plt.hist(mjj_sr, bins=hist[1], histtype="step", label="mjj SR")

plt.legend()
plt.yscale("log")
plt.show()

In [None]:
hist = plt.hist(
    mjj, bins=np.arange(1e3, 9.5e3, 0.1e3), histtype="stepfilled", label="mjj", alpha=0.5
)
plt.hist(mjj_sb_sr, bins=hist[1], histtype="step", label="mjj SB SR")

plt.legend()
plt.yscale("log")
plt.show()

### fit KDE on SR and SB

In [None]:
kde_model_sb_sr = KernelDensity(kernel="gaussian", bandwidth=0.001)
kde_model_sb_sr.fit(mjj_sb_sr.reshape(-1, 1))

samples_sb_sr = kde_model_sb_sr.sample(len(mjj_sb_sr))

In [None]:
hist = plt.hist(
    mjj_sb_sr,
    bins=np.arange(1e3, 9.5e3, 0.05e3),
    histtype="stepfilled",
    label="Truth",
    alpha=0.5,
)
plt.hist(samples_sb_sr, bins=hist[1], histtype="step", label="KDE samples")
plt.xlabel("mjj [GeV]")
plt.ylim(1e-1, 1e5)
plt.legend(frameon=False)
plt.yscale("log")
plt.show()

### only take SR data

In [None]:
args_to_keep_sr_samples = (samples_sb_sr > 3300) & (samples_sb_sr < 3700)
mjj_samples_sr = samples_sb_sr[args_to_keep_sr_samples]

In [None]:
hist = plt.hist(
    mjj_sr,
    bins=np.arange(1e3, 9.5e3, 0.1e3),
    histtype="stepfilled",
    label="Truth",
    alpha=0.5,
)
plt.hist(mjj_samples_sr, bins=hist[1], histtype="step", label="KDE samples")
plt.ylim(1e-1, 1e5)
plt.xlabel("mjj [GeV]")
plt.legend(frameon=False)
plt.yscale("log")
plt.show()

# Generate from Jet Feature Model

### Load model

In [None]:
experiment = "/lhco/lhco_jet_features.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    # print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
datamodule.setup()

In [None]:
test_data_sr = np.array(datamodule.tensor_test_sr)
test_cond_sr = np.array(datamodule.tensor_conditioning_test_sr)
val_data_sr = np.array(datamodule.tensor_val_sr)
val_cond_sr = np.array(datamodule.tensor_conditioning_val_sr)
train_data_sr = np.array(datamodule.tensor_train_sr)
train_cond_sr = np.array(datamodule.tensor_conditioning_train_sr)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)
means_cond = np.array(datamodule.cond_means)
stds_cond = np.array(datamodule.cond_stds)

### Load checkpoint

In [None]:
ckpt = (
    "/beegfs/desy/user/ewencedr/deep-learning/logs/lhco jet features with particle"
    " multiplicity/runs/2023-08-16_14-58-31/checkpoints/last-EMA.ckpt"
)
model = model.load_from_checkpoint(ckpt)

### Generate Data

In [None]:
n_samples = 20000

In [None]:
# normalize conditioning variables
normalized_cond = normalize_tensor(
    torch.tensor(mjj_samples_sr, dtype=torch.float).clone().unsqueeze(-1),
    means_cond,
    stds_cond,
    datamodule.hparams.normalize_sigma,
)

In [None]:
torch.manual_seed(9999)
data_jet_feature, generation_time = generate_data(
    model,
    num_jet_samples=n_samples,
    batch_size=2048,
    cond=normalized_cond[:n_samples],
    normalized_data=datamodule.hparams.normalize,
    means=datamodule.means,
    stds=datamodule.stds,
    ode_solver="midpoint",
    ode_steps=100,
)

In [None]:
label_map = {
    "0": r"${p_T}_1$",
    "1": r"$\eta_1$",
    "2": r"$\phi_1$",
    "3": r"$m_1$",
    "4": "Particle Multiplicity 1",
    "5": r"${p_T}_2$",
    "6": r"$\eta_2$",
    "7": r"$\phi_2$",
    "8": r"$m_2$",
    "9": "Particle Multiplicity 2",
}
fig, axs = plt.subplots(2, 5, figsize=(25, 11))
for index, ax in enumerate(axs.reshape(-1)):
    x_min, x_max = min(
        np.min(test_data_sr[:n_samples, index]), np.min(data_jet_feature[:n_samples, index])
    ), max(np.max(test_data_sr[:n_samples, index]), np.max(data_jet_feature[:n_samples, index]))
    if index == 4 or index == 9:
        bin_width = 1
        bins = range(int(x_min), int(x_max) + bin_width, bin_width)
    else:
        bins = 100
    hist1 = ax.hist(
        test_data_sr[:n_samples, index],
        bins=bins,
        label="train data",
        range=[x_min, x_max],
        alpha=0.5,
    )
    ax.hist(data_jet_feature[:n_samples, index], bins=hist1[1], label="generated", histtype="step")
    ax.set_xlabel(f"{label_map[str(index)]}")
    ax.set_yscale("log")
    if index == 2 or index == 7:
        ax.legend(frameon=False)
        ax.set_ylim(1e-1, 1e6)
plt.tight_layout()
plt.suptitle("Signal Region", fontsize=30)
fig.subplots_adjust(top=0.93)
plt.show()

# Particle Feature Model

### Load Models

In [None]:
experiment_x = "/lhco/lhco_x_jet.yaml"
experiment_y = "/lhco/lhco_y_jet.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg_x = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment_x}"])
    # print(OmegaConf.to_yaml(cfg_x))

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg_y = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment_y}"])
    # print(OmegaConf.to_yaml(cfg_y))

In [None]:
datamodule_x = hydra.utils.instantiate(cfg_x.data)
model_x = hydra.utils.instantiate(cfg_x.model)

In [None]:
datamodule_y = hydra.utils.instantiate(cfg_y.data)
model_y = hydra.utils.instantiate(cfg_y.model)

In [None]:
datamodule_x.setup()

In [None]:
datamodule_y.setup()

### Load checkpoint

In [None]:
ckpt_x = (
    "/beegfs/desy/user/ewencedr/deep-learning/logs/lhco x jet mass new cut"
    " interpolate/runs/2023-08-16_02-00-37/checkpoints/last-EMA.ckpt"
)
model_x = model_x.load_from_checkpoint(ckpt_x)

In [None]:
ckpt_y = (
    "/beegfs/desy/user/ewencedr/deep-learning/logs/lhco y jet mass new cut"
    " interpolate/runs/2023-08-16_03-49-00/checkpoints/last-EMA.ckpt"
)
model_y = model_y.load_from_checkpoint(ckpt_y)

### Generate Data

In [None]:
n_samples_x = 20000
n_samples_y = 20000

In [None]:
# normalize conditioning variables
cond_x = data_jet_feature[:, 0:4]
normalized_cond_x = normalize_tensor(
    torch.tensor(cond_x, dtype=torch.float32).clone(),
    datamodule_x.cond_means,
    datamodule_x.cond_stds,
    datamodule_x.hparams.normalize_sigma,
)

In [None]:
# normalize conditioning variables
cond_y = data_jet_feature[:, 5:9]
normalized_cond_y = normalize_tensor(
    torch.tensor(cond_y, dtype=torch.float32).clone(),
    datamodule_y.cond_means,
    datamodule_y.cond_stds,
    datamodule_y.hparams.normalize_sigma,
)

In [None]:
mask_x_ints = data_jet_feature[:, 4]

In [None]:
mask_y_ints = data_jet_feature[:, 9]

In [None]:
n_classes_x = datamodule_x.tensor_test.shape[1]
targets_x = np.rint(mask_x_ints).astype(int)
mask_x = np.expand_dims(np.tril(np.ones((n_classes_x, n_classes_x)), k=-1)[targets_x], axis=-1)

In [None]:
n_classes_y = datamodule_y.tensor_test.shape[1]
targets_y = np.rint(mask_y_ints).astype(int)
mask_y = np.expand_dims(np.tril(np.ones((n_classes_y, n_classes_y)), k=-1)[targets_y], axis=-1)

In [None]:
print(normalized_cond_x.shape)

In [None]:
torch.manual_seed(9999)
data_x, generation_time_x = generate_data(
    model_x,
    num_jet_samples=n_samples_x,
    batch_size=2048,
    cond=normalized_cond_x[:n_samples_x],
    variable_set_sizes=datamodule_x.hparams.variable_jet_sizes,
    mask=torch.tensor(mask_x, dtype=torch.int64),
    normalized_data=datamodule_x.hparams.normalize,
    means=datamodule_x.means,
    stds=datamodule_x.stds,
    ode_solver="midpoint",
    ode_steps=100,
)

In [None]:
torch.manual_seed(9999)
data_y, generation_time_y = generate_data(
    model_y,
    num_jet_samples=n_samples_y,
    batch_size=2048,
    cond=normalized_cond_y[:n_samples_y],
    variable_set_sizes=datamodule_y.hparams.variable_jet_sizes,
    mask=torch.tensor(mask_y, dtype=torch.int64),
    normalized_data=datamodule_y.hparams.normalize,
    means=datamodule_y.means,
    stds=datamodule_y.stds,
    ode_solver="midpoint",
    ode_steps=100,
)

### Evalutation

### Plots

In [None]:
data = data_y
background_data = np.array(datamodule_y.tensor_test_sr[: len(data)])

In [None]:
w_dists = calculate_all_wasserstein_metrics(
    background_data, data, num_eval_samples=50_000, num_batches=40
)

In [None]:
print(data.shape)
print(background_data.shape)

In [None]:
plot_config = {
    "num_samples": -1,
    "plot_jet_features": True,
    "plot_w_dists": False,
    "plot_efps": True,
    "plot_selected_multiplicities": False,
    "selected_multiplicities": [10, 20, 30, 40, 50, 100],
    "selected_particles": [1, 5, 20],
    "plottype": "sim_data",
    "save_fig": False,
    "variable_jet_sizes_plotting": True,
    "bins": 100,
    "close_fig": False,
}
plot_prep_config = {
    "calculate_efps" if key == "plot_efps" else key: value
    for key, value in plot_config.items()
    if key in ["plot_efps", "selected_particles", "selected_multiplicities"]
}

In [None]:
(
    jet_data,
    efps_values,
    pt_selected_particles,
    pt_selected_multiplicities,
) = prepare_data_for_plotting(
    np.array([data]),
    **plot_prep_config,
)

In [None]:
(
    jet_data_sim,
    efps_sim,
    pt_selected_particles_sim,
    pt_selected_multiplicities_sim,
) = prepare_data_for_plotting(
    [background_data],
    **plot_prep_config,
)
jet_data_sim, efps_sim, pt_selected_particles_sim = (
    jet_data_sim[0],
    efps_sim[0],
    pt_selected_particles_sim[0],
)

In [None]:
fig = plot_data(
    particle_data=np.array([data]),
    sim_data=background_data,
    jet_data_sim=jet_data_sim,
    jet_data=jet_data,
    efps_sim=efps_sim,
    efps_values=efps_values,
    pt_selected_particles=pt_selected_particles,
    pt_selected_multiplicities=pt_selected_multiplicities,
    pt_selected_particles_sim=pt_selected_particles_sim,
    pt_selected_multiplicities_sim=pt_selected_multiplicities_sim,
    **plot_config,
)

### Back to non-rel. Coordinates

In [None]:
print(cond_x.shape)
print(data_x.shape)