In [14]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [15]:
#export
import time
from functools import partial

from fastprogress.fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import format_time

from exp.nb_utils import listify

In [5]:
#export
import re

_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')

def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

class Callback():
    _order=0
    def set_learner(self, learner): self.learner = learner
    def __getattr__(self, k): return getattr(self.learner, k)

    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f(): return True
        return False
    
    @property
    def name(self):
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback')

    
class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

In [16]:
#export
class TrainEvalCallback(Callback):
    def begin_fit(self):
        self.learner.n_epochs=0.
        self.learner.n_iter=0

    def begin_epoch(self):
        self.learner.n_epochs=self.epoch
        self.model.train()
        self.learner.in_train=True

    def begin_validate(self):
        self.model.eval()
        self.run.in_train=False
        
    def after_batch(self):
        if not self.in_train: return
        self.learner.n_epochs += 1./self.iters
        self.learner.n_iter   += 1

In [8]:
class AvgStats():
    def __init__(self, metrics, in_train):
        self.metrics, self.in_train = listify(metrics), in_train
        self.reset()
    
    def reset(self):
        self.total_loss = 0.
        self.total_loss = 0
        self.total_metrics = [0.] * len(self.metrics)
        
    @property
    def all_stats(self): return [self.total_loss.item()] + self.total_metrics
    
    @property
    def avg_stats(self): return [s / self.count for s in self.all_stats]
        
    def __repr__(self):
        if not self.count:
            return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

In [10]:
#export
class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
    
    def begin_fit(self):
        met_names = ['loss'] + [m.__name__ for m in self.train_stats.metrics]
        names = ['epoch'] + [f'train_{n}' for n in met_names] + [
            f'valid_{n}' for n in met_names] + ['time']
        self.logger(names)
    
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        self.start_time = time.time()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.learner)
    
    def after_epoch(self):
        stats = [str(self.epoch)] 
        for o in [self.train_stats, self.valid_stats]:
            stats += [f'{v:.6f}' for v in o.avg_stats] 
        stats += [format_time(time.time() - self.start_time)]
        self.logger(stats)


In [11]:
#export
class ProgressCallback(Callback):
    _order=-1
    def begin_fit(self):
        self.mbar = master_bar(range(self.epochs))
        self.mbar.on_iter_begin()
        self.learner.set_logger(partial(self.mbar.write, table=True))
        
    def after_fit(self): self.mbar.on_iter_end()
    def after_batch(self): self.pb.update(self.iter)
    def begin_epoch   (self): self.set_pb()
    def begin_validate(self): self.set_pb()
        
    def set_pb(self):
        self.pb = progress_bar(self.dl, parent=self.mbar)
        self.mbar.update(self.epoch)

In [None]:
!python notebook2script.py callbacks.ipynb