In [9]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [10]:
#export
from pathlib import Path

import torch

from loop.callbacks import Callback, Order
from loop.utils import autoformat

In [11]:
#export
class BestMetric(Callback):
    """A callback that memorizes the best value of metric.
    
    The class is intended to be a base class for other types of metric trackers that
    perform some action when metric stops to improve.
    """

    def __init__(self, phase: str='valid', metric: str='loss', better: 'callable'=min):
        self.phase = phase
        self.metric = metric
        self.better = better
        
    @property
    def formatted_best(self):
        return f'{self.phase}_{self.metric}={autoformat(self.best_value)}'
        
    def training_started(self, **kwargs):
        self.best_value = None
        
    def epoch_started(self, **kwargs):
        self.updated = False
        
    def phase_ended(self, phase, **kwargs):
        ignore = phase.name != self.phase
        if not ignore:
            self.update_best(phase, **kwargs)
        return ignore
        
    def update_best(self, phase, **kwargs):
        new_value = phase.get_last_value(self.metric)
        if self.best_value is None:
            self.best_value = new_value
        else:
            self.best_value = self.better(self.best_value, new_value)
        self.updated = self.best_value == new_value
        return self.updated

In [12]:
#export
class EarlyStopping(BestMetric):
    
    order = Order.Tracker(1)
    
    def __init__(self, patience: int=1, **kwargs):
        super().__init__(**kwargs)
        self.patience = patience
    
    def training_started(self, **kwargs):
        super().training_started(**kwargs)
        self.trials = 0
        self.running = True
        
    def phase_ended(self, phase, **kwargs):
        ignore = super().phase_ended(phase=phase, **kwargs)
        if ignore: return
        if self.updated: 
            self.trials = 0
        else:
            self.trials += 1
            if self.trials >= self.patience:
                breakpoint()
                self.running = False
        
    def epoch_ended(self, phases, epoch, **kwargs):
        super().epoch_ended(phases=phases, epoch=epoch, **kwargs)
        if self.running: return
        from loop.training import TrainingInterrupted
        msg = f'Early stopping at epoch {epoch} with {self.formatted_best}'
        raise TrainingInterrupted(msg)

In [13]:
#export
class ModelSaver(BestMetric):
    
    order = Order.Tracker(2)
    
    def __init__(self, mode: str='every', root: Path=Path.cwd(), **kwargs):
        super().__init__(**kwargs)
        assert mode in {'every', 'best'}
        self.root = Path(root)
        self.mode = mode
        
    def training_started(self, **kwargs):
        super().training_started(**kwargs)
        if not self.root.exists():
            self.root.mkdir(parents=True)
        self.last_saved = None
        
    def epoch_ended(self, phases, epoch, **kwargs):
        super().epoch_ended(phases=phases, epoch=epoch, **kwargs)
        fname = f'model__{self.formatted_best}__epoch={epoch}.pth'
        if self.mode == 'every' or self.updated:
            path = self.root/fname
            torch.save(self.group.model, path)
            self.last_saved = path
            
    def load_last_saved_state(self, model=None):
        if self.last_saved is None:
            raise ValueError('nothing was saved during training')
        model = model or self.group.model
        if model is None:
            raise ValueError('no model provided to restore the saved state')
        model.load_state_dict(torch.load(self.last_saved))

In [20]:
from loop.testing import train_classifier_with_callbacks
from loop.modules import TinyNet

cbs = [C.Average(accuracy, alias='acc'), 
       EarlyStopping(metric='acc', patience=1),
       ModelSaver(mode='best', metric='acc', root=Path.home()/'models')]

train_classifier_with_callbacks(TinyNet(1), cbs=cbs, n=10000, bs=100)

loop.cb['history'].plot()
cbs['early_stopping'].load_last_saved_state()

Exception ignored in: <function _DataLoaderIter.__del__ at 0x7eff21d1d950>
Traceback (most recent call last):
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 717, in __del__
    self._shutdown_workers()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 713, in _shutdown_workers
    w.join()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/process.py", line 138, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <function _DataLoaderIter.__del__ at 0x7eff21d1d950>
Traceback (most recent call last):
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 717, in __del__
    self._shutdown_workers()
  File "/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py",

KeyError: 'valid_acc'

In [8]:
!python export.py -o loop

Exported: /home/ck/code/loop/dev/00a_annotations.ipynb -> loop/annotations.py
Exported: /home/ck/code/loop/dev/00b_config.ipynb -> loop/config.py
Exported: /home/ck/code/loop/dev/00c_utils.ipynb -> loop/utils.py
Exported: /home/ck/code/loop/dev/01a_callbacks.ipynb -> loop/callbacks.py
Exported: /home/ck/code/loop/dev/01b_modules.ipynb -> loop/modules.py
Exported: /home/ck/code/loop/dev/02a_metrics.ipynb -> loop/metrics.py
Exported: /home/ck/code/loop/dev/02b_phase.ipynb -> loop/phase.py
Exported: /home/ck/code/loop/dev/02c_training.ipynb -> loop/training.py
Exported: /home/ck/code/loop/dev/03a_schedule.ipynb -> loop/schedule.py
Exported: /home/ck/code/loop/dev/03b_early_stopping.ipynb -> loop/early_stopping.py
Exported: /home/ck/code/loop/dev/99_testing.ipynb -> loop/testing.py
11 notebook(s) exported into folder: loop
