In this series, I want to discuss the creation of a small library for training neural networks: `nntrain`. It's based off the excellent [part 2](https://course.fast.ai/) of Practical Deep Learning for Coders by Jeremy Howard, in which from lessons 13 to 18 (roughly) the development of the `miniai` library is discussed.

The library will build upon PyTorch. We'll try as much as possible to build from scratch to understand how it all works. Once the main functionality of components are implemented and verified, we can switch over to PyTorch's version. This is similar to how things are done in the course. However, this is not just a "copy / paste" of the course: on many occasions I take a different route, and most of the code is my own. That is not to say that all of this is meant to be extremely innovative, instead I had the following goals:

- Deeply understand the training of neural networks with a focus on PyTorch
- Try to create an even better narrative then what's presented in FastAI 🙉🤷‍♂️🙈
- Get hands-on experience with creating a library with [`nb_dev`](https://nbdev.fast.ai/)

`nb_dev` is another great project from the fastai community, which allows python libraries to be written in jupyter notebooks. This may sound a bit weird since the mainstream paradigm is to only do experimental work in notebooks. It has the advantage though that we can create the source code for our library in the very same environment in which we want to experiment and interact with our methods, objects and structure **while we are building the library**. For more details on why this is a good idea and other nice features of `nb_dev`, see [here](https://www.fast.ai/posts/2022-07-28-nbdev2.html).

So without further ado, let's start with where we left off in the previous [post](https://lucasvw.github.io/posts/08_nntrain_setup/):

## End of last post:

We finished the last post with exporting the `dataloaders` module into the `nntrain` library, which helps transforming a huggingface dataset dictionary into PyTorch dataloaders, so let's use [that](https://lucasvw.github.io/nntrain/dataloaders.html):

In [None]:
from datasets import load_dataset,load_dataset_builder

from nntrain.dataloaders import DataLoaders, hf_ds_collate_fn

In [None]:
 #| export
import torchvision.transforms.functional as TF
import torch
import torch.nn as nn
import torch.nn.functional as F
from operator import attrgetter
import fastcore.all as fc

In [None]:
name = "fashion_mnist"
ds_builder = load_dataset_builder(name)
hf_dd = load_dataset(name)

bs = 1024
dls = DataLoaders.from_hf_dd(hf_dd, batch_size=bs)

# As a reminder, `DataLoaders` expose a PyTorch train and validation dataloader as `train` and `valid` attributes:

dls.train, dls.valid

Reusing dataset fashion_mnist (/root/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1)


  0%|          | 0/2 [00:00<?, ?it/s]

(<torch.utils.data.dataloader.DataLoader>,
 <torch.utils.data.dataloader.DataLoader>)

## Learner Class

Let's continue to formalize our training loop into a `Learner` class with a `fit()` method. So far the training loop looks like:

```{.python code-line-numbers='true'}
def fit(epochs):
    for epoch in range(epochs):
        model.train()                                       
        n_t = train_loss_s = 0                              
        for xb, yb in dls.train:
            preds = model(xb)
            train_loss = loss_func(preds, yb)
            train_loss.backward()
            
            n_t += len(xb)
            train_loss_s += train_loss.item() * len(xb)
            
            opt.step()
            opt.zero_grad()
        
        model.eval()                                        
        n_v = valid_loss_s = acc_s = 0                      
        for xb, yb in dls.valid: 
            with torch.no_grad():                           
                preds = model(xb)
                valid_loss = loss_func(preds, yb)
                
                n_v += len(xb)
                valid_loss_s += valid_loss.item() * len(xb)
                acc_s += accuracy(preds, yb) * len(xb)
        
        train_loss = train_loss_s / n_t                     
        valid_loss = valid_loss_s / n_v
        acc = acc_s / n_v
        print(f'{epoch=} | {train_loss=:.3f} | {valid_loss=:.3f} | {acc=:.3f}')
```

Let's build this class in steps. Initialization is straigh forward: pass in everything the class needs to have access to. Note that we pass the optimizer class in, and instantiate it during initialization to be able to pass the model parameters and the learning rate.

In [None]:
class Learner():
    def __init__(self, model, dls, loss_fn, metric_fn, optim_class, lr):
        self.model = model
        self.dls = dls
        self.loss_fn = loss_fn
        self.metric_fn = metric_fn
        self.optim = optim_class(model.parameters(), lr)

Next, let's define the outer most call: `fit()`. The main improvement here is to call `one_epoch` twice, once for the training and once for the validation. Both passes are fairly similar as can be seen from comparing lines 3-8 with 16-21

In [None]:
@fc.patch
def fit(self:Learner, epochs):
    for epoch in range(epochs):                # iterate through the epochs
        self.one_epoch(epoch, train=True)      # one epoch through the training dataloader
        with torch.no_grad():                  # for the validation epoch we don't need grads
            self.one_epoch(epoch, train=False) # one epoch through the validation dataloader

Next, let's implement `one_epoch()`. To keep the functionality nice and small, we factor `do_batch()` out into it's own method

In [None]:
@fc.patch
def one_epoch(self:Learner, epoch, train):
    self.reset_stats()                         # reset the stats at beginning of each epoch
    self.model.train(train)                    # put the model either in train or validation mode
    self.dl = self.dls.train if train else self.dls.valid # reference to the active dataloader
    for self.batch in self.dl:                 # iterate through the active dataloader
        self.one_batch(train)                  # do one batch
    self.print_stats(epoch, train)             # print stats at the end of the epoch

And finally the method responsible for dealing with a single batch of data:

In [None]:
@fc.patch
def one_batch(self:Learner, train):
    self.xb, self.yb = self.batch
    self.preds = self.model(self.xb)           # forward pass through the model
    self.loss = self.loss_fn(self.preds, self.yb)  # loss
    if train:                                  # only do a backward and weight update if train
        self.loss.backward()
        self.optim.step()
        self.optim.zero_grad()
    self.update_stats()                        # update stats

We also add the methods related to the computation of the statistics:

In [None]:
@fc.patch
def update_stats(self:Learner):
    n = len(self.xb)
    self.loss_s += self.loss.item() * n
    self.metric_s += self.metric_fn(self.preds, self.yb).item() * n
    self.counter += n

@fc.patch
def reset_stats(self:Learner):
    self.counter = 0
    self.loss_s = 0
    self.metric_s = 0

@fc.patch
def print_stats(self:Learner, epoch, train):
    loss = self.loss_s / self.counter
    metric = self.metric_s / self.counter
    print(f'{epoch=:02d} | {"train" if train else "eval":<5} | {loss=:.3f} | {metric=:.3f}')

And let's do a round of training:

In [None]:
n_in = 28*28
n_h = 50
n_out = 10
lr = 0.01

def accuracy(preds, targs):
    return (preds.argmax(dim=1) == targs).float().mean()

layers = [nn.Linear(n_in, n_h), nn.ReLU(), nn.Linear(n_h, n_out)]
model = nn.Sequential(*layers)

l = Learner(model, dls, F.cross_entropy, accuracy, torch.optim.SGD, lr)

l.fit(5)

epoch=00 | train | loss=2.177 | metric=0.218
epoch=00 | eval  | loss=2.039 | metric=0.339
epoch=01 | train | loss=1.905 | metric=0.493
epoch=01 | eval  | loss=1.766 | metric=0.588
epoch=02 | train | loss=1.640 | metric=0.628
epoch=02 | eval  | loss=1.523 | metric=0.637
epoch=03 | train | loss=1.423 | metric=0.651
epoch=03 | eval  | loss=1.337 | metric=0.652
epoch=04 | train | loss=1.261 | metric=0.660
epoch=04 | eval  | loss=1.202 | metric=0.657


## Training on the GPU

Now let's see if we can train this on the GPU instead of the CPU. For that we have to move all our tensors to the GPU: notably:

- the data tensors in the dataloaders
- all parameters of our model, i.e. the weight and bias tensors from each layer

Let's first define a variable that will represent whether we can train on the GPU or not:

In [None]:
 #| export
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

We can put the model (parameters) on the GPU when we instantiate it:

In [None]:
model = nn.Sequential(*layers).to(device)

To put all data tensors on the GPU we update our `one_batch()` method:

In [None]:
@fc.patch
def one_batch(self:Learner, train):
    self.xb, self.yb = map(lambda x: x.to(device), self.batch)  # move the batch to the device
    self.preds = self.model(self.xb)           
    self.loss = self.loss_fn(self.preds, self.yb)
    if train:                                  
        self.loss.backward()
        self.optim.step()
        self.optim.zero_grad()
    self.update_stats()

We also have to make sure that when we store the loss and the metric, that we detach it from the computational graph. Otherwise, PyTorch will keep a reference to all the autograd history, which unnecessarily fills GPU memory:

In [None]:
@fc.patch
def update_stats(self:Learner):
    n = len(self.xb)
    self.loss_s += self.loss.detach().cpu().item() * n
    self.metric_s += self.metric_fn(self.preds, self.yb).detach().cpu().item() * n
    self.counter += n

# Callbacks, pubsub and event handlers

On the one side we want to keep the `Learner` and its training loop generic on the other side we need to be able to tweak the dynamics of the training loop depending on the use-case. One way to customize the training loop, without having to re-write the training loop would be to add a publish/subscribe (pubsub) mechanism. In the FastAI course, they are referred to as "callbacks", and although callbacks, event handlers and pubsub are all related. I think the mechanism implemented here is best referred to as pubsub. It can be compared to the way front-end development works. Whenever the user takes an action such as clicking a button, or hovering over a button certain events are **published**. The developer can **subscribe** to these events by adding a function (a **callback** or **event handler**) that get's called whenever they occur.

For the purposes of training neural networks we have the following requirements:

- The Learner framework defines a number of "events" that are published:
  - `before_fit`, `after_fit`
  - `before_epoch`, `after_epoch`
  - `before_batch`, `after_batch`
- Subscribers are classes that implement methods (e.g. `before_fit()`) that will be triggered whenever the associated event is published. They also have an `order` attribute which determines the order in which they are called in case multiple Subscribers subscribed to the same event.
- As an additional feature, subscribers will be able to redirect flow, but we will come back to that later

So let's implement this. First, we will need to store subscribers in the Learner class:

In [None]:
class Learner():
    def __init__(self, model, dls, loss_fn, metric_fn, optim_class, lr, subs):
        self.model = model
        self.dls = dls
        self.loss_fn = loss_fn
        self.metric_fn = metric_fn
        self.optim = optim(model.parameters(), lr)
        self.subs = subs

Next, let's define a method for publishing events. The method will go through the registered subscribers and if a method with the name of the event is declared, call that method passing the `learner` object as an argument:

In [None]:
@fc.patch
def publish(self:Learner, event):
    for sub in sorted(self.subs, key=attrgetter('order')):
        method = getattr(sub, name, None)
        if method is not None: method(self)

To publish the events during the training loop, realize that we have three time the same construct:

```
publish "before_event" event
do event
publish "after_event" event
```

With `event` being either `fit`, `epoch` or `batch`. So instead of adding this construct multiple times in the training loop let's define a class we can use as a decorater wrapping the actual "event":

In [None]:
 #| export

class PublishEvents():
    def __init__(self, event): 
        self.event = event
    
    def __call__(self, decorated_fn):
        def decorated_fn_with_publishing(learner, *args, **kwargs):
            learner.publish(f'before_{self.event}')
            decorated_fn(learner, *args, **kwargs)
            learner.publish(f'after_{self.event}')
        return decorated_fn_with_publishing

To implement this into the `Learner` we have to factor out the exact code we want to be executed in between the publishing of the `before` and `after`, see the additional `_one_epoch()` method:

In [None]:
 #| export
class Learner():
    def __init__(self, model, dls, loss_fn, metric_fn, optim_class, lr, subs):
        self.model = model
        self.dls = dls
        self.loss_fn = loss_fn
        self.metric_fn = metric_fn
        self.optim = optim_class(model.parameters(), lr)
        self.subs = subs
    
    @PublishEvents('fit')
    def fit(self, epochs):
        for epoch in range(epochs):
            self.one_epoch(epoch, train=True)
            with torch.no_grad():
                self.one_epoch(epoch, train=False)

    def one_epoch(self, epoch, train):
        self.reset_stats()
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        self._one_epoch(epoch, train)
        self.print_stats(epoch, train)
        
    @PublishEvents('epoch')
    def _one_epoch(self, epoch, train):
        for self.batch in self.dl:
            self.xb, self.yb = map(lambda x: x.to(device), self.batch)
            self.one_batch(train)
    
    @PublishEvents('batch')
    def one_batch(self, train):
        self.preds = self.model(self.xb)           
        self.loss = self.loss_fn(self.preds, self.yb)
        if train:                                  
            self.loss.backward()
            self.optim.step()
            self.optim.zero_grad()
        self.update_stats()
        
    def publish(self, event):
        for sub in sorted(self.subs, key=attrgetter('order')):
            method = getattr(sub, event, None)
            if method is not None: method(self)
            
    def update_stats(self):
        n = len(self.xb)
        self.loss_s += self.loss.detach().cpu().item() * n
        self.metric_s += self.metric_fn(self.preds, self.yb).detach().cpu().item() * n
        self.counter += n

    def reset_stats(self):
        self.counter = 0
        self.loss_s = 0
        self.metric_s = 0

    def print_stats(self, epoch, train):
        loss = self.loss_s / self.counter
        metric = self.metric_s / self.counter
        print(f'{epoch=:02d} | {"train" if train else "eval":<5} | {loss=:.3f} | {metric=:.3f}')

Let's create a dummy subscriber and test it out:

In [None]:
class Subscriber():
    order = 0

class DummySub(Subscriber):
    
    def before_fit(self, learn):
        print('before fit👋')
        
    def after_fit(self, learn):
        print('after fit👋')
        
    def before_epoch(self, learn):
        print('before epoch 💥')
        
    def after_epoch(self, learn):
        print('after epoch 💥')

l = Learner(model, dls, F.cross_entropy, accuracy, torch.optim.SGD, lr, [DummySub()])
l.fit(1)

before fit👋
before epoch 💥
after epoch 💥
epoch=00 | train | loss=1.146 | metric=0.664
before epoch 💥
after epoch 💥
epoch=00 | eval  | loss=1.106 | metric=0.657
after fit👋


Nice! Now let's add the last component of our pubsub system: subscribers should be able to cancel processing of certain events. For example, a a subscriber that would implement Early Stopping, will have to be able to cancel any further epochs when the validation loss starts increasing. One way to implement this, is with the help of `Exceptions` and `try` / `except` blocks:

It's actually very easy to implement this logic, we only need to define custom `Exceptions`, and update the `PublishEvents` class we are using as decorater to catch the exceptions that are thrown in any subscriber:

In [None]:
class CancelFitException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass


class PublishEvents():
    def __init__(self, name): 
        self.name = name
    
    def __call__(self, decorated_fn):
        def decorated_fn_with_publishing(learner, *args, **kwargs):
            try:
                learner.publish(f'before_{self.name}')
                decorated_fn(learner, *args, **kwargs)
                learner.publish(f'after_{self.name}')
            except globals()[f'Cancel{self.name.title()}Exception']: pass
        return decorated_fn_with_publishing

In [None]:
class DummySub(Subscriber):
    
    def before_fit(self, learn): print('before fit👋')
        
    def before_epoch(self, learn): raise CancelFitException
    
    def after_fit(self, learn): print('after fit 👋')

In [None]:
 #| export
    
class Subscriber():
    order = 0

class Learner():
    def __init__(self, model, dls, loss_fn, metric_fn, optim_class, lr, subs):
        self.model = model
        self.dls = dls
        self.loss_fn = loss_fn
        self.metric_fn = metric_fn
        self.optim = optim_class(model.parameters(), lr)
        self.subs = subs
    
    @PublishEvents('fit')
    def fit(self, epochs):
        for epoch in range(epochs):
            self.one_epoch(epoch, train=True)
            with torch.no_grad():
                self.one_epoch(epoch, train=False)

    def one_epoch(self, epoch, train):
        self.reset_stats()
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        self._one_epoch(epoch, train)
        
    @PublishEvents('epoch')
    def _one_epoch(self, epoch, train):
        for self.batch in self.dl:
            self.xb, self.yb = map(lambda x: x.to(device), self.batch)
            self.one_batch(train)
        self.print_stats(epoch, train)
    
    @PublishEvents('batch')
    def one_batch(self, train):
        self.preds = self.model(self.xb)           
        self.loss = self.loss_fn(self.preds, self.yb)
        if train:                                  
            self.loss.backward()
            self.optim.step()
            self.optim.zero_grad()
        self.update_stats()
        
    def publish(self, event):
        for sub in sorted(self.subs, key=attrgetter('order')):
            method = getattr(sub, event, None)
            if method is not None: method(self)
            
    def update_stats(self):
        n = len(self.xb)
        self.loss_s += self.loss.detach().cpu().item() * n
        self.metric_s += self.metric_fn(self.preds, self.yb).detach().cpu().item() * n
        self.counter += n

    def reset_stats(self):
        self.counter = 0
        self.loss_s = 0
        self.metric_s = 0

    def print_stats(self, epoch, train):
        loss = self.loss_s / self.counter
        metric = self.metric_s / self.counter
        print(f'{epoch=:02d} | {"train" if train else "eval":<5} | {loss=:.3f} | {metric=:.3f}')

In [None]:
l = Learner(model, dls, F.cross_entropy, accuracy, torch.optim.SGD, lr, [DummySub()])
l.fit(5)

before fit👋


And indeed, the after_fit event is never published, since the fit was cancelled during `before_epoch`