In [None]:
import os
import sys

sys.path.append("../")

%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf

In [None]:
# set env variable DATA_DIR again because of hydra
from dotenv import load_dotenv

load_dotenv()
os.environ["DATA_DIR"] = os.environ.get("DATA_DIR")

In [None]:
experiment = "fm_tops.yaml"
experiment_jedi = "epic_jedi.yaml"

In [None]:
# load everything from experiment config
torch.manual_seed(123)
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment}"])
    # print(OmegaConf.to_yaml(cfg))
torch.manual_seed(torch.seed())

In [None]:
# load everything from experiment config
torch.manual_seed(123)
with hydra.initialize(version_base=None, config_path="../configs/"):
    cfg_jedi = hydra.compose(config_name="train.yaml", overrides=[f"experiment={experiment_jedi}"])
    # print(OmegaConf.to_yaml(cfg_jedi))
torch.manual_seed(torch.seed())

In [None]:
torch.manual_seed(123)
datamodule = hydra.utils.instantiate(cfg.data)
datamodule.setup()
torch.manual_seed(torch.seed())

In [None]:
torch.manual_seed(123)
datamodule_jedi = hydra.utils.instantiate(cfg_jedi.data)
datamodule_jedi.setup(stage="fit")
torch.manual_seed(torch.seed())

## Compare the data shuffling

In [None]:
test_data = np.array(datamodule.tensor_test)
test_mask = np.array(datamodule.mask_test)
test_cond = np.array(datamodule.tensor_conditioning_test)
val_data = np.array(datamodule.tensor_val)
val_mask = np.array(datamodule.mask_val)
val_cond = np.array(datamodule.tensor_conditioning_val)
train_data = np.array(datamodule.tensor_train)
train_mask = np.array(datamodule.mask_train)
train_cond = np.array(datamodule.tensor_conditioning_train)
means = np.array(datamodule.means)
stds = np.array(datamodule.stds)

In [None]:
print("test_data.shape", test_data.shape)
print("val_data.shape", val_data.shape)
print("train_data.shape", train_data.shape)

In [None]:
print(datamodule.hparams.normalize)

In [None]:
data_jedi = []
torch.manual_seed(123)

for i in datamodule_jedi.train_dataloader():
    # print(i[0].shape)
    data_jedi.append(i[0].numpy())
data_jedi = np.array(data_jedi)
data_jedi = np.reshape(
    data_jedi, (data_jedi.shape[0] * data_jedi.shape[1], data_jedi.shape[2], data_jedi.shape[3])
)
print("data_jedi.shape", data_jedi.shape)
torch.manual_seed(torch.seed())

In [None]:
data = []
torch.manual_seed(123)
for i in datamodule.train_dataloader():
    #
    data.append(i[0].numpy())
data = np.array(data)

data = np.reshape(data, (data.shape[0] * data.shape[1], data.shape[2], data.shape[3]))
print("data.shape", data.shape)
torch.manual_seed(torch.seed())

In [None]:
print(data_jedi[:10, 0])

In [None]:
print(data[:10, 0])

In [None]:
print(data.shape)
print(data_jedi.shape)
print((data - data_jedi).shape)
diff = data - data_jedi
print(diff[:10, 0])
print(np.allclose(data, data_jedi, rtol=1e-05, atol=1e-08, equal_nan=False))

In [None]:
import matplotlib.pyplot as plt

In [None]:
hist = plt.hist(data[:, :, 2].flatten(), histtype="stepfilled")
plt.hist(data_jedi[:, :, 2].flatten(), histtype="step", bins=hist[1])
plt.yscale("log")

In [None]:
test_array9 = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])
test_mask9 = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1], dtype=torch.bool)
print(type(test_array9))
print(test_mask9)
print(test_array9[test_mask9])

In [None]:
t_mask = torch.tensor(test_mask)
print(test_mask.shape)
print(test_data.shape)

In [None]:
print(test_mask.shape)
print(test_data.shape)
# print(test_mask[:2])
# print(test_mask[:2]==1)
# test_mask = torch.tensor(test_mask.clone(), dtype=torch.bool)
# test_data = torch.tensor(test_data.clone())
test_data = test_data.clone()
test_mask = test_mask.clone() == 1
print(test_mask[:2])
# tm = np.squeeze(np.array(test_mask == 1))
# tm = test_mask.repeat_interleave(3, dim=-1)
tm = test_mask.squeeze()
# tm =np.repeat(tm, 3, axis=-1)
# print(tm.shape)
mean_ar = []
print(test_data.shape)
# for i in range(3):
#    t2 = test_data[:, :, i]
#    t23 = t2[tm]
#    print(t23.shape)
#    mean = np.mean(t23)
#    mean_ar.append(mean)
#
# print(np.array(mean_ar).shape)
means = torch.tensor([[0.0, 0.0, 0.0]])
print(means.shape)
tt = test_data[tm]
print(tt.shape)
means, var = torch.var_mean(tt, dim=0, keepdim=True)
print(means.shape)
print(var.shape)

## One Hot Test

In [None]:
from jetnet.datasets import JetNet
from sklearn.preprocessing import OneHotEncoder

In [None]:
# function to one hot encode the jet type and leave the rest of the features as is
def OneHotEncodeType(x: np.ndarray):
    enc = OneHotEncoder(categories=[[0, 1]])
    type_encoded = enc.fit_transform(x[..., 0].reshape(-1, 1)).toarray()
    other_features = x[..., 1:].reshape(-1, 3)
    return np.concatenate((type_encoded, other_features), axis=-1).reshape(*x.shape[:-1], -1)

In [None]:
data_args = {
    # "jet_type": ["g", "q", "t", "w", "z"],  # gluon and top quark jets
    "jet_type": ["g", "t"],  # gluon and top quark jets
    "data_dir": "/beegfs/desy/user/ewencedr/data/jetnet/",
    # these are the default particle features, written here to be explicit
    "particle_features": ["etarel", "phirel", "ptrel", "mask"],
    "num_particles": 150,  # we retain only the 10 highest pT particles for this demo
    "jet_features": ["type", "pt", "eta", "mass"],
    # we don't want to normalise the 'mask' feature so we set that to False
    # "particle_normalisation": FeaturewiseLinear(
    #    normal=True, normalise_features=[True, True, True, False]
    # ),
    # pass our function as a transform to be applied to the jet features
    "jet_transform": OneHotEncodeType,
}

In [None]:
jets_train = JetNet(**data_args, split="train")
jets_valid = JetNet(**data_args, split="valid")
jets = JetNet(**data_args)

In [None]:
jets_train

In [None]:
particle_features, jet_features = jets_train[0]
print(f"Particle features ({data_args['particle_features']}):\n\t{particle_features}")
print(f"\nJet features ({data_args['jet_features']}):\n\t{jet_features}")

In [None]:
data_args_np = {
    "jet_type": ["g", "q", "t", "z"],  # gluon and top quark jets
    # "jet_type": ["z"],  # gluon and top quark jets
    "data_dir": "/beegfs/desy/user/ewencedr/data/jetnet/",
    # these are the default particle features, written here to be explicit
    "particle_features": ["etarel", "phirel", "ptrel", "mask"],
    "num_particles": 150,  # we retain only the 10 highest pT particles for this demo
    "jet_features": ["type", "pt", "eta", "mass"],
    # we don't want to normalise the 'mask' feature so we set that to False
    # "particle_normalisation": FeaturewiseLinear(
    #    normal=True, normalise_features=[True, True, True, False]
    # ),
    # pass our function as a transform to be applied to the jet features
    # "jet_transform": OneHotEncodeType,
    "split": "all",
}

In [None]:
particle_data_np, jet_data_np = JetNet.getData(**data_args_np)

In [None]:
print(particle_data_np.shape)
print(jet_data_np[:10])

In [None]:
# gluon: 0
# quark: 1
# top: 2
# w: 3
# z: 4

In [None]:
type_dict = {"g": 0, "q": 1, "t": 2, "w": 3, "z": 4}
categories = []
for type in data_args_np["jet_type"]:
    categories.append(type_dict[type])
print(categories)

In [None]:
def OneHotEncodeTypeNp(x: np.ndarray, categories: list = [[0, 1, 2, 3, 4]]):
    """One hot encode the jet type and leave the rest of the features as is
        Note: The one_hot encoded value is based on the position in the categories list not the value itself,
        e.g. categories: [0,3] results in the two one_hot encoded values [1,0] and [0,1]

    Args:
        x (np.ndarray): jet data with shape (num_jets, num_features) that contains the jet type in the first column
        categories (list, optional): List with values in x that should be one hot encoded. Defaults to [[0, 1, 2, 3, 4]].

    Returns:
        np.array: one_hot_encoded jet data (num_jets, num_features) with feature length len(categories) + 3 (pt, eta, mass)
    """
    enc = OneHotEncoder(categories=categories)
    type_encoded = enc.fit_transform(x[..., 0].reshape(-1, 1)).toarray()
    other_features = x[..., 1:].reshape(-1, 3)
    return np.concatenate((type_encoded, other_features), axis=-1).reshape(*x.shape[:-1], -1)

In [None]:
jet_data_one_hot = OneHotEncodeTypeNp(jet_data_np, categories=[categories])

In [None]:
print(jet_data_one_hot.shape)
print(jet_data_one_hot[:10])

In [None]:
conditioning_type = True
conditioning_pt = False
conditioning_eta = False
conditioning_mass = True

In [None]:
one_hot_len = len(categories)
print(one_hot_len)
keep_col = []
if conditioning_type:
    keep_col.append(np.arange(one_hot_len))
if conditioning_pt:
    keep_col.append(np.arange(one_hot_len, one_hot_len + 1))
if conditioning_eta:
    keep_col.append(np.arange(one_hot_len + 1, one_hot_len + 2))
if conditioning_mass:
    keep_col.append(np.arange(one_hot_len + 2, one_hot_len + 3))
keep_col = np.concatenate(keep_col)
print(keep_col)

In [None]:
# what happens if no conditioning is used?
jet_data_final = jet_data_one_hot[..., keep_col]
print(jet_data_final.shape)

In [None]:
print(jet_data_one_hot.shape)
print(jet_data_one_hot[:, [0, 1, 2, 5]].shape)

In [None]:
test_array = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
test_array = np.repeat(test_array, 10, axis=0)
print(test_array.shape)
print(test_array)
print(test_array[:, [0, 1, 2, 5]])

In [None]:
data = np.array(datamodule.tensor_test)

In [None]:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from tqdm import tqdm

In [None]:
print(data.shape)

In [None]:
color: str = ("#E2001A",)
mask_data = np.ma.masked_where(
    data[:, :, 0] == 0,
    data[:, :, 0],
)
mask = np.expand_dims(mask_data, axis=-1)

fig = plt.figure(figsize=(5, 5))
gs = GridSpec(1, 1)
ax = fig.add_subplot(gs[0])
# idx = np.random.randint(len(data))
for idx in tqdm(range(1000)):
    x_plot = data[idx, :, :2]  # .cpu()
    s_plot = np.abs(data[idx, :, 2])  # .cpu())
    s_plot[mask[idx, :, 0] < 0.0] = 0.0

    ax.scatter(*x_plot.T, s=50 * s_plot, color=color, alpha=0.5)

ax.set_xlabel(r"$\eta$")
ax.set_ylabel(r"$\phi$")

ax.set_xlim(-0.3, 0.3)
ax.set_ylim(-0.3, 0.3)
plt.show()

In [None]:
def plot_single_jets(
    data: np.ndarray,
    color: str = "#E2001A",
    save_folder: str = "logs/",
    save_name: str = "sim_jets",
) -> plt.figure:
    """Create a plot with 16 randomly selected jets from the data.

    Args:
        data (_type_): Data to plot.
        color (str, optional): Color of plotted point cloud. Defaults to "#E2001A".
        save_folder (str, optional): Path to folder where the plot is saved. Defaults to "logs/".
        save_name (str, optional): File_name for saving the plot. Defaults to "sim_jets".
    """
    mask_data = np.ma.masked_where(
        data[:, :, 0] == 0,
        data[:, :, 0],
    )
    mask = np.expand_dims(mask_data, axis=-1)
    fig = plt.figure(figsize=(16, 16))
    gs = GridSpec(4, 4)

    for i in tqdm(range(16)):
        ax = fig.add_subplot(gs[i])

        idx = np.random.randint(len(data))
        x_plot = data[idx, :, :2]  # .cpu()
        s_plot = np.abs(data[idx, :, 2])  # .cpu())
        s_plot[mask[idx, :, 0] < 0.0] = 0.0

        ax.scatter(*x_plot.T, s=5000 * s_plot, color=color, alpha=0.5)

        ax.set_xlabel(r"$\eta$")
        ax.set_ylabel(r"$\phi$")

        ax.set_xlim(-0.3, 0.3)
        ax.set_ylim(-0.3, 0.3)

    plt.tight_layout()

    plt.savefig(f"{save_folder}{save_name}.png", bbox_inches="tight")
    return fig

# Test Sliced Wasserstein Distance

In [None]:
def swd(data: torch.Tensor, preds: torch.Tensor, n_proj: int = 1024) -> torch.Tensor:
    """Sliced Wassersteini Distance
    Compute the Wasserstein distance between two point clouds.
    Inspired by https://github.com/apple/ml-cvpr2019-swd/blob/master/swd.py#L45

    Args:
        data (torch.Tensor) [batch, n_points, feats]: Ground truth.
        preds (torch.Tensor) [batch, n_points, feats]: Predictions.
        n_proj (int, optional): number of random 1d projections. Defaults to 1024.

    Returns:
        wdist (torch.Tensor) [1]: Wasserstein distance
    """

    b, p, f = data.shape  # [batch,points,feats]
    data, preds = data.float(), preds.float()
    proj = torch.randn(f, n_proj, device=data.device)  # [feats, l]
    print(f"proj: {proj.shape}")
    proj *= torch.rsqrt(torch.sum(torch.square(proj), 0, keepdim=True))
    print(f"proj: {proj.shape}")
    proj = proj.view(1, f, n_proj).expand(b, -1, -1)  # first add dim, then expand to batch dim
    print(f"proj: {proj.shape}")
    p1 = torch.matmul(data, proj)  # shape: [batch, n_points, l]
    print(f"p1: {p1.shape}")
    p2 = torch.matmul(preds, proj)  # shape: [batch, n_points, l]
    print(f"p2: {p2.shape}")
    p1, _ = torch.sort(p1, 1, descending=True)  # point wise sorting
    print(f"p1: {p1.shape}")
    p2, _ = torch.sort(p2, 1, descending=True)
    print(f"p2: {p2.shape}")
    wdist = torch.mean(torch.square(p1 - p2))  # MSE
    return wdist

In [None]:
tensor1 = torch.tensor(data)
tensor2 = torch.tensor(data) + 0.1
print(tensor1.shape)

In [None]:
swd_ = swd(tensor1, tensor2)
print(swd_)

In [None]:
mask_test = torch.rand(2, 5)
mask_test[mask_test > 0.5] = 1
mask_test = mask_test == 1
print(mask_test)
mask_test.shape

In [None]:
test_tensor = torch.rand(2, 5, 3)
print(test_tensor)
print(test_tensor.shape)
print(mask_test.shape)
masked_test = test_tensor[mask_test]
print(masked_test.shape)
print(masked_test)

In [None]:
mskd = test_tensor * mask_test
print(mskd)

In [None]:
print(mskd.mean())
print(mskd.sum() / mask_test.sum())