# Implementing Callbacks in TensorFlow using the MNIST Dataset

We will code a classifier for the MNIST dataset, that trains until it reaches 98% accuracy and stops once this threshold is achieved.

In [1]:
import os
import tensorflow as tf

In [2]:
# Get current working directory
current_dir = os.getcwd()

# Append data/mnist.npz to the previous path to get the full path
data_path = os.path.join(current_dir, "data/mnist.npz")

# Load data (discard test set)
(training_images, training_labels), _ = tf.keras.datasets.mnist.load_data(path=data_path)

print(f"training_images is of type {type(training_images)}.\ntraining_labels is of type {type(training_labels)}\n")

# Inspect shape of the data
data_shape = training_images.shape

print(f"There are {data_shape[0]} examples with shape ({data_shape[1]}, {data_shape[2]})")

training_images is of type <class 'numpy.ndarray'>.
training_labels is of type <class 'numpy.ndarray'>

There are 60000 examples with shape (28, 28)


In [3]:
# Normalize pixel values
training_images = training_images / 255.0

In [4]:
def create_and_compile_model():
    """Returns the compiled (but untrained) model.

    Returns:
        tf.keras.Model: The model that will be trained to predict predict handwriting digits.
    """

    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),   # Flatten 28x28 input images
        tf.keras.layers.Dense(128, activation='relu'),   # Hidden layer
        tf.keras.layers.Dense(10, activation='softmax')  # Output layer for 10 classes
    ]) 

    model.compile(
		optimizer='adam',
		loss='sparse_categorical_crossentropy',
		metrics=['accuracy']
	)

    return model

In [5]:
untrained_model = create_and_compile_model()

In [6]:
untrained_model.summary()

In [7]:
# Use it to predict the first 5 images in the train set
predictions = untrained_model.predict(training_images[:5], verbose=False)

print(f"predictions have shape: {predictions.shape}")

predictions have shape: (5, 10)


In [8]:
class EarlyStoppingCallback(tf.keras.callbacks.Callback):

    # Define the correct function signature for on_epoch_end method
    def on_epoch_end(self, epoch, logs=None):
        
        # Check if the accuracy is greater or equal to 0.98
        if logs.get('accuracy') >= 0.98:
                            
            # Stop training once the above condition is met
            self.model.stop_training = True

            print("\nReached 98% accuracy so cancelling training!") 

In [9]:
def train_mnist(training_images, training_labels):
    """Trains a classifier of handwritten digits.

    Args:
        training_images (numpy.ndarray): The images of handwritten digits
        training_labels (numpy.ndarray): The labels of each image

    Returns:
        tf.keras.callbacks.History : The training history.
    """
    model = create_and_compile_model()
    
    # Fit the model for 10 epochs adding the callbacks and save the training history
    history = model.fit(training_images, training_labels, epochs=10, callbacks=[EarlyStoppingCallback()])

    return history

In [10]:
training_history = train_mnist(training_images, training_labels)

Epoch 1/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 2ms/step - accuracy: 0.8796 - loss: 0.4303
Epoch 2/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.9672 - loss: 0.1148
Epoch 3/10
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.9765 - loss: 0.0804
Epoch 4/10
[1m1868/1875[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 2ms/step - accuracy: 0.9827 - loss: 0.0570
Reached 98% accuracy so cancelling training!
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - accuracy: 0.9827 - loss: 0.0570
