In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp dataloading.core

In [None]:
#|export
# Python native modules
from typing import Tuple,Union,List
# Third party libs
import torchdata.datapipes as dp
from torchdata.dataloader2 import MultiProcessingReadingService,DataLoader2
from fastcore.all import delegates
# Local modules

# Dataloading Core
> Basic utils for creating dataloaders from rl datapipes.

In [None]:
#|export
@delegates(MultiProcessingReadingService)
def dataloaders(
    # A tuple of iterable datapipes to generate dataloaders from.
    pipes:Union[Tuple[dp.iter.IterDataPipe],dp.iter.IterDataPipe],
    # Concat the dataloaders together
    do_concat:bool = False,
    # Multiplex the dataloaders
    do_multiplex:bool = False,
    # Number of workers the dataloaders should run in
    num_workers: int = 0,
    **kwargs
) -> Union[dp.iter.IterDataPipe,List[dp.iter.IterDataPipe]]:
    "Function that creates dataloaders based on `pipes` with different ways of combing them."
    if not isinstance(pipes,tuple):
        pipes = (pipes,)

    dls = []
    for pipe in pipes:
        dl = DataLoader2(
            datapipe=pipe,
            reading_service=MultiProcessingReadingService(
                num_workers = num_workers,
                **kwargs
            ) if num_workers > 0 else None
        )
        dl = dp.iter.IterableWrapper(dl,deepcopy=False)
        dls.append(dl)
    #TODO(josiahls): Not sure if this is needed tbh.. Might be better to just
    # return dls, and have the user wrap them if they want. Then try can do more complex stuff.
    if do_concat:
        return dp.iter.Concater(*dls)
    elif do_multiplex:
        return dp.iter.Multiplexer(*dls)
    else:
        return dls


In [None]:
from fastcore.test import test_eq

In [None]:
# Sample Data
pipe1 = dp.iter.IterableWrapper([1, 2, 3])
pipe2 = dp.iter.IterableWrapper([4, 5, 6])

# Test for a single IterDataPipe
dls = dataloaders(pipe1)
assert len(dls) == 1
assert isinstance(dls[0], dp.iter.IterableWrapper)
test_eq(list(dls[0]), [1, 2, 3])

# Test for a tuple of IterDataPipes without concatenation or multiplexing
dls = dataloaders((pipe1, pipe2))
test_eq(len(dls),2)
test_eq(list(dls[0]), [1, 2, 3])
test_eq(list(dls[1]), [4, 5, 6])

# Test for concatenation
dl = dataloaders((pipe1, pipe2), do_concat=True)
assert isinstance(dl, dp.iter.Concater)
test_eq(list(dl), [1, 2, 3, 4, 5, 6])

# Test for multiplexing
dl = dataloaders((pipe1, pipe2), do_multiplex=True)
assert isinstance(dl, dp.iter.Multiplexer)
test_eq(list(dl), [1, 4, 2, 5, 3, 6])

In [None]:
#|hide
#|eval: false
!nbdev_export