# Keras callbacks

In Keras, **Callback** is a Python class meant to be subclassed to provide specific functionality, with a set of methods called at various stages of training (including batch/epoch start and ends), testing, and predicting. Callbacks are useful to get a view on internal states and statistics of the model during training. The methods of the [callbacks](https://keras.io/api/callbacks/) can be called at different stages of training/evaluating/inference. Keras has available callbacks and we'll show how you can use it in the following sections.

## Model methods that take callbacks
Users can supply a list of callbacks to the following `tf.keras.Model` methods:
* [`fit()`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#fit), [`fit_generator()`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#fit_generator)
Trains the model for a fixed number of epochs (iterations over a dataset, or data yielded batch-by-batch by a Python generator).
* [`evaluate()`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#evaluate), [`evaluate_generator()`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#evaluate_generator)
Evaluates the model for given data or data generator. Outputs the loss and metric values from the evaluation.
* [`predict()`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#predict), [`predict_generator()`](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model#predict_generator)
Generates output predictions for the input data or data generator.

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

try:
    %tnesorflow_version 2.x
except Exception:
    pass

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.keras.utils.vis_utils import plot_model

import matplotlib.pyplot as plt
import io

from PIL import Image

from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, LearningRateScheduler, ModelCheckpoint, CSVLogger, \
ReduceLROnPlateau

%load_ext tensorboard

import os
import matplotlib.pylab as plt
import numpy as np, pandas as pd
import math
import datetime

print("Version: ", tf.__version__)
tf.get_logger().setLevel('INFO')

Version:  2.4.1


## Examples of Keras callback applications

In [2]:
# Download and prepare the horses or humans dataset

splits, info = tfds.load('horses_or_humans', as_supervised=True, with_info=True, split=['train[:80%]', 'train[80%:]', 'test'])

(train_examples, validation_examples, test_examples) = splits

num_examples = info.splits['train'].num_examples
num_classes = info.features['label'].num_classes

In [3]:
SIZE = 150 #@param {type:"slider", min:64, max:300, step:1}
IMAGE_SIZE = (SIZE, SIZE)

In [4]:
def format_image(image, label):
  image = tf.image.resize(image, IMAGE_SIZE) / 255.0
  return  image, label

In [5]:
BATCH_SIZE = 32 #@param {type:"integer"}

In [6]:
train_batches = train_examples.shuffle(num_examples // 4).map(format_image).batch(BATCH_SIZE).prefetch(1)
validation_batches = validation_examples.map(format_image).batch(BATCH_SIZE).prefetch(1)
test_batches = test_examples.map(format_image).batch(1)

In [7]:
for image_batch, label_batch in train_batches.take(1):
  pass

image_batch.shape

TensorShape([32, 150, 150, 3])

In [8]:
def build_model(dense_units, input_shape=IMAGE_SIZE + (3,)):
    model = tf.keras.models.Sequential([
      tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=input_shape),
      tf.keras.layers.MaxPooling2D(2, 2),
      tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
      tf.keras.layers.MaxPooling2D(2, 2),
      tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
      tf.keras.layers.MaxPooling2D(2, 2),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(dense_units, activation='relu'),
      tf.keras.layers.Dense(2, activation='softmax')
    ])
    return model

## [TensorBoard](https://keras.io/api/callbacks/tensorboard/)

Enable visualizations for TensorBoard.

In [9]:
!rm -rf logs

In [10]:
model = build_model(dense_units=256)
model.compile(optimizer='sgd',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

logdir = os.path.join('logs', datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir)

model.fit(train_batches,
         epochs=5,
         validation_data=validation_batches,
         callbacks=[tensorboard_callback])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fa6c0dee190>

In [11]:
%tensorboard --logdir logs/

ERROR: Could not find `tensorboard`. Please ensure that your PATH
contains an executable `tensorboard` program, or explicitly specify
the path to a TensorBoard binary by setting the `TENSORBOARD_BINARY`
environment variable.

## [Model Checkpoint](https://keras.io/api/callbacks/model_checkpoint/)

Callback to save the Keras model or model weights at some frequency.

In [12]:
model = build_model(dense_units=256)
model.compile(optimizer='sgd',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

model.fit(train_batches,
         epochs=1,
         validation_data=validation_batches,
         verbose=2,
         callbacks=[ModelCheckpoint('weights.{epoch:02d}-{val_loss:.2f}.h5', verbose=1)])

26/26 - 14s - loss: 0.6670 - accuracy: 0.5900 - val_loss: 0.6202 - val_accuracy: 0.6780

Epoch 00001: saving model to weights.01-0.62.h5


<tensorflow.python.keras.callbacks.History at 0x7fa6a9d370d0>

In [13]:
model = build_model(dense_units=256)
model.compile(optimizer='sgd',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

model.fit(train_batches,
         epochs=1,
         validation_data=validation_batches,
         verbose=2,
         callbacks=[ModelCheckpoint('saved_model', verbose=1)])

26/26 - 14s - loss: 0.6751 - accuracy: 0.5681 - val_loss: 0.6545 - val_accuracy: 0.7756

Epoch 00001: saving model to saved_model
INFO:tensorflow:Assets written to: saved_model/assets


INFO:tensorflow:Assets written to: saved_model/assets


<tensorflow.python.keras.callbacks.History at 0x7fa6cb076be0>

In [14]:
model = build_model(dense_units=256)
model.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])
  
model.fit(train_batches, 
          epochs=2, 
          validation_data=validation_batches, 
          verbose=2,
          callbacks=[ModelCheckpoint('model.h5', verbose=1)
          ])

Epoch 1/2
26/26 - 14s - loss: 0.6751 - accuracy: 0.5718 - val_loss: 0.6676 - val_accuracy: 0.6049

Epoch 00001: saving model to model.h5
Epoch 2/2
26/26 - 13s - loss: 0.6350 - accuracy: 0.6934 - val_loss: 0.6163 - val_accuracy: 0.7317

Epoch 00002: saving model to model.h5


<tensorflow.python.keras.callbacks.History at 0x7fa6caf4b850>

## [Early stopping](https://keras.io/api/callbacks/early_stopping/)

Stop training when a monitored metric has stopped improving.

In [15]:
model = build_model(dense_units=256)
model.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])

model.fit(train_batches,
         epochs=50,
         validation_data=validation_batches,
         verbose=2,
         callbacks=[EarlyStopping(
         patience=3,
         min_delta=0.05,
         baseline=0.8,
         mode='min',
         monitor='val_loss',
         restore_best_weights=True,
         verbose=1)])

Epoch 1/50
26/26 - 14s - loss: 0.6783 - accuracy: 0.5499 - val_loss: 0.6619 - val_accuracy: 0.8146
Epoch 2/50
26/26 - 13s - loss: 0.6505 - accuracy: 0.6375 - val_loss: 0.6253 - val_accuracy: 0.8341
Epoch 3/50
26/26 - 14s - loss: 0.6069 - accuracy: 0.7117 - val_loss: 0.5939 - val_accuracy: 0.6732
Epoch 4/50
26/26 - 14s - loss: 0.5714 - accuracy: 0.7105 - val_loss: 0.5180 - val_accuracy: 0.9024
Epoch 5/50
26/26 - 14s - loss: 0.5223 - accuracy: 0.7603 - val_loss: 0.5951 - val_accuracy: 0.5951
Epoch 6/50
26/26 - 14s - loss: 0.4797 - accuracy: 0.7956 - val_loss: 0.4709 - val_accuracy: 0.6878
Epoch 7/50
26/26 - 13s - loss: 0.4125 - accuracy: 0.8455 - val_loss: 0.3230 - val_accuracy: 0.9415
Epoch 8/50
26/26 - 13s - loss: 0.3112 - accuracy: 0.8966 - val_loss: 0.2431 - val_accuracy: 0.9463
Epoch 9/50
26/26 - 13s - loss: 0.2936 - accuracy: 0.8832 - val_loss: 0.5663 - val_accuracy: 0.6439
Epoch 10/50
26/26 - 14s - loss: 0.2532 - accuracy: 0.9270 - val_loss: 0.1708 - val_accuracy: 0.9610
Epoch 11/

<tensorflow.python.keras.callbacks.History at 0x7fa6cace1160>

## [CSV Logger](https://keras.io/api/callbacks/csv_logger/)

Callback that streams epoch results to a CSV file.

In [16]:
model = build_model(dense_units=256)
model.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])

csv_file = 'training.csv'

model.fit(train_batches,
         epochs=5,
         validation_data=validation_batches,
         callbacks=[CSVLogger(csv_file)])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fa6cabe3fd0>

In [17]:
pd.read_csv(csv_file).head()

Unnamed: 0,epoch,accuracy,loss,val_accuracy,val_loss
0,0,0.608272,0.673407,0.765854,0.656475
1,1,0.716545,0.63093,0.526829,0.655732
2,2,0.742092,0.566111,0.75122,0.538003
3,3,0.750608,0.524267,0.804878,0.511137
4,4,0.80292,0.452956,0.814634,0.443315


## [Learning Rate Scheduler](https://keras.io/api/callbacks/learning_rate_scheduler/)

Updates the learning rate during training.

In [18]:
model = build_model(dense_units=256)
model.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])

def step_decay(epoch):
    initial_lr = 0.01
    drop = 0.5
    epochs_drop = 1
    lr = initial_lr * math.pow(drop, math.floor((1+epoch) / epochs_drop))
    return lr

model.fit(train_batches,
         epochs=5,
         validation_data=validation_batches,
         callbacks=[LearningRateScheduler(step_decay, verbose=1),
                   TensorBoard(log_dir='./log_dir')])

Epoch 1/5

Epoch 00001: LearningRateScheduler reducing learning rate to 0.005.
Epoch 2/5

Epoch 00002: LearningRateScheduler reducing learning rate to 0.0025.
Epoch 3/5

Epoch 00003: LearningRateScheduler reducing learning rate to 0.00125.
Epoch 4/5

Epoch 00004: LearningRateScheduler reducing learning rate to 0.000625.
Epoch 5/5

Epoch 00005: LearningRateScheduler reducing learning rate to 0.0003125.


<tensorflow.python.keras.callbacks.History at 0x7fa6caab0eb0>

In [19]:
%tensorboard --logdir logdir

ERROR: Could not find `tensorboard`. Please ensure that your PATH
contains an executable `tensorboard` program, or explicitly specify
the path to a TensorBoard binary by setting the `TENSORBOARD_BINARY`
environment variable.

## [ReduceLROnPlateau](https://keras.io/api/callbacks/reduce_lr_on_plateau/)

Reduce learning rate when a metric has stopped improving.
**Callback**

In [20]:
model = build_model(dense_units=256)
model.compile(
    optimizer='sgd',
    loss='sparse_categorical_crossentropy', 
    metrics=['accuracy'])

model.fit(train_batches,
         epochs=50,
         validation_data=validation_batches,
         callbacks=[ReduceLROnPlateau(monitor='val_loss',
                                     factor=0.2, verbose=1,
                                     patience=1, min_lr=0.001),
                   TensorBoard(log_dir='./log_dir')])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50

Epoch 00004: ReduceLROnPlateau reducing learning rate to 0.0019999999552965165.
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50

Epoch 00009: ReduceLROnPlateau reducing learning rate to 0.001.
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50

Epoch 00018: ReduceLROnPlateau reducing learning rate to 0.001.
Epoch 19/50

Epoch 00019: ReduceLROnPlateau reducing learning rate to 0.001.
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50

Epoch 00028: ReduceLROnPlateau reducing learning rate to 0.001.
Epoch 29/50

Epoch 00029: ReduceLROnPlateau reducing learning rate to 0.001.
Epoch 30/50
Epoch 31/50

Epoch 00031: ReduceLROnPlateau reducing learning rate to 0.001.
Epoch 32/50
Epoch 33/50

Epoch 00033: ReduceLROnPlateau reducing learning rate to 0.001.
Epoch 34/50

Epoch 00034: ReduceLROnPlateau reducing learning rate to 0.001.



Epoch 00050: ReduceLROnPlateau reducing learning rate to 0.001.


<tensorflow.python.keras.callbacks.History at 0x7fa6ca9b4df0>

In [21]:
%tensorboard --logdir logdir

ERROR: Could not find `tensorboard`. Please ensure that your PATH
contains an executable `tensorboard` program, or explicitly specify
the path to a TensorBoard binary by setting the `TENSORBOARD_BINARY`
environment variable.

# Keras custom callbacks

In [22]:
# Define the Keras model to add callbacks to
def get_model():
    model = tf.keras.Sequential()
    model.add(tf.keras.layers.Dense(1, activation = 'linear', input_dim = 784))
    model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
    return model

In [23]:
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

Now Let's define a simple custom callback to track the start and the end of every batch data. During those calls, it prints the index of the current batch.

In [27]:
class CustomCallback(tf.keras.callbacks.Callback):
    def on_train_batch_begin(self, batch, logs=None):
        print(f'Training batch {batch} begins at {datetime.datetime.now().time()}')
        
    def on_train_batch_end(self, batch, logs=None):
        print(f'Training batch {batch} ends at {datetime.datetime.now().time()}')

Providing a callback to model methods such as $tf.keras.Model.fit()$ ensures the methods are called at those stages:

In [29]:
model = get_model()
_ = model.fit(x_train, y_train,
             batch_size=64,
             epochs=1,
             steps_per_epoch=5,
             verbose=0,
             callbacks=[CustomCallback()])

Training batch 0 begins at 14:33:36.656532
Training batch 0 ends at 14:33:37.057915
Training batch 1 begins at 14:33:37.058091
Training batch 1 ends at 14:33:37.059771
Training batch 2 begins at 14:33:37.059878
Training batch 2 ends at 14:33:37.060954
Training batch 3 begins at 14:33:37.061069
Training batch 3 ends at 14:33:37.062189
Training batch 4 begins at 14:33:37.062525
Training batch 4 ends at 14:33:37.063596


## **An overview of callback methods**

### **Common methods for training/testing/predicting**

For training, testing, and predicting, following methods are provided to be overridden.

**on_(train|test|predict)_begin(self, logs=None)**
Called at the beginning of fit/evaluate/predict.

**on_(train|test|predict)_end(self, logs=None)**
Called at the end of fit/evaluate/predict.

**on_(train|test|predict)_batch_begin(self, batch, logs=None)**
Called right before processing a batch during training/testing/predicting.

Within this method, logs is a dict with batch and size available keys, representing the current batch number and the size of the batch.

**on_(train|test|predict)_batch_end(self, batch, logs=None)**
Called at the end of training/testing/predicting a batch. Within this method, logs is a dict containing the stateful metrics result.

### **Training specific methods**

In addition, for training, following are provided.

**on_epoch_begin(self, epoch, logs=None)**
Called at the beginning of an epoch during training.

**on_epoch_end(self, epoch, logs=None)**
Called at the end of an epoch during training.

### Usage of `logs` dict
The `logs` dict contains the loss value, and all the metrics at the end of a batch or epoch. Example includes the loss and mean absolute error.

In [30]:
callback = tf.keras.callbacks.LambdaCallback(
    on_epoch_end = lambda epoch, logs:
    print('Epoch: {}, Val/Train loss ratio: {:.2f}'.format(epoch, logs['val_loss'] / logs['loss']))
)

model = get_model()
_ = model.fit(x_train, y_train, 
             validation_data=(x_test, y_test),
             batch_size=64,
             epochs=3,
             verbose=0,
             callbacks=[callback()])

Epoch: 0, Val/Train loss ratio: 0.48
Epoch: 1, Val/Train loss ratio: 0.13
Epoch: 2, Val/Train loss ratio: 1.81


In [32]:
class DetectOverfittingCallback(tf.keras.callbacks.Callback):
    def __init__(self, threshold=0.7):
        super(DetectOverfittingCallback, self).__init__()
        self.threshold = threshold
        
    def on_epoch_end(self, epoch, logs=None):
        ratio = logs['val_loss'] / logs['loss']
        print('Epoch: {}, Val/Train loss ratio: {:.2f}'.format(epoch, ratio))
        
        if ratio > self.threshold:
            print('Stopping training...')
            self.model.stop_training = True
            
model = get_model()
_ = model.fit(x_train, y_train, 
             validation_data=(x_test, y_test),
             batch_size=64,
             epochs=3,
             verbose=0,
             callbacks=[DetectOverfittingCallback()])

Epoch: 0, Val/Train loss ratio: 1.41
Stopping training...
