### AvgStats callback
A callback that's responsible for managing the training/validation statistics during the training loop.

In [None]:
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True), AvgStats(metrics,False)
        
    def begin_epoch(self):
        # reset stats at the beginning of each epoch
        self.train_stats.reset()
        self.valid_stats.reset()
        
    def after_loss(self):
        # update train or validatioin stats depending on whether in train/eval mode
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run)
    
    def after_epoch(self):
        # show stats after epoch
        print(self.train_stats)
        print(self.valid_stats)

### Recorder callback

The recorder callback is used in the LRF callback for recording the learning rate and the losses.

In [None]:
class Recorder(Callback):
    def begin_fit(self):
        self.lrs = [[] for _ in self.opt.param_groups] # create list of empty lists
        self.losses = []

    def after_batch(self):
        if not self.in_train: return # stop if not in training mode
        for pg,lr in zip(self.opt.param_groups,self.lrs): 
            lr.append(pg['lr'])
        self.losses.append(self.loss.detach().cpu())        

    def plot_lr  (self, pgid=-1): plt.plot(self.lrs[pgid])
    def plot_loss(self, skip_last=0): plt.plot(self.losses[:len(self.losses)-skip_last])
        
    def plot(self, skip_last=0, pgid=-1):
        losses = [o.item() for o in self.losses]
        lrs    = self.lrs[pgid]
        n = len(losses)-skip_last
        plt.xscale('log')
        plt.plot(lrs[:n], losses[:n])

### Param scheduler callback

In [None]:
class ParamScheduler(Callback):
    _order=1
    def __init__(self, pname, sched_funcs): self.pname,self.sched_funcs = pname,sched_funcs
        
    def begin_fit(self):
        if not isinstance(self.sched_funcs, (list,tuple)): # checks sched_funcs is neither list or tuple
            self.sched_funcs = [self.sched_funcs] * len(self.opt.param_groups)

    def set_param(self):
        assert len(self.opt.param_groups)==len(self.sched_funcs)
        for pg,f in zip(self.opt.param_groups,self.sched_funcs):
            # sets the parameter group to a schedule as determined by f with
            # input being training progress in terms of number of epochs
            pg[self.pname] = f(self.n_epochs/self.epochs)
            
    def begin_batch(self): 
        # update parameter schedules at beginning of batch if in training mode
        if self.in_train: self.set_param()

### Learning Rate Finder

The Learning Rate Finder (LRF) is, in my opinion, one of FastAI's novelty that allows neural net training to be more reliable. The central idea is to go through a limited number of batches in the dataset, with the learning rate increasing at every batch. 

As a result, the parameter update is taking larger and larger "steps" with each batch. The loss value is expected to gradually decrease with respect to the gradual increase of the learning rate in the beginning - until it eventually hit a point where the step being taken is "too large", where the symptom is a sudden large jump in loss value.

Below is the LRF implemented as a callback. 

**NOTE:** this implmentation is missing the part where it saves the models' weights initially and loads it back after LRF finishes its routine.

In [None]:
class LR_Find(Callback):
    # note this _order=1 where its parent class Callback has _order=0
    _order=1
    def __init__(self, max_iter=100, min_lr=1e-6, max_lr=10):
        self.max_iter,self.min_lr,self.max_lr = max_iter,min_lr,max_lr
        self.best_loss = 1e9  # init to an wildly large loss value to start

    def begin_batch(self): # runs before each batch begins
        if not self.in_train: return
        # increase lr exponentially, note that n_iter should be coming from the runner object
        pos = self.n_iter/self.max_iter
        lr = self.min_lr * (self.max_lr/self.min_lr) ** pos
        for pg in self.opt.param_groups: pg['lr'] = lr
            
    def after_step(self): # check conditions on whether to stop training after updating the model params
        if self.n_iter>=self.max_iter or self.loss>self.best_loss*10:
            raise CancelTrainException()
        if self.loss < self.best_loss: self.best_loss = self.loss

#### LRF usage

In [None]:
learn = create_learner(get_model, loss_func, data)
run = Runner(cb_funcs=[LR_Find, Recorder])
run.fit(2, learn)

# show result
run.recorder.plot(skip_last=5)
run.recorder.plot_lr()

### Cuda callback

Uses the callback mechanism to automatically handle sticking parameters to a GPU device

In [None]:
# set the device to stick parameters to, i.e. the first GPU device
device = torch.device('cuda',0)

class CudaCallback(Callback):
    def __init__(self,device): self.device=device
    def begin_fit(self): self.model.to(self.device)
    def begin_batch(self): self.run.xb,self.run.yb = self.xb.to(self.device),self.yb.to(self.device)

#### usage

In [None]:
cbfs.append(CudaCallback)
model = get_cnn_model(data)
opt = optim.SGD(model.parameters(), lr=0.4)
learn = Learner(model, opt, loss_func, data)
run = Runner(cb_funcs=cbfs)
run.fit(3, learn)