In [4]:
%load_ext jupyter_black
from __future__ import annotations

The jupyter_black extension is already loaded. To reload it, use:
  %reload_ext jupyter_black


In [5]:
import abc
import bisect
import itertools

from typing import Generic, TypeVar, Iterable, Iterator, Final

from mesoscaler.generic import NamedAndSized
from mesoscaler.utils import acc_size

_T = TypeVar("_T")

In [6]:
class Dataset(NamedAndSized, Generic[_T], abc.ABC):
    @abc.abstractmethod
    def __getitem__(self, index: int) -> _T:
        ...

    def __add__(self, other: Dataset[_T]) -> ConcatDataset[_T]:
        return ConcatDataset([self, other])


class ConcatDataset(Dataset[_T]):
    def __init__(self, data: Iterable[Dataset[_T]]) -> None:
        super().__init__()
        self.data = data = list(data)
        if not data:
            raise ValueError("datasets should not be an empty iterable")
        for ds in data:
            if isinstance(ds, IterableDataset):
                raise ValueError("ConcatDataset does not support IterableDataset")

        self.accumulated_sizes = list(acc_size(data))

    def __len__(self) -> int:
        return self.accumulated_sizes[-1]

    def __getitem__(self, idx: int) -> _T:
        if idx < 0:
            if -idx > len(self):
                raise ValueError("absolute value of index should not exceed dataset length")
            idx += len(self)

        if ds_idx := bisect.bisect_right(self.accumulated_sizes, idx):
            idx -= self.accumulated_sizes[ds_idx - 1]

        return self.data[ds_idx][idx]


# =====================================================================================================================
class IterableDataset(NamedAndSized, Iterable[_T], abc.ABC):
    @abc.abstractmethod
    def __iter__(self) -> Iterator[_T]:
        ...

    def __add__(self, other: IterableDataset[_T]) -> ChainDataset[_T]:
        return ChainDataset([self, other])


class ChainDataset(IterableDataset[_T]):
    def __init__(self, datasets: Iterable[IterableDataset[_T]]) -> None:
        super().__init__()
        self.data: Final[Iterable[IterableDataset[_T]]] = datasets

    def __iter__(self) -> Iterator[_T]:
        return itertools.chain.from_iterable(self.data)

    def __len__(self) -> int:
        return sum(map(len, self.data))