In [1]:
import torch
import matplotlib.pyplot as plt
import random

# Callbacks
Our goal is to build a learner, a general purpose traing loop flexible,
it is going to use callbacks

## Callbacks as GUI events
Let's see how callbacks are created for gui events

In [2]:
# !pip install ipywidgets
import ipywidgets as widgets

from IPython.display import display
output = widgets.Output()

In [3]:
w = widgets.Button(description="click me"); w

Button(description='click me', style=ButtonStyle())

In [5]:
def f(o): # a callable that will calld back to when smthing happens 
    print("hi")
w.on_click(f) # <- add a callback to be executed when button is clicked

## Let's create our own callback

In [6]:
from time import sleep
def slow_calculation():
    res = 0 
    for i in range(5):
        res += i*i
        sleep(1)
    return res

In [7]:
slow_calculation()

30

In [10]:
def slow_calculation(cb=None): # cb can be whatever can be called
    res = 0 
    for i in range(5):
        res += i*i
        sleep(1)
        if (cb): # cb to be executed at end of each train loop
            cb(i)
    return res

In [11]:
def show_progres(epoch):
    print(f'Hi {epoch}')

In [12]:
slow_calculation(cb=show_progres)

Hi 0
Hi 1
Hi 2
Hi 3
Hi 4


30

In [13]:
slow_calculation(lambda i : print(f"hi {i}"))

hi 0
hi 1
hi 2
hi 3
hi 4


30

In [16]:
def show_progress(inp1, inp2):
    print(f"hi {inp1}, {inp2}")

slow_calculation(lambda o: show_progress("firstInp", o))

hi firstInp, 0
hi firstInp, 1
hi firstInp, 2
hi firstInp, 3
hi firstInp, 4


30

In [17]:
def make_show_progress(inp1):
    return lambda inp2: print(f"hi {inp1}, {inp2}")

slow_calculation(make_show_progress("nice!"))

hi nice!, 0
hi nice!, 1
hi nice!, 2
hi nice!, 3
hi nice!, 4


30

## We can do the same using partials

In [18]:
from functools import partial

slow_calculation(partial(show_progress, "OK I guess")) #"OK I guess" will be the first input
# this will call the function show progres with partial input 

hi OK I guess, 0
hi OK I guess, 1
hi OK I guess, 2
hi OK I guess, 3
hi OK I guess, 4


30

In [19]:
f2 = partial(show_progress, "OK I guess")
slow_calculation(f2)

hi OK I guess, 0
hi OK I guess, 1
hi OK I guess, 2
hi OK I guess, 3
hi OK I guess, 4


30

### A callback must be a callable, we have callable classes so now we see 
# Callbacks as Callable Classes

In [21]:
class ProgressShowingCallback():
    def __init__(self, exlamation="Hi"):
        self.exlamation = exlamation
    def __call__(self, epoch):
        print(f"{self.exlamation}, {epoch} is over!")

In [22]:
cb = ProgressShowingCallback()
slow_calculation(cb)

Hi, 0 is over!
Hi, 1 is over!
Hi, 2 is over!
Hi, 3 is over!
Hi, 4 is over!


30

## *args and **kwargs

In [25]:
def f(*args, **kwargs):
    print(f"args: {args}, kwargs: {kwargs}")

def f1(*a, **b):
    print(f"args: {a}, kwargs: {b}")
    

In [24]:
f(3, "a", thing1="hello")
# position based output

args: (3, 'a'), kwargs: {'thing1': 'hello'}


In [26]:
f1(3, "a", thing1="hello")

args: (3, 'a'), kwargs: {'thing1': 'hello'}


In [28]:
c = [1,2]
d = {"name": "marco" , "cognome": "nobile"}
f(*c, **d)

args: (1, 2), kwargs: {'name': 'marco', 'cognome': 'nobile'}


## Multiple callbacks
Now instead of passing a callable let's pass an obj that implements 2 methods:  
- pre_calc()
- post_calc()


In [40]:
def slow_calculation(cb=None): # cb can be whatever can be called
    res = 0 
    for i in range(5):
        if (cb): cb.pre_calc(i)
        res += i*i
        sleep(1)
        if (cb): cb.post_calc(i, val=res)
    return res

In [41]:
class PrintStepCallback(): 
    # this should implement an interface, look class Module() in lecture4
     
    # with this signature even if vals/kwargs are passed they are ignored
    # eats undesired args
    def pre_calc(self, *args, **kwargs):
        print("about to start")
        
    def post_calc(self, *args, **kwargs):
        print("about to end")
    

In [42]:
slow_calculation(cb= PrintStepCallback())

about to start
about to end
about to start
about to end
about to start
about to end
about to start
about to end
about to start
about to end


30

In [43]:
class PrintStepCallback(): 
    # this should implement an interface, look class Module() in lecture4
    
    def pre_calc(self, epoch, **kwargs):
        print(f"about to start epoch {epoch}")
        
    def post_calc(self, epoch, val, **kwargs):
        print(f"epoch {epoch} about to end with val: {val}")


In [44]:
slow_calculation(cb= PrintStepCallback())

about to start epoch 0
epoch 0 about to end with val: 0
about to start epoch 1
epoch 1 about to end with val: 1
about to start epoch 2
epoch 2 about to end with val: 5
about to start epoch 3
epoch 3 about to end with val: 14
about to start epoch 4
epoch 4 about to end with val: 30


30

## Modifying behavior

In [72]:
def slow_calculation(cb=None):
    res = 0
    for i in range(5):
        if cb and hasattr(cb,'pre_calc'): # conditionally executes cb iif the cb obj has a .pre_calc() method
            cb.pre_calc(i)
        res += i*i
        sleep(1)
        if cb and hasattr(cb,'after_calc'):
            if cb.after_calc(i, res): # acts on return of cb
                print("stopping early")
                break
    return res

In [73]:
class PrintAfterCallback():
    def after_calc(self, epoch, val):
        print(f"After {epoch}: {val}")
        if val>10: return True

In [74]:
slow_calculation(PrintAfterCallback())


After 0: 0
After 1: 1
After 2: 5
After 3: 14
stopping early


14

In [75]:
class SlowCalculator():
    def __init__(self, cb=None): self.cb,self.res = cb,0
    
    def callback(self, cb_name, *args):
        if not self.cb: 
            return
        cb = getattr(self.cb,cb_name, None) # from callback obj select method
        if cb: 
            return cb(self, *args)

    def calc(self):
        for i in range(5):
            self.callback('before_calc', i)
            self.res += i*i
            sleep(1)
            if self.callback('after_calc', i):
                print("stopping early")
                break

In [76]:
class ModifyingCallback():
    def after_calc (self, calc, epoch):
        print(f"After {epoch}: {calc.res}")
        if calc.res>10: return True
        if calc.res<3: calc.res = calc.res*2

In [77]:
calculator = SlowCalculator(ModifyingCallback())
calculator.calc()
calculator.res

After 0: 0
After 1: 1
After 2: 6
After 3: 15
stopping early


15

# \_\_dunder\_\_ methods
All the dunder methods of the base object class can be found at: https://docs.python.org/3/reference/datamodel.html
For example python when calls + it actually calls \_\_add\_\_

In [50]:
class ExampleAdder():
    def __init__(self, a):
        self.a=a
    def __add__(self, b): # b must be another ExampleAdder
        return ExampleAdder(self.a + b.a)
    def __repr__(self):
        return str(self.a)

In [51]:
a = ExampleAdder(1)
b = ExampleAdder(2)
a+b

3

Important dunder methods:

- \_\_getitem\_\_
- \_\_getattr\_\_
- \_\_setattr\_\_
- \_\_del\_\_
- \_\_init\_\_
- \_\_new\_\_
- \_\_enter\_\_
- \_\_exit\_\_
- \_\_len\_\_
- \_\_dunder_\_\_dunder

Let's see more in focus \_\_getattr\_\_ and \_\_setattr\_\_ 

In [52]:
class A:
    a = 1
    b = 2

In [55]:
a = A()
a.a, a.b # # this is calling getattr

(1, 2)

In [56]:
getattr(a, 'a'), getattr(a, 'b') 

(1, 2)

In [60]:
class B:
    a = 1
    b = 2
    
    def __getattr__(self, k):
        # __getattr__ is called only for things that are not defined
        if k[0]=='_': raise AttributeError(k)
        return f'Hello from {k}' # so in this case we are "prohibiting" definition of dmember

In [61]:
b = B()

In [62]:
b.a

1

In [63]:
b.foo

'Hello from foo'

In [65]:
a.foo = 10
getattr(a, 'foo') # python object has __getattr__ method to get dmembers at runtime but it does not create em

10

In [7]:
class B:
    a = 1
    b = 2
    
    def __setattr__(self, k, v):
        if not k.startswith("_"): # if the datamember is not private
            print(f"setting attr {k}, with value {v}")
            super().__setattr__(k,v) 
    
    def __getattr__(self, k):
        if k[0]=='_': raise AttributeError(k)

In [8]:
b = B()

In [9]:
b.hello = "world"

setting attr hello, with value world


In [10]:
getattr(b, 'ggoo') # 

In [12]:
print(b.ggoo)

None
