In [None]:
#| default_exp callback

In [None]:
#| export
from __future__ import annotations

In [None]:
#| hide
# %reload_ext autoreload
# %autoreload 0

# Callback

> Base and helpers for classes that can be augmented with callbacks
> 

# Prologue

In [None]:
#| export

import time
from contextlib import contextmanager
from functools import partial
from operator import attrgetter
from typing import Any
from typing import Iterable
from typing import Sequence
from typing import Type
from typing import TypeAlias

import fastcore.all as FC


In [None]:
#| export
from olio.basic import _EMPTY
from olio.basic import AD


In [None]:
from itertools import repeat

from fastcore.test import *
from olio.project import setup_console
from olio.test import *


----

In [None]:
console, cprint = setup_console(108)

----

# Callback


In [None]:
#| export

class Callback(): 
    order = 0
    cbs: list[Callback]  # if present, run callbacks in the list after running this callback


In [None]:
#| export

def run_cbs(cbs: Sequence[Callback] | FC.L, method_nm:str, caller=None, *args, **kwargs):
    "Run `method_nm(caller)` of each callback in `cbs` in order."
    for cb in sorted(cbs, key=attrgetter('order')):
        if method := getattr(cb, method_nm, None): method(caller, *args, **kwargs)
        if nested := getattr(cb, 'cbs', None): run_cbs(nested, method_nm, caller, *args, **kwargs)


In [None]:
class VerboseCB(Callback):
    def before_fit(self, caller): self.count = 0
    def after_batch(self, caller): self.count += 1
    def after_fit(self, caller): print(f'{caller} Completed {self.count} batches')
    

In [None]:
cbs = [VerboseCB()]
run_cbs(cbs, 'before_fit')
test_eq(cbs[0].count, 0)
run_cbs(cbs, 'after_batch')
test_eq(cbs[0].count, 1)
test_stdout(lambda: run_cbs(cbs, 'after_fit'), 'None Completed 1 batches')


# PassCB

> A callback that does nothing

In [None]:
#| export

class PassCB(Callback):
    def noop(self, *args, **kwargs): pass
    def __getattr__(self, name): return self.noop


## HasCallbacks

In [None]:
#| export

class HasCallbacks:
    "Base for classes that can be augmented with callbacks."
    cbs: list[Callback]
    cbs_names: tuple[str,...] = ()  # this callbacks becomes self methods if a subclass overrides them

    def __new__(cls, *args, **kwargs):
        self = super().__new__(cls)
        self.cbs = []#PassCB()]
        return self

    def __init__(self, cbs:Sequence[Callback]=()): self.cbs = list(cbs)

    def with_cbs(self, cbs:Sequence[Callback], extend=False):
        if extend: self.cbs.extend(cbs)
        else: self.cbs = list(cbs)
        return self
    
    @contextmanager
    def this_cbs(self, cbs:Sequence[Callback]):
        "Use temporary `cbs` in `with` block."
        for cb in cbs: self.cbs.append(cb)
        try:
            yield        
        finally:
            for cb in cbs: self.cbs.remove(cb)

    def __getattr__(self, name):
        "Run `name` as a callback if it's in `self.cbs_names`."
        if name in self.cbs_names: return partial(self.callback, name)
        raise AttributeError(name)

    def callback(self, method_nm, *args, **kwargs): run_cbs(self.cbs, method_nm, self, *args, **kwargs)


In [None]:
class Test(HasCallbacks):
    count = 0
    cbs_names = ('on_think',)
    
    def think(self):
        self.callback('before_think')
        self.count += 1
        print('thinking...')
        self.on_think()
        self.callback('after_think')

    def act(self):
        self.callback('before_act')
        self.count = 1
        print('acting...')
        self.think()
        self.callback('after_act')

class VerboseCB(Callback):
    def before_act(self, ctx): print('before_act count:', ctx.count)
    def after_act(self, ctx): print('after_act count:', ctx.count)

test = Test()
test.act()
test_eq(test.count, 2)

print()
test = Test([VerboseCB()])
test.act()
test_eq(test.count, 2)
class ThinkCB(Callback):
    def on_think(self, ctx): ctx.count += 1

print()
test = Test([VerboseCB(), ThinkCB()])
test.act()
test_eq(test.count, 3)


acting...
thinking...

before_act count: 0
acting...
thinking...
after_act count: 2

before_act count: 0
acting...
thinking...
after_act count: 3


## with_cbs

In [None]:
#| export

class with_cbs:
    def __init__(self, nm:str|None=None): self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            nm = self.nm or f.__name__
            try:
                o.callback(f'before_{nm}')
                f(o, *args, **kwargs)
                o.callback(f'after_{nm}')
            except globals()[f'Cancel{nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{nm}')
        return _f

In [None]:
class Test(HasCallbacks):
    count = 0
    
    @with_cbs()
    def think(self):
        print('thinking...')
        self.count += 1

    @with_cbs('act')
    def act(self, cbs:Sequence[Callback]=()):
        with self.this_cbs(cbs):
            print('acting...')
            self.count = 1
            self.think()


class ActCB(Callback):
    def before_act(self, ctx): print('before_act count:', ctx.count)
    def after_act(self, ctx): 
        print('after_act count:', ctx.count)
        ctx.count += 1

class ThinkCB(Callback):
    def before_think(self, ctx): print('before_think count:', ctx.count)
    def after_think(self, ctx): 
        print('after_think count:', ctx.count)
        ctx.count += 1


test = Test()
test.act()
test_eq(test.count, 2)
print()

test = Test([ActCB()])
test.act()
test_eq(test.count, 3)
print()

test = Test([ActCB()])
test.act([ThinkCB()])
test_eq(test.count, 4)


acting...
thinking...

before_act count: 0
acting...
thinking...
after_act count: 2

before_act count: 0
acting...
before_think count: 1
thinking...
after_think count: 2
after_act count: 3


# CollectionTracker

> Track progress of an iterable over a collection

In [None]:
#| export

NoTotalT: TypeAlias = None

In [None]:
#| export

class CollectionTracker(HasCallbacks):
    "Base for tracking iteration state over a collection, extensible via subclassing."
    idx: int|None = None
    start_time: float|None = None
    elapsed_time: float|None = None

    cbs_names = ('on_start', 'on_stop', 'before_iter', 'after_iter', 'on_update', 'on_interrupt')

    @property
    def state(self):
        st = AD(idx=self.idx, total=self.total, 
            progress=self.progress, elapsed_time=self.elapsed_time)
        if self.interrupted: st.interrupted = True
        return st

    @property
    def progress(self):
        if self.total is not None and self.idx is not None and self.total:
            return min(1.0, round((self.idx+1) / float(self.total), 4))
        return None

    def update(self, idx:int|None=None, item:Any=_EMPTY):
        if self.idx is None: self.idx = self._start(idx or 0)
        if not self.active or (self.total is not None and self.idx >= self.total): return
        self.elapsed_time = time.time() - self.start_time  # type: ignore
        self.on_update(item) if item is not _EMPTY else self.on_update()
        if idx is None: idx = self.idx + 1
        if self.total is not None and idx >= self.total:
            self.after_iter()
            self._stop()
        self.idx = idx
    
    def _start(self, idx:int):
        self.active = True
        self.start_time = time.time()
        self.on_start()
        self.before_iter()
        self.idx = idx
        return idx

    def _stop(self):
        if self.active:
            self.active = False
            self.on_stop()

    def __iter__(self):
        # if self.total != 0: self._setup(0)
        self._start(0)
        try:
            for i, o in enumerate(self.source):
                if self.total is not None and i >= self.total: break
                yield o
                self.update(item=o)
            if self.total is None:
                self.total = self.idx
                self.update(self.total)
        except Exception as e:
            self.interrupted = True
            self.on_interrupt()
            raise e
        finally: self._stop()

    def __init__(self, source:Iterable[Any], total:int|NoTotalT|Type[_EMPTY]=_EMPTY, **kwargs):
        super().__init__(**kwargs)
        self.source = source
        if total is _EMPTY: 
            try: total = len(source)  # type: ignore
            except: total = None
        self.total: int|None = total  # type: ignore
        self.interrupted = False
        self.active = False

In [None]:
t = CollectionTracker(())
test_eq(t.total, 0)
with test_raises(StopIteration): next(iter(t))
for i in t: pass

t = CollectionTracker(range(6))
test_eq(t.state, {'idx': None, 'total': 6, 'progress': None, 'elapsed_time': None})
test_eq(next(iter(t)), 0)
test_eq(t.state, {'idx': 0, 'total': 6, 'progress': 0.1667, 'elapsed_time': t.elapsed_time})

t = CollectionTracker(range(6))
test_eq([t.idx for i in t], range(6))
test_eq(t.state, {'idx': 6, 'total': 6, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


In [None]:
t = CollectionTracker('abc')
test_eq(t.total, 3)
test_eq(t.state, {'idx': None, 'total': 3, 'progress': None, 'elapsed_time': None})
test_eq(next(it := iter(t)), 'a')
test_eq(next(it), 'b')
test_eq(t.state, {'idx': 1, 'total': 3, 'progress': 0.6667, 'elapsed_time': t.elapsed_time})

t = CollectionTracker('abc', None)
test_eq([_ for _ in t], ('a', 'b', 'c'))
test_eq(t.state, {'idx': 3, 'total': 3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


In [None]:
t = CollectionTracker(c := (12, 56, -1, 2, 67), None)
oo = []
for i, o in enumerate(c):
    t.update(item=o)
    test_eq(t.idx, i+1)
    oo.append(o)
test_eq(oo, c)
test_eq(t.state, {'idx': len(c), 'total': None, 'progress': None, 'elapsed_time': t.elapsed_time})

t = CollectionTracker(c, 3)
for i in range(5):
    t.update()
    test_eq(t.idx, min(i+1, 3))
test_eq(t.state, {'idx':3, 'total':3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})

t = CollectionTracker(c, 3)
oo = [o for o in t]
test_eq(oo, (12, 56, -1))
test_eq(t.state, {'idx':3, 'total':3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


In [None]:
from functools import reduce

t = CollectionTracker(c := (1, 2, -1, -2, 7))
test_eq(reduce(lambda x, y: x+y, t), 7)
test_eq(t.state, {'idx': 5, 'total': 5, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


class CountCB(Callback):
    def on_start(self, ctx): self.count = 0
    def on_update(self, ctx, item=_EMPTY): 
        self.count += 1

with (t := CollectionTracker(c)).this_cbs([cb := CountCB()]):
    test_eq(reduce(lambda x, y: x+y, t), 7)
test_eq(t.state, {'idx': 5, 'total': 5, 'progress': 1.0, 'elapsed_time': t.elapsed_time})
test_eq(cb.count, 5)

test_eq(reduce(lambda x, y: x+y, (t := CollectionTracker(c, cbs=[cb]))), 7)
test_eq(t.state, {'idx': 5, 'total': 5, 'progress': 1.0, 'elapsed_time': t.elapsed_time})
test_eq(cb.count, 5)


In [None]:
t = CollectionTracker(repeat(7), 3)
oo = []
for o in t:
    oo.append(o)
test_eq(oo, (7, 7, 7))
test_eq(t.state, {'idx': 3, 'total': 3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})

t = CollectionTracker(repeat(None), None)
for _ in t:
    if t.state.idx >= 10: break  # type: ignore
test_eq(t.state, {'idx': 10, 'total': None, 'progress': None, 'elapsed_time': t.elapsed_time})


# Colophon
----

In [None]:
import fastcore.all as FC
import nbdev
from nbdev.clean import nbdev_clean


In [None]:
if FC.IN_NOTEBOOK:
    nb_path = '10_callback.ipynb'
    nbdev_clean(nb_path)
    nbdev.nbdev_export(nb_path)
