In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
import torch
import matplotlib.pyplot as plt

# Callbacks

## 1. Callbacks as GUI events

In [5]:
import ipywidgets as widgets

```
The button widget is used to handle mouse clicks.
The "on_click()" method of the Button can be used to register a callback function when clicked.
```

In [13]:
w = widgets.Button(description="ClickMe!")

In [14]:
w

Button(description='ClickMe!', style=ButtonStyle())

In [15]:
def f(o):
    print("Hey you clicked me ! :)")

In [16]:
w.on_click(f)

In [17]:
w

Button(description='ClickMe!', style=ButtonStyle())

Hey you clicked me ! :)
Hey you clicked me ! :)
Hey you clicked me ! :)
Hey you clicked me ! :)
Hey you clicked me ! :)
Hey you clicked me ! :)
Hey you clicked me ! :)


```
When callbacks are used in this way they are often called as "events"
```

## 2. Creating your own callback

In [18]:
from time import sleep

In [19]:
def slow_calculation():
    result = 0
    for i in range(5):
        result += i*i
        sleep(1)
    return result

In [20]:
slow_calculation()

30

In [21]:
def slow_calculation(cb=None):
    result = 0
    for i in range(5):
        result += i*i
        sleep(1)
        if cb:
            cb(i)
    return result


In [23]:
def show_progress(epoch):
    print(f"Awesome!! We've finished epoch {epoch}")

In [24]:
slow_calculation(cb=show_progress)

Awesome!! We've finished epoch 0
Awesome!! We've finished epoch 1
Awesome!! We've finished epoch 2
Awesome!! We've finished epoch 3
Awesome!! We've finished epoch 4


30

### Lambdas and partials

In [25]:
slow_calculation(lambda i: print(f"awesome !!!! You have finished epoch {i}"))

awesome !!!! You have finished epoch 0
awesome !!!! You have finished epoch 1
awesome !!!! You have finished epoch 2
awesome !!!! You have finished epoch 3
awesome !!!! You have finished epoch 4


30

In [32]:
def make_show_progress(greet):
    _inner = lambda i: print(f"{greet}!!!! finished epoch {i}")
    return _inner

In [33]:
slow_calculation(cb=make_show_progress(greet="Wow"))

Wow!!!! finished epoch 0
Wow!!!! finished epoch 1
Wow!!!! finished epoch 2
Wow!!!! finished epoch 3
Wow!!!! finished epoch 4


30

In [34]:
def make_show_progress(greet):
    def _inner(i):
        print(f"{greet}!! finished epoch {i}")
    return _inner

In [35]:
slow_calculation(cb=make_show_progress(greet="Superb"))

Superb!! finished epoch 0
Superb!! finished epoch 1
Superb!! finished epoch 2
Superb!! finished epoch 3
Superb!! finished epoch 4


30

In [36]:
from functools import partial

In [37]:
def show_progress(greet, epoch):
    print(f"{greet}!!! finished epoch {epoch}")

In [39]:
slow_calculation(cb=partial(show_progress, "Superb"))

Superb!!! finished epoch 0
Superb!!! finished epoch 1
Superb!!! finished epoch 2
Superb!!! finished epoch 3
Superb!!! finished epoch 4


30

In [45]:
a = partial(show_progress, "zzz")
a(20)

zzz!!! finished epoch 20


## Multiple callback functions: `*args` and `**kwargs`

In [46]:
def f(*args, **kwargs):
    print(f"args={args}, kwargs={kwargs}")

In [47]:
f(3, "a", name="flash", work="save it")

args=(3, 'a'), kwargs={'name': 'flash', 'work': 'save it'}


In [53]:
def slow_calculation(cb=None):
    result = 0
    for i in range(5):
        if cb:
            cb.before_calc(i)
        result += i*i
        sleep(1)
        if cb:
            cb.after_calc(i, val=result)
    return result

In [69]:
class PrintStepCallback():
    def __init__(self): pass
    def before_calc(self, *args, **kwargs):
        print(f"About to start.")
    def after_calc(self, *args, **kwargs):
        print(f"Done Step")

In [55]:
slow_calculation(cb=PrintStepCallback())

About to start.
Done Step
About to start.
Done Step
About to start.
Done Step
About to start.
Done Step
About to start.
Done Step


30

In [70]:
class PrintStatusCallback():
    def __init_(self):
        pass
    def before_calc(self, epoch, **kwargs):
        print(f"About to start epoch {epoch}")
    def after_calc(self, epoch, val, **kwargs):
        print(f"After epoch {epoch}: val={val}")


In [71]:
slow_calculation(cb=PrintStatusCallback())

About to start epoch 0
After epoch 0: val=0
About to start epoch 1
After epoch 1: val=1
About to start epoch 2
After epoch 2: val=5
About to start epoch 3
After epoch 3: val=14
About to start epoch 4
After epoch 4: val=30


30

## Modifying Behaviour

In [83]:
def slow_calculation(cb=None):
    result = 0
    for i in range(5):
        if cb and hasattr(cb, "before_calc"):
            cb.before_calc(i)
        result += i*i
        sleep(1)
        if cb and hasattr(cb, "after_calc"):
            if cb.after_calc(i, result):
                print(f"Stopping early")
                break
    return result
                

In [84]:
class PrintAfterCallback():
    def __init__(self):
        pass
    def after_calc(self, epoch, val):
        print(f"After epoch={epoch}: val={val}")
        if val > 10:
            return True        
        

In [85]:
slow_calculation(cb=PrintAfterCallback())

After epoch=0: val=0
After epoch=1: val=1
After epoch=2: val=5
After epoch=3: val=14
Stopping early


14

In [102]:
class SlowCalculator():
    def __init__(self, cb=None):
        self.cb = cb
        self.result = 0
        
    def callback(self, cb_name, *args):
        if not self.cb:
            return
        cb = getattr(self.cb, cb_name, None)
        if cb:
            return cb(self, *args)
        
    def calc(self):
        for i in range(5):
            self.callback("before_calc", i)
            self.result += i*i
            sleep(1)
            if self.callback("after_calc", i):
                print("Stopping Early")
                break
        

In [103]:
class ModifyingCallback():
    def after_calc(self, calc, epoch):
        print(f"After epoch {epoch}: val={calc.result}")
        if calc.result > 10:
            return True              
        

In [104]:
calculator = SlowCalculator(cb=ModifyingCallback())

In [105]:
calculator.calc()

After epoch 0: val=0
After epoch 1: val=1
After epoch 2: val=5
After epoch 3: val=14
Stopping Early


In [106]:
calculator.result

14