In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Getting Rid of the `Runner`

In [2]:
from exp.nb_09 import *

- On closer inspection, the `Runner` is not essential, especially when `Learner` already has everything it needs in its state. Thus we will implement everything inside it, while removing the need for another object.

## Loading Imagenette data

In [3]:
path = datasets.untar_data(datasets.URLs.IMAGENETTE_160)

In [5]:
tfms = [make_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]
bs=64

img_list = ImageList.from_files(path, tfms=tfms)
split = SplitData.split_by_func(img_list, partial(grandparent_splitter,
                                                  valid_name='val'))
lab_list = label_by_func(split, parent_labeler, proc_y=CategoryProcessor())
data = lab_list.to_databunch(bs, c_in=3, c_out=10, num_workers=4)

In [6]:
callbacks = [partial(AvgStatsCallback, accuracy),
            CudaCallback,
            partial(BatchTransformXCallback, norm_imagenette)]

In [7]:
nfs = [32]*4

In [None]:
# Rebuilding the Learner
def param_getter(m): return m.parameters()

class Learner():
    def __init__(self, model, data, loss_func, opt_func=sgd_opt, lr=1e-2,
                splitter=param_getter, cbs=None, cb_funcs=None):
        self.model, self.data, self.loss_func, self.opt_func, self.lr, self.splitter = model, data, loss_func, opt_func, lr, splitter
        self.in_train, self.logger, self.opt = False, print, None
        
        # Avoid the need for set_runner
        self.cbs = []
        self.add_cb(TrainEvalCallback())
        self.add_cbs(cbs)
        self.add_cbs(cbf() for cbf in listify(cb_funcs))
        
    def add_cbs(self, cbs):
        for cb in listify(cbs): self.add_cb(cb)
            
    def add_cb(self, cb):
        cb.set_runne(self)
        setattr(self, cb.name, cb)
        self.cbs.append(cb)
    
    def remove_cbs(self, cbs):
        for cb in listify(cbs): self.cbs.remove(cb)
            
    def one_batch(self, i, xb, yb):
        try:
            self.iter = i
            self.xb, self.yb = xb, yb;                      self('begin_batch')
            self.pred = self.model(self.xb);                self('after_pred')
            self.loss = self.loss_func(self.pred, self.yb); self('after_loss')
            if not self.in_train: return
            self.loss.backward();                           self('after_backward')
            self.opt.step();                                self('after_step')
            self.opt.zero_grad()
        except CancelBatchException:                        self('after_cancel_batch')
        finally:                                            self('after_batch')
            
    def all_batches(self):
        self.iters = len(self.dl)
        try:
            for i, (xb, yb) in enumerate(self.dl): self.one_batch(i, xb, yb)
        except CancelEpochException: self('after_cancel_epoch')
            
    def do_begin_fit(self, epochs):
        self.epochs, self.loss = epochs, tensor(0.)
        self('begin_fit')
        
    def do_begin_epoch(self, epoch):
        self.epoch, self.dl = epoch, self.data, train_dl
        return self('begin_epoch')
    
    
        