In [1]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [3]:
# export
# Python native modules
import os,sys
from copy import deepcopy,copy
from typing import *
import types
import logging
import inspect
from itertools import chain
from functools import partial
# Third party libs
from fastcore.all import *
import numpy as np
# Local modules

IN_IPYTHON=False

_logger=logging.getLogger(__name__)

In [4]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display=Display(visible=0,size=(400,300))
    display.start()

# Loop
> fastrl concept of generic loop objects. 

The goal for Loops is to make it easy to customize, and know how sections of code connects
to other parts.

### Why do we need this?
We have identified at least 3 different kinds of loops already:

    Learner (training)
    Source/Gym (Data Access)
    Agent (How an AI takes in data, generates actions)

### What is a loop?

    It should be capable of containing inner loops. 
    It should self-describe its structure. 
    It should be easy to know which parts of the loop are taking long/short amounts of time.
    It should be flexible in state modification.
    It should alternatively make it easy show what fields are being changed at what points in time.

A Loop will act as a compiled structure. The actual result will be a compiled list of nodes that reference the original loop.

In [5]:
# export
EVENT_ORDER_MAPPING={}
PREFIXES=['before_','on_','after_','failed_','finally_']

def _is_event(o): return issubclass(o.__class__,Event)
def _last_element(ls): return ls[-1]
def _grab_full_name(o:'Event'): return o.full_name

class EventException(Exception):pass

def event_parent_iter(event:'Event'):
    yield event
    while True:
        event=event.parent_event
        if event is None: break
        yield event
    
def custom_traceback():
    print('hi')

class Event(object):
    def __init__(self,
                 function:Callable,
                 loop=None,
                 parent_event=None
                ):
        store_attr()
        # We set the order over the entire Loop definition
        self.outer_name=function.__module__+'.'+function.__qualname__.split('.')[0]
        if self.outer_name not in EVENT_ORDER_MAPPING: self.order=1
        else: self.order=EVENT_ORDER_MAPPING[self.outer_name]
        EVENT_ORDER_MAPPING[self.outer_name]=self.order+1
        
        self.full_name=function.__module__+'.'+function.__qualname__
            
        if self.name.startswith('_') or not any(self.name.startswith(pre) for pre in PREFIXES):
            raise EventException(f'{self.name} needs to start with any {PREFIXES}')
            
        self.cbs=L()
        
    def set_cbs(self,cbs=None):
        if cbs is not None:
            self.cbs=L((cb() if isinstance(cb, type) else cb) for cb in L(cbs) 
                       if hasattr(cb,self.name) and (not cb.call_on or any([isrelevent(cb,e) for e in event_parent_iter(self)])))
            
    @property
    def root_loop(self): return self.loop.root_loop
    def __call__(self,*args,**kwargs): 
        ret=self.function(self.loop,*args,**kwargs)                             # fastrl.skip_traceback
        for cb in self.cbs: cb_ret=getattr(cb,self.name)()
        return ret

    def __lt__(self,o:'Event'): return self.order<o.order
    @property
    def name(self): return self.function.__name__
    @property
    def prefix(self): return self.name.split('_')[0]+'_'
    @property
    def postfix(self): return '_'.join(self.name.split('_')[1:])
    def __repr__(self): return self.full_name
    
    def show(self,n_tab=0,n_cbs=False,include_cbs=False,**kwargs):
        kwargs=merge(dict(n_cbs=n_cbs,include_cbs=include_cbs),kwargs)
        tab='\t' if self.loop is None or self.loop.root_loop is None else self.loop.root_loop.tab
        event=n_tab*tab+str(self.full_name)
        if self.cbs: event+=f' #{len(self.cbs)}'
        if include_cbs: 
            for cb in self.cbs:
                event+='\n'+(n_tab*tab)+cb.show(n_tab=n_tab+1,**kwargs)
        return event
    
def isrelevent(loop_or_cb,event:Event):
    if isinstance(event,(list,L)): return False
    return event.full_name in loop_or_cb.call_on.map(_grab_full_name)

event=Event

In [6]:
def on_test():pass

decorated_on_test=Event(on_test)

In [7]:
decorated_on_test.full_name

'__main__.on_test'

In [8]:
class B():
    @event
    def on_test_2_a(self):pass

In [9]:
B.on_test_2_a.full_name

'__main__.B.on_test_2_a'

In [10]:

class A(object):
    loop,root_loop=None,None
    
    @classmethod
    def events(cls,instance=None,instantiate=False):
        loop=instance if instance is not None else (cls() if instantiate else cls)
        events=L(inspect.getmembers(loop)).map(_last_element).filter(_is_event)
        for o in events: o.loop=loop
        return events
    
    @event
    def before_test_c(self):print('before_test_c')
    @event
    def on_test_c(self):print('on_test_c')
    @event
    def after_test_c(self):print('after_test_c')
    @event
    def after_test_c(self):print('after_test_c')
    @event
    def on_test_b(self):print('on_test_b')
    @event
    def on_test_a(self):print('on_test_a')

In [11]:
A().before_test_c()

before_test_c


Notes:
- Having a section object might make this overal loop management better
    - remaining stuff, the Section object needs to handle lists and execute those also.
    - with this in mind, using the section I think will make the looping run much cleaner.

In [12]:
# export
def _grab_postfix(e:Union[Event,list],previous_event:dict=None): 
    if not isinstance(e,Event):
        if not previous_event: 
            raise SectionException(f'{e} doesnt have an event tied to it.')
        else:
            return previous_event['last_event']
    if previous_event is not None:
        previous_event['last_event']=e.postfix
    return e.postfix

def _grab_prefix(o:Event): return o.prefix
def _default_raise(): raise 
def _event2dict(e:Union[Event,list,L],previous_event:dict):
    if not isinstance(e,Event):
        if not previous_event: 
            raise SectionException(f'{e} doesnt have an event tied to it.')
        else:
            return (previous_event['last_event']+'inner',e)
    
    previous_event['last_event']=e.prefix
    return (e.prefix,e)

def _noop_event_pair(k):
    return ((k,noop if k!='failed_' else _default_raise),(k+'inner',[]))


class SectionException(Exception):pass

class Section(object):
    
    def __init__(self,events,parent_event=None):
        store_attr(but='events')
        default_events=L(PREFIXES).map(_noop_event_pair)
        default_events=L(chain.from_iterable(default_events))
        previous_event={}
        self.events_ls=[o for o in events if not isinstance(o,(list,L)) or len(o)!=0]
        if parent_event is not None: 
            for event in self.events_ls: event.parent_event=parent_event
        self.events=merge(
            dict(default_events),
            dict(L(self.events_ls).map(_event2dict,previous_event=previous_event))
        )
        
    @property
    def root_loop(self): return self.events_ls[0].root_loop
    @property
    def loop(self): return self.events_ls[0].loop
    def __repr__(self): return str(self.__class__.__name__)
    def __len__(self): return len(L(self.events.values()).filter(_is_event)) 

    @delegates(Event.show)
    def show(self,n_tab=0,n_events=False,include_events=False,include_defaults=False,**kwargs):
        kwargs=merge(dict(n_events=n_events,
                          include_events=include_events,
                          include_defaults=include_defaults),kwargs)
        tab='\t' if self.root_loop is None else self.root_loop.tab
        section=n_tab*tab+str(self.__class__.__name__)
        if n_events:section+=f' {len(self)} events'
        if include_events:
            for event in self.events.values(): 
                if isinstance(event,(list,L)):
                    for o in event:
                        section+='\n'+(n_tab*tab)+o.show(n_tab=n_tab+1,**kwargs)
                elif event in [_default_raise,noop] and include_defaults:
                    section+='\n'+((n_tab+1)*tab)+str(event)
                elif event not in [_default_raise,noop]:
                    section+='\n'+(n_tab*tab)+event.show(n_tab=n_tab+1,**kwargs)
        return section

    def run(self):
        try:
            self.events['before_']()
            for o in self.events['before_inner']: o.run()                       # fastrl.skip_traceback
            self.events['on_']()
            for o in self.events['on_inner']: o.run()
            self.events['after_']()
            for o in self.events['after_inner']: o.run()                        # fastrl.skip_traceback
        except Exception as ex:
            try:     
                self.events['failed_']()                                        # fastrl.skip_traceback
                raise
            finally: 
                for o in self.events['failed_inner']: o.run()                   # fastrl.skip_traceback
        finally:
            self.events['finally_']()
            for o in self.events['finally_inner']: o.run()                      # fastrl.skip_traceback
    
    @classmethod
    def from_events(cls,events:List[Event],loop=None,parent_event=None):
        previous_event={}
        event_groups=groupby(events,partial(_grab_postfix,previous_event=previous_event))
        return [cls(o,parent_event=parent_event) for o in event_groups.values()]

In [13]:
Section.from_events(A.events())[0].events

{'before_': __main__.A.before_test_c,
 'before_inner': [],
 'on_': __main__.A.on_test_c,
 'on_inner': [],
 'after_': __main__.A.after_test_c,
 'after_inner': [],
 'failed_': <function __main__._default_raise()>,
 'failed_inner': [],
 'finally_': <function fastcore.imports.noop(x=None, *args, **kwargs)>,
 'finally_inner': []}

In [14]:
A.events(A())[0]()

after_test_c


In [15]:
[o.show( 
    n_events=      True,
    include_events=True,
    n_cbs=         True,
    include_cbs=   True) for o in Section.from_events(A.events(A()))]

['Section 3 events\n\t__main__.A.before_test_c\n\t__main__.A.on_test_c\n\t__main__.A.after_test_c',
 'Section 1 events\n\t__main__.A.on_test_a',
 'Section 1 events\n\t__main__.A.on_test_b']

In [16]:
[print(o.show( 
    n_events=      True,
    include_events=True,
    n_cbs=         True,
    include_cbs=   True)) for o in Section.from_events(A.events(A()))]

Section 3 events
	__main__.A.before_test_c
	__main__.A.on_test_c
	__main__.A.after_test_c
Section 1 events
	__main__.A.on_test_a
Section 1 events
	__main__.A.on_test_b


[None, None, None]

In [17]:
[o.run() for o in Section.from_events(A.events(A()))]

before_test_c
on_test_c
after_test_c
on_test_a
on_test_b


[None, None, None]

In [18]:
# export
def _is_event(o): return issubclass(o.__class__,Event)
def _last_element(ls): return ls[-1]
def _loop2sections(loop,loops,cbs): return loop.sections(loops=loops,cbs=cbs)
def _loop_with_sections(loop,loops,cbs,parent,parent_event): 
    loop.loop=parent
    return loop.from_sections(loops=loops,cbs=cbs,parent_event=parent_event)

class class_or_instancemethod(classmethod):
    "From: https://stackoverflow.com/questions/28237955/same-name-for-classmethod-and-instancemethod"
    def __get__(self, instance, type_):
        descr_get = super().__get__ if instance is None else self.__func__.__get__
        return descr_get(instance, type_)

class Loop(object):  
    call_on,loop,loops,cbs,tab,verbose=L(),None,None,None,'  ',False
    
    @class_or_instancemethod
    def events(cls_or_self,loops=None,cbs=None):
        events=L(inspect.getmembers(cls_or_self)).map(_last_element)\
                                                 .filter(_is_event)\
                                                 .sorted()
        for o in events: 
            o.loop=cls_or_self
            for cb in L(cbs): cb.loop=cls_or_self
            o.set_cbs(cbs)
        
        events=chain.from_iterable([
            (o,L(ifnone(loops,L()).filter(isrelevent,event=o)\
                                .map(_loop_with_sections,
                                     loops=loops,cbs=cbs,
                                     parent=cls_or_self,
                                     parent_event=o))) 
            for o in events
        ])
        return events

    @class_or_instancemethod
    def get_sections(cls_or_self,loops=None,cbs=None,parent_event=None):
        cls_or_self.loops=ifnone(cls_or_self.loops,loops)
        cls_or_self.cbs=ifnone(cls_or_self.cbs,cbs)
        events=cls_or_self.events(loops=L(cls_or_self.loops),cbs=L(cls_or_self.cbs))
        sections=Section.from_events(events,cls_or_self,parent_event)
        return sections
    
    @classmethod
    def from_sections(cls,**kwargs):
        loop=cls()
        loop.sections=loop.get_sections(**kwargs)
        return loop
    
    def run(self,loops=None,cbs=None):
        try:
            sections=self.get_sections(loops=loops,cbs=cbs)                     # fastrl.skip_traceback
            for section in sections:                                            # fastrl.skip_traceback
                section.run()                                                   # fastrl.skip_traceback
        except Exception as e:
            e._show_loop_errors=self.verbose
            raise
            
    @property
    def root_loop(self): return self if self.loop is None else self.loop
    def __len__(self):  return len(self.sections)
    def __repr__(self): return str(self.__class__.__name__)

    @delegates(Section.show)
    def show(self,n_tab=0,n_sections=False,include_sections=False,**kwargs):
        tab=self.root_loop.tab
        kwargs=merge(dict(n_sections=n_sections,include_sections=include_sections),kwargs)
        loop=n_tab*tab+str(self.__class__.__name__)
        if n_sections:loop+=f' {len(self)} sections'
        if include_sections:
            for section in self.sections: 
                loop+='\n'+(n_tab*tab)
                loop+=section.show(n_tab=n_tab+1,**kwargs)
        return loop

In [19]:
class Outer(Loop):
    @event
    def before_step(self) :  print('before_step')
    @event
    def on_step(self)     :  print('on_step')
    @event
    def after_step(self)  :  print('after_step')
    @event
    def failed_step(self) :  print('failed_step')
    @event
    def finally_step(self):  print('finally_step')
 
    @event
    def before_jump(self) :  print('before_jump')
    @event
    def on_jump(self)     :  print('on_jump')
    @event
    def after_jump(self)  :  print('after_jump')
    @event
    def failed_jump(self) :  print('failed_jump')
    @event
    def finally_jump(self):  print('finally_jump')

class Inner(Loop):
    call_on=L(Outer.on_step,Outer.after_step,Outer.finally_jump)
    
    @event
    def before_iteration(self) : print('before_iteration')
    @event
    def on_iteration(self)     : print('on_iteration')
    @event
    def after_iteration(self)  : print('after_iteration')
    @event
    def failed_iteration(self) : print('failed_iteration')
    @event
    def finally_iteration(self): print('finally_iteration')
    # def run(self,loops=None,cbs=None):
    #     sections=self.get_sections(loops=loops,cbs=cbs)
    #     for section in sections:
    #         section.run()
    
class FailingInner(Loop):
    call_on=L(Inner.finally_iteration)
    
    @event
    def on_force_fail(self):                    
        print('on_force_fail')
        raise Exception
        

In [20]:
Outer().get_sections(L(Inner(),FailingInner()))

[Section, Section]

In [21]:
# export
class CallbackException(Exception):pass

class Callback(object):
    call_on,loop=None,None
    
    def show(self,n_tab=0,**kwargs):
        tab=self.loop.root_loop.tab if self.loop is not None else '\t'
        return (tab*n_tab)+str(self)

In [37]:
class OuterCallback(Callback):
    call_on=L(Outer.on_step)
    
    def before_iteration(self)->dict(this=list,that=str):
        print('   OuterCallback called lol')

Outer().get_sections(L(Inner(),FailingInner()),OuterCallback)

[Section, Section]

In [38]:
print(Outer.from_sections(
    loops=L(Inner(),FailingInner()),
    cbs=OuterCallback
                   
                   
).show(
    n_sections=      True,
    include_sections=True,
    n_events=        True,
    include_events=  True,
    n_cbs=           True,
    include_cbs=     True,
    include_defaults=True
))

Outer 2 sections
  Section 5 events
      __main__.Outer.before_step
      __main__.Outer.on_step
      Inner 1 sections
          Section 5 events
              __main__.Inner.before_iteration
              __main__.Inner.on_iteration
              __main__.Inner.after_iteration
              __main__.Inner.failed_iteration
              __main__.Inner.finally_iteration
              FailingInner 1 sections
                  Section 1 events
            <function noop at 0x7f8cc8309ca0>
                      __main__.FailingInner.on_force_fail
            <function noop at 0x7f8cc8309ca0>
            <function _default_raise at 0x7f8cc83091f0>
            <function noop at 0x7f8cc8309ca0>
      __main__.Outer.after_step
      Inner 1 sections
          Section 5 events
              __main__.Inner.before_iteration
              __main__.Inner.on_iteration
              __main__.Inner.after_iteration
              __main__.Inner.failed_iteration
              __main__.Inner.finally_ite

`Reference: https://stackoverflow.com/questions/31949760/how-to-limit-python-traceback-to-specific-files`

`Reference: https://github.com/ipython/ipython/blob/8520f3063ca36655b5febbbd18bf55e59cb2cbb5/IPython/core/interactiveshell.py#L1945`

In [24]:
# export
def _skip_traceback(s):
    return in_('# fastrl.skip_traceback',s)
    
def ipy_handle_exception(self, etype, value, tb, tb_offset):
    ## Do something fancy
    stb = self.InteractiveTB.structured_traceback(etype,value,tb,tb_offset=tb_offset)
    if not getattr(value,'_show_loop_errors',True):
        tmp,idxs=[],L(stb).argwhere(_skip_traceback)
        prev_skipped_idx=idxs[0] if idxs else 0
        for i,s in enumerate(stb):
            if i in idxs and i-1!=prev_skipped_idx: 
                msg='Skipped Loop Code due to # fastrl.skip_traceback found in source code,'
                msg+=' please use Loop(...verbose=True) to view loop tracebacks\n'
                tmp.append(msg)
            if i not in idxs:
                tmp.append(s)
            else:
                prev_skipped_idx=i
        stb=tmp
    ## Do something fancy
    self._showtraceback(type, value, stb)

if IN_IPYTHON:
    get_ipython().set_custom_exc((Exception,),ipy_handle_exception)
    

In [32]:
if False:
    Outer.verbose=False
    Outer().run(L(Inner(),FailingInner()),OuterCallback())

before_step
on_step
before_iteration
on_iteration
after_iteration
finally_iteration
on_force_fail
failed_step
finally_step


type: 

In [34]:
if False:
    Outer.verbose=True
    Outer().run(L(Inner(),FailingInner()),OuterCallback())

before_step
on_step
before_iteration
on_iteration
after_iteration
finally_iteration
on_force_fail
failed_step
finally_step


type: 

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()
    # notebook2script('[!02_fastai.loop.old]*.ipynb')
    # notebook2html('[!02_fastai.loop.old]*.ipynb')

In [None]:
# # export
# # Reference: https://stackoverflow.com/questions/31949760/how-to-limit-python-traceback-to-specific-files
# # __mycode = True

# # def is_mycode(tb):
# #     globals = tb.tb_frame.f_globals
# #     return globals.has_key('__mycode')

# # def mycode_traceback_levels(tb):
# #     length = 0
# #     while tb and is_mycode(tb):
# #         tb = tb.tb_next
# #         length += 1
# #     return length
# import traceback
# from traceback import TracebackException
# __mycode = True

# def callers_module():
#     module_name = inspect.currentframe().f_back.f_globals["__name__"]
#     return sys.modules[module_name]

# def is_mycode(tb):
#     return globals().get('__mycode',False)

# def mycode_traceback_levels(tb):
#     length = 0
#     while tb and is_mycode(tb):
#         tb = tb.tb_next
#         length += 1
#     return length

# def handle_exception(type, value, tb, tb_offset):
#     # 1. skip custom assert code, e.g.
#     # while tb and is_custom_assert_code(tb):
#     #   tb = tb.tb_next
#     # 2. only display your code
#     length = mycode_traceback_levels(tb)
    
#     # Reference: https://github.com/ipython/ipython/blob/8520f3063ca36655b5febbbd18bf55e59cb2cbb5/IPython/core/interactiveshell.py#L1945

#     # return [str(line) for line in TracebackException(
#     #         type, value, tb, limit=length).format(chain=True)]
#     return TracebackException(
#             type, value, tb, limit=length).format(chain=True)


# def ipy_handle_exception(self, etype, value, tb, tb_offset):
#     # 1. skip custom assert code, e.g.
#     # while tb and is_custom_assert_code(tb):
#     #   tb = tb.tb_next
#     # 2. only display your code
#     # length = mycode_traceback_levels(tb)
    
#     # Reference: https://github.com/ipython/ipython/blob/8520f3063ca36655b5febbbd18bf55e59cb2cbb5/IPython/core/interactiveshell.py#L1945

#     # return [str(line) for line in TracebackException(
#     #         type, value, tb, limit=length).format(chain=True)]
#     print('Returning traceback')
# #     ss= [str(line) for line in TracebackException(
# #             type, value, tb, limit=None).format(chain=True)]
    
#     stb = self.InteractiveTB.structured_traceback(etype,
#                                             value, tb, tb_offset=tb_offset)
    
#     # print(stb)
#     self._showtraceback(type, value, stb)
    
#     # print(ss)
#     # return ss
# get_ipython().set_custom_exc((Exception,),ipy_handle_exception)
    
# # sys.excepthook = handle_exception

# def custom_traceback():
#     exception_info=sys.exc_info()
#     # print(exception_info)
#     ex=handle_exception(*exception_info)
#     # get_ipython()
    