In [1]:
from fastai.torch_core import *
from fastai.data import DataBunch
from fastai.callback import *
from fastai.basic_train import Learner, LearnerCallback

In [2]:
from fastai import * 
from fastai.docs import * 
from fastai.vision import * 

In [3]:
# __all__ = ['TerminateOnNaN', 'EarlyStopping', 'SaveModel']

In [4]:
untar_data(MNIST_PATH)
MNIST_PATH

PosixPath('../data/mnist_sample')

In [5]:
data = image_data_from_folder(MNIST_PATH, ds_tfms=(rand_pad(2, 28), []))

___

The callbacks bellow are based on Keras Callbacks of same name:
https://github.com/keras-team/keras/blob/master/keras/callbacks.py

___

### The Problem: Loss is Nan

In [6]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])
learn.fit(2,1e4)

VBox(children=(HBox(children=(IntProgress(value=0, max=2), HTML(value='0.00% [0/2 00:00<00:00]'))), HTML(value…

Total time: 00:08
epoch  train loss  valid loss  accuracy
0      nan         nan         0.495584  (00:04)
1      nan         nan         0.495584  (00:03)



### The Solution

The Callback bellow is *very* influenced by Keras Callback of same name.

In [7]:
class TerminateOnNaN(LearnerCallback):
    "A `LearnerCallback` that terminates training if loss is NaN."
    
    def __init__(self):
        super().__init__(learn)
        self.stop = False
    
    def on_batch_end(self,epoch:int, num_batch:int,  **kwargs:Any)->None:
        loss = learn.recorder.losses[-1:][0]
        if loss is not None:
            if torch.isnan(loss):
                print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')
                self.stop = True
                return True
    
    def on_epoch_end(self, metrics, **kwargs:Any)->None:
        return self.stop
    

In [8]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])

In [9]:
learn.fit(2,1e4, callbacks=[TerminateOnNaN()])

VBox(children=(HBox(children=(IntProgress(value=0, max=2), HTML(value='0.00% [0/2 00:00<00:00]'))), HTML(value…

Epoch/Batch (0/6): Invalid loss, terminating training.
Epoch/Batch (0/7): Invalid loss, terminating training.


Actually, I don't know why it went to the second batch if I set return `True` in the first.

### The Problem: Metric does not improve

In [10]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])

In [11]:
learn.fit(3,1e-42)

VBox(children=(HBox(children=(IntProgress(value=0, max=3), HTML(value='0.00% [0/3 00:00<00:00]'))), HTML(value…

Total time: 00:12
epoch  train loss  valid loss  accuracy
0      1.036479    0.927704    0.396467  (00:04)
1      1.021361    0.945106    0.377821  (00:03)
2      1.049802    0.957063    0.383710  (00:03)



The Callback bellow is basically a simplified port of Keras Early Stopping callback to fastai/pytorch.

In [12]:
class EarlyStopping(LearnerCallback):
    "A `LearnerCallback` that terminates training when monitored quantity stops improving."
    def __init__(self, 
                 monitor='val_loss',
                 min_delta=0,
                 patience=0,
                 mode='auto'):
        
        super().__init__(learn)
        
        self.monitor = monitor
        self.min_delta = min_delta
        self.patience = patience
        self.mode = mode
        if mode not in ['auto', 'min', 'max']:
            #should I use warning?
            print(f'TerminateEarly mode {mode} is invalid, falling back to "auto" mode.')
            mode = 'auto'
        if mode == 'min':
            self.operator = np.less
        elif mode == 'max':
            self.operator = np.greater
        else:
            if 'loss' in self.monitor:
                self.operator = np.less
            else:
                self.operator = np.greater
        if self.operator == np.less:
            self.min_delta *= -1
    
    def on_train_begin(self, **kwargs:Any)->None:
        self.wait = 0
        self.stopped_epoch = 0
        self.best = float('inf') if self.operator == np.less else -float('inf')

    def on_epoch_end(self, epoch, **kwargs:Any)->None:
        current = self.get_monitor_value()
        if current is None:
            return
         
        if self.operator(current - self.min_delta, self.best):
            
            self.best = current
            self.wait = 0
        else:
            
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                return True
        return False
                
    def on_train_end(self, **kwargs:Any)->None:
        if self.stopped_epoch > 0:
            print(f'Epoch {self.stopped_epoch}: early stopping')
                                                                          
    def get_monitor_value(self):
        values = {'val_loss':learn.recorder.val_losses[-1:][0]
               ,'trn_loss':learn.recorder.losses[-1:][0].cpu().numpy()}
        for i, name in enumerate(learn.recorder.names[3:]):
            values[name]=learn.recorder.metrics[-1:][0][i] 

        if values.get(self.monitor) is None:
            print(f'Early stopping conditioned on metric `{self.monitor}` which is not available. Available metrics are: {", ".join(map(str, learn.recorder.names[3:]))}')   
        return values.get(self.monitor)
    

In [13]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy], callbacks=[EarlyStopping(monitor='accuracy', min_delta=0.01, patience=3)] )

In [14]:
learn.fit(50,1e-42)

VBox(children=(HBox(children=(IntProgress(value=0, max=50), HTML(value='0.00% [0/50 00:00<00:00]'))), HTML(val…

Epoch 4: early stopping


### The Problem: best result is not in the last epoch

In [15]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy] )

In [16]:
learn.fit(5,1e-42)

VBox(children=(HBox(children=(IntProgress(value=0, max=5), HTML(value='0.00% [0/5 00:00<00:00]'))), HTML(value…

Total time: 00:19
epoch  train loss  valid loss  accuracy
0      0.849374    0.728057    0.555937  (00:04)
1      0.844745    0.706033    0.575564  (00:03)
2      0.843050    0.715870    0.562316  (00:03)
3      0.861841    0.722247    0.558881  (00:04)
4      0.853869    0.718867    0.555937  (00:03)



Best epoch is #1. But model is in #4.

In [17]:
class SaveModel(LearnerCallback):
    "A `LearnerCallback` that terminates training when monitored quantity stops improving."
    def __init__(self,
                 monitor='val_loss',
                 every = 'improvement',
                 mode='auto'):
        
        super().__init__(learn)
        
        self.monitor = monitor
        self.every = every
        self.mode = mode
        if every not in ['improvement', 'epoch']:
            #should I use warning?
            print(f'SaveModel every {every} is invalid, falling back to "improvement".')
            every = 'improvement'
        
        if mode not in ['auto', 'min', 'max']:
            #should I use warning?
            print(f'SaveModel mode {mode} is invalid, falling back to "auto" mode.')
            mode = 'auto'
        if mode == 'min':
            self.operator = np.less
        elif mode == 'max':
            self.operator = np.greater
        else:
            if 'loss' in self.monitor:
                self.operator = np.less
            else:
                self.operator = np.greater
   
    def on_train_begin(self, **kwargs:Any)->None:
        self.best = float('inf') if self.operator == np.less else -float('inf')
    
    def on_epoch_end(self, epoch, **kwargs:Any)->None:
        if self.every=="epoch":
            filename = f'model__epoch{epoch}'
            learn.save(filename)
            return 
        else: #every="improvement"
            current = self.get_monitor_value()
            if current is None:
                return
            if self.operator(current, self.best):
                self.best = current
                filename = f'bestmodel_epoch{epoch}'
                learn.save(filename)
            return 
                                                                          
    def get_monitor_value(self):
        values = {'val_loss':learn.recorder.val_losses[-1:][0]
               ,'trn_loss':learn.recorder.losses[-1:][0].cpu().numpy()}
        for i, name in enumerate(learn.recorder.names[3:]):
            values[name]=learn.recorder.metrics[-1:][0][i] 

        if values.get(self.monitor) is None:
            print(f'Early stopping conditioned on metric `{self.monitor}` which is not available. Available metrics are: {", ".join(map(str, learn.recorder.names[3:]))}')   
        return values.get(self.monitor)
    

In [18]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy], callbacks=[SaveModel(every='epoch')] )

In [19]:
learn.fit(5,1e-42)

VBox(children=(HBox(children=(IntProgress(value=0, max=5), HTML(value='0.00% [0/5 00:00<00:00]'))), HTML(value…

Total time: 00:20
epoch  train loss  valid loss  accuracy
0      0.923145    0.757920    0.546124  (00:03)
1      0.904519    0.760853    0.550540  (00:04)
2      0.914042    0.753116    0.555447  (00:04)
3      0.913810    0.758505    0.547596  (00:04)
4      0.908953    0.764323    0.547105  (00:04)



In [20]:
!ls ../data/mnist_sample/models/

model__epoch0.pth  model__epoch2.pth  model__epoch4.pth
model__epoch1.pth  model__epoch3.pth


In [21]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy], callbacks=[SaveModel(every='improvement')] )

In [22]:
learn.fit(5,1e-2)

VBox(children=(HBox(children=(IntProgress(value=0, max=5), HTML(value='0.00% [0/5 00:00<00:00]'))), HTML(value…

Total time: 00:20
epoch  train loss  valid loss  accuracy
0      0.050595    0.040298    0.988714  (00:04)
1      0.026141    0.011412    0.995093  (00:04)
2      0.019063    0.012740    0.995093  (00:04)
3      0.030391    0.009357    0.998037  (00:04)
4      0.027890    0.007864    0.998528  (00:04)



In [23]:
!ls ../data/mnist_sample/models/

bestmodel_epoch0.pth  bestmodel_epoch4.pth  model__epoch2.pth
bestmodel_epoch1.pth  model__epoch0.pth     model__epoch3.pth
bestmodel_epoch3.pth  model__epoch1.pth     model__epoch4.pth
