In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
#export
from exp.nb_02_minibatch_training import *

# `DataBunch` / `Learner`

In [25]:
x_train, y_train, x_valid, y_valid = get_data(url=MNIST_URL)
train_ds = Dataset(x_train, y_train)
valid_ds = Dataset(x_valid, y_valid)
nh = 50
bs = 16
c = y_train.max().item() + 1
loss_fn = F.cross_entropy

```
Step-1: Factor out the connected pieces of info out of the fit() argument list.
        fit(epochs, model, loss_fn, opt, train_dl, valid_dl)
Step-2: Replace the above fit() function with something more concise like:
        fit(epochs, learner)
        
This will allow us to tweak what's happening inside the training loop in other places of the code because he Learner() object will be mutable; so changing any of the its attribute will be noticed in our training loop     

```

In [44]:
#export
class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl = train_dl
        self.valid_dl = valid_dl
        self.c = c
        
    @property
    def train_ds(self):
        return self.train_dl.dataset
    
    @property
    def valid_ds(self):
        return self.valid_dl.dataset
    

In [45]:
data = DataBunch(*get_dls(train_ds=train_ds, valid_ds=valid_ds, bs=bs), c=c)

In [46]:
#export
def get_model(data, lr=0.5, nh=50):
    m = data.train_ds.x.shape[1]
    model = torch.nn.Sequential(torch.nn.Linear(m, nh), 
                                torch.nn.ReLU(),
                                torch.nn.Linear(nh, data.c))
    return model, optim.SGD(model.parameters(), lr=lr)


In [47]:
#export
class Learner():
    def __init__(self, model, opt, loss_fn, data):
        self.model   = model
        self.opt     = opt
        self.loss_fn = loss_fn
        self.data    = data
        

In [48]:
learner = Learner(*get_model(data=data, nh=nh), loss_fn=loss_fn, data=data)

In [84]:
def fit(epochs, learner):
    num_train = len(learner.data.train_dl)
    num_valid = len(learner.data.valid_dl)
    print(f"num_train={num_train}\t num_valid={num_valid}")
    
    for epoch in range(epochs):
        loss_train = 0.
        loss_valid = 0.
        acc_train  = 0.
        acc_valid  = 0.
        
        ### Training phase
        learner.model.train()
        for xb, yb in learner.data.train_dl:
            preds = learner.model(xb)
            loss  = loss_fn(input=preds, target=yb)
            acc   = accuracy(preds=preds, labels=yb)
            
            loss.backward()
            learner.opt.step()
            learner.opt.zero_grad()
            
            loss_train += loss.item()
            acc_train  += acc
            
        ### Validation Phase
        learner.model.eval()
        with torch.no_grad():
            for xb, yb in learner.data.valid_dl:
                preds = learner.model(xb)
                loss  = loss_fn(input=preds, target=yb)
                acc   = accuracy(preds=preds, labels=yb)

                loss_valid += loss.item()
                acc_valid  += acc
        
        print(f"epoch=[{epoch}/{epochs}]\t"\
              f"loss_train={(loss_train/num_train):.5f}\t acc_train={(acc_train/num_train):.5f}\t"\
              f"loss_valid={(loss_valid/num_valid):.5f}\t acc_valid={(acc_valid/num_valid):.5f}")        
    
    return loss_train/num_train, acc_train/num_train


In [85]:
fit(epochs=2, learner=learner)

num_train=3125	 num_valid=313
epoch=[0/2]	loss_train=0.13000	 acc_train=0.97514	loss_valid=0.37034	 acc_valid=0.95927
epoch=[1/2]	loss_train=0.11727	 acc_train=0.97704	loss_valid=0.40302	 acc_valid=0.95737


(0.11726952236815576, tensor(0.9770))

# `CallbackHandler`

```
Because we are performing batchwise operations inside the "fit()" function.
So, let's create separate function "one_batch()" that does our batchwise operation for one batch.

NOTE: Add CALLBACKS, so that we can easily remove the complexity of our "fit()" function and make it more flexible
```

In [86]:
def one_batch(xb, yb, cb):
    if not cb.begin_batch(xb, yb):
        return
    preds = cb.learner.model(xb)
    loss  = cb.learner.loss_fn(input=preds, target=yb)
    if not cb.after_loss(loss):
        return
    loss.backward()
    if cb.after_backward():
        cb.learner.opt.step()
    if cb.after_step():
        cb.learner.opt.zero_grad()

def all_batches(dl, cb):
    for xb, yb in dl:
        one_batch(xb, yb, cb)
        if cb.do_stop():
            return

def fit(epochs, learner, cb):
    if not cb.begin_fit(learner):
        return
    for epoch in range(epochs):
        if not cb.begin_epoch(epoch):
            continue
        all_batches(learner.data.train_dl, cb)
        
        if cb.begin_validate():
            with torch.no_grad():
                all_batches(learner.data.valid_dl, cb)
        if cb.do_stop() or not cb.after_epoch():
            break
    cb.after_fit()
    

In [87]:
class Callback():
    def begin_fit(self, learner):
        self.learner = learner
        return True
    
    def after_fit(self): return True
    
    def begin_epoch(self, epoch):
        self.epoch = epoch
        return True 
    
    def after_epoch(self): return True
    
    def begin_validate(self):
        return True
    
    def begin_batch(self, xb, yb):
        self.xb = xb
        self.yb = yb
        return True
    
    def after_loss(self, loss):
        self.loss = loss
        return True
    
    def after_backward(self): return True
    
    def after_step(self): return True
    

In [91]:
class CallbackHandler():
    def __init__(self, cbs=None):
        self.cbs = cbs if cbs else []
    
    def begin_fit(self, learner):
        self.learner = learner
        self.in_train = True
        self.learner.stop = False
        result = True
        for cb in self.cbs:
            result = result and cb.begin_fit(learner)
        return result
    
    def after_fit(self):
        result = not self.in_train
        for cb in self.cbs:
            result = result and cb.after_fit()
        return result
    
    def begin_epoch(self, epoch):
        self.learner.model.train()
        self.in_train = True
        result = True
        for cb in self.cbs:
            result = result and cb.begin_epoch(epoch)
        return result
    
    def after_epoch(self):
        result = True
        for cb in self.cbs:
            result = result and cb.after_epoch()
        return result
    
    def begin_validate(self):
        self.learner.model.eval()
        self.in_train = False
        result = True
        for cb in self.cbs:
            result = result and cb.begin_validate()
        return result
    
    def begin_batch(self, xb, yb):
        result = True
        for cb in self.cbs:
            result = result and cb.begin_batch(xb, yb)
        return result
    
    def after_loss(self, loss):
        result = self.in_train
        for cb in self.cbs:
            result = result and cb.after_loss(loss)
        return result
    
    def after_backward(self):
        result = True
        for cb in self.cbs:
            result = result and cb.after_backward()
        return result
    
    def after_step(self):
        result = True
        for cb in self.cbs:
            result = result and cb.after_step()
        return result
    
    def do_stop(self):
        try:
            return self.learner.stop
        finally:
            self.learner.stop = False


In [96]:
class TestCallback(Callback):
    def begin_fit(self, learner):
        super().begin_fit(learner)
        self.n_iters = 0
        return True
    
    def after_step(self):
        self.n_iters += 1
        print(f"n_iters = {self.n_iters}")
        if self.n_iters >= 10:
            self.learner.stop = True
        return True
    

In [97]:
fit(epochs=1, learner=learner, cb=CallbackHandler([TestCallback()]))

n_iters = 1
n_iters = 2
n_iters = 3
n_iters = 4
n_iters = 5
n_iters = 6
n_iters = 7
n_iters = 8
n_iters = 9
n_iters = 10


```
This is roughly how "fastai" does it for now (except that the "CallbackHandler()" can also modify and return "xb", "yb" and "loss").

We will try to create a single class that has access to everything and can change at any time.
The fact that we're passing "cb" to so many functions is a hint that they all should be in the same class!!!
We will call this class "Runner()" !!!
```

# `Runner`

In [109]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

class Callback():
    _order = 0
    def set_runner(self, run):
        self.run = run
    def __getattr__(self, key):
        return getattr(self.run, key)
    
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')
    

In [116]:
#export
class TrainEvalCallback(Callback):
    """
    This callback is resposnsible to switch the model back and forth
    in training and validation mode, as well as maintaining the count
    of the total numbe rof iterations or %age of iterations elapsed in the epoch
    """
    
    def begin_fit(self):
        self.run.n_epochs = 0.
        self.run.n_iters  = 0
        
    def after_batch(self):
        if not self.in_train:
            return
        self.run.n_epochs += 1./self.iters
        self.run.n_iters  += 1
        
    def begin_epoch(self):
        self.run.n_epochs = self.epoch
        self.model.train()
        self.run.in_train = True
        
    def begin_validate(self):
        self.model.eval()
        self.run.in_train = False
        

```Recreating our TestCallback```

In [111]:
class TestCallback(Callback):
    def after_step(self):
        if self.train_eval.n_iters >= 10:
            return True

In [112]:
cbname = 'TrainEvalCallback'
camel2snake(cbname)

'train_eval_callback'

In [115]:
TrainEvalCallback().name

'train_eval'

In [119]:
class Lilashah(Callback):
    pass
a = Lilashah()
a.name

'lilashah'

In [120]:
#export
from typing import *

def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str):  return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]


In [141]:
#export
class Runner():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
        self.stop = False
        self.cbs = cbs + [TrainEvalCallback()]
    
    @property
    def opt(self):       return self.learner.opt
    @property
    def model(self):     return self.learner.model
    @property
    def loss_func(self): return self.learner.loss_fn
    @property
    def data(self):      return self.learner.data
    
    def one_batch(self, xb, yb):
        self.xb = xb
        self.yb = yb
        if self("begin_batch"):
            return
        self.pred = self.model(self.xb)
        if self("after_pred"):
            return
        self.loss = self.loss_func(self.pred, self.yb)
        if self("after_loss") or not self.in_train:
            return
        self.loss.backward()
        if self("after_backward"):
            return
        self.opt.step()
        if self("after_step"):
            return
        self.opt.zero_grad()
        
    def all_batches(self, dl):
        self.iters = len(dl)
        for xb, yb in dl:
            if self.stop:
                break
            self.one_batch(xb, yb)
            self("after_batch")
        self.stop = False
        
    def fit(self, epochs, learner):
        self.epochs  = epochs
        self.learner = learner
        
        try:
            for cb in self.cbs:
                cb.set_runner(run=self)
            if self("begin_fit"):
                return
            for epoch in range(self.epochs):
                self.epoch = epoch
                
                ### Training Phase
                if not self("begin_epoch"):
                    self.all_batches(dl=self.data.train_dl)
                    
                ### Validation Phase
                with torch.no_grad():
                    if not self("begin_validate"):
                        self.all_batches(dl=self.data.valid_dl)
                
                if self("after_epoch"): break
                    
        finally:
            self("after_fit")
            self.learner = None
        
    def __call__(self, cb_name):
        """This is the code which takes also takes care of calling self(*)"""
        for cb in sorted(self.cbs, key=lambda x: x._order):
            f = getattr(cb, cb_name, None)                    ### equivalent to "cb.cb_name"; default is None
            if f and f():
                return True
        return False


### We will also define a new `Callback` that calculates and stores the `metrics`

In [158]:
#export
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics  = listify(metrics)
        self.in_train = in_train
    
    def reset(self):
        self.tot_loss = 0.
        self.count    = 0
        self.tot_mets = [0.] * len(self.metrics)
        
    @property
    def all_stats(self):
        return self.tot_mets + [self.tot_loss.item()]
    
    @property
    def avg_stats(self):
        return [o/self.count for o in self.all_stats]
    
    def __repr__(self):
        if not self.count: 
            return ""
        return f"{'train: ' if self.in_train else 'valid: '}{self.avg_stats}"
    
    def accumulate(self, run):
        bn = run.xb.shape[0]
        ### print(f"bn={bn}")
        self.tot_loss += run.loss * bn
        ### print(f"run.loss, *bn = {run.loss}, {run.loss*bn}, self.tot_loss={self.tot_loss}")
        self.count    += bn
        ### print(f"bn={bn}, self.count={self.count}")
        for i, m in enumerate(self.metrics):
            self.tot_mets[i] += m(run.pred, run.yb) * bn
                  

In [163]:
#export
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats = AvgStats(metrics=metrics, in_train=True)
        self.valid_stats = AvgStats(metrics=metrics, in_train=False)
        
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad():
            stats.accumulate(self.run)
    
    def after_epoch(self):
        print(f"epoch=[{self.epoch+1}/{self.epochs}]:\t"\
              f"{self.train_stats}\t"\
              f"{self.valid_stats}")        
        

In [164]:
learner = Learner(*get_model(data), loss_fn=loss_fn, data=data)

In [165]:
stats_callback = AvgStatsCallback(metrics=[accuracy])
run = Runner(cbs=stats_callback)

In [166]:
run.fit(epochs=2, learner=learner)

epoch=[1/2]:	train: [tensor(0.8930)]	valid: [tensor(0.9423)]
epoch=[2/2]:	train: [tensor(0.9367)]	valid: [tensor(0.9480)]


In [167]:
bs

16