In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

%matplotlib inline
%config InlineBackend.figure_format='retina'

import hydra
import numpy as np
import pytorch_lightning as pl
import torch

# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv
from omegaconf import OmegaConf

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

import logging
from copy import deepcopy

import cplt

# plots and metrics
import matplotlib.pyplot as plt

from src.data.components import (
    calculate_all_wasserstein_metrics,
    inverse_normalize_tensor,
    normalize_tensor,
)
from src.data.components.utils import calculate_jet_features
from src.utils.data_generation import generate_data
from src.utils.plotting import (
    apply_mpl_styles,
    create_and_plot_data,
    plot_particle_features,
    plot_single_jets,
)

# set up logging for jupyter notebook
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.info("test")

apply_mpl_styles()

In [None]:
# specify here the path to the run directory of the model you want to evaluate
# run_dir = "/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-07_11-56-01"
run_dir = "/beegfs/desy/user/birkjosc/epic-fm/logs/jetclass_cond_jettype/runs/2023-08-16_19-21-45/"
cfg_backup_file = f"{run_dir}/config.yaml"

# -----------------------------------------------------------
# for backward-compatability: load the config file from the run directory and save it to the run directory
experiment = "jetclass_cond.yaml"
model_name_for_saving = "nb_fm_tops_jetclass"
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    if os.path.exists(cfg_backup_file):
        print("config file already exists --> loading from run directory")
    else:
        cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
        print(f"saving config file as {cfg_backup_file}")
        with open(cfg_backup_file, "w") as f:
            OmegaConf.save(cfg, f)
# -----------------------------------------------------------

# load everything from run directory (safer in terms of reproducing results)
cfg = OmegaConf.load(cfg_backup_file)
print(type(cfg))
print(OmegaConf.to_yaml(cfg))

datamodule = hydra.utils.instantiate(cfg.data)
datamodule.hparams.number_of_used_jets = 100_000
# set remove_etadiff_tails=False when checking the pT_jet distribution calculated from particle pT
# datamodule.hparams.remove_etadiff_tails = False
model = hydra.utils.instantiate(cfg.model)
# load the model from the checkpoint
ckpt = f"{run_dir}/checkpoints/last-EMA.ckpt"
model = model.load_from_checkpoint(ckpt)
datamodule.setup()

# ------------------------------------------------
# Some printouts about shape to check if it's what we expect
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
test_cond = np.array(datamodule.tensor_conditioning_test)
val_data = np.array(datamodule.tensor_val)
val_mask = np.array(datamodule.mask_val)
val_cond = np.array(datamodule.tensor_conditioning_val)
train_data = np.array(datamodule.tensor_train)
train_mask = np.array(datamodule.mask_train)
train_cond = np.array(datamodule.tensor_conditioning_train)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

print(test_data.shape)
print(test_mask.shape)
print(test_cond.shape)
print(val_data.shape)
print(val_mask.shape)
print(val_cond.shape)
print(train_data.shape)
print(train_mask.shape)
print(train_cond.shape)
print(means)
print(stds)

In [None]:
# optional: increase the size of the test data for better statistics
FACTOR_REPEAT_MASK_COND = 1  # this is the factor by which the test data is increased/repeated
NUMER_OF_GENERATED_JETS = 2_000

# choose between test and val
mask_real = test_mask[:NUMER_OF_GENERATED_JETS]
data_real = test_data[:NUMER_OF_GENERATED_JETS]
cond_real = test_cond[:NUMER_OF_GENERATED_JETS]

# increase size for better statistics
big_mask_real = np.repeat(mask_real, FACTOR_REPEAT_MASK_COND, axis=0)
big_data_real = np.repeat(data_real, FACTOR_REPEAT_MASK_COND, axis=0)
big_cond_real = np.repeat(cond_real, FACTOR_REPEAT_MASK_COND, axis=0)

In [None]:
data_generated, generation_time = generate_data(
    model,
    num_jet_samples=FACTOR_REPEAT_MASK_COND * len(mask_real),
    batch_size=1000,
    cond=torch.tensor(big_cond_real),
    variable_set_sizes=True,
    mask=torch.tensor(big_mask_real),
    normalized_data=True,
    means=means,
    stds=stds,
    ode_solver="midpoint",
    ode_steps=200,
)

In [None]:
!ls -l "{run_dir}"

In [None]:
plot_particle_features(
    data_real,
    data_generated,
    mask_real,
    mask_real,
    feature_names=datamodule.names_particle_features,
    plot_path="part_features.pdf",
    also_png=True,
)

In [None]:
print(datamodule.names_conditioning)
print(datamodule.names_jet_features)
print(datamodule.names_particle_features)

In [None]:
# plot the generated features and compare sim. data to gen. data
plot_cols = 3
plot_rows = data_real.shape[-1] // plot_cols + 1
fig, ax = plt.subplots(plot_rows, plot_cols, figsize=(12, 3 * plot_rows))
ax = ax.flatten()
hist_kwargs = dict(density=True)
for i in range(data_real.shape[-1]):
    values_sim = data_real[:, :, i][mask_real[:, :, 0] != 0].flatten()
    values_gen = data_generated[:, :, i][mask_real[:, :, 0] != 0].flatten()
    _, bin_edges = np.histogram(np.concatenate([values_sim, values_gen]), bins=100)
    hist_kwargs["bins"] = bin_edges
    ax[i].hist(values_sim, **hist_kwargs, label="Sim. data", alpha=0.5)
    ax[i].hist(
        values_gen,
        label="Gen. data",
        histtype="step",
        **hist_kwargs,
    )
    ax[i].set_yscale("log")
    ax[i].set_xlabel(datamodule.names_particle_features[i])
ax[2].legend(frameon=False)
fig.tight_layout()

In [None]:
# calculate the Wasserstein distance between the simulated and generated data
w_dists_big = calculate_all_wasserstein_metrics(
    data_real[..., :3],
    data_generated[..., :3],
    None,
    None,
    num_eval_samples=len(data_real),
    num_batches=FACTOR_REPEAT_MASK_COND,
    calculate_efps=True,
    use_masks=False,
)

print(f"W-Dist m: {w_dists_big['w1m_mean']:4.3E} +- {w_dists_big['w1m_std']:4.3E}")
print(f"W-Dist p: {w_dists_big['w1p_mean']:4.3E} +- {w_dists_big['w1p_std']:4.3E}")
print(f"W-Dist efp: {w_dists_big['w1efp_mean']:4.3E} +- {w_dists_big['w1efp_std']:4.3E}")

In [None]:
# w_dists_big_non_averaged = calculate_all_wasserstein_metrics(
#     data_real[..., :3],
#     data_generated[..., :3],
#     None,
#     None,
#     num_eval_samples=len(data_real),
#     num_batches=factor,
#     calculate_efps=True,
#     use_masks=False,
# )
# w_dists_big_non_averaged

In [None]:
# Crosscheck plots:
# - pT_particle / pT_jet (as in dataset)
# - pT_particle when rescaled with jet pT
# - pT_jet when calculated from constituents
#


# cplt.utils.set_mpl_colours()
# # cplt.utils.reset_mpl_colours()

# fig, ax = plt.subplots(1, 3, figsize=(15, 4))
# hist_kwargs = dict(bins=100, histtype="step")

# # make copy of particle features
# particle_features = deepcopy(data_real)

# # re-scale particle pt with jet pt
# particle_features[:, :, 2] *= np.repeat(
#     spectator_real[:, np.newaxis, :], mask_real.shape[1], axis=1
# )[:, :, 0]

# # calculate jet features (both with pT_rel and pT)
# jet_features_rel = calculate_jet_features(data_real)  # pT_rel
# jet_features_calculated = calculate_jet_features(particle_features)  # pT

# # Note: the jet pt which is calculated from the constituent pt does not
# #       yield exactly the same distribution if the etadiff tails are removed!
# #       the distributions should match though when using all constituents.
# ax[0].hist(data_real[:, :, 2][mask_real[:, :, 0] != 0].flatten(), **hist_kwargs)
# ax[0].set_xlabel("$p_T^{particle} / p_T^{jet}$")
# ax[1].hist(particle_features[:, :, 2][mask_real[:, :, 0] != 0].flatten(), **hist_kwargs)
# ax[1].set_xlabel("$p_T^{particle}$")
# ax[0].set_yscale("log")
# ax[1].set_yscale("log")
# ax[2].hist(jet_features_calculated[:, 0], **hist_kwargs, label="Calculated from $p_T^{particle}$")
# ax[2].hist(cond_real[:, 0], **hist_kwargs, label="Original value", ls="--")
# ax[2].legend(frameon=False)
# ax[2].set_xlabel("$p_T^{jet}$")
# fig.tight_layout()
# plt.show()

# # - jet mass calculated from rescaled pT_particle and eta_rel, phi_rel
# # - jet mass calculated from pT_rel, eta_rel, phi_rel

# fig, ax = plt.subplots(1, 2, figsize=(13, 5))
# hist_kwargs = dict(histtype="step", density=True, linewidth=2)

# import yaml

# # load labels from labels.yaml
# with open("../configs/plotting/labels.yaml", "r") as f:
#     labels = yaml.load(f, Loader=yaml.SafeLoader)
#     latex_labels = labels["latex_labels"]
#     print(latex_labels)


# for i, conditioning_variable in enumerate(datamodule.names_conditioning):
#     # print(jet_type)
#     if "jet_type" not in conditioning_variable:
#         continue
#     mask = cond_real[:, i] == 1
#     jet_type = conditioning_variable.split("jet_type_")[-1]
#     hist_kwargs["linestyle"] = (
#         "solid"
#         if i < len(cplt.utils.get_good_colours())
#         else cplt.utils.get_good_linestyles("densely dotted")
#     )
#     ax[0].hist(
#         jet_features_calculated[:, 3][mask],
#         label=latex_labels[jet_type],
#         bins=np.linspace(0, 300, 60),
#         **hist_kwargs,
#     )
#     ax[0].set_xlabel("$m_\\mathrm{jet}$ (using $p_\\mathrm{T}^\\mathrm{particle}$)")
#     ax[0].set_ylabel("Normalized")
#     ax[1].hist(
#         jet_features_rel[:, 3][mask],
#         label=latex_labels[jet_type],
#         bins=np.linspace(0, 0.6, 60),
#         **hist_kwargs,
#     )
#     ax[1].set_xlabel(
#         "$m_\\mathrm{jet}$ (using $p_\\mathrm{T}^\\mathrm{particle} /"
#         " p_\\mathrm{T}^\\mathrm{jet}$)"
#     )
#     ax[1].set_ylabel("Normalized")
# ax[0].legend(frameon=False)
# fig.tight_layout()
# fig.savefig("jet_mass_comparison.pdf", bbox_inches="tight")
# plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(13, 5))
hist_kwargs = dict(histtype="step", density=True, linewidth=2)

import yaml

# load labels from labels.yaml
with open("../configs/plotting/labels.yaml", "r") as f:
    labels = yaml.load(f, Loader=yaml.SafeLoader)
    latex_labels = labels["latex_labels"]
    print(latex_labels)

n_particles_real = np.sum(data_real[:, :, 2] != 0, axis=1)
print(n_particles_real)

for i, conditioning_variable in enumerate(datamodule.names_conditioning):
    # print(jet_type)
    if "jet_type" not in conditioning_variable:
        continue
    mask_this_jet_type = cond_real[:, i] == 1
    jet_type = conditioning_variable.split("jet_type_")[-1]
    hist_kwargs["linestyle"] = (
        "solid"
        if i < len(cplt.utils.get_good_colours())
        else cplt.utils.get_good_linestyles("densely dotted")
    )
    ax[0].hist(
        n_particles_real[mask_this_jet_type],
        label=latex_labels[jet_type],
        bins=np.linspace(-5.5, 120.5, 127),
        **hist_kwargs,
    )
    ax[0].set_xlabel("Number of jet constituents")
    ax[0].set_ylabel("Normalized")
    mask_this_jet_type_and_isvalid = np.logical_and(
        mask_real[:, :, 0] != 0, np.repeat(mask_this_jet_type[:, np.newaxis], 128, axis=1)
    )
    ax[1].hist(
        data_real[:, :, 0][mask_this_jet_type_and_isvalid].flatten(),
        label=latex_labels[jet_type],
        bins=np.linspace(-1.1, 1.1, 100),
        **hist_kwargs,
    )
    ax[1].set_xlabel("$\\eta^\\mathrm{rel}$")
    ax[1].set_ylabel("Normalized")
    ax[1].set_yscale("log")
ax[0].legend(frameon=False)
fig.tight_layout()
fig.savefig("num_constituents_and_etarel.pdf", bbox_inches="tight")
plt.show()

In [None]:
# plot the (relative) jet mass for each jet type individually and compare between
# generated and real jets

# calculate jet features
jet_features_real = calculate_jet_features(data_real)
jet_features_generated = calculate_jet_features(data_generated)

# plot the jet mass for each jet type
fig, ax = plt.subplots(11, 5, figsize=(18, 30))
hist_kwargs = dict(bins=100, density=True)
# ax= ax.flatten()

for i, conditioning_variable in enumerate(["all"] + list(datamodule.names_conditioning)):
    if not ("jet_type" in conditioning_variable or "all" in conditioning_variable):
        continue
    if "all" in conditioning_variable:
        jet_type = "All jets types"
        mask_particle_level = mask_real != 0
        mask = np.ones(len(cond_real)) > 0

    else:
        jet_type = conditioning_variable.split("jet_type_")[-1]
        mask = cond_real[:, i - 1] == 1
        mask_particle_level = np.repeat(
            mask[:, np.newaxis, np.newaxis], data_real.shape[1], axis=1
        ) & (mask_real != 0)
    # print(jet_type)
    # if i> 1:
    #     break
    # print(mask.shape)
    # print(mask_particle_level.shape)
    # hist_kwargs["bins"] = 10
    ax[i, 0].set_title(jet_type)
    # eta_rel
    for j in range(3):
        _, bin_edges, _ = ax[i, j].hist(
            data_real[:, :, j][mask_particle_level[:, :, 0]].flatten(),
            **hist_kwargs,
            label="Sim. data",
            histtype="stepfilled",
            alpha=0.5,
        )
        ax[i, j].hist(
            data_generated[:, :, j][mask_particle_level[:, :, 0]].flatten(),
            bins=bin_edges,
            density=True,
            label="Gen. data",
            histtype="step",
        )
        ax[i, j].set_yscale("log")
    ax[i, 0].legend(frameon=False)
    ax[i, 0].set_xlabel("$\\eta_\\mathrm{rel}$")
    ax[i, 1].set_xlabel("$\\phi_\\mathrm{rel}$")
    ax[i, 2].set_xlabel("$p_\\mathrm{T}^\\mathrm{rel}$")
    # jet mass
    ax[i, 3].hist(
        jet_features_real[:, 3][mask],
        **hist_kwargs,
        label="Sim. data",
        histtype="stepfilled",
        alpha=0.5,
    )
    ax[i, 3].hist(
        jet_features_generated[:, 3][mask], **hist_kwargs, label="Gen. data", histtype="step"
    )
    ax[i, 3].set_xlabel("$m_\\mathrm{jet}$ (using $p_\\mathrm{T}^\\mathrm{rel}$)")
fig.tight_layout()
plt.show()

In [None]:
# investivate the relative jet mass a bit:
# one important note: the relative jet mass only depends on the direction and relative
# momentum of the constituents, not on their absolute momentum.
# this means that two jets with different momenta, but the same (in terms of direction
# and relative pT) constituents, will have the same relative jet mass.

# jet constituent coordinates: (eta_rel, phi_rel, pt_rel)
jet_constituents_artificial = np.array(
    [
        [
            [-1, 0, 0.5],
            [1, 0, 0.5],
        ],
        [
            [-0.5, 0, 0.5],
            [0.5, 0, 0.5],
        ],
    ]
)
jet_features_artificial = calculate_jet_features(jet_constituents_artificial)
print(jet_features_artificial)

jet_constituents_artificial60 = deepcopy(jet_constituents_artificial)
jet_constituents_artificial60[:, :, 2] *= 60
jet_features_artificial60 = calculate_jet_features(jet_constituents_artificial60)
print(jet_features_artificial60)

jet_constituents_artificial100 = deepcopy(jet_constituents_artificial)
jet_constituents_artificial100[:, :, 2] *= 100
jet_features_artificial100 = calculate_jet_features(jet_constituents_artificial100)

print(jet_features_artificial100)

In [None]:
arr = np.array(
    [
        [
            [-1, 0, 0.5],
            [1, 0, 0.5],
        ],
        [
            [-0.5, 0, 0.5],
            [0.5, 0, 0.5],
        ],
    ]
)
print(arr)
mask = np.array([True, False])
print()
print(arr[mask])

print((arr[:, :, 0] > 0).shape)

arr = np.array([[1, 2, 3], [4, 5, 6], [1, 2, 3]])
mask = arr[:, 0] > 2
print(mask.shape)
arr[mask]