In [3]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [4]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [5]:
# default_exp fastai.data.pipes.mux

In [6]:
# export
# Python native modules
import os
from inspect import isfunction,ismethod
from typing import *
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
# from torch.utils.data.dataloader import DataLoader as OrgDataLoader
import torchdata.datapipes as dp
from torch.utils.data.dataloader_experimental import DataLoader2
from fastai.data.transforms import *
# Local modules
from fastrl.fastai.loop import *
from fastrl.fastai.data.load import *

# Basic DataPipes - Multiplexer
> Basic datapipes for work with fastrl core API.

In [7]:
#For example, so not exported

from fastai.vision.core import *
from fastai.vision.data import *
from fastai.data.external import *

untar_data(URLs.MNIST_TINY),URLs.MNIST_TINY

(Path('/home/fastrl_user/.fastai/data/mnist_tiny'),
 'https://s3.amazonaws.com/fast-ai-sample/mnist_tiny.tgz')

Load the mnist csv...

In [8]:
pipe = dp.iter.IterableWrapper([str(untar_data(URLs.MNIST_TINY)/'labels.csv')]) # FileOpener really should support Path as well as str
pipe = dp.iter.FileOpener(pipe, mode="b")
pipe = dp.iter.CSVParser(pipe,skip_lines=1)

class AddIdx():
    def __init__(self): self.idx=0
    def __call__(self,file):
        try:     return (self.idx,file)
        finally: self.idx+=1

base_pipe = dp.map.IterToMapConverter(pipe,key_value_fn=AddIdx())
pipe = dp.map.IterToMapConverter(pipe,key_value_fn=AddIdx())
# pipe[5],len(base_pipe)

In [9]:
len(pipe)

  "Data from prior DataPipe are loaded to get length of"


1408

Now that we have the csv converted into a map, we want to split it into a training and validation dataset...

In [10]:
from fastrl.fastai.data.pipes.demux import *

In [11]:
def train_valid_splitter(o): 
    return Path(o[0]).parts[0]

dp1, dp2 = DemultiplexerMapDataPipe(pipe,instance_keys=['train','valid'], classifier_fn=train_valid_splitter, drop_none=True)

In [12]:
assert len(dp1)+len(dp2)==len(pipe),f"The demux'd dp1 and dp2 when added together should be the same len as pipe {len(dp1)} + {len(dp2)} = {len(pipe)}"

  "Data from prior DataPipe are loaded to get length of"


In [13]:
dp1[0],dp2[1000]

(['train/3/7463.png', '3'], ['valid/3/9614.png', '3'])

In [14]:
list(dp1)[:5],list(dp2)[:5]

([['train/3/7463.png', '3'],
  ['train/3/9829.png', '3'],
  ['train/3/7881.png', '3'],
  ['train/3/8065.png', '3'],
  ['train/3/7046.png', '3']],
 [['valid/3/8430.png', '3'],
  ['valid/3/7946.png', '3'],
  ['valid/3/933.png', '3'],
  ['valid/3/9308.png', '3'],
  ['valid/3/795.png', '3']])

In [15]:
class KSplitter():
    k=0
    def __init__(self,k_splits=2): self.k_splits=k_splits
    def __call__(self,*args):
        try: 
            return self.k
        finally: 
            self.k+=1
            if self.k==self.k_splits: self.k=0

k1,k2,k3 = DemultiplexerMapDataPipe(dp1,instance_keys=[0,1,2], classifier_fn=KSplitter(k_splits=3), drop_none=True)

In [16]:
# export
dp.functional_datapipe('mux')
class MultiplexerMapDataPipe(dp.map.MapDataPipe):
    def __init__(self, *datapipes):
        self.datapipes = datapipes
        # datapipe import standard is import torchdata as dp. We need to make these
        # private
        for _dp in self.datapipes:
            if not isinstance(_dp,dp.map.MapDataPipe):
                dp_types=[type(o) for o in self.datapipes]
                raise ValueError(f'Passed in datapipes need to be MapDataPipes, got {dp_types}')
        
        self.length: Optional[int] = None
        self._map = {}

    def _setup_datapipe_indexer(self, datapipe) -> Optional[Iterator[Any]]:
        # self._datapipe_iterator: Optional[Iterator[Any]] = None
        # Instead of _datapipe_iterator we have _datapipe_indexer
        # We need to know how to get the index from the main_datapipe. In order
        # to do this, we check if it is...
        
        # NOTE: THIS IS NOT A GOOD SOLUTION SINCE THIS CANT RELY ON A STANDARD
        # INTERFACE FOR GETTING INDEXES
        
        # We cash the indexes because we want to be able to have consistent behavior 
        # when calling __getitem__ on a child pipe. 
        # What we don't want is the main_datapipe being indexed by `str` but the
        # child pipes indexing by `int`...
        if isinstance(datapipe, dp.map.SequenceWrapper):
            return iter(range(len(datapipe)))
        elif hasattr(datapipe, '_map'):
            return iter(datapipe._map)
        elif hasattr(datapipe, 'index_map'):
            return iter(datapipe.index_map)
        else:
            warnings.warn('data pipe will be indexed by len')
            return iter(range(len(datapipe)))
        
    def __iter__(self):
        iterators = [self._setup_datapipe_indexer(x) for x in self.datapipes]
        finished: Set[int] = set()
        while len(finished) < len(iterators):
            for i in range(len(iterators)):
                if i not in finished:
                    try:
                        index = next(iterators[i])
                        # print(i,index)
                        value = self.datapipes[i][index]
                        # self._map will track which index is associated with 
                        # which datapipe...
                        self._map[index] = i
                        yield value
                    except StopIteration:
                        finished.add(i)

    def __getitem__(self, index) -> T_co:
        if index in self._map:
            # self._map[index] -> the datapipe to getitem at, then pass index to 
            # get the value 
            return self.datapipes[self._map[index]][index]
        
        # Remember that iter(self) adds index to self._map. So as we iter,
        # we can check if index has be found in a datapipe, and once found, return 
        # that value.
        for value in self:
            if index in self._map: return value
        
        raise IndexError(f'Unable to find {index} in the datapipes')
                        
    def __len__(self):
        if self.length is not None:
            if self.length == -1:
                raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
            return self.length
        if all(isinstance(_dp, Sized) for _dp in self.datapipes):
            self.length = sum(len(_dp) for _dp in self.datapipes)
        else:
            self.length = -1
        return len(self)


In [17]:
len(k1),len(k2),len(k3)

(237, 236, 236)

In [18]:
combined_pipes=MultiplexerMapDataPipe(k1,k2,k3)

In [19]:
list(combined_pipes)[::100]

[['train/3/7463.png', '3'],
 ['train/7/954.png', '7'],
 ['train/7/8302.png', '7'],
 ['train/3/8014.png', '3'],
 ['train/3/8447.png', '3'],
 ['train/3/7225.png', '3'],
 ['train/7/773.png', '7'],
 ['train/7/8571.png', '7']]

In [20]:
len(combined_pipes)

709

In [21]:
combined_pipes[5]

['train/3/7745.png', '3']

In [22]:
list(combined_pipes)[::100]

[['train/3/7463.png', '3'],
 ['train/7/954.png', '7'],
 ['train/7/8302.png', '7'],
 ['train/3/8014.png', '3'],
 ['train/3/8447.png', '3'],
 ['train/3/7225.png', '3'],
 ['train/7/773.png', '7'],
 ['train/7/8571.png', '7']]

In [24]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 02_fastai.exception_test.ipynb.
Converted 02a_fastai.loop.ipynb.
Converted 02a_fastai.loop_initial.ipynb.
Converted 02b_fastai.data.load.ipynb.
Converted 02c_fastai.data.block.ipynb.
Converted 02c_fastai.data.pipes.demux.ipynb.
Converted 02c_fastai.data.pipes.mux.ipynb.
Converted 03_callback.core.ipynb.
Converted 04_agent.ipynb.
Converted 05_data.test_async.ipynb.
Converted 05a_data.block.ipynb.
Converted 05b_data.gym.ipynb.
Converted 06a_memory.experience_replay.ipynb.
Converted 06f_memory.tensorboard.ipynb.
Converted 10a_agents.dqn.core.ipynb.
Converted 10b_agents.dqn.targets.ipynb.
Converted 10c_agents.dqn.double.ipynb.
Converted 10d_agents.dqn.dueling.ipynb.
Converted 10e_agents.dqn.categorical.ipynb.
Converted 11a_agents.policy_gradient.ppo.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
