In [None]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [None]:
#|default_exp pipes.map.mux
from nbdev.showdoc import *

In [None]:
#|export
# Python native modules
import os
from inspect import isfunction,ismethod
from itertools import chain, zip_longest
from typing import Callable, Dict, Iterable, Optional, TypeVar
# Third party libs
from fastcore.all import *
from fastai.torch_basics import *
# from torch.utils.data.dataloader import DataLoader as OrgDataLoader
import torchdata.datapipes as dp
from torchdata.datapipes import functional_datapipe
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,MapDataPipe,IterDataPipe
from torchdata.dataloader2.dataloader2 import DataLoader2
# Local modules

# Multiplexer
> MapDataPipe for splitting a single datapipe into multiple

In [None]:
#|export
T_co = TypeVar("T_co", covariant=True)

@functional_datapipe("mux")
class MultiplexerMapDataPipe(MapDataPipe[T_co]):
    def __init__(self, *datapipes, dp_index_map: Optional[Dict[MapDataPipe, Iterable]] = None):
        self.datapipes = datapipes
        self.dp_index_map = dp_index_map if dp_index_map else {}
        self.length: Optional[int] = None
        self.index_map = {}
        # Create a generator that yields (index, (dp_num, old_index)) in sequentially order.
        indices = (self._add_dp_num(i, dp) for i, dp in enumerate(datapipes))
        dp_id_and_key_tuples = chain.from_iterable(zip_longest(*indices))
        self.key_gen = enumerate(e for e in dp_id_and_key_tuples if e is not None)

    def _add_dp_num(self, dp_num: int, dp: MapDataPipe):
        # Assume 0-index for all DataPipes unless alternate indices are defined in `self.dp_index_map`
        dp_indices = self.dp_index_map[dp] if dp in self.dp_index_map else range(len(dp))
        for idx in dp_indices:
            yield dp_num, idx

    def __getitem__(self, index):
        if 0 <= index < len(self):
            if index in self.index_map:
                dp_num, old_key = self.index_map[index]
            else:
                curr_key = -1
                while curr_key < index:
                    curr_key, dp_num_key_tuple = next(self.key_gen)
                    dp_num, old_key = dp_num_key_tuple
                self.index_map[index] = dp_num, old_key
            try:
                return self.datapipes[dp_num][old_key]
            except KeyError:
                raise RuntimeError(
                    f"Incorrect key is given to MapDataPipe {dp_num} in Multiplexer, likely because"
                    f"that DataPipe is not 0-index but alternate indices are not given."
                )
        raise RuntimeError(f"Index {index} is out of bound for Multiplexer.")

    def __iter__(self):
        for i in range(len(self)):
            yield self[i]

    def __len__(self):
        if self.length is None:
            self.length = 0
            for dp in self.datapipes:
                self.length += len(dp)
        return self.length
    
MultiplexerMapDataPipe.__doc__ = """Yields one element at a time from each of the input Iterable DataPipes (functional name: ``mux``). As in,
    one element from the 1st input DataPipe, then one element from the 2nd DataPipe in the next iteration,
    and so on. It ends when the shortest input DataPipe is exhausted.
"""

For example we can have a datapipe `a` that contains a list of numbers from `0->10`,
and `b` that is a map that maps letters to some integers.

`dp_index_map` gives instructions for each additional datapipe (`b`) on the order to querying for values.

The result is that the letters `a`,`b`, etc get mapped to their respective integers.

In [None]:
a = dp.map.SequenceWrapper(range(10))
b = dp.map.SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
datapipe = a.mux(b, dp_index_map={b: ['a', 'b', 'c', 'd']})
test_eq(list(datapipe),[0, 100, 1, 200, 2, 300, 3, 400, 4, 5, 6, 7, 8, 9])
list(datapipe)

[0, 100, 1, 200, 2, 300, 3, 400, 4, 5, 6, 7, 8, 9]

In the second example we take a list of numbers `0->12` and split them equally into 3 datapipes...

In [None]:
from fastrl.pipes.map.demux import DemultiplexerMapDataPipe

In [None]:
a = dp.map.SequenceWrapper(range(12))

def split_three_way(a):
    return a%3

k1,k2,k3 = a.demux(num_instances=3,classifier_fn=split_three_way)
for pipe in (k1,k2,k3): test_eq(len(pipe),4)
len(k1),len(k2),len(k3)

(4, 4, 4)

We can then recombine them back into a single datapipe...

In [None]:
combined_pipes=MultiplexerMapDataPipe(k1,k2,k3)
test_eq(list(combined_pipes),range(12))
list(combined_pipes)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()