In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
!python export.py -o loop

Exported: /home/ck/code/loop/dev/00a_annotations.ipynb -> loop/annotations.py
Exported: /home/ck/code/loop/dev/00b_config.ipynb -> loop/config.py
Exported: /home/ck/code/loop/dev/00c_utils.ipynb -> loop/utils.py
Exported: /home/ck/code/loop/dev/00d_mixins.ipynb -> loop/mixins.py
Exported: /home/ck/code/loop/dev/01a_callbacks.ipynb -> loop/callbacks.py
Exported: /home/ck/code/loop/dev/01b_modules.ipynb -> loop/modules.py
Exported: /home/ck/code/loop/dev/02a_metrics.ipynb -> loop/metrics.py
Exported: /home/ck/code/loop/dev/02b_phase.ipynb -> loop/phase.py
Exported: /home/ck/code/loop/dev/02c_training.ipynb -> loop/training.py
Exported: /home/ck/code/loop/dev/03a_schedule.ipynb -> loop/schedule.py
Exported: /home/ck/code/loop/dev/03b_early_stopping.ipynb -> loop/early_stopping.py
Exported: /home/ck/code/loop/dev/99_testing.ipynb -> loop/testing.py
12 notebook(s) exported into folder: loop


In [3]:
#export
from pathlib import Path

import torch

from loop.callbacks import Callback, Order
from loop.utils import autoformat

In [4]:
#export
class BestMetric(Callback):
    """A callback that memorizes the best value of metric.
    
    The class is intended to be a base class for other types of metric trackers that
    perform some action when metric stops to improve.
    """

    def __init__(self, phase: str='valid', metric: str='loss', better: 'callable'=min):
        self.phase = phase
        self.metric = metric
        self.better = better
        
    @property
    def formatted_best(self):
        return f'{self.phase}_{self.metric}={autoformat(self.best_value)}'
        
    def training_started(self, **kwargs):
        self.best_value = None
        
    def epoch_started(self, **kwargs):
        self.updated = False
        
    def phase_ended(self, phase, **kwargs):
        ignore = phase.name != self.phase
        if not ignore:
            self.update_best(phase, **kwargs)
        return ignore
        
    def update_best(self, phase, **kwargs):
        breakpoint()
        new_value = phase.get_last_value(self.metric)
        if self.best_value is None:
            self.best_value = new_value
        else:
            self.best_value = self.better(self.best_value, new_value)
        self.updated = self.best_value == new_value
        return self.updated

In [5]:
#export
class EarlyStopping(BestMetric):
    
    order = Order.Tracker(1)
    
    def __init__(self, patience: int=1, **kwargs):
        super().__init__(**kwargs)
        self.patience = patience
    
    def training_started(self, **kwargs):
        super().training_started(**kwargs)
        self.trials = 0
        self.running = True
        
    def phase_ended(self, phase, **kwargs):
        ignore = super().phase_ended(phase=phase, **kwargs)
        if ignore: return
        if self.updated: 
            self.trials = 0
        else:
            self.trials += 1
            if self.trials >= self.patience:
                self.running = False
        
    def epoch_ended(self, phases, epoch, **kwargs):
        super().epoch_ended(phases=phases, epoch=epoch, **kwargs)
        if self.running: return
        from loop.training import TrainingInterrupted
        msg = f'Early stopping at epoch {epoch} with {self.formatted_best}'
        raise TrainingInterrupted(msg)

In [6]:
#export
class ModelSaver(BestMetric):
    
    order = Order.Tracker(2)
    
    def __init__(self, mode: str='every', root: Path=Path.cwd(), **kwargs):
        super().__init__(**kwargs)
        assert mode in {'every', 'best'}
        self.root = Path(root)
        self.mode = mode
        
    def training_started(self, **kwargs):
        super().training_started(**kwargs)
        if not self.root.exists():
            self.root.mkdir(parents=True)
        self.last_saved = None
        
    def epoch_ended(self, phases, epoch, **kwargs):
        super().epoch_ended(phases=phases, epoch=epoch, **kwargs)
        fname = f'model__{self.formatted_best}__epoch={epoch}.pth'
        if self.mode == 'every' or self.updated:
            path = self.root/fname
            torch.save(self.group.model, path)
            self.last_saved = path
            
    def load_last_saved_state(self, model=None):
        if self.last_saved is None:
            raise ValueError('nothing was saved during training')
        model = model or self.group.model
        if model is None:
            raise ValueError('no model provided to restore the saved state')
        model.load_state_dict(torch.load(self.last_saved))

In [7]:
from loop.callbacks import Average
from loop.testing import train_classifier_with_callbacks
from loop.modules import TinyNet
from loop.metrics import accuracy

cbs = [Average(accuracy, alias='acc'), 
       EarlyStopping(metric='acc', patience=1),
       ModelSaver(mode='best', metric='acc', root=Path.home()/'models')]

loop = train_classifier_with_callbacks(TinyNet(1), cbs=cbs, n=10000, bs=100)

loop.cb['history'].plot()
cbs['early_stopping'].load_last_saved_state()

> <ipython-input-4-28fce76911c9>(32)update_best()
-> new_value = phase.get_last_value(self.metric)
(Pdb) self.metric
'acc'
(Pdb) u
> <ipython-input-4-28fce76911c9>(27)phase_ended()
-> self.update_best(phase, **kwargs)
(Pdb) u
> <ipython-input-5-ae4c545295dd>(16)phase_ended()
-> ignore = super().phase_ended(phase=phase, **kwargs)
(Pdb) u
> /home/ck/code/loop/dev/loop/callbacks.py(229)__call__()
-> method(**kwargs)
(Pdb) u
> /home/ck/code/loop/dev/loop/callbacks.py(200)phase_ended()
-> def phase_ended(self, **kwargs): self('phase_ended', **kwargs)
(Pdb) u
> /home/ck/code/loop/dev/loop/training.py(94)train_one_epoch()
-> cb.phase_ended(phase=phase)
(Pdb) ll
 63  	    def train_one_epoch(self, phases: list, curr_epoch: int=1):
 64  	        cb, model, opt = self.cb, self.model, self.opt
 65  	
 66  	        cb.epoch_started(epoch=curr_epoch)
 67  	
 68  	        for phase in phases:
 69  	            n = len(phase.loader)
 70  	            cb.phase_started(phase=phase, total_batches=n)
 71

BdbQuit: 

In [8]:
# Fix EarlyStopping: should check metrics on epoch end, not phase end