In [2]:
import jammy_flows
import h5py
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import math

In [23]:
class HDF5PhotonTable(torch.utils.data.IterableDataset):
    def __init__(self, filename, start, end):
        super(HDF5PhotonTable).__init__()
        assert end > start, "end >= start"
        self._hdl = h5py.File(filename)
        self.size = len(self._hdl["photon_tables"].keys())
        self.start = start
        self.end = min((self.size, end)) 

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)

        ds_groups = sorted(self._hdl["photon_tables"].keys())

        for i in range(iter_start, iter_end):
            grp = ds_groups[i]
            yield dict(self._hdl["photon_tables"][grp].attrs), self._hdl["photon_tables"][grp][:]
       
    def __del__(self):
        if hasattr(self, "_hdl"):
            self._hdl.close()
        
        

In [36]:
for d in HDF5PhotonTable("../assets/photon_table.hd5", 2, 10):
    print(d)
    break

({'dir_phi': 1.7847153148493689, 'dir_theta': 2.369524989194152, 'distance': 140.7464599609375, 'energy': 261.867183066023, 'pos_phi': 0.8196239488360239, 'pos_theta': 0.472755575708884}, array([], shape=(2, 0), dtype=float32))


In [39]:
with h5py.File('../assets/photon_table.hd5', 'r') as f:
    print(f["photon_tables"]["dataset_4"])


<HDF5 dataset "dataset_4": shape (2, 0), type "<f4">


ValueError: Invalid location identifier (invalid location identifier)

<KeysViewHDF5 ['dataset_1', 'dataset_10', 'dataset_11', 'dataset_12', 'dataset_2', 'dataset_3', 'dataset_4', 'dataset_5', 'dataset_6', 'dataset_7', 'dataset_8', 'dataset_9']>

In [6]:
f.close()