In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_03 import *

## Initial setup

### Data

In [3]:
x_train,y_train,x_valid,y_valid = get_data()

In [4]:
train_ds,valid_ds = Dataset(x_train, y_train),Dataset(x_valid, y_valid)

In [5]:
nh,bs = 50,64

In [6]:
loss_func = F.cross_entropy

### DataBunch, Learner, Callbacks

Factor out the connected pieces of info out of the fit() argument list:

In [None]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl):
        self.train_dl,self.valid_dl = train_dl,valid_dl
        self.c = self.train_ds.y.max().item()+1
        
    @property
    def train_ds(self): return self.train_dl.dataset
        
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [None]:
data = DataBunch(*get_dls(train_ds, valid_ds, bs))

In [None]:
#export
def get_model(data, lr=0.5, nh=50):
    m = data.train_ds.x.shape[1]
    model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,data.c))
    return model, optim.SGD(model.parameters(), lr=lr)

class Learner():
    def __init__(self, model, opt, loss_func, data):
        self.model,self.opt,self.loss_func,self.data = model,opt,loss_func,data

In [None]:
learn = Learner(*get_model(data), loss_func, data)

Add callbacks so we can remove complexity from loop, and make it flexible:

In [None]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb,yb): return
    loss = cb.learn.loss_func(cb.learn.model(xb), yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learn.opt.step()
    if cb.after_step(): cb.learn.opt.zero_grad()

In [None]:
def all_batches(dl, cb):
    for xb,yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop(): return

In [None]:
def fit(epochs, learn, cb):
    if not cb.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if not cb.after_epoch(): break
    cb.after_fit()

In [None]:
class CallbackHandler():
    def __init__(self): self.stop,self.cbs = False,[]

    def begin_fit(self, learn):
        self.learn,self.in_train = learn,True
        return True
    def after_fit(self): pass
    
    def begin_epoch(self, epoch):
        learn.model.train()
        self.in_train=True
        return True
    def begin_validate(self):
        self.learn.model.eval()
        self.in_train=False
        return True
    def after_epoch(self): return True
    
    def begin_batch(self, xb, yb): return True
    def after_loss(self, loss): return self.in_train
    def after_backward(self): return True
    def after_step(self): return True
    
    def do_stop(self):
        try:     return self.stop
        finally: self.stop = False

In [None]:
fit(1, learn, cb=CallbackHandler())

This is roughly how fastai does it now (except that the handler can also change and return `xb`, `yb`, and `loss`). But let's see if we can make things simpler and more flexible, so that a single class has access to everything and can change anything at any time. The fact that we're passing `cb` to so many functions is a strong hint they should all be in the same class!

## Runner

In [None]:
#export
class Callback(): _order=0

class TrainEvalCallback(Callback):
    def begin_fit(self, run):
        run.n_epochs=0.
        run.n_iter=0
    
    def after_batch(self, run):
        if run.in_train:
            run.n_epochs += 1./run.iters
            run.n_iter   += 1
        
    def begin_epoch(self, run):
        run.n_epochs=run.epoch
        run.model.train()
        run.in_train=True

    def begin_validate(self, run):
        run.model.eval()
        run.in_train=False

In [None]:
#export
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, tuple): return list(o)
    return [o]

In [None]:
#export
class Runner():
    def __init__(self, cbs=None):
        self.stop,self.cbs = False,[TrainEvalCallback()]+listify(cbs)

    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

    def one_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        if self('begin_batch'): return
        self.pred = self.model(self.xb)
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        if self('after_backward'): return
        self.opt.step()
        if self('after_step'): return
        self.opt.zero_grad()

    def all_batches(self, dl):
        self.iters = len(dl)
        for xb,yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop=False

    def fit(self, epochs, learn):
        self.epochs,self.learn = epochs,learn

        try:
            if self('begin_fit'): return
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)

                with torch.no_grad(): 
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                if self('after_epoch'): break
            
        finally:
            self('after_fit')
            self.learn = None

    def __call__(self, cb_name):
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, cb_name, None)
            if f and f(self): return True
        return False

In [None]:
#export
class AvgStats():
    def __init__(self, metrics, in_train): self.metrics,self.in_train = listify(metrics),in_train
    
    def reset(self):
        self.tot_loss,self.count = 0.,0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {[o/self.count for o in self.all_stats]}"

    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.tot_loss += run.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn

class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
        
    def stats(self, run): return self.train_stats if run.in_train else self.valid_stats

    def begin_epoch(self, run):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self, run):
        with torch.no_grad(): self.stats(run).accumulate(run)
    
    def after_epoch(self, run):
        print(self.train_stats)
        print(self.valid_stats)

In [None]:
learn = Learner(*get_model(data), loss_func, data)

In [None]:
run = Runner()
run.cbs.append(AvgStatsCallback([accuracy]))

In [None]:
run.fit(3, learn)

train: [0.3119046484375, tensor(0.9049)]
valid: [0.17293448486328125, tensor(0.9499)]
train: [0.140401572265625, tensor(0.9569)]
valid: [0.16141004638671874, tensor(0.9548)]
train: [0.105048544921875, tensor(0.9676)]
valid: [0.1088798828125, tensor(0.9681)]


## Export

In [None]:
!./notebook2script.py 04_callbacks.ipynb

Converted 04_callbacks.ipynb to nb_04.py
