In [1]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [3]:
# default_exp fastai.data.loop.core

In [4]:
# export
# Python native modules
import os
import logging
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.graph import traverse
from torchdata.datapipes import functional_datapipe
# Local modules


_logger = logging.getLogger()

# Loop
> Customizable loop API for fastrl.

so we need loops within loops posibly. Or do they need to be loops? maybe they need to 
just be sections? do we even need sections? I wonder if we can leverage the torch data as more 
of a base... I want to see how far we can get with this....

we also need to think about whether this is always iterable or whether we can mix and match
map vs iter...






In [5]:
# export
class Callback():
    "A list of data pipes that have an associated job."
    call_on = L()
    exclude_under = L()
    do_copy = False
    immediate_parents = L()
    root_parent = None
    pipes = L()
    
    def set_parents(self,immediate_parent):
        if immediate_parent is not None: 
            self.immediate_parents.append(immediate_parent)

In [6]:
# export
def filter_call_on_cbs(obj, cbs): return tuple(cb for cb in cbs if obj.__class__ in cb.call_on)

In [7]:
# export
@patch
def attach_callbacks(self:dp.map.MapDataPipe,cbs):
    pipe = self
    cbs = filter_call_on_cbs(self,cbs)
    for cb in cbs:
        for dp in cb.pipes: pipe = dp(pipe)
    return pipe

@patch
def attach_callbacks(self:dp.iter.IterDataPipe,cbs):
    pipe = self
    cbs = filter_call_on_cbs(self,cbs)
    for cb in cbs:
        for dp in cb.pipes: pipe = dp(pipe)
    return pipe
        

In [8]:
# export
dp.map.MapDataPipe.callbacks = L()
dp.iter.IterDataPipe.callbacks = L()

@patch
def __repr__(self:dp.map.MapDataPipe):
    if self.repr_hook is not None:
        return self.repr_hook(self)
    # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
    str_rep = str(self.__class__.__qualname__)
    if self.callbacks: return str_rep + str(self.callbacks)
    return str_rep

@patch
def __str__(self:dp.map.MapDataPipe):
    if self.str_hook is not None:
        return self.str_hook(self)
    # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
    str_rep = str(self.__class__.__qualname__)
    if self.callbacks: return str_rep + str(self.callbacks)
    return str_rep

@patch
def __repr__(self:dp.iter.IterDataPipe):
    if self.repr_hook is not None:
        return self.repr_hook(self)
    # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
    str_rep = str(self.__class__.__qualname__)
    if self.callbacks: return str_rep + str(self.callbacks)
    return str_rep

@patch
def __str__(self:dp.iter.IterDataPipe):
    if self.str_hook is not None:
        return self.str_hook(self)
    # Instead of showing <torch. ... .MapperMapDataPipe object at 0x.....>, return the class name
    str_rep = str(self.__class__.__qualname__)
    if self.callbacks: return str_rep + str(self.callbacks)
    return str_rep

In [9]:
# export
def set_cbs(loop,cbs): 
    name = loop.__class__.__name__.lower()
    loop.callbacks = [cb() if isinstance(cb, type) else cb for cb in cbs]
    for cb in loop.callbacks: cb.set_parents(loop)
    for s in ['before','on','after','failed','finally']:
        setattr(loop,f'cb_{s}', L(getattr(cb,f'{s}_{name}') for cb in loop.callbacks if hasattr(cb,f'{s}_{name}')))

In [10]:
# export
def filter_exclude_under_cbs(
    pipe:Union[dp.map.MapDataPipe,dp.iter.IterDataPipe], 
    cbs:List[Callback]
):
    cbs = tuple(cb for cb in cbs if pipe.__class__  not in cb.exclude_under)
    for v in traverse(pipe).values():
        for k,_ in v.items():
            cbs = filter_exclude_under_cbs(k,cbs)
    return cbs

Below is a simple example of a custom training setup. We will be re-defining these later. Below demonstrates a neive implimentation.

In [11]:
# export
from functools import wraps

def soft_compose(loop,attr): return compose(*getattr(loop,attr,L()))

def callback_iter(f):
    @wraps(f)
    def _inner(self):
        try:
            soft_compose(self,'cb_after')()
            for record in f(self):
                self.x = record
                soft_compose(self,'cb_on')()
                yield record
                del self.x
            soft_compose(self,'cb_after')()
        except Exception:
            if len(getattr(self,'cb_failed',L()))==0: raise
            else:                                     soft_compose(self,'cb_failed')()
        finally:
            soft_compose(self,'cb_finally')()
    return _inner

def callback_getitem(f):
    @wraps(f)
    def _inner(self, index):
        ex_occured = False
        try:
            soft_compose(self,'cb_before')()
            self.x = f(self, index)
            soft_compose(self,'cb_on')()
            return self.x
        except Exception:
            ex_occured = True
            if len(getattr(self,'cb_failed',L()))==0: raise
            else:                                     soft_compose(self,'cb_failed')()
        finally:
            del self.x
            if not ex_occured: 
                soft_compose(self,'cb_after')()
            soft_compose(self,'cb_finally')()
    return _inner


In [12]:
# export
dp.map.Batcher.__getitem__ = callback_getitem(dp.map.Batcher.__getitem__)
dp.iter.Batcher.__iter__ = callback_iter(dp.iter.Batcher.__iter__)

In [13]:
pipe = dp.iter.IterableWrapper([1,2,3,4,5])
batch_dp = dp.iter.Batcher(pipe,2)

In [14]:
list(batch_dp)

[[1, 2], [3, 4], [5]]

In [15]:
from fastrl.fastai.data.pipes.map.demux import *
from fastrl.fastai.data.pipes.map.mux import *

class Iterationer(dp.iter.IterDataPipe):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    @callback_iter
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe: yield 0.5


class Trainer(dp.iter.IterDataPipe):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    @callback_iter
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
                
class Validater(dp.iter.IterDataPipe):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
        
    @callback_iter
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
               
class Epocher(dp.iter.IterDataPipe):
    def __init__(self, source_datapipes:tuple, **kwargs) -> None:
        test_eq(type(source_datapipes), tuple)
        self.source_datapipes = source_datapipes
        self.kwargs = kwargs
    
    @callback_iter
    def __iter__(self) -> Iterator[T_co]:
        for element in zip(*self.source_datapipes):
            yield element
                
class Fitter(dp.iter.IterDataPipe):
    def __init__(self,iterable):
                
        self.trainer,self.validater = Trainer(iterable),Validater(iterable)
        self.train_b, self.valid_b = dp.iter.Batcher(self.trainer,2),dp.iter.Batcher(self.validater,2)
        self.train_it,self.valid_it = Iterationer(self.train_b),Iterationer(self.valid_b)
        self.epocher = dp.iter.Zipper(self.train_it,self.valid_it)
        self.source_datapipe = self.epocher
            
    @callback_iter
    def __iter__(self):
        for epoch in self.source_datapipe: 
            yield epoch
        
class TrainerCallback(Callback):
    "A list of data pipes that have an associated job."
    call_on=L(Trainer)
    exclude_under=L()
    
    def on_trainer(self): print('on_trainer')
    
class IterCallback(Callback):
    "A list of data pipes that have an associated job."
    call_on=L(Iterationer)
    exclude_under=L(Validater)


In [16]:
# export
def default_constructor(
    datapipe:Union[dp.map.MapDataPipe,dp.iter.IterDataPipe,Dict], 
    cbs:List[Callback],
    _outer=True # Only used to differentiate between recursive calls and initial.
):
    if _outer:
        for cb in cbs: cb.root_parent=datapipe
    d = datapipe if isinstance(datapipe,dict) else traverse(datapipe)
    
    for k,v in d.items():
        filtered_cbs = filter_call_on_cbs(k,cbs)
        _logger.info('Given loop: %s, found callbacks: %s',k.__class__,filtered_cbs)
        kept_cbs = filter_exclude_under_cbs(k,filtered_cbs)
        _logger.info('Given loop: %s, filtered callbacks: %s',k.__class__,kept_cbs)
        kept_cbs = [copy(cb) if cb.do_copy else cb for cb in kept_cbs]
        set_cbs(k,kept_cbs)
        if not v: continue
        default_constructor(v, cbs, _outer=False)
        

In [17]:
base_pipe = Fitter([1,2,3,4,5,6])

default_constructor(
    base_pipe,
    [TrainerCallback,IterCallback]
)
str(traverse(base_pipe))


'{Fitter: {Trainer[<__main__.TrainerCallback object at 0x7fa2d1001050>]: {}, Validater: {}, BatcherIterDataPipe: {Trainer[<__main__.TrainerCallback object at 0x7fa2d1001050>]: {}}, BatcherIterDataPipe: {Validater: {}}, Iterationer[<__main__.IterCallback object at 0x7fa2d1003610>]: {BatcherIterDataPipe: {Trainer[<__main__.TrainerCallback object at 0x7fa2d1001050>]: {}}}, Iterationer: {BatcherIterDataPipe: {Validater: {}}}, ZipperIterDataPipe: {Iterationer[<__main__.IterCallback object at 0x7fa2d1003610>]: {BatcherIterDataPipe: {Trainer[<__main__.TrainerCallback object at 0x7fa2d1001050>]: {}}}, Iterationer: {BatcherIterDataPipe: {Validater: {}}}}}}'

In [18]:
list(base_pipe)

on_trainer
on_trainer
on_trainer
on_trainer
on_trainer
on_trainer


[(0.5, 0.5), (0.5, 0.5), (0.5, 0.5)]

In [19]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script(silent=True)

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
