In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from typing import List, Optional, Tuple

import hydra
import pyrootutils
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import Logger

from src import utils

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
experiment = "fm_tops.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
model = hydra.utils.instantiate(cfg.model)

In [None]:
test = OmegaConf.to_yaml(cfg.callbacks)
print(test)

In [None]:
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    RichProgressBar,
)

checkpoint_callback = ModelCheckpoint(
    monitor="val/loss", mode="min", save_top_k=1, save_last=True, save_weights_only=True
)
early_stopping = EarlyStopping(
    monitor="val/loss", mode="min", patience=10, verbose=True, min_delta=0.0001
)
lr_monitor = LearningRateMonitor(logging_interval="step")
model_summary = ModelSummary()
rich_progress_bar = RichProgressBar()

In [None]:
trainer = pl.Trainer(max_epochs=5, callbacks=[], accelerator="gpu")
torch.set_float32_matmul_precision("medium")

In [None]:
trainer.fit(
    model=model,
    datamodule=datamodule,
    ckpt_path=cfg.get("ckpt_path"),
)

In [None]:
# model.eval()
model.cuda().device

In [None]:
x_samples = model.sample(1000).detach().cpu().numpy()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tqdm
from matplotlib.gridspec import GridSpec

In [None]:
fig = plt.figure(figsize=(16, 16))
gs = GridSpec(4, 4)

for i in range(16):
    ax = fig.add_subplot(gs[i])

    ax.scatter(*x_samples[i, :, :2].T, s=1000 * np.abs(x_samples[i, :, 2]), alpha=0.5)

    ax.set_xlabel(r"$\eta$")
    ax.set_ylabel(r"$\phi$")

    ax.set_xlim(-0.3, 0.3)
    ax.set_ylim(-0.3, 0.3)

plt.suptitle("Gen. jets", fontsize=20, y=1.01)
plt.tight_layout()

In [None]:
import energyflow as ef


def jet_masses(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    masses = ef.ms_from_p4s(jets_p4s.sum(axis=1))
    return masses


def jet_ys(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    ys = ef.ys_from_p4s(jets_p4s.sum(axis=1))
    return ys


def jet_etas(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    etas = ef.etas_from_p4s(jets_p4s.sum(axis=1))
    return etas


def jet_phis(jets_ary):
    jets_p4s = ef.p4s_from_ptyphims(jets_ary)
    phis = ef.phis_from_p4s(jets_p4s.sum(axis=1), phi_ref=0)
    return phis

In [None]:
x = datamodule.tensor_test
print(x.shape)

In [None]:
fig = plt.figure(figsize=(20, 4))
gs = GridSpec(1, 4)

#####

ax = fig.add_subplot(gs[0])

i_feat = 0

bins = np.linspace(-0.5, 0.5, 50)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)

eta = np.concatenate((np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$\eta^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[1])

i_feat = 1

bins = np.linspace(-0.5, 0.5, 50)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)

eta = (np.concatenate(np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$\phi^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[2])

i_feat = 2

bins = np.linspace(-0.1, 0.5, 100)
ax.hist(
    (np.concatenate(x_samples))[:, i_feat],
    histtype="step",
    bins=bins,
    density=True,
    lw=2,
    ls="--",
    alpha=0.7,
    label="Gen",
)

eta = np.concatenate((np.array(x)))[:, i_feat]
eta = eta[eta != 0.0]
ax.hist(eta, histtype="step", density=True, bins=bins, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"$p_\mathrm{T}^\mathrm{rel}$")
ax.get_yaxis().set_ticklabels([])
ax.set_yscale("log")
ax.legend()

#####

ax = fig.add_subplot(gs[3])

bins = np.linspace(0.0, 0.3, 100)
# bins=100
jet_mass = jet_masses(
    np.array([x_samples[:, :, 2], x_samples[:, :, 0], x_samples[:, :, 1]]).transpose(1, 2, 0)
)
ax.hist(jet_mass, histtype="step", bins=bins, density=True, lw=2, ls="--", alpha=0.7, label="Gen")

jet_mass = jet_masses(
    np.array([x.numpy()[:, :, 2], x.numpy()[:, :, 0], x.numpy()[:, :, 1]]).transpose(1, 2, 0)
)
ax.hist(jet_mass, histtype="step", bins=bins, density=True, lw=2, alpha=0.7, label="Sim")

ax.set_xlabel(r"Jet mass")
ax.set_yscale("log")
ax.legend()


plt.tight_layout()