In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_03 import *
import torch.nn.functional as F

In [3]:
x_train,y_train,x_valid,y_valid = get_data()
ds_train,ds_valid = Dataset(x_train, y_train),Dataset(x_valid, y_valid)

File mnist.pkl.gz already exists in data


In [4]:
nh = 50
bs = 64

In [5]:
c = y_train.max().item() + 1
loss_func = F.cross_entropy

In [6]:
# export
class DataBunch:
    def __init__(self, dl_train, dl_valid, c):
        self.dl_train = dl_train
        self.dl_valid = dl_valid
        self.c = c
    
    @property
    def ds_train(self):
        return self.dl_train.dataset
    @property
    def ds_valid(self):
        return self.dl_valid.dataset

In [7]:
data = DataBunch(*get_dls(ds_train, ds_valid, bs), 10)

In [8]:
# export
def get_model(data, lr=0.5, nh=50):
    n_in = data.ds_train.x.shape[1]
    model = nn.Sequential(nn.Linear(n_in, nh), nn.ReLU(), nn.Linear(nh, data.c))
    opt = optim.SGD(model.parameters(), lr)
    return model, opt

In [9]:
# export
class Learner:
    def __init__(self, model, opt, loss_func, data):
        self.model = model
        self.opt = opt
        self.loss_func = loss_func
        self.data = data

In [10]:
model, opt = get_model(data)
learn = Learner(model, opt, loss_func, data)

In [11]:
from IPython.core.debugger import set_trace

In [12]:
def fit(epochs, learn):
    for epoch in range(epochs):
        learn.model.train()
        for xb, yb in data.dl_train:
            preds = learn.model(xb)
            loss = learn.loss_func(preds, yb)
            loss.backward()
            learn.opt.step()
            learn.opt.zero_grad()
        
        learn.model.eval()
        tot_correct = tot_loss = tot_seen = 0
        with torch.no_grad():
            for xb, yb in learn.data.dl_valid:
                preds = model(xb)
                loss = learn.loss_func(preds, yb)
                tot_correct += accuracy(preds, yb).item() * len(xb)
                tot_loss += loss.item() * len(xb)
                tot_seen += len(xb)
            print(f"Acc:{tot_correct/tot_seen:<10}| Loss:{(tot_loss/tot_seen)}")

In [13]:
fit(5, learn)

Acc:0.9544    | Loss:0.15658490743637085
Acc:0.9527    | Loss:0.1588026722729206
Acc:0.9058    | Loss:0.3521318807423115
Acc:0.9697    | Loss:0.1075203885935247
Acc:0.9704    | Loss:0.10032944450825453


## Callback Handler

In [14]:
# Without callbacks
def one_batch(xb,yb):
    pred = model(xb)
    loss = loss_func(pred, yb)
    loss.backward()
    opt.step()
    opt.zero_grad()

def fit():
    for epoch in range(epochs):
        for b in train_dl: one_batch(*b)

In [15]:
def onebatch(xb, yb, cbh):
    if not cbh.begin_batch(xb, yb): 
        return
    preds = cbh.learn.model(xb)
    loss = cbh.learn.loss_func(preds, yb)
    if not cbh.in_train:
        return
    if not cbh.after_loss(loss):
        return
    loss.backward()
    if cbh.after_backward():
        cbh.learn.opt.step()
    if cbh.after_step():
        cbh.learn.opt.zero_grad()
        
def all_batches(dl, cbh):
    for xb, yb in dl:
        onebatch(xb, yb, cbh)
        if cbh.do_stop():
            return

def fit(epochs, learn, cbh):
    if not cbh.begin_fit(learn): 
        return
    for epoch in range(epochs):
        if not cbh.begin_epoch(epoch):
            continue        
        all_batches(learn.data.dl_train, cbh)
        if cbh.begin_validate():
            with torch.no_grad():
                all_batches(learn.data.dl_valid, cbh)
        if cbh.do_stop() or not cbh.after_epoch():
            break
    cbh.after_fit()

In [16]:
class Callback():
    def begin_fit(self, learn):
        self.learn = learn
        return True
    def after_fit(self):
        return True
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True
    def after_epoch(self):
        return True
    def begin_validate(self):
        return True
    def begin_batch(self, xb, yb):
        self.xb = xb
        self.yb = yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self):
        return True
    def after_step(self):
        return True

In [17]:
class CallbackHandler():
    def __init__(self, cbs):
        self.cbs = cbs if cbs else []
    
    def begin_fit(self, learn):
        self.learn = learn
        self.in_train = True
        learn.stop = False
        res = True
        for cb in self.cbs:
            res = res and cb.begin_fit(learn)
        return res
    
    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs:
            res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.learn.model.train()
        self.in_train = True
        res = True
        for cb in self.cbs:
            res = res and cb.begin_epoch(epoch)
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs:
            res = res and cb.after_epoch()
        return res
    
    def begin_validate(self):
        self.learn.model.eval()
        self.in_train = False
        res = True
        for cb in self.cbs:
            res = res and cb.begin_validate()
        return res

    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs:
            res = res and cb.begin_batch(xb, yb)
        return res
    
    def after_loss(self, loss):
        res = True
        for cb in self.cbs:
            res = res and cb.after_loss(loss)
        return res

    def after_backward(self):
        res=True
        for cb in self.cbs:
            res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs:
            res = res and cb.after_step()
        return res

    def do_stop(self):
        try:
            return self.learn.stop
        finally:
            self.learn.stop = False

In [18]:
class StopAfterTenItersCallback(Callback):
    def begin_fit(self, learn):
        self.n_iters = 0
        super().begin_fit(learn)
        return True
    
    def after_step(self):
        self.n_iters += 1
        if self.n_iters>=10:
            self.learn.stop = True
        return True

In [19]:
custom_callbacks = [StopAfterTenItersCallback()]
fit(1, learn, cbh=CallbackHandler(custom_callbacks))

AttributeError: 'StopAfterTenItersCallback' object has no attribute 'n_iter'

## Runner

In [None]:
# export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

In [None]:
# export
class Callback:
    _order = 0
    
    def set_runner(self, run):
        self.run = run
    
    def __getattr__(self, x):
        return getattr(self.run, x)
    
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

In [None]:
# export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.run.n_epochs = 0
        self.run.n_iters = 0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iters += 1
    
    def begin_epoch(self):
        self.run.n_epochs = self.epoch
        self.model.train()
        self.run.in_train = True
    
    def begin_validate(self):
        self.model.eval()
        self.run.in_train = False

In [None]:
class StopAfterTenIter(Callback):
    def after_iter(self):
        if self.train_eval.n_iters >= 10:
            return True

In [None]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

In [None]:
TrainEvalCallback.name

In [None]:
TrainEvalCallback().name

In [None]:
# export
from typing import *

In [None]:
# export
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, Iterable): return list(o)
    return [o]

In [None]:
# export
class Runner:
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
        self.stop = False
        self.cbs = [TrainEvalCallback()] + cbs
        
    @property
    def opt(self):
        return self.learn.opt
    
    @property
    def model(self):
        return self.learn.model
    
    @property
    def loss_func(self):
        return self.learn.loss_func
    
    @property
    def data(self):
        return self.learn.data
    
    def one_batch(self, xb, yb):
        self.xb, self.yb = xb, yb
        
        if self('begin_batch'): return
        self.pred = self.model(self.xb)
        
        if self('after_pred'): return
        self.loss = self.loss_func(self.pred, self.yb)
        
        if self('after_loss') or not self.in_train: return
        self.loss.backward()
        
        if self('after_backward'): return
        self.opt.step()
        
        if self('after_step'): return
        self.opt.zero_grad()
        
    def all_batches(self, dl):
        self.iters = len(dl)
        
        for xb, yb in dl:
            if self.stop: break
            self.one_batch(xb, yb)
            self('after_batch')
        self.stop = False
        
    def fit(self, epochs, learn):
        self.epochs = epochs
        self.learn = learn
        try:

            for cb in self.cbs:
                cb.set_runner(self)

            if self('begin_fit'):
                return
            
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'):
                    self.all_batches(self.data.dl_train)
                
                with torch.no_grad():
                    if not self('begin_validate'):
                        self.all_batches(self.data.dl_valid)
                
                if self('after_epoch'):
                    break
        finally:
            self('after_fit')
            self.learn = None
            
    def __call__(self, cb_name):
        for cb in sorted(self.cbs, key=lambda x:x._order):
            f = getattr(cb, cb_name, None)
            if f and f():
                return True
        return False

In [None]:
# export
class AvgStatsCallback(Callback):
    
    def __init__(self, metrics):
        self.train_stats = AvgStats(metrics, in_train=True)
        self.valid_stats = AvgStats(metrics, in_train=False)
    
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
    
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad():
            stats.accumulate(self.run)
    
    def after_epoch(self):
        print(self.train_stats)
        print(self.valid_stats)        

In [None]:
# export
class AvgStats:
    def __init__(self, metrics, in_train):
        self.metrics = listify(metrics)
        self.in_train = in_train
    
    def reset(self):
        self.total_loss = 0
        self.count = 0
        self.total_metrics = [0.] * len(self.metrics)
    
    @property
    def all_stats(self):
        return [self.total_loss.item()] + self.total_metrics
    
    @property
    def avg_stats(self):
        return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
    
    def accumulate(self, run):
        bn = run.xb.shape[0]
        self.total_loss += (bn * run.loss)
        self.count += bn
        for i, metric in enumerate(self.metrics):
            self.total_metrics[i] += metric(run.pred, run.yb) * bn

In [None]:
data

In [None]:
learn = Learner(*get_model(data), loss_func, data)

In [None]:
accuracy

In [None]:
stats = AvgStatsCallback([accuracy])
run = Runner(cbs=stats)

In [None]:
run.fit(5, learn)

In [None]:
# %debug

In [None]:
!python notebook2script.py 04_callbacks.ipynb