In [22]:
import os
import glob
import random
import numpy as np
import awkward as ak
# from ml4cc.tools.data import io
from omegaconf import DictConfig

In [31]:
def save_array_to_file(data: ak.Array, output_path: str) -> None:
    print(f"Saving {len(data)} processed entries to {output_path}")
    ak.to_parquet(data, output_path, row_group_size=1024)


def save_processed_data(arrays: ak.Array, path: str, dataset: str = "") -> None:
    """
    Parameters:
        arrays: ak.Array
            The awkward array to be saved into file
        path: str
            The path of the output file
        dataset: str
            [default: ""] The dataset to be used for the training, either train, val or test

    Returns:
        None
    """
    input_dir = os.path.dirname(path)
    output_dir = os.path.join(input_dir, "preprocessed_filtered", dataset)
    filename = os.path.basename(path)
    output_path = os.path.join(output_dir, filename)
    os.makedirs(output_dir, exist_ok=True)
    save_array_to_file(data=arrays, output_path=output_path)


def train_val_test_split(arrays: ak.Array, test_also: bool = True) -> tuple:
    """Splits the data into train, val and test sets.

    Parameters:
        arrays : ak.Array
            The array to be split into train, val and test sets

    Returns:
        train_indices : list
            The indices of the train set
        test_indices : list
            The indices of the test set
        val_indices : list
            The indices of the val set
    """
    total_len = len(arrays)
    indices = list(range(total_len))
    random.seed(42)
    random.shuffle(indices)
    if test_also:
        num_train_rows = int(np.ceil(total_len * 0.7))
        num_val_rows = int(np.ceil(total_len * 0.1))
        train_indices = indices[:num_train_rows]
        val_indices = indices[num_train_rows : num_train_rows + num_val_rows]
        test_indices = indices[num_train_rows + num_val_rows :]
    else:
        frac_train_raw = 0.7
        frac_val_raw = 0.1
        frac_train = frac_train_raw / (frac_train_raw + frac_val_raw)
        frac_val = frac_val_raw / (frac_train_raw + frac_val_raw)
        train_indices = indices[: int(np.ceil(total_len * frac_train))]
        val_indices = indices[int(np.ceil(total_len * frac_train)) : int(np.ceil(total_len * (frac_train + frac_val)))]
        test_indices = None
    return train_indices, test_indices, val_indices


def save_train_val_test_data(
    arrays, path, train_indices, val_indices, test_indices) -> None:
    """Saves the split data into train, val and test sets.

    Parameters:
        arrays : ak.Array
            The array to be split into train, val and test sets
        path : str
            The path of the output file
        train_indices : list
            The indices of the train set
        val_indices : list
            The indices of the val set
        test_indices : list
            The indices of the test set


    Returns:
        None
    """
    train_array = arrays[train_indices]
    save_processed_data(train_array, path, dataset="train")
    val_array = arrays[val_indices]
    save_processed_data(val_array, path, dataset="val")
    if test_indices is not None:
        test_array = arrays[test_indices]
        save_processed_data(test_array, path, dataset="test")

def process_all(input_dir):
    for path in glob.glob(os.path.join(input_dir, "*")):
        arrays = ak.from_parquet(path)
        flat_arrays = ak.flatten(arrays.TrainingJet)
        train_indices, val_indices, test_indices = train_val_test_split(arrays=flat_arrays)
        save_train_val_test_data(flat_arrays, path, train_indices, val_indices, test_indices)
        
def process_all_filtered(input_dir):
    for path in glob.glob(os.path.join(input_dir, "*")):
        arrays = ak.from_parquet(path)
        print("processing file named ", path)
        filtered_arrays = arrays.TrainingJet
        train_indices, val_indices, test_indices = train_val_test_split(arrays=filtered_arrays)
        save_train_val_test_data(filtered_arrays, path, train_indices, val_indices, test_indices)

In [32]:
INPUT_DIR = "/home/norman/vbf-tagger/vbf_tagger/data/22pre/"
for dataset_dir in glob.glob(os.path.join(INPUT_DIR, "*")):
    process_all_filtered(dataset_dir)

processing file named  /home/norman/vbf-tagger/vbf_tagger/data/22pre/hh_ggf/hh_ggf_hbb_htt_kl5_kt1_powheg_events_0.parquet
Saving 10207 processed entries to /home/norman/vbf-tagger/vbf_tagger/data/22pre/hh_ggf/preprocessed_filtered/train/hh_ggf_hbb_htt_kl5_kt1_powheg_events_0.parquet
Saving 2915 processed entries to /home/norman/vbf-tagger/vbf_tagger/data/22pre/hh_ggf/preprocessed_filtered/val/hh_ggf_hbb_htt_kl5_kt1_powheg_events_0.parquet
Saving 1459 processed entries to /home/norman/vbf-tagger/vbf_tagger/data/22pre/hh_ggf/preprocessed_filtered/test/hh_ggf_hbb_htt_kl5_kt1_powheg_events_0.parquet
processing file named  /home/norman/vbf-tagger/vbf_tagger/data/22pre/hh_ggf/hh_ggf_hbb_htt_kl2p45_kt1_powheg_events_1.parquet
Saving 30786 processed entries to /home/norman/vbf-tagger/vbf_tagger/data/22pre/hh_ggf/preprocessed_filtered/train/hh_ggf_hbb_htt_kl2p45_kt1_powheg_events_1.parquet
Saving 8796 processed entries to /home/norman/vbf-tagger/vbf_tagger/data/22pre/hh_ggf/preprocessed_filter

In [8]:
INPUT_DIR2 = "/home/norman/vbf-tagger/vbf_tagger/data/22pre/hh_vbf/hh_vbf_hbb_htt_kv1_k2v0_kl1_madgraph_events_0.parquet"

In [15]:
arrays2 = ak.from_parquet(INPUT_DIR2)
arrays2 = arrays2.TrainingJet

In [19]:
arrays2[0]