## Week 2: Implementing Callbacks in TensorFlow using the MNIST Dataset

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, losses

MNIST dataset
* 60,000 28x28 grayscale images of the 10 digits

[tf.keras.datasets.mnist.load_data](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/mnist/load_data)

In [2]:
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()

In [3]:
class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        if logs.get('accuracy') is not None and logs.get('accuracy') > 0.99:
            print("\nReached 99% accuracy so cancelling training!") 
            self.model.stop_training = True

In [4]:
def train_mnist(x_train, y_train):
    callbacks = myCallback()

    model = tf.keras.Sequential([
        layers.Rescaling(1/255, input_shape=(28, 28)),
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.Dense(10)])

    model.compile(optimizer='adam',
                  loss=losses.SparseCategoricalCrossentropy(from_logits=True), 
                  metrics=['accuracy'])

    history = model.fit(x_train, y_train, epochs=10, callbacks=[callbacks])

    return history
    

In [5]:
history = train_mnist(x_train, y_train)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Reached 99% accuracy so cancelling training!
