In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
# Here belongs the name of the experiment that you want to run
experiment = "experiment.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    # print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
model_name_for_saving = "nb_fm_tops30"

In [None]:
datamodule.setup()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
test_cond = np.array(datamodule.tensor_conditioning_test)
val_data = np.array(datamodule.tensor_val)
val_mask = np.array(datamodule.mask_val)
val_cond = np.array(datamodule.tensor_conditioning_val)
train_data = np.array(datamodule.tensor_train)
train_mask = np.array(datamodule.mask_train)
train_cond = np.array(datamodule.tensor_conditioning_train)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

In [None]:
print(test_data.shape)
print(test_mask.shape)
print(test_cond.shape)
print(val_data.shape)
print(val_mask.shape)
print(val_cond.shape)
print(train_data.shape)
print(train_mask.shape)
print(train_cond.shape)
print(means)
print(stds)

In [None]:
unique, counts = np.unique(np.sum(test_mask, axis=-2), return_counts=True)
# print(np.asarray((unique, counts)).T)

In [None]:
from particle_fm.callbacks.ema import EMA
from particle_fm.callbacks.jetnet_eval import JetNetEvaluationCallback

In [None]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    RichProgressBar,
)

checkpoint_callback = ModelCheckpoint(
    monitor="val/loss",
    mode="min",
    save_top_k=1,
    save_last=True,
    save_weights_only=True,
    dirpath=f"./logs/{model_name_for_saving}/checkpoints",
)
early_stopping = EarlyStopping(
    monitor="val/loss", mode="min", patience=10, verbose=True, min_delta=0.0001
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
model_summary = ModelSummary()
rich_progress_bar = RichProgressBar()

# jetnet_eval_callback = JetNetEvaluationCallback(
#    every_n_epochs=3,
#    num_jet_samples=10000,
#    logger=2,
#    log_w_dists=False,
#    image_path="/beegfs/desy/user/ewencedr/deep-learning/logs/comet_logs",
# )

ema = EMA(
    decay=0.9999,
)

In [None]:
from pytorch_lightning.loggers import CometLogger, CSVLogger, WandbLogger

csv_logger = CSVLogger(f"./logs/{model_name_for_saving}/csv_logs")
comet_logger = CometLogger(
    api_key=os.environ.get("COMET_API_KEY"),
    workspace=os.environ.get("COMET_WORKSPACE"),  # Optional
    save_dir=f"./logs/{model_name_for_saving}/comet_logs",  # Optional
    project_name="Flow Matching",  # Optional
    rest_api_key=os.environ.get("COMET_REST_API_KEY"),  # Optional
    experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"),  # Optional
    experiment_name=model_name_for_saving,  # Optional
    offline=False,
)
wandb_logger = WandbLogger(
    project="Flow Matching", name=model_name_for_saving, save_dir=f"./logs/{model_name_for_saving}"
)

In [None]:
trainer = pl.Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback, lr_monitor, model_summary, ema],
    logger=[csv_logger, wandb_logger],
    accelerator="gpu",
)
torch.set_float32_matmul_precision("medium")
trainer.fit(
    model=model,
    datamodule=datamodule,
)