In [20]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [21]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [22]:
# default_exp fastai.data.pipes

In [7]:
# export
# Python native modules
import os
from inspect import isfunction,ismethod
from typing import *
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
# from torch.utils.data.dataloader import DataLoader as OrgDataLoader
import torchdata.datapipes as dp
from torch.utils.data.dataloader_experimental import DataLoader2
from fastai.data.transforms import *
# Local modules
from fastrl.fastai.loop import *
from fastrl.fastai.data.load import *

# Basic DataPipes
> Basic datapipes for work with fastrl core API.

In [85]:
#For example, so not exported

from fastai.vision.core import *
from fastai.vision.data import *
from fastai.data.external import *

untar_data(URLs.MNIST_TINY)

Path('/home/fastrl_user/.fastai/data/mnist_tiny')

Load the mnist csv...

In [208]:
pipe = dp.iter.IterableWrapper([str(untar_data(URLs.MNIST_TINY)/'labels.csv')]) # FileOpener really should support Path as well as str
pipe = dp.iter.FileOpener(pipe, mode="b")
pipe = dp.iter.CSVParser(pipe,skip_lines=1)

class AddIdx():
    def __init__(self): self.idx=0
    def __call__(self,file):
        try:     return (self.idx,file)
        finally: self.idx+=1

base_pipe = dp.map.IterToMapConverter(pipe,key_value_fn=AddIdx())
pipe = dp.map.IterToMapConverter(pipe,key_value_fn=AddIdx())
pipe[5],len(base_pipe)

(['train/3/7745.png', '3'], 1408)

Now that we have the csv converted into a map, we want to split it into a training and validation dataset...

In [221]:
import warnings

from collections import deque
from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque

from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data.datapipes.utils.common import check_lambda_fn
from torch.utils.data._utils.serialization import serialize_fn, deserialize_fn


T_co = TypeVar("T_co", covariant=True)

from collections import deque
from collections.abc import Hashable


class _ChildDataPipe(dp.map.MapDataPipe):
    r"""
    Map Datapipe that is a child of a main DataPipe. The instance of this class
    will pass its instance_id to get the next value from its main DataPipe.

    Args:
        main_datapipe: Main DataPipe with a method 'get_next_element_by_instance(instance_id)'
        instance_id: integer identifier of this instance
    """
    # def __init__(self, main_datapipe, instance_id: int):
    def __init__(self, main_datapipe, instance_id: Hashable):
        
        required_attrs = ["get_next_element_by_instance", "is_instance_started", "is_every_instance_exhausted", "reset"]
        required_ops = [getattr(main_datapipe, attr) for attr in required_attrs]
        if any(not callable(op) for op in required_ops):
            raise NotImplementedError(f"Main Datapipe must have methods {required_attrs} implemented.")
        self.main_datapipe = main_datapipe
        self.instance_id = instance_id

    def __iter__(self):
        if self.main_datapipe.is_instance_started(self.instance_id):  # Only reset if the DataPipe started to read
            if not self.main_datapipe.is_every_instance_exhausted():
                warnings.warn("Some child DataPipes are not exhausted when __iter__ is called. We are resetting "
                              "the buffer and each child DataPipe will read from the start again.", UserWarning)
            self.main_datapipe.reset()
        # We want to separate the code for reset and yield, so that 'reset' exeutes before __next__ is called
        return self.get_generator_by_instance(self.instance_id)

    def __len__(self):
        return len(self.main_datapipe)

    # def get_generator_by_instance(self, instance_id: int):
    def get_generator_by_instance(self, instance_id: Hashable):
        yield from self.main_datapipe.get_next_element_by_instance(self.instance_id)


class _DemultiplexerMapDataPipe(dp.map.MapDataPipe):
    r"""
    Container to hold instance-specific information on behalf of _DemultiplexerMapDataPipe. It tracks
    the state of its child DataPipes, maintains the buffer, classifies and yields the next correct value
    as requested by the child DataPipes.
    """

    def __init__(self, datapipe: dp.map.MapDataPipe[T_co], num_instances: int,
                 classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool, buffer_size: int):
        self.main_datapipe = datapipe
        self._datapipe_iterator: Optional[Iterator[Any]] = None
        self.num_instances = num_instances
        self.buffer_size = buffer_size
        if self.buffer_size < 0:
            warnings.warn(
                "Unlimited buffer size is set for `demux`, "
                "please be aware of OOM at random places",
                UserWarning
            )
        self.current_buffer_usage = 0
        self.child_buffers: List[Deque[T_co]] = [deque() for _ in range(num_instances)]
        self.instance_started: List[bool] = [False] * num_instances
        self.classifier_fn = classifier_fn
        self.drop_none = drop_none
        self.main_datapipe_exhausted = False

    def _find_next(self, instance_id: int) -> T_co:
        while True:
            if self.main_datapipe_exhausted:
                raise StopIteration
            if self._datapipe_iterator is None:
                raise ValueError(
                    "_datapipe_iterator has not been set, likely because this private method is called directly "
                    "without invoking get_next_element_by_instance() first.")
            value = next(self._datapipe_iterator)
            classification = self.classifier_fn(value)
            if classification is None and self.drop_none:
                continue
            if classification is None or classification >= self.num_instances or classification < 0:
                raise ValueError(f"Output of the classification fn should be between 0 and {self.num_instances - 1}. " +
                                 f"{classification} is returned.")
            if classification == instance_id:
                return value
            self.child_buffers[classification].append(value)
            self.current_buffer_usage += 1
            if self.buffer_size >= 0 and self.current_buffer_usage > self.buffer_size:
                raise BufferError(
                    f"DemultiplexerIterDataPipe buffer overflow, buffer size {self.buffer_size} is insufficient.")

    def get_next_element_by_instance(self, instance_id: Hashable):
    # def get_next_element_by_instance(self, instance_id: int):
        if self._datapipe_iterator is None and not self.main_datapipe_exhausted:
            self._datapipe_iterator = iter(self.main_datapipe)
        stop = False
        self.instance_started[instance_id] = True
        while not stop:
            if self.child_buffers[instance_id]:
                self.current_buffer_usage -= 1
                yield self.child_buffers[instance_id].popleft()
            else:
                try:
                    yield self._find_next(instance_id)
                except StopIteration:
                    stop = True
                    self.main_datapipe_exhausted = True
                    self._datapipe_iterator = None

    def is_instance_started(self, instance_id: int) -> bool:
        return self.instance_started[instance_id]

    def is_every_instance_exhausted(self) -> bool:
        return self.main_datapipe_exhausted and all(not child_buffer for child_buffer in self.child_buffers)

    def reset(self):
        self._datapipe_iterator = iter(self.main_datapipe)
        self.current_buffer_usage = 0
        self.child_buffers = [deque() for _ in range(self.num_instances)]
        self.instance_started = [False] * self.num_instances
        self.main_datapipe_exhausted = False

    def __getstate__(self):
        if IterDataPipe.getstate_hook is not None:
            return IterDataPipe.getstate_hook(self)

        serialized_fn_with_method = serialize_fn(self.classifier_fn)
        state = (
            self.main_datapipe,
            self.num_instances,
            self.buffer_size,
            serialized_fn_with_method,
            self.drop_none,
        )
        return state

    def __setstate__(self, state):
        (
            self.main_datapipe,
            self.num_instances,
            self.buffer_size,
            serialized_fn_with_method,
            self.drop_none,
        ) = state
        self.classifier_fn = deserialize_fn(serialized_fn_with_method)
        self._datapipe_iterator = None
        self.current_buffer_usage = 0
        self.child_buffers = [deque() for _ in range(self.num_instances)]
        self.instance_started = [False] * self.num_instances
        self.main_datapipe_exhausted = False
        

class DemultiplexerMapDataPipe(dp.map.MapDataPipe):
    r"""
    Splits the input DataPipe into multiple child DataPipes, using the given
    classification function (functional name: ``demux``). A list of the child DataPipes is returned from this operation.

    Args:
        datapipe: Iterable DataPipe being filtered
        num_instances: number of instances of the DataPipe to create
        classifier_fn: a function that maps values to an integer within the range ``[0, num_instances - 1]`` or ``None``
        drop_none: defaults to ``False``, if ``True``, the function will skip over elements classified as ``None``
        
        Josiah: buffer_size doesn't look like it should be needed(?)
        buffer_size: this defines the maximum number of inputs that the buffer can hold across all child
            DataPipes while waiting for their values to be yielded.
            Defaults to ``1000``. Use ``-1`` for the unlimited buffer.

    Examples:
        >>> from torchdata.datapipes.iter import IterableWrapper
        >>> def odd_or_even(n):
        ...     return n % 2
        >>> source_dp = IterableWrapper(range(5))
        >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
        >>> list(dp1)
        [0, 2, 4]
        >>> list(dp2)
        [1, 3]
        >>> # It can also filter out any element that gets `None` from the `classifier_fn`
        >>> def odd_or_even_no_zero(n):
        ...     return n % 2 if n != 0 else None
        >>> dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
        >>> list(dp1)
        [2, 4]
        >>> list(dp2)
        [1, 3]
    """
    def __new__(cls, datapipe: dp.map.MapDataPipe, num_instances: int,
                classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False, buffer_size: int = 1000):
        if num_instances < 1:
            raise ValueError(f"Expected `num_instaces` larger than 0, but {num_instances} is found")

        check_lambda_fn(classifier_fn)

        # When num_instances == 1, demux can be replaced by filter,
        # but keep it as Demultiplexer for the sake of consistency
        # like throwing Error when classification result is out of o range
        container = _DemultiplexerIterDataPipe(datapipe, num_instances, classifier_fn, drop_none, buffer_size)
        return [_ChildDataPipe(container, i) for i in range(num_instances)]

In [222]:
# It can also filter out any element that gets `None` from the `classifier_fn`

def train_valid_splitter(o): 
    print(o)
    return 'train'

dp1, dp2 = DemultiplexerMapDataPipe(pipe,num_instances=2, classifier_fn=train_valid_splitter, drop_none=True)
list(dp1)
[2, 4]
list(dp2)

['train/3/7463.png', '3']


TypeError: '>=' not supported between instances of 'str' and 'int'

In [212]:
train,valid = MultiplexerMapDataPipe(pipe)

assert 0 not in (len(train),len(valid))
assert len(train)+len(valid)==len(base_pipe)


pipe

ValueError: too many values to unpack (expected 2)

In [53]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 02_fastai.exception_test.ipynb.
Converted 02a_fastai.loop.ipynb.
Converted 02a_fastai.loop_initial.ipynb.
Converted 02b_fastai.data.load.ipynb.
Converted 02c_fastai.data.block.ipynb.
Converted 02c_fastai.data.pipes.ipynb.
Converted 03_callback.core.ipynb.
Converted 04_agent.ipynb.
Converted 05_data.test_async.ipynb.
Converted 05a_data.block.ipynb.
Converted 05b_data.gym.ipynb.
Converted 06a_memory.experience_replay.ipynb.
Converted 06f_memory.tensorboard.ipynb.
Converted 10a_agents.dqn.core.ipynb.
Converted 10b_agents.dqn.targets.ipynb.
Converted 10c_agents.dqn.double.ipynb.
Converted 10d_agents.dqn.dueling.ipynb.
Converted 10e_agents.dqn.categorical.ipynb.
Converted 11a_agents.policy_gradient.ppo.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
converting: /home/fastrl_user/fastrl/nbs/02a_fastai.loop.ipynb
c