In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
experiment = "fm_tops.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
model_name_for_saving = "nb_fm_tops30"

In [None]:
datamodule.setup()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

In [None]:
print(test_data.shape)
print(test_mask.shape)
print(means)
print(stds)

In [None]:
# from src.callbacks.jetnet_eval import JetNetEvaluationCallback

In [None]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    RichProgressBar,
)

checkpoint_callback = ModelCheckpoint(
    monitor="val/loss",
    mode="min",
    save_top_k=1,
    save_last=True,
    save_weights_only=True,
    dirpath=f"./logs/{model_name_for_saving}/checkpoints",
)
early_stopping = EarlyStopping(
    monitor="val/loss", mode="min", patience=10, verbose=True, min_delta=0.0001
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
model_summary = ModelSummary()
rich_progress_bar = RichProgressBar()

# jetnet_eval_callback = JetNetEvaluationCallback(every_n_epochs=3,test_particle_data=test_data,test_mask=test_mask,means=means,stds=stds,num_jet_samples=3000,w_dists_batches=2,selected_particles=[1, 2, 5],normalised_data=False)
# jetnet_eval_callback = JetNetEvaluationCallback(every_n_epochs=3,datamodule=datamodule, num_jet_samples=3000)

In [None]:
from pytorch_lightning.loggers import CometLogger, CSVLogger, WandbLogger

csv_logger = CSVLogger(f"./logs/{model_name_for_saving}/csv_logs")
comet_logger = CometLogger(
    api_key=os.environ.get("COMET_API_KEY"),
    workspace=os.environ.get("COMET_WORKSPACE"),  # Optional
    save_dir=f"./logs/{model_name_for_saving}/comet_logs",  # Optional
    project_name="Flow Matching",  # Optional
    rest_api_key=os.environ.get("COMET_REST_API_KEY"),  # Optional
    experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"),  # Optional
    experiment_name=model_name_for_saving,  # Optional
    offline=False,
)
wandb_logger = WandbLogger(
    project="Flow Matching", name=model_name_for_saving, save_dir=f"./logs/{model_name_for_saving}"
)

In [None]:
model.eval()

In [None]:
trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[checkpoint_callback, lr_monitor, model_summary],
    logger=[csv_logger, wandb_logger],
    accelerator="gpu",
)
torch.set_float32_matmul_precision("medium")
trainer.fit(
    model=model,
    datamodule=datamodule,
)

In [None]:
# ckpt = "/beegfs/desy/user/ewencedr/deep-learning/logs/epic2000/runs/2023-03-30_00-03-29/checkpoints/epoch_1127_loss_0.126.ckpt"
# model = model.load_from_checkpoint(ckpt)

# Evaluation

In [None]:
import matplotlib.pyplot as plt

from src.utils.plotting import apply_mpl_styles, create_and_plot_data, plot_single_jets

apply_mpl_styles()

## Histograms

In [None]:
fig, data, generation_times = create_and_plot_data(
    np.array(test_data),
    [model],
    "fm_tops_nb",
    labels=["FM"],
    mask=test_mask,
    num_jet_samples=10000,
    normalised_data=[False, False],
    means=means,
    stds=stds,
    save_folder="./logs/nb_plots/",
    plottype="",
    plot_jet_features=False,
    plot_w_dists=True,
    plot_selected_multiplicities=False,
    selected_multiplicities=[1, 3, 5, 10, 20, 30],
)

## Simulated Data

In [None]:
fig = plot_single_jets(test_data, save_folder="./logs/nb_plots/")
plt.show()

## Generated Data

In [None]:
fig = plot_single_jets(
    data[0], save_folder="./logs/nb_plots/", save_name="gen_jets", color="#0271BB"
)
plt.show()

In [None]:
model.cuda()
with torch.no_grad():
    log_p = model.flows[0].log_prob(torch.tensor(test_data[:5]).float().cuda())
print(log_p)

In [None]:
model.eval()