In [2]:
from fastai.torch_core import *
from fastai.data import DataBunch
from fastai.callback import *
from fastai.basic_train import Learner, LearnerCallback

In [28]:
from fastai import * 
from fastai.docs import * 
from fastai.vision import * 

In [27]:
# __all__ = ['TerminateOnNaN', 'EarlyStopping']

In [7]:
untar_data(MNIST_PATH)
MNIST_PATH

HBox(children=(IntProgress(value=0, max=3214948), HTML(value='')))

PosixPath('../data/mnist_sample')

In [62]:
data = image_data_from_folder(MNIST_PATH, ds_tfms=(rand_pad(2, 28), []))

### The Problem: Loss is Nan

In [88]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])
learn.fit(2,1e4)

VBox(children=(HBox(children=(IntProgress(value=0, max=2), HTML(value='0.00% [0/2 00:00<00:00]'))), HTML(value…

Total time: 00:07
epoch  train loss  valid loss  accuracy
0      nan         nan         0.495584  (00:03)
1      nan         nan         0.495584  (00:03)



### The Solution

The Callback bellow is influenced by Keras Callback of same name.

In [163]:
class TerminateOnNaN(LearnerCallback):
    "A `LearnerCallback` that terminates training if loss is NaN."
    
    def __init__(self):
        super().__init__(learn)
        self.stop = False
    
    def on_batch_end(self,epoch:int, num_batch:int,  **kwargs:Any)->None:
        loss = learn.recorder.losses[-1:][0]
        if loss is not None:
            if torch.isnan(loss):
                print (f'Epoch/Batch ({epoch}/{num_batch}): Invalid loss, terminating training.')
                self.stop = True
                return True
    
    def on_epoch_end(self, metrics, **kwargs:Any)->None:
        return self.stop
    

In [164]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy])

In [165]:
learn.fit(2,1e4, callbacks=[TerminateOnNaN()])

VBox(children=(HBox(children=(IntProgress(value=0, max=2), HTML(value='0.00% [0/2 00:00<00:00]'))), HTML(value…

Epoch/Batch (0/6): Invalid loss, terminating training.
Epoch/Batch (0/7): Invalid loss, terminating training.


Actually, I don't know why it went to the second batch if I set return `True` in the first.

### The Problem: Metric does not improve

In [122]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy, accuracy])

In [123]:
learn.fit(3,1e-42)

VBox(children=(HBox(children=(IntProgress(value=0, max=10), HTML(value='0.00% [0/10 00:00<00:00]'))), HTML(val…

Total time: 00:42
epoch  train loss  valid loss  accuracy  accuracy
0      0.921106    0.787993    0.516683  0.516683  (00:04)
1      0.912194    0.763384    0.539745  0.539745  (00:03)
2      0.927705    0.771625    0.530422  0.530422  (00:04)
3      0.952577    0.780007    0.529931  0.529931  (00:04)
4      0.914487    0.792268    0.523552  0.523552  (00:04)
5      0.920072    0.779946    0.533857  0.533857  (00:04)
6      0.922355    0.779907    0.523552  0.523552  (00:04)
7      0.935991    0.794458    0.513739  0.513739  (00:04)
8      0.925322    0.776205    0.535819  0.535819  (00:04)
9      0.925311    0.784154    0.521099  0.521099  (00:04)



The Callback bellow is basically a simplified port of Keras Early Stopping callback to fastai/pytorch.

In [310]:
class EarlyStopping(LearnerCallback):
    "A `LearnerCallback` that terminates training when monitored quantity stops improving."
    def __init__(self, 
                 monitor='val_loss',
                 min_delta=0,
                 patience=0,
                 mode='auto'):
        
        super().__init__(learn)
        
        self.monitor = monitor
        self.min_delta = min_delta
        self.patience = patience
        self.mode = mode
        if mode not in ['auto', 'min', 'max']:
            #should I use warning?
            print(f'TerminateEarly mode {mode} is invalid, falling back to "auto" mode.')
            mode = 'auto'
        if mode == 'min':
            self.operator = np.less
        elif mode == 'max':
            self.operator = np.greater
        else:
            if 'acc' in self.monitor:
                self.operator = np.greater
            else:
                self.operator = np.less
        if self.operator == np.less:
            self.min_delta *= -1
    
    def on_train_begin(self, **kwargs:Any)->None:
        self.wait = 0
        self.stopped_epoch = 0
        self.best = float('inf') if self.operator == np.less else -float('inf')

    def on_epoch_end(self, epoch, **kwargs:Any)->None:
        current = self.get_monitor_value()
        if current is None:
            return
         
        if self.operator(current - self.min_delta, self.best):
            
            self.best = current
            self.wait = 0
        else:
            
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                return True
        return False
                
    def on_train_end(self, **kwargs:Any)->None:
        if self.stopped_epoch > 0:
            print(f'Epoch {self.stopped_epoch}: early stopping')
                                                                          
    def get_monitor_value(self):
        values = {'val_loss':learn.recorder.val_losses[-1:][0]
               ,'trn_loss':learn.recorder.losses[-1:][0].cpu().numpy()}
        for i, name in enumerate(learn.recorder.names[3:]):
            values[name]=learn.recorder.metrics[-1:][0][i] 

        if values.get(self.monitor) is None:
            print(f'Early stopping conditioned on metric `{self.monitor}` which is not available. Available metrics are: {", ".join(map(str, learn.recorder.names[3:]))}')   
        return values.get(self.monitor)
    

In [308]:
learn = ConvLearner(data, tvm.resnet18, metrics=[accuracy], callbacks=[EarlyStopping(monitor='accuracy', min_delta=0.01, patience=3)] )

In [309]:
learn.fit(50,1e-42)

best::::::-inf


VBox(children=(HBox(children=(IntProgress(value=0, max=50), HTML(value='0.00% [0/50 00:00<00:00]'))), HTML(val…

current:0.5505397319793701, best:-inf
best:0.5505397319793701, current:0.5456329584121704, operator:<ufunc 'greater'>
best:0.5505397319793701, current:0.5466143488883972, operator:<ufunc 'greater'>
best:0.5505397319793701, current:0.5510303974151611, operator:<ufunc 'greater'>
Epoch 3: early stopping


In [232]:
learn.recorder.val_losses[-1:][0]

0.8651589

In [247]:
a

{'c': 1, 'accuracy': 0.5210991}

In [250]:
a.get('cebola')