In [140]:
from abc import ABC, abstractmethod
from functools import partial, wraps
import inspect
import numpy as np

from htools import debug_call

In [294]:
def debug(func=None, prefix='', arguments=True):
    if not func: 
        if prefix: prefix += ' '
        return partial(debug, prefix=prefix, arguments=arguments)
    
    @wraps(func)
    def wrapper(*args, **kwargs):
        out_fmt = '\n{}CALLING {}({})'
        arg_strs = ''
        if arguments:
            sig = inspect.signature(wrapper).bind_partial(*args, **kwargs)
            sig.apply_defaults()
            sig.arguments.update(sig.arguments.pop('kwargs', {}))
            arg_strs = (f'{k}={repr(v)}' for k, v in sig.arguments.items())
            
        # Print call message and return output.
        print(out_fmt.format(prefix, func.__qualname__, ', '.join(arg_strs)))
        return func(*args, **kwargs)
    
    return wrapper

In [296]:
class Callback(ABC):
    """Abstract base class for callback objects to be passed to @callbacks
    decorator. Children must implement on_begin and on_end methods. Both should
    accept the decorated function's inputs and output as arguments
    
    Often, we may want to use the @debug decorator on one or both of these
    methods. If both methods should perform the same steps, one shortcut
    is to implement a single undecorated __call__ method, then have the
    debug-decorated on_begin and on_end methods return self(inputs, output).
    """
    
    @abstractmethod
    def on_begin(self, inputs, output=None):
        """
        Parameters
        -------------
        inputs: dict
            Dictionary of bound arguments passed to the function being 
            decorated with @callbacks.
        output: any
            Callbacks to be executed after the function call can pass the 
            function output to the callback. The default None value will remain
            for callbacks that execute before the function.
        """
    
    @abstractmethod
    def on_end(self, inputs, output=None):
        """
        Parameters
        -------------
        inputs: dict
            Dictionary of bound arguments passed to the function being 
            decorated with @callbacks.
        output: any
            Callbacks to be executed after the function call can pass the 
            function output to the callback. The default None value will remain
            for callbacks that execute before the function.
        """

    def __repr__(self):
        return f'{type(self).__name__}()'

In [315]:
def callbacks(cbs):
    """Decorator that attaches callbacks to a function. Callbacks should be
    defined as classes inheriting from abstract base class Callback that 
    implement a __call__ method. This allows us to store states
    rather than just printing outputs or relying on global vars.

    Parameters
    ----------
    cbs: list
        List of callbacks to execute before and after the decorated function.

    Examples
    --------
    @callbacks([PrintHyperparameters(), PlotActivationHist(),
                ActivationMeans(), PrintOutput()])
    def train_one_epoch(**kwargs):
        # Train model.
    """
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            bound = inspect.signature(func).bind_partial(*args, **kwargs)
            bound.apply_defaults()
            for cb in cbs:
                cb.on_begin(bound.arguments, None)
            out = func(*args, **kwargs)
            for cb in cbs:
                cb.on_end(bound.arguments, out)
            return out
        return wrapper
    return decorator

In [298]:
class PrintOutputCallback(Callback):
    
    def __init__(self):
        pass
    
    def on_begin(self, inputs, output):
        pass
    
    @debug
    def on_end(self, inputs, output):
        print(output)

In [314]:
class StdCallback(Callback):
    """Notice in this example we're calculating std of inputs, not outputs.
    That's why on_begin prints a new value but on_end prints the same as the 
    previous on_begin call.
    """
    
    def __init__(self):
        pass

    @debug(prefix='>'*2, arguments=True)
    def on_begin(self, inputs, output):
        return self(inputs, output)
    
    @debug(prefix='>'*8, arguments=False)
    def on_end(self, inputs, output):
        return self(inputs, output)
        
    def __call__(self, inputs, output):
        print(np.std(inputs['nums']))

In [310]:
@callbacks([PrintOutputCallback(),
            StdCallback()])
def foo(nums, a=6, **kwargs):
    print('kwargs', kwargs)
    return [n * a for n in nums]

In [311]:
nums = [3,4,5]
for i in range(1, 5):
    nums = foo(nums, i, c=True, d='d')


>> CALLING StdCallback.on_begin(self=StdCallback(), inputs=OrderedDict([('nums', [3, 4, 5]), ('a', 1), ('kwargs', {'c': True, 'd': 'd'})]), output=None)
0.816496580927726
kwargs {'c': True, 'd': 'd'}

CALLING PrintOutputCallback.on_end(self=PrintOutputCallback(), inputs=OrderedDict([('nums', [3, 4, 5]), ('a', 1), ('kwargs', {'c': True, 'd': 'd'})]), output=[3, 4, 5])
[3, 4, 5]

>>>>>>>> CALLING StdCallback.on_end()
0.816496580927726

>> CALLING StdCallback.on_begin(self=StdCallback(), inputs=OrderedDict([('nums', [3, 4, 5]), ('a', 2), ('kwargs', {'c': True, 'd': 'd'})]), output=None)
0.816496580927726
kwargs {'c': True, 'd': 'd'}

CALLING PrintOutputCallback.on_end(self=PrintOutputCallback(), inputs=OrderedDict([('nums', [3, 4, 5]), ('a', 2), ('kwargs', {'c': True, 'd': 'd'})]), output=[6, 8, 10])
[6, 8, 10]

>>>>>>>> CALLING StdCallback.on_end()
0.816496580927726

>> CALLING StdCallback.on_begin(self=StdCallback(), inputs=OrderedDict([('nums', [6, 8, 10]), ('a', 3), ('kwargs', {'c': 

In [302]:
np.std((18, 24, 30))

4.898979485566356