In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

%matplotlib inline
%config InlineBackend.figure_format='retina'

import hydra
import numpy as np
import pytorch_lightning as pl
import torch

# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv
from omegaconf import OmegaConf

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

# plots and metrics
import matplotlib.pyplot as plt

from src.data.components import calculate_all_wasserstein_metrics
from src.utils.data_generation import generate_data
from src.utils.plotting import apply_mpl_styles, create_and_plot_data, plot_single_jets

apply_mpl_styles()

In [None]:
experiment = "jetclass_cond.yaml"
model_name_for_saving = "nb_fm_tops_jetclass"

# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    print(OmegaConf.to_yaml(cfg))

datamodule = hydra.utils.instantiate(cfg.data)
# set remove_etadiff_tails=False when checking the pT_jet distribution calculated from particle pT
# datamodule.hparams.remove_etadiff_tails = False
model = hydra.utils.instantiate(cfg.model)
datamodule.setup()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
test_cond = np.array(datamodule.tensor_conditioning_test)
val_data = np.array(datamodule.tensor_val)
val_mask = np.array(datamodule.mask_val)
val_cond = np.array(datamodule.tensor_conditioning_val)
train_data = np.array(datamodule.tensor_train)
train_mask = np.array(datamodule.mask_train)
train_cond = np.array(datamodule.tensor_conditioning_train)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

print(test_data.shape)
print(test_mask.shape)
print(test_cond.shape)
print(val_data.shape)
print(val_mask.shape)
print(val_cond.shape)
print(train_data.shape)
print(train_mask.shape)
print(train_cond.shape)
print(means)
print(stds)

In [None]:
ckpt = "/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_flow_matching/runs/2023-07-27_23-46-46/checkpoints/last-EMA.ckpt"
model = model.load_from_checkpoint(ckpt)

In [None]:
factor = 1
# chosse between test and val
mask_real = test_mask
data_real = test_data
cond_real = test_cond

# increase size for better statistics
big_mask_real = np.repeat(mask_real, factor, axis=0)
big_data_real = np.repeat(data_real, factor, axis=0)
big_cond_real = np.repeat(cond_real, factor, axis=0)

In [None]:
data_generated, generation_time = generate_data(
    model,
    num_jet_samples=factor * len(mask_real),
    batch_size=1000,
    cond=torch.tensor(big_cond_real),
    variable_set_sizes=True,
    mask=torch.tensor(big_mask_real),
    normalized_data=False,
    means=means,
    stds=stds,
    ode_solver="midpoint",
    ode_steps=200,
)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax = ax.flatten()
hist_kwargs = dict(bins=100, alpha=0.5, density=True)
for i in range(3):
    ax[i].hist(data_real[:, :, i][mask_real[:, :, 0] != 0].flatten(), **hist_kwargs, label="real")
    ax[i].hist(
        data_generated[:, :, i][mask_real[:, :, 0] != 0].flatten(),
        **hist_kwargs,
        label="generated",
        histtype="step",
    )
    ax[i].set_yscale("log")
ax[2].legend(frameon=False)
fig.tight_layout()

In [None]:
w_dists_big = calculate_all_wasserstein_metrics(
    data_real[..., :3],
    data_generated[..., :3],
    None,
    None,
    num_eval_samples=len(data_real),
    num_batches=factor,
    calculate_efps=True,
    use_masks=False,
)

print(f"W-Dist m: {w_dists_big['w1m_mean']:4.3E} +- {w_dists_big['w1m_std']:4.3E}")
print(f"W-Dist p: {w_dists_big['w1p_mean']:4.3E} +- {w_dists_big['w1p_std']:4.3E}")
print(f"W-Dist efp: {w_dists_big['w1efp_mean']:4.3E} +- {w_dists_big['w1efp_std']:4.3E}")

In [None]:
w_dists_big_non_averaged = calculate_all_wasserstein_metrics(
    data_real[..., :3],
    data_generated[..., :3],
    None,
    None,
    num_eval_samples=len(data_real),
    num_batches=factor,
    calculate_efps=True,
    use_masks=False,
)
w_dists_big_non_averaged

In [None]:
cond_real_repeat = np.repeat(cond_real[:, np.newaxis, :], mask_real.shape[1], axis=1)
cond_real_repeat.shape

In [None]:
# Crosscheck plots:
# - pT_particle / pT_jet (as in dataset)
# - pT_particle when rescaled with jet pT
# - pT_jet when calculated from constituents
#
# - jet mass calculated from rescaled pT_particle and eta_rel, phi_rel
# - jet mass calculated from pT_rel, eta_rel, phi_rel

from copy import deepcopy

from src.data.components.utils import calculate_jet_features

fig, ax = plt.subplots(1, 3, figsize=(15, 4))
hist_kwargs = dict(bins=100, histtype="step")

# make copy of particle features
particle_features = deepcopy(data_real)

# re-scale particle pt with jet pt
particle_features[:, :, 2] *= cond_real_repeat[:, :, 0]

# calculate jet features (both with pT_rel and pT)
jet_features_rel = calculate_jet_features(data_real)  # pT_rel
jet_features = calculate_jet_features(particle_features)  # pT

# Note: the jet pt which is calculated from the constituent pt does not
#       yield exactly the same distribution if the etadiff tails are removed!
#       the distributions should match though when using all constituents.
ax[0].hist(data_real[:, :, 2][mask_real[:, :, 0] != 0].flatten(), **hist_kwargs)
ax[0].set_xlabel("$p_T^{particle} / p_T^{jet}$")
ax[1].hist(particle_features[:, :, 2][mask_real[:, :, 0] != 0].flatten(), **hist_kwargs)
ax[1].set_xlabel("$p_T^{particle}$")
ax[0].set_yscale("log")
ax[1].set_yscale("log")
ax[2].hist(jet_features[:, 0], **hist_kwargs, label="Calculated from $p_T^{particle}$")
ax[2].hist(cond_real[:, 0], **hist_kwargs, label="Original value", ls="--")
ax[2].legend(frameon=False)
ax[2].set_xlabel("$p_T^{jet}$")
fig.tight_layout()
plt.show()

fig, ax = plt.subplots(1, 2, figsize=(15, 4))
ax[0].hist(jet_features[:, 3], **hist_kwargs, label="Calculated from $p_T^{particle}$")
ax[0].set_xlabel("$m_{jet}$ - using $p_T^{particle}$")
ax[1].hist(
    jet_features_rel[:, 3], **hist_kwargs, label="Calculated from $p_T^{particle} / p_T^{jet}$"
)
ax[1].set_xlabel("$m_{jet}$ - using $p_T^{particle} / p_T^{jet}$")
fig.tight_layout()
plt.show()