In [3]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [37]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display=Display(visible=0,size=(400,300))
    display.start()

In [5]:
# default_exp fastai.loop

In [36]:
# export
# Python native modules
import os
from copy import deepcopy
from typing import *
import types
import logging
# Third party libs
from fastcore.all import *
# Local modules

_logger=logging.getLogger(__name__)

# Loop
> fastrl concept of generic loop objects. 

The goal for Loops is to make it easy to customize, and know how sections of code connects
to other parts.

### Why do we need this?
We have identified at least 3 different kinds of loops already:

    Learner (training)
    Source/Gym (Data Access)
    Agent (How an AI takes in data, generates actions)

### What is a loop?

    It should be capable of containing inner loops. 
    It should be able to handle "phases" that might be similar to each other. 
    It should self-describe its structure. 
    It should be easy to know which parts of the loop are taking long/short amounts of time.
    It should be flexible in state modification.
    It should alternatively make it easy show what fields are being changed at what points in time.

A Loop will act as a compiled structure. The actual result will be a compiled list of nodes that reference the original loop.

Loop is a compiled object that organizes the callbacks and loop calls into a possibly repeating sequence.

`Literal['order',int]` contains a int that should be more than or equal to 0. This determines when a 
function in the loops should be executed relative to other functions.

In [7]:
class TestLoop:
    def on_iteration(self)->Literal['order',3]:pass
    def on_step(self)->Literal['order',3]:pass

In [43]:
# export
PREFIXES=['on_','after_','before_','failed_','finally_']
PRE2ORDER={'on_':1,'after_':1.1,'before_':0.9,'failed_':1.2,
           'finally_':1.3}

class Node(object):
    def __init__(self,
                 function:Callable, # The function to be called on this node.
                 base_loop=None,         
                 # The BASE loop that this node is a part of. This will be
                 # different from the loop it is defined in.
                 loop=None,
                 children:List['Node']=None,
                 parent:'Node'=None,
                 order=0,
                 call_on='',
                ):
        store_attr()
        if not self.isvalid(function):
            raise ValueError(f"""{function} cannot be private ('_'), and must have
                                 either {PREFIXES}""")
        self.sig=inspect.signature(function)
        self.name=function.__name__
        self.order=order
        self.user_defined_order=False
        for anno in L(anno_ret(function)):
            if 'order' in anno.__args__: 
                self.order+=anno.__args__[-1]
                self.user_defined_order=True
            
        for pre in PREFIXES: 
            if self.name.startswith(pre): 
                self.postfix=self.name.strip(pre)
                self.prefix=pre
                # If order is 0, then we will define the order based on the prefix
                if self.order==0: self.order=PRE2ORDER[pre]
                break
        
    def __round__(self): return round(self.order)  
    def __lt__(self,o:'Node'): return self.order<o.order
            
    def __str__(self):
        base=f"{self.function.__name__}"
        if self.call_on!='': base+=f" -> on -> {self.call_on}"
        return base
    
    def __repr__(self): return str(self)
    def __hash__(self): return hash(f'{self.call_on} {str(self)}')

    @property
    def level(self):
        return 1 if self.parent is None else (1+self.parent.level)

    @classmethod
    def isvalid(cls,name)->bool:
        if isinstance(name,Callable): name=name.__name__
        return not name.startswith('_') and \
          any(name.startswith(pre) for pre in PREFIXES)
    
    def postfix(self):
        "Given `self.name` grab the postfix (i.e. without the `on_`,`after_`,etc.)"


        
def on_test()->Literal['order',3]: return {'this':'that'}

Node(on_test,0)

on_test

In [44]:
Node(on_test,0).postfix

'test'

In [45]:
Node(on_test,0).__hash__()

2112245071694279568

In [46]:
round

<function round(number, ndigits=None)>

In [58]:
class Loop(object):
    core_obj=None
    call_on=''
    
    @classmethod
    def nodes(cls)->List[Node]:
        nodes=L(Node(node) for k,node in inspect.getmembers(cls) if Node.isvalid(k))
        # For a given set of nodes, fix the orders, and defining the order if it
        # it not defined for a given node.
        for postfix,subnodes in groupby(nodes,lambda o:o.postfix).items():
            # For a given subnode, we want to see if the user defined any of the orders
            if any(n.user_defined_order for n in subnodes):
                # If the user defined ONE of the orders then we need to make sure 
                # the other nodes have orders that make sense.
                user_ordered_nodes=[n for n in subnodes if n.user_defined_order]
                if all([n.user_defined_order for n in subnodes]): pass
                elif len(user_ordered_nodes)==1:
                    max_order=round(user_ordered_nodes)
                    idxs=subnodes.argwhere(lambda o:not o.user_defined_order)
                    for subn in subnmodes[idxs]:subn.order+=max_order
                elif len(user_ordered_nodes)>1:
                    # If the user defined MORE THAN ONE of the orders, then we will
                    # warn them and not mess with the orders.
                    _logger.warning("""nodes: %s have %s user defined orders. 
                                       because there are more than 1 user defined
                                       orders, we will not be able to define them
                                       automatically.""",str(subnodes),len(user_ordered_nodes))
        # Duplicate the nodes if for a given node, it is called on 2 different
        # locations.
        final_nodes=L()
        if cls.call_on!='': 
            for n in nodes: 
                for call in L(cls.call_on.split(',')):
                    n.call_on=call
                    final_nodes.append(deepcopy(n))
        else:
            final_nodes=nodes
        return final_nodes.sorted()
    
    @classmethod
    def with_inner_loops(cls,inner_loops=None):
        return cls.nodes()+L(inner_loops).map(Self.with_inner_loops()).concat()
    
    @classmethod
    def organized_nodes(cls,inner_loops=None,as_dict=False)->Union[List[Node],Dict[str,Node]]:
        nodes=list(cls.with_inner_loops(inner_loops))
        # For the nodes that have `call_on`s it will be more efficient to map
        # them the function they are being called on. Nodes without call_ons we 
        # can be confident are the root nodes.
        call_on_nodes=L(nodes).filter(lambda o:o.call_on!='')
        nodes=L(nodes).filter(lambda o:o.call_on=='')
        call_ons=call_on_nodes.map(lambda o:o.call_on).unique()
        call_on_map={c:call_on_nodes.filter(lambda o:o.call_on==c).sorted() for c in call_ons}
        # For each node, add the call_on nodes if they exist 
        organized_nodes=L()
        order=0
        while nodes:
            n=nodes.pop(0)
            # The order will cascade ensuring that all the nodes have their 
            # respective orders organized. This is needed incase you execute `sorted` 
            # on the final list of nodes.
            if order==0: 
                order=n.order
            else:
                _order=order+n.order
                n.order+=order
                order=_order
                
            organized_nodes.append(n)
            if n.name in call_on_map:
                # If there are call_on nodes for `n`, then we want to process them
                # immediately after `n`. We reverse the list and add insert them
                # on index 0, thus the nodes that should go first will be in 
                # index 0-ish and the ones that should go later will be farthur 
                # down this list.
                n.children=call_on_map[n.name]
                for inner_n in reversed(call_on_map[n.name]): 
                    inner_n.parent=n
                    nodes.insert(0,inner_n)
        return organized_nodes

class Outer(Loop):
    def before_step(self)->Literal['order',1]:   return None
    def on_step(self)->Literal['order',2]:       return None
    def after_step(self)->Literal['order',3]:    return None
    def failed_step(self)->Literal['order',4]:   return None
    def finally_step(self)->Literal['order',5]:  return None
 
    def before_jump(self)->Literal['order',6]:   return None
    def on_jump(self)->Literal['order',7]:       return None
    def after_jump(self)->Literal['order',8]:    return None
    def failed_jump(self)->Literal['order',9]:   return None
    def finally_jump(self)->Literal['order',10]: return None

class Inner(Loop):
    call_on='on_step,after_step,finally_jump'
    
    def before_iteration(self)->Literal['order',1]: return None
    def on_iteration(self):                         return None
    def after_iteration(self):                      return None
    def failed_iteration(self):                     return None
    def finally_iteration(self):                    return None

In [59]:
groupby(Outer.nodes(),lambda o:o.postfix)

{'step': [before_step, on_step, after_step, failed_step, finally_step],
 'jump': [before_jump, on_jump, after_jump, failed_jump, finally_jump]}

In [60]:
list(Outer.nodes().map(lambda o:f'{o.order} {o.level} {o}'))

['1 1 before_step',
 '2 1 on_step',
 '3 1 after_step',
 '4 1 failed_step',
 '5 1 finally_step',
 '6 1 before_jump',
 '7 1 on_jump',
 '8 1 after_jump',
 '9 1 failed_jump',
 '10 1 finally_jump']

In [61]:
list(Outer.organized_nodes(Inner).map(lambda o:f'{o.order} {o.level} {o}'))

TypeError: type list doesn't define __round__ method

In [34]:
# export
def in_phase(hook:Hook,current_phase:str): return hook.phase in [None,current_phase]

class Loop(GetAttr):
    _default='_base'
    phases=None
    
    def __init__(self,
                 inner_loops:Optional[List['Loop']]=None, # Internal `Loop` objects that are called on `indexes` 
                 cbs:Optional[List['Callback']]=None, # Callbacks to be used for a given `Loop`. These will also be used in internal loops.
                 persist_cbs:bool=False, # Whether to review callbacks once a loop ends.
                 call_on:List[str]=None, # A list of methods to run this loop on.
                 parent:'Loop'=None, # Since loops can be nested, they need to reference the parent loops
                 phases:List[str]=None # Phases that a loop supports, for example train,validate,test...
                ):
        store_attr(but='cbs,inner_loops,call_on,phases')
        self.cbs=L()
        self.inner_loops=L()
        self.call_on=L(call_on)
        self.add_objs(L(cbs),self.add_cb)
        self.add_objs(L(inner_loops),self.add_loop)
        self._base=None
        self.__class__.phases=phases
        self.__class__.phase=None
        
    def _grab_objs(self, obj_cls,ls): 
        return L(o for o in ls if isinstance(o, obj_cls))
        
    def add_loop(self, loop):
        "Instantiate, add to `self.loops`"
        if isinstance(loop, type):   loop = loop(base_loop=self)
        elif loop.parent is None: loop.parent=ifnone(self.parent,self)
        self.inner_loops.append(loop)
        return self

    def remove_loop(self, loop):
        "Instantiate, remove from `self.loops`"
        loop.base_loop=None
        if isinstance(loop, type): 
            self.remove_objs(self._grab_objs(loop,self.inner_loops),self.remove_loop)
        else:
            if loop in self.inner_loops: self.inner_loops.remove(loop)
        return self
        
    def add_objs(self,objs,adder):
        "add `cbs` to `self`"
        L(objs).map(adder)
        return self

    def remove_objs(self,objs,remover):
        "rm all `objs` from `self` using `remover`"
        L(objs).map(remover)
        return self
        
    def add_cb(self, cb):
        "Instantiate, set as field in self, set as field in self._default, and add to `self.cbs`"
        if isinstance(cb, type): cb = cb()
        setattr(cb,self._default,self)
        setattr(self, cb.name, cb)
        self.cbs.append(cb)
        return self

    def remove_cb(self, cb):
        "Instantiate, rm `cb` from self, rm `cb` from self._default, and remove from `self.cbs`"
        if isinstance(cb, type): 
            self.remove_objs(self._grab_objs(cb,self.cbs),self.remove_cb)
        else:
            setattr(cb,self._default,None)
            if hasattr(self, cb.name): delattr(self, cb.name)
            if cb in self.cbs: self.cbs.remove(cb)
        return self
    
    def switch_phase(self,phase): self.phase=phase
    
    @classmethod
    def hooks(cls)->List[Hook]:
        hooks=L(Hook(hook) for k,hook in inspect.getmembers(cls) if Hook.ishook(k))
        return hooks.sorted()
        
    def hooks_with_phases(self,phase=None):
        hooks=self.hooks()
        hooks_with_phases=L()
        for phase in L(ifnone(phase,self.phases)):  
            hooks_with_phases+=L(Hook(self.switch_phase,phase=phase))
            hooks_with_phases+=hooks.filter(in_phase,current_phase=phase).map(Self.copy(phase=phase))
            
            
            return L(hooks_with_phases)

NameError: name 'Hook' is not defined

In [None]:
class Outer(Loop):
    def before_step(self):  return None
    def on_step(self,_order=2):      return None
    def after_step(self,_order=3):   return None
    def failed_step(self,_order=4):  return None
    def finally_step(self,_order=5,_phase='valid'): return None

class Inner(Loop):
    def before_iteration(self):  return None
    def on_iteration(self):      return None
    def after_iteration(self):   return None
    def failed_iteration(self):  return None
    def finally_iteration(self): return None

In [None]:
loop=Outer(phases=L(['train','valid']),
           inner_loops=Inner(call_on=['on_step']))
list(loop.hooks_with_phases())

In [None]:
loop=Outer(inner_loops=Inner())
list(loop.hooks())

In [None]:
# export
class CustomLoop(Loop):pass

In [None]:
# export 
class CustomCallback()

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()