In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import sys
#sys.path.append('shallow')

# Imports

In [None]:
#export
from functools import partial

from shallow import nb_utils

#from nb_utils import GetAttr

# Code

In [None]:
#export 
class Callback(nb_utils.GetAttr): _default='learner'

class ParamScheduler(Callback):
    def __init__(self, phase, pname, sched_func):
        self.pname, self.sched_func = pname, sched_func
        setattr(self, phase, self.set_param)
        
    def set_param(self):
        setattr(self.learner, self.pname, self.sched_func(self.n_epochs/self.epochs))
    
class SetupLearnerCB(Callback):
    def before_batch(self):
        xb,yb = to_device(self.batch)
        self.learner.batch = tfm_x(xb),yb

    def before_fit(self): self.model.cuda()

class TrackResults(Callback):
    def before_epoch(self): self.accs,self.losses,self.ns = [],[],[]
        
    def after_epoch(self):
        n = sum(self.ns)
        print(self.epoch, self.model.training,
              sum(self.losses).item()/n, sum(self.accs).item()/n)
        
    def after_batch(self):
        xb,yb = self.batch
        acc = (self.preds.argmax(dim=1)==yb).float().sum()
        self.accs.append(acc)
        n = len(xb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

class LRFinder(Callback):
    def before_fit(self):
        self.losses,self.lrs = [],[]
        self.learner.lr = 1e-6
        
    def before_batch(self):
        if not self.model.training: return
        self.opt.lr *= 1.2

    def after_batch(self):
        if not self.model.training: return
        if self.opt.lr>10 or torch.isnan(self.loss): raise CancelFitException
        self.losses.append(self.loss.item())
        self.lrs.append(self.opt.lr)

In [None]:
def append_stats(hook, mod, inp, outp, bins=100, vmin=0, vmax=0):
    if not hasattr(hook,'stats'): hook.stats = ([],[],[])
    means,stds,hists = hook.stats
    means.append(outp.data.mean().cpu())
    stds .append(outp.data.std().cpu())
    hists.append(outp.data.cpu().histc(bins,vmin,vmax))
    
class Hook():
    def __init__(self, m, f): self.hook = m.register_forward_hook(partial(f, self))
    def remove(self): self.hook.remove()
    def __del__(self): self.remove()
        
class Hooks(nb_utils.ListContainer):
    def __init__(self, ms, f): super().__init__([Hook(m, f) for m in ms])
    def __enter__(self, *args): return self
    def __exit__ (self, *args): self.remove()
    def __del__(self): self.remove()

    def __delitem__(self, i):
        self[i].remove()
        super().__delitem__(i)
        
    def remove(self):
        for h in self: h.remove()

# Tests

### Param sched

In [None]:
import torch
import matplotlib.pyplot as plt

from shallow import nb_schedulers

In [None]:
class Learner:
    def __init__(self, cbs):
        self.cbs = cbs
        self.lr = -1
        self.epochs = 100
        self.n_epochs = 13
        for c in self.cbs: c.learner=self
    
    def t(self):
        self('before_epoch')
        
    def tt(self):
        self('bla_bla')
        
    def __call__(self, name):
        for cb in self.cbs: getattr(cb, name, nb_utils.noop())()

In [None]:

p = ParamScheduler('before_epoch', 'lr', nb_schedulers.sched_lin(0,.1))

In [None]:
l = Learner([p])
l.n_epochs = 88 # out of 100
l.t()
l.lr

### Hooks