In [None]:
#|hide
from fastrl.test_utils import initialize_notebook
initialize_notebook()

In [None]:
#|default_exp pipes.core

In [None]:
#|export
# Python native modules
import os
import logging
import inspect
from typing import Callable,Union,TypeVar,Optional,Type,List,Tuple
# Third party libs
import torchdata.datapipes as dp
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
from torchdata.dataloader2.graph import DataPipe, DataPipeGraph,find_dps,traverse_dps,list_dps
# Local modules

In [None]:
from nbdev.showdoc import show_doc
from fastcore.all import ExceptionExpected,test_eq

# Pipes Core
> Utilities for templating pipelines

In [None]:
#|export
def find_dps(
        graph: DataPipeGraph, 
        dp_type: Type[DataPipe],
        include_subclasses:bool=False
    ) -> List[DataPipe]:
    r"""
    Given the graph of DataPipe generated by ``traverse`` function, return DataPipe
    instances with the provided DataPipe type.
    """
    dps: List[DataPipe] = []

    def helper(g) -> None:  # pyre-ignore
        for _, (dp, src_graph) in g.items():
            if include_subclasses and issubclass(type(dp),dp_type):
                dps.append(dp)
            elif type(dp) is dp_type:  # Please not use `isinstance`, there is a bug.
                dps.append(dp)
            helper(src_graph)

    helper(graph)

    return dps

In [None]:
#|export
def find_dp(
        # A graph created from the `traverse` function
        graph: DataPipeGraph, 
        # 
        dp_type: Type[DataPipe],
        include_subclasses:bool=False
    ) -> DataPipe:
    pipes = find_dps(graph,dp_type,include_subclasses)
    if len(pipes)==1: return pipes[0]
    elif len(pipes)>1:
        found_ids = set([id(pipe) for pipe in pipes])
        if len(found_ids)>1:
            logging.warn("""There are %s pipes of type %s. If this is intended, 
                     please use `find_dps` directly. Returning first instance.""",len(pipes),dp_type)
        return pipes[0]
    else:
        raise LookupError(f'Unable to find {dp_type} starting at {graph}')
    
find_dp.__doc__ = "Returns a single `DataPipe` as opposed to `find_dps`.\n"+find_dps.__doc__

For example if we have a pipeline such as:

In [None]:
class Template(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe=None): self.source_datapipe = source_datapipe
    def __iter__(self): return (o for o in self.source_datapipe)

class A(Template):pass
class B(Template):pass
class C(Template):pass
class D(Template):pass
class E(Template):pass
class F(Template):pass

pipe = A(range(10))
pipe = B(pipe)
pipe = C(pipe)
pipe = D(pipe)
pipe = D(pipe)
pipe = E(pipe)
list(pipe)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

We can grab the instance `C` in the middle of the graph via...

In [None]:
find_dp(traverse_dps(pipe),C)

C

If the pipe doesnt exist a `LookupError` gets raised...

In [None]:
with ExceptionExpected(LookupError):
    find_dp(traverse_dps(pipe),F)

And if there are multiple instances of the same time, you will be warned to use `find_dps` instead...

In [None]:
find_dp(traverse_dps(pipe),D)

  logging.warn("""There are %s pipes of type %s. If this is intended,
                     please use `find_dps` directly. Returning first instance.


D

If we try searching for all iterpipes we get nothing...

In [None]:
find_dps(traverse_dps(pipe),dp.iter.IterDataPipe)

[]

However we can include subclasses in our search...

In [None]:
find_dps(traverse_dps(pipe),dp.iter.IterDataPipe,include_subclasses=True)

[E, D, D, C, B, A]

In [None]:
#|export
class DataPipeAugmentationFn(Callable[[DataPipe],Optional[DataPipe]]):...

DataPipeAugmentationFn.__doc__ = f"""`DataPipeAugmentationFn` must take in a `DataPipe` and either output a `DataPipe` or `None`. This function should perform some operation on the graph
such as replacing, removing, inserting `DataPipe`'s and `DataGraph`s. Below is an example that replaces a `dp.iter.Batcher` datapipe with a `dp.iter.Filter`"""

In [None]:
show_doc(DataPipeAugmentationFn)

---

### DataPipeAugmentationFn

>      DataPipeAugmentationFn (*args, **kwds)

`DataPipeAugmentationFn` must take in a `DataPipe` and either output a `DataPipe` or `None`. This function should perform some operation on the graph
such as replacing, removing, inserting `DataPipe`'s and `DataGraph`s. Below is an example that replaces a `dp.iter.Batcher` datapipe with a `dp.iter.Filter`

In [None]:
from torchdata.dataloader2.graph import replace_dp

In [None]:

def iseven(i): return i%2==0
def test_replace(pipe:DataPipe) -> DataPipe:
    graph = replace_dp(
        traverse_dps(pipe),
        find_dp(traverse_dps(pipe),dp.iter.Batcher),
        dp.iter.Filter(find_dp(traverse_dps(pipe),dp.iter.IterableWrapper),filter_fn=iseven)
    )
    return list(graph.values())[0][0]

In [None]:
#|export
def apply_dp_augmentation_fns(
        pipe:DataPipe,
        dp_augmentation_fns:Optional[Tuple[DataPipeAugmentationFn]],
        debug:bool=False
    ) -> DataPipe:
    "Given a `pipe`, run `dp_augmentation_fns` other the pipeline"
    if dp_augmentation_fns is None: return pipe
    for fn in dp_augmentation_fns:
        if debug: print(f'Running fn: {fn} given current pipe: \n\t{traverse_dps(pipe)}')
        result = fn(pipe)
        if result is not None: pipe = result
    return pipe

Given a simple pipeline below...

In [None]:
pipe = dp.iter.IterableWrapper(range(10))
pipe = pipe.batch(2)
pipe = pipe.cycle(2)
pipe = pipe.header(limit=10)
traverse_dps(pipe)

{140129912465104: (HeaderIterDataPipe,
  {140129912465440: (CyclerIterDataPipe,
    {140129912463952: (BatcherIterDataPipe,
      {140129912464816: (IterableWrapperIterDataPipe, {})})})})}

We want to run `test_replace` over the pipeline which will replace the `pipe.batch` with a `dp.iter.Filter`...

In [None]:
new_dp = apply_dp_augmentation_fns(pipe,(test_replace,))
test_eq(list(new_dp),[0, 2, 4, 6, 8, 0, 2, 4, 6, 8])
traverse_dps(new_dp)


  warn(


{140129912465104: (HeaderIterDataPipe,
  {140129912465440: (CyclerIterDataPipe,
    {140129911061280: (FilterIterDataPipe,
      {140129912464816: (IterableWrapperIterDataPipe, {})})})})}

In [None]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()

Note nbdev2 no longer supports nbdev1 syntax. Run `nbdev_migrate` to upgrade.
See https://nbdev.fast.ai/getting_started.html for more information.
  warn(f"Notebook '{nbname}' uses `#|export` without `#|default_exp` cell.\n"
