In [4]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
#export
import torch
import matplotlib.pyplot as plt

import time
from functools import partial

from fastprogress.fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import format_time

from exp.nb_utils import listify

In [6]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')


def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()


class Callback():
    _order = 0

    def set_learner(self, learner): self.learner = learner

    def __getattr__(self, k): return getattr(self.learner, k)

    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f(): return True
        return False

    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')


class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

In [5]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.learner.cur_train_epoch_flt = 0.
        self.learner.cur_train_iter = 0

    def begin_epoch(self):
        self.learner.cur_train_epoch = self.epoch
        self.model.train()
        self.learner.in_train = True

    def begin_validate(self):
        self.model.eval()
        self.learner.in_train = False

    def after_batch(self):
        if not self.in_train: return
        self.learner.cur_train_epoch_flt += 1./self.iters
        self.learner.cur_train_iter += 1

In [29]:
#export
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics, self.in_train = listify(metrics), in_train
        self.reset()

    def reset(self):
        self.total_loss = 0.
        self.count = 0
        self.total_metrics = [0.] * len(self.metrics)

    @property
    def all_stats(self): return [self.total_loss.item()] + self.total_metrics

    @property
    def avg_stats(self): return [s / self.count for s in self.all_stats]

    def __repr__(self):
        if not self.count:
            return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, learner):
        bn = learner.xb.shape[0]
        self.total_loss += learner.loss * bn
        self.count += bn
        for i,m in enumerate(self.metrics):
            self.total_metrics[i] += m(learner.pred, learner.yb) * bn

In [36]:
#export
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)

    def begin_fit(self):
        met_names = ['loss'] + [m.__name__ for m in self.train_stats.metrics]
        names = ['epoch'] + [f'train_{n}' for n in met_names] + [
            f'valid_{n}' for n in met_names] + ['time']
        #Write headers of table
        self.logger(names)

    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        self.start_time = time.time()

    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.learner)

    def after_epoch(self):
        stats = [str(self.epoch)]
        for o in [self.train_stats, self.valid_stats]:
            stats += [f'{v:.6f}' for v in o.avg_stats]
        stats += [format_time(time.time() - self.start_time)]
        #Write row
        self.logger(stats)

In [37]:
#export
class Recorder(Callback):
    def begin_fit(self): self.lrs, self.losses = [], []

    def after_batch(self):
        if not self.in_train: return
        self.lrs.append(self.opt.param_groups[-1]['lr'])
        self.losses.append(self.loss.detach().cpu())

    def plot_lr(self): plt.plot(self.lrs)

    def plot_loss(self): plt.plot(self.losses)

    def plot(self, skip_last=0):
        losses = [o.item() for o in self.losses]
        n = len(losses)-skip_last
        plt.xscale('log')
        plt.plot(self.lrs[:n], losses[:n])


class ParamScheduler(Callback):
    _order = 1

    def __init__(self, pname, sched_funcs):
        self.pname,self.sched_funcs = pname,listify(sched_funcs)

    def begin_batch(self):
        if not self.in_train: return
        fs = self.sched_funcs
        if len(fs)==1: fs = fs*len(self.opt.param_groups)
        pos = self.cur_train_epoch_flt/self.epochs
        for f,h in zip(fs,self.opt.param_groups): h[self.pname] = f(pos)


class LR_Find(Callback):
    _order = 1

    def __init__(self, max_iter=100, min_lr=1e-6, max_lr=10):
        self.max_iter,self.min_lr,self.max_lr = max_iter,min_lr,max_lr
        self.best_loss = 1e9

    def begin_batch(self):
        if not self.in_train: return
        pos = self.cur_train_iter/self.max_iter
        lr = self.min_lr * (self.max_lr/self.min_lr) ** pos
        for pg in self.opt.param_groups: pg['lr'] = lr

    def after_step(self):
        if self.iter>=self.max_iter or self.loss>self.best_loss*10:
            raise CancelTrainException()
        if self.loss < self.best_loss: self.best_loss = self.loss

In [42]:
#export
class ProgressCallback(Callback):
    _order = -1

    def begin_fit(self):
        self.mbar = master_bar(range(self.epochs))
#         self.mbar.on_iter_begin()
        self.learner.set_logger(partial(self.mbar.write, table=True))

    def after_fit(self): self.mbar.on_iter_end()
    def after_batch(self): self.pb.update(self.iter)
    def begin_epoch   (self): self.set_pb()
    def begin_validate(self): self.set_pb()

    def set_pb(self):
        self.pb = progress_bar(self.dl, parent=self.mbar)
        self.mbar.update(self.epoch)

In [43]:
#export
class CudaCallback(Callback):
    def begin_fit(self): self.model.cuda()
    def begin_batch(self): self.run.xb,self.run.yb = self.xb.cuda(),self.yb.cuda()

In [44]:
!python notebook2script.py callbacks.ipynb

Converted callbacks.ipynb to exp\nb_callbacks.py
