In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
#export
from lib.nb_01 import *

In [4]:
#export
import torch.nn as nn
from torch import optim

In [5]:
#export
import torch.nn.functional as F
from typing import *
from functools import partial
import time

In [6]:
#!conda install -c fastai fastprogress --yes

In [7]:
#export
from fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import format_time

A simple deep learning training looks something like this

### Data

In [8]:
x_train, y_train, x_valid, y_valid = get_mnist();

In [9]:
def normalize_to(x_train, x_valid):
    m, s = x_train.mean(), x_train.std()
    return normalize(x_train, m, s), normalize(x_valid, m, s)

In [10]:
x_train, x_valid = normalize_to(x_train, x_valid)

In [11]:
#export
class Databunch():
    
    def __init__(self, train_dl, valid_dl, c_in, c_out):
        self.train_dl, self.valid_dl, self.c_in, self.c_out = train_dl, valid_dl, c_in, c_out
        
    @property
    def train_ds(self): return self.train_dl.dataset
    
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [12]:
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)

In [13]:
bs = 64

In [14]:
train_dl, valid_dl = get_dls(train_ds, valid_ds, bs=bs, num_workers=6)

In [15]:
data = Databunch(train_dl, valid_dl, c_in=28*28, c_out=10)

In [16]:
data.c_in, data.c_out

(784, 10)

### Model

In [17]:
def get_model(data, lr=0.1):
    model = nn.Sequential(nn.Linear(data.c_in,200), nn.ReLU(), nn.Linear(200, data.c_out))
    return model, optim.SGD

In [18]:
model, opt_func = get_model(data, lr=0.1)

In [19]:
model

Sequential(
  (0): Linear(in_features=784, out_features=200, bias=True)
  (1): ReLU()
  (2): Linear(in_features=200, out_features=10, bias=True)
)

### Callbacks

```python
def fit():
    for epoch in range(epochs):
        for xb,yb in train_dl:
            pred = model(xb)
            loss = loss_func(pred, yb)
            loss.backward()
            opt.step()
            opt.zero_grad()
```

We want to be able to write a flexible training loop so that we can inject behaviour during training.

- callbacks
- learner
- recorder
- avgstats
- cuda

In [20]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

In [21]:
#export
class Callback():
    _order=0
    def set_runner(self, run): self.run=run
    def __getattr__(self, k): return getattr(self.run, k)
    
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')
    
    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f(): return True
        return False

class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch(self):
        if not self.run.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1
        
    def begin_epoch(self):
        self.run.n_epochs=self.epoch
        self.model.train()
        self.run.in_train=True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False

class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

In [22]:
#export
def param_getter(m): return m.parameters()

In [23]:
list(param_getter(model))[0]

Parameter containing:
tensor([[-0.0323,  0.0097,  0.0147,  ..., -0.0096,  0.0028, -0.0228],
        [-0.0126,  0.0237, -0.0166,  ...,  0.0068,  0.0159, -0.0200],
        [ 0.0109, -0.0226,  0.0292,  ..., -0.0095, -0.0134,  0.0350],
        ...,
        [ 0.0325, -0.0256,  0.0106,  ..., -0.0240,  0.0055, -0.0149],
        [-0.0326,  0.0058,  0.0327,  ..., -0.0223, -0.0310,  0.0105],
        [-0.0262, -0.0265, -0.0114,  ..., -0.0346,  0.0191, -0.0156]],
       requires_grad=True)

In [24]:
#export
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]

In [25]:
#export
class AvgStats():
    def __init__(self, metrics, in_train): self.metrics,self.in_train = listify(metrics),in_train
    
    def reset(self):
        self.tot_loss,self.count = 0.,0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn

            
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
    
    def begin_fit(self):
        met_names = ['loss'] + [m.__name__ for m in self.train_stats.metrics]
        names = ['epoch'] + [f'train_{n}' for n in met_names] + [
            f'valid_{n}' for n in met_names] + ['time']
        self.logger(names)
    
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        self.start_time = time.time()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
    
    def after_epoch(self):
        stats = [str(self.epoch)] 
        for o in [self.train_stats, self.valid_stats]:
            stats += [f'{v:.6f}' for v in o.avg_stats] 
        stats += [format_time(time.time() - self.start_time)]
        self.logger(stats)

In [26]:
#export
class ProgressCallback(Callback):
    _order=-1
    def begin_fit(self):
        self.mbar = master_bar(range(self.epochs))
        self.mbar.on_iter_begin()
        self.run.logger = partial(self.mbar.write, table=True)
        
    def after_fit(self): self.mbar.on_iter_end()
    def after_batch(self): self.pb.update(self.iter)
    def begin_epoch   (self): self.set_pb()
    def begin_validate(self): self.set_pb()
        
    def set_pb(self):
        self.pb = progress_bar(self.dl, parent=self.mbar, auto_update=False)
        self.mbar.update(self.epoch)

### Learner

In [27]:
#export

class Learner():
    def __init__(self, model, data, loss_func, opt_func=optim.SGD, lr=1e-2, splitter=param_getter,
                 cbs=None, cb_funcs=None):
        self.model,self.data,self.loss_func,self.opt_func,self.lr,self.splitter = model,data,loss_func,opt_func,lr,splitter
        self.in_train,self.logger,self.opt = False,print,None
        
        # NB: Things marked "NEW" are covered in lesson 12
        # NEW: avoid need for set_runner
        self.cbs = []
        self.add_cb(TrainEvalCallback())
        self.add_cbs(cbs)
        self.add_cbs(cbf() for cbf in listify(cb_funcs))

    def add_cbs(self, cbs):
        for cb in listify(cbs): self.add_cb(cb)
            
    def add_cb(self, cb):
        cb.set_runner(self)
        setattr(self, cb.name, cb)
        self.cbs.append(cb)

    def remove_cbs(self, cbs):
        for cb in listify(cbs): self.cbs.remove(cb)
            
    def one_batch(self, i, xb, yb):
        try:
            self.iter = i
            self.xb,self.yb = xb,yb;                        self('begin_batch')
            self.pred = self.model(self.xb);                self('after_pred')
            self.loss = self.loss_func(self.pred, self.yb); self('after_loss')
            if not self.in_train: return
            self.loss.backward();                           self('after_backward')
            self.opt.step();                                self('after_step')
            self.opt.zero_grad()
        except CancelBatchException:                        self('after_cancel_batch')
        finally:                                            self('after_batch')

    def all_batches(self):
        self.iters = len(self.dl)
        try:
            for i,(xb,yb) in enumerate(self.dl): self.one_batch(i, xb, yb)
        except CancelEpochException: self('after_cancel_epoch')

    def do_begin_fit(self, epochs):
        self.epochs,self.loss = epochs,tensor(0.)
        self('begin_fit')

    def do_begin_epoch(self, epoch):
        self.epoch,self.dl = epoch,self.data.train_dl
        return self('begin_epoch')

    def fit(self, epochs, cbs=None, reset_opt=False):
        # NEW: pass callbacks to fit() and have them removed when done
        self.add_cbs(cbs)
        # NEW: create optimizer on fit(), optionally replacing existing
        if reset_opt or not self.opt: self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)
            
        try:
            self.do_begin_fit(epochs)
            for epoch in range(epochs):
                self.do_begin_epoch(epoch)
                if not self('begin_epoch'): self.all_batches()

                with torch.no_grad(): 
                    self.dl = self.data.valid_dl
                    if not self('begin_validate'): self.all_batches()
                self('after_epoch')
            
        except CancelTrainException: self('after_cancel_train')
        finally:
            self('after_fit')
            self.remove_cbs(cbs)

    ALL_CBS = {'begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step',
        'after_cancel_batch', 'after_batch', 'after_cancel_epoch', 'begin_fit',
        'begin_epoch', 'begin_epoch', 'begin_validate', 'after_epoch',
        'after_cancel_train', 'after_fit'}
    
    def __call__(self, cb_name):
        res = False
        assert cb_name in self.ALL_CBS
        for cb in sorted(self.cbs, key=lambda x: x._order): res = cb(cb_name) and res
        return res

Documentation of valious variables present: 

- `learner.epochs`: total number of epochs told by user to run.
- ``

In [28]:
#export
def accuracy(out, yb): return (torch.argmax(out, dim=1)==yb).float().mean()

In [29]:
cbfs = [partial(AvgStatsCallback,accuracy),
        ProgressCallback]

In [30]:
learn = Learner(model,data, F.cross_entropy, opt_func, cb_funcs=cbfs)

In [31]:
learn.fit(10)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,0.636609,0.83982,0.326754,0.9084,00:01
1,0.319392,0.90898,0.267471,0.9234,00:01
2,0.270548,0.92282,0.233598,0.9331,00:01
3,0.238136,0.93258,0.211052,0.9404,00:01
4,0.213312,0.93908,0.190725,0.947,00:01
5,0.19246,0.94554,0.175108,0.9542,00:02
6,0.175379,0.95044,0.16226,0.9565,00:02
7,0.161475,0.95428,0.15298,0.9581,00:02
8,0.149135,0.95782,0.143697,0.9612,00:02
9,0.138424,0.96114,0.137345,0.9635,00:02


In [32]:
!python notebook2script.py 02_callback_and_learner.ipynb

Converted 02_callback_and_learner.ipynb to lib/nb_02.py
