In [56]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [57]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display = Display(visible=0, size=(400, 300))
    display.start()

In [58]:
# default_exp fastai.loop

In [59]:
# export
# Python native modules
import os
# Third party libs
from fastcore.all import *
# Local modules

# Loop
> fastrl concept of generic loop objects. 

The goal for Loops is to make it easy to customize, and know how sections of code connects
to other parts.

### Why do we need this?
We have identified at least 3 different kinds of loops already:

    Learner (training)
    Source/Gym (Data Access)
    Agent (How an AI takes in data, generates actions)

### What is a loop?

    It should be capable of containing inner loops. 
    It should be able to handle "phases" that might be similar to each other. 
    It should self-describe its structure. 
    It should be easy to know which parts of the loop are taking long/short amounts of time.
    It should be flexible in state modification.
    It should alternatively make it easy show what fields are being changed at what points in time.

In [270]:
# export 
class Hook(object):
    def __init__(self,function,phase=None): 
        store_attr()
        self.sig=inspect.signature(function)
        
        if '_phase' not in self.sig.parameters or phase is not None: 
            self.phase=phase 
        else: 
            self.phase=L(self.sig.parameters['_phase'].default)
            self.phase=self.phase[0]
        
        if '_order' not in self.sig.parameters: self.order=-1 
        else: self.order=self.sig.parameters['_order'].default
        
    def __lt__(self,o:'Hook'): return self.order<o.order

    def __str__(self): return f'{self.order}_{self.phase}_{self.function.__name__}'
    def __hash__(self): return str(self)
    def __repr__(self): return str(self)

    def copy(self,**kwargs):
        h=Hook(self.function,self.phase)
        for k,v in kwargs.items(): setattr(h,k,v)
        return h
        
    @classmethod
    def ishook(cls,name)->bool:
        return not name.startswith('_') and \
          (name.startswith('on_') or name.startswith('after_') or \
           name.startswith('before_') or name.startswith('failed_') or \
           name.startswith('finally_'))

In [271]:
def test():pass
def test2(_order=4):pass
def test3(_phase='valid'):pass

In [272]:
Hook(test),Hook(test2),Hook(test3)

(-1_None_test, 4_None_test2, -1_valid_test3)

In [283]:
# export
def in_phase(hook:Hook,current_phase:str): return hook.phase in [None,current_phase]

class Loop(GetAttr):
    _default='_base'
    phases=None
    
    def __init__(self,
                 inner_loops:Optional[List['Loop']]=None, # Internal `Loop` objects that are called on `indexes` 
                 cbs:Optional[List['Callback']]=None, # Callbacks to be used for a given `Loop`. These will also be used in internal loops.
                 persist_cbs:bool=False, # Whether to review callbacks once a loop ends.
                 call_on:List[str]=None, # A list of methods to run this loop on.
                 parent:'Loop'=None, # Since loops can be nested, they need to reference the parent loops
                 phases:List[str]=None # Phases that a loop supports, for example train,validate,test...
                ):
        store_attr(but='cbs,inner_loops,call_on,phases')
        self.cbs=L()
        self.inner_loops=L()
        self.call_on=L(call_on)
        self.add_objs(L(cbs),self.add_cb)
        self.add_objs(L(inner_loops),self.add_loop)
        self._base=None
        self.__class__.phases=phases
        self.__class__.phase=None
        
    def _grab_objs(self, obj_cls,ls): 
        return L(o for o in ls if isinstance(o, obj_cls))
        
    def add_loop(self, loop):
        "Instantiate, add to `self.loops`"
        if isinstance(loop, type):   loop = loop(base_loop=self)
        elif loop.parent is None: loop.parent=ifnone(self.parent,self)
        self.inner_loops.append(loop)
        return self

    def remove_loop(self, loop):
        "Instantiate, remove from `self.loops`"
        loop.base_loop=None
        if isinstance(loop, type): 
            self.remove_objs(self._grab_objs(loop,self.inner_loops),self.remove_loop)
        else:
            if loop in self.inner_loops: self.inner_loops.remove(loop)
        return self
        
    def add_objs(self,objs,adder):
        "add `cbs` to `self`"
        L(objs).map(adder)
        return self

    def remove_objs(self,objs,remover):
        "rm all `objs` from `self` using `remover`"
        L(objs).map(remover)
        return self
        
    def add_cb(self, cb):
        "Instantiate, set as field in self, set as field in self._default, and add to `self.cbs`"
        if isinstance(cb, type): cb = cb()
        setattr(cb,self._default,self)
        setattr(self, cb.name, cb)
        self.cbs.append(cb)
        return self

    def remove_cb(self, cb):
        "Instantiate, rm `cb` from self, rm `cb` from self._default, and remove from `self.cbs`"
        if isinstance(cb, type): 
            self.remove_objs(self._grab_objs(cb,self.cbs),self.remove_cb)
        else:
            setattr(cb,self._default,None)
            if hasattr(self, cb.name): delattr(self, cb.name)
            if cb in self.cbs: self.cbs.remove(cb)
        return self
    
    def switch_phase(self,phase): self.phase=phase
    
    @classmethod
    def hooks(cls)->List[Hook]:
        hooks=L(Hook(hook) for k,hook in inspect.getmembers(cls) if Hook.ishook(k))
        return hooks.sorted()
        
    def hooks_with_phases(self,phase=None):
        hooks=self.hooks()
        hooks_with_phases=L()
        for phase in L(ifnone(phase,self.phases)):  
            hooks_with_phases+=L(Hook(self.switch_phase,phase=phase))
            hooks_with_phases+=hooks.filter(in_phase,current_phase=phase).map(Self.copy(phase=phase))
            
            
            return L(hooks_with_phases)

In [285]:
class Outer(Loop):
    def before_step(self):  return None
    def on_step(self,_order=2):      return None
    def after_step(self,_order=3):   return None
    def failed_step(self,_order=4):  return None
    def finally_step(self,_order=5,_phase='valid'): return None

class Inner(Loop):
    def before_iteration(self):  return None
    def on_iteration(self):      return None
    def after_iteration(self):   return None
    def failed_iteration(self):  return None
    def finally_iteration(self): return None

In [286]:
loop=Outer(phases=L(['train','valid']),
           inner_loops=Inner(call_on=['on_step']))
list(loop.hooks_with_phases())

[-1_train_switch_phase,
 -1_train_before_step,
 2_train_on_step,
 3_train_after_step,
 4_train_failed_step,
 -1_valid_switch_phase,
 -1_valid_before_step,
 2_valid_on_step,
 3_valid_after_step,
 4_valid_failed_step,
 5_valid_finally_step]

In [280]:
loop=Outer(inner_loops=Inner())
list(loop.hooks())

[-1_None_before_step,
 2_None_on_step,
 3_None_after_step,
 4_None_failed_step,
 5_valid_finally_step]

In [None]:
# export
class CustomLoop(Loop):pass

In [None]:
# export 
class CustomCallback()

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()