In [93]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [94]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [95]:
# default_exp fastai.data.loop

In [96]:
# export
# Python native modules
import os
import logging
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.graph import traverse
# Local modules
from fastrl.fastai.data.pipes.demux import *
from fastrl.fastai.data.pipes.mux import *

_logger = logging.getLogger()

# Loop
> Customizable loop API for fastrl.

so we need loops within loops posibly. Or do they need to be loops? maybe they need to 
just be sections? do we even need sections? I wonder if we can leverage the torch data as more 
of a base... I want to see how far we can get with this....

we also need to think about whether this is always iterable or whether we can mix and match
map vs iter...






In [97]:
# export
class Callback():
    "A list of data pipes that have an associated job."
    call_on = L()
    exclude_under = L()
    do_copy = False
    immediate_parents = L()
    root_parent = None
    
    def set_parents(self,immediate_parent):
        if immediate_parent is not None: 
            self.immediate_parents.append(immediate_parent)

In [98]:
# export
class Loop(dp.iter.IterDataPipe):
    "A datapipe with nesting and callback capabilities."
    callbacks = L()
    
    def set_cbs(self,cbs): 
        self.callbacks = [cb() if isinstance(cb, type) else cb for cb in cbs]
        for cb in self.callbacks: cb.set_parents(self)
    
    def filter_call_on_cbs(self, cbs):
        return tuple(cb for cb in cbs if self.__class__ in cb.call_on)
    
    def __repr__(self): return f'{self.__class__} {self.callbacks}'
    
    def handle_exeption(self,ex): raise
    
    def __iter__(self):
        try:
            for cb in self.callbacks: 
                getattr(cb,'before_'+self.__class__.__name__.lower(),noop)()
            for record in self.__subiter__():
            
                for cb in self.callbacks: 
                    getattr(cb,'on_'+self.__class__.__name__.lower(),noop)()
                yield record
                
            for cb in self.callbacks: 
                getattr(cb,'after_'+self.__class__.__name__.lower(),noop)()
        except Exception as e:
            for cb in self.callbacks: 
                getattr(cb,'failed_'+self.__class__.__name__.lower(),noop)()
            self.handle_exeption(e)
        finally:
            for cb in self.callbacks: 
                getattr(cb,'finally_'+self.__class__.__name__.lower(),noop)()

In [105]:
# export
def filter_exclude_under_cbs(
    pipe:Union[Loop,dp.iter.IterDataPipe], 
    cbs:List[Callback]
):
    cbs = tuple(cb for cb in cbs if pipe.__class__  not in cb.exclude_under)
    for v in traverse(pipe).values():
        for k,_ in v.items():
            cbs = filter_exclude_under_cbs(k,cbs)
    return cbs

In [100]:
class Iterationer(Loop):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    
    def __subiter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            # xb,yb = next(self.source_datapipe)
            # pred = self.kwargs['model'](xb)
            # loss_grad = self.kwargs['loss_grad']
            # opt = self.kwargs['opt']
            # loss_func = self.kwargs['loss_func']
            # if len(yb):
            #     loss_grad = self.loss_func(pred, *yb)
            #     loss = loss_grad.clone()
            # self('after_loss')
            # if not self.training or not len(yb): return
            # self('before_backward')
            # loss_grad.backward()
            # opt.step()
            # opt.zero_grad()
            # yield loss.detach()
            yield 0.5

class Batcher(Loop):
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    def __subiter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
        
class Trainer(Loop):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    def __subiter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
                
class Validater(Loop):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    def __subiter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
               
class Epocher(Loop):
    def __init__(self, source_datapipes:tuple, **kwargs) -> None:
        test_eq(type(source_datapipes), tuple)
        self.source_datapipes = source_datapipes
        self.kwargs = kwargs
    
    def __subiter__(self) -> Iterator[T_co]:
        for element in zip(*self.source_datapipes):
            yield element
                
class Fitter(Loop):
    def __init__(self,iterable):
                
        trainer,validater = Trainer(iterable),Validater(iterable)
        train_b, valid_b = Batcher(trainer),Batcher(validater)
        train_it,valid_it = Iterationer(train_b),Iterationer(valid_b)
        self.epocher = dp.iter.Zipper(train_it,valid_it)
        self.source_datapipe = self.epocher
            
    def __subiter__(self):
        for epoch in self.source_datapipe: 
            yield epoch

        
class TrainerCallback(Callback):
    "A list of data pipes that have an associated job."
    call_on=L(Trainer)
    exclude_under=L()
    
    def on_trainer(self): print('on_trainer')
    
class IterCallback(Callback):
    "A list of data pipes that have an associated job."
    call_on=L(Iterationer)
    exclude_under=L(Validater)


In [101]:
# export
def default_constructor(
    datapipe:Union[Loop,Dict], 
    cbs:List[Callback],
    _outer=True
):
    if _outer and issubclass(datapipe.__class__,Loop):
        for cb in cbs: cb.root_parent=datapipe
    d = datapipe if isinstance(datapipe,dict) else traverse(datapipe)
    
    for k,v in d.items():
        if issubclass(k.__class__,Loop): 
            filtered_cbs = k.filter_call_on_cbs(cbs)
            _logger.info('Given loop: %s, found callbacks: %s',k.__class__,filtered_cbs)
            kept_cbs = filter_exclude_under_cbs(k,filtered_cbs)
            _logger.info('Given loop: %s, filtered callbacks: %s',k.__class__,kept_cbs)
            kept_cbs = [copy(cb) if cb.do_copy else cb for cb in kept_cbs]
            k.set_cbs(kept_cbs)
        if not v: continue
        default_constructor(v, cbs, _outer=False)
        

In [102]:
base_pipe = Fitter([1,2,3,4,5,6])

default_constructor(
    base_pipe,
    [TrainerCallback,IterCallback]
)
str(traverse(base_pipe))


INFO:root:Given loop: <class '__main__.Fitter'>, found callbacks: ()
INFO:root:Given loop: <class '__main__.Fitter'>, filtered callbacks: ()
INFO:root:Given loop: <class '__main__.Iterationer'>, found callbacks: (<class '__main__.IterCallback'>,)
INFO:root:Given loop: <class '__main__.Iterationer'>, filtered callbacks: (<class '__main__.IterCallback'>,)
INFO:root:Given loop: <class '__main__.Batcher'>, found callbacks: ()
INFO:root:Given loop: <class '__main__.Batcher'>, filtered callbacks: ()
INFO:root:Given loop: <class '__main__.Trainer'>, found callbacks: (<class '__main__.TrainerCallback'>,)
INFO:root:Given loop: <class '__main__.Trainer'>, filtered callbacks: (<class '__main__.TrainerCallback'>,)
INFO:root:Given loop: <class '__main__.Iterationer'>, found callbacks: (<class '__main__.IterCallback'>,)
INFO:root:Given loop: <class '__main__.Iterationer'>, filtered callbacks: ()
INFO:root:Given loop: <class '__main__.Batcher'>, found callbacks: ()
INFO:root:Given loop: <class '__mai

"{<class '__main__.Fitter'> []: {<torch.utils.data.datapipes.iter.combining.ZipperIterDataPipe object at 0x7ff07bd51a90>: {<class '__main__.Iterationer'> [<__main__.IterCallback object at 0x7ff07bd51410>]: {<class '__main__.Batcher'> []: {<class '__main__.Trainer'> [<__main__.TrainerCallback object at 0x7ff07bd5b3d0>]: {}}}, <class '__main__.Iterationer'> []: {<class '__main__.Batcher'> []: {<class '__main__.Validater'> []: {}}}}}}"

In [103]:
list(base_pipe)

on_trainer
on_trainer
on_trainer
on_trainer
on_trainer
on_trainer


[(0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)]

In [110]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script(silent=True)

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
