In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_03 import *

## DataBunch/Learner

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=4799)

Most library have callbacks, but fastai callbacks are more pervasive, and they are used in many ways withing the library. The fastai v1 training loop contains a bunch of calls to callbacks, like `on_train_begin`, `on_epoch_begin`, `on_batch_begin` etc. With these, fasta v1 creates callbacks for learning rate schedulers, early stopping, parallel trainers, gradient clipping. You can mix them all together and write something like:

> `learn.fit(epochs, lr, wd, callbacks)`

For example, GANs are modules (`GANModule` class) that have an `nn.Module` object called `generator` passed to their `__init__`. In the forward pass, we inspect the (boolean) value of `generator`. If `True`, we are in generator mode, otherwise we are in discriminator mode (called *critic* in the library). There is then a `switch` method that switches between generator and discriminator mode. Similarly, they create a `GANLoss` class containing a generator loss and a discriminator/critic loss. They then create a `GANTrainer` class that inherits from `LearnerCallback` containing `on_train_begin`, `on_train_end`, etc., where each method does the required thing. (You need to go through the code to understand how it works. J.H. is just showing screenshots of these classes).

From the previous notebook we have a training loop. Let's start again grabbing the data.

In [3]:
x_train, y_train, x_valid, y_valid = get_data()
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)
nh, bs = 50, 64  # N. of hidden units and batch size.
c = y_train.max().item() + 1  # N. of classes.
loss_func = F.cross_entropy

This is the current signature of our `fit` function. It contains a lot of parameters, which is usually a *code smell*. Can we package some of these things together? This is usually a good thing, as we can pass them around together, create factory methods to create them etc. 

`fit(epochs, model, loss_func, opt, train_dl, valid_dl)`

We can do this in two steps: first we observe that training and validation set should live in one place, as they are both processed in the training loop. We put these datasets into a `DataBunch`. Then we create another class that stores the `DataBunch`, the loss function, the optimizer and the model, and we call it a `Learner`. Our final goal is to have something like this:

`fit(1, learn)`

This will allow us to tweak what's happening inside the training loop in other places of the code because the `Learner` object will be mutable, so changing any of its attribute elsewhere will be seen in our training loop.

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5363)

The `DataBunch` stores the training and the validation `DataLoader`s, and has a couple of properties that return the datasets.

In [4]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.c = c

    @property
    def train_ds(self):
        return self.train_dl.dataset

    @property
    def valid_ds(self):
        return self.valid_dl.dataset

The `get_dls` function simply taks two datsets, a batch size, and possibly some additional keyword arguments, and returns a tuple of `DataLoader`s. This is the body of the function.

```python
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))
```

In [5]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

We create a `get_model` function that returns a simple sequential model and an SGD optimizer. This is just for illustrative purposes. One nice aspect of this function is that it sets the dimension of the last size to the correct value, by useing the `data.c` attribute.

The important bit is the `Learner` class, that bundles together the model, the optimizer, the loss function and the `DataBunch`. The `Learner` class has no logic at all. It is only a storage device.

In [6]:
#export
def get_model(data, lr=0.5, nh=50):
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, data.c))
    return model, optim.SGD(model.parameters(), lr=lr)


class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model = model
        self.opt = opt
        self.loss_func = loss_func
        self.data = data

In [7]:
learn = Learner(*get_model(data), loss_func, data)

This is the same as the previous `fit` function, but wherever we had `model` we now have `learn.model`; where we had `data` we now have `learn.data`; where we had `opt` we have `learn.opt`, and so on.

In [8]:
def fit(epochs, learn):
    for epoch in range(epochs):
        learn.model.train()
        for xb, yb in learn.data.train_dl:
            loss = learn.loss_func(learn.model(xb), yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()

        learn.model.eval()
        with torch.no_grad():
            tot_loss, tot_acc = 0., 0.
            for xb, yb in learn.data.valid_dl:
                pred = learn.model(xb)
                tot_loss += learn.loss_func(pred, yb)
                tot_acc += accuracy(pred, yb)
        nv = len(learn.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv

In [9]:
loss, acc = fit(1, learn)

0 tensor(0.3741) tensor(0.8812)


## CallbackHandler

This was our training loop (without validation) from the previous notebook, with the inner loop contents factored out. More precisely we have factored out the computations in the `one_batch` function.

```python
def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()
    
def fit():
    for epoch in range(epochs):
        for b in train_dl:
            one_batch(*b)
```

We actually add a `all_batches` function, that runs `one_batch` on all batches both on the training and in the validation set. We add callbacks so we can remove complexity from loop, and make it flexible:

[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5628)

In [10]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb, yb):
        return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss):
        return
    loss.backward()
    if cb.after_backward():
        cb.learn.opt.step()
    if cb.after_step():
        cb.learn.opt.zero_grad()


def all_batches(dl, cb):
    for xb, yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop():
            return


def fit(epochs, learn, cb):
    if not cb.begin_fit(learn):
        return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch):
            continue
        all_batches(learn.data.train_dl, cb)

        if cb.begin_validate():
            with torch.no_grad():
                all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch():
            break
    cb.after_fit()

A custom callback inherits from the class below, which defines the methods we may use, and returns `True` by default. Some methods, like `begin_fit`, `begin_epoch`, `begin_batch`, `after_loss`, also store.

In [11]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True

    def after_fit(self):
        return True

    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True

    def begin_validate(self):
        return True

    def after_epoch(self):
        return True

    def begin_batch(self, xb, yb):
        self.xb = xb
        self.yb = yb
        return True

    def after_loss(self, loss):
        self.loss = loss
        return True

    def after_backward(self):
        return True

    def after_step(self):
        return True

A `CallbackHandler` is initialized with a list of Callbacks, if any, and then it takes care of running them all in the order they appear in the list. For example, `begin_fit` takes a learner, sets `in_train` and `res` to `True` and `stop` to `False`, and loops over all the callbacks. For each of them, it calls the `begin_git(learn)` method. At the end of the loop it returns the final valud of `res`.

The important thing is to understand how the callbacks, callback handler and training loop come together.

1. We create a learner.
2. We create one or more `Callback`.
3. We pass the list of callbacks to a `CallbackHandler`.
4. We pass the learner and the callback handler to the `fit` function.

In [12]:
class CallbackHandler():
    def __init__(self, cbs=None):
        self.cbs = cbs if cbs else []

    def begin_fit(self, learn):
        self.learn = learn
        self.in_train = True
        learn.stop = False
        res = True
        for cb in self.cbs:
            res = res and cb.begin_fit(learn)
        return res

    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs:
            res = res and cb.after_fit()
        return res

    def begin_epoch(self, epoch):
        self.learn.model.train()
        self.in_train = True
        res = True
        for cb in self.cbs:
            res = res and cb.begin_epoch(epoch)
        return res

    def begin_validate(self):
        self.learn.model.eval()
        self.in_train = False
        res = True
        for cb in self.cbs:
            res = res and cb.begin_validate()
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs:
            res = res and cb.after_epoch()
        return res

    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs:
            res = res and cb.begin_batch(xb, yb)
        return res

    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs:
            res = res and cb.after_loss(loss)
        return res

    def after_backward(self):
        res = True
        for cb in self.cbs:
            res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs:
            res = res and cb.after_step()
        return res

    def do_stop(self):
        try:
            return self.learn.stop
        finally:
            self.learn.stop = False

For example, the `TestCallback` callback below inherits `begin_fit` from the `Callback` class, and sets the `n_iters` attribute to zero. At the end of each step, it increases `n_iters` by one. If the number of iteration equals 10, it stops the training. This can be a handy callback, since we may often want to run just a few iterations.

In [13]:
class TestCallback(Callback):
    def begin_fit(self, learn):
        super().begin_fit(learn)
        self.n_iters = 0
        return True

    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters >= 10:
            self.learn.stop = True
        return True

In [14]:
fit(1, learn, cb=CallbackHandler([TestCallback()]))

1
2
3
4
5
6
7
8
9
10


This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing `cb` to so many functions is a strong hint they should all be in the same class! The fact that we are passing `cb` everywhere indicates that we should keep this state somewhere.

## Runner

A `Runner` is a new class that is likely to appear in the new version of fastai, and that contains the three things we have been using so far: `one_batch`, `all_batches` and `fit`.

**Python Note**: whenever a Python function or method does *not* return a value, this is equivalent to returning `None` which evaluates to `False` in logical operations. In all the code below, the logic is:

> If your callback handler returns `False`, *keep going*.

This means that the callbacks don't need to return anything, most of the time. `TestCallback`, for example, returns `True` when we want to *stop* the training.

In [15]:
def foo(x):
    x +=1


print(foo(2))

None


[Jump_to lesson 9 video](https://course.fast.ai/videos/?lesson=9&t=5811)

We re-write the `Callback` class adding an `_order` class attribute. We also add a `set_runner` method that assigns the runner as an attribute and redefine the `__getattr__` method. Note that `__getattr__` is called only if the attribute `k` is *not* found.

When in the `Runner` class below we call the `fit` method, this goes through the callbacks in the callback list and calls their `set_runner` method.

In [16]:
# export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')


def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

We now define a new `Callback` class which is much shorter and simpler than the previous one. We will come back to the `_order` attribute later, but note for now the presence of `__getattr__()`. It is defined to return `getattr(self.run, k)`. `__getattr__` is only called by Python when it *cannot* find the attribute it is asked for. So if we ask for an attribute `k` and this does not exist, `__getattr__(self, k)` will call `getattr(self.run, k)`, *i.e.*, if the attribute `k` is not defined in the `Callback`, it will look for it into the `Runner`. This is a common pattern in fastai. If an object contains or composes another object, we very often delegate `gettatr` to the other object.

The `Callback` class has a `name` property. This works as follows: if we create a new callback called `MyOwnCallback`, this will have a name that is the name of the class converted to snake case and stripped of the `callback` part, *i.e.*, `my_own`. This is useful because in `Runner` we have the following code:

```python
def __init__(self, cbs=None, cb_funcs=None):
    cbs = listify(cbs)
    for cbs in listify(cbf_funcs):
        cb = cbf()
        setattr(self, cb.name, cb)  # <--- here
        cb.append(cb)
    self.stop = False
    self.cbs = [TrainEvalCallback()] + cbs
```

That line sets the attribute in the runner from the callback. This is used, for example, for the `Recorder`.

We use this to add metrics, via the `AvgStatsCallback` class. At the beginning of an epoch (`.begin_epoch()`) we reset the statistics and at the end of an epoch (`.after_epoch()`) we print them. After computing the loss (`.after_loss()`) we accumulate them. In order to do this our `AvgStatsCallback` class has an `accumulate()` method. There are also two properties `all_stats` and `avg_stats` that return what their names say.

In [17]:
class Callback():
    _order = 0

    def set_runner(self, run):
        self.run = run

    def __getattr__(self, k):
        return getattr(self.run, k)

    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

This first callback is reponsible to switch the model back and forth in training or validation mode, as well as maintaining a count of the iterations, or the percentage of iterations ellapsed in the epoch. More details about this class are given below. The way the methods are organized, in fact, can only be understood in the light of the `Runner` class.

In [18]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs = 0.
        self.run.n_iter = 0

    def after_batch(self):
        if not self.in_train:
            return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter += 1

    def begin_epoch(self):
        self.run.n_epochs = self.epoch
        self.model.train()
        self.run.in_train = True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train = False

This is the same `TestCallback` as before, but notice how much smaller it has become.

In [19]:
class TestCallback(Callback):
    def after_step(self):
        if self.train_eval.n_iters >= 10:
            return True

In [20]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

In [21]:
TrainEvalCallback().name

'train_eval'

In [22]:
#export
from typing import *


def listify(o):
    if o is None:
        return []
    if isinstance(o, list):
        return o
    if isinstance(o, str):
        return [o]
    if isinstance(o, Iterable):
        return list(o)
    return [o]

The `fit` method in the runner is very simple: it takes an `epoch` and a `learner` (which has no logic in it), we then tell each of our callbacks what runner they are currently working with, and we then call `begin_fit`, we go through each epoch in turn starting with a call to `begin_epoch` which , unless it returns `False`, calls `all_batches(self.data.train_dl)`. Once the training set is exhausted we pass to the validatin step within a `no_grad()` context manager, call `begin_validate()` and run `all_batches(self.data.valid_dl)`.

Calls like `self('begin_fit')`, `self('begin_epoch')`, `self('after_fit')` etc. may look strange. What J.H. has done has been to define a `__call__` method that, as usual, lets you treat an object as if it were a function. In one of the next lessons we will see that we could have called this `callback` or `run_callbacks` or something else. Here, J.H. prefers to use `self` as a function, but I personally find that it makes the code much less readable. `__call__` will go through all the callbacks. J.H. says that he didn't like the fact that before each callback had to inherit from the `Callback` class, in order to have the various methods (what's wrong with that?), so this is why he is now using `__getattr__(cb, cb_name, None)` which means: "look inside `cb` and check whether the name `cb_name` exists, and if you cannot find it return `None`. If it finds it, it calls it (very unfortunately) `f` and runs it.

This trick allows us to re-write our `TestCallback` in a much simpler way. This happens because we can refer to stuff that is defined inside the runner and call `self.iter`.

Technically we don't need to inherit from `Callback` any more, but we still do it for one detail: `_order`. This allows us to decide in which order we run our callbacks. Often you need things to run in a particular order. In the `__call__` method of the `Runner` there is the following line

```python
for cb in sorted(self.cbs, key=lambda x: x._order):
```

which sorts the callbacks by the `_order` attribute. For example, we may want to execute `TrainEvalCallback` (see below) first, assignign it `_order = 0` and our `TestCallback` afterward, with `_order = 1`.

You may have noticed that the `Runner` never calls `model.train` or `model.eval`. This is because there is a `TrainEvalCallback` that, at the beginning of an epoch, when `begin_epoch()` is called, calls `model.train()`, and at the beginning of the validation phase, when `begin_validate()` is called, calls `model.eval()`. `TrainEvalCallback` has also something that keep tracks of the epoch express as a float. This happens inside `after_batch()`. It also keeps track of the number of iterations.

In [23]:
#export
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
        self.stop = False
        self.cbs = [TrainEvalCallback()] + cbs

    @property
    def opt(self):
        return self.learn.opt

    @property
    def model(self):
        return self.learn.model

    @property
    def loss_func(self):
        return self.learn.loss_func

    @property
    def data(self):
        return self.learn.data

    def one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        if self('begin_batch'):
            return
        self.pred = self.model(self.xb)
        if self('after_pred'):
            return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train:
            return
        self.loss.backward()
        if self('after_backward'):
            return
        self.opt.step()
        if self('after_step'):
            return
        self.opt.zero_grad()

    def all_batches(self, dl):
        self.iters = len(dl)
        for xb, yb in dl:
            if self.stop:
                break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop = False

    def fit(self, epochs, learn):
        self.epochs, self.learn = epochs, learn

        try:
            for cb in self.cbs:
                cb.set_runner(self)
            if self('begin_fit'):
                return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'):
                    self.all_batches(self.data.train_dl)

                with torch.no_grad():
                    if not self('begin_validate'):
                        self.all_batches(self.data.valid_dl)
                if self('after_epoch'):
                    break

        finally:
            self('after_fit')
            self.learn = None

    def __call__(self, cb_name):
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, cb_name, None)
            if f and f():
                return True
        return False

Third callback: how to compute metrics.

In [24]:
# export
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics = listify(metrics)
        self.in_train = in_train

    def reset(self):
        self.tot_loss = 0.
        self.count = 0
        self.tot_mets = [0.] * len(self.metrics)

    @property
    def all_stats(self):
        return [self.tot_loss.item()] + self.tot_mets

    @property
    def avg_stats(self):
        return [o/self.count for o in self.all_stats]

    def __repr__(self):
        if not self.count:
            return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i, m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn


class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats = AvgStats(metrics, True)
        self.valid_stats = AvgStats(metrics, False)

    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()

    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad():
            stats.accumulate(self.run)

    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)

In [25]:
learn = Learner(*get_model(data), loss_func, data)

In the code below, we create our `AvgStatsCallback` and then pass it to `run`

In [26]:
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)

In [27]:
run.fit(2, learn)

train: [0.3159608203125, tensor(0.9019)]
valid: [0.282188720703125, tensor(0.9173)]
train: [0.140204130859375, tensor(0.9574)]
valid: [0.12210206298828125, tensor(0.9634)]


In [28]:
loss, acc = stats.valid_stats.avg_stats
assert acc > 0.9
loss, acc

(0.12210206298828125, tensor(0.9634))

Another way of doing this is to create an accuracy **callback function**. We do this by using `partial` from `functools`.

In [29]:
# export
from functools import partial

The result of the operation below is a **function that can create a callback**.

In [30]:
acc_cbf = partial(AvgStatsCallback, accuracy)
type(acc_cbf)

functools.partial

If we pass `acc_cbf` to `Runner`, we don't need to store the `stats` callback into the runner anymore, because the runner will run this bit of code

```python
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):  # Go through each callback function
            cb = cbf()                 # Call the function and create the callback.
            setattr(self, cb.name, cb) # Save the callback as an attribute in the runner with name `cb.name`.
            cbs.append(cb)
        self.stop = False
        self.cbs = [TrainEvalCallback()] + cbs
```

In [31]:
run = Runner(cb_funcs=acc_cbf)

In [32]:
run.fit(1, learn)

train: [0.107763583984375, tensor(0.9668)]
valid: [0.2064674072265625, tensor(0.9432)]


By using `partial` we can now access the statistics from inside the runner. This is what fastai v1 does, with the difference that the statistics are saved in a `Learner` and not in a `Runner` (there is no `Runner` class in v1).

Using Jupyter means we can get tab-completion even for dynamic code like this.

In [33]:
run.avg_stats.valid_stats.avg_stats

[0.2064674072265625, tensor(0.9432)]

The approaches seen so far may look awkward at the beginning, but it's up to the user how deep s/he wants to go into the Software Engineering details of these solutions. The important bit to retain is summarized by this chunck of code:

```python
    def fit(self, epochs, learn):
        self.epochs, self.learn = epochs, learn

        try:
            for cb in self.cbs:
                cb.set_runner(self)
            if self('begin_fit'):
                return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'):
                    self.all_batches(self.data.train_dl)

                with torch.no_grad():
                    if not self('begin_validate'):
                        self.all_batches(self.data.valid_dl)
                if self('after_epoch'):
                    break

        finally:
            self('after_fit')
            self.learn = None
```

## Export

In [34]:
!python notebook2script.py 04_callbacks.ipynb

Converted 04_callbacks.ipynb to exp/nb_04.py
