# Implementing Callbacks from Foundations

#### Last Time

In our [previous notebook](http://nbviewer.jupyter.org/github/jamesdellinger/fastai_deep_learning_course_part2_v3/blob/master/03_minibatch_training_my_reimplementation.ipynb?flush_cache=true) we implemented data loader classes from scratch so that our training cycle could take advantage of mini-batch training.

We observed that training via mini-batches allows us to leverage the parallel processing capabilities of Nvidia GPUs, which means we can run forward passes on several, 64 in our case, training inputs simultaneously. This obviously allows us to train our models much faster than if we were forced to process only one input at a time.

#### Callbacks

In this notebook we demonstrate how to implement a callback system from scratch, and use it to hook into our model at various points during the training cycle. 

Fundamentally, callbacks allow us to observe and if we choose, influence how our model is training, all while the training cycle is still ongoing. Useful things we might use callbacks to accomplish include:
* Switch our model between train and eval mode, depending on whether we are training or performing inference.
* Dynamically updating the values of hyperparameters such as learning rate and momentum.
* Retrieve information that we can use to ascertain *how well* our model is training, such as validation loss and accuracy or other metrics.

Simply put, callbacks are an indispensable tool for the training of deep neural networks.

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [2]:
#export
from exports.nb_03 import *

### Refactoring the data loader and model

Before we implement the classes and methods that will allow us to handle callbacks, let's quickly refactor the code we used to create our data sets and loaders, as well as the code we used to initialize our model.

As with all previous notebooks so far, we're using the [MNIST](http://yann.lecun.com/exdb/mnist/index.html) dataset.

In [3]:
x_train,y_train,x_valid,y_valid = get_data()
train_ds,valid_ds = Dataset(x_train,y_train), Dataset(x_valid,y_valid)
nh,bs = 50,64 # hidden layer size, batch size
c = y_train.max().item() + 1 # number of classes
loss_func = F.cross_entropy

Up until now, our approach has been to create a method called `fit()` which defined how our model's training loop would run. We then called it whenever we wanted to initiate a training cycle. It looked like this:
```
fit(epochs, model, loss_func, opt, train_dl, valid_dl)
```
What if we stored the `model, loss_func, opt, train_dl, valid_dl` parameters inside another class called `Learner()`? This would not only make our `fit()` call much simpler, but if we made `Learner` objects mutable, would have the nice side-effect of allowing any adjustments made to the `Learner` while the model is training *be immediately seen* inside the training loop.

We could, for example, update the state of the learning rate value stored inside the `Learner` object at a particular point during the training cycle, and our model would immediately begin training at the updated learning rate.

In [4]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c
        
    @property
    def train_ds(self): return self.train_dl.dataset
    
    @property
    def valid_ds(self): return self.valid_dl.dataset

The `DataBunch()` class gives us a handy way to manage both the train and validation datasets/loaders, all under one roof.

In [5]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c) # get_dls() defined in notebook03

In [6]:
#export
def get_model(data, lr=0.5, nh=50):
    m = data.train_ds.x.shape[1] # Size of inputs
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))
    return model, optim.SGD(model.parameters(), lr=lr)

class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data=model,opt,loss_func,data

In [7]:
learn = Learner(*get_model(data), loss_func, data)

In [8]:
def fit(epochs, learn):
    for epoch in range(epochs):
        learn.model.train() # put the model in train mode
        for xb,yb in learn.data.train_dl:
            loss = learn.loss_func(learn.model(xb), yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()
            
        learn.model.eval() # put model in inference mode
        with torch.no_grad():
            tot_loss, tot_acc = 0., 0.
            for xb,yb in learn.data.valid_dl:
                pred = learn.model(xb)
                # Technically not proper to calculate tot_loss & tot_acc 
                # in this way because we don't ensure each val batch has 
                # the same size (by setting drop_last=True for the val 
                # data loader). Last val batch size is smaller than all
                # prev val batches.
                tot_loss += learn.loss_func(pred, yb)
                tot_acc  += accuracy(pred, yb)
            nv = len(learn.data.valid_dl)
            print(epoch, tot_loss/nv, tot_acc/nv)
        return tot_loss/nv, tot_acc/nv

In [9]:
loss, acc = fit(1, learn)

0 tensor(0.2717) tensor(0.9133)


### Callback Handler

We can refactor our `fit()` training loop so that it is easy to identify when a single batch is trained, and also when all batches are trained. 

This simpler structure will allow us to easily specify where we wire-in our various callbacks. Each of the below three functions has a `cb` parameter to accept a `Callback` object. Again, the ultimate goal of all of this is to make our training loop more flexible.

In [10]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()
        
def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return
        
def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()

In [11]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self): return True
    def after_step(self): return True

In [12]:
class CallbackHandler():
    def __init__(self,cbs=None):
        self.cbs = cbs if cbs else []
    
    def begin_fit(self, learn):
        self.learn, self.in_train = learn,True
        learn.stop = False
        res = True
        for cb in self.cbs: res = res and cb.begin_fit(learn)
        return res
    
    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        learn.model.train()
        self.in_train = True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res
    
    def begin_validate(self):
        self.learn.model.eval()
        self.in_train=False
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
        return res
    
    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res
            
    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res
    
    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res
    
    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try:     return learn.stop
        finally: learn.stop = False

In [13]:
class TestCallback(Callback):
    def begin_fit(self, learn):
        super().begin_fit(learn)
        self.n_iters = 0
        return True
    
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: learn.stop = True
        return True

In [14]:
fit(1, learn, cb=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10


The above structure is very similar to how version 1.0 of the [fastai library](https://github.com/fastai/fastai) implements callbacks, with the exception being that fastai's callback handler can also change and return `xb`, `yb`, and `loss`.

Now while the above architecture is a nice first attempt at creating a workable callback handler, there are ways we can make things more simple and flexible. It would be more straightforward if a single class had access to everything and could thus change anything at any time.

After all, seeing as how we're passing `cb` to each of the `one_batch()`, `all_batches()`, and `fit()` functions, it would make sense to store them all under one class.

### The `Runner()` Class

In fact, this is what we will do shortly. We will create a class called `Runner()`, which contains all the methods that compose our model training cycle, as well as the optimizer, model, loss function, and data.

First we'll rewrite our `Callback()` class to be compatible with the soon-to-be-implemented `Runner()` class:

In [15]:
#export
import re

# Helper function to convert the formatting of callback names 
# so they can be displayed the way we want: all lower-case, 
# with underscores in-between each word.
_camel_re1 = re.compile('(.)([A-Z][a-z])')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

class Callback():
    _order=0
    def set_runner(self, run): self.run=run
    def __getattr__(self, k): return getattr(self.run, k)
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

What will be the core callback of our training cycle. It is responsible for switching in between train and eval mode, as well as maintaining a count of the iterations that have elapsed during an epoch. This callback will always be called by default by our `Runner()` class:

In [16]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
        
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1
        
    def begin_epoch(self):
        self.run.n_epochs=self.epoch
        self.model.train()
        self.run.in_train=True
        
    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

A quick test of what we just wrote:

In [17]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

In [18]:
TrainEvalCallback().name

'train_eval'

A quick helper function that transforms inputs into lists:

In [19]:
#export
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

Finally, our new `Runner()` class:

In [20]:
#export
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
        self.stop, self.cbs = False, [TrainEvalCallback()] + cbs
        
    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data
    
    def one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        if self('begin_batch'): return
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()
        
    def all_batches(self, dl):
        self.iters = len(dl)
        for xb,yb in dl:
            if self.stop: break
            self.one_batch(xb,yb)
            self('after_batch')
        self.stop=False
        
    def fit(self, epochs, learn):
        self.epochs, self.learn, self.loss= epochs, learn, tensor(0.)
        
        try:
            for cb in self.cbs: cb.set_runner(self)
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)
                
                with torch.no_grad():
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                if self('after_epoch'): break
                    
        finally:
            self('after_fit')
            self.learn = None
            
    def __call__(self, cb_name):
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, cb_name, None)
            if f and f(): return True
        return False

#### `AvgStats()`
The second callback we will create is also of core importance: it will calculate and display the average loss and evaluation metrics during training (unlike how we'd been doing it up until now, this implementation will display correct avg loss/metrics regardless of whether batch size is constant of varies across iterations):

In [21]:
#export
class AvgStats():
    def __init__(self, metrics, in_train): self.metrics, self.in_train = listify(metrics), in_train
        
    def reset(self):
        self.tot_loss, self.count = 0., 0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
        
    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn
            
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats, self.valid_stats = AvgStats(metrics,in_train=True), AvgStats(metrics,in_train=False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
            
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

Let's try it out!

In [22]:
learn = Learner(*get_model(data), loss_func, data)

Let's suppose that we use accuracy as our evaluation metric. Here's one way to tell our `AvgStatsCallback()` class that this is what we want to do:

In [23]:
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)

In [24]:
run.fit(2, learn)

train: [0.31132888671875, tensor(0.9030)]
valid: [0.18311549072265626, tensor(0.9440)]
train: [0.139973671875, tensor(0.9580)]
valid: [0.1413156494140625, tensor(0.9571)]


In [25]:
loss, acc = stats.valid_stats.avg_stats
assert acc > 0.9
loss, acc

(0.1413156494140625, tensor(0.9571))

We can also use the `partial` method to pass the method that calculates accuracy to our `AvgStatsCallback` class object:

In [26]:
#export
from functools import partial

In [27]:
acc_cbf = partial(AvgStatsCallback, accuracy)

In [28]:
run = Runner(cb_funcs=acc_cbf)

In [29]:
run.fit(1, learn)

train: [0.10387189453125, tensor(0.9687)]
valid: [0.114280712890625, tensor(0.9661)]


Finally, if you try typing out the line below, you'll see that Jupyter lets us get tab-completion for dynamic code: each time you type a `.`, press the <kbd>tab</kbd> key to see a pop-up of all the possible methods that could be called next.

In [30]:
run.avg_stats.valid_stats.avg_stats

[0.114280712890625, tensor(0.9661)]

### Export

In [31]:
!python notebook2script_my_reimplementation.py 04_callbacks_my_reimplementation.ipynb

Converted 04_callbacks_my_reimplementation.ipynb to nb_04.py
