In [1]:
#|hide
#|eval: false
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
#|hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
#|default_exp pipes.core

In [19]:
#|export
# Python native modules
import os
import logging
import inspect
from typing import Callable,Union
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torchdata.datapipes import functional_datapipe
from torchdata.dataloader2.graph import find_dps,DataPipeGraph,Type,DataPipe,traverse,_assign_attr,replace_dp
# Local modules

# Pipes Core
> Utilities for templating pipelines

In [20]:
#|export
def find_dp(graph: DataPipeGraph, dp_type: Type[DataPipe]) -> DataPipe:
    pipes = find_dps(graph,dp_type)
    if len(pipes)==1: return pipes[0]
    elif len(pipes)>1:
        found_ids = set([id(pipe) for pipe in pipes])
        if len(found_ids)>1:
            warn(f"""There are {len(pipes)} pipes of type {dp_type}. If this is intended, 
                     please use `find_dps` directly. Returning first instance.""")
        return pipes[0]
    else:
        raise LookupError(f'Unable to find {dp_type} starting at {graph}')
    
find_dp.__doc__ = "Returns a single `DataPipe` as opposed to `find_dps`.\n"+find_dps.__doc__

In [7]:
#|export
def _insert_dp(recv_dp, send_graph: DataPipeGraph, old_dp: DataPipe, new_dp: DataPipe) -> None:
    old_dp_id = id(old_dp)
    for send_id in send_graph:
        if send_id == old_dp_id:
            _assign_attr(recv_dp, old_dp, new_dp, inner_dp=True)
            
            # Replace the last datapipe in new_dp with the old_dp
            final_datapipe = find_dp(traverse(new_dp),PassThroughIterPipe)
            
            _assign_attr(new_dp, final_datapipe, old_dp, inner_dp=True)
            # new_dp.source_datapipe
        else:
            send_dp, sub_send_graph = send_graph[send_id]
            _insert_dp(send_dp, sub_send_graph, old_dp, new_dp)

In [8]:
#|export
def insert_dp(graph: DataPipeGraph, on_datapipe: DataPipe, insert_datapipe: DataPipe) -> DataPipeGraph:
    r"""
    Given the graph of DataPipe generated by ``traverse`` function and the ``on_datapipe`` DataPipe to be reconnected and
    the new ``insert_datapipe`` DataPipe to be inserted after ``on_datapipe``, 
    return the new graph of DataPipe.
    """
    assert len(graph) == 1

    # Check if `on_datapipe` is that the head of the graph
    # If so, we `insert_datapipe`
    if id(on_datapipe) in graph: 
        graph = traverse(insert_datapipe, only_datapipe=True)

    final_datapipe = list(graph.values())[0][0]
    
    for recv_dp, send_graph in graph.values():
        _insert_dp(recv_dp, send_graph, on_datapipe, insert_datapipe)

    return traverse(final_datapipe, only_datapipe=True)


In [9]:
it_pipe = dp.iter.IterableWrapper([1,2,3,4,5,6])
pipe = it_pipe.cycle(count=2)
pipe = pipe.batch(batch_size=2)

In [10]:
#|export
class PassThroughIterPipe(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe): self.source_datapipe = source_datapipe
    def __iter__(self): return (o for o in self.source_datapipe)

In [11]:
new_dp = insert_dp(
    traverse(pipe,only_datapipe=True),
    find_dp(traverse(pipe,only_datapipe=True),dp.iter.Cycler),
    dp.iter.Header(PassThroughIterPipe([]),limit=4)
)

PassThroughIterPipe


In [12]:
traverse(pipe)

{139679188293072: (BatcherIterDataPipe,
  {139679188300368: (HeaderIterDataPipe,
    {139679188292944: (CyclerIterDataPipe,
      {139679188293008: (IterableWrapperIterDataPipe, {})})})})}

In [13]:
new_dp

{139679188293072: (BatcherIterDataPipe,
  {139679188300368: (HeaderIterDataPipe,
    {139679188292944: (CyclerIterDataPipe,
      {139679188293008: (IterableWrapperIterDataPipe, {})})})})}

In [14]:
list(new_dp)

[139679188293072]

In [17]:
#|export
class TypeTransformLoop(dp.map.MapDataPipe):
    def __init__(self,datapipe, type_tfms):
        self.type_tfms,self.datapipe = Pipeline(type_tfms),datapipe
    
    def __getitem__(self, index):
        data = self.datapipe[index]
        return self.type_tfms(data)
            
    def __len__(self): return len(self.datapipe)
    
class ItemTransformLoop(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe, item_tfms:List[Callable]): 
        self.item_tfms,self.source_datapipe = Pipeline(item_tfms),source_datapipe
    
    def __iter__(self):
        for data in self.source_datapipe:
            yield self.item_tfms(data)
    
class BatchTransformLoop(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe, batch_tfms):
        self.batch_tfms,self.source_datapipe = Pipeline(batch_tfms),source_datapipe
    
    def __iter__(self):
        for data in self.source_datapipe:
            yield self.batch_tfms(data)

In [2]:
#|hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev import nbdev_export
    nbdev_export()