In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os

from torch.utils.data import Dataset
import awkward as ak

import h5py
from torch.utils.data import random_split

import uproot
import torch



In [3]:
class FileConfig:
    files = {"../data/fourtopvsthree/rootfiles/3tJ_LO_final.root" : 0,
              "../data/fourtopvsthree/rootfiles/3tWm_LO_final.root" : 0,
               "../data/fourtopvsthree/rootfiles/3tWp_LO_final.root" : 0, 
                "../data/fourtopvsthree/rootfiles/4top_2LSS_April18.root" : 1}

PARTICLE_PREFIXES = {
    "jet" : 0,
    "el" : 1,
    "mu" : 2
    
}

class TopClassiferDataSetPrepaper:
    file_config = FileConfig
    train_split_size = 0.7
    def __init__(self):
        (
         particle_data, 
         global_data, 
         src_mask,
         targets
        ) = self.parse_root_file(
             max_particles = 25,
             particle_features = ["_pt", "_eta", "_phi", "_mass"],
             global_features = ["met_met", "met_eta", "met_phi"]
         ) 
        self.data_set = TopMulitplicityClassifierDataSet(particle_data, global_data, src_mask, targets)
        
        self.particle_data = particle_data
        self.global_data = global_data
        self.targets = targets
    
    def parse_root_file(self, max_particles, particle_features, global_features):
        per_file_events, per_file_targets, per_file_globals = [], [], []

        for path, y in self.file_config.files.items():
            reco = uproot.open(path)["Reco;1"]
            blocks = []

            gdict = reco.arrays(global_features, how=dict)
            gstack = ak.concatenate([gdict[name][..., None] for name in global_features], axis=-1)
            per_file_globals.append(gstack)

            for prefix in PARTICLE_PREFIXES.keys():
                feats = [f"{prefix}{fe}" for fe in particle_features]
                bdict = reco.arrays(feats, how=dict)
                base  = ak.concatenate([bdict[name][..., None] for name in feats], axis=-1)

                if prefix in ("el", "mu"):
                    chd = reco.arrays(f"{prefix}_charge", how=dict)
                    ch  = ak.concatenate([chd[k][..., None] for k in chd], axis=-1)
                    print(ch)
                else:
                    ch = ak.zeros_like(base[..., :1])

                if prefix == "jet":
                    btd = reco.arrays(f"{prefix}_btag", how=dict)
                    bt  = ak.concatenate([btd[k][..., None] for k in btd], axis=-1)
                else:
                    bt = ak.zeros_like(base[..., :1])

                blocks.append(ak.concatenate([base, ch, bt], axis=-1))

            events = ak.concatenate(blocks, axis=1)
            per_file_events.append(events)
            per_file_targets.append(torch.full((len(events), 1), int(y), dtype=torch.long))

        global_arr = ak.concatenate(per_file_globals, axis=0)
        global_arr = torch.from_numpy(ak.to_numpy(global_arr).astype(np.float32, copy=False))

        arr = ak.concatenate(per_file_events, axis=0)                  
        arr = ak.pad_none(arr, max_particles, axis=1, clip=True)

        pad_mask_np   = ak.to_numpy(ak.is_none(arr, axis=-1))          
        arr           = ak.fill_none(arr, np.nan)                     
        src_mask = allnan_mask_np = ak.to_numpy(ak.all(np.isnan(arr), axis=-1))   


        dense_np = ak.to_numpy(arr).astype(np.float32, copy=False)
        ## Commented out to save file in h5oy format with nans
        #np.nan_to_num(dense_np, copy=False, nan=-1010, posinf=-1010, neginf=-1010)
        set_array = torch.from_numpy(dense_np)

        target_array = torch.cat(per_file_targets, dim=0)
        return set_array, global_arr, src_mask, target_array

    def split_data_set(self):
        self.train_data, self.val_data,  self.test_data = random_split(self.data_set,  [self.train_split_size, (1 -  self.train_split_size) /2 , (1 - self.train_split_size )/ 2])

    

    def _save_dataset(self, save_file_name: str):
        ## Assumes the data set is already created tbh cba coding in the checks
        with h5py.File(save_file_name, "w") as f:
            particle_fea = f.create_group("particle_features")
            dset = particle_fea.create_dataset("all", data = self.particle_data)

            global_data = f.create_group("global_data")
            glo_dset = global_data.create_dataset("all" , data = self.global_data)

            target_data = f.create_group("targets")
            target_dset = target_data.create_dataset("all", data = self.targets)
        

class TopMulitplicityClassifierDataSet(Dataset):
    ### Torch module for Dataset, allows easy dataloader creation
    def __init__(self, particle_features, global_features, src_mask, target_labels):
        self.particle_features = particle_features
        self.global_features = global_features
        self.src_mask = src_mask
        self.target_labels = target_labels
    def __len__(self):
        return self.particle_features.shape[0]
    def __getitem__(self, idx):
        return {"particle_features": self.particle_features[idx],
               "global_features": self.global_features[idx],
               "src_mask": self.src_mask[idx]}, self.target_labels[idx]

In [4]:
top_data_sets = TopClassiferDataSetPrepaper()

[[[1], [-1]], [], [[1]], [[-1]], [[1]], ..., [[1]], [[1]], [[-1]], [[1]]]
[[], [[-1], [1]], [[1]], [[-1]], [[1]], ..., [[1]], [[-1]], [[-1]], [[-1]]]
[[[-1]], [[-1]], [[-1]], [], [], ..., [[1], ...], [[-1]], [[-1], [1]], [], []]
[[[-1]], [[1]], [[1]], [[-1], [1]], ..., [[-1]], [], [[1], [-1]], [[1], [1]]]
[[[1], [1]], [[-1], [1]], [[-1], [1]], [], ..., [[-1], ...], [[1]], [], [[-1]]]
[[], [], [], [[1], [1]], [[1], [-1]], ..., [], [], [[-1]], [[1], [1]], [[1]]]
[[], [], [[1]], [], [], [[1]], [], ..., [[-1]], [], [], [], [[-1]], [], [[-1]]]
[[[1], [1]], [[1], [1]], [[1]], [[1], ...], ..., [[-1]], [[-1], [-1]], [[-1]]]


In [5]:
top_data_sets._save_dataset("raw_data")