In [None]:
#| default_exp callback

In [None]:
#| export
from __future__ import annotations

In [None]:
#| hide
# %reload_ext autoreload
# %autoreload 0

# Callback

> Base and helpers for classes that can be augmented with callbacks
> 

# Prologue

In [None]:
#| export
import collections
import time
from contextlib import contextmanager
from functools import partial
from operator import attrgetter
from operator import length_hint
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Iterator
from typing import Sequence
from typing import Type
from typing import TypeVar

import fastcore.all as FC


In [None]:
#| export
from olio.basic import AD
from olio.basic import empty
from olio.basic import EmptyT
from olio.basic import update_
from olio.common import Runner
from olio.common import setattrs


In [None]:
import itertools
import operator
from collections import ChainMap
from collections import Counter
from collections import deque
from itertools import islice
from itertools import repeat
from typing import cast
from typing import TypeAlias

from fastcore.test import *
from olio.common import *
from olio.project import setup_console
from olio.test import *


----

In [None]:
console, cprint = setup_console(108)

----

# Callback


In [None]:
#| export

class Callback: 
    order = 0
    cbs: Sequence[Callback]  # if present, run callbacks in the list after running this callback
    

In [None]:
filter(lambda x: x != 'b', ['a', 'b', 'c'])

<filter>

In [None]:
#| export

def run_cbs(cbs: Iterable[Callback] | FC.L, method_nm:str, ctx=None, *args, **kwargs):
    "Run `method_nm(ctx, ...)` of each callback in `cbs` in order."
    for cb in sorted(cbs, key=attrgetter('order')):
        if f := getattr(cb, method_nm, None): Runner(f)(ctx, *args, **kwargs)
        if nested := getattr(cb, 'cbs', None): 
            run_cbs(filter(lambda x: x != cb, nested), method_nm, ctx, *args, **kwargs)


## EchoCB

In [None]:
#| export

class EchoCB(Callback):
    cbs=()
    def echo(self, ctx, *args, **kwargs): print(ctx, args, kwargs)
    def __getattr__(self, name): return partial(self.echo, name)


In [None]:
run_cbs((EchoCB(),), 'on_start', AD(count=3))
run_cbs((EchoCB(),), 'on_update', AD(count=3), 12)

on_start ({'count': 3},) {}
on_update ({'count': 3}, 12) {}


In [None]:
class VerboseCB(Callback):
    def before_fit(self, ctx): self.count = 0
    def after_batch(self, ctx): self.count += 1
    def after_fit(self, ctx): print(f'{ctx} Completed {self.count} batches')
    

In [None]:
cbs = [VerboseCB()]
run_cbs(cbs, 'before_fit')
test_eq(cbs[0].count, 0)
run_cbs(cbs, 'after_batch')
test_eq(cbs[0].count, 1)
test_stdout(lambda: run_cbs(cbs, 'after_fit'), 'None Completed 1 batches')


In [None]:
# def FuncCB(**kwargs):
#     setattrs(cb := Callback(), kwargs)
#     return cb


## FuncCB

In [None]:
#| export

class FuncCB(Callback):
    def __init__(self, **kwargs): setattrs(self, kwargs)


In [None]:
cb = FuncCB(on_count=lambda ctx: print(ctx.count))
run_cbs((cb,), 'on_count', o := AD(count=3))


3


In [None]:
def inc(ctx): ctx.count += 1

cb = FuncCB(on_count=(inc, lambda ctx: print(ctx.count)))
run_cbs((cb,), 'on_count', o := AD(count=3))
test_eq(o.count, 4)


4


## PassCB

> A callback that does nothing

In [None]:
#| export

class PassCB(Callback):
    def noop(self, *args, **kwargs): pass
    def __getattr__(self, name): return self.noop


# HasCallbacks

In [None]:
#| export

class HasCallbacks:
    "Base for classes that can be augmented with callbacks."
    cbs: list[Callback]
    cbs_names: tuple[str,...] = ()  # this callbacks becomes self methods if a subclass overrides them

    def __new__(cls, *args, **kwargs):
        self = super().__new__(cls)
        self.cbs = []#PassCB()]
        return self

    def __init__(self, cbs:Sequence[Callback]=()): self.cbs = list(cbs)

    def with_cbs(self, cbs:Sequence[Callback], extend=False):
        if extend: self.cbs.extend(cbs)
        else: self.cbs = list(cbs)
        return self
    
    @contextmanager
    def this_cbs(self, cbs:Sequence[Callback]):
        "Use temporary `cbs` in `with` block."
        for cb in cbs: self.cbs.append(cb)
        try:
            yield        
        finally:
            for cb in cbs: self.cbs.remove(cb)

    def __getattr__(self, name):
        "Run `name` as a callback if it's in `self.cbs_names`."
        if name in self.cbs_names: return partial(self.callback, name)
        raise AttributeError(name)

    def callback(self, method_nm, *args, **kwargs): run_cbs(self.cbs, method_nm, self, *args, **kwargs)


In [None]:
class Test(HasCallbacks):
    count = 0
    cbs_names = ('on_think',)
    
    def think(self):
        self.callback('before_think')
        self.count += 1
        print('thinking...')
        self.on_think()
        self.callback('after_think')

    def act(self):
        self.callback('before_act')
        self.count = 1
        print('acting...')
        self.think()
        self.callback('after_act')

class VerboseCB(Callback):
    def before_act(self, ctx): print('before_act count:', ctx.count)
    def after_act(self, ctx): print('after_act count:', ctx.count)

test = Test()
test.act()
test_eq(test.count, 2)

print()
test = Test([VerboseCB()])
test.act()
test_eq(test.count, 2)
class ThinkCB(Callback):
    def on_think(self, ctx): ctx.count += 1

print()
test = Test([VerboseCB(), ThinkCB()])
test.act()
test_eq(test.count, 3)


acting...
thinking...

before_act count: 0
acting...
thinking...
after_act count: 2

before_act count: 0
acting...
thinking...
after_act count: 3


# with_cbs

In [None]:
#| export

class with_cbs:
    def __init__(self, nm:str|None=None): self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            nm = self.nm or f.__name__
            try:
                o.callback(f'before_{nm}')
                f(o, *args, **kwargs)
                o.callback(f'after_{nm}')
            except globals()[f'Cancel{nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{nm}')
        return _f

In [None]:
class Test(HasCallbacks):
    count = 0
    
    @with_cbs()
    def think(self):
        print('thinking...')
        self.count += 1

    @with_cbs('act')
    def act(self, cbs:Sequence[Callback]=()):
        with self.this_cbs(cbs):
            print('acting...')
            self.count = 1
            self.think()


class ActCB(Callback):
    def before_act(self, ctx): print('before_act count:', ctx.count)
    def after_act(self, ctx): 
        print('after_act count:', ctx.count)
        ctx.count += 1

class ThinkCB(Callback):
    def before_think(self, ctx): print('before_think count:', ctx.count)
    def after_think(self, ctx): 
        print('after_think count:', ctx.count)
        ctx.count += 1


test = Test()
test.act()
test_eq(test.count, 2)
print()

test = Test([ActCB()])
test.act()
test_eq(test.count, 3)
print()

test = Test([ActCB()])
test.act([ThinkCB()])
test_eq(test.count, 4)


acting...
thinking...

before_act count: 0
acting...
thinking...
after_act count: 2

before_act count: 0
acting...
before_think count: 1
thinking...
after_think count: 2
after_act count: 3


# Iteration

In [None]:
#| exporti

def _get_total(total: int|None|Type[EmptyT], source) -> int|None:
    if total is empty:
        try: return len(source)
        except: return length_hint(source) or None
    if total is not None and (not isinstance(total, int) or total < 0): total = None
    return total


In [None]:
test_eq(_get_total(-2, ()), None)
test_eq(_get_total(None, ()), None)
test_eq(_get_total(empty, ()), 0)
test_eq(_get_total(None, range(10)), None)
test_eq(_get_total(empty, range(10)), 10)
test_eq(_get_total(empty, [1,2,3]), 3)
test_eq(_get_total(empty, repeat(1, 5)), 5)
test_eq(_get_total(2, repeat(1, 5)), 2)
def _g(): yield 1
test_eq(_get_total(empty, _g()), None)
test_eq(_get_total(None, _g()), None)
test_eq(_get_total(3, _g()), 3)


In [None]:
class _Container:
    def __init__(self, l:Sequence='abcdefghij'): self._l = l
    def __contains__(self, x): return x in self._l

class _Iterable:
    def __init__(self, l:Sequence='abcdefghij'): self._l = l
    def __iter__(self): return iter(self._l)

class _Iterator(_Iterable):
    def __init__(self, l:Sequence='abcdefghij'): super().__init__(l); self._it = iter(self._l)
    def __iter__(self): return self
    def __next__(self): return next(self._it)

class _Reversible(_Iterable):
    def __init__(self, l: str='abcdefghij'): self._l = l
    def __reversed__(self): return reversed(self._l)

class _Sized:
    def __init__(self, l: str='abcdefghij'): self._l = l
    def __len__(self): return len(self._l)

class _Collection(_Sized, _Iterable, _Container): ...

class _Indexable(_Collection):
    """Sequence-like that only supports integer indexing, not slicing"""
    def __getitem__(self, idx) -> str:
        if not isinstance(idx, int): raise TypeError("Indices must be integers")
        return self._l[idx]
    def index(self, x, start=None, stop=None) -> int: return self._l.index(x, start, stop)
    def count(self, x, start=None, stop=None) -> int: return self._l.count(x, start, stop)

class _Sequence(_Reversible, _Indexable):
    def __getitem__(self, idx: int | slice) -> str | list[str]: return list(self._l)[idx]
    

In [None]:
# Test basic iteration
test_eq(list(_Iterable()), list('abcdefghij'))

# Test iterator exhaustion
it = _Iterator()
test_eq(list(it), list('abcdefghij'))
test_eq(list(it), [])  # Should be empty on second pass

# Test collection capabilities
coll = _Collection()
test_eq(len(coll), 10)
test_is('a' in coll, True)
test_eq(list(coll), list('abcdefghij'))

# Test sequence capabilities
seq = _Sequence()
test_eq(list(reversed(seq)), list('jihgfedcba'))

# Test _IndexableOnly
idx_only = _Indexable()
test_eq(idx_only[0], 'a')
test_eq(idx_only[-1], 'j')
test_eq(idx_only[5], 'f')
with test_raises(TypeError): idx_only[1:5]

# Test full _Sequence
seq = _Sequence()
# Integer indexing
test_eq(seq[0], 'a')
test_eq(seq[-1], 'j')
test_eq(seq[5], 'f')

# Slicing
test_eq(seq[1:5], list('bcde'))
test_eq(seq[::2], list('acegi'))
test_eq(seq[::-1], list('jihgfedcba'))
test_eq(seq[-3:], list('hij'))

# Index and count
test_eq(seq.index('d'), 3)
test_eq(seq.index('h', 5), 7)
with test_raises(ValueError): seq.index('z')

seq = _Sequence('aabbbcccc')
test_eq(seq.count('a'), 2)
test_eq(seq.count('b'), 3)
test_eq(seq.count('c'), 4)
test_eq(seq.count('z'), 0)


# IterTracker

In [None]:
class IterTracker:
    "Track iterators, "

    n: int|None = None
    elapsed_time: float|None = None
    active, interrupted = False, False
    item: Any = empty
    def __init__(self, source: Iterable[Any] = (), total: int|None|Type[EmptyT] = empty):
        self._source, self._total = source, total
        self.total = _get_total(self._total, self._source)

    @property
    def progress(self):
        if self.total is not None:
            return None if self.n is None else min(1., round((self.n+1)/float(self.total), 4))
        return None

    @property
    def state(self):
        return AD(update_(item=self.item, n=self.n, total=self.total, 
                        progress=self.progress, elapsed_time=self.elapsed_time,
                        interrupted=self.interrupted or empty,
                        empty_value=empty))

    def _iter_start(self):
        self.total: int|None = _get_total(self._total, self._source)
        self.active, self.elapsed_time, self.n = True, 0.0, None
        return time.time()
    
    def _iter_update(self): ...

    def _iter_stop(self):
        if self.total is None and self.n is not None: self.total = self.n + 1
        self.active, self.item= False, empty

    def _iter_interrupt(self):
        self.interrupted = True

    def __iter__(self):
        if self._source is None: return
        start_time = self._iter_start()
        try:
            for self.n, self.item in enumerate(self._source):
                if self.total is not None and self.n >= self.total: break
                yield self.item
                self.elapsed_time = time.time() - start_time
                self._iter_update()
        except Exception as e: self._iter_interrupt(); raise e
        finally: self._iter_stop()

    def trackiter(self) -> Iterator[tuple[AD, Any]]:
        for elem in self:
            yield self.state, elem

def trackiter(source: Iterable[Any]) -> Iterator[tuple[AD, Any]]:
    return IterTracker(source, empty).trackiter()


In [None]:
tr = IterTracker(_Iterable())
test_eq(tr.active, False)
test_eq(tr.state, AD(n=None, total=None, progress=None, elapsed_time=None))
test_eq([(tr.n,o) for o in tr], zip(range(10), 'abcdefghij'))
test_eq(tr.active, False)
test_eq(tr.state, AD(n=9, total=10, progress=1.0, elapsed_time=tr.elapsed_time))

test_eq(list(IterTracker()), [])

for st,o in trackiter(_Iterable()): print(o, st)


a {'item': 'a', 'n': 0, 'total': None, 'progress': None, 'elapsed_time': 0.0}
b {'item': 'b', 'n': 1, 'total': None, 'progress': None, 'elapsed_time': 2.9802322387695312e-05}
c {'item': 'c', 'n': 2, 'total': None, 'progress': None, 'elapsed_time': 4.100799560546875e-05}
d {'item': 'd', 'n': 3, 'total': None, 'progress': None, 'elapsed_time': 4.76837158203125e-05}
e {'item': 'e', 'n': 4, 'total': None, 'progress': None, 'elapsed_time': 5.3882598876953125e-05}
f {'item': 'f', 'n': 5, 'total': None, 'progress': None, 'elapsed_time': 6.079673767089844e-05}
g {'item': 'g', 'n': 6, 'total': None, 'progress': None, 'elapsed_time': 6.771087646484375e-05}
h {'item': 'h', 'n': 7, 'total': None, 'progress': None, 'elapsed_time': 7.390975952148438e-05}
i {'item': 'i', 'n': 8, 'total': None, 'progress': None, 'elapsed_time': 8.082389831542969e-05}
j {'item': 'j', 'n': 9, 'total': None, 'progress': None, 'elapsed_time': 8.702278137207031e-05}


In [None]:
tr = IterTracker(_Iterator())
test_eq(tr.state, AD(n=None, total=None, progress=None, elapsed_time=None))
test_eq(''.join([o for o in tr]), 'abcdefghij')
test_eq(tr.state, AD(n=9, total=10, progress=1.0, elapsed_time=tr.elapsed_time))
test_eq(''.join([o for o in tr]), '')  # exhausted
test_eq(tr.state, AD(n=None, total=None, progress=None, elapsed_time=0.0))


In [None]:
tr = IterTracker(_Collection())
test_eq(tr.state, AD(n=None, total=10, progress=None, elapsed_time=None))
test_eq(sum(tr.n for o in tr), 45)  # type: ignore


In [None]:
tr = IterTracker(_Indexable())
test_eq(tr.state, AD(n=None, total=10, progress=None, elapsed_time=None))
test_eq(operator.itemgetter(1, -5, 5)(_Indexable()), ('b', 'f', 'f'))


In [None]:
for i,o in enumerate(t := IterTracker(_Sequence())):
    test_eq(gets(t.state, 'item', 'n', 'total', 'progress'), (o, i, 10, (i+1)/10.))


# CollBack
> Even more impure iterators.

Track iterable progress with arbitrary side-effects.

In [None]:
#| export

class CollBack(HasCallbacks):
    "Track iterables and extend them with callbacks."

    cbs_names = ('before_iter', 'after_iter', 'on_iter', 'on_interrupt')

    n: int|None = None
    elapsed_time: float|None = None
    active, interrupted = False, False
    item: Any = empty
    def __init__(self, 
            source: Iterable[Any] = (), 
            total: int|None|Type[EmptyT] = empty, 
            context: Any = empty,
            **kwargs):
        self._source, self._total = source, total
        self.total: int|None = _get_total(self._total, self._source)
        self.context = context
        super().__init__(**kwargs)

    @property
    def progress(self):
        if self.total:
            return None if self.n is None else min(1., round((self.n+1)/float(self.total), 4))
        return None

    @property
    def state(self):
        return AD(update_(item=self.item, n=self.n, total=self.total, progress=self.progress, 
            context=self.context, 
            elapsed_time=self.elapsed_time, interrupted=self.interrupted or empty, empty_value=empty))

    def __repr__(self): return f'{self.__class__.__name__}#{self._source}, total={self._total}'

    def _start(self):
        self.total = _get_total(self._total, self._source)
        self.active, self.elapsed_time, self.n = True, None, None
        run_cbs(self.cbs, 'before_iter', self.state)

    def _stop(self):
        if self.total is None and self.n is not None: self.total = self.n + 1
        self.active, self.item = False, empty
        run_cbs(self.cbs, 'after_iter', self.state)
    
    def _interrupt(self): self.interrupted = True; run_cbs(self.cbs, 'on_interrupt', self.state)

    def __iter__(self) -> Iterator[Any]:
        if self._source is None: return
        try:
            start_time = time.time()
            self._start()
            for self.n, self.item in enumerate(self._source):
                if self.total is not None and self.n >= self.total: break
                yield self.item
                self.elapsed_time = time.time() - start_time
                run_cbs(self.cbs, 'on_iter', self.state, self.item)
                if self.total is not None and self.n >= self.total-1: break
        except Exception as e: self._interrupt(); raise e
        finally:
            self._stop()

    def trackback(self, cbs:Sequence[Callback]=()) -> Iterator[tuple[AD[Any], Any]]:
        with self.this_cbs(cbs):
            for elem in self:
                yield self.state, elem

def trackback(source: Iterable[Any], total: int|None|Type[EmptyT]=empty, context: Any=empty, 
        cbs:Sequence[Callback]=()) -> Iterator[tuple[AD[Any], Any]]:
    return CollBack(source, total, context, cbs=cbs).trackback()


In [None]:
test_eq(deque([1,2,3], maxlen=3), deque(CollBack([1,2,3]), maxlen=3))
test_eq(list(ChainMap({'a':1}, {'b':2})), list(CollBack(ChainMap({'a':1}, {'b':2}))))  # Search multiple dicts
test_eq(list(CollBack(Counter('hello').items())), [('h', 1), ('e', 1), ('l', 2), ('o', 1)])


In [None]:
# File objects
# import os

# tuple(CollTracker(open('file.txt')))  # Line iterator
# tuple(CollTracker(open('file.txt').read(3)))  # Char iterator
# tuple(CollTracker(os.scandir()))  # Directory iterator


In [None]:
# Recursive generators
from pathlib import Path


def tree_walk(path):
    yield path
    if path.is_dir():
        for p in path.iterdir():
            yield from tree_walk(p)

list(tree_walk(Path().parent))

[Path('.'),
 Path('10_callback.ipynb'),
 Path('file.txt'),
 Path('15_config.ipynb'),
 Path('00_basic.ipynb'),
 Path('05_test.ipynb'),
 Path('apollo_astronauts.json'),
 Path('00_project.ipynb')]

In [None]:
list(Path().parent.iterdir())

[Path('10_callback.ipynb'),
 Path('file.txt'),
 Path('15_config.ipynb'),
 Path('00_basic.ipynb'),
 Path('05_test.ipynb'),
 Path('apollo_astronauts.json'),
 Path('00_project.ipynb')]

In [None]:

# # Coroutine-style generators
# def averager():
#     total = 0.0
#     count = 0
#     average = None
#     while True:
#         term = yield average
#         total += term
#         count += 1
#         average = total/count

# # State machines
# def parser():
#     state = 'START'
#     while True:
#         char = yield
#         if state == 'START':
#             if char == '{': state = 'OPEN'
#         elif state == 'OPEN':
#             if char == '}': state = 'CLOSE'

# # Generator pipelines
# def read_chunks(file):
#     while chunk := file.read(8192):
#         yield chunk

# def decompress(chunks):
#     decompressor = zlib.decompressobj()
#     for chunk in chunks:
#         yield decompressor.decompress(chunk)

In [None]:
# # Basic patterns
# for x in iterable: ...
# [x for x in iterable]
# {x for x in iterable}
# {x:f(x) for x in iterable}
# any(iterable)
# all(iterable)
# sum(iterable)

# # With itertools
# from itertools import *
# for x in islice(count(), 10): ...  # First 10
# for x,y in zip(it1, it2): ...     # Parallel
# for x in chain(it1, it2): ...     # Sequential
# for k,g in groupby(data): ...     # Group runs

# # Unpacking
# first, *rest = iterable
# *rest, last = iterable
# first, *mid, last = iterable

# # Multiple assignment
# for x,y in pairs: ...
# for i,(x,y) in enumerate(pairs): ...

# # Context managers
# with contextlib.closing(iterator) as it: ...

# # Async iteration
# async for x in async_iterable: ...
# [x async for x in async_iterable]

# # Sentinel values
# for line in iter(f.readline, ''): ...

# # Custom iteration
# class Custom:
#     def __iter__(self): ...
#     def __next__(self): ...
#     def __reversed__(self): ...

In [None]:
import random

def print_progress(ctx, line):
    if ctx.n % 100 == 0:
        print(f"Processed {ctx.n} {ctx.elapsed_time:.2f}")

def process_line(line): time.sleep(random.uniform(0.0001, 0.0005))


with open('10_callback.ipynb') as f:
    for st,line in trackback(f, cbs=[FuncCB(on_iter=print_progress)]):
        process_line(line)


Processed 0 0.00


Processed 100 0.04
Processed 200 0.08
Processed 300 0.12
Processed 400 0.16
Processed 500 0.20
Processed 600 0.24
Processed 700 0.28
Processed 800 0.32
Processed 900 0.36
Processed 1000 0.40
Processed 1100 0.43
Processed 1200 0.47
Processed 1300 0.51
Processed 1400 0.55
Processed 1500 0.59


In [None]:
# # 1. Processing files with progress tracking
# def process_large_file(file_path):
#     with open(file_path) as f:
#         tracker = CollTracker(f)
#         for line in tracker:
#             if tracker.state.progress and tracker.state.progress % 0.1 == 0:
#                 print(f"Processed {tracker.state.progress:.0%}")
#             process_line(line)

# # 2. Batch processing with time estimation
# def process_batches(items, batch_size=100):
#     tracker = CollTracker(range(0, len(items), batch_size))
#     for i in tracker:
#         batch = items[i:i+batch_size]
#         if tracker.elapsed_time:
#             eta = (tracker.elapsed_time / tracker.state.progress) * (1 - tracker.state.progress)
#             print(f"Batch {i//batch_size}, ETA: {eta:.2f}s")
#         process_batch(batch)

# # 3. Training loop with epoch tracking
# def train_model(model, epochs, data):
#     for epoch, (epoch_state, batch) in enumerate(iterstate(range(epochs))):
#         losses = []
#         for batch_state, (x, y) in iterstate(data):
#             loss = model.train_step(x, y)
#             losses.append(loss)
#             print(f"Epoch {epoch}/{epochs}: {batch_state.progress:.1%}, Loss: {loss:.4f}")

# # 4. Parallel processing with progress
# from concurrent.futures import ProcessPoolExecutor
# def parallel_process(items, n_workers=4):
#     with ProcessPoolExecutor(n_workers) as ex:
#         futures = [ex.submit(process_item, item) for item in items]
#         for state, future in iterstate(futures):
#             result = future.result()
#             print(f"Completed {state.progress:.1%}, Latest result: {result}")

# # 5. Data pipeline with checkpoints
# def process_pipeline(data, checkpoint_every=1000):
#     for state, item in iterstate(data):
#         processed = transform(item)
#         if state.count % checkpoint_every == 0:
#             save_checkpoint(processed, state.count)
#             print(f"Checkpoint at item {state.count}, Progress: {state.progress:.1%}")

# # 6. Sliding window analysis with state
# def analyze_windows(sequence, window_size=3):
#     tracker = CollTracker(sequence)
#     windows = []
#     for i in tracker:
#         if i >= window_size - 1:
#             window = sequence[i-window_size+1:i+1]
#             windows.append((window, tracker.state))
#     return windows

# # 7. Recursive processing with depth tracking
# def process_tree(node, max_depth=None):
#     def walk(node, depth=0):
#         if max_depth and depth >= max_depth: return
#         for state, child in iterstate(node.children):
#             print(f"Depth {depth}, Child {state.count}/{state.total}")
#             yield child
#             yield from walk(child, depth + 1)
#     return walk(node)

# # 8. Time-based iteration control
# def process_with_timeout(items, timeout_secs=60):
#     tracker = CollTracker(items)
#     for item in tracker:
#         if tracker.elapsed_time > timeout_secs:
#             print(f"Timeout after processing {tracker.state.count} items")
#             break
#         process_item(item)

In [None]:
for o in trackback(range(3), cbs=[EchoCB()]): print(o)


before_iter ({'n': None, 'total': 3, 'progress': None, 'elapsed_time': None},) {}
({'item': 0, 'n': 0, 'total': 3, 'progress': 0.3333, 'elapsed_time': None}, 0)
on_iter ({'item': 0, 'n': 0, 'total': 3, 'progress': 0.3333, 'elapsed_time': 5.2928924560546875e-05}, 0) {}
({'item': 1, 'n': 1, 'total': 3, 'progress': 0.6667, 'elapsed_time': 5.2928924560546875e-05}, 1)
on_iter ({'item': 1, 'n': 1, 'total': 3, 'progress': 0.6667, 'elapsed_time': 7.390975952148438e-05}, 1) {}
({'item': 2, 'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': 7.390975952148438e-05}, 2)
on_iter ({'item': 2, 'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': 9.107589721679688e-05}, 2) {}
after_iter ({'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': 9.107589721679688e-05},) {}


In [None]:
t = CollBack(())
test_eq(t.total, 0)
with test_raises(StopIteration): next(iter(t))

for _ in (t := CollBack(range(3))):
    print(t.state)

t = CollBack(range(3))
test_eq(t.state, {'n': None, 'total': 3, 'progress': None, 'elapsed_time': None})
test_eq(next(it := iter(t)), 0)
test_eq(t.state, {'item': 0, 'n': 0, 'total': 3, 'progress': 0.3333, 'elapsed_time': t.elapsed_time})
test_eq(next(it), 1)
test_eq(t.state, {'item': 1, 'n': 1, 'total': 3, 'progress': 0.6667, 'elapsed_time': t.elapsed_time})
test_eq(next(it), 2)
test_eq(t.state, {'item': 2, 'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})
with test_raises(StopIteration): next(it)
test_eq(t.state, {'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})

t = CollBack('abcdef')
test_eq([o for o in t], list('abcdef'))
test_eq(t.state, {'n': 5, 'total': 6, 'progress': 1.0, 'elapsed_time': t.elapsed_time})
test_eq(t.active, False)

t = CollBack(repeat(1, 3))
test_eq(list(map(lambda x: x, t)), [1, 1, 1])
test_eq(t.state, {'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


{'item': 0, 'n': 0, 'total': 3, 'progress': 0.3333, 'elapsed_time': None}
{'item': 1, 'n': 1, 'total': 3, 'progress': 0.6667, 'elapsed_time': 2.7179718017578125e-05}
{'item': 2, 'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': 3.933906555175781e-05}


In [None]:
t = CollBack((12, 56, -1, 2, 67), 3)
test_eq([o for o in t], (12, 56, -1))
test_eq(t.state, {'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


In [None]:
from functools import reduce

t = CollBack(c := (1, 2, -1, -2, 7))
test_eq(reduce(lambda x, y: x+y, t), 7)
test_eq(t.state, {'n': 4, 'total': 5, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


class CountCB(Callback):
    def before_iter(self, ctx): self.count = 0
    def on_iter(self, ctx, _): self.count += 1

with (t := CollBack(c)).this_cbs([cb := CountCB()]):
    test_eq(reduce(lambda x, y: x+y, t), 7)
test_eq(t.state, {'n': 4, 'total': 5, 'progress': 1.0, 'elapsed_time': t.elapsed_time})
test_eq(cb.count, 5)

test_eq(reduce(lambda x, y: x+y, (t := CollBack(c, cbs=[cb]))), 7)
test_eq(t.state, {'n': 4, 'total': 5, 'progress': 1.0, 'elapsed_time': t.elapsed_time})
test_eq(cb.count, 5)


In [None]:
t = CollBack(repeat(7), 3)
oo = [o for o in t]
test_eq(oo, (7, 7, 7))
test_eq(t.state, {'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})

t = CollBack(repeat(None), None)
for _ in t:
    if t.state.n >= 10: break  # type: ignore
test_eq(t.state, {'n': 10, 'total': 11, 'progress': 1.0, 'elapsed_time': t.elapsed_time})

def fibonacci():
    a, b = 0, 1
    while True:
        yield a
        a, b = b, a + b

test_eq(CollBack(fibonacci(), 10), [0, 1, 1, 2, 3, 5, 8, 13, 21, 34])


In [None]:
t = CollBack(())
test_eq(t.total, 0)
with test_raises(StopIteration): next(iter(t))
for i in t: pass

t = CollBack(range(6))
test_eq(t.state, {'n': None, 'total': 6, 'progress': None, 'elapsed_time': None})
test_eq(next(iter(t)), 0)
test_eq(t.state, {'n': 0, 'total': 6, 'progress': 0.1667, 'elapsed_time': t.elapsed_time})

t = CollBack(range(6))
test_eq([t.n for i in t], range(6))
test_eq(t.state, {'n': 5, 'total': 6, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


In [None]:
t = CollBack('abc', None)
test_eq([_ for _ in t], ('a', 'b', 'c'))
test_eq(t.state, {'n': 2, 'total': 3, 'progress': 1.0, 'elapsed_time': t.elapsed_time})


# process_

https://stackoverflow.com/questions/50937966/fastest-most-pythonic-way-to-consume-an-iterator


In [None]:
#| export

_T = TypeVar('_T')

def process_(
        iterable:Iterable[_T], /, 
        cbs: Callback|Sequence[Callback]=(), 
        slc:slice|None=None, 
        pred:Callable[[_T], bool]|None=None, 
        context:Any=empty,
        **kwargs  # FuncCB kwargs
    ) -> tuple[Callback,...]:
    "Process a subset `slc` of `iterable` filtered by `pred` with callbacks from `cbs` and `FuncCB` `kwargs`"
    _cbs = FC.tuplify(cbs) + ((FuncCB(**kwargs),) if kwargs else ())
    items = FC.L(iterable)[slc or slice(None)].filter(pred)  # type: ignore
    collections.deque(CollBack(items, context=context,cbs=_cbs), maxlen=0)
    return _cbs  # type: ignore


In [None]:
process_(range(10), EchoCB(), slice(1,9), pred=lambda x: x%2==0);

before_iter ({'n': None, 'total': 4, 'progress': None, 'elapsed_time': None},) {}
on_iter ({'item': 2, 'n': 0, 'total': 4, 'progress': 0.25, 'elapsed_time': 4.1961669921875e-05}, 2) {}
on_iter ({'item': 4, 'n': 1, 'total': 4, 'progress': 0.5, 'elapsed_time': 6.413459777832031e-05}, 4) {}
on_iter ({'item': 6, 'n': 2, 'total': 4, 'progress': 0.75, 'elapsed_time': 7.700920104980469e-05}, 6) {}
on_iter ({'item': 8, 'n': 3, 'total': 4, 'progress': 1.0, 'elapsed_time': 8.916854858398438e-05}, 8) {}
after_iter ({'n': 3, 'total': 4, 'progress': 1.0, 'elapsed_time': 8.916854858398438e-05},) {}


# Colophon
----

In [None]:
import fastcore.all as FC
import nbdev
from nbdev.clean import nbdev_clean


In [None]:
if FC.IN_NOTEBOOK:
    nb_path = '10_callback.ipynb'
    nbdev_clean(nb_path)
    nbdev.nbdev_export(nb_path)
