In [1059]:
#hide
#skip
%config Completer.use_jedi = False
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [1060]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display=Display(visible=0,size=(400,300))
    display.start()

In [1061]:
# default_exp fastai.loop

In [1062]:
# export
# Python native modules
import os
from copy import deepcopy,copy
from typing import *
import types
import logging
import inspect
# Third party libs
from fastcore.all import *
import numpy as np
# Local modules

_logger=logging.getLogger(__name__)

# Loop
> fastrl concept of generic loop objects. 

The goal for Loops is to make it easy to customize, and know how sections of code connects
to other parts.

### Why do we need this?
We have identified at least 3 different kinds of loops already:

    Learner (training)
    Source/Gym (Data Access)
    Agent (How an AI takes in data, generates actions)

### What is a loop?

    It should be capable of containing inner loops. 
    It should be able to handle "phases" that might be similar to each other. 
    It should self-describe its structure. 
    It should be easy to know which parts of the loop are taking long/short amounts of time.
    It should be flexible in state modification.
    It should alternatively make it easy show what fields are being changed at what points in time.

A Loop will act as a compiled structure. The actual result will be a compiled list of nodes that reference the original loop.

Loop is a compiled object that organizes the callbacks and loop calls into a possibly repeating sequence.

`Literal['order',int]` contains a int that should be more than or equal to 0. This determines when a 
function in the loops should be executed relative to other functions.

In [1063]:
# export
PREFIXES=['before_','on_','after_','failed_','finally_']

class NodeException(Exception):pass

def unwrap_nodes(n:'Node'):
    while n is not None:
        yield n
        n=n.parent
    
def is_relevant_cb(cb:Callback,n:'Node'):
    return all([
        hasattr(cb,n.name),
        (cb.call_on is None or any([
            o in cb.call_on for o in unwrap_nodes(n)
        ]))
    ])

class Node(object):
    def __init__(self,
                 function:Callable,
                 parent:Optional['Node']=None,
                 children:Optional[List['Node']]=None,
                 loop:'Loop'=None):
        store_attr(but='cbs')
        self.order=None
        for anno in L(anno_ret(self.function)):
            if 'order' in anno.__args__: 
                self.order=anno.__args__[-1]
        if self.order is None: 
            raise NodeException(f'Node: {self.name} needs Literal["order",int]')
        self.cbs=[]
        if loop is not None:
            self.cbs=L(cb() for cb in L(loop.get_cbs()) if is_relevant_cb(cb,self))
            
    def run(self):
        self.function()
        for cb in self.cbs: getattr(cb,self.name)(n=self)

    def __repr__(self): 
        if self.cbs: 
            strs=[cb.name(self.name,len(self.cbs)>i>0) for i,cb in enumerate(self.cbs)]
            return self.name+'\n'+self.loop.tab+'\n'.join(strs)+'\n'
        return self.name



    def __lt__(self,o:'Node'): return self.order<o.order
    def __eq__(self,o:Union['Node',Callable]):
        return getattr(o,'function',o).__qualname__==self.function.__qualname__

    @property
    def name(self): return self.function.__name__
    @property
    def prefix(self): return self.name.split('_')[0]
    @property
    def postfix(self): return '_'.join(self.name.split('_')[1:])

    @classmethod
    def isvalid(cls,name)->bool:
        if isinstance(name,Callable): name=name.__name__
        return not name.startswith('_') and \
          any(name.startswith(pre) for pre in PREFIXES)

Nodes are self organizing parts of the Loop Graph.Their execution order is
determined by their relationship in the graph.

In [1064]:
def on_test()->Literal["order",1]:pass
Node(on_test)

on_test

In [1065]:
anno_dict(on_test)

{'return': typing.Literal['order', 1]}

In [1066]:
def on_test(n:Node=None)->Literal["order",1]:pass
Node(on_test)

on_test

In [1067]:
anno_dict(on_test)

{'n': __main__.Node, 'return': typing.Literal['order', 1]}

In [1068]:
def on_test()->Literal["order",1]:pass
Node(on_test)

on_test

In [1069]:
anno_dict(on_test)

{'return': typing.Literal['order', 1]}

In [1070]:
# export
def intersection_idxs(a:L,b:L,after_insert=False):
    "Get the idx where elements in `a` appear in `b`. If `after_insert`=True, the offset the idxs."
    idxs=a.argwhere(lambda o:not isinstance(o,(list,L)) and o in b)
    if after_insert: return np.array(idxs)+np.arange(1,len(idxs)+1)
    else:            return np.array(idxs)

In [1071]:
original_list=L('a b c d e f g h i j k l'.split())

In [1072]:
idxs=intersection_idxs(original_list,L('d g l'.split()));idxs

array([ 3,  6, 11])

In [1073]:
offset_idxs=intersection_idxs(original_list,L('d g l'.split()),after_insert=True);offset_idxs

array([ 4,  8, 14])

In [1074]:
inserted_list=deepcopy(original_list)
for i in offset_idxs: inserted_list.insert(i,'thing')
list(inserted_list)

['a',
 'b',
 'c',
 'd',
 'thing',
 'e',
 'f',
 'g',
 'thing',
 'h',
 'i',
 'j',
 'k',
 'l',
 'thing']

In [1075]:
test_eq(offset_idxs,inserted_list.argwhere(lambda o:o=='thing'))

In [1076]:
# export
class CallbackException(Exception):pass

class Callback(object):
    call_on,loop=None,None
    
    def __new__(cls,*args,**kwargs):
        for k,fn in inspect.getmembers(cls):
            if not Node.isvalid(k): continue
            
            if 'n' not in anno_dict(fn):
                msg=f'Function {k} of {cls} needs to have n:Node=None as an arg.'
                raise CallbackException(msg)

        return super(Callback, cls).__new__(cls, *args, **kwargs)

    def __repr__(self): return type(self).__name__

    def name(self,event:str=None,add_arrow=False):
        name=str(self)
        if event is not None:
            _event=getattr(self,event)
            annotations=deepcopy(anno_dict(_event))
            if 'return' in annotations:
                name+=f' -> {annotations["return"]}'
        if add_arrow: name=' '*(len(name)//2)+'\u2193'+'\n'+name
        return name

In [1077]:
print('\u2193')

↓


In [1078]:
class GoodCallback(Callback):
    def before_iteration(self,n:Node=None)->dict(this=list,that=str):
        print(n.loop.tab,n.loop.common_obj,'   OuterCallback called lol')
        
class BadCallback(Callback):
    def after_iteration(self):
        print(n.loop.tab,n.loop.common_obj,'   OuterCallback called lol')

In [1079]:
GoodCallback().name('before_iteration')

"GoodCallback -> {'this': <class 'list'>, 'that': <class 'str'>}"

In [1080]:
print(GoodCallback().name('before_iteration',True))

                               ↓
GoodCallback -> {'this': <class 'list'>, 'that': <class 'str'>}


In [1081]:
GoodCallback().name()

'GoodCallback'

In [1082]:
GoodCallback()
print(GoodCallback().name())
test_fail(BadCallback)

GoodCallback


In [1083]:
# export
class Loop(object):    
    _common_obj,cbs,parent_loop=None,None,None
    
    def __init__(self,common_obj=None,cbs=None):
        self._common_obj=ifnone(common_obj,self._common_obj)
        self.cbs=ifnone(cbs,self.cbs)
   
    @classmethod
    def from_nodes(cls,loops,**kwargs):
        loop=cls(**kwargs)
        loop.nodes=loop.nodes(loops,instantiate=True,instance=loop)
        return loop 
    
    def __call__(self): return self

    def get_cbs(self):
        for cb in L(self.cbs): cb.loop=self
        return L(self.cbs) if self.parent_loop is None else (self.parent_loop.get_cbs()+L(self.cbs))
    
    @property
    def common_obj(self):
        return ifnone(self._common_obj,ifnone(self.parent_loop,noop).common_obj)
        
    @classmethod
    def nodes(cls,loops,parent_node=None,instantiate=False,instance=None):
        loops=L(loops)
        loop=ifnone(instance,cls()) if instantiate else cls
        nodes=L(Node(n,loop=loop,parent=parent_node) 
                for k,n in inspect.getmembers(loop) if Node.isvalid(k))
        nodes=nodes.sorted()
        for l in loops:
            for idx in intersection_idxs(nodes,l.call_on,after_insert=True):
                l.parent_loop=loop
                nodes.insert(idx,l.nodes(loops,parent_node=nodes[idx-1],
                                         instantiate=instantiate))
                nodes[idx-1].children=nodes[idx]
        return nodes
    
    def run(self,nodes=None):
        for n in ifnone(nodes,self.nodes):
            if isinstance(n,(L,list)): 
                n[0].loop.run(n)
            else:                      
                ret=n.run()
                if isinstance(ret,dict) and self.common_obj is not None:
                    for k,v in self.ret.items(): setattr(self.common_obj,k,v)

                
class Learner(object):pass
class Outer(Loop):
    common_obj=Learner
    
    def run(self,nodes=None):
        print('--- ENTERING OUTER LOOP ---')
        super().run(nodes)
        print('--- EXITING OUTER LOOP ---')
    
    def before_step(self) ->Literal['order',1]:  print('before_step')
    def on_step(self)     ->Literal['order',2]:  print('on_step')
    def after_step(self)  ->Literal['order',3]:  print('after_step')
    def failed_step(self) ->Literal['order',4]:  print('failed_step')
    def finally_step(self)->Literal['order',5]:  print('finally_step')
 
    def before_jump(self) ->Literal['order',6]:  print('before_jump')
    def on_jump(self)     ->Literal['order',7]:  print('on_jump')
    def after_jump(self)  ->Literal['order',8]:  print('after_jump')
    def failed_jump(self) ->Literal['order',9]:  print('failed_jump')
    def finally_jump(self)->Literal['order',10]: print('finally_jump')

class Inner(Loop):
    call_on=L(Outer.on_step,Outer.after_step,Outer.finally_jump)
    
    tab='\t'
    
    def run(self,nodes=None):
        print(self.tab,'--- ENTERING INNER LOOP ---')
        super().run(nodes)
        print(self.tab,'--- EXITING INNER LOOP ---')
    
    def before_iteration(self) ->Literal['order',1]: print(self.tab,'before_iteration')
    def on_iteration(self)     ->Literal['order',2]: print(self.tab,'on_iteration')
    def after_iteration(self)  ->Literal['order',3]: print(self.tab,'after_iteration')
    def failed_iteration(self) ->Literal['order',4]: print(self.tab,'failed_iteration')
    def finally_iteration(self)->Literal['order',5]: print(self.tab,'finally_iteration')
    
class FailingInner(Loop):
    call_on=L(Inner.failed_iteration)
    
    tab='\t\t'
    
    def run(self,nodes=None):
        print(self.tab,'--- ENTERING FailingInner LOOP ---')
        super().run(nodes)
        print(self.tab,'--- EXITING FailingInner LOOP ---')
    
    def on_force_fail(self) ->Literal['order',1]:                    
        print(self.tab,'on_force_fail')
        # raise Exception
        
list(Outer.nodes([Inner,FailingInner],instantiate=True))

[before_step,
 on_step,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration],
 after_step,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration],
 failed_step,
 finally_step,
 before_jump,
 on_jump,
 after_jump,
 failed_jump,
 finally_jump,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration]]

In [1084]:
Outer.from_nodes([Inner,FailingInner]).run()

--- ENTERING OUTER LOOP ---
before_step
on_step
	 --- ENTERING INNER LOOP ---
	 before_iteration
	 on_iteration
	 after_iteration
	 failed_iteration
		 --- ENTERING FailingInner LOOP ---
		 on_force_fail
		 --- EXITING FailingInner LOOP ---
	 finally_iteration
	 --- EXITING INNER LOOP ---
after_step
	 --- ENTERING INNER LOOP ---
	 before_iteration
	 on_iteration
	 after_iteration
	 failed_iteration
		 --- ENTERING FailingInner LOOP ---
		 on_force_fail
		 --- EXITING FailingInner LOOP ---
	 finally_iteration
	 --- EXITING INNER LOOP ---
failed_step
finally_step
before_jump
on_jump
after_jump
failed_jump
finally_jump
	 --- ENTERING INNER LOOP ---
	 before_iteration
	 on_iteration
	 after_iteration
	 failed_iteration
		 --- ENTERING FailingInner LOOP ---
		 on_force_fail
		 --- EXITING FailingInner LOOP ---
	 finally_iteration
	 --- EXITING INNER LOOP ---
--- EXITING OUTER LOOP ---


In [1085]:
class OuterCallback(Callback):
    call_on=L(Outer.on_step)
    
    def before_iteration(self,n:Node=None)->dict(this=list,that=str):
        print(n.loop.tab,n.loop.common_obj,'   OuterCallback called lol')

In [1089]:
list(Outer.from_nodes([Inner,FailingInner],cbs=L(OuterCallback,OuterCallback)).nodes)

[before_step,
 on_step,
 (#6) [before_iteration
	OuterCallback -> {'this': <class 'list'>, 'that': <class 'str'>}
                                ↓
OuterCallback -> {'this': <class 'list'>, 'that': <class 'str'>}
,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration],
 after_step,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration],
 failed_step,
 finally_step,
 before_jump,
 on_jump,
 after_jump,
 failed_jump,
 finally_jump,
 (#6) [before_iteration,on_iteration,after_iteration,failed_iteration,[on_force_fail],finally_iteration]]

In [1090]:
Outer.from_nodes([Inner,FailingInner],cbs=L(OuterCallback)).run()

--- ENTERING OUTER LOOP ---
before_step
on_step
	 --- ENTERING INNER LOOP ---
	 before_iteration
	 <class '__main__.Learner'>    OuterCallback called lol
	 on_iteration
	 after_iteration
	 failed_iteration
		 --- ENTERING FailingInner LOOP ---
		 on_force_fail
		 --- EXITING FailingInner LOOP ---
	 finally_iteration
	 --- EXITING INNER LOOP ---
after_step
	 --- ENTERING INNER LOOP ---
	 before_iteration
	 on_iteration
	 after_iteration
	 failed_iteration
		 --- ENTERING FailingInner LOOP ---
		 on_force_fail
		 --- EXITING FailingInner LOOP ---
	 finally_iteration
	 --- EXITING INNER LOOP ---
failed_step
finally_step
before_jump
on_jump
after_jump
failed_jump
finally_jump
	 --- ENTERING INNER LOOP ---
	 before_iteration
	 on_iteration
	 after_iteration
	 failed_iteration
		 --- ENTERING FailingInner LOOP ---
		 on_force_fail
		 --- EXITING FailingInner LOOP ---
	 finally_iteration
	 --- EXITING INNER LOOP ---
--- EXITING OUTER LOOP ---


In [1088]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()

converting /home/fastrl_user/fastrl/nbs/index.ipynb to README.md
Converted 00_core.ipynb.
Converted 00_nbdev_extension.ipynb.
Converted 02_fastai.loop.ipynb.
Converted 02_fastai.loop.old.ipynb.
Converted 03_callback.core.ipynb.
Converted 04_agent.ipynb.
Converted 05_data.test_async.ipynb.
Converted 05a_data.block.ipynb.
Converted 05b_data.gym.ipynb.
Converted 06a_memory.experience_replay.ipynb.
Converted 06f_memory.tensorboard.ipynb.
Converted 10a_agents.dqn.core.ipynb.
Converted 10b_agents.dqn.targets.ipynb.
Converted 10c_agents.dqn.double.ipynb.
Converted 10d_agents.dqn.dueling.ipynb.
Converted 10e_agents.dqn.categorical.ipynb.
Converted 11a_agents.policy_gradient.ppo.ipynb.
Converted 20_test_utils.ipynb.
Converted index.ipynb.
Converted nbdev_template.ipynb.
converting: /home/fastrl_user/fastrl/nbs/02_fastai.loop.ipynb
An error occurred while executing the following cell:
------------------
from nbverbose.showdoc import show_doc
from fastrl.fastai.loop import *
------------------

