In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_03 import *
import pdb

In [3]:
x_train, y_train, x_valid, y_valid = get_data()
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)

In [4]:
nh = 50
bs = 64

In [5]:
c = y_train.max().item() + 1

In [6]:
loss_func = F.cross_entropy

# DataBunch and Learner

Let's replace

```
fit(epochs, model, loss_func, opt, train_dl, valid_dl)
```

with:

```
fit(epochs, learn)
```

The `Learner` object will be mutable so any changes to its attribute elsewhere will be seen in our training loop. This allows us to tweak what is happening within the training loop from outside.

First, we combine the two dataloaders into one object:

In [7]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl, self.valid_dl, self.c = train_dl, valid_dl, c
        
    @property
    def train_ds(self): return self.train_dl.dataset
    
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [8]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [9]:
#export
def get_model(data, lr=0.5, nh=50):
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, data.c))
    return model, optim.SGD(model.parameters(), lr=lr)

class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model, self.opt, self.loss_func, self.data = model, opt, loss_func, data

In [10]:
learner = Learner(*get_model(data), loss_func, data)

In [11]:
def fit(epochs, learn):
    for epoch in range(epochs):
        learn.model.train()  # To handle batchnorm / dropout
        
        for xb, yb in learn.data.train_dl:
            loss = learn.loss_func(learn.model(xb), yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()
            
        learn.model.eval()
        with torch.no_grad():
            tot_loss, tot_acc = 0., 0.
            for xb, yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc += accuracy(pred, yb)
        nv = len(learn.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv

In [12]:
loss, acc = fit(2, learner)

0 tensor(0.7854) tensor(0.8161)
1 tensor(0.1205) tensor(0.9637)


# CallbackHandler

Let's refactor this into a function that goes through one batch:

In [13]:
def one_batch(xb, yb, cb):  # cb is the CallbackHandler here
    if not cb.begin_batch(xb, yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()
        
def all_batches(dl, cb):
    for xb, yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return
        
def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): return
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()

In [14]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self): return True
    def after_step(self): return True

In [15]:
class CallbackHandler():
    def __init__(self, cbs=None):
        self.cbs = cbs if cbs else []
    
    def begin_fit(self, learn):
        self.learn, self.in_train = learn, True
        learn.stop = False
        res = True
        for cb in self.cbs: res = res and cb.begin_fit(learn)
        return res
    
    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        #pdb.set_trace()
        self.learn.model.train()
        self.in_train = True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res
    
    def begin_validate(self):
        self.learn.model.eval()
        self.in_train = False
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
        return res
    
    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res
    
    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res
    
    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res
    
    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try: 
            return self.learn.stop
        finally: 
            self.learn.stop = False

In [16]:
class TestCallback(Callback):
    def begin_fit(self, learn):
        super().begin_fit(learn)
        self.n_iters = 0
        return True
    
    def begin_epoch(self, epoch):
        super().begin_epoch(epoch)
        return True
    
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters >= 10: self.learn.stop = True
        return True

In [17]:
fit(1, learner, cb=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10


This is basically how fastai v1 works.

Not very nice though for the following reasons:

1. All the callbacks have to return `True` for learning to continue. This leads to a lot of code duplication. In the next version the code should always continue training unless `True` is returned.
2. A `CallbackHandler` object (`cb`) is passed to the methods `one_batch`, `all_batches`, and `fit`. This hints that those three methods should belong to another class that has a `cb` attribute.

Let's implement that does this for us, we call it `Runner`.

## Runner

In [18]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

In [19]:
camel2snake("BleuScoreCallback")

'bleu_score_callback'

In [20]:
#export
class Callback():
    _order = 0  # We need a mechanism to define an order in which callbacks run
    def set_runner(self, run):
        self.run = run
    def __getattr__(self, k): return getattr(self.run, k)
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

`__getattr__` only called when attribute not found. Often what you are looking for is inside of the runner. That means that in a Callback you can do `self.stuff` even though `stuff` is inside the runner, thus, you delegate `getattr` to the runner. Example for how this works:

In [21]:
class A():
    def __init__(self):
        self.c = 1

class B():
    def __init__(self, A):
        self.a = 2
        self.b = 3
        self.A = A
    def __getattr__(self, k):
        return getattr(self.A, k)

In [22]:
myA = A()

In [23]:
myB = B(myA)

In [24]:
myB.c

1

**Let's build a callback that is responsible to switch the model back and forth between training and validation mode, counts the iterations or the percentage of iterations ellapsed in the epoch.**

In [25]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs = 0
        self.run.n_iter = 0
        
    def after_batch(self):
        if not self.in_train: return  # count of iters does not have to be increased
        self.run.n_epochs += 1./self.iters  # delegate to the runner which has iters attribute (all_batches func)
        self.run.n_iter   += 1
        
    def begin_epoch(self):
        self.run.n_epochs = self.epoch  # delegated to the runner which has a epoch attribute
        self.model.train()              # self.model also delegates to runner which retuns learn.model
        self.run.in_train = True
        
    def begin_validate(self):
        self.model.eval()
        self.run.in_train = False    

Let's also implement our `TestCallback`:

In [26]:
class TestCallback(Callback):
    def after_step(self):
        if self.train_eval.n_iters >= 10:  # train_eval is delegated to runner (has a TrainEvalCallback attribute
            return True                    # stored under 'train_eval'). return True here means STOP training

In [27]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

In [28]:
TrainEvalCallback().name

'train_eval'

In [29]:
#export
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, Iterable): return list(o)
    return [o]

**Let's construct the runner:**

In [30]:
#export 
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)                              # Why don't we call setattr for the cb in cbs as well?
        for cbf in listify(cb_funcs):
            cb = cbf()  # function constructs the callback object
            setattr(self, cb.name, cb)
            cbs.append(cb)
        self.stop, self.cbs = False, [TrainEvalCallback()] + cbs
        
    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data
    
    def one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb  # current batch
        if self('begin_batch'): return  # calls __call__ of self with argument 'begin_batch':
                                        # __call__ calls all callbacks that have attribute 'begin_batch'
                                        # If any of them returns True, return
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return  # return if any cbs say so or if not in training
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()
        
    def all_batches(self, dl):
        self.iters = len(dl)
        for xb, yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop = False
        
    def fit(self, epochs, learn):
        self.epochs, self.learn = epochs, learn
        
        try:
            for cb in self.cbs: cb.set_runner(self)
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)  
                # rember if a cb returns True it means stop training
                    
                with torch.no_grad():
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                if self('after_epoch'): break
        finally:
            self('after_fit')
            self.learn = None
            
    def __call__(self, cb_name):  # handles the calls self('cb name')
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, cb_name, None)
            if f and f(): return True
        return False
        
        # For all callbacks check if they have an attribute cb_name, if yes set f to that attribute, if not set f to None. 
        # If f is not None call f(). If it returns True it means that that Callback intends to stop training!

**Check notebook `05a_foundations.ipynb` to see how we can check wheter a callback method exists using `hasattr` instead of `getattr`!**

**Let's implement a Callback that computes metrics:**

In [31]:
#export
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics = listify(metrics)
        self.in_train = in_train
        
    def reset(self):
        self.tot_loss, self.count = 0., 0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self):
        return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self):
        return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
        
    def accumulate(self, run):
        bn = run.xb.shape[0]  # batch size
        self.tot_loss += run.loss * bn
        self.count += bn
        for i, m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn

In [32]:
#export
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats, self.valid_stats = AvgStats(metrics, True), AvgStats(metrics, False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
    
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
            
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

In [33]:
learn = Learner(*get_model(data), loss_func, data)

In [34]:
stats = AvgStatsCallback([accuracy])

In [35]:
accuracy??

In [36]:
run = Runner(cbs=stats)

In [37]:
run.fit(2, learn)

train: [0.316653671875, tensor(0.9030)]
valid: [0.1528393310546875, tensor(0.9568)]
train: [0.139232421875, tensor(0.9577)]
valid: [0.14021868896484374, tensor(0.9589)]


In [38]:
loss, acc = stats.valid_stats.avg_stats
assert acc > 0.9
loss, acc

(0.14021868896484374, tensor(0.9589))

It is not very nice that we first have to create an `AvgStatsCallback` object, pass it to the runner, and then later get the metrics from this object. Let's create a function with `partial` that creates this object for us in the constructor of `Runner` (all `cb_funcs` create an object and `setattr` is called for it). We can, thus, do `run.valid_stats` instead of `stats.valid_stats`.

In [39]:
#export
from functools import partial

In [40]:
acc_cbf = partial(AvgStatsCallback, accuracy)

Example:

In [41]:
a = acc_cbf()

In [43]:
isinstance(a, AvgStatsCallback)

True

In [44]:
run = Runner(cb_funcs=acc_cbf)

In [45]:
run.fit(2, learn)

train: [0.106533125, tensor(0.9676)]
valid: [0.19543341064453126, tensor(0.9481)]
train: [0.090037509765625, tensor(0.9727)]
valid: [0.740310986328125, tensor(0.8502)]


In [46]:
run.avg_stats.valid_stats.avg_stats

[0.740310986328125, tensor(0.8502)]

The first `avg_stats` is converted from the class name `AvgStatsCallback`.

## Export

In [47]:
!python notebook2script.py 04_callbacks.ipynb

Converted 04_callbacks.ipynb to exp/nb_04.py
