In [497]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [498]:
# default_exp fastai.loop

In [499]:
# export
# Python native modules
import os,sys,json
from copy import deepcopy,copy
from typing import *
import types
import logging
import inspect
from itertools import chain,product
from functools import partial
# Third party libs
from fastcore.all import *
import numpy as np
# Local modules
from fastrl.core import test_in

IN_IPYTHON=False

_logger=logging.getLogger(__name__)

In [500]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display=Display(visible=0,size=(400,300))
    display.start()

# Loop
> fastrl concept of generic loop objects. 

The goal for Loops is to make it easy to customize, and know how sections of code connects
to other parts.

Ideally, a loop can consist of inner loops to have a format:
```python
FitLoop
BatchLoop
ClassicTrain
BatchLoop
StepLoop
ValidStepLoop
```
Which might have a structure:
```json
{
"FitLoop": [
    "FitLoop (epoch)": [
        "ClassicTrain (train)": [
            "BatchLoop (batches)": [
                "StepLoop (pred)": [],
                "StepLoop (loss)": [],
                "StepLoop (backward)": [],
                "StepLoop (step)": [],
                "StepLoop (zero_grad)": []
            ]
        ],
        "ClassicTrain (valid)": [
            "BatchLoop (batches)": [
                "ValidStepLoop (pred)": [],
                "ValidStepLoop (loss)": [],
            ]
        ]
    ]
}
```

In [501]:
# export
def map_obj_attr2func_attr(obj,fn):
    got_attrs={}
    for k,v in inspect.signature(fn).parameters.items():
        if k=='self':continue
        elif v.default==inspect._empty:
            got_attrs[k]=getattr(obj,k)
        else:
            got_attrs[k]=getattr(obj,k,v.default)
    return got_attrs

In [502]:
class A():
    loss=0.5
    
a=A()
    
map_obj_attr2func_attr(a,OuterCallback.before_iteration)

{'loss': 0.5}

In [503]:
# export
EVENT_ORDER_MAPPING={}
PREFIXES=['before_','on_','after_','failed_','finally_']

def isevent(o): return issubclass(o.__class__,Event)
class EventException(Exception):pass
def _default_raise(_placeholder): raise
def grab_parent_event(o): return o.parent_event

class KwargSetAttr(object):
    def __setattr__(self,name,value):
        "Allow setting attrs via kwarg."
        super().__setattr__(name,value)

class Events(KwargSetAttr,L):
    def __init__(self,items=None,postfix=None,prefix=None,item_iter_hint='prefix',
                 order=0,parent_event=None,*args,**kwargs):
        store_attr(but='items')
        super().__init__(items=items,*args,**kwargs)
        
    def flat(self):
        return Events(chain.from_iterable(self),
                      postfix=self.postfix,prefix=self.prefix,
                      item_iter_hint=self.item_iter_hint,order=self.order,
                      parent_event=self.parent_event)

    def __lt__(self,o:'Event'): return self.order<o.order    
    def todict(self): 
        return {getattr(o,self.item_iter_hint):o for o in self}
    
    def __repr__(self): 
        if len(self)==0: return super().__repr__()
        return '['+'\n'.join([str(o) for o in self])+']'
    def run(self):
        for o in self: o.run()                                                  # fastrl.skip_traceback

class Event(KwargSetAttr):
    def __init__(self,
                 function:Callable,
                 loop=None,
                 override_name=None,
                 override_qualname=None,
                 override_module=None,
                 order=None
                ):
        store_attr()
        if self.function==noop and self.prefix=='failed_':
            self.function=_default_raise
        # We set the order over the entire Loop definition
        if self.order is None:
            if self.outer_name not in EVENT_ORDER_MAPPING: self.order=1
            else: self.order=EVENT_ORDER_MAPPING[self.outer_name]
            EVENT_ORDER_MAPPING[self.outer_name]=self.order+1

        # self.original_name=self.function.__module__+'.'+self.function.__qualname__
            
        if self.name.startswith('_') or not any(self.name.startswith(pre) for pre in PREFIXES):
            raise EventException(f'{self.name} needs to start with any {PREFIXES}')
            
        self.cbs=L()
        
    def climb(self):
        "Returns a generator that moves up to the parent/root event"
        if self.loop is not None:
            yield from self.loop.climb()
            
    @property
    def level(self): return len(list(self.climb()))
        
    @classmethod
    def from_override_name(cls,name,**kwargs):
        return cls(noop,override_name=name,**kwargs)
  
    def init_cbs(self):
        "Look at the cbs in the `parent_loop` and add them to `self`"
        cbs=L(self.climb())[-1].cbs
        # parent_events=[self.name]+[o.parent_event.name for o in self.climb() if o.parent_event is not None]
        parent_events=[self.qualname]+L(self.climb())\
                                   .map(grab_parent_event)\
                                   .filter(ifnone,b=False)\
                                   .map(Self.qualname())
        # Check if the callback has an event relevent to self
        for cb in L(cbs):
            if hasattr(cb,self.name):
                if not cb.call_on or any(o.qualname in parent_events for o in cb.call_on):
                    self.cbs.append(cb)
        
    @property
    def root_loop(self): return list(self.climb())[-1]
    def __call__(self,*args,**kwargs): 
        ret=self.function(self.loop,*args,**kwargs)                             # fastrl.skip_traceback
        for cb in self.cbs: 
            fn=getattr(cb,self.name)
            params=map_obj_attr2func_attr(self.root_loop,fn)
            
            cb_ret=fn(**params)
            
            if isinstance(cb_ret,dict):
                loop=self.root_loop
                for k,v in cb_ret.items(): setattr(loop,k,v)
            
        return ret

    def __lt__(self,o:'Event'): return self.order<o.order
    @property
    def name(self): return ifnone(self.override_name,self.function.__name__)
    @property
    def module(self): return ifnone(self.override_module,self.function.__module__)
    @property
    def qualname(self): return ifnone(self.override_qualname,self.function.__qualname__)
    @property
    def prefix(self): return self.name.split('_')[0]+'_'
    @property
    def postfix(self): return '_'.join(self.name.split('_')[1:])
    @property
    def outer_name(self): return self.module+'.'+self.qualname.split('.')[0]
    @property
    def original_name(self): 
        return self.function.__module__+'.'+self.function.__qualname__
    
    def __repr__(self): return self.module+'.'+self.name
    def with_inner(self):
        return (self,Events(postfix=self.postfix,
                            prefix=self.prefix+'inner',
                            order=self.order))

event=Event

In [504]:
# export
class Loops(L):
    def run(self):
        for o in self: o.run()                                                  # fastrl.skip_traceback

class Loop(object):
    def __init__(self,cbs:L=None,verbose:bool=False):
        store_attr()
        # When a loop is initialized, we need to make sure that the events
        # are re-initialized also
        events(self,reset=True)
        
        self.parent_loop=None
        self.parent_event=None
        
        _events=Events(inspect.getmembers(self)).map(Self[-1]).filter(isevent).sorted()
        # print(Events(inspect.getmembers(self)).map(Self[-1]))
        # 1. Make Events have the same module as the function being run
        # 2. Convert the Events to Events+Inner Events
        # 3. Convert [(Event,[]*inner events*)...] to [Event,[]*inner events*...]
        # 4. Sure they are sorted correctly
        self.default_events=Events(PREFIXES)\
            .map(Event.from_override_name,override_module=_events[0].module)\
            .map(Event.with_inner)\
            .flat()\
            .sorted()                                                           
        self.events=_events.sorted().map(Event.with_inner).flat().sorted()
        self.events.map(Event.__setattr__,name='loop',value=self)
        self.sections=groupby(self.events,Self.postfix())
        for k,v in self.sections.items():
            self.sections[k]=merge(self.default_events.map(copy).todict(),
                                   Events(v).todict())
            
        
    def copy(self): 
        
        return self.__class__()
    
    
    def climb(self):
        "Returns a generator that moves up to the parent/root event"
        yield self
        if self.parent_loop is not None:
            yield from self.parent_loop.climb()
            
    def run(self):
        try:                                                                    # fastrl.skip_traceback
            for v in self.sections.values(): run_section(v)                     # fastrl.skip_traceback
        except Exception as e:
            e._show_loop_errors=self.verbose
            raise
    

In [505]:
# export
class _Events():
    def __call__(self,loop,reset=False):
        # Handle types/instances...
        if isinstance(loop,type): attrs=loop.__dict__.items()  
        else:                     attrs=inspect.getmembers(loop)
        
        for k,v in attrs:
            if not callable(v): continue
            if any(k.startswith(s) for s in PREFIXES):
                if not isevent(v): setattr(loop,k,Event(v))
                if isevent(v) and reset: setattr(loop,k,Event(v.function))
        return loop
        
events=_Events()

In [506]:
@events
class A(Loop):
    def on_step(self): print('on_step')
    

Check that we can re-initialize `A` and that the `Event`s also reinitialize. 

In [507]:
a=A()
other_a=A()
test_ne(id(a.on_step),id(other_a.on_step))
test_eq_type(a.on_step.loop,a)
a_copy=a.copy()
test_eq_type(a.on_step.loop,a)

In [508]:
@events
class Outer(Loop):
    def before_step(self) :  print('before_step')
    def on_step(self)     :  print('on_step')
    def after_step(self)  :  print('after_step')
    def failed_step(self) :  print('failed_step')
    def finally_step(self):  print('finally_step')
 
    def before_jump(self) :  print('before_jump')
    def on_jump(self)     :  print('on_jump')
    def after_jump(self)  :  print('after_jump')
    def failed_jump(self) :  print('failed_jump')
    def finally_jump(self):  print('finally_jump')

class Inner(Loop):
    call_on=L(Outer.on_step,Outer.after_step,Outer.finally_jump)
    
    @event
    def before_iteration(self) : print('before_iteration')
    @event
    def on_iteration(self)     : print('on_iteration')
    @event
    def after_iteration(self)  : print('after_iteration')
    @event
    def failed_iteration(self) : print('failed_iteration')
    @event
    def finally_iteration(self): print('finally_iteration')

class FailingInner(Loop):
    call_on=L(Inner.finally_iteration)
    
    @event
    def on_force_fail(self):                    
        print('on_force_fail')
        raise Exception

In [509]:
# export
def run_section(section:Dict):
    try:
        section['before_']()
        section['before_inner'].run()                                           # fastrl.skip_traceback
        section['on_']()
        section['on_inner'].run()
        section['after_']()
        section['after_inner'].run()                                            # fastrl.skip_traceback
    except Exception as ex:
        try:     
            section['failed_']()                                                # fastrl.skip_traceback
            raise
        finally: 
            section['failed_inner'].run()                                       # fastrl.skip_traceback
    finally:
        section['finally_']()
        section['finally_inner'].run()                                          # fastrl.skip_traceback

In [510]:
# export
def eq_loops(a:Loop,b:Loop): return a.__class__==b.__class__

@with_cast
def connect_loops2loop(loops:Loops,to_loop):
    # Given `to_loop`, generate some fresh `loops`...
    loops=loops.map(Self.copy())
    to_events=to_loop.events.filter(isevent).map(Self.original_name()) 
    for from_loop in loops.filter(eq_loops,b=to_loop,negate=True):
        for call_on in from_loop.call_on:
            if call_on.original_name in to_events:
                _from_loop=from_loop.copy()
                
                _from_loop.parent_event=to_loop.sections[call_on.postfix][call_on.prefix]
                _from_loop.parent_loop=to_loop
                
                _from_loop.events.filter(isevent).map(Self.init_cbs())

                to_loop.sections[call_on.postfix][call_on.prefix+'inner'].extend([_from_loop])
                connect_loops2loop(loops,_from_loop)
    return to_loop

Check that a single loop connects to the parent...

In [511]:
@events
class A(Loop):
    def on_step(self): print('on_step')
    
@events
class B(Loop):
    call_on=L(A.on_step)
    def on_event(self): print('on_event')
    
@events
class C(Loop):
    call_on=L(B.on_event)
    def on_second_step(self): print('on_second_step')

Check that the parent loop does not reference any non-existant parents. We also
expect it to have 1 section (on_step).

In [512]:
a_instance=connect_loops2loop((B(),C()),A())
test_eq(a_instance.parent_loop,None)
test_eq(a_instance.parent_event,None)
test_eq(a_instance.on_step.root_loop,a_instance)
test_eq(len(a_instance.sections),1)

Check that the inner B loop properly references A, as well as the parent event `on_step`

In [515]:
b_instance=a_instance.sections['step']['on_inner'][0]
test_eq_type(b_instance.__class__,B)
test_eq(b_instance.parent_loop,a_instance)
test_eq(b_instance.parent_event,a_instance.on_step)
test_eq(id(b_instance.on_event.loop),id(b_instance))
test_eq(b_instance.on_event.loop.parent_loop,a_instance)
test_eq(b_instance.on_event.root_loop,a_instance)
test_eq(len(b_instance.sections),1)

Check that the inner C loop properly references B, as well as the parent event `on_event`.
We also should expect the root_loop to still be the a_instance.

In [534]:
c_instance=b_instance.sections['event']['on_inner'][0]
test_eq_type(c_instance.__class__,C)
test_eq(c_instance.parent_loop,b_instance)
test_eq(c_instance.parent_event,b_instance.on_event)
test_eq(id(c_instance.on_second_step.loop),id(c_instance))
test_eq(c_instance.on_second_step.loop.parent_loop,b_instance)
test_eq(c_instance.on_second_step.root_loop,a_instance)
test_eq(len(c_instance.sections),1)

In [535]:
sections=connect_loops2loop(Loops(FailingInner(),Inner()),Outer()).sections

In [536]:
# export
def dict2loops(d):
    if isinstance(d,dict):
        for o in d.values():
            yield from dict2loops(o)
    elif isinstance(d,(Loops,Events)):
        for o in d:
            yield from dict2loops(o) 
    elif issubclass(d.__class__,Loop):
        yield d
        yield from dict2loops(d.sections)
        
def dict2events(d):
    if isinstance(d,dict):
        for o in d.values():
            yield from dict2events(o)
    elif isinstance(d,(Loops,Events)):
        for o in d:
            yield from dict2events(o) 
    elif issubclass(d.__class__,Loop):
        yield from dict2events(d.sections)
    elif issubclass(d.__class__,Event):
        yield d

In [537]:
list(L(dict2events(sections)).map(Self.climb()).map(L).filter().map(Self[-1]))

[<__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 <__main__.Outer at 0x7f742e0f5bb0>,
 

In [540]:
# export
class CallbackException(Exception):pass

class Callback(object):
    call_on,loop=None,None
    
    @property
    def root(self): return loop.root_loop

In [541]:
class OuterCallback(Callback):
    call_on=L(Outer.on_step)

    def before_iteration(self,loss:int)->dict(this=list,that=str):
        print('   OuterCallback called lol')

In [542]:
# export
def map_obj_attr2func_attr(obj,fn):
    got_attrs={}
    for k,v in inspect.signature(fn).parameters.items():
        if k=='self':continue
        elif v.default==inspect._empty:
            got_attrs[k]=getattr(obj,k)
        else:
            got_attrs[k]=getattr(obj,k,v.default)
    return got_attrs

In [543]:
class A():
    loss=0.5
    
a=A()
    
map_obj_attr2func_attr(a,OuterCallback.before_iteration)

{'loss': 0.5}

In [544]:
connect_loops2loop(Loops(FailingInner(),Inner()),Outer(cbs=OuterCallback())).sections

{'jump': {'failed_': __main__.failed_jump,
  'failed_inner': [],
  'before_': __main__.before_jump,
  'before_inner': [],
  'on_': __main__.on_jump,
  'on_inner': [],
  'after_': __main__.after_jump,
  'after_inner': [],
  'finally_': __main__.finally_jump,
  'finally_inner': [<__main__.Inner object at 0x7f742ca7eb80>]},
 'step': {'failed_': __main__.failed_step,
  'failed_inner': [],
  'before_': __main__.before_step,
  'before_inner': [],
  'on_': __main__.on_step,
  'on_inner': [<__main__.Inner object at 0x7f742cb02ac0>],
  'after_': __main__.after_step,
  'after_inner': [<__main__.Inner object at 0x7f742cb6cf70>],
  'finally_': __main__.finally_step,
  'finally_inner': []}}

In [545]:
connect_loops2loop(Loops(FailingInner(),Inner()),Outer()).sections['step']['on_inner'][0].sections['iteration']['before_'].cbs

(#0) []

In [528]:
connect_loops2loop(FailingInner(),Inner()).sections

{'iteration': {'failed_': __main__.failed_iteration,
  'failed_inner': [],
  'before_': __main__.before_iteration,
  'before_inner': [],
  'on_': __main__.on_iteration,
  'on_inner': [],
  'after_': __main__.after_iteration,
  'after_inner': [],
  'finally_': __main__.finally_iteration,
  'finally_inner': [<__main__.FailingInner object at 0x7f742ce55730>]}}

In [529]:
# export
def _skip_traceback(s):
    return in_('# fastrl.skip_traceback',s)
    
def ipy_handle_exception(self, etype, value, tb, tb_offset):
    ## Do something fancy
    stb = self.InteractiveTB.structured_traceback(etype,value,tb,tb_offset=tb_offset)
    if not getattr(value,'_show_loop_errors',True):
        tmp,idxs=[],L(stb).argwhere(_skip_traceback)
        prev_skipped_idx=idxs[0] if idxs else 0
        for i,s in enumerate(stb):
            if i in idxs and i-1!=prev_skipped_idx: 
                msg='Skipped Loop Code due to # fastrl.skip_traceback found in source code,'
                msg+=' please use Loop(...verbose=True) to view loop tracebacks\n'
                tmp.append(msg)
            if i not in idxs:
                tmp.append(s)
            else:
                prev_skipped_idx=i
        stb=tmp
    ## Do something fancy
    self._showtraceback(type, value, stb)

if IN_IPYTHON:
    get_ipython().set_custom_exc((Exception,),ipy_handle_exception)
    

In [530]:
if False: Outer(verbose=False).run(L(Inner(),FailingInner()),OuterCallback())

In [531]:
if False: Outer(verbose=True).run(L(Inner(),FailingInner()),OuterCallback())

In [532]:
connect_loops2loop(Loops(FailingInner(),Inner()),Outer(cbs=OuterCallback())).run()

before_jump
on_jump
after_jump
finally_jump
before_iteration
on_iteration
after_iteration
finally_iteration
on_force_fail


type: 

In [None]:
with ExceptionExpected():
    connect_loops2loop(Loops(FailingInner(),Inner()),Outer()).run()

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()