# Callbacks

In [1]:
import tensorflow as tf
print(tf.__version__)

2.0.0


## `checkpoint_cb`: saving model after each epoch

In [None]:
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint('my_keras_model.h5',
                                                   save_best_only=True) #best valid performance
history = model.fit(X_train, y_train,
                    validation_data=(X_valid, y_valid),
                    epochs=10,
                    callbacks=[checkpoint_cb])

## `early stopping` (interrupt training) when there is no more progress (combines with `checkpoint_cb`)

In [None]:
# no progress on the validation set after x epochs (set on "patience" argument)
# "restore_best_weights": rolls back to best model
early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,
                                                     restore_best_weights=True)

history = model.fit(X_train, y_train,
                    validation_data=(X_valid, y_valid),
                    epochs=100,
                    callbacks=[checkpoint_cb, early_stopping_cb])
# combines both "checkpoint" and "early_stopping" callbacks above

## `custom`

It's possible to implement the callback for training:
> - on_train_begin(), on_train_end()  
> - on_epoch_begin(), on_epoch_end()  
> - on_batch_begin(), on_batch_end()  

and for evalutation (called by `evaluate()`):
> - on_test_begin(), on_test_end()  
> - on_test_batch_begin(), on_test_batch_end()  

and for prediction (called by `predict()`):
> - on_predict_begin(), on_predict_end()  
> - on_predict_batch_begin(), on_predict_batch_end()  

In [None]:
## saving model after desired accuracy is reached

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('accuracy')>0.97):
            print("\nReached 97.0% accuracy so cancelling training!")
            self.model.stop_training = True

callbacks = myCallback()
history = model.fit_generator(train_generator,
                              validation_data=validation_generator,
                              steps_per_epoch=100,
                              epochs=3,
                              validation_steps=50,
                              verbose=2,
                              callbacks=[callbacks],
                             )

In [None]:
## printing results to console

class PrintValTrainRatioCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs):
        print("\nval/train: {:.2f}".format(logs['val_loss'] / logs['loss']))


## `tensorboard`

In [None]:
import os
root_logdir = os.path.join(os.curdir, "my_logs")

def get_run_logdir():
    import time
    run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
    return os.path.join(root_logdir, run_id)

run_logdir = get_run_logdir()

# CALLBACK
tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)

history = model.fit(X_train, y_train,
                    validation_data=(X_valid, y_valid),
                    epochs=100,
                    callbacks=[tensorboard_cb])

## `custom tensorboard`: register custom metrics, audio, images, text, etc.

In [None]:
import numpy as np

import os
root_logdir = os.path.join(os.curdir, "my_logs")

def get_run_logdir():
    import time
    run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
    return os.path.join(root_logdir, run_id)

test_logdir = get_run_logdir()
writer = tf.summary.create_file_writer(test_logdir)
with writer.as_default():
    for step in range(1, 1000 + 1):
        tf.summary.scalar('my_scalar', np.sin(step / 10), step=step)
        
        data = (np.random.randn(100) + 2) * step / 100 # some random data
        tf.summary.histogram("my_hist", data, buckets=50, step=step)
        
        images = np.random.randn(2, 32, 32, 3) # random 32x32 RGB images
        tf.summary.image('my_images', images * step / 1000, step=step)
        
        texts = ['The step is ' + str(step), "Its square is " + str(step ** 2)]
        tf.summary.text('my_text', texts, step=step)
        
        sine_wave = tf.math.sin(tf.range(12000) / 48000 * 2 * np.pi * step)
        audio = tf.reshape(tf.cast(sine_wave, tf.float32), [1, -1, 1])
        tf.summary.audio('my_audio', audio, sample_rate=48000, step=step)

tensorboard_cb = tf.keras.callbacks.TensorBoard(test_logdir)


# fit
history = model.fit((X_train_A, X_train_B), (y_train, y_train), # 2 outputs, 2 y's
                    validation_data=((X_valid_A, X_valid_B), (y_valid, y_valid)),
                    epochs=100,
                    callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb])