# Callbacks with Keras and Writing your own Callbacks

Callbacks is the way to customize the behaviour of your model during either training or evaluation. You can use callbacks to get a view on internal states and statistics of the model during training. 

You can pass a list of callbacks (as the keyword argument callbacks) to the .fit() method of the Sequential or Model classes. The relevant methods of the callbacks will then be called at each stage of the training.

We have seen how to save a model, but we have seen it how to do it after the training. Suppose we want to save the model after each epoch. We can do it with a callback. Or save the best model. 

## 1. Imports and Configuration

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds

# Configure GPU memory growth to be dynamic instead of allocating all memory at once
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

  from .autonotebook import tqdm as notebook_tqdm


## 2. Data Loading and Preprocessing

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,  # will return tuple (img, label) otherwise dict
    with_info=True,  # able to get info about dataset
)

In [3]:
def normalize_img(image, label):
    """Normalizes images"""
    return tf.cast(image, tf.float32) / 255.0, label


AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 128

# Setup for train dataset
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)

## 3. Model Definition

In [4]:
model = keras.Sequential(
    [
        keras.Input((28, 28, 1)),
        layers.Conv2D(32, 3, activation="relu"),
        layers.Flatten(),
        tf.keras.layers.Dense(10, activation="softmax"),
    ]
)

### Callback Functions

In [5]:
# Keras has many built-in callbacks, lookup write custom
# callbacks tensorflow documentation.
save_callback = keras.callbacks.ModelCheckpoint(
    "checkpoint_callback/", 
    save_weights_only=True, 
    monitor="train_acc", 
    save_best_only=False,
)

# This callback function is a learning reate scheduler
lr_scheduler = keras.callbacks.ReduceLROnPlateau(
    monitor="loss", 
    factor=0.1, 
    patience=3, 
    mode="max", 
    verbose=1
)

class OurOwnCallback(keras.callbacks.Callback):
    # you can also do on_batch_end, check documentation
    def on_epoch_end(self, epoch, logs=None):
        # printing logs will show us the keys we can use
        # print(logs.keys()) 
        if logs.get("accuracy") > 1:
            print("Accuracy over 70%, quitting training")
            self.model.stop_training = True

## 4. Compile Model

In [6]:
model.compile(
    optimizer=keras.optimizers.Adam(0.01),
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"],
)

## 5. Model Training and Evaluation

In [7]:
model.fit(
    ds_train,
    epochs=10,
    callbacks=[save_callback, lr_scheduler, OurOwnCallback()],
    verbose=2,
)

Epoch 1/10
469/469 - 7s - loss: 0.1422 - accuracy: 0.9573
Epoch 2/10
469/469 - 1s - loss: 0.0572 - accuracy: 0.9824
Epoch 3/10
469/469 - 1s - loss: 0.0342 - accuracy: 0.9890
Epoch 4/10
469/469 - 1s - loss: 0.0245 - accuracy: 0.9923

Epoch 00004: ReduceLROnPlateau reducing learning rate to 0.0009999999776482583.
Epoch 5/10
469/469 - 1s - loss: 0.0087 - accuracy: 0.9973
Epoch 6/10
469/469 - 1s - loss: 0.0042 - accuracy: 0.9991
Epoch 7/10
469/469 - 1s - loss: 0.0028 - accuracy: 0.9995

Epoch 00007: ReduceLROnPlateau reducing learning rate to 9.999999310821295e-05.
Epoch 8/10
469/469 - 1s - loss: 0.0019 - accuracy: 0.9998
Epoch 9/10
469/469 - 1s - loss: 0.0018 - accuracy: 0.9998
Epoch 10/10
469/469 - 1s - loss: 0.0017 - accuracy: 0.9998

Epoch 00010: ReduceLROnPlateau reducing learning rate to 9.999999019782991e-06.


<tensorflow.python.keras.callbacks.History at 0x7fdc483cd8e0>