In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
experiment = "fm_tops.yaml"

In [None]:
# load everything from experiment config
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    # print(OmegaConf.to_yaml(cfg))

In [None]:
datamodule = hydra.utils.instantiate(cfg.data)
datamodule.setup()

In [None]:
data = np.array(datamodule.tensor_test)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from tqdm import tqdm

In [None]:
print(data.shape)

In [None]:
color: str = ("#E2001A",)
mask_data = np.ma.masked_where(
    data[:, :, 0] == 0,
    data[:, :, 0],
)
mask = np.expand_dims(mask_data, axis=-1)

fig = plt.figure(figsize=(5, 5))
gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0])
# idx = np.random.randint(len(data))
for idx in tqdm(range(1000)):
    x_plot = data[idx, :, :2]  # .cpu()
    s_plot = np.abs(data[idx, :, 2])  # .cpu())
    s_plot[mask[idx, :, 0] < 0.0] = 0.0

    ax.scatter(*x_plot.T, s=50 * s_plot, color=color, alpha=0.5)

ax.set_xlabel(r"$\eta$")
ax.set_ylabel(r"$\phi$")

ax.set_xlim(-0.3, 0.3)
ax.set_ylim(-0.3, 0.3)
plt.show()

In [None]:
def plot_single_jets(
    data: np.ndarray,
    color: str = "#E2001A",
    save_folder: str = "logs/",
    save_name: str = "sim_jets",
) -> plt.figure:
    """Create a plot with 16 randomly selected jets from the data.

    Args:
        data (_type_): Data to plot.
        color (str, optional): Color of plotted point cloud. Defaults to "#E2001A".
        save_folder (str, optional): Path to folder where the plot is saved. Defaults to "logs/".
        save_name (str, optional): File_name for saving the plot. Defaults to "sim_jets".
    """
    mask_data = np.ma.masked_where(
        data[:, :, 0] == 0,
        data[:, :, 0],
    )
    mask = np.expand_dims(mask_data, axis=-1)
    fig = plt.figure(figsize=(16, 16))
    gs = GridSpec(4, 4)

    for i in tqdm(range(16)):
        ax = fig.add_subplot(gs[i])

        idx = np.random.randint(len(data))
        x_plot = data[idx, :, :2]  # .cpu()
        s_plot = np.abs(data[idx, :, 2])  # .cpu())
        s_plot[mask[idx, :, 0] < 0.0] = 0.0

        ax.scatter(*x_plot.T, s=5000 * s_plot, color=color, alpha=0.5)

        ax.set_xlabel(r"$\eta$")
        ax.set_ylabel(r"$\phi$")

        ax.set_xlim(-0.3, 0.3)
        ax.set_ylim(-0.3, 0.3)

    plt.tight_layout()

    plt.savefig(f"{save_folder}{save_name}.png", bbox_inches="tight")
    return fig