In [3]:
from typing import Any, Callable, Sequence, List

import h5py
import torch
from pytorch_lightning import LightningDataModule, cli
from pytorch_lightning.cli import SaveConfigCallback
from torch.utils.data import DataLoader, Dataset
import os

In [4]:
class H5PYDataset(Dataset):
    def __init__(self, path: str):
        if os.path.isdir(path):
            self.paths = [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.hdf5')]
        else:
            raise ValueError("Path should be a directory")
        
        self.paths.sort()
        self.key = 'data'
        self.lengths = [self._get_file_length(path) for path in self.paths]
        self.cumulative_lengths = self._compute_cumulative_lengths(self.lengths)

    def _get_file_length(self, path):
        with h5py.File(path, 'r') as file:
            return file[self.key].shape[0]

    def _compute_cumulative_lengths(self, lengths):
        cumulative_lengths = [0]
        for length in lengths:
            cumulative_lengths.append(cumulative_lengths[-1] + length)
        return cumulative_lengths

    def __len__(self):
        return self.cumulative_lengths[-1]

    def _load_data(self, path, local_index):
        with h5py.File(path, 'r') as file:
            return file[self.key][local_index]

    def __getitem__(self, global_index: int):
        # If global_index is out of bounds, raise an error
        if global_index < 0 or global_index >= len(self):
            raise IndexError(f"Index {global_index} out of bounds for dataset of length {len(self)}")
        
        file_index = self._find_file_index(global_index)
        local_index = global_index - self.cumulative_lengths[file_index]
        data = self._load_data(self.paths[file_index], local_index)
        return data

    def _find_file_index(self, global_index):
        # Binary search to find the right file index
        low, high = 0, len(self.cumulative_lengths) - 1
        while low < high:
            mid = (low + high) // 2
            if global_index < self.cumulative_lengths[mid + 1]:
                high = mid
            else:
                low = mid + 1
        return low
    
    def teardown(self):
        pass