In [20]:
import glob
import logging
import random
from typing import Optional

import awkward as ak
import lightning as L
import numpy as np
import torch
import vector
from torch.utils.data import DataLoader, IterableDataset
import uproot

from data_processing.cld_processing import (
    get_event_data,
    gen_to_features,
    track_to_features,
    cluster_to_features,
    process_calo_hit_data,
    process_tracker_hit_data,
    create_track_to_hit_coo_matrix,
    create_cluster_to_hit_coo_matrix,
    genparticle_track_adj,
    create_genparticle_to_genparticle_coo_matrix,
    create_genparticle_to_genparticle_coo_matrix2,
    get_calo_hit_data
)

#logger = get_pylogger(__name__)

vector.register_awkward()


In [51]:
class CustomIterableDataset(IterableDataset):
    """Custom IterableDataset that loads data from multiple files."""

    def __init__(
        self,
        files_list: list,
        n_files_at_once: int = None,
        max_n_files_per_type: int = None,
        shuffle_files: bool = True,
        shuffle_data: bool = True,
        seed: int = 4697,
        seed_shuffle_data: int = 3838,
        pad_length: int = 128,
        logger_name: str = "CustomIterableDataset",
        feature_dict: dict = None,
        token_reco_cfg: dict = None,
        token_id_cfg: dict = None,
        load_only_once: bool = False,
        shuffle_only_once: bool = False,
        random_seed_for_per_file_shuffling: int = None,
        **kwargs,
    ):
        """
        Parameters
        ----------
        files_list : list
            list with the file names for each type. e.g. a dict like
            {"tbqq": ["tbqq_0.root", ...], "qcd": ["qcd_0.root", ...], ...}.
        n_files_at_once : int, optional
            Number of files to load at once. If None, one file per files_dict key
            is loaded.
        max_n_files_per_type : int, optional
            Maximum number of files to use per type. If None, all files are used.
            Can be used to use e.g. always the first file from the sorted list of files
            in validation.
        shuffle_files : bool, optional
            Whether to shuffle the list of files.
        shuffle_data : bool, optional
            Whether to shuffle the data after loading.
        seed : int, optional
            Random seed.
        seed_shuffle_data : int, optional
            Random seed for shuffling the data. This is useful if you want to shuffle
            the data in the same way for different datasets (e.g. train and val).
            The default value is 3838.
        pad_length : int, optional
            Maximum number of particles per jet. If a jet has more particles, the
            first pad_length particles are used, the rest is discarded.
        logger_name : str, optional
            Name of the logger.
        feature_dict : dict, optional
            Dictionary with the features to load. The keys are the names of the features
            and the values are the preprocessing parameters passed to the
            `ak_select_and_preprocess` function.
        token_reco_cfg : dict, optional
            Dictionary with the configuration to reconstruct the tokenized jetclass files.
            If None, this is not used.
        token_id_cfg : dict, optional
            Dictionary with the tokenization configuration, this is to be used when the
            token-id data is to be loaded. If None, this is ignored.
        load_only_once : bool, optional
            If True, the data is loaded only once and then returned in the same order
            in each iteration. NOTE: this is only useful if the whole dataset fits into
            memory. If the dataset is too large, this will lead to a memory error.
        shuffle_only_once : bool, optional
            If True, the data is shuffled only once and then returned in the same order
            in each iteration. NOTE: this should only be used for val/test.
        random_seed_for_per_file_shuffling : int, optional
            Random seed for shuffling the jets within a file. This is useful if you want
            to only load a subset of the jets from a file and want to choose different
            jets in different training runs.
            If load_only_once is False, this is ignored.

        **kwargs
            Additional keyword arguments.

        """
        if feature_dict is None:
            raise ValueError("feature_dict must be provided.")
     
            
        self.logger = logging.getLogger(logger_name)
        self.logger.info(f"Using seed {seed}")
        self.pad_length = pad_length
        self.shuffle_data = shuffle_data
        self.shuffle_files = shuffle_files
        self.feature_dict = feature_dict
        self.particle_features_list = [feat for feat in self.feature_dict.keys()]
        self.seed_shuffle_data = seed_shuffle_data
        self.load_only_once = load_only_once
        self.shuffle_only_once = shuffle_only_once
        self.data_shuffled = False
        self.random_seed_for_per_file_shuffling = random_seed_for_per_file_shuffling

        if self.random_seed_for_per_file_shuffling is not None:
            if not self.load_only_once:
                self.logger.warning(
                    "random_seed_for_per_file_shuffling is only used if load_only_once is True."
                )
                self.random_seed_for_per_file_shuffling = None
            else:
                self.logger.info(
                    f"Using random seed {self.random_seed_for_per_file_shuffling} for per-file shuffling."
                )

        self.logger.info(f"Using the following particle features: {self.particle_features_list}")
        self.logger.info(f"pad_length {self.pad_length} for the number of particles per jet.")
        self.logger.info(f"shuffle_data={self.shuffle_data}")
        self.logger.info(f"shuffle_files={self.shuffle_files}")
       
        self.logger.info("Using the following features:")
        for feat, params in self.feature_dict.items():
            self.logger.info(f"- {feat}: {params}")
        self.file_list = []
        for file in files_list:
            self.file_list.extend(sorted(list(glob.glob(file))))

        for file in self.file_list:
            self.logger.info(f" - {file}")

        if self.load_only_once:
            logger.warning(
                "load_only_once is True. This means that there will only be the initial data loading."
            )

        # if not specified how many files to use at once, use one file per jet_type
        if n_files_at_once is None:
            self.n_files_at_once = len(self.files_dict)
        else:
            if n_files_at_once > len(self.file_list):
                self.logger.warning(
                    f"n_files_at_once={n_files_at_once} is larger than the number of files in the"
                    f" dataset ({len(self.file_list)})."
                )
                self.logger.warning(f"Setting n_files_at_once to {len(self.file_list)}.")
                self.n_files_at_once = len(self.file_list)
            else:
                self.n_files_at_once = n_files_at_once

        self.logger.info(f"Will load {self.n_files_at_once} files at a time and combine them.")

        self.file_indices = np.array([0, self.n_files_at_once])
        self.file_iterations = len(self.file_list) // self.n_files_at_once
        if self.load_only_once:
            self.file_iterations = 1

        self.current_part_data = None
        self.current_part_mask = None
        self.token_reco_cfg = token_reco_cfg
        self.token_id_cfg = token_id_cfg

    def get_data(self):
        """Returns a generator (i.e. iterator) that goes over the current files list and returns
        batches of the corresponding data."""
        # Iterate over jet_type
        self.logger.debug("\n>>> __iter__ called\n")
        self.file_indices = np.array([0, self.n_files_at_once])

        # shuffle the file list
        if self.shuffle_files:
            self.logger.info(">>> Shuffling files")
            random.shuffle(self.file_list)
            # self.logger.info(">>> self.file_list:")
            # for filename in self.file_list:
            #     self.logger.info(f" - {filename}")

        # Iterate over files
        for j in range(self.file_iterations):
            self.logger.debug(20 * "-")
            # Increment file index if not first iteration
            if j > 0:
                self.logger.info(">>> Incrementing file index")
                self.file_indices += self.n_files_at_once

            # stop the iteration if self.file_indices[1] is larger than the number of files
            # FIXME: this means that the last batch of files (in case the number of files is not
            # divisible by self.n_files_at_once) is not used --> fix this
            # but if shuffling is used, this should not be a problem
            if self.file_indices[1] <= len(self.file_list):
                self.load_next_files()

                # loop over the current data
                for i in range(len(self.current_part_data)):
                    yield {
                        "part_features": self.current_part_data[i],
                        "part_mask": self.current_part_mask[i],
                        "jet_type_labels_one_hot": self.current_jet_type_labels_one_hot[i],
                        "jet_type_labels": torch.argmax(self.current_jet_type_labels_one_hot[i]),
                    }

    def __iter__(self):
        """returns an iterable which represents an iterator that iterates over the dataset."""
        return iter(self.get_data())

    def load_next_files(self):
        if self.load_only_once:
            if self.current_part_data is not None:
                self.logger.warning("Data has already been loaded. Will not load again.")
                self.shuffle_current_data()
                return
        self.part_data_list = []
        self.mask_data_list = []

        self.current_files = self.file_list[self.file_indices[0] : self.file_indices[1]]
        self.logger.info(f">>> Loading next files - self.file_indices={self.file_indices}")
        if self.load_only_once:
            self.logger.warning("Loading data only once. Will not load again.")
            self.logger.warning("--> This will be the data for all iterations.")
        for i_file, filename in enumerate(self.current_files):
            self.logger.info(f"{i_file+1} / {len(self.current_files)} : {filename}")

            """
            if self.token_reco_cfg is not None:
                gpu_available = torch.cuda.is_available()
                _, ak_x_particles, ak_jet_type_labels = reconstruct_jetclass_file(
                    filename_in=filename,
                    model_ckpt_path=self.token_reco_cfg["ckpt_file"],
                    config_path=self.token_reco_cfg["config_file"],
                    start_token_included=self.token_reco_cfg["start_token_included"],
                    end_token_included=self.token_reco_cfg["end_token_included"],
                    shift_tokens_by_minus_one=self.token_reco_cfg["shift_tokens_by_minus_one"],
                    device="cuda" if gpu_available else "cpu",
                    return_labels=True,
                )
                self.logger.info("Calculating additional kinematic features.")
                ak_x_particles = calc_additional_kinematic_features(ak_x_particles)

            elif self.token_id_cfg is not None:
                ak_x_particles, ak_jet_type_labels = read_tokenized_jetclass_file(
                    filename,
                    labels=self.labels_to_load,
                    remove_start_token=self.token_id_cfg.get("remove_start_token", False),
                    remove_end_token=self.token_id_cfg.get("remove_end_token", False),
                    shift_tokens_minus_one=self.token_id_cfg.get("shift_tokens_minus_one", False),
                    n_load=self.n_jets_per_file,
                    random_seed=self.random_seed_for_per_file_shuffling,
                )
                ak_x_particles = ak.Array(
                    {
                        "part_token_id": ak_x_particles["part_token_id"],
                        "part_token_id_without_last": ak_x_particles["part_token_id"][:, :-1],
                        "part_token_id_without_first": ak_x_particles["part_token_id"][:, 1:],
                    }
                )
            else:
            """
                # read the data from the file
            # can add jet features, labels, and p4s here''

            fi = uproot.open(filename)
            ev = fi["events"]
            event_data = get_event_data(ev)

            all_event_data = []

            collectionIDs = {
                    k: v
                    for k, v in zip(
                        fi.get("podio_metadata").arrays("events___idTable/m_names")["events___idTable/m_names"][0],
                        fi.get("podio_metadata").arrays("events___idTable/m_collectionIDs")["events___idTable/m_collectionIDs"][0],
                    )
                    }

            for iev in range(2):#len(ev["MCParticles.momentum.x"].array())):
            
                calo_hit_features, genparticle_to_calo_hit_matrix, calo_hit_idx_local_to_global = process_calo_hit_data(
                    event_data, iev, collectionIDs
                )
                all_event_data.append(calo_hit_features)

            all_event_data = ak.Array(all_event_data)

            # pad data
            padded_data = ak.fill_none(ak.pad_none(all_event_data, self.pad_length, axis=1, clip=True), value=0)
            if len(all_event_data.fields) >= 1:
                 mask = ak.ones_like(all_event_data[all_event_data.fields[0]], dtype="bool")
            else:
                mask = ak.ones_like(all_event_data, dtype="bool")
            mask = ak.fill_none(ak.pad_none(mask, self.pad_length, axis=1, clip=True), False)


            # convert to numpy
            np_padded_data = ak.to_numpy(
                    np.stack(
                        [ak.to_numpy(ak.values_astype(padded_data[name], "float32")) for name in self.particle_features_list],
                        axis=1,
                    )
                )
            
            np_mask = ak.to_numpy(mask)

            # add the data to the lists
            self.part_data_list.append(torch.tensor(np_padded_data))
            self.mask_data_list.append(torch.tensor(np_mask, dtype=torch.bool))
        
        # concatenate the data from all files
        self.current_part_data = torch.cat(self.part_data_list, dim=0)
        self.current_part_mask = torch.cat(self.mask_data_list, dim=0)
        self.shuffle_current_data()
        self.logger.info(
            f">>> Data loaded. (self.current_part_data.shape = {self.current_part_data.shape})"
        )

    def shuffle_current_data(self):
        # shuffle the data
        if self.shuffle_only_once and self.data_shuffled:
            self.logger.info("Data has already been shuffled. Will not shuffle again.")
            return
        if self.shuffle_data:
            rng = np.random.default_rng()
            if self.seed_shuffle_data is not None:
                self.logger.info(f"Shuffling data with seed {self.seed_shuffle_data}")
                rng = np.random.default_rng(self.seed_shuffle_data)
            perm = rng.permutation(len(self.current_part_data))
            self.current_part_data = self.current_part_data[perm]
            self.current_part_mask = self.current_part_mask[perm]
            self.data_shuffled = True

In [52]:


data_dir = "/pscratch/sd/r/rmastand/particlemind/data/p8_ee_tt_ecm365_rootfiles/"


files_list = [f"{data_dir}/reco_p8_ee_tt_ecm365_63743.root", f"{data_dir}/reco_p8_ee_tt_ecm365_63733.root"]


feature_dict = { "type": {"multiply_by": 0.3, "subtract_by": 2.7, "func": np.log, "inv_func": np.exp},
     "energy": {"multiply_by": 4},
     "position.x": {"multiply_by": 4},
               "position.y": {"multiply_by": 4},
               "position.z": {"multiply_by": 4}}


test_dataset = CustomIterableDataset(
        files_list = files_list,
        n_files_at_once = 1,
        feature_dict = feature_dict,
    pad_length = 100000,
        )

test_dataloader = DataLoader(test_dataset, batch_size=3)

iter_dl = iter(test_dataloader)
batch = next(iter_dl)
print(batch)
"""
    n_jets_per_file: int = None,
    max_n_files_per_type: int = None,
    shuffle_files: bool = True,
    shuffle_data: bool = True,
    seed: int = 4697,
    seed_shuffle_data: int = 3838,
    pad_length: int = 128,
    logger_name: str = "CustomIterableDataset",
    feature_dict: dict = None,
    labels_to_load: list = None,
    token_reco_cfg: dict = None,
    token_id_cfg: dict = None,
    load_only_once: bool = False,
    shuffle_only_once: bool = False,
    random_seed_for_per_file_shuffling: int = None,,
"""



[21010, 31010, 41010, 51010, 61010, ..., 12032, 22032, 32032, 42032, 52032]
[0.0266, 0.0917, 0.251, 0.209, ..., 0.00106, 9.51e-05, 7.98e-05, 6.17e-05]
[2.16e+03, 2.17e+03, 2.17e+03, 2.18e+03, ..., -2.95e+03, -3.09e+03, -3.24e+03]
[35.7, 35.7, 35.7, 35.7, 35.7, ..., -2.33e+03, -2.44e+03, -2.59e+03, -2.7e+03]
[1.29e+03, 1.29e+03, 1.29e+03, 1.29e+03, ..., 4.69e+03, 4.93e+03, 5.17e+03]
[1010, 11010, 21010, 31010, 41010, ..., 54021, 44021, 64021, 64021, 34021]
[0.00764, 0.00662, 0.00763, 0.00654, ..., 0.0155, 0.0521, 0.032, 0.0355]
[1.77e+03, 1.78e+03, 1.78e+03, 1.79e+03, 1.79e+03, ..., 182, 235, 240, 283, 132]
[-1.23e+03, -1.23e+03, -1.24e+03, ..., -2.58e+03, -2.65e+03, -2.52e+03]
[-1.84e+03, -1.85e+03, -1.85e+03, ..., -2.53e+03, -2.53e+03, -2.45e+03]
torch.Size([2, 5, 100000]) torch.Size([2, 100000])


AttributeError: 'CustomIterableDataset' object has no attribute 'current_jet_type_labels_one_hot'