In [14]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [15]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [16]:
#|default_exp pipes.core

In [17]:
#|export
# Python native modules
import os
import logging
import inspect
from typing import Callable,Union,TypeVar,Optional
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,traverse,_assign_attr,replace_dp
# Local modules

# Pipes Core
> Utilities for templating pipelines

In [18]:
#|export
def find_dp(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> DataPipe:
    pipes = find_dps(graph,dp_type)
    if len(pipes)==1: return pipes[0]
    elif len(pipes)>1:
        found_ids = set([id(pipe) for pipe in pipes])
        if len(found_ids)>1:
            warn(f"""There are {len(pipes)} pipes of type {dp_type}. If this is intended, 
                     please use `find_dps` directly. Returning first instance.""")
        return pipes[0]
    else:
        raise LookupError(f'Unable to find {dp_type} starting at {graph}')
    
find_dp.__doc__ = "Returns a single `DataPipe` as opposed to `find_dps`.\n"+find_dps.__doc__

For example if we have a pipeline such as:

In [19]:
class Template(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe=None): self.source_datapipe = source_datapipe
    def __iter__(self): return (o for o in self.source_datapipe)

class A(Template):pass
class B(Template):pass
class C(Template):pass
class D(Template):pass
class E(Template):pass
class F(Template):pass

pipe = A(range(10))
pipe = B(pipe)
pipe = C(pipe)
pipe = D(pipe)
pipe = D(pipe)
pipe = E(pipe)
list(pipe)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

We can grab the instance `C` in the middle of the graph via...

In [20]:
find_dp(traverse(pipe),C)

C

If the pipe doesnt exist a `LookupError` gets raised...

In [21]:
with ExceptionExpected(LookupError):
    find_dp(traverse(pipe),F)

And if there are multiple instances of the same time, you will be warned to use `find_dps` instead...

In [22]:
find_dp(traverse(pipe),D)

                     please use `find_dps` directly. Returning first instance.
  if __name__ == "__main__":


D

In [125]:
#|export
DataPipeAugmentationFn = Callable[[DataPipe],Optional[DataPipe]]

DataPipeAugmentationFn.__doc__ = f"""`DataPipeAugmentationFn` must take in a `DataPipe` and either output a `DataPipe` or `None`. This function should perform some operation on the graph
such as replacing, removing, inserting `DataPipe`'s and `DataGraph`s. Below is an example that replaces a `dp.iter.Batcher` datapipe with a `dp.iter.Filter`

Type: `{DataPipeAugmentationFn.__str__().replace(DataPipe.__str__(),'DataPipe').replace(Optional[DataPipe].__str__(),'Optional[DataPipe]').replace('typing.','')}`
"""

In [127]:
show_doc(DataPipeAugmentationFn)

---

### Callable

>      Callable (*args, **kwargs)

`DataPipeAugmentationFn` must take in a `DataPipe` and either output a `DataPipe` or `None`. This function should perform some operation on the graph
such as replacing, removing, inserting `DataPipe`'s and `DataGraph`s. Below is an example that replaces a `dp.iter.Batcher` datapipe with a `dp.iter.Filter`

Type: `Callable[[DataPipe], Optional[DataPipe]]`

In [103]:
def test_replace(pipe:DataPipe) -> DataPipe:
    graph = replace_dp(
        traverse(pipe),
        find_dp(traverse(pipe),dp.iter.Batcher),
        dp.iter.Filter(find_dp(traverse(pipe),dp.iter.IterableWrapper),filter_fn=lambda o:o%2==0)
    )
    return list(graph.values())[0][0]

In [111]:
#|export
def apply_dp_augmentation_fns(pipe:DataPipe,dp_augmentation_fns:Tuple[DataPipeAugmentationFn],debug:bool=False):
    "Given a `pipe`, run `dp_augmentation_fns` other the pipeline"
    for fn in dp_augmentation_fns:
        if debug: print(f'Running fn: {fn} given current pipe: \n\t{traverse(pipe)}')
        result = fn(pipe)
        if result is not None: pipe = result
    return pipe

Given a simple pipeline below...

In [108]:
pipe = dp.iter.IterableWrapper(range(10))
pipe = pipe.batch(2)
pipe = pipe.cycle(2)
pipe = pipe.header(limit=10)
traverse(pipe)

{139738656070352: (HeaderIterDataPipe,
  {139738656070096: (CyclerIterDataPipe,
    {139738656070224: (BatcherIterDataPipe,
      {139738656069968: (IterableWrapperIterDataPipe, {})})})})}

We want to run `test_replace` over the pipeline which will replace the `pipe.batch` with a `dp.iter.Filter`...

In [109]:
new_dp = apply_dp_augmentation_fns(pipe,(test_replace,))
test_eq(list(new_dp),[0, 2, 4, 6, 8, 0, 2, 4, 6, 8])
traverse(new_dp)


  "The length of this HeaderIterDataPipe is inferred to be equal to its limit."


{139738656070352: (HeaderIterDataPipe,
  {139738656070096: (CyclerIterDataPipe,
    {139738655387920: (FilterIterDataPipe,
      {139738656069968: (IterableWrapperIterDataPipe, {})})})})}

In [128]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()