In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
experiment = "fm_tops.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    RichProgressBar,
)

checkpoint_callback = ModelCheckpoint(
    monitor="val/loss", mode="min", save_top_k=1, save_last=True, save_weights_only=True
)
early_stopping = EarlyStopping(
    monitor="val/loss", mode="min", patience=10, verbose=True, min_delta=0.0001
)
lr_monitor = LearningRateMonitor(logging_interval="step")
model_summary = ModelSummary()
rich_progress_bar = RichProgressBar()

In [None]:
trainer = pl.Trainer(max_epochs=5, callbacks=[early_stopping, lr_monitor], accelerator="gpu")
torch.set_float32_matmul_precision("medium")

In [None]:
trainer.fit(
    model=model,
    datamodule=datamodule,
    ckpt_path=cfg.get("ckpt_path"),
)

# Evaluation

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from src.utils.plotting import apply_mpl_styles, create_and_plot_data, plot_single_jets

apply_mpl_styles()

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)

## Histograms

In [None]:
fig, data, generation_times = create_and_plot_data(
    np.array(test_data),
    [model, model],
    "fm_tops_nb",
    labels=["FM", "FM2"],
    mask=test_mask,
    num_jet_samples=2000,
    normalised_data=[False, False],
    save_folder="/home/ewencedr/deep-learning/logs/plots/",
)

## Simulated Data

In [None]:
fig = plot_single_jets(test_data, save_folder="/home/ewencedr/deep-learning/logs/plots/")
plt.show()

## Generated Data

In [None]:
fig = plot_single_jets(
    data[0], save_folder="/home/ewencedr/deep-learning/logs/plots/", save_name="gen_jets"
)
plt.show()