In [None]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
#|default_exp pipes.map.demux
from nbdev.showdoc import *

In [None]:
#|export
# Python native modules
import os
from inspect import isfunction,ismethod
from typing import Callable, Dict, Iterable, Optional, TypeVar
# Third party libs
from fastcore.all import *
from fastrl.torch_core import *
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
import torchdata.datapipes as dp
from torchdata.datapipes import functional_datapipe
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,MapDataPipe,IterDataPipe
from torchdata.dataloader2.dataloader2 import DataLoader2
# Local modules

# DeMultiplexer
> MapDataPipe for splitting a single datapipe into multiple

In [None]:
#|export
T_co = TypeVar("T_co", covariant=True)

@functional_datapipe("demux")
class DemultiplexerMapDataPipe(MapDataPipe[T_co]):
    def __new__(cls, datapipe: MapDataPipe, num_instances: int, classifier_fn: Callable, drop_none: bool = False,
                source_index: Optional[Iterable] = None):
        if num_instances < 1:
            raise ValueError(f"Expected `num_instances` larger than 0, but {num_instances} is found")
        _check_unpickable_fn(classifier_fn)
        container = _DemultiplexerMapDataPipe(datapipe, num_instances, classifier_fn, drop_none, source_index)
        return [_DemultiplexerChildMapDataPipe(container, i) for i in range(num_instances)]

class _DemultiplexerMapDataPipe(MapDataPipe[T_co]):
    def __init__(
        self,
        datapipe: MapDataPipe[T_co],
        num_instances: int,
        classifier_fn: Callable[[T_co], Optional[int]],
        drop_none: bool,
        source_index: Optional[Iterable],
    ):
        self.main_datapipe = datapipe
        self.num_instances = num_instances
        self.classifier_fn = classifier_fn
        self.drop_none = drop_none
        self.iterator = None
        self.exhausted = False  # Once we iterate through `main_datapipe` once, we know all the index mapping
        self.index_mapping = [[] for _ in range(num_instances)]
        self.source_index = source_index  # if None, assume `main_datapipe` 0-index

    def _classify_next(self):
        if self.source_index is None:
            self.source_index = range(len(self.main_datapipe))
        if self.iterator is None:
            self.iterator = iter(self.source_index)
        try:
            next_source_idx = next(self.iterator)
        except StopIteration:
            self.exhausted = True
            return
        value = self.main_datapipe[next_source_idx]
        classification = self.classifier_fn(value)
        if classification is None and self.drop_none:
            self._classify_next()
        else:
            self.index_mapping[classification].append(value)

    def classify_all(self):
        while not self.exhausted:
            self._classify_next()

    def get_value(self, instance_id: int, index: int) -> T_co:
        while not self.exhausted and len(self.index_mapping[instance_id]) <= index:
            self._classify_next()
        if len(self.index_mapping[instance_id]) > index:
            return self.index_mapping[instance_id][index]
        raise RuntimeError("Index is out of bound.")

    def __len__(self):
        return len(self.main_datapipe)

class _DemultiplexerChildMapDataPipe(MapDataPipe[T_co]):
    def __init__(self, main_datapipe: _DemultiplexerMapDataPipe, instance_id: int):
        self.main_datapipe: _DemultiplexerMapDataPipe = main_datapipe
        self.instance_id = instance_id

    def __getitem__(self, index: int):
        return self.main_datapipe.get_value(self.instance_id, index)

    def __len__(self):
        self.main_datapipe.classify_all()  # You have to read through the entirety of main_datapipe to know `len`
        return len(self.main_datapipe.index_mapping[self.instance_id])

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]
            
DemultiplexerMapDataPipe.__doc__ = """Splits the input DataPipe into multiple child DataPipes, using the given
    classification function (functional name: ``demux``). A list of the child DataPipes is returned from this operation.
"""

For example we can have a function that splits a map of numbers into even and odd...

In [None]:
def odd_or_even(n):
    return n % 2
source_dp = dp.map.SequenceWrapper(range(5))
dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even)
list([dp1[i] for i in range(len(dp1))])

[0, 2, 4]

> Note: The resulting datapipes from `DemultiplexerMapDataPipe` will be "re-indexed".

In [None]:
list([dp2[i] for i in range(len(dp2))])

[1, 3]

It can also filter out any element that gets `None` from the `classifier_fn`

In [None]:
def odd_or_even_no_zero(n):
    return n % 2 if n != 0 else None

source_dp = dp.map.SequenceWrapper(range(5))
dp1, dp2 = source_dp.demux(num_instances=2, classifier_fn=odd_or_even_no_zero, drop_none=True)
list([dp1[i] for i in range(len(dp1))])

[2, 4]

In [None]:
list([dp2[i] for i in range(len(dp2))])

[1, 3]

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()