In [4]:
#hide
#skip
%config Completer.use_jedi = False
%config IPCompleter.greedy=True
# upgrade fastrl on colab
! [ -e /content ] && pip install -Uqq fastrl['dev'] pyvirtualdisplay && \
                     apt-get install -y xvfb python-opengl > /dev/null 2>&1 
# NOTE: IF YOU SEE VERSION ERRORS, IT IS SAFE TO IGNORE THEM. COLAB IS BEHIND IN SOME OF THE PACKAGE VERSIONS

In [5]:
# default_exp fastai.loop

In [6]:
# export
# Python native modules
import os,sys
from copy import deepcopy,copy
from typing import *
import types
import logging
import inspect
from itertools import chain
from functools import partial
# Third party libs
from fastcore.all import *
import numpy as np
# Local modules

IN_IPYTHON=False

_logger=logging.getLogger(__name__)

In [7]:
# hide
from fastcore.imports import in_colab
# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbverbose.showdoc import *
    from nbdev.imports import *
    if not os.environ.get("IN_TEST", None):
        assert IN_NOTEBOOK
        assert not IN_COLAB
        assert IN_IPYTHON
else:
    # Virutual display is needed for colab
    from pyvirtualdisplay import Display
    display=Display(visible=0,size=(400,300))
    display.start()

# Loop
> fastrl concept of generic loop objects. 

The goal for Loops is to make it easy to customize, and know how sections of code connects
to other parts.

### Why do we need this?
We have identified at least 4 different kinds of loops already:

    Learner (Training)
    Source/Gym (Data Access)
    Agent (How an AI takes in data, generates actions)
    DataLoader (Transforms)

### What is a loop?

    It should be capable of containing inner loops/sections/events/callbacks. 
    It should self-describe its structure. 
    It should be easy to know which parts of the loop are taking long/short amounts of time.
    It should be stateless
    It should alternatively make it easy show what fields are being changed at what points in time.
    
There is a possibility that portions of the loop 

In [69]:
# export
EVENT_ORDER_MAPPING={}
PREFIXES=['before_','on_','after_','failed_','finally_']

def isnone(o):                  return o is None
def _is_event(o):               return issubclass(o.__class__,Event)
def _last_element(ls):          return ls[-1]
def _grab_full_name(o:'Event'): return o.full_name
def iseventable(o:Any):         
    if not isinstance(o,Callable): return False
    if o.__name__.startswith('_'): return False
    if not any(o.__name__.startswith(pre) for pre in PREFIXES): return False
    return True

class EventException(Exception):pass

def event_parent_iter(event:'Event'):
    yield event
    while True:
        event=event.parent_event
        if event is None: break
        yield event

class Event(object):
    def __init__(self,
                 function:Callable, # The function/method from a loop to execute.
                 loop:'Loop'=None, # The immediate loop that this event is a part of.
                 parent_event:'Event'=None # If this event is nested, then this will be a reference to outer events.
                ):
        store_attr()
        # We set the order over the entire Loop definition
        self.outer_name=function.__module__+'.'+function.__qualname__.split('.')[0]
        if self.outer_name not in EVENT_ORDER_MAPPING: 
            self.order=1
        else: 
            self.order=EVENT_ORDER_MAPPING[self.outer_name]
        EVENT_ORDER_MAPPING[self.outer_name]=self.order+1
        
        self.full_name=function.__module__+'.'+function.__qualname__
            
        if not iseventable(self.function):
            raise EventException(f'{self.name} needs to start with any {PREFIXES}')
            
        self.cbs=L()        

    def __lt__(self,o:'Event'): return self.order<o.order
    def __repr__(self): return self.full_name
    @property
    def name(self): return self.function.__name__
    @property
    def prefix(self): return self.name.split('_')[0]+'_'
    @property
    def postfix(self): return '_'.join(self.name.split('_')[1:])
    @property
    def root_loop(self): return self.loop.root_loop

In [71]:

class A(object):
    loop,root_loop=None,None
    
    @classmethod
    def events(cls,instance=None,instantiate=False):
        loop=instance if instance is not None else (cls() if instantiate else cls)
        events=L(loop.__dict__.values())\
            .filter(isnone,negate=True)\
            .filter(iseventable)\
            .map(Event)
        for o in events: o.loop=loop
        return events
    def not_event(self):     print('This shouldnt be included in the events')
    def before_test_c(self): print('before_test_c')
    def on_test_c(self):     print('on_test_c')
    def after_test_c(self):  print('after_test_c')
    def on_test_a(self):     print('on_test_a')
    def on_test_b(self):     print('on_test_b')

In [72]:
A.events()

(#5) [__main__.A.before_test_c,__main__.A.on_test_c,__main__.A.after_test_c,__main__.A.on_test_a,__main__.A.on_test_b]

In [73]:
# export
def _grab_postfix(e:Union[Event,list],previous_event:dict=None): 
    if not isinstance(e,Event):
        if not previous_event: 
            raise SectionException(f'{e} doesnt have an event tied to it.')
        else:
            return previous_event['last_event']
    if previous_event is not None:
        previous_event['last_event']=e.postfix
    return e.postfix

def _grab_prefix(o:Event): return o.prefix
def _default_raise(): raise 
def _event2dict(e:Union[Event,list,L],previous_event:dict):
    if not isinstance(e,Event):
        if not previous_event: 
            raise SectionException(f'{e} doesnt have an event tied to it.')
        else:
            return (previous_event['last_event']+'inner',e)
    
    previous_event['last_event']=e.prefix
    return (e.prefix,e)

def _noop_event_pair(k):
    return ((k,noop if k!='failed_' else _default_raise),(k+'inner',[]))

class SectionException(Exception):pass

class Section(object):
    
    def __init__(self,events,parent_event=None):
        store_attr(but='events')
        default_events=L(PREFIXES).map(_noop_event_pair)
        default_events=L(chain.from_iterable(default_events))
        previous_event={}
        self.events_ls=[o for o in events if not isinstance(o,(list,L)) or len(o)!=0]
        if parent_event is not None: 
            for event in self.events_ls: event.parent_event=parent_event
        self.events=merge(
            dict(default_events),
            dict(L(self.events_ls).map(_event2dict,previous_event=previous_event))
        )
        
    @property
    def root_loop(self): return self.events_ls[0].root_loop
    @property
    def loop(self):      return self.events_ls[0].loop
    def __repr__(self):  return str(self.__class__.__name__)
    def __len__(self):   return len(L(self.events.values()).filter(_is_event)) 
    def run_inner(self,event): [o.run() for o in self.events[event]]            # fastrl.skip_traceback
        
    def run(self):
        try:
            self.events['before_']()
            self.run_inner('before_inner')                                      # fastrl.skip_traceback
            self.events['on_']()
            self.run_inner('on_inner')
            self.events['after_']()
            self.run_inner('after_inner')                                       # fastrl.skip_traceback
        except Exception as ex:
            try:     
                self.events['failed_']()                                        # fastrl.skip_traceback
                raise
            finally: 
                self.run_inner('failed_inner')                                  # fastrl.skip_traceback
        finally:
            self.events['finally_']()
            self.run_inner('finally_inner')                                     # fastrl.skip_traceback
    
    @classmethod
    def from_events(cls,events:List[Event],loop=None,parent_event=None):
        previous_event={}
        event_groups=groupby(events,partial(_grab_postfix,previous_event=previous_event))
        return [cls(o,parent_event=parent_event) for o in event_groups.values()]

In [None]:
# hide
from fastcore.imports import in_colab

# Since colab still requires tornado<6, we don't want to import nbdev if we don't have to
if not in_colab():
    from nbdev.export import *
    from nbdev.export2html import *
    from nbverbose.cli import *
    make_readme()
    notebook2script()
    notebook2html()