In [1]:
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from time import sleep
from torch import optim
from torch.utils.data import DataLoader, RandomSampler

### What is a callback?
- a callback is a function that you write *in case something happens*. - - for example, an `on_click_callback()` function will be executed everytime a button is clicked on a webpage.

In [2]:
import ipywidgets

def custom_callback(o):
    print('You clicked on me!')
        
button = ipywidgets.Button(description='Click Me!')
button.on_click(custom_callback)
button

Button(description='Click Me!', style=ButtonStyle())

You clicked on me!
You clicked on me!


### Callback toy example

Consider the follow basic routine

In [3]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        res+=i**2
        sleep(1)
        if cb: cb(i)
    return res

slow_calculation()

30

now, lets add a callback that displays progress during each step

In [4]:
def show_progress(epoch):
    print(f"We're at epoch {epoch}")
    
slow_calculation(show_progress)

We're at epoch 0
We're at epoch 1
We're at epoch 2
We're at epoch 3
We're at epoch 4


30

suppose now we want to add a custom message to show_progress, so we modify it as follows

In [32]:
def show_progress(msg, epoch):
    print(f"{msg} - we're at epoch {epoch}")
    
slow_calculation(show_progress('Hello'))

TypeError: show_progress() missing 1 required positional argument: 'epoch'

### Lambda function 

but now, the same code does not work because `show_progress` now takes two parameters, but when `slow_calcultion` calls it, it only passes it one via `cb(i)`. One way of resolving this is to use `lambda` function that allows one to dynamically parameterize a parameter.

In [15]:
slow_calculation(lambda o: show_progress('Hello', o))

Hello - we're at epoch 0
Hello - we're at epoch 1
Hello - we're at epoch 2
Hello - we're at epoch 3
Hello - we're at epoch 4


30

The `lambda o: show_progress('Hello', o)` essentially creates a new function that takes in a single parameter `o`. We can rewrite this process to make `show_progress` easier to call by creating a new function that returns an inner function.

In [5]:
def make_show_progress(msg):
    def _inner(epoch): print(f"{msg} - we're at epoch {epoch}")
    return _inner

slow_calculation(make_show_progress('Nice!'))

Nice! - we're at epoch 0
Nice! - we're at epoch 1
Nice! - we're at epoch 2
Nice! - we're at epoch 3
Nice! - we're at epoch 4


30

### Partial function

Another way of achieving this is through the use of `partial` function as follows.

In [31]:
from functools import partial

slow_calculation(partial(show_progress, 'Awesome!'))

Awesome! - we're at epoch 0
Awesome! - we're at epoch 1
Awesome! - we're at epoch 2
Awesome! - we're at epoch 3
Awesome! - we're at epoch 4


30

### Callbacks as callable classes

This is the typical pattern in FastAI, and as we can see, it provides the most minimalistic and natural pattern compared to the ones above.

In [7]:
class ProgressShowingCallback():
    def __init__(self, msg='hello'): self.msg = msg
    def __call__(self, epoch): print(f'{self.msg} - epoch {epoch}')

# instantiate callback
cb = ProgressShowingCallback('Woah!')

# use the callback in slow_calculation
slow_calculation(cb)

Woah! - epoch 0
Woah! - epoch 1
Woah! - epoch 2
Woah! - epoch 3
Woah! - epoch 4


30

### *args & **kargs

- `*args` represents enumerated arguments, while `**kwargs` represents key-word arguments. The `*` indicates there can be arbitrary number of these arguments. 

- one of the use case is to use `*args` and `**kwargs` in function definitions to provide flexibility of handling  arguments

In [21]:
def args_and_kwargs(*args, **kwargs):
    print(f'{args} and {kwargs}')
    
args_and_kwargs(1,2,3,kwarg1=4,kwargs=5)

(1, 2, 3) and {'kwarg1': 4, 'kwargs': 5}


### Check callback method is defined

- in our control process, we can check a callback to see whether a method is defined

In [18]:
class PrintStatusCallback():
    def __init__(self): pass
    def before_calc(self, epoch, **kwargs): print(f'About to start calculation {epoch}')
    def after_calc(self, epoch, val, **kwargs): print(f'After {epoch}: {val}')

- we can use the `hasattr(cb,'method_name')` pattern to check whether a particular callback method is defined in our callback

In [20]:
cb = PrintStatusCallback()
print(hasattr(cb,'before_calc'))
print(hasattr(cb,'after_calc'))
print(hasattr(cb,'between_calc'))

True
True
False


### Use callback to check status

- below, we modify our `after_calc` method to return a flag when the value has exceeded 10

In [17]:
class PrintStatusCallback():
    def __init__(self): pass
    def before_calc(self, epoch, **kwargs): print(f'About to start calculation {epoch}')
    def after_calc(self, epoch, val, **kwargs): 
        print(f'After {epoch}: {val}')
        if val>10: return True
        
        
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb,'before_calc'): cb.before_calc(i)
        res += i**2
        sleep(1)
        if cb and hasattr(cb,'after_calc'): 
            if cb.after_calc(i, res):
                print(f'Early stopping at epoch {i}')
                break
    
slow_calculation(PrintStatusCallback())

About to start calculation 0
After 0: 0
About to start calculation 1
After 1: 1
About to start calculation 2
After 2: 5
About to start calculation 3
After 3: 14
Early stopping at epoch 3


### Use callback to change state

- we can also use callback methods to change the value that we care about by sticking our main control process in a class and the value that we want to change as an attribute of the class

- our callback class can then access the value that we care about via the process class objects' attribute

In [29]:
class SlowCalculator():
    def __init__(self, cb=None): self.cb, self.res = cb, 0

    def callback(self, cb_name, *args):
        if not self.cb: return
        # below: None is returned in case cb does not have an attribute called cb)name
        cb = getattr(self.cb, cb_name, None) 
        # below: self is returned as the `calc` object in ModifyingCallback.after_calc
        if cb: return cb(self, *args)
    
    def calc(self):
        for i in range(5):
            self.callback('before_calc', i)
            self.res += i**2
            sleep(1)
            if self.callback('after_calc', i):
                print('stopping early')
                break
                
                
class ModifyingCallback():
    def after_calc(self, calc, epoch):
        print(f'After {epoch}: {calc.res}')
        if calc.res > 10: return True
        if calc.res < 3: calc.res = calc.res*2

In [32]:
calculator = SlowCalculator(ModifyingCallback())
calculator.calc()
calculator.res

After 0: 0
After 1: 1
After 2: 6
After 3: 15
stopping early


15

### `__call__`

- the calculator + callback system above is a very generic pattern
- one more simplification we can do is to eliminate the need to repeatedly use `self.callback` by changing the name of the function from `callback` to `__call__`

In [33]:
class SlowCalculator():
    def __init__(self, cb=None): self.cb, self.res = cb, 0

    def __call__(self, cb_name, *args): # this simplifies to just calling self
        if not self.cb: return
        cb = getattr(self.cb, cb_name, None) 
        if cb: return cb(self, *args)
    
    def calc(self):
        for i in range(5):
            self('before_calc', i)  # this is actually calling __call__
            self.res += i**2
            sleep(1)
            if self('after_calc', i): # this is actually calling __call__
                print('stopping early')
                break
                
                
class ModifyingCallback():
    def after_calc(self, calc, epoch):
        print(f'After {epoch}: {calc.res}')
        if calc.res > 10: return True
        if calc.res < 3: calc.res = calc.res*2

In [34]:
calculator = SlowCalculator(ModifyingCallback())
calculator.calc()
calculator.res

After 0: 0
After 1: 1
After 2: 6
After 3: 15
stopping early


15