In [8]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import h5py

#Dataset for hdf5 in torch based on lazy loading described here https://vict0rs.ch/2021/06/15/pytorch-h5/


class hdf5Dataset(Dataset):
    def __init__(self, h5_path, x_name, y_name):
        super().__init__()
        self.h5_path = h5_path
        self._data = None
        self.x_name = x_name
        self.y_name = y_name

    @property
    def data(self):
        if self._data is None:
            self._data = h5py.File(self.h5_path, "r")
        return self._data       

    def __getitem__(self, index):
        return self.data[self.x_name][index], self.data[self.y_name][index]

    def __len__(self):
        return len(self.data[self.x_name])        

In [3]:
#Test with created dataset
N = 500
x_name = "train_x"
y_name = "train_y"
x_npy = np.random.rand(N, 3, 256, 256)
y_npy = np.random.rand(N, 5)
with h5py.File("mytestfile.hdf5", "a") as hf:
    if x_name not in hf.keys():
        hf.create_dataset(x_name, (N, 3, 256, 256), maxshape=(None,3,256,256), dtype='f', chunks=True)
        hf[x_name][-x_npy.shape[0]:] = x_npy
    if y_name not in hf.keys():
        hf.create_dataset(y_name, (N,5), maxshape=(None,5), dtype='f', chunks=True)
        hf[y_name][-y_npy.shape[0]:] = y_npy
        

In [9]:
loader = DataLoader(hdf5Dataset("mytestfile.hdf5", "train_x", "train_y"), batch_size=5)
images, labels = next(iter(loader))
print(images.shape)
print(labels.shape)



torch.Size([5, 3, 256, 256])
torch.Size([5, 5])
