In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
from src.models.zuko.utils import odeint

In [None]:
A = torch.randn(3, 3)


def f(t, x, y):
    return t * x @ A


x = torch.randn(3).unsqueeze(-1)
print(A)
print(x)
cond = torch.randn(3)
print(cond)

In [None]:
shapes = [y.shape for y in x]
sizes = [y.numel() for y in x]
print(shapes)
print(sizes)

In [None]:
def pack(x):
    return torch.cat([y.flatten() for y in x])


def unpack(x):
    return [y.reshape(s) for y, s in zip(x.split(sizes), shapes)]


def g(t, x):
    return pack(f(t, *unpack(x)))

In [None]:
x = pack(x)
print(x)
print(unpack(x))

In [None]:
x1 = odeint(f, x, None, 0.0, 1.0)
print(x1)

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
experiment = "fm_tops.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
model_name_for_saving = "nb_fm_tops30"

In [None]:
datamodule.setup()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

In [None]:
print(test_data.shape)
print(test_mask.shape)
print(means)
print(stds)

In [None]:
from src.callbacks.jetnet_eval import JetNetEvaluationCallback

In [None]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    RichProgressBar,
)

checkpoint_callback = ModelCheckpoint(
    monitor="val/loss",
    mode="min",
    save_top_k=1,
    save_last=True,
    save_weights_only=True,
    dirpath=f"./logs/{model_name_for_saving}/checkpoints",
)
early_stopping = EarlyStopping(
    monitor="val/loss", mode="min", patience=10, verbose=True, min_delta=0.0001
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
model_summary = ModelSummary()
rich_progress_bar = RichProgressBar()

jetnet_eval_callback = JetNetEvaluationCallback(
    every_n_epochs=3,
    num_jet_samples=10000,
    logger=2,
    log_w_dists=False,
    image_path="/beegfs/desy/user/ewencedr/deep-learning/logs/comet_logs",
)

In [None]:
from pytorch_lightning.loggers import CometLogger, CSVLogger, WandbLogger

csv_logger = CSVLogger(f"./logs/{model_name_for_saving}/csv_logs")
comet_logger = CometLogger(
    api_key=os.environ.get("COMET_API_KEY"),
    workspace=os.environ.get("COMET_WORKSPACE"),  # Optional
    save_dir=f"./logs/{model_name_for_saving}/comet_logs",  # Optional
    project_name="Flow Matching",  # Optional
    rest_api_key=os.environ.get("COMET_REST_API_KEY"),  # Optional
    experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"),  # Optional
    experiment_name=model_name_for_saving,  # Optional
    offline=False,
)
wandb_logger = WandbLogger(
    project="Flow Matching", name=model_name_for_saving, save_dir=f"./logs/{model_name_for_saving}"
)

In [None]:
model.eval()

In [None]:
trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback, lr_monitor, model_summary],
    logger=[csv_logger, wandb_logger],
    accelerator="gpu",
)
torch.set_float32_matmul_precision("medium")
trainer.fit(
    model=model,
    datamodule=datamodule,
)

In [None]:
# ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/transformer/runs/2023-03-30_13-07-29/checkpoints/epoch_1861_loss_0.126.ckpt"
ckpt = "./logs/nb_fm_tops30/checkpoints/last-v17.ckpt"  # mass conditioning
model = model.load_from_checkpoint(ckpt)

In [None]:
from src.data.components.utils import jet_masses

print(test_data.shape)
masses = jet_masses(torch.tensor(test_data)).unsqueeze(-1)
print(masses.shape)
print(masses[:100].shape)

In [None]:
model.eval().cuda()
with torch.no_grad():
    x_samples = model.sample(100, masses[:100]).cpu().numpy()

In [None]:
print(x_samples.shape)

# Evaluation

In [None]:
import matplotlib.pyplot as plt

from src.utils.plotting import apply_mpl_styles, create_and_plot_data, plot_single_jets

apply_mpl_styles()

## Histograms

In [None]:
fig, data, generation_times = create_and_plot_data(
    np.array(test_data),
    [model],
    cond=masses,
    save_name="fm_tops_nb",
    labels=["FM"],
    mask=test_mask,
    num_jet_samples=2000,
    normalised_data=[False, False],
    means=means,
    stds=stds,
    save_folder="./logs/nb_plots/",
    plottype="sim_data",
    plot_jet_features=False,
    plot_w_dists=True,
    plot_selected_multiplicities=False,
    selected_multiplicities=[1, 3, 5, 10, 20, 30],
)

## Simulated Data

In [None]:
fig = plot_single_jets(test_data, save_folder="./logs/nb_plots/")
plt.show()

## Generated Data

In [None]:
fig = plot_single_jets(
    data[0], save_folder="./logs/nb_plots/", save_name="gen_jets", color="#0271BB"
)
plt.show()

In [None]:
model.cuda()
with torch.no_grad():
    log_p = model.flows[0].log_prob(torch.tensor(test_data[:5]).float().cuda())
print(log_p)

In [None]:
model.eval()

In [None]:
model.eval().cuda()

with torch.no_grad():
    x_samples = model.sample(100).cpu().numpy()

In [None]:
from src.data.components.utils import jet_masses

print(f"x_samples shape: {x_samples.shape}")
mass = jet_masses(torch.tensor(x_samples))
print(f"mass shape: {mass.shape}")
print(mass)

In [None]:
fig = plot_single_jets(
    x_samples, save_folder="./logs/nb_plots/", save_name="gen_jets", color="#0271BB"
)
plt.show()

In [None]:
import energyflow as ef
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec


def jet_masses(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    masses = ef.ms_from_p4s(jets_p4s.sum(axis=1))
    return masses


def jet_ys(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    ys = ef.ys_from_p4s(jets_p4s.sum(axis=1))
    return ys


def jet_etas(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    etas = ef.etas_from_p4s(jets_p4s.sum(axis=1))
    return etas


def jet_phis(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    phis = ef.phis_from_p4s(jets_p4s.sum(axis=1), phi_ref=0)
    return phis

In [None]:
x = torch.tensor(test_data)
fig = plt.figure(figsize=(20, 4))
gs = GridSpec(1, 4)

#####

ax = fig.add_subplot(gs[0])

i_feat = 0

bins = np.linspace(-0.5, 0.5, 50)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)
print(np.concatenate(x_samples).shape)
eta = np.concatenate((np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$\eta^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[1])

i_feat = 1

bins = np.linspace(-0.5, 0.5, 50)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)

eta = (np.concatenate(np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$\phi^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[2])

i_feat = 2

bins = np.linspace(-0.1, 0.5, 100)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)

eta = np.concatenate((np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$p_\mathrm{T}^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[3])

bins = np.linspace(0.0, 0.3, 100)

jet_mass = jet_masses(
    np.array([x_samples[:, :, 2], x_samples[:, :, 0], x_samples[:, :, 1]]).transpose(1, 2, 0)
)
ax.hist(jet_mass, histtype="step", bins=bins, density=True, lw=2, ls="--", alpha=0.7, label="Gen")

jet_mass = jet_masses(
    np.array([x.numpy()[:, :, 2], x.numpy()[:, :, 0], x.numpy()[:, :, 1]]).transpose(1, 2, 0)
)
ax.hist(jet_mass, histtype="step", bins=bins, density=True, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"Jet mass")
ax.set_yscale("log")
ax.legend()


plt.tight_layout()