In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch
import matplotlib.pyplot as plt

# Callbacks
## Create your own callback

In [13]:
from time import sleep

If there is a callback, call it and pass in the epoch number:

In [16]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        res += i*i
        sleep(1)
        if cb: cb(i)
    return res

In [17]:
slow_calculation()

30

In [18]:
def show_progress(epoch):
    print(f"We finished epoch {epoch}")

In [19]:
slow_calculation(show_progress)

We finished epoch 0
We finished epoch 1
We finished epoch 2
We finished epoch 3
We finished epoch 4


30

### Lambdas and partials

If you use the function just once, just create a lambda in place:

In [20]:
slow_calculation(lambda o: print(f"We finished epoch {o}"))

We finished epoch 0
We finished epoch 1
We finished epoch 2
We finished epoch 3
We finished epoch 4


30

Want to print a custom exclamation? `make_show_progress` creates a new function with a custom exclamation that takes the argument `epoch`. This is called a closure (a closure is the combination of a function and the lexical environment within which that function was declared).

In [25]:
def make_show_progress(exclamation):
    #_inner = lambda epoch: print(f"{exclamation} We finished epoch {epoch}")
    def _inner(epoch): print(f"{exclamation} We finished epoch {epoch}")
    return _inner

In [24]:
slow_calculation(make_show_progress("Ok, great!"))

Ok, great! We finished epoch 0
Ok, great! We finished epoch 1
Ok, great! We finished epoch 2
Ok, great! We finished epoch 3
Ok, great! We finished epoch 4


30

**Alternative:** use `partial` to create a function that takes only one argument from a function that takes i.e. two arguments.

In [26]:
from functools import partial

In [27]:
def show_progress(exclamation, epoch):
    print(f"{exclamation} We finished epoch {epoch}")

In [28]:
slow_calculation(partial(show_progress, "Ok, great."))

Ok, great. We finished epoch 0
Ok, great. We finished epoch 1
Ok, great. We finished epoch 2
Ok, great. We finished epoch 3
Ok, great. We finished epoch 4


30

### Callbacks as callable classes

In [32]:
class ProgressShowingCallback():
    def __init__(self, exclamation="Awesome"):
        self.exclamation = exclamation
    def __call__(self, epoch):
        print(f"{self.exclamation} We finished epoch {epoch}")

In [33]:
cb = ProgressShowingCallback("Super!")

In [34]:
slow_calculation(cb)

Super! We finished epoch 0
Super! We finished epoch 1
Super! We finished epoch 2
Super! We finished epoch 3
Super! We finished epoch 4


30

### Multiple callback funcs; `*args` and `**kwargs`

Positional arguments end up in a tuple called `args`. Keyword arguments end up in a dictionary called `kwargs`.

In [44]:
def f(*args, **kwargs):
    print(f"args: {args}; kwargs: {kwargs}")

In [45]:
f(3, "a", thing1="hello", thing2="world")

args: (3, 'a'); kwargs: {'thing1': 'hello', 'thing2': 'world'}


In [47]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb: cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb: cb.after_calc(i, val=res)
    return res

Do it like this if you don't care about the parameters passed to the function or don't even know the number of parameters. Accept the parameters but don't use them:

In [48]:
class PrintStepCallback():
    def __init__(self): pass
    def before_calc(self, *args, **kwargs): print(f"About to start")
    def after_calc(self, *args, **kwargs): print(f"Done step")

In [49]:
slow_calculation(PrintStepCallback())

About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step
About to start
Done step


30

If you want to use the parameters do it like this (`**kwargs` is still there in case we add further parameters in the future and don't want things to break when the parameter count is not right anymore):

In [54]:
class PrintStatusCallback():
    def __init__(self): pass
    def before_calc(self, epoch, **kwargs): print(f"About to start epoch {epoch}")
    def after_calc(self, epoch, val, **kwargs): print(f"After epoch {epoch} the result is {val}")

In [55]:
slow_calculation(PrintStatusCallback())

About to start epoch 0
After epoch 0 the result is 0
About to start epoch 1
After epoch 1 the result is 1
About to start epoch 2
After epoch 2 the result is 5
About to start epoch 3
After epoch 3 the result is 14
About to start epoch 4
After epoch 4 the result is 30


30

#### Modifying behaviour

1. We want to check whether our callback object has a certain method defined and only call it if it is.
2. Use the return value of the callback to modify behaviour of our slow calculation.
3. Use the callback to change the calculation!

In [57]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb, 'before_calc'): cb.before_calc(i)
        res += i*i
        sleep(1)
        if cb and hasattr(cb, 'after_calc'):
            if cb.after_calc(i, res):
                print("Stopping early")
                break
    return res

In [58]:
class PrintAfterCallback():
    def after_calc(self, epoch, val):
        print(f"After {epoch}: {val}")
        if val > 10: return True

In [59]:
slow_calculation(PrintAfterCallback())

After 0: 0
After 1: 1
After 2: 5
After 3: 14
Stopping early


14

If we want the callback to change the result of the calculation we can do it as follows:

In [68]:
class SlowCalculator():
    def __init__(self, cb=None):
        self.cb, self.res = cb, 0
    
    def callback(self, cb_name, *args):
        if not self.cb: return
        cb = getattr(self.cb, cb_name, None)
        if cb: return cb(self, *args)
        
    def calc(self):
        for i in range(5):
            self.callback('before_calc', i)
            self.res += i*i
            sleep(1)
            if self.callback('after_calc', i):
                print("Stopping early")
                break

In [69]:
class ModifyingCallback():
    def after_calc(self, calc, epoch):  # the calculator object gets passed to the callback with `calc`
        print(f"After epoch {epoch}: {calc.res}")
        if calc.res > 10: return True
        if calc.res < 3: calc.res = calc.res * 2

In [70]:
s = SlowCalculator(ModifyingCallback())

In [71]:
s.calc()

After epoch 0: 0
After epoch 1: 1
After epoch 2: 6
After epoch 3: 15
Stopping early
