### About this demo

This demo generates padded events to `npz` format, note the events are read from `hdf5` file. However, one can also regenerate the data by choosing `use_hdf5=False` in `module_data.FatJetEvents`.

In [None]:
import os

import awkward as ak
import lightning as L
import numpy as np
import torch

import module_data

In [None]:
def generate_jet_events(channel: str):
    """Generate Jet events in FatJetEvents type."""
    jet_events = module_data.FatJetEvents(
        channel=channel,
        cut_pt=(800,1000),
        subjet_radius=0,
        num_pt_ptcs=8,
        use_hdf5=True
    )

    return jet_events

def generate_uniform_pt_events(jet_events: module_data.FatJetEvents, rnd_seed: int):
    # Randomly generate uniform pt events.
    L.seed_everything(rnd_seed)
    events = jet_events.generate_uniform_pt_events(bin=10, num_bin_data=500)
    
    # Print out information about the channel.
    channel = jet_events.channel
    max_num_ptcs = ak.max(ak.num(events["fast_pt"], axis=1))
    print(f"Channel {channel}: Maximum number of particles = {max_num_ptcs}")

    # Preprocess the jet events (see `module_data.py` for detail).
    preprocess_func = lambda _events: \
        module_data.JetDataModule._preprocess(self=None, events=_events)
    events = preprocess_func(events)

    # Pad all events with same number of particles with zero.
    for i in range(len(events)):
        num_padding = max_num_ptcs - len(events[i])
        zero_padding = torch.zeros((num_padding, 3))
        events[i] = torch.cat((events[i], zero_padding), dim=0)
        events[i] = events[i].numpy()

    return np.array(events)

In [None]:
channel_list = ["VzToQCD", "VzToZhToVevebb", "VzToTt"]

# Generate in different random seeds.
for rnd_seed in range(3):
    npz_dict = {}

    # Generate for different channels.
    for channel in channel_list:
        jet_events = generate_jet_events(channel=channel)
        npz_dict[channel] = generate_uniform_pt_events(
            jet_events=jet_events,
            rnd_seed=rnd_seed,
        )
    
    # Each npy file generates with a specific random seed.
    npz_dir = f"./jet_dataset/padded_npz"
    npz_name = f"fatjet_{rnd_seed}.npz"
    os.makedirs(npz_dir, exist_ok=True)

    # Save to `npz` format.
    np.savez(os.path.join(npz_dir, npz_name), **npz_dict)