In [1]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
#|hide
#|eval: false
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
#|default_exp pipes.core

In [169]:
#|export
# Python native modules
import os
import logging
import inspect
from typing import Callable,Union,TypeVar
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterDataPipe
from torchdata.datapipes.map import MapDataPipe
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,traverse,_assign_attr,replace_dp
# Local modules

# Pipes Core
> Utilities for templating pipelines

In [5]:
#|export
def find_dp(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> DataPipe:
    pipes = find_dps(graph,dp_type)
    if len(pipes)==1: return pipes[0]
    elif len(pipes)>1:
        found_ids = set([id(pipe) for pipe in pipes])
        if len(found_ids)>1:
            warn(f"""There are {len(pipes)} pipes of type {dp_type}. If this is intended, 
                     please use `find_dps` directly. Returning first instance.""")
        return pipes[0]
    else:
        raise LookupError(f'Unable to find {dp_type} starting at {graph}')
    
find_dp.__doc__ = "Returns a single `DataPipe` as opposed to `find_dps`.\n"+find_dps.__doc__

For example if we have a pipeline such as:

In [160]:
class Template(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe=None): self.source_datapipe = source_datapipe
    def __iter__(self): return (o for o in self.source_datapipe)

class A(Template):pass
class B(Template):pass
class C(Template):pass
class D(Template):pass
class E(Template):pass
class F(Template):pass

pipe = A(range(10))
pipe = B(pipe)
pipe = C(pipe)
pipe = D(pipe)
pipe = D(pipe)
pipe = E(pipe)
list(pipe)

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

We can grab the instance `C` in the middle of the graph via...

In [154]:
find_dp(traverse(pipe),C)

C

If the pipe doesnt exist a `LookupError` gets raised...

In [158]:
with ExceptionExpected(LookupError):
    find_dp(traverse(pipe),F)

And if there are multiple instances of the same time, you will be warned to use `find_dps` instead...

In [161]:
find_dp(traverse(pipe),D)

                     please use `find_dps` directly. Returning first instance.
  if __name__ == "__main__":


D

In [164]:
#|hide
#|eval: false
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()