In [11]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [12]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [13]:
# default_exp fastai.data.loop

In [14]:
# export
# Python native modules
import os
import logging
# Third party libs
from fastcore.all import *
import torchdata.datapipes as dp
from torch.utils.data.graph import traverse
# Local modules
from fastrl.fastai.data.pipes.demux import *
from fastrl.fastai.data.pipes.mux import *

_logger = logging.getLogger()

In [15]:
logging.basicConfig(level='INFO')

# Loop
> Customizable loop API for fastrl.

so we need loops within loops posibly. Or do they need to be loops? maybe they need to 
just be sections? do we even need sections? I wonder if we can leverage the torch data as more 
of a base... I want to see how far we can get with this....

we also need to think about whether this is always iterable or whether we can mix and match
map vs iter...






In [16]:
class Callback():
    "A list of data pipes that have an associated job."
    call_on=L()

In [17]:

class Loop(dp.iter.IterDataPipe):
    "A datapipe with nesting and callback capabilities."

    def set_cbs(self,cbs): self.callbacks = cbs
    
    def filter_cbs(self, cbs):
        return tuple(cb for cb in cbs if self.__class__ in cb.call_on)

In [18]:
# export
# class ExceptionCatcher(Loop):
#     def __init__(self, source_datapipe, exception_cls, **kwargs) -> None:
#         self.source_datapipe = source_datapipe
#         self.exception_cls = exception_cls
#         self.kwargs = kwargs
    
#     def __iter__(self) -> Iterator[T_co]:
#         while True:
#             try:
#                 yield from self.source_datapipe
#             except self.exception_cls:
#                 return

class Iterationer(Loop):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            # xb,yb = next(self.source_datapipe)
            # pred = self.kwargs['model'](xb)
            # loss_grad = self.kwargs['loss_grad']
            # opt = self.kwargs['opt']
            # loss_func = self.kwargs['loss_func']
            # if len(yb):
            #     loss_grad = self.loss_func(pred, *yb)
            #     loss = loss_grad.clone()
            # self('after_loss')
            # if not self.training or not len(yb): return
            # self('before_backward')
            # loss_grad.backward()
            # opt.step()
            # opt.zero_grad()
            # yield loss.detach()
            yield 0.5

class Batcher(Loop):
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
        
class Trainer(Loop):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
                
class Validater(Loop):
        
    def __init__(self, source_datapipe, **kwargs) -> None:
        self.source_datapipe = source_datapipe
        self.kwargs = kwargs
    
    def __iter__(self) -> Iterator[T_co]:
        for element in self.source_datapipe:
            yield element
               
class Epocher(Loop):
    def __init__(self, source_datapipes:tuple, **kwargs) -> None:
        test_eq(type(source_datapipes), tuple)
        self.source_datapipes = source_datapipes
        self.kwargs = kwargs
    
    def __iter__(self) -> Iterator[T_co]:
        for element in zip(*self.source_datapipes):
            yield element
                
class Fitter(Loop):
    def __init__(self,iterable):
                
        trainer,validater = Trainer(iterable),Validater(iterable)
        train_b, valid_b = Batcher(trainer),Batcher(validater)
        train_it,valid_it = Iterationer(train_b),Iterationer(valid_b)
        self.epocher = dp.iter.Zipper(train_it,valid_it)
            
    def __iter__(self):
        for epoch in self.epocher: yield epoch

        
class TrainerCallback():
    "A list of data pipes that have an associated job."
    call_on=L(Trainer)
class IterCallback():
    "A list of data pipes that have an associated job."
    call_on=L(Iterationer)


In [24]:
# export
def default_constructor(
    datapipe:Union[Loop,Dict], 
    cbs:List[Callback]
):
    d = datapipe if isinstance(datapipe,dict) else traverse(datapipe)
    
    for k,v in d.items():
        if issubclass(k.__class__,Loop): 
            filtered_cbs = k.filter_cbs(cbs)
            _logger.info('Given loop: %s, found callbacks: %s',k,filtered_cbs)
            k.set_cbs(filtered_cbs)
        if not v: continue
        default_constructor(v, cbs)
        

In [25]:
default_constructor(
    Fitter([1,2,3,4,5,6]).epocher,
    [TrainerCallback,IterCallback]
)

INFO:root:Given loop: <__main__.Iterationer object at 0x7fa617348790>, found callbacks: (<class '__main__.IterCallback'>,)
INFO:root:Given loop: <__main__.Batcher object at 0x7fa617348410>, found callbacks: ()
INFO:root:Given loop: <__main__.Trainer object at 0x7fa617348d50>, found callbacks: (<class '__main__.TrainerCallback'>,)
INFO:root:Given loop: <__main__.Iterationer object at 0x7fa617348c90>, found callbacks: (<class '__main__.IterCallback'>,)
INFO:root:Given loop: <__main__.Batcher object at 0x7fa617348e90>, found callbacks: ()
INFO:root:Given loop: <__main__.Validater object at 0x7fa617348ed0>, found callbacks: ()


In [21]:
traverse(Fitter([1,2,3,4,5,6]).epocher)

{<torch.utils.data.datapipes.iter.combining.ZipperIterDataPipe at 0x7fa61735b7d0>: {<__main__.Iterationer at 0x7fa61735b150>: {<__main__.Batcher at 0x7fa617348f10>: {<__main__.Trainer at 0x7fa617348610>: {}}},
  <__main__.Iterationer at 0x7fa61735b450>: {<__main__.Batcher at 0x7fa617348e50>: {<__main__.Validater at 0x7fa617348f50>: {}}}}}

In [22]:
list(Fitter([1,2,3,4,5,6]))

[(0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)]

In [23]:
# export
@events
class Outer(Loop):
    
    def before_step(self) :  print('before_step')
    def on_step(self)     :  print('on_step')
    def after_step(self)  :  print('after_step')
    def failed_step(self) :  print('failed_step')
    def finally_step(self):  print('finally_step')
 
    def before_jump(self) :  print('before_jump')
    def on_jump(self)     :  print('on_jump')
    def after_jump(self)  :  print('after_jump')
    def failed_jump(self) :  print('failed_jump')
    def finally_jump(self):  print('finally_jump')

class Inner(Loop):
    call_on=L(Outer.on_step,Outer.after_step,Outer.finally_jump)
    
    @event
    def before_iteration(self) : print('before_iteration')
    @event
    def on_iteration(self)     : print('on_iteration')
    @event
    def after_iteration(self)  : print('after_iteration')
    @event
    def failed_iteration(self) : print('failed_iteration')
    @event
    def finally_iteration(self): print('finally_iteration')
    
    def thingy(self): pass

class FailingInner(Loop):
    call_on=L(Inner.finally_iteration)
    
    @event
    def on_force_fail(self):                    
        print('on_force_fail')
        raise Exception

NameError: name 'events' is not defined

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbdev.cli import *
    make_readme()
    notebook2script(silent=True)