# Dataset Index Generation
Generates indices for train, validation and test sets

In [1]:
import sys
import os
import h5py
from collections import Counter
from progressbar import *
import re
import numpy as np
from scipy import signal
import matplotlib
#from watchmal.testing.repeating_classifier_training_utils import *
from functools import reduce

# Add the path to the parent directory to augment search for module
par_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

if par_dir not in sys.path:
    sys.path.append(par_dir)

sys.path.append("../..")
sys.path.append("..")


%load_ext autoreload
%matplotlib inline
%autoreload 2

from IPython.display import HTML

In [2]:
original_data_path = "/fast_scratch/WatChMaL/data/IWCD_mPMT_Short_emg_E0to1000MeV_digihits.h5"
f = h5py.File(original_data_path, "r")

labels = np.array(f['labels'])

## Load dataset

In [3]:
# Import test events from h5 file
original_data_path = "/fast_scratch/WatChMaL/data/IWCD_mPMT_Short_emg_E0to1000MeV_digihits.h5"
f = h5py.File(original_data_path, "r")

print(f.keys())

hdf5_hit_pmt    = f["hit_pmt"]
hdf5_hit_time   = f["hit_time"]
hdf5_hit_charge = f["hit_charge"]


hit_pmt = np.memmap(original_data_path, mode="r", shape=hdf5_hit_pmt.shape,
                                    offset=hdf5_hit_pmt.id.get_offset(), dtype=hdf5_hit_pmt.dtype)

hit_time = np.memmap(original_data_path, mode="r", shape=hdf5_hit_time.shape,
                                    offset=hdf5_hit_time.id.get_offset(), dtype=hdf5_hit_time.dtype)

hit_charge = np.memmap(original_data_path, mode="r", shape=hdf5_hit_charge.shape,
                                    offset=hdf5_hit_charge.id.get_offset(), dtype=hdf5_hit_charge.dtype)

angles     = np.array(f['angles'])
energies   = np.array(f['energies'])
positions  = np.array(f['positions'])
labels     = np.array(f['labels'])
root_files = np.array(f['root_files'])

veto = np.array(f['veto'])
veto2 = np.array(f['veto2'])

<KeysViewHDF5 ['angles', 'energies', 'event_hits_index', 'event_ids', 'hit_charge', 'hit_pmt', 'hit_time', 'labels', 'positions', 'root_files', 'veto', 'veto2']>


In [4]:
# Set up indices
indices = np.array(range(len(labels)))

In [5]:
# Filter indices based on vetos
#overall_veto = np.logical_or(veto, veto2)
overall_veto = veto

filtered_indices    = indices[np.invert(overall_veto)]
filtered_labels     = labels[np.invert(overall_veto)]
filtered_root_files = root_files[np.invert(overall_veto)]

In [6]:
print(min(filtered_indices))

0


In [7]:
print(set(overall_veto[0:50]))

{False, True}


In [8]:
# Set up dict of file indices
file_dict = dict.fromkeys(root_files)
print("Dict set")

for file in file_dict.keys():
    file_dict[file] = []

for idx, root_file in zip(filtered_indices, filtered_root_files):
    file_dict[root_file].append(idx)
print("Done")

Dict set
Done


In [9]:
# Get files associated with each particle type
gamma_indices = filtered_indices[np.where(filtered_labels == 0)]
gamma_root_file_set = list(dict.fromkeys(root_files[gamma_indices]))

e_indices     = filtered_indices[np.where(filtered_labels == 1)]
e_root_file_set = list(dict.fromkeys(root_files[e_indices]))

mu_indices    = filtered_indices[np.where(filtered_labels == 2)]
mu_root_file_set = list(dict.fromkeys(root_files[mu_indices]))

print(len(e_root_file_set))
print(len(mu_root_file_set))
print(len(gamma_root_file_set))

3000
1000
3000


In [10]:
# Define indices retrieval function
def get_indices_for_files(file_names):
    all_indices = []
    for file_name in file_names:
        all_indices.extend(file_dict[file_name])
    return np.array(all_indices)
        

In [11]:
mu_test_files, mu_val_files, mu_train_files = mu_root_file_set[0:400], mu_root_file_set[400:500], mu_root_file_set[500:]

mu_test_set, mu_val_set, mu_train_set = get_indices_for_files(mu_test_files), get_indices_for_files(mu_val_files), get_indices_for_files(mu_train_files)

print(mu_test_set)

[5900828 5900829 5900831 ... 7064959 7064962 7064965]


In [12]:
gamma_test_files, gamma_val_files, gamma_train_files = gamma_root_file_set[0:400], gamma_root_file_set[400:500], gamma_root_file_set[500:]

gamma_test_set, gamma_val_set, gamma_train_set = get_indices_for_files(gamma_test_files), get_indices_for_files(gamma_val_files), get_indices_for_files(gamma_train_files)

print(gamma_test_set)

[2944634 2944635 2944637 ... 4127008 4127009 4127010]


In [13]:
e_test_files, e_val_files, e_train_files = e_root_file_set[0:400], e_root_file_set[400:500], e_root_file_set[500:]

e_test_set, e_val_set, e_train_set = get_indices_for_files(e_test_files), get_indices_for_files(e_val_files), get_indices_for_files(e_train_files)

print(e_test_set)

[      0       1       2 ... 1177963 1177964 1177965]


In [14]:
print(e_test_files[0])
print(e_test_files[-1])

b'/localscratch/prouse.56905527.0/WCSim/e-/E0to1000MeV/unif-pos-R400-y300cm/4pi-dir/IWCD_mPMT_Short_e-_E0to1000MeV_unif-pos-R400-y300cm_4pi-dir_3000evts_0.root'
b'/localscratch/prouse.56922977.0/WCSim/e-/E0to1000MeV/unif-pos-R400-y300cm/4pi-dir/IWCD_mPMT_Short_e-_E0to1000MeV_unif-pos-R400-y300cm_4pi-dir_3000evts_399.root'


In [15]:
print(len(mu_train_files))
print(len(gamma_train_files))
print(len(e_train_files))

500
2500
2500


In [16]:
# Verify that indices match
all_e_indices = np.concatenate((e_test_set, e_val_set, e_train_set))
print(set(labels[all_e_indices]))

all_gamma_indices = np.concatenate((gamma_test_set, gamma_val_set, gamma_train_set))
print(set(labels[all_gamma_indices]))

all_mu_indices = np.concatenate((mu_test_set, mu_val_set, mu_train_set))
print(set(labels[all_mu_indices]))

{1}
{0}
{2}


In [17]:
# Verify that all events are uniquely accounted for
all_collected_indices = np.concatenate((e_test_set, e_val_set, e_train_set, gamma_test_set, gamma_val_set, gamma_train_set, mu_test_set, mu_val_set, mu_train_set))

print(len(labels))
print(len(all_collected_indices))
print(len(set(all_collected_indices)))

20613195
17415143
17415143


## 3 Class

In [18]:
train_idxs = np.concatenate((e_train_set, mu_train_set, gamma_train_set))
val_idxs   = np.concatenate((e_val_set, mu_val_set, gamma_val_set))
test_idxs  = np.concatenate((e_test_set, mu_test_set, gamma_test_set))

In [19]:
np.savez('./short_dataset_data/IWCD_mPMT_Short_3_class_emg_9M_OD_veto_idxs.npz', train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs)

## 2 Class e/gamma

In [20]:
train_idxs = np.concatenate((e_train_set, gamma_train_set))
val_idxs   = np.concatenate((e_val_set, gamma_val_set))
test_idxs  = np.concatenate((e_test_set, gamma_test_set))

In [21]:
np.savez('./short_dataset_data/IWCD_mPMT_Short_2_class_eg_9M_OD_veto_idxs.npz', train_idxs=train_idxs, val_idxs=val_idxs, test_idxs=test_idxs)