# data

> Fill in a module description here

In [None]:
#| default_exp data

In [None]:
#| export
from typing import (
    Iterable,
    Callable,
    Iterator,
    Optional,
    List
)
import minima as mi
from minima import Tensor
from minima import init

In [None]:
#| export
class Sampler:
    """
    A custom sampler class.

    Args:
        ds (Iterable[int]): Iterable of indices.
        shuffle (bool): Whether to shuffle the indices.

    Example:
        >>> x = range(10)
        >>> sampler = Sampler(x, shuffle=True)
    """

    def __init__(self, ds: Iterable[int], shuffle: bool = False):
        self.n = len(ds)
        self.shuffle = shuffle

    def __iter__(self) -> Iterator[int]:
        res = list(range(self.n))
        if self.shuffle: random.shuffle(res)
        return iter(res)

In [None]:
#| export
class BatchSampler:
    """
    A custom batch sampler class.

    Args:
        sampler (Sampler): The sampler to use.
        bs (int): Batch size.
        drop_last (bool): Whether to drop the last batch if it is smaller than the batch size.

    Example:
        >>> x = range(10)
        >>> sampler = Sampler(x, shuffle=True)
        >>> batch_sampler = BatchSampler(sampler, bs=3)
    """

    def __init__(self, sampler: Sampler, bs: int, drop_last: bool = False):
        self.sampler = sampler
        self.bs = bs
        self.drop_last = drop_last

    def __iter__(self):
        yield from fc.chunked(iter(self.sampler), self.bs, drop_last=self.drop_last)


In [None]:
#| export
class Dataset():
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite:
    `__getitem__`, supporting fetching a data sample for a given key.
    `__len__`, which is expected to return the size of the dataset.
    """

    def __init__(self, transforms: Optional[List]=None):
        self.transforms = transforms

    def __getitem__(self, index) -> object:
        """
        Get an item from the dataset at the given index.

        Args:
            i (int): Index of the item.

        Returns:
            Tuple[float, float]: A tuple containing the input data and target label at the given index.

        Example:
            >>> dataset[0]
            (1, 0)
        """
        
        raise NotImplementedError

    def __len__(self) -> int:
        """
        Get the length of the dataset.

        Returns:
            int: Length of the dataset.

        Example:
            >>> len(dataset)
            5
        """
        
        raise NotImplementedError

    def apply_transforms(self, x):
        if self.transforms is not None:
            for tfms in self.transforms:
                x = tfms(x)
        return x

In [None]:
#|export
class DataLoader:
    """
    A custom data loader class.

    Args:
        ds (Dataset): The dataset to load.
        bs (int): Batch size.

    Example:
        >>> dataloader = DataLoader(dataset, batch_size)
    """

    def __init__(self,
                 dataset: Dataset,
                 batch_size: int = 1,
                 shuffle: bool = True,
                 sampler: Sampler = None,
                 batch_sampler: BatchSampler = None,
                 num_workers: int = 0,
                 collate_fn: callable = None,
                 drop_last: bool = False):

        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.sampler = sampler if sampler else Sampler(dataset, shuffle)
        self.batch_sampler = batch_sampler if batch_sampler else BatchSampler(self.sampler, batch_size, drop_last)
        self.num_workers = num_workers # --> TODO: implement a multiprocessing DataLoader :3
        self.collate_fn = collate
        self.drop_last = drop_last

    def __iter__(self):
        """
        Get an iterator over the DataLoader.

        Yields:
            Tuple[float, float]: A tuple containing a batch of input data and target labels.

        Example:
            >>> for batch in dataloader:
            >>>     # Process the batch
        """
        if self.num_workers:
            with mp.Pool(self.num_workers) as ex:
                yield from ex.map(self.dataset.__getitem__,  iter(self.batch_sampler))
        else:
            yield from (self.dataset[batch_idxs] for batch_idxs in self.batch_sampler)



In [None]:
X = init.rand(100, 10)
Y = init.randb(X.shape[0])
X.shape, Y.shape

((100, 10), (100,))

In [None]:
class MiDataset(Dataset):

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self) -> int:
        return self.x.shape[0]

    def __getitem__(self, i: int):
        return self.x[i], self.y[i]

In [None]:
ds = MiDataset(X,Y)

In [None]:
len(ds)

100

In [None]:
ds[:10]

(minima.Tensor(
 [[0.205512 0.403219 0.692859 0.422495 0.808965 0.576947 0.827043 0.11447  0.082906 0.779861]
  [0.984203 0.686364 0.590332 0.926567 0.376217 0.443898 0.855632 0.818763 0.118218 0.344021]
  [0.413944 0.907886 0.245893 0.007645 0.628516 0.072028 0.258648 0.087356 0.93394  0.076115]
  [0.914312 0.807569 0.319946 0.971284 0.876449 0.825572 0.808259 0.91054  0.071735 0.478746]
  [0.1809   0.188892 0.158353 0.247531 0.989222 0.685215 0.409875 0.906303 0.546848 0.734692]
  [0.363155 0.297981 0.47858  0.133156 0.017859 0.573805 0.880012 0.344674 0.614881 0.933287]
  [0.052164 0.1493   0.30457  0.640711 0.983575 0.842516 0.65946  0.562504 0.762918 0.95061 ]
  [0.660865 0.254197 0.342602 0.618704 0.828643 0.294252 0.425804 0.518007 0.031176 0.9347  ]
  [0.946777 0.235436 0.353485 0.611394 0.109561 0.332844 0.851154 0.676482 0.761387 0.110494]
  [0.265457 0.044131 0.432185 0.443281 0.548712 0.367395 0.36302  0.545991 0.666848 0.653104]]),
 minima.Tensor(
 [ True  True  True  True

## Export

In [None]:
import nbdev; nbdev.nbdev_export()