In [1]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display=Display(visible=0,size=(400,300))
    display.start()

In [3]:
####### default_exp fastai.loop

In [4]:
# export
# Python native modules
import os
from copy import deepcopy
from typing import *
import types
import logging
import inspect
# Third party libs
from fastcore.all import *
# Local modules

_logger=logging.getLogger(__name__)

# Loop
> fastrl concept of generic loop objects. 

The goal for Loops is to make it easy to customize, and know how sections of code connects
to other parts.

### Why do we need this?
We have identified at least 3 different kinds of loops already:

    Learner (training)
    Source/Gym (Data Access)
    Agent (How an AI takes in data, generates actions)

### What is a loop?

    It should be capable of containing inner loops. 
    It should be able to handle "phases" that might be similar to each other. 
    It should self-describe its structure. 
    It should be easy to know which parts of the loop are taking long/short amounts of time.
    It should be flexible in state modification.
    It should alternatively make it easy show what fields are being changed at what points in time.

A Loop will act as a compiled structure. The actual result will be a compiled list of nodes that reference the original loop.

Loop is a compiled object that organizes the callbacks and loop calls into a possibly repeating sequence.

`Literal['order',int]` contains a int that should be more than or equal to 0. This determines when a 
function in the loops should be executed relative to other functions.

In [5]:
class TestLoop:
    def on_iteration(self)->Literal['order',3]:pass
    def on_step(self)->Literal['order',3]:pass

In [6]:
# export
PREFIXES=['on_','after_','before_','failed_','finally_']
PRE2ORDER={'on_':2,'after_':3,'before_':1,'failed_':4,
           'finally_':5}

class Node(object):
    def __init__(self,
                 function:Callable, # The function to be called on this node.
                 base_loop=None,         
                 # The BASE loop that this node is a part of. This will be
                 # different from the loop it is defined in.
                 loop=None, # The loop this node immediately reports to
                 children:List['Node']=None,
                 parent:'Node'=None,
                 order=0,
                 call_on='',
                 call_on2str:bool=False, # Whether to show the call_on the __repr__ and __str__ functions
                 order2str:bool=False, # Whether to show the order the __repr__ and __str__ functions
                 level2str:bool=False, # Whether to show the level the __repr__ and __str__ functions
                 indent2str:str='' # Amount and characters to indent the __repr__ and __str__ functions
                ):
        store_attr()
        if not self.isvalid(function):
            raise ValueError(f"""{function} cannot be private ('_'), and must have
                                 either {PREFIXES}""")
        self.sig=inspect.signature(function)
        self.name=function.__name__
        self.order=order
        self.user_defined_order=False
        for anno in L(anno_ret(function)):
            if 'order' in anno.__args__: 
                self.order+=anno.__args__[-1]
                self.user_defined_order=True
            
        for pre in PREFIXES: 
            if self.name.startswith(pre): 
                self.postfix=self.name.replace(pre,'')
                self.prefix=pre
                # If order is 0, then we will define the order based on the prefix
                if self.order==0: self.order=PRE2ORDER[pre]
                break
        
    def __call__(self,*args,**kwargs): return self.function(*args,**kwargs)
    def __round__(self): return round(self.order)  
    def __lt__(self,o:'Node'): return self.order<o.order
            
    @delegates(__init__,but='function,base_loop,loop,children,parent,order,call_on')
    def adjust_str(self,**kwargs):
        n=deepcopy(self)
        for k,v in kwargs.items(): setattr(n,k,v)
        return n
        
    def __str__(self):
        base=f"{self.function.__name__}"
        if self.call_on!='' and self.call_on2str: base=f"{self.call_on}:"+base
        if self.level2str:                        base=f"level:{self.level} "+base
        if self.order2str:                        base=f"order:{self.order} "+base
        base=self.indent2str*(self.level-1)+base
        return base
    
    def __repr__(self): return str(self)
    def __hash__(self): return hash(f'{self.call_on} {str(self)}')

    @property
    def level(self):
        return 1 if self.parent is None else (1+self.parent.level)

    @classmethod
    def isvalid(cls,name)->bool:
        if isinstance(name,Callable): name=name.__name__
        return not name.startswith('_') and \
          any(name.startswith(pre) for pre in PREFIXES)
    
add_docs(Node,"""Nodes are used to generate a static loop. They wrap whatever functions
are defined in that loop.""",
        adjust_str='If we want to change some of the str related params on the fly, we can call this method.')

In [7]:
def on_test()->Literal['order',3]: return {'this':'that'}
def on_other_test()->Literal['order',3]: return {'this':'that'}
Node(on_test,0)

on_test

In [8]:
Node(on_test,0)()

{'this': 'that'}

In [9]:
str(Node(on_test,0))

'on_test'

In [10]:
str(Node(on_test,0,call_on='parent_method',call_on2str=True))

'parent_method:on_test'

In [11]:
str(Node(on_test,0,call_on='parent_method',call_on2str=True,order2str=True))

'order:3 parent_method:on_test'

In [12]:
parent=Node(on_other_test)

str(Node(on_test,0,parent=parent,call_on='parent_method',call_on2str=True,order2str=True,level2str=True))

'order:3 level:2 parent_method:on_test'

In [13]:
parent=Node(on_other_test)

str(Node(on_test,0,parent=parent,call_on='parent_method',call_on2str=True,order2str=True,
         level2str=True,indent2str='--'))

'--order:3 level:2 parent_method:on_test'

In [14]:
parent=Node(on_other_test)

str(Node(on_test,0,parent=parent,call_on='parent_method',call_on2str=True,order2str=True,
         level2str=True,indent2str='--').adjust_str(indent2str='+++++'))

'+++++order:3 level:2 parent_method:on_test'

In [15]:
Node(on_test,0).postfix

'test'

In [16]:
Node(on_test,0).__hash__()

5365586429828178065

In [17]:
# export
class Loop(object):
    core_obj=None
    call_on=''
    
    def __init__(self,
                 full_loop=None,
                 inner_loops=None
                ):
        store_attr()
    
    @classmethod
    def nodes(cls,
              loop_instance=None # If we want to actually run the nodes, we need to use a loop isntance
             )->List[Node]:
        nodes=L(Node(node) for k,node in inspect.getmembers(cls) if Node.isvalid(k))
        if loop_instance is not None: 
            for n in nodes: n.function=getattr(loop_instance,n.name)
        # For a given set of nodes, fix the orders, and defining the order if it
        # it not defined for a given node.
        for postfix,subnodes in groupby(nodes,lambda o:o.postfix).items():
            # For a given subnode, we want to see if the user defined any of the orders
            if any(n.user_defined_order for n in subnodes):
                # If the user defined ONE of the orders then we need to make sure 
                # the other nodes have orders that make sense.
                user_ordered_nodes=[n for n in subnodes if n.user_defined_order]
                if all([n.user_defined_order for n in subnodes]): pass
                elif len(user_ordered_nodes)==1:
                    max_order=max(user_ordered_nodes).order
                    idxs=L(subnodes).argwhere(lambda o:not o.user_defined_order)
                    for subn in L(subnodes)[idxs]:subn.order+=max_order
                elif len(user_ordered_nodes)>1:
                    # If the user defined MORE THAN ONE of the orders, then we will
                    # warn them and not mess with the orders.
                    _logger.warning("""nodes: %s have %s user defined orders. 
                                       because there are more than 1 user defined
                                       orders, we will not be able to define them
                                       automatically.""",str(subnodes),len(user_ordered_nodes))
        # Duplicate the nodes if for a given node, it is called on 2 different
        # locations.
        final_nodes=L()
        if cls.call_on!='': 
            for n in nodes: 
                for call in L(cls.call_on.split(',')):
                    n.call_on=call
                    final_nodes.append(deepcopy(n))
        else:
            final_nodes=nodes
        return final_nodes.sorted()
    
    @classmethod
    def with_inner_loops(cls,
             loop_instance=None, # If we want to actually run the nodes, we need to use a loop isntance
             inner_loops:L=None # Either list of loop cls or instances 
            ):
        if loop_instance is None:
            return cls.nodes()+L(inner_loops).map(Self.with_inner_loops()).concat()
        else:
            ns=cls.nodes(loop_instance=loop_instance)
            ns+=L(inner_loops).map(lambda o:o.with_inner_loops(loop_instance=o)).concat()
            return ns
    
    @classmethod
    def organized_nodes(cls,
             loop_instance=None, # If we want to actually run the nodes, we need to use a loop isntance
             inner_loops=None,
             as_dict=False)->Union[List[Node],Dict[str,Node]]:
        nodes=list(cls.with_inner_loops(loop_instance=loop_instance,inner_loops=inner_loops))
        # For the nodes that have `call_on`s it will be more efficient to map
        # them the function they are being called on. Nodes without call_ons we 
        # can be confident are the root nodes.
        grouped_nodes=groupby(L(nodes),lambda o:o.call_on)
        grouped_nodes={g:L(l).sorted() for g,l in grouped_nodes.items()}
        # For each node, add the call_on nodes if they exist 
        organized_nodes=L()
        order=0
        nodes=grouped_nodes['']
        while nodes:
            n=nodes.pop(0)
            # The order will cascade ensuring that all the nodes have their 
            # respective orders organized. This is needed incase you execute `sorted` 
            # on the final list of nodes.
            if order==0: 
                order=n.order
            else:
                _order=order+n.order
                n.order+=order
                n.order=round(n.order,2)
                order=_order
                
            organized_nodes.append(n)
            if n.name in grouped_nodes:
                # If there are call_on nodes for `n`, then we want to process them
                # immediately after `n`. We reverse the list and add insert them
                # on index 0, thus the nodes that should go first will be in 
                # index 0-ish and the ones that should go later will be farthur 
                # down this list.
                n.children=grouped_nodes[n.name]
                for inner_n in reversed(grouped_nodes[n.name]): 
                    inner_n.parent=n
                    nodes.insert(0,inner_n)
        return organized_nodes

In [18]:
class Outer(Loop):
    def before_step(self)->Literal['order',1]:   print('before_step')
    def on_step(self)->Literal['order',2]:       print('on_step')
    def after_step(self)->Literal['order',3]:    print('after_step')
    def failed_step(self)->Literal['order',4]:   print('failed_step')
    def finally_step(self)->Literal['order',5]:  print('finally_step')
 
    def before_jump(self)->Literal['order',6]:   print('before_jump')
    def on_jump(self)->Literal['order',7]:       print('on_jump')
    def after_jump(self)->Literal['order',8]:    print('after_jump')
    def failed_jump(self)->Literal['order',9]:   print('failed_jump')
    def finally_jump(self)->Literal['order',10]: print('finally_jump')

class Inner(Loop):
    call_on='on_step,after_step,finally_jump'
    
    def before_iteration(self)->Literal['order',1]: print('before_iteration')
    def on_iteration(self):                         print('on_iteration')
    def after_iteration(self):                      print('after_iteration')
    def failed_iteration(self):                     print('failed_iteration')
    def finally_iteration(self):                    print('finally_iteration')
    
class FailingInner(Loop):
    call_on='finally_iteration'
    
    def on_force_fail(self):                    
        print('on_force_fail')
        raise Exception

In [19]:
list(Outer.organized_nodes(inner_loops=L(Inner,FailingInner)).map(
    lambda o:o.adjust_str(call_on2str=True,indent2str='     ')
))

[before_step,
 on_step,
      on_step:before_iteration,
      on_step:on_iteration,
      on_step:after_iteration,
      on_step:failed_iteration,
      on_step:finally_iteration,
           finally_iteration:on_force_fail,
 after_step,
      after_step:before_iteration,
      after_step:on_iteration,
      after_step:after_iteration,
      after_step:failed_iteration,
      after_step:finally_iteration,
           finally_iteration:on_force_fail,
 failed_step,
 finally_step,
 before_jump,
 on_jump,
 after_jump,
 failed_jump,
 finally_jump,
      finally_jump:before_iteration,
      finally_jump:on_iteration,
      finally_jump:after_iteration,
      finally_jump:failed_iteration,
      finally_jump:finally_iteration,
           finally_iteration:on_force_fail]

In [20]:
@patch
def run(self:Loop):
    full_loop=self.organized_nodes(loop_instance=self,inner_loops=L(self.inner_loops))
    
    failed:Node=None
    failed_parent:Node=None
    for n in full_loop:
        # If there wasn't a failure, skip `failed_` events.
        if failed is None and n.prefix=='failed_': continue
        # If there was a failure in the loop, we need to skip parts of the loop
        # and see if any of the 'failed_' methods handle the exception as well as
        # calling the 'finally_' fields.
        if failed is not None:
            # If there was an exception, we need to do the following:
            # if we are out of the section that failed, shift up to the parent
            if n.postfix!=failed.postfix and failed_parent is None: 
                # If the postfix is no longer the failed node, then we need to
                # escalate to a lower level in the loop.
                failed_parent=failed.parent
            # skip all steps that are not finally or failed events
            if n.prefix not in ('finally_','failed_'): continue
            # Check that we arent going back into another sub loop
            if failed_parent is not None and n.level>failed_parent.level: 
                if n.postfix!=failed_parent.postfix and failed_parent.parent is not None:
                    failed_parent=failed_parent.parent
                continue
            if failed_parent is not None and n.level==failed_parent.level:
                if n.postfix!=failed_parent.postfix: continue
        
        try: 
            res=n()
            if failed is not None and type(res)==bool and res:
                failed,failed_parent=None,None
        except:
            failed=n
    
    return None

In [21]:
Outer(inner_loops=L(Inner(),FailingInner())).run()

before_step
on_step
before_iteration
on_iteration
after_iteration
finally_iteration
on_force_fail
failed_iteration
finally_iteration
failed_step
finally_step
failed_jump
finally_jump
failed_iteration
finally_iteration


In [1]:
# # hide
# from fastcore.imports import in_colab

# # Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
# if not in_colab():
#     from nbdev.export import *
#     from nbdev.export2html import *
#     from nbverbose.cli import *
#     make_readme()
#     notebook2script()
#     notebook2html()