# Dataset Index Generation
Generates indices for train, validation and test sets

In [1]:
import h5py
import re
import numpy as np

## Options

In [2]:
# define which files of events are included in test / validation / training split
# the same split is done for each particle type
test_files_start = 0
test_files_stop = test_files_start+400 # first 400 files are for test set
val_files_start = test_files_stop
val_files_stop = val_files_start+100 # next 100 files are for validation set
train_files_start = val_files_stop
train_files_stop = None # all remaining files are for training set
# train_files_stop = 500 # next 500 files are for training set

# define which particle labels to include 0=gamma 1=electron 2=muon 3=pi0
#labels = (0, 1, 2, 3) # 4 class
labels = (1,) # 1 class electrons
#labels = (2,) # 1 class muons
#labels = (1,2,) # 2 class e & mu

## Load dataset

In [3]:
data_path = "/fast_scratch/WatChMaL/data/WCTE/WCTE_e-_1M_mu-_1M_0to1.5GeV.h5"
f = h5py.File(data_path, "r")
event_labels = np.array(f['labels'])
root_files = np.array(f['root_files']).astype(str)

### Define a cut to choose which events to keep

In [4]:
event_hits_index = np.array(f["event_hits_index"])
nhits = np.diff(event_hits_index, append=f["hit_pmt"].shape[0])

veto = np.array(f['veto']) # removes events with particles that escape tank, based on truth info but not entirely correct

cut = ((nhits>10) # keep only events with more than 10 hits
      & (veto==0) # keep only FC events
      )

## Find the files of each label and indices of each file

In [5]:
def atoi(text):
    return int(text) if text.isdigit() else text
# Sort by only the basename of the file, with natural sorting of numbers in the filename
def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text.split('/')[-1]) ]

In [6]:
files_in_labels = {l: sorted(set(root_files[event_labels==l]), key=natural_keys) for l in labels}
idxs_in_files = {f: np.arange(i, i+c) for f,i,c in zip(*np.unique(root_files, return_index=True, return_counts=True))}

In [7]:
for l, f in files_in_labels.items():
    print("label", l,"has", len(f),"files and ", sum([len(idxs_in_files[i]) for i in f]), "indices")

label 1 has 1000 files and  1000000 indices


In [8]:
# select the events that are not cut
selected_idxs_in_files = {k: v[cut[v]] for k,v in idxs_in_files.items()}

In [9]:
for l, f in files_in_labels.items():
    print("label", l,"has", len(f),"files and ", sum([len(selected_idxs_in_files[i]) for i in f]), "selected indices")

label 1 has 1000 files and  127791 selected indices


## Create the splits

In [10]:
split_files = {"test_idxs":  [f for l in labels for f in files_in_labels[l][test_files_start:test_files_stop]],
               "val_idxs":   [f for l in labels for f in files_in_labels[l][val_files_start:val_files_stop]],
               "train_idxs": [f for l in labels for f in files_in_labels[l][train_files_start:train_files_stop]]}
split_idxs = {k: [i for f in v for i in selected_idxs_in_files[f]] for k, v in split_files.items()}

In [11]:
for s in split_files.keys():
    print(s,"has", len(split_files[s]),"files and", len(split_idxs[s]),"indices")

test_idxs has 400 files and 50986 indices
val_idxs has 100 files and 12805 indices
train_idxs has 500 files and 64000 indices


In [12]:
# Verify that all events are uniquely accounted for
all_indices = np.concatenate(list(split_idxs.values()))
print(len(event_labels))
print(len(event_labels[cut]))
print(len(all_indices))
print(len(set(all_indices)))

2000000
257207
127791
127791


## Save file

In [13]:
#np.savez('/fast_scratch/WatChMaL/data/WCTE/index_lists/WCTE_e-_1M_mu-_1M_0to1.5GeV_1class_e-.npz', **split_idxs)
#np.savez('/fast_scratch/WatChMaL/data/WCTE/index_lists/WCTE_e-_1M_mu-_1M_0to1.5GeV_1class_mu-.npz', **split_idxs)
#np.savez('/fast_scratch/WatChMaL/data/WCTE/index_lists/WCTE_e-_1M_mu-_1M_0to1.5GeV_2class.npz', **split_idxs)
np.savez('/fast_scratch/WatChMaL/data/WCTE/index_lists/WCTE_e-_1M_mu-_1M_0to1.5GeV_1class_e-_FC.npz', **split_idxs)