In [17]:
#| eval: false
! [ -e /content ] && pip install -Uqq fastai # upgrade fastai on colab

In [18]:
#| default_exp l2r.callbacks

In [19]:
from fastai.torch_imports import *
from fastai.torch_core import *
from fastai.callback.core import *
from fastcore.all import *
from xcube.imports import *
from xcube.metrics import *

In [20]:
from nbdev.showdoc import *

In [21]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
class TrainEval(Callback):
    order = -10
    
    def before_train(self):
        self.model.train()
        
    def before_validate(self):
        self.model.eval()

In [23]:
class TrackResults(Callback):
    def __init__(self, train_metrics=False): 
        store_attr()
        self.names = ['loss', 'ndcg', 'ndcg_at_6', 'acc']
    
    def before_fit(self): self.losses_full, self.grads_full, self.metrics_full = [], defaultdict(list), defaultdict(list) 
    
    def before_train(self): self._initialise_metrics()
    
    def before_validate(self): self._initialise_metrics()
        
    def after_train(self):
        self.losses_full.extend(self.losses)
        log = self._compute_epoch_mean()
        if self.train_metrics:
            self.metrics_full['trn'].append(log)
        print(self.epoch, self.model.training, *log)
                
    def after_validate(self):
        log = self._compute_epoch_mean()
        if hasattr(self, 'metrics_full'):
            self.metrics_full['val'].append(log)
        print(self.epoch if hasattr(self, 'epoch') else 0, self.model.training, *log)
            
    def _compute_epoch_mean(self):
        _li = [self.losses, self.ndcgs, self.ndcgs_at_6, self.accs]
        _li = [torch.stack(o) if o else torch.Tensor() for o in _li] 
        [self.losses, self.ndcgs, self.ndcgs_at_6, self.accs] = _li
        log = [round(o.mean().item(), 4) if o.sum() else "NA" for o in _li]
        return log
    
    def _initialise_metrics(self): self.losses, self.ndcgs, self.ndcgs_at_6, self.accs = [], [], [], []
    
    def after_batch(self):
        with torch.no_grad():
            loss = self.loss_func(self.preds, self.xb)
            self.losses.append(loss.mean())
            if self.model.training:
                if self.train_metrics: self._compute_metrics()
            else: self._compute_metrics()
                        
    def _compute_metrics(self):
        *_, _ndcg, _ndcg_at_k = ndcg(self.preds, self.xb, k=6)
        self.ndcgs.append(_ndcg.mean())
        self.ndcgs_at_6.append(_ndcg_at_k.mean())
        acc = accuracy(self.xb, self.model).mean()
        self.accs.append(acc.mean())
        
    def after_backward(self):
        for name,param in self.model.named_parameters():
            grad = param.grad.data.detach().clone()
            self.grads_full[name].append(grad)

In [24]:
class ProgressBarCallback(Callback):
    order = 70
    
    def before_fit(self):
        self.mbar = master_bar(range(self.n_epochs))
        
    def before_epoch(self):
        if getattr(self, 'mbar', False): self.mbar.update(self.epoch)
        
    def before_train(self): self._launch_pbar()
    
    def before_validate(self): self._launch_pbar()
        
    def _launch_pbar(self):
        self.pbar = progress_bar(self.dl, parent=getattr(self, 'mbar', None), leave=False)
        self.pbar.update(0)
        
    def after_batch(self):
        self.pbar.update(self.iter_num+1)
    
    def after_train(self):
        self.pbar.on_iter_end()
        
    def after_validate(self):
        self.pbar.on_iter_end()
        
    def after_fit(self):
        if getattr(self, 'mbar', False):
            self.mbar.on_iter_end()
            delattr(self, 'mbar')

In [25]:
class Monitor(Callback):
    order = 60
    def __init__(self, monitor='ndcg_at_6', comp=None, min_delta=0., reset_on_fit=False):
        if comp is None: comp = np.greater
        if comp == np.less: min_delta *= -1
        # store_attr()
        self.monitor,self.comp,self.min_delta,self.reset_on_fit,self.best= monitor,comp,min_delta,reset_on_fit,None
       
    def before_fit(self):
        if self.reset_on_fit or self.best is None: self.best = float('inf') if self.comp == np.less else -float('inf')
        assert self.monitor in self.track_results.names
        self.idx = list(self.track_results.names).index(self.monitor)
        
    def after_epoch(self):
        val = self.track_results.metrics_full.get('val')[-1][self.idx]
        if self.comp(val - self.min_delta, self.best): self.best, self.new_best, = val, True
        else: self.new_best = False

In [26]:
class SaveCallBack(Monitor):
    order = Monitor.order+1
    def __init__(self, 
        fname, 
        monitor='ndcg_at_6', 
        comp=None, 
        min_delta=0., 
        reset_on_fit=False,
    ):
        super().__init__(monitor=monitor, comp=comp, min_delta=min_delta, reset_on_fit=reset_on_fit)
        self.last_saved_path = None
        store_attr('fname')
        
    @property
    def best(self): return self._best
    @best.setter    
    def best(self, b): self._best = b
    
    def after_epoch(self):
        super().after_epoch()
        if self.new_best:
            print(f'Better model found at epoch {self.epoch} with {self.monitor} value: {self.best}.')
            self.learn.save(self.fname)
    
    # def after_fit(self):
        # if self.best_at_end: self.learn.load(self.fname)

## Export

In [27]:
import nbdev; nbdev.nbdev_export()