In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
# from src.models.zuko.utils import odeint

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
experiment = "fm_tops.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
model_name_for_saving = "nb_fm_tops30"

In [None]:
datamodule.setup()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

In [None]:
print(test_data.shape)
print(test_mask.shape)
print(means)
print(stds)

In [None]:
unique, counts = np.unique(np.sum(test_mask, axis=-2), return_counts=True)
print(np.asarray((unique, counts)).T)

In [None]:
from src.callbacks.ema import EMA
from src.callbacks.jetnet_eval import JetNetEvaluationCallback

In [None]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    RichProgressBar,
)

checkpoint_callback = ModelCheckpoint(
    monitor="val/loss",
    mode="min",
    save_top_k=1,
    save_last=True,
    save_weights_only=True,
    dirpath=f"./logs/{model_name_for_saving}/checkpoints",
)
early_stopping = EarlyStopping(
    monitor="val/loss", mode="min", patience=10, verbose=True, min_delta=0.0001
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
model_summary = ModelSummary()
rich_progress_bar = RichProgressBar()

jetnet_eval_callback = JetNetEvaluationCallback(
    every_n_epochs=3,
    num_jet_samples=10000,
    logger=2,
    log_w_dists=False,
    image_path="/beegfs/desy/user/ewencedr/deep-learning/logs/comet_logs",
)

ema = EMA(
    decay=0.9999,
)

In [None]:
tensor = torch.tensor([1, 2, 3]).unsqueeze(-1).repeat_interleave(3, dim=-1)
print(tensor)

In [None]:
from pytorch_lightning.loggers import CometLogger, CSVLogger, WandbLogger

csv_logger = CSVLogger(f"./logs/{model_name_for_saving}/csv_logs")
comet_logger = CometLogger(
    api_key=os.environ.get("COMET_API_KEY"),
    workspace=os.environ.get("COMET_WORKSPACE"),  # Optional
    save_dir=f"./logs/{model_name_for_saving}/comet_logs",  # Optional
    project_name="Flow Matching",  # Optional
    rest_api_key=os.environ.get("COMET_REST_API_KEY"),  # Optional
    experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"),  # Optional
    experiment_name=model_name_for_saving,  # Optional
    offline=False,
)
wandb_logger = WandbLogger(
    project="Flow Matching", name=model_name_for_saving, save_dir=f"./logs/{model_name_for_saving}"
)

In [None]:
model.eval()

In [None]:
trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[checkpoint_callback, lr_monitor, model_summary, ema],
    logger=[csv_logger, wandb_logger],
    accelerator="gpu",
)
torch.set_float32_matmul_precision("medium")
trainer.fit(
    model=model,
    datamodule=datamodule,
)

In [None]:
# ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/transformer/runs/2023-03-30_13-07-29/checkpoints/epoch_1861_loss_0.126.ckpt"
# ckpt = "./logs/nb_fm_tops30/checkpoints/last-v17.ckpt"  # mass conditioning
# ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/SWD/runs/2023-04-18_14-51-12/checkpoints/epoch_1752_loss_0.02817.ckpt"
# ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/normalize t local 5000epochs/runs/2023-05-14_03-00-43/checkpoints/epoch_3503_loss_5.96382.ckpt"
ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/150 particles sigma 1e-6/runs/2023-05-26_18-47-53/checkpoints/epoch_9540_w1p_0.00500000.ckpt"
ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/fm_tops-150-1/runs/2023-06-01_15-35-44/checkpoints/last.ckpt"
model = model.load_from_checkpoint(ckpt)

In [None]:
from src.data.components.utils import jet_masses

print(test_data.shape)
masses = jet_masses(torch.tensor(test_data)).unsqueeze(-1)
print(masses.shape)
print(masses[:100].shape)

In [None]:
model.eval().cuda()
with torch.no_grad():
    x_samples = model.sample(100, masses[:100]).cpu().numpy()

In [None]:
print(x_samples.shape)

# Evaluation

In [None]:
import matplotlib.pyplot as plt

from src.utils.plotting import apply_mpl_styles, create_and_plot_data, plot_single_jets

apply_mpl_styles()

In [None]:
print(type(model))

## Histograms

In [None]:
fig, data, generation_times = create_and_plot_data(
    np.array(test_data),
    [model],
    cond=None,
    save_name="fm_tops_nb",
    labels=["FM"],
    mask=test_mask,
    num_jet_samples=10000,
    batch_size=1000,
    variable_set_sizes=True,
    normalized_data=[True, True],
    means=means,
    stds=stds,
    save_folder="./logs/nb_plots/",
    plottype="sim_data",
    plot_jet_features=True,
    plot_w_dists=True,
    plot_selected_multiplicities=True,
    selected_multiplicities=[1, 3, 5, 10, 20, 30],
    ode_solver="midpoint",
    ode_steps=100,
)

In [None]:
# print(data[0][0])

In [None]:
print(data)

In [None]:
from src.data.components import calculate_all_wasserstein_metrics

In [None]:
print(data[0].shape)
print(test_mask.shape)
print(test_data.shape)

In [None]:
test_data.shape

In [None]:
particle_data = data[0]
mask_data = (particle_data[..., 0] == 0).astype(int)
mask_data = 1 - mask_data
# print(mask_data)
mask_data = np.expand_dims(mask_data, axis=-1)
# print(np.count_nonzero(mask_data, axis=-2))
# print(np.count_nonzero(test_data[:len(particle_data),:, :3], axis=-2))

In [None]:
# print(particle_data[0])
print(particle_data[:, :, :3])

In [None]:
zeros1 = ~(np.linalg.norm(particle_data[:, :, :3], axis=-1) == 0)
print(zeros1)
print(zeros1.shape)
mask_data = ~(particle_data[..., 0] == 0)
print(mask_data)
print(mask_data.shape)
print(np.count_nonzero((mask_data == zeros1).astype(int)))
print(~zeros1 * mask_data)
print(True * True)
print(False * False)
print(True * False)

In [None]:
print(int(10004 // 5))

In [None]:
particle_data = data[0]
mask_data = (particle_data[..., 0] == 0).astype(int)
# print(mask_data.shape)
# print(test_mask[:len(particle_data),:,0].astype(bool))
mask_data = np.expand_dims(mask_data, axis=-1)
mask_data = 1 - mask_data
print(mask_data.shape)
print(mask_data)
w_dists_1b = calculate_all_wasserstein_metrics(
    test_data[: len(particle_data), :, :3],
    particle_data,
    test_mask[: len(particle_data)],
    mask_data,
    num_eval_samples=len(particle_data),
    num_batches=1,
    calculate_efps=True,
    use_masks=True,
)
w_dists = calculate_all_wasserstein_metrics(
    test_data[: len(particle_data), :, :3],
    particle_data,
    test_mask[: len(particle_data)],
    mask_data,
    num_eval_samples=int(len(particle_data) / 5),
    num_batches=5,
    calculate_efps=True,
    use_masks=True,
)
print(w_dists_1b)
print(w_dists)
# 0.0029 using mask
# 0.0004 using zero exclude without
# 0.00018 without zero exclude

## Simulated Data

In [None]:
fig = plot_single_jets(test_data, save_folder="./logs/nb_plots/")
plt.show()

## Generated Data

In [None]:
fig = plot_single_jets(
    data[0], save_folder="./logs/nb_plots/", save_name="gen_jets", color="#0271BB"
)
plt.show()

In [None]:
model.cuda()
with torch.no_grad():
    log_p = model.flows[0].log_prob(torch.tensor(test_data[:5]).float().cuda())
print(log_p)

In [None]:
model.eval()

In [None]:
print(type(test_mask))

In [None]:
model.eval().cuda()

with torch.no_grad():
    x_samples = model.sample(100, mask=torch.tensor(test_mask)).cpu().numpy()

In [None]:
print(np.repeat(test_mask[:100], 3, axis=-1).shape)

In [None]:
masked_samples = x_samples * np.repeat(test_mask[:100], 3, axis=-1)

In [None]:
print(np.count_nonzero(test_mask[0]))
print(np.count_nonzero(x_samples[0, :, 0]))
print(np.count_nonzero(masked_samples[0, :, 0]))

In [None]:
from src.data.components.utils import jet_masses

print(f"x_samples shape: {x_samples.shape}")
mass = jet_masses(torch.tensor(x_samples))
print(f"mass shape: {mass.shape}")
print(mass)

In [None]:
fig = plot_single_jets(
    masked_samples, save_folder="./logs/nb_plots/", save_name="gen_jets", color="#0271BB"
)
fig = plot_single_jets(
    test_data[: masked_samples.shape[0]],
    save_folder="./logs/nb_plots/",
    save_name="gen_jets",
    color="r",
)
plt.show()

In [None]:
import energyflow as ef
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec


def jet_masses(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    masses = ef.ms_from_p4s(jets_p4s.sum(axis=1))
    return masses


def jet_ys(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    ys = ef.ys_from_p4s(jets_p4s.sum(axis=1))
    return ys


def jet_etas(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    etas = ef.etas_from_p4s(jets_p4s.sum(axis=1))
    return etas


def jet_phis(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    phis = ef.phis_from_p4s(jets_p4s.sum(axis=1), phi_ref=0)
    return phis

In [None]:
print(x_samples.shape)
# (np.concatenate(x_samples))[:, i_feat]
# print((np.concatenate(x_samples))[:,0].shape)
np.concatenate()
print(np.reshape(x_samples, (-1, x_samples.shape[-1])).shape)

In [None]:
x = torch.tensor(test_data)
fig = plt.figure(figsize=(20, 4))
gs = GridSpec(1, 4)

#####

ax = fig.add_subplot(gs[0])

i_feat = 0

bins = np.linspace(-0.5, 0.5, 50)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)
print(np.concatenate(x_samples).shape)
eta = np.concatenate((np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$\eta^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[1])

i_feat = 1

bins = np.linspace(-0.5, 0.5, 50)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)

eta = (np.concatenate(np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$\phi^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[2])

i_feat = 2

bins = np.linspace(-0.1, 0.5, 100)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)

eta = np.concatenate((np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$p_\mathrm{T}^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[3])

bins = np.linspace(0.0, 0.3, 100)

jet_mass = jet_masses(
    np.array([x_samples[:, :, 2], x_samples[:, :, 0], x_samples[:, :, 1]]).transpose(1, 2, 0)
)
ax.hist(jet_mass, histtype="step", bins=bins, density=True, lw=2, ls="--", alpha=0.7, label="Gen")

jet_mass = jet_masses(
    np.array([x.numpy()[:, :, 2], x.numpy()[:, :, 0], x.numpy()[:, :, 1]]).transpose(1, 2, 0)
)
ax.hist(jet_mass, histtype="step", bins=bins, density=True, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"Jet mass")
ax.set_yscale("log")
ax.legend()


plt.tight_layout()