In [3]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [4]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [5]:
# default_exp fastai.data.pipes.demux

In [6]:
# export
# Python native modules
import os
from inspect import isfunction,ismethod
from typing import *
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
# from torch.utils.data.dataloader import DataLoader as OrgDataLoader
import torchdata.datapipes as dp
from torch.utils.data.dataloader_experimental import DataLoader2
from fastai.data.transforms import *
# Local modules
from fastrl.fastai.loop import *
from fastrl.fastai.data.load import *

# Basic DataPipes
> Basic datapipes for work with fastrl core API.

In [7]:
#For example, so not exported

from fastai.vision.core import *
from fastai.vision.data import *
from fastai.data.external import *

untar_data(URLs.MNIST_TINY),URLs.MNIST_TINY

(Path('/home/fastrl_user/.fastai/data/mnist_tiny'),
 'https://s3.amazonaws.com/fast-ai-sample/mnist_tiny.tgz')

Load the mnist csv...

In [8]:
pipe = dp.iter.IterableWrapper([str(untar_data(URLs.MNIST_TINY)/'labels.csv')]) # FileOpener really should support Path as well as str
pipe = dp.iter.FileOpener(pipe, mode="b")
pipe = dp.iter.CSVParser(pipe,skip_lines=1)

class AddIdx():
    def __init__(self): self.idx=0
    def __call__(self,file):
        try:     return (self.idx,file)
        finally: self.idx+=1

base_pipe = dp.map.IterToMapConverter(pipe,key_value_fn=AddIdx())
pipe = dp.map.IterToMapConverter(pipe,key_value_fn=AddIdx())
# pipe[5],len(base_pipe)

In [9]:
len(pipe)

  "Data from prior DataPipe are loaded to get length of"


1408

Now that we have the csv converted into a map, we want to split it into a training and validation dataset...

In [12]:
iterator=iter({1,2,3,4})
next(iterator)

1

In [14]:
# export

from typing import Callable, Dict, Iterable, Optional, TypeVar

from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import MapDataPipe

from torch.utils.data.datapipes.utils.common import check_lambda_fn

T_co = TypeVar("T_co", covariant=True)



@functional_datapipe("demux")
class DemultiplexerMapDataPipe:
    def __new__(cls, datapipe: MapDataPipe, num_instances: int, classifier_fn: Callable, drop_none: bool = False,
                source_index: Optional[Iterable] = None):
        if num_instances < 1:
            raise ValueError(f"Expected `num_instances` larger than 0, but {num_instances} is found")
        check_lambda_fn(classifier_fn)
        container = _DemultiplexerMapDataPipe(datapipe, num_instances, classifier_fn, drop_none, source_index)
        return [_DemultiplexerChildMapDataPipe(container, i) for i in range(num_instances)]


class _DemultiplexerMapDataPipe:
    def __init__(
        self,
        datapipe: MapDataPipe[T_co],
        num_instances: int,
        classifier_fn: Callable[[T_co], Optional[int]],
        drop_none: bool,
        source_index: Optional[Iterable],
    ):
        self.main_datapipe = datapipe
        self.num_instances = num_instances
        self.classifier_fn = classifier_fn
        self.drop_none = drop_none
        self.iterator = None
        self.exhausted = False  # Once we iterate through `main_datapipe` once, we know all the index mapping
        self.index_mapping = [[] for _ in range(num_instances)]
        self.source_index = source_index  # if None, assume `main_datapipe` 0-index

    def _classify_next(self):
        if self.source_index is None:
            self.source_index = range(len(self.main_datapipe))
        if self.iterator is None:
            self.iterator = iter(self.source_index)
        try:
            next_source_idx = next(self.iterator)
        except StopIteration:
            self.exhausted = True
            return
        value = self.main_datapipe[next_source_idx]
        classification = self.classifier_fn(value)
        if classification is None and self.drop_none:
            self._classify_next()
        else:
            self.index_mapping[classification].append(value)

    def classify_all(self):
        while not self.exhausted:
            self._classify_next()

    def get_value(self, instance_id: int, index: int) -> T_co:
        while not self.exhausted and len(self.index_mapping[instance_id]) <= index:
            self._classify_next()
        if len(self.index_mapping[instance_id]) > index:
            return self.index_mapping[instance_id][index]
        raise RuntimeError("Index is out of bound.")

    def __len__(self):
        return len(self.main_datapipe)


class _DemultiplexerChildMapDataPipe(MapDataPipe):
    def __init__(self, main_datapipe: _DemultiplexerMapDataPipe, instance_id: int):
        self.main_datapipe: _DemultiplexerMapDataPipe = main_datapipe
        self.instance_id = instance_id

    def __getitem__(self, index: int):
        return self.main_datapipe.get_value(self.instance_id, index)

    def __len__(self):
        self.main_datapipe.classify_all()  # You have to read through the entirety of main_datapipe to know `len`
        return len(self.main_datapipe.index_mapping[self.instance_id])

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

In [19]:
def train_valid_splitter(o): 
    int_mapping={'train':0,'valid':1}
    return int_mapping[Path(o[0]).parts[0]]

dp1, dp2 = DemultiplexerMapDataPipe(pipe,num_instances=2, classifier_fn=train_valid_splitter, drop_none=True)

In [20]:
assert len(dp1)+len(dp2)==len(pipe),f"The demux'd dp1 and dp2 when added together should be the same len as pipe {len(dp1)} + {len(dp2)} = {len(pipe)}"

In [25]:
dp1[60],dp2[60]

(['train/3/9680.png', '3'], ['valid/3/9219.png', '3'])

In [26]:
list(dp1)[:5],list(dp2)[:5]

([['train/3/7463.png', '3'],
  ['train/3/9829.png', '3'],
  ['train/3/7881.png', '3'],
  ['train/3/8065.png', '3'],
  ['train/3/7046.png', '3']],
 [['valid/3/8430.png', '3'],
  ['valid/3/7946.png', '3'],
  ['valid/3/933.png', '3'],
  ['valid/3/9308.png', '3'],
  ['valid/3/795.png', '3']])

In [30]:
class KSplitter():
    k=0
    def __init__(self,k_splits=2): self.k_splits=k_splits
    def __call__(self,*args):
        try: 
            return self.k
        finally: 
            self.k+=1
            if self.k==self.k_splits: self.k=0

k1,k2,k3 = DemultiplexerMapDataPipe(dp1,num_instances=3, classifier_fn=KSplitter(k_splits=3), drop_none=True)

In [31]:
len(k1),len(k2),len(k3)

(237, 236, 236)

In [33]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 02_fastai.exception_test.ipynb.
Converted 02a_fastai.loop.ipynb.
Converted 02a_fastai.loop_initial.ipynb.
Converted 02b_fastai.data.load.ipynb.
Converted 02c_fastai.data.block.ipynb.
Converted 02c_fastai.data.pipes.demux.ipynb.
Converted 02c_fastai.data.pipes.mux.ipynb.
Converted 03_callback.core.ipynb.
Converted 04_agent.ipynb.
Converted 05_data.test_async.ipynb.
Converted 05a_data.block.ipynb.
Converted 05b_data.gym.ipynb.
Converted 06a_memory.experience_replay.ipynb.
Converted 06f_memory.tensorboard.ipynb.
Converted 10a_agents.dqn.core.ipynb.
Converted 10b_agents.dqn.targets.ipynb.
Converted 10c_agents.dqn.double.ipynb.
Converted 10d_agents.dqn.dueling.ipynb.
Converted 10e_agents.dqn.categorical.ipynb.
Converted 11a_agents.policy_gradient.ppo.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.


In [None]:

# import warnings

# from collections import deque
# from collections.abc import Hashable

# from typing import Any, Callable, Iterator, List, Optional, Set, Sized, Tuple, TypeVar, Deque

# from torch.utils.data import IterDataPipe, functional_datapipe
# from torch.utils.data.datapipes.utils.common import check_lambda_fn
# from torch.utils.data._utils.serialization import serialize_fn, deserialize_fn


# T_co = TypeVar("T_co", covariant=True)



# class _ChildMapDataPipe(dp.map.MapDataPipe):
#     def __init__(self, main_datapipe, instance_id: Hashable):
#         required_attrs = ["get_next_element_by_instance", "is_instance_started", "getitem_by_instance"]
#         required_ops = [getattr(main_datapipe, attr) for attr in required_attrs]
#         if any(not callable(op) for op in required_ops):
#             raise NotImplementedError(f"Main Datapipe must have methods {required_attrs} implemented.")
#         self.main_datapipe = main_datapipe
#         self.instance_id = instance_id

#     def __iter__(self):
#         # These is no concept of exhaustion of the 'main_datapipe'. We only need
#         # to run through it once, then use the cached indexes for querying.
#         return self.get_generator_by_instance(self.instance_id)
    
#     @property
#     def _map(self): 
#         return self.main_datapipe.child_index_buffers[self.instance_id]

#     def __len__(self):
#         if not self.main_datapipe.main_datapipe_exhausted:
#             warnings.warn(
#                 "Data from prior DataPipe are loaded to get length of"
#                 "_ChildMapDataPipe before execution of the pipeline."
#                 "Please consider removing len()."
#             )
#             return len(list(self.get_generator_by_instance(self.instance_id)))
#         # Need to be careful here,  the len of `_ChildMapDataPipe` will be <= len(self.main_datapipe)
#         return len(self.main_datapipe.get_instance_buffer(self.instance_id))

#     def get_generator_by_instance(self, instance_id: Hashable):
#         yield from self.main_datapipe.get_next_element_by_instance(self.instance_id)
        
#     def __getitem__(self, index):
#         "Gets an item from `self.main_datapipe` in `self.instance_id`"
#         return self.main_datapipe.getitem_by_instance(self.instance_id, index)


# class _DemultiplexerMapDataPipe(dp.map.MapDataPipe):
#     def __init__(self, datapipe: dp.map.MapDataPipe[T_co], 
#                  # num_instances: int,
#                  instance_keys: Hashable,
#                  classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool):
#         self.main_datapipe = datapipe
#         self._datapipe_indexer: Optional[Iterator[Any]] = None
#         # self._datapipe_iterator: Optional[Iterator[Any]] = None
#         self.instance_keys = instance_keys
#         # The child buffers will store the indexes separated into their respective
#         # `_ChildMapDataPipe`'s
#         self.child_index_buffers: Dict[set[T_co]] = {k:set() for k in self.instance_keys}
#         self.instance_started: Dict[Hashable,bool] = {k:False for k in instance_keys}
#         self.classifier_fn = classifier_fn
#         self.drop_none = drop_none
#         self.main_datapipe_exhausted = False
        
#     def _setup_datapipe_indexer(self) -> Optional[Iterator[Any]]:
#         # self._datapipe_iterator: Optional[Iterator[Any]] = None
#         # Instead of _datapipe_iterator we have _datapipe_indexer
#         # We need to know how to get the index from the main_datapipe. In order
#         # to do this, we check if it is...
        
#         # NOTE: THIS IS NOT A GOOD SOLUTION SINCE THIS CANT RELY ON A STANDARD
#         # INTERFACE FOR GETTING INDEXES
        
#         # We cash the indexes because we want to be able to have consistent behavior 
#         # when calling __getitem__ on a child pipe. 
#         # What we don't want is the main_datapipe being indexed by `str` but the
#         # child pipes indexing by `int`...
#         if isinstance(self.main_datapipe, dp.map.SequenceWrapper):
#             return iter(range(len(self.main_datapipe)))
#         elif hasattr(self.main_datapipe, '_map'):
#             return iter(self.main_datapipe._map)
#         elif hasattr(self.main_datapipe, 'index_map'):
#             return iter(self.main_datapipe.index_map)
#         else:
#             warnings.warn('data pipe will be indexed by len')
#             return iter(range(len(self.main_datapipe)))
        
#     def get_instance_buffer(self, instance_id: Hashable):
#         return self.child_index_buffers[instance_id]

#     def _find_next(self, instance_id: Hashable) -> T_co:
#         while True:
#             if self.main_datapipe_exhausted:
#                 raise StopIteration
#             if self._datapipe_indexer is None:
#                 raise ValueError(
#                     "_datapipe_indexer has not been set, likely because this private method is called directly "
#                     "without invoking get_next_element_by_instance() first.")
#             index = next(self._datapipe_indexer)
#             value = self.main_datapipe[index]
#             classification = self.classifier_fn(value)
#             if classification is None and self.drop_none:
#                 continue
#             if classification is None or classification not in self.instance_keys:
#                 raise ValueError(f"Output of the classification fn should be a key in {self.instance_keys}. " +
#                                  f"{classification} is returned.")
            
#             if index not in self.child_index_buffers[classification]:
#                 self.child_index_buffers[classification].add(index)

#             if classification == instance_id:
#                 return value,index
            
#     def getitem_by_instance(self, instance_id: Hashable, index: Hashable):
#         # We need to handle the situation where the index is not currently cached.
#         # In this case we still need to build the cache, while still attempting to 
#         # get the value for `index`
        
#         # In this case, `main_datapipe_exhausted` which means we still have some
#         # of the cache to populate possibly.
#         # Josiah: The main_datapipe_exhausted doesnt make sense in this context.
#         if index in self.child_index_buffers[instance_id]:
#             return self.main_datapipe[index]
        
#         if not self.main_datapipe_exhausted:
#             for _ in self.get_next_element_by_instance(instance_id):
#                 if index in self.child_index_buffers[instance_id]:
#                     return self.main_datapipe[index]
        
#         raise IndexError(f'Index {index} not found in {instance_id}')

#     def get_next_element_by_instance(self, instance_id: Hashable):
#         # Josiah: The main_datapipe_exhausted doesnt make sense in this context.
#         if self._datapipe_indexer is None and not self.main_datapipe_exhausted:
#             self._datapipe_indexer = iter(self._setup_datapipe_indexer())
#         stop = False
#         self.instance_started[instance_id] = True
#         instance_next_indexer = None
        
#         while not stop:
#             # We only want to iterate through the indexes once `self._datapipe_indexer` is clear
#             # so that we are "gaurenteed" to go through all the indexes possible for 
#             # instance_id
#             if self.child_index_buffers[instance_id] and self.main_datapipe_exhausted:
#                 instance_next_indexer = self.child_index_buffers[instance_id]
#                 yield from (self.main_datapipe[index] for index in instance_next_indexer)
#                 break
#             else:
#                 try:
#                     value,index = self._find_next(instance_id)
#                     yield value
#                 except StopIteration:
#                     stop = True
#                     self.main_datapipe_exhausted = True
#                     self._datapipe_indexer = None
                    
#     def is_instance_started(self, instance_id: Hashable) -> bool:
#         return self.instance_started[instance_id]

#     def reset(self):
#         self._datapipe_indexer: Optional[Iterator[Any]] = None
#         self.child_index_buffers: Dict[set[T_co]] = {k:set() for k in self.instance_keys}
#         self.instance_started: Dict[Hashable,bool] = {k:False for k in instance_keys}
#         self.main_datapipe_exhausted = False

#     def __getstate__(self):
#         if IterDataPipe.getstate_hook is not None:
#             return IterDataPipe.getstate_hook(self)

#         serialized_fn_with_method = serialize_fn(self.classifier_fn)
#         state = (
#             self.main_datapipe,
#             self.instance_keys,
#             self.buffer_size,
#             serialized_fn_with_method,
#             self.drop_none,
#         )
#         return state

#     def __setstate__(self, state):
#         (
#             self.main_datapipe,
#             self.num_instances,
#             self.buffer_size,
#             serialized_fn_with_method,
#             self.drop_none,
#         ) = state
#         self.classifier_fn = deserialize_fn(serialized_fn_with_method)
#         self._datapipe_indexer: Optional[Iterator[Any]] = None
#         self.child_index_buffers: Dict[set[T_co]] = {k:set() for k in self.instance_keys}
#         self.instance_started: Dict[Hashable,bool] = {k:False for k in instance_keys}
#         self.main_datapipe_exhausted = False

# class DemultiplexerMapDataPipe(dp.map.MapDataPipe):
#     def __new__(cls, datapipe: dp.map.MapDataPipe, instance_keys: List[Hashable],
#                 classifier_fn: Callable[[T_co], Optional[int]], drop_none: bool = False):
#         if not isinstance(datapipe, dp.map.MapDataPipe):
#             raise TypeError(f"DemultiplexerMapDataPipe can only apply on MapDataPipe, but found {type(datapipe)}")
#         if not instance_keys:
#             raise ValueError(f"Expected `instance_keys` larger than 0, but {instance_keys} is found")

#         check_lambda_fn(classifier_fn)

#         container = _DemultiplexerMapDataPipe(datapipe, instance_keys, classifier_fn, drop_none)
#         return [_ChildMapDataPipe(container, k) for k in instance_keys]
