In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
#export
from uti.basic_train import *

In [3]:
train_ds,valid_ds = get_dataset(*get_data())
data = Databunch(*get_dl(train_ds,valid_ds,bs=64),c=10)

/home/jupyter/.fastai/data/mnist.pkl.gz


In [4]:
loss_func = F.cross_entropy

In [5]:
learn = Learner(*get_model(data),loss_func,data)

In [6]:
def one_batch(xb,yb,cb_handler):
    if not cb_handler.begin_batch(xb,yb): return
    preds = cb_handler.learn.model(xb)
    loss = cb_handler.learn.loss_func(preds,yb)
    if not cb_handler.after_loss(loss,preds): return
    if not cb_handler.begin_backward(): return
    loss.backward()
    if cb_handler.after_backward(): cb_handler.learn.opt.step()
    if cb_handler.after_step(): cb_handler.learn.opt.zero_grad()

def all_batches(dl, cb_handler):
    for xb,yb in dl:
        one_batch(xb,yb,cb_handler)
        if cb_handler.do_stop(): return 


def fit(epochs,learn,cb_handler):
    if not cb_handler.begin_fit(learn): return
    for epoch in range(epochs):
        if not cb_handler.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl,cb_handler)
        
        if cb_handler.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl,cb_handler)
        if cb_handler.do_stop() or not cb_handler.after_epoch(): break
    
    cb_handler.after_fit()


In [7]:
#export
#abstract class for callbacks
class Callbacks():
    _order = 0
    def begin_fit(self,learn):
        self.learn = learn
        return True
    def after_fit(self): return True
    def begin_epoch(self,epoch):
        self.epoch = epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self,xb,yb):
        self.xb, self.yb = xb, yb
        return True
    def after_loss(self,loss,preds):
        self.loss = loss
        self.preds = preds
        return True
    def after_backward(self): return True
    def after_step(self): return True

In [27]:
#export
class CallbackHandler():
    def __init__(self,cbs=None):
        self.cbs = cbs if cbs else []
    
    def begin_fit(self,learn):
        self.learn = learn
        self.in_train = True
        self.learn.stop = False
        if len(self.cbs) is 0: return True #may cause problem
        res = True
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.begin_fit(learn)
        return res
    
    def after_fit(self):
        if len(self.cbs) is 0: return True
        res = not self.in_train
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.after_fit()
        return res
    
    def begin_epoch(self,epoch):
        self.learn.model.train()
        self.in_train = True
        if len(self.cbs) is 0: return True #may cause problem
        res = True
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.begin_epoch(epoch)
        return res
    
    def begin_validate(self):
        self.learn.model.eval()
        self.in_train = False
        if len(self.cbs) is 0: return True #may cause problem
        res = True
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.begin_validate()
        return res
    
    def after_epoch(self):
        if len(self.cbs) is 0: return True
        res = True
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.after_epoch()
        return res
    
    def begin_batch(self,xb,yb):
        if len(self.cbs) is 0: return True
        res = True
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.begin_batch(xb,yb)
        return res
    
    def after_loss(self,loss,preds):
        if len(self.cbs) is 0: return True
        res = True
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.after_loss(loss,preds)
        return res
    
    def begin_backward(self):
        res = self.in_train
        return res
    
    def after_backward(self):
        if len(self.cbs) is 0: return True
        res = True
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.after_backward()
        return res
    
    def after_step(self):
        if len(self.cbs) is 0: return True
        res = True
        for cb in sorted(self.cbs,key=lambda x: x._order): res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try: return self.learn.stop
        finally: self.learn.stop = False

In [12]:
class Accuracy(Callbacks):
    _order = 10
    def begin_epoch(self,epoch):
        self.total_loss, self.total_acc = 0,0
        print('Acc')
        return super().begin_epoch(epoch)
    
    def begin_validate(self):
        nv = len(self.learn.data.train_dl)
        print('Train: ', self.epoch, self.total_loss / nv, self.total_acc / nv)
        self.total_loss, self.total_acc = 0,0
        return True
    
    def accuracy(self, out, yb): return (torch.argmax(out, dim=1)==yb).float().mean()
    
    def after_loss(self,loss,preds):
        self.total_loss += loss
        self.total_acc += self.accuracy(preds,self.yb)
        return True
    
    def after_epoch(self):
        total_batch = len(self.learn.data.valid_dl)
        print('Valid: ', self.epoch, self.total_loss / total_batch, self.total_acc / total_batch)
        return True

In [13]:
class TestEarly(Callbacks):
    _order = 1
    def begin_epoch(self,epoch):
        print('Early')
        return True

class TestMid(Callbacks):
    _order = 5
    def begin_epoch(self,epoch):
        print('Mid')
        return True

In [14]:
cb_early = TestEarly()
cb_mid = TestMid()
cb = Accuracy()
cb_early._order, cb_mid._order

(1, 5)

In [15]:
cbs = CallbackHandler([cb,cb_early,cb_mid])

In [16]:
for cb in cbs.cbs:
    print(cb)

<__main__.Accuracy object at 0x7f62e72f9208>
<__main__.TestEarly object at 0x7f62e72f96d8>
<__main__.TestMid object at 0x7f62e72f9748>


In [17]:
for cb in sorted(cbs.cbs,key=lambda x: x._order):
    print(cb)

<__main__.TestEarly object at 0x7f62e72f96d8>
<__main__.TestMid object at 0x7f62e72f9748>
<__main__.Accuracy object at 0x7f62e72f9208>


In [18]:
fit(2,learn,cbs)

Early
Mid
Acc
Train:  0 tensor(0.3181, grad_fn=<DivBackward0>) tensor(0.9017)
Valid:  0 tensor(0.1722) tensor(0.9493)
Early
Mid
Acc
Train:  1 tensor(0.1464, grad_fn=<DivBackward0>) tensor(0.9551)
Valid:  1 tensor(0.1212) tensor(0.9630)


In [19]:
learn = Learner(*get_model(data),loss_func,data)
cb_acc = Accuracy()
fit(4,learn,cb_handler=CallbackHandler([cb_acc]))

Acc
Train:  0 tensor(0.3126, grad_fn=<DivBackward0>) tensor(0.9042)
Valid:  0 tensor(0.2466) tensor(0.9229)
Acc
Train:  1 tensor(0.1408, grad_fn=<DivBackward0>) tensor(0.9572)
Valid:  1 tensor(0.1839) tensor(0.9467)
Acc
Train:  2 tensor(0.1082, grad_fn=<DivBackward0>) tensor(0.9672)
Valid:  2 tensor(0.1055) tensor(0.9691)
Acc
Train:  3 tensor(0.0896, grad_fn=<DivBackward0>) tensor(0.9728)
Valid:  3 tensor(0.0959) tensor(0.9708)


In [28]:
learn = Learner(*get_model(data),loss_func,data)
fit(2,learn,cb_handler=CallbackHandler([]))

# Fin

In [29]:
from notebook2script import *

In [30]:
notebook2script('04_callback.ipynb','callback')

Converted 04_callback.ipynb to uti/callback_04.py
