In [1]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [2]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display=Display(visible=0,size=(400,300))
    display.start()

In [3]:
# default_exp fastai.loop

In [4]:
# export
# Python native modules
import os
from copy import deepcopy
from typing import *
import types
import logging
import inspect
# Third party libs
from fastcore.all import *
import numpy as np
# Local modules

_logger=logging.getLogger(__name__)

# Loop
> fastrl concept of generic loop objects. 

The goal for Loops is to make it easy to customize, and know how sections of code connects
to other parts.

### Why do we need this?
We have identified at least 3 different kinds of loops already:

    Learner (training)
    Source/Gym (Data Access)
    Agent (How an AI takes in data, generates actions)

### What is a loop?

    It should be capable of containing inner loops. 
    It should be able to handle "phases" that might be similar to each other. 
    It should self-describe its structure. 
    It should be easy to know which parts of the loop are taking long/short amounts of time.
    It should be flexible in state modification.
    It should alternatively make it easy show what fields are being changed at what points in time.

A Loop will act as a compiled structure. The actual result will be a compiled list of nodes that reference the original loop.

Loop is a compiled object that organizes the callbacks and loop calls into a possibly repeating sequence.

`Literal['order',int]` contains a int that should be more than or equal to 0. This determines when a 
function in the loops should be executed relative to other functions.

In [5]:
# export
PREFIXES=['before_','on_','after_','failed_','finally_']

class NodeException(Exception):pass

class Node(object):
    def __init__(self,
                 function:Callable,
                 parent:Optional['Node']=None,
                 children:Optional[List['Node']]=None):
        store_attr()
        self.order=None
        for anno in L(anno_ret(self.function)):
            if 'order' in anno.__args__: 
                self.order=anno.__args__[-1]
        if self.order is None: 
            raise NodeException(f'Node: {self.name} needs Literal["order",int]')
        
    def __str__(self):  return self.name
    def __repr__(self): return str(self)
    def __lt__(self,o:'Node'): return self.order<o.order
    def __eq__(self,o:Union['Node',Callable]):
        return getattr(o,'function',o).__qualname__==self.function.__qualname__

    @property
    def name(self): return self.function.__name__
    @property
    def prefix(self): return self.name.split('_')[0]
    @property
    def postfix(self): return '_'.join(self.name.split('_')[1:])

    @classmethod
    def isvalid(cls,name)->bool:
        if isinstance(name,Callable): name=name.__name__
        return not name.startswith('_') and \
          any(name.startswith(pre) for pre in PREFIXES)

Nodes are self organizing parts of the Loop Graph.Their execution order is
determined by their relationship in the graph.

In [91]:
# export
def call_on_node_idxs(nodes:List[Node],loop):
    "Given a loop, get the idx locations for them"
    idxs=nodes.argwhere(lambda o:not isinstance(o,(list,L)) and o in loop.call_on)
    return np.array(idxs)+np.cumsum(L.range(1,len(idxs)+1))

class Loop(object):
    
    @classmethod
    def nodes(cls,loops,instantiate=False):
        loops=L(loops)
        loop=cls() if instantiate else cls
        nodes=L(Node(node) for k,node in inspect.getmembers(loop) if Node.isvalid(k))
        nodes=nodes.sorted()
        for l in loops:
            for idx in call_on_node_idxs(nodes,l):
                nodes.insert(idx,l.nodes(loops,instantiate))
        return nodes
    
    def run(self):

class Outer(Loop):
    def before_step(self) ->Literal['order',1]:  print('before_step')
    def on_step(self)     ->Literal['order',2]:  print('on_step')
    def after_step(self)  ->Literal['order',3]:  print('after_step')
    def failed_step(self) ->Literal['order',4]:  print('failed_step')
    def finally_step(self)->Literal['order',5]:  print('finally_step')
 
    def before_jump(self) ->Literal['order',6]:  print('before_jump')
    def on_jump(self)     ->Literal['order',7]:  print('on_jump')
    def after_jump(self)  ->Literal['order',8]:  print('after_jump')
    def failed_jump(self) ->Literal['order',9]:  print('failed_jump')
    def finally_jump(self)->Literal['order',10]: print('finally_jump')

class Inner(Loop):
    call_on=L(Outer.on_step,Outer.after_step,Outer.finally_jump)
    
    def before_iteration(self) ->Literal['order',1]: print('before_iteration')
    def on_iteration(self)     ->Literal['order',2]: print('on_iteration')
    def after_iteration(self)  ->Literal['order',3]: print('after_iteration')
    def failed_iteration(self) ->Literal['order',4]: print('failed_iteration')
    def finally_iteration(self)->Literal['order',5]: print('finally_iteration')
    
class FailingInner(Loop):
    call_on=L(Inner.failed_iteration)
    
    def on_force_fail(self) ->Literal['order',1]:                    
        print('on_force_fail')
        raise Exception
        
list(Outer.nodes([Inner,FailingInner],instantiate=True))

[before_step,
 on_step,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration],
 after_step,
 failed_step,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration],
 finally_step,
 before_jump,
 on_jump,
 after_jump,
 failed_jump,
 finally_jump,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration]]

In [72]:
nodes=Outer.nodes([Inner,FailingInner],instantiate=True)

In [73]:
def node_run(n):
    if isinstance(n,L): 
        for o in n: node_run(o)
    else:
        n.function()
        

In [74]:
node_run(nodes)

before_step
on_step
before_iteration
on_iteration
after_iteration
failed_iteration
on_force_fail


Exception: 

In [62]:
list(nodes)

[before_step,
 on_step,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration],
 after_step,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration],
 failed_step,
 finally_step,
 before_jump,
 on_jump,
 after_jump,
 failed_jump,
 finally_jump,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration]]

In [18]:
Node(FailingInner.on_force_fail) in Inner.call_on

False

In [26]:
n

(#6) [before_iteration,on_iteration,after_iteration,failed_iteration,finally_iteration,[on_force_fail]]

In [None]:
list(n)

In [None]:
out=inspect.signature(n[0].function)

In [None]:
out

In [None]:
n[0].function.__module__

In [None]:
test.__annotations__

In [None]:
Inner.call_on[0].__qualname__

In [None]:
n[0].function.__qualname__

In [None]:
def run_or_return(f):
    try: return f()
    except: return f

{v:run_or_return(getattr(n[0].function,v)) for v in n[0].function.__dir__() if v not in ['__globals__']}

In [None]:
n[0].function.__class__

In [None]:
n[0].function.__class__

In [None]:
n[0].function.__class__

In [None]:
n[0].function.__class__

In [None]:
n[0].function.__class__

In [None]:
n[0].function.__class__

In [None]:
n[0].function.__class__

In [None]:
n[0].function.__

In [None]:
n.function

In [None]:
class TestLoop:
    def on_iteration(self)->Literal['order',3]:pass
    def on_step(self)->Literal['order',3]:pass

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()