# Evaluation of Trainings on JetNet Dataset on single jettype

## Import, data and checkpoint loading

### Imports

In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
# plots and metrics
import matplotlib.pyplot as plt

from src.data.components import calculate_all_wasserstein_metrics
from src.utils.data_generation import generate_data
from src.utils.plotting import apply_mpl_styles, plot_data, prepare_data_for_plotting

apply_mpl_styles()

### Load model and datamodule from selected experiment

In [None]:
experiment = "fm_tops_slurm.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    # print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
datamodule.setup()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
test_cond = np.array(datamodule.tensor_conditioning_test)
val_data = np.array(datamodule.tensor_val)
val_mask = np.array(datamodule.mask_val)
val_cond = np.array(datamodule.tensor_conditioning_val)
train_data = np.array(datamodule.tensor_train)
train_mask = np.array(datamodule.mask_train)
train_cond = np.array(datamodule.tensor_conditioning_train)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

In [None]:
print(test_data.shape)
print(test_mask.shape)
print(test_cond.shape)
print(val_data.shape)
print(val_mask.shape)
print(val_cond.shape)
print(train_data.shape)
print(train_mask.shape)
print(train_cond.shape)
print(means)
print(stds)

### Load checkpoint

In [None]:
ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/150 t ptmasscond onlymetrics/runs/2023-06-14_18-45-03/checkpoints/epoch_9812_w1m_0.00019521-EMA.ckpt"
model = model.load_from_checkpoint(ckpt)

## Data generation

In [None]:
dataset = "test"  # choose from "test", "val"
num_samples = (
    -1
)  # negative values are interpreted as multiplications of len(dataset), e.g. -2 -> 2*len(dataset)

In [None]:
if dataset == "test":
    background_data = test_data
    background_mask = test_mask
    background_cond = test_cond
elif dataset == "val":
    background_data = val_data
    background_mask = val_mask
    background_cond = val_cond
else:
    raise ValueError("Choose from test and val")
print(background_data.shape)

In [None]:
if num_samples < 0:
    factor = abs(num_samples)
    num_samples = len(background_data) * factor
    background_data = np.repeat(background_data, factor, axis=0)
    background_mask = np.repeat(background_mask, factor, axis=0)
    background_cond = np.repeat(background_cond, factor, axis=0)
print(background_data.shape)

In [None]:
torch.manual_seed(9999)
data, generation_time = generate_data(
    model,
    num_jet_samples=len(background_data),
    batch_size=1000,
    cond=torch.tensor(background_cond),
    variable_set_sizes=True,
    mask=torch.tensor(background_mask),
    normalized_data=True,
    means=means,
    stds=stds,
    ode_solver="midpoint",
    ode_steps=200,
)

In [None]:
print(f"Generation time: {generation_time:.2f}s")
print(f"Generation time per jet: {generation_time / len(background_data):.5f}s")
print(data.shape)

#### Save data in npy file

In [None]:
path = "/".join(ckpt.split("/")[:-2]) + "/"
file_name = "generated_data.npy"
full_path = path + file_name
print(full_path)

In [None]:
np.save(full_path, data)

## Evaluation

### Load data

In [None]:
path = "/".join(ckpt.split("/")[:-2]) + "/"
file_name = "generated_data.npy"
full_path = path + file_name
print(full_path)

In [None]:
data = np.load(full_path)

### Wasserstein distances

In [None]:
w_dists = calculate_all_wasserstein_metrics(
    background_data, data, num_eval_samples=10_000, num_batches=40
)

print(f"W-Dist m: {w_dists['w1m_mean']:4.3E} +- {w_dists['w1m_std']:4.3E}")
print(f"W-Dist p: {w_dists['w1p_mean']:4.3E} +- {w_dists['w1p_std']:4.3E}")
print(f"W-Dist efp: {w_dists['w1efp_mean']:4.3E} +- {w_dists['w1efp_std']:4.3E}")

### Plots

In [None]:
plot_config = {
    "num_samples": -1,
    "plot_jet_features": True,
    "plot_w_dists": False,
    "plot_efps": False,
    "plot_selected_multiplicities": True,
    "selected_multiplicities": [10, 20, 30, 40, 50, 100],
    "selected_particles": [1, 3, 10],
    "plottype": "sim_data",
    "save_fig": False,
    "variable_jet_sizes_plotting": True,
    "bins": 100,
    "close_fig": False,
}
plot_prep_config = {
    "calculate_efps" if key == "plot_efps" else key: value
    for key, value in plot_config.items()
    if key in ["plot_efps", "selected_particles", "selected_multiplicities"]
}

In [None]:
(
    jet_data,
    efps_values,
    pt_selected_particles,
    pt_selected_multiplicities,
) = prepare_data_for_plotting(
    np.array([data]),
    **plot_prep_config,
)

In [None]:
(
    jet_data_sim,
    efps_sim,
    pt_selected_particles_sim,
    pt_selected_multiplicities_sim,
) = prepare_data_for_plotting(
    [background_data],
    **plot_prep_config,
)
jet_data_sim, efps_sim, pt_selected_particles_sim = (
    jet_data_sim[0],
    efps_sim[0],
    pt_selected_particles_sim[0],
)
sim_data = np.concatenate([background_data, background_mask], axis=-1)

In [None]:
fig = plot_data(
    particle_data=np.array([data]),
    sim_data=sim_data,
    jet_data_sim=jet_data_sim,
    jet_data=jet_data,
    efps_sim=efps_sim,
    efps_values=efps_values,
    pt_selected_particles=pt_selected_particles,
    pt_selected_multiplicities=pt_selected_multiplicities,
    pt_selected_particles_sim=pt_selected_particles_sim,
    pt_selected_multiplicities_sim=pt_selected_multiplicities_sim,
    **plot_config,
)