In [1]:
from torch.utils.data import Dataset
import h5py
from pathlib import Path
from typing import Union
from tqdm.notebook import tqdm

In [2]:
class TrackMLDataset(Dataset):

    def __init__(
        self,
        file: Union[str,Path],
        return_hits: bool=True,
        return_cells: bool=True,
        return_particles: bool=True,
        return_truth: bool=True,
    ):
        self.file = h5py.File(file, 'r')
        self.number_of_events = self.file.attrs['number_of_events']
        
        self.return_hits = return_hits
        self.return_cells = return_cells
        self.return_particles = return_particles
        self.return_truth = return_truth

        self.hits = self.file['hits']
        self.cells = self.file['cells']
        self.particles = self.file['particles']
        self.truth = self.file['truth']

    def __del__(self):
        self.file.close()

    def __len__(self):
        return self.number_of_events

    def __getitem__(self, idx: int):
        return {
            'hits': self._get_hits(idx),
            'cells': self._get_cells(idx),
            'particles': self._get_particles(idx),
            'truth': self._get_truth(idx),
        }

    def _get_hits(self, idx: int):
        if not self.return_hits:
            return None
        offset = self.hits['event_offset'][idx]
        length = self.hits['event_length'][idx]
        event_slice = slice(offset, offset+length)
        hit_id = self.hits['hit_id'][event_slice]
        x = self.hits['x'][event_slice]
        y = self.hits['y'][event_slice]
        z = self.hits['z'][event_slice]
        volume_id = self.hits['volume_id'][event_slice]
        layer_id  = self.hits[ 'layer_id'][event_slice]
        module_id = self.hits['module_id'][event_slice]
        return {'hit_id': hit_id, 'x': x, 'y': y, 'z': z, 'volume_id': volume_id, 'layer_id': layer_id, 'module_id': module_id}

    def _get_cells(self, idx: int):
        if not self.return_cells:
            return None
        offset = self.cells['event_offset'][idx]
        length = self.cells['event_length'][idx]
        event_slice = slice(offset, offset+length)
        hit_id = self.cells['hit_id'][event_slice]
        ch0 = self.cells['ch0'][event_slice]
        ch1 = self.cells['ch1'][event_slice]
        value = self.cells['value'][event_slice]
        return {'hit_id': hit_id, 'ch0': ch0, 'ch1': ch1, 'value': value}

    def _get_particles(self, idx: int):
        if not self.return_particles:
            return None
        offset = self.particles['event_offset'][idx]
        length = self.particles['event_length'][idx]
        event_slice = slice(offset, offset+length)
        particle_id = self.particles['particle_id'][event_slice]
        particle_type = self.particles['particle_type'][event_slice]
        vx = self.particles['vx'][event_slice]
        vy = self.particles['vy'][event_slice]
        vz = self.particles['vz'][event_slice]
        px = self.particles['px'][event_slice]
        py = self.particles['py'][event_slice]
        pz = self.particles['pz'][event_slice]
        q = self.particles['q'][event_slice]
        nhits = self.particles['nhits'][event_slice]
        return {'particle_id': particle_id, 'particle_type': particle_type, 'vx': vx, 'vy': vy, 'vz': vz, 'px': px, 'py': py, 'pz': pz, 'q': q, 'nhits': nhits}

    def _get_truth(self, idx: int):
        if not self.return_truth:
            return None
        offset = self.truth['event_offset'][idx]
        length = self.truth['event_length'][idx]
        event_slice = slice(offset, offset+length)
        hit_id = self.truth['hit_id'][event_slice]
        particle_id = self.truth['particle_id'][event_slice]
        tx = self.truth['tx'][event_slice]
        ty = self.truth['ty'][event_slice]
        tz = self.truth['tz'][event_slice]
        tpx = self.truth['tpx'][event_slice]
        tpy = self.truth['tpy'][event_slice]
        tpz = self.truth['tpz'][event_slice]
        weight = self.truth['weight'][event_slice]
        return {'hit_id': hit_id, 'particle_id': particle_id, 'tx': tx, 'ty': ty, 'tz': tz, 'tpx': tpx, 'tpy': tpy, 'tpz': tpz, 'weight': weight}

In [3]:
dset = TrackMLDataset('../TrackML.hdf5')

In [4]:
dset[0]

{'hits': {'hit_id': array([     1,      2,      3, ..., 105311, 105312, 105313], dtype=int32),
  'x': array([ -72.7191,  -33.8991,  -61.3116, ..., -812.388 , -773.    ,
         -971.92  ], dtype=float32),
  'y': array([ -7.75438 ,  -1.94067 ,   0.566296, ..., 106.28    ,  71.169   ,
          50.9279  ], dtype=float32),
  'z': array([-1502.5, -1502.5, -1502.5, ...,  2944.5,  2944.5,  2952.5],
        dtype=float32),
  'volume_id': array([ 7,  7,  7, ..., 18, 18, 18], dtype=int8),
  'layer_id': array([ 2,  2,  2, ..., 12, 12, 12], dtype=int8),
  'module_id': array([ 1,  1,  1, ..., 97, 97, 98], dtype=int16)},
 'cells': {'hit_id': array([     1,      1,      2, ..., 104848, 104849, 104849], dtype=int32),
  'ch0': array([ 129,  129,  146, ..., 1010,  358,  357], dtype=int16),
  'ch1': array([1171, 1170,  906, ...,   14,    2,    2], dtype=int16),
  'value': array([0.0432567, 0.17901  , 0.235661 , ..., 1.       , 1.       ,
         1.       ], dtype=float32)},
 'particles': {'particle_id

In [5]:
%%time
for i in tqdm(range(len(dset))):
    dset[i]

  0%|          | 0/8743 [00:00<?, ?it/s]

CPU times: user 2min 39s, sys: 41.1 s, total: 3min 20s
Wall time: 5min 11s
