In [12]:
# Import packages
import tensorflow as tf

from tensorflow.keras import backend as k

tf.__version__

'2.7.0'

# Callbacks

Callbacks are objects that are called at different points during training (at the start of an epoch, at the end of a batch, at the end of an epoch, etc.):

- Doing validation at different points during training (beyond the built-in per-epoch validation)
- Checkpointing the model at regular intervals or when it exceeds a certain accuracy threshold
- Changing the learning rate of the model when training seems to be plateauing
- Stopping training when validation loss starts increasing
- Doing fine-tuning of the top layers when training seems to be plateauing
- Sending email or instant message notifications when training ends or where a certain performance threshold is exceeded etc.

### Model Early Stopping via Callbacks

In [13]:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=1e-2,
    patience=2,
    verbose=1,
)

### Model Checkpointing via Callbacks

In [14]:
# Define a callback to save models while monitoring validation loss
# It overwrites the model when validation loss improves
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="tmp/mymodel_{epoch}", # Path to save model
    save_best_only=True,  # Overwrite a model only if `val_loss` has improved.
    monitor="val_loss",
    verbose=1
)

## Custom Callback

You can create a custom callback by extending the base class tf.keras.callbacks.Callback. 

A callback has access to its associated model through the class property self.model.

In [15]:
# Saving a dict of per-batch loss values during training instead of default behaviour of saving it for every epoch
class LossHistoryBatch(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = dict()

    def on_batch_end(self, batch, logs):
        self.per_batch_losses[batch] = logs.get("loss")

## Learning Rate Scheduler via Callbacks

A common pattern when training deep learning models is to gradually reduce the learning as training progresses. This is generally known as "learning rate decay".

The learning decay schedule could be static (fixed in advance, as a function of the current epoch or the current batch index), or dynamic (responding to the current behavior of the model, in particular the validation loss).

### Static

You can easily use a static learning rate decay schedule by passing a schedule object as the learning_rate argument in your optimizer.

In [16]:
# Set initial learning rate
initial_learning_rate = 0.01

# Define a scheduler
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True
)

# Define optimizer with a learning rate scheduler
optimizer = tf.keras.optimizers.RMSprop(learning_rate=lr_schedule)

### Dynamic

A dynamic learning rate schedule (for instance, decreasing the learning rate when the validation loss is no longer improving) cannot be achieved with these schedule objects since the optimizer does not have access to validation metrics.

However, callbacks do have access to all metrics, including validation metrics! You can thus achieve this pattern by using a callback that modifies the current learning rate on the optimizer. In fact, this is even built-in as the ReduceLROnPlateau callback.

In [17]:
# A very simple dynamic learning rate scheduler using callbacks
# Note that this is just for the sake of example
class IncreaseLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        lr = k.get_value(self.model.optimizer.lr)
        new_lr = lr + 0.001 # Decrease learning rate
        k.set_value(self.model.optimizer.lr, new_lr) # Set new learning rate