In [1]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
# default_exp fastai.data.pipes.core

In [4]:
# export
# Python native modules
import os
import logging
import inspect
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.graph import traverse
from torchdata.datapipes import functional_datapipe
# Local modules


_logger = logging.getLogger()

# Pipes Core
> Callback + DataPipe support for highly flexible looping.

In [5]:
# export
_allowed_hook_params = ['before','after','not_under']

class Callback():
    @property
    def name(self):
        "Name of the `Callback`, camel-cased and with '*Callback*' removed"
        return class2attr(self, 'Callback')
    
    def hooks(self):
        if inspect.isclass(self): raise ValueError(f'{self} needs to be instantiated!')
        hooks = []
        def in_allowed_hooks(param): return param in _allowed_hook_params
        for k in self.__class__.__dict__:
            if k.startswith('_'): continue
            params = L(inspect.signature(getattr(self,k)).parameters).map(in_allowed_hooks)
            if params and all(params): hooks.append(getattr(self,k))
        return hooks

In [6]:
# export
def filter_call_on_cbs(obj, cbs): return tuple(cb for cb in cbs if obj.__class__ in cb.call_on)

In [7]:
# export
def filter_exclude_under_cbs(
    pipe:Union[dp.map.MapDataPipe,dp.iter.IterDataPipe], 
    cbs:List[Callback]
):
    cbs = tuple(cb for cb in cbs if pipe.__class__  not in cb.exclude_under)
    for v in traverse(pipe,only_datapipe=True).values(): # We dont want to traverse non-dp objects.
        for k,_ in v.items():
            cbs = filter_exclude_under_cbs(k,cbs)
    return cbs

In [8]:
# export
def find_pipes(
    pipe:Union[dp.map.MapDataPipe,dp.iter.IterDataPipe],
    fn,
    pipe_list=None
):
    pipe_list = ifnone(pipe_list,[]) 
    if issubclass(pipe.__class__,(dp.map.MapDataPipe,dp.iter.IterDataPipe)) and fn(pipe): pipe_list.append(pipe)
    for v in traverse(pipe,only_datapipe=True).values(): # We dont want to traverse non-dp objects.
        for k,_ in v.items(): cbs = find_pipes(k,fn,pipe_list)
    return pipe_list

A callback does a couple things:

    - inserts data pipes at different points in a pipeline
    - maintains business logic state
    - allows publishing that state to another object.

In [9]:
class OneAdder(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe): self.source_datapipe = source_datapipe
    def __iter__(self):
        for o in self.source_datapipe: 
            print('adding',o)
            yield o+1

class PointZeroFiveAdder(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe): self.source_datapipe = source_datapipe
    def __iter__(self):
        for o in self.source_datapipe: yield o+0.05

class PointZeroOne(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe): self.source_datapipe = source_datapipe
    def __iter__(self):
        for o in self.source_datapipe: yield o+0.01
        
class Point5Adder(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe): self.source_datapipe = source_datapipe
    def __iter__(self):
        for o in self.source_datapipe:
            if type(o) not in [list,dp.DataChunk]: 
                raise AssertionError(f'This goes after batcher, these should be lists but instead are {type(o)}')
            o = [f+0.5 for f in o]
            yield o

class Point1Adder(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe): self.source_datapipe = source_datapipe
    def __iter__(self):
        for o in self.source_datapipe:
            if type(o) not in [list,dp.DataChunk]: 
                raise AssertionError(f'This goes after batcher, these should be lists but instead are {type(o)}')
            o = [f+0.1 for f in o]
            yield o
        
class TestCallback(Callback):
    
    def not_a_hook(self): return 'this should not be processed as a hook!'
    
    def add_one(self,before=dp.iter.Batcher,after=None,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(OneAdder)
    def add_point_zero_five(self,before=None,after=dp.iter.IterableWrapper,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(PointZeroFiveAdder)
    def add_point_zero_1(self,before=None,after=dp.iter.IterableWrapper,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(PointZeroOne)
    
    def add_point5_and_point1_batch(self,before=None,after=dp.iter.Batcher,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(Point5Adder,Point1Adder)
    

In [10]:
cb = TestCallback()
cb.hooks()

[<bound method TestCallback.add_one of <__main__.TestCallback object at 0x7f008c9ad250>>,
 <bound method TestCallback.add_point_zero_five of <__main__.TestCallback object at 0x7f008c9ad250>>,
 <bound method TestCallback.add_point_zero_1 of <__main__.TestCallback object at 0x7f008c9ad250>>,
 <bound method TestCallback.add_point5_and_point1_batch of <__main__.TestCallback object at 0x7f008c9ad250>>]

In [22]:
# export
for _pipe in [dp.map.MapDataPipe,dp.iter.IterDataPipe]:
    _pipe.callbacks = L()
    
    @patch
    def __repr__(self:_pipe):
        if self.repr_hook is not None:
            return self.repr_hook(self)
        # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
        str_rep = str(self.__class__.__qualname__)
        if self.callbacks: return str_rep + str(self.callbacks)
        return str_rep

    @patch
    def __str__(self:_pipe):
        if self.str_hook is not None:
            return self.str_hook(self)
        # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
        str_rep = str(self.__class__.__qualname__)
        if self.callbacks: return str_rep + str(self.callbacks)
        return str_rep
    
    @patch
    def add_cbs_before(self:_pipe,cbs):
        pipe = self
        if cbs is None or len(cbs)==0: return pipe
        
        for cb in cbs:
            for hook in cb.hooks():
                pipe = add_hooks_before(pipe,hook,base_pipe=self)
        if pipe.__class__==PassThroughIterPipe: return pipe.source_datapipe
        return pipe
    
    @patch
    def add_cbs_after(self:_pipe,cbs):
        pipe = self
        if cbs is None or len(cbs)==0: return pipe
        after_pipe,fld = after_pipes(pipe)
        
        for cb in cbs:
            for hook in cb.hooks():
                # In this instance, we want to add the hook if the event is `after_pipe`
                # So if after_pipe->pipe,
                # we add `hook` before `pipe` which ends up also being after `after_pipe`
                # So the result is: `after_pipe->hook_results->pipe`
                pipe = add_hooks_before(pipe,hook,base_pipe=after_pipe,event_key='after')
        if pipe.__class__==PassThroughIterPipe: return pipe.source_datapipe
        return pipe

In [24]:
# export
def after_pipes(dp):
    if hasattr(dp,'iterable'):          return dp.iterable,'iterable'
    elif hasattr(dp,'datapipe'):        return dp.datapipe,'datapipe'
    elif hasattr(dp,'source_datapipe'): return dp.source_datapipe,'source_datapipe'
    elif hasattr(dp,'main_datapipe'):   return dp.main_datapipe,'main_datapipe'
    elif hasattr(dp,'datapipes'):       return dp.datapipes,'datapipes'
    else:                               return None,None


In [23]:
# export
_supported_pipe_attrs = ['iterable','datapipe','source_datapipe','main_datapipe','datapipes']

def add_hooks_before(dp,cb_hook,base_pipe=None,event_key='before'):
    "Given `dp`, attach a `cb_hook` before or after it. It will not be attached if there is a `not_under` farthur up the pipeline."
    events = {k:v.default for k,v in inspect.signature(cb_hook).parameters.items()}
    
    if events['not_under'] is not None:
        for not_under_pipe in L(events['not_under']):
            if not find_pipes(dp,lambda o:o is not_under_pipe):
                return dp
    if events[event_key] is not None:
        for pipe in L(events[event_key]):
            if pipe==base_pipe.__class__:
                for cb_dp in cb_hook():
                    if hasattr(dp,'iterable'):   
                        cb_dp = cb_dp(dp.iterable)
                        dp.iterable = cb_dp
                    elif hasattr(dp,'datapipe'): 
                        cb_dp = cb_dp(dp.datapipe) 
                        dp.datapipe = cb_dp
                    elif hasattr(dp,'source_datapipe'): 
                        cb_dp = cb_dp(dp.source_datapipe) 
                        dp.source_datapipe = cb_dp
                    elif hasattr(dp,'main_datapipe'): 
                        cb_dp = cb_dp(dp.main_datapipe) 
                        dp.main_datapipe = cb_dp
                    elif hasattr(dp,'datapipes'): 
                        dp.datapipes = tuple(cb_dp(_dp) for _dp in dp.datapipes)
                    else:
                        raise ValueError(f'Given {cb_hook}, tried adding {cb_dp} to {after_pipe}:{dp}:base:{base_pipe} \
                            but doesnt have any of the expected attrs: {_supported_pipe_attrs}')
    return dp


> Note: In order for hooks to work correctly, the base pipeline has to already be constructed.

In [25]:
class TestCallback(Callback):
    
    def not_a_hook(self): return 'this should not be processed as a hook!'
    
    def add_one(self,before=dp.iter.Batcher,after=None,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(OneAdder)
    def add_one2(self,before=dp.iter.Batcher,after=None,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(PointZeroOne)

    def add_point_zero_five(self,before=None,after=dp.iter.IterableWrapper,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(PointZeroFiveAdder)    
    def add_point_zero_1(self,before=None,after=dp.iter.IterableWrapper,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(PointZeroOne)
    def add_point5_and_point1_batch(self,before=None,after=dp.iter.Batcher,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(Point5Adder,Point1Adder)
    def add_point5_and_point1_batch_2(self,before=None,after=dp.iter.Batcher,not_under=None) -> List[dp.iter.IterDataPipe]:
        return L(Point5Adder,Point1Adder)
    

In [26]:
_logger.setLevel('INFO')

In [27]:
# export
class PassThroughIterPipe(dp.iter.IterDataPipe):
    def __init__(self,source_datapipe): self.source_datapipe = source_datapipe
    def __iter__(self): return (o for o in self.source_datapipe)

Before alg:
    
    new_before_pipe = some_pipe
    specific_pipe_cls = SomePipe
    
    for pipe in all_pipes:
        if specific_pipe_cls==pipe:
            pipe.attr = new_before_pipe(pipe.attr) 

This will insert the `new_before_pipe` inbetween 2 existing pipes if `pipe` is the pipe we want to insert `new_before_pipe` before (determined by `specific_pipe_cls`). So when pipe is iterated through, `new_before_pipe` will execute before pipe, but after pipe.attr.

After alg:
    
    new_before_pipe = some_pipe
    specific_pipe_cls = SomePipe
    for pipe in all_pipes:
        for attr in pipe.attr:
            if specific_pipe_cls==attr:
                pipe.attr = new_before_pipe(pipe.attr) 

This will insert the `new_before_pipe` inbetween 2 existing pipes if `attr in pipe.attrs` is the pipe we want to insert `new_before_pipe` after (determined by `specific_pipe_cls`). So when pipe is iterated through, `new_before_pipe` will execute after `attr`.



In [31]:
# export
def add_cbs_to_pipes(pipe,cbs):
    for _pipe in reversed(find_pipes(PassThroughIterPipe(pipe),lambda o:True)): pipe = _pipe.add_cbs_after(cbs)
    for _pipe in reversed(find_pipes(PassThroughIterPipe(pipe),lambda o:True)): pipe = _pipe.add_cbs_before(cbs)
    return pipe

In [32]:
cb = TestCallback()

pipe = dp.iter.IterableWrapper([1,2,3,4,5,6])
pipe = dp.iter.Batcher(pipe,batch_size=2)
# The base pipeline is fully constructed...
# for _pipe in reversed(find_pipes(pipe,lambda o:True)): pipe = _pipe.add_cbs_before(L(cb))

pipe = add_cbs_to_pipes(pipe,L(cb))

list(pipe),pipe

adding 1.06
adding 2.0599999999999996
adding 3.0599999999999996
adding 4.06
adding 5.06
adding 6.06


([[3.27, 4.27],
  [5.269999999999999, 6.269999999999999],
  [7.269999999999999, 8.269999999999998]],
 Point1Adder)

In [33]:
find_pipes(pipe,lambda o:True)

[Point1Adder,
 Point5Adder,
 Point1Adder,
 Point5Adder,
 BatcherIterDataPipe,
 PointZeroOne,
 OneAdder,
 PointZeroOne,
 PointZeroFiveAdder,
 IterableWrapperIterDataPipe]

In [34]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script(silent=True)

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
