In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
# plots and metrics
import matplotlib.pyplot as plt

from src.data.components import calculate_all_wasserstein_metrics
from src.utils.data_generation import generate_data
from src.utils.plotting import apply_mpl_styles, plot_data, prepare_data_for_plotting

apply_mpl_styles()

In [None]:
experiment = "lhco.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    # print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
datamodule.setup()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
test_cond = np.array(datamodule.tensor_conditioning_test)
val_data = np.array(datamodule.tensor_val)
val_mask = np.array(datamodule.mask_val)
val_cond = np.array(datamodule.tensor_conditioning_val)
train_data = np.array(datamodule.tensor_train)
train_mask = np.array(datamodule.mask_train)
train_cond = np.array(datamodule.tensor_conditioning_train)

In [None]:
print(test_data.shape)
print(test_mask.shape)
print(test_cond.shape)
print(val_data.shape)
print(val_mask.shape)
print(val_cond.shape)
print(train_data.shape)
print(train_mask.shape)
print(train_cond.shape)

In [None]:
ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/lhco_flow_matching/runs/2023-07-22_00-46-21/checkpoints/last-EMA.ckpt"
model = model.load_from_checkpoint(ckpt)

In [None]:
samples = 10000

In [None]:
torch.manual_seed(9999)
data, generation_time = generate_data(
    model,
    num_jet_samples=samples,
    batch_size=2048,
    cond=None,
    variable_set_sizes=True,
    mask=torch.tensor(test_mask[:samples]),
    normalized_data=False,
    means=None,
    stds=None,
    ode_solver="midpoint",
    ode_steps=50,
)

In [None]:
data = data[..., [1, 2, 0]]

In [None]:
background_data = test_data[:samples]
background_data = background_data[..., [1, 2, 0]]

In [None]:
plot_config = {
    "num_samples": -1,
    "plot_jet_features": True,
    "plot_w_dists": False,
    "plot_efps": False,
    "plot_selected_multiplicities": False,
    "selected_multiplicities": [10, 20, 30, 40, 50, 100],
    "selected_particles": [1, 3, 10],
    "plottype": "sim_data",
    "save_fig": False,
    "variable_jet_sizes_plotting": True,
    "bins": 100,
    "close_fig": False,
}
plot_prep_config = {
    "calculate_efps" if key == "plot_efps" else key: value
    for key, value in plot_config.items()
    if key in ["plot_efps", "selected_particles", "selected_multiplicities"]
}

In [None]:
(
    jet_data,
    efps_values,
    pt_selected_particles,
    pt_selected_multiplicities,
) = prepare_data_for_plotting(
    np.array([data]),
    **plot_prep_config,
)

In [None]:
(
    jet_data_sim,
    efps_sim,
    pt_selected_particles_sim,
    pt_selected_multiplicities_sim,
) = prepare_data_for_plotting(
    [background_data],
    **plot_prep_config,
)
jet_data_sim, efps_sim, pt_selected_particles_sim = (
    jet_data_sim[0],
    efps_sim[0],
    pt_selected_particles_sim[0],
)

In [None]:
fig = plot_data(
    particle_data=np.array([data]),
    sim_data=background_data,
    jet_data_sim=jet_data_sim,
    jet_data=jet_data,
    efps_sim=efps_sim,
    efps_values=efps_values,
    pt_selected_particles=pt_selected_particles,
    pt_selected_multiplicities=pt_selected_multiplicities,
    pt_selected_particles_sim=pt_selected_particles_sim,
    pt_selected_multiplicities_sim=pt_selected_multiplicities_sim,
    **plot_config,
)

#test

In [None]:
samples = np.array(model.sample(1000, mask=torch.tensor(test_mask), ode_steps=50))

In [None]:
particle_mulitplicity = np.count_nonzero(samples[:, :, 0], axis=1)
sim_data = np.sum(test_mask.squeeze()[: len(particle_mulitplicity)], axis=-1, dtype=int)
x_min, x_max = np.min(sim_data), np.max(sim_data)
binwidth = 1
bins_pm = range(x_min, x_max + binwidth, binwidth)
plt.hist(sim_data, bins=bins_pm, histtype="stepfilled", label="simulated")
plt.hist(particle_mulitplicity, bins=bins_pm, histtype="step", label="generated")
plt.show()

In [None]:
x_min, x_max = test_data[:, :, 0].min(), test_data[:, :, 0].max()
hist1 = plt.hist(
    test_data[:, :, 0].flatten(),
    bins=100,
    range=[x_min, x_max],
    histtype="stepfilled",
    label="simulated",
)
plt.hist(samples[:, :, 0].flatten(), bins=hist1[1], histtype="step", label="generated")
plt.yscale("log")
plt.show()