In [1]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
# default_exp fastai.data.pipes.core

In [4]:
# export
# Python native modules
import os
import logging
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.graph import traverse
from torchdata.datapipes import functional_datapipe
# Local modules


_logger = logging.getLogger()

# Pipes Core
> Callback + DataPipe support for highly flexible looping.

> Notes: about revising this one more time...

I think that Callback call_on might be too limited. We want to have a callback that call add datapipes at different points in the loop.

I think we can revise it to have for example:
    
```python

class Recorder(Callback):
    def initialize(self,on=None,before=Epocher,after=None):
        class initialized_recorder_pipe(dp.iter.DataPipe):
            def __iter__(self):
                self.mbar = master_bar(self.learn.epochs)
                yield from self.source_datapipe
        return initialized_recorder_pipe
        
    def accumulate(self,on=None,before=None,after=Predictor):
        class accumulated_recorder_pipe(dp.iter.DataPipe):
            def __iter__(self):
                for o in self.source_datapipe:
                    self.smooth_loss.append(self.learn.loss)
                    yield o   
        return accumulated_recorder_pipe
        
    def batch_update(self,on=Batcher,before=None,after=None):
        class batch_update_recorder_pipe(dp.iter.DataPipe):
            def __iter__(self):
                for o in self.source_datapipe:
                    self.pbar.update(self.learn.nbatch)
                    yield o   
        return batch_update_recorder_pipe
        
    def epoch_update(self,on=Epocher,before=None,after=None):
        class epoch_update_recorder_pipe(dp.iter.DataPipe):
            def __iter__(self):
                for o in self.source_datapipe:
                    self.mbar.update(self.learn.epoch)
                    yield o   
        return epoch_update_recorder_pipe

```

In [5]:
# export
class Callback():
    "A list of data pipes that have an associated job."
    call_on = L()
    exclude_under = L()
    do_copy = False
    immediate_parents = L()
    root_parent = None
    pipes = L()
    
    @property
    def name(self):
        "Name of the `Callback`, camel-cased and with '*Callback*' removed"
        return class2attr(self, 'Callback')
    
    def init_pipes(self):pass
    
    def set_parents(self,immediate_parent):
        if immediate_parent is not None: 
            self.immediate_parents.append(immediate_parent)

In [6]:
# export
dp.map.MapDataPipe.callbacks = L()
dp.iter.IterDataPipe.callbacks = L()

@patch
def __repr__(self:dp.map.MapDataPipe):
    if self.repr_hook is not None:
        return self.repr_hook(self)
    # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
    str_rep = str(self.__class__.__qualname__)
    if self.callbacks: return str_rep + str(self.callbacks)
    return str_rep

@patch
def __str__(self:dp.map.MapDataPipe):
    if self.str_hook is not None:
        return self.str_hook(self)
    # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
    str_rep = str(self.__class__.__qualname__)
    if self.callbacks: return str_rep + str(self.callbacks)
    return str_rep

@patch
def __repr__(self:dp.iter.IterDataPipe):
    if self.repr_hook is not None:
        return self.repr_hook(self)
    # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
    str_rep = str(self.__class__.__qualname__)
    if self.callbacks: return str_rep + str(self.callbacks)
    return str_rep

@patch
def __str__(self:dp.iter.IterDataPipe):
    if self.str_hook is not None:
        return self.str_hook(self)
    # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
    str_rep = str(self.__class__.__qualname__)
    if self.callbacks: return str_rep + str(self.callbacks)
    return str_rep

In [7]:
# export
def filter_call_on_cbs(obj, cbs): return tuple(cb for cb in cbs if obj.__class__ in cb.call_on)

In [8]:
# export
def filter_exclude_under_cbs(
    pipe:Union[dp.map.MapDataPipe,dp.iter.IterDataPipe], 
    cbs:List[Callback]
):
    cbs = tuple(cb for cb in cbs if pipe.__class__  not in cb.exclude_under)
    for v in traverse(pipe,only_datapipe=True).values(): # We dont want to traverse non-dp objects.
        for k,_ in v.items():
            cbs = filter_exclude_under_cbs(k,cbs)
    return cbs

Below is a simple example of a custom training setup. We will be re-defining these later. Below demonstrates a neive implimentation.

In [9]:
from fastrl.fastai.data.pipes.map.demux import *
from fastrl.fastai.data.pipes.map.mux import *

class Iterationer(dp.iter.IterDataPipe):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe: yield 0.5

class Trainer(dp.iter.IterDataPipe):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs

    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
                
class Validater(dp.iter.IterDataPipe):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
        
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
               
class Epocher(dp.iter.IterDataPipe):
    def __init__(self, source_datapipes:tuple, **kwargs) -> None:
        test_eq(type(source_datapipes), tuple)
        self.source_datapipes = source_datapipes
        self.kwargs = kwargs
    
    def __iter__(self) -> Iterator[T_co]:
        for element in zip(*self.source_datapipes):
            yield element
                
class TrainerShower(dp.iter.IterDataPipe):
    "Prints `on_trainer`."
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
    
    def __iter__(self):
        for o in self.source_datapipe:
            print('on_trainer')
            yield o
            
class IterShower(dp.iter.IterDataPipe):
    "Prints `on_iter`."
    def __init__(self, source_datapipe, **kwargs) -> None:
        print('show init')
        self.source_datapipe = source_datapipe
    
    def __iter__(self):
        for o in self.source_datapipe:
            print('on_iter')
            yield o
            
class TrainerCallback(Callback):
    call_on=L(Trainer)
    exclude_under=L(Validater)
    pipes=L(TrainerShower)
    
class IterCallback(Callback):
    call_on=L(Iterationer)
    exclude_under=L(Validater)
    pipes=L(IterShower)

class Fitter(dp.iter.IterDataPipe):
    def __init__(self,iterable,cbs):
                
        trainer,validater = Trainer(iterable).add_cbs(cbs),Validater(iterable).add_cbs(cbs)
        train_b,valid_b = L(trainer,validater).map(dp.iter.Batcher,batch_size=2).map(Self.add_cbs(cbs))
        train_it,valid_it = L(train_b,valid_b).map(Iterationer).map(Self.add_cbs(cbs))
        epocher = dp.iter.Zipper(train_it,valid_it).add_cbs(cbs)
        self.source_datapipe = epocher
            
    def __iter__(self):
        for epoch in self.source_datapipe: 
            yield epoch
        

In [10]:
# export
# @patch
def add_cbs(self,cbs):
    pipe = self
    if cbs is None or len(cbs)==0: return pipe
    cbs = filter_call_on_cbs(self,cbs)
    cbs = filter_exclude_under_cbs(self,cbs)
    for cb in cbs:
        for dp in cb.pipes: pipe = dp(pipe)
    return pipe


patch_to(dp.map.MapDataPipe)(add_cbs)
patch_to(dp.iter.IterDataPipe)(add_cbs)

# @patch
# def add_cbs(self:dp.iter.IterDataPipe,cbs):
#     pipe = self
#     if cbs is None or len(cbs)==0: return pipe
#     cbs = filter_call_on_cbs(self,cbs)
#     cbs = filter_exclude_under_cbs(self,cbs)
#     for cb in cbs:
#         for dp in cb.pipes: pipe = dp(pipe)
#     return pipe
        

In [11]:
from pprint import pprint
base_pipe = Fitter([1,2,3,4,5,6],[TrainerCallback,IterCallback])
pprint(traverse(base_pipe))

show init
{Fitter: {ZipperIterDataPipe: {IterShower: {Iterationer: {BatcherIterDataPipe: {TrainerShower: {Trainer: {}}}}},
                               Iterationer: {BatcherIterDataPipe: {Validater: {}}}}}}


In [12]:
list(base_pipe)

on_trainer
on_trainer
on_iter
on_trainer
on_trainer
on_iter
on_trainer
on_trainer
on_iter


[(0.5, 0.5), (0.5, 0.5), (0.5, 0.5)]

In [13]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script(silent=True)

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
