## Usage of callbacks

A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training. You can pass a list of callbacks (as the keyword argument callbacks) to the .fit() method of the Sequential or Model classes. The relevant methods of the callbacks will then be called at each stage of the training.

## Base Callbacks

There are a couple of callbacks that you are already using without knowing it:

* BaseLogger: Callback that accumulates epoch averages of metrics.
* ProgbarLogger: Callback that prints metrics to stdout.
* History: Callback that records events into a History object.


## Even More Callbacks

I'll show off a set of callbacks available to you to use with any model, and then we will talk about custom callbacks.

In [1]:
from keras.callbacks import ModelCheckpoint

mc = ModelCheckpoint(
    filepath='tmp/weights.{epoch:02d}-{val_loss:.2f}.hdf5',
    monitor='val_loss',
    verbose=0,
    save_best_only=True,
    save_weights_only=True,
    mode='max',
    period=5)

Using TensorFlow backend.


In [None]:
from keras.callbacks import EarlyStopping

es = EarlyStopping(
    monitor='val_loss',
    min_delta=0.01,
    patience=5,
    verbose=1,
    mode='max')

In [2]:
from keras.callbacks import LearningRateScheduler

lrs = LearningRateScheduler(lambda epoch: 1./epoch)

In [3]:
from keras.callbacks import ReduceLROnPlateau

rlrop = ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.1, 
    patience=10, 
    verbose=0, 
    mode='auto', 
    epsilon=0.0001, 
    cooldown=4, 
    min_lr=10e-7)

In [4]:
from keras.callbacks import CSVLogger

csvl = CSVLogger(
    filename='tmp/training.log',
    separator=',', 
    append=False)

In [None]:
from keras.callbacks import TensorBoard

TensorBoard(
    log_dir='./logs', 
    histogram_freq=0, 
    write_graph=True, 
    write_images=False,
    embeddings_freq=100,
    embeddings_layer_names=None, # this list of embedding layers...
    embeddings_metadata=None)      # with this metadata associated with them.)

## Lambda Callback

If that was not enough for you, here is the big one. 

This callback is constructed with anonymous functions that will be called at the appropriate time. Note that the callbacks expects positional arguments, as: - on_epoch_begin and on_epoch_end expect two positional arguments: epoch, logs - on_batch_begin and on_batch_end expect two positional arguments: batch, logs - on_train_begin and on_train_end expect one positional argument: logs

#### Arguments

* on_epoch_begin: called at the beginning of every epoch.
* on_epoch_end: called at the end of every epoch.
* on_batch_begin: called at the beginning of every batch.
* on_batch_end: called at the end of every batch.
* on_train_begin: called at the beginning of model training.
* on_train_end: called at the end of model training.


In [5]:
from keras.callbacks import LambdaCallback

# Print the batch number at the beginning of every batch.
def print_batch(batch, logs):
    print batch
batch_print_callback = LambdaCallback(
    on_batch_begin=print_batch)

# Terminate some processes after having finished model training.
processes = []
cleanup_callback = LambdaCallback(
    on_train_end=lambda logs: [
    p.terminate() for p in processes if p.is_alive()])

## Super Custom Callbacks

You can create a custom callback by extending the base class keras.callbacks.Callback. A callback has access to its associated model through the class property self.model.

Abstract base class used to build new callbacks.

#### Properties

* params: dict. Training parameters (eg. verbosity, batch size, number of epochs...).
* model: instance of keras.models.Model. Reference of the model being trained.

The logs dictionary that callback methods take as argument will contain keys for quantities relevant to the current batch or epoch.

Currently, the .fit() method of the Sequential model class will include the following quantities in the logs that it passes to its callbacks:

* on_epoch_end: logs include acc and loss, and optionally include val_loss (if validation is enabled in fit), and val_acc (if validation and accuracy monitoring are enabled).
* on_batch_begin: logs include size, the number of samples in the current batch.
* on_batch_end: logs include loss, and optionally acc (if accuracy monitoring is enabled).

Here's a simple example saving a list of losses over each batch during training:

In [None]:
import keras

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))