In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
from exp.nb_03 import *

## DataBunchLearner 

In [3]:
x_train,y_train,x_valid,y_valid=get_data()

In [4]:
train_mean=x_train.mean()
train_std=x_train.std()
x_train=normalize(x_train,train_mean,train_std)
x_valid=normalize(x_valid,train_mean,train_std)


In [12]:
train_ds,valid_ds = Dataset(x_train,y_train),Dataset(x_valid,y_valid)

In [25]:
#export
class DataBunch():
    def __init__(self,train_dl,valid_dl,c):
        self.train_dl=train_dl
        self.valid_dl=valid_dl
        self.c=c
        
    @property
    def train_ds(self): return self.train_dl.dataset
    
    @property
    def valid_ds(self): return self.valid_dl.dataset

In [26]:
nh=50
bs=64
c=y_train.max().item()+1
data = DataBunch(*get_dls(train_ds,valid_ds,bs=bs),c=c)

In [32]:
#export
def get_model(data,lr=0.05):
    ni=data.train_ds.x.shape[1]
    model=nn.Sequential(nn.Linear(ni,nh),nn.ReLU(),nn.Linear(nh,c))
    return model,optim.SGD(model.parameters(),lr=lr),

In [33]:
class Learner():
    def __init__(self,model,opt,data,loss):
        self.data=data
        self.model=model
        self.opt=opt
        self.loss=loss
    

In [34]:
learner=Learner(*get_model(data),data,F.cross_entropy) 

In [35]:
learner

<__main__.Learner at 0x7fc8fe178f10>

In [41]:
def fit(epochs,learner):
    for epoch in range(epochs):
        learner.model.train()
        for xb,yb in learner.data.train_dl:
            pred=learner.model(xb)
            loss=learner.loss(pred,yb)
            loss.backward()
            learner.opt.step()
            learner.opt.zero_grad
        learner.model.eval()
        with torch.no_grad():
            tot_loss,tot_acc = 0.,0.
            for xb,yb in learner.data.valid_dl:
                pred=learner.model(xb)
                loss=learner.loss(pred,yb)
                tot_loss+=loss
                tot_acc+=accuracy(pred,yb)
        nv = len(learner.data.valid_dl)
        print(epoch, tot_loss/nv, tot_acc/nv)
    return tot_loss/nv, tot_acc/nv
    

In [42]:
fit(10,learner)

0 tensor(10.4060) tensor(0.1120)
1 tensor(2.5311) tensor(0.1027)
2 tensor(2.8002) tensor(0.1075)
3 tensor(2.8951) tensor(0.0996)
4 tensor(3.3626) tensor(0.0973)
5 tensor(4.0307) tensor(0.0981)
6 tensor(4.0291) tensor(0.0981)
7 tensor(5.0162) tensor(0.0926)
8 tensor(7.8101) tensor(0.0973)
9 tensor(10.2441) tensor(0.0996)


(tensor(10.2441), tensor(0.0996))

In [57]:
def one_batch(xb,yb,cb):
    if not cb.begin_batch(xb,yb): return
    pred=cb.learner.model(xb)
    loss=learner.loss(pred,yb)
    if not cb.after_loss(loss): return
    loss.backward()
    if cb.after_backward(): cb.learner.opt.step()
    if cb.after_step(): cb.learner.opt.zero_grad


In [80]:
def all_batches(dl,cb):
    for xb,yb in dl:
        one_batch(xb,yb,cb)
        if cb.do_stop(): return


In [82]:
def fit(epochs,learn,cb):
    cb.begin_fit(learn)
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch): continue
        all_batches(learn.data.train_dl,cb)
        if cb.begin_validate():
            with torch.no_grad(): all_batches(learn.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch(): break
    cb.after_fit()
        
        

In [83]:
class Callback():
    def begin_fit(self, learner):
        self.learner = learner
        return True
    def after_fit(self): return True
    def begin_epoch(self, epoch):
        self.epoch=epoch
        return True
    def begin_validate(self): return True
    def after_epoch(self): return True
    def begin_batch(self, xb, yb):
        self.xb,self.yb = xb,yb
        return True
    def after_loss(self, loss):
        self.loss = loss
        return True
    def after_backward(self): return True
    def after_step(self): return True

In [84]:
class CallbackHandler():
    def __init__(self,cbs=None):
        self.cbs = cbs if cbs else []

    def begin_fit(self, learner):
        self.learner,self.in_train = learner,True
        learner.stop = False
        res = True
        for cb in self.cbs: res = res and cb.begin_fit(learner)
        return res

    def after_fit(self):
        res = not self.in_train
        for cb in self.cbs: res = res and cb.after_fit()
        return res
    
    def begin_epoch(self, epoch):
        self.learner.model.train()
        self.in_train=True
        res = True
        for cb in self.cbs: res = res and cb.begin_epoch(epoch)
        return res

    def begin_validate(self):
        self.learner.model.eval()
        self.in_train=False
        res = True
        for cb in self.cbs: res = res and cb.begin_validate()
        return res

    def after_epoch(self):
        res = True
        for cb in self.cbs: res = res and cb.after_epoch()
        return res
    
    def begin_batch(self, xb, yb):
        res = True
        for cb in self.cbs: res = res and cb.begin_batch(xb, yb)
        return res

    def after_loss(self, loss):
        res = self.in_train
        for cb in self.cbs: res = res and cb.after_loss(loss)
        return res

    def after_backward(self):
        res = True
        for cb in self.cbs: res = res and cb.after_backward()
        return res

    def after_step(self):
        res = True
        for cb in self.cbs: res = res and cb.after_step()
        return res
    
    def do_stop(self):
        try:     return self.learner.stop
        finally: self.learner.stop = False

In [85]:
class TestCallback(Callback):
    def begin_fit(self,learn):
        super().begin_fit(learn)
        self.n_iters = 0
        return True
        
    def after_step(self):
        self.n_iters += 1
        print(self.n_iters)
        if self.n_iters>=10: self.learner.stop = True
        return True

In [86]:
fit(1,learner,CallbackHandler(cbs=[TestCallback()]))

1
2
3
4
5
6
7
8
9
10
