# Early stopping example on MNIST data

First we load the MNIST data.

In [1]:
from keras.datasets import mnist
(train_images, train_labels), (val_images, val_labels) = mnist.load_data()

Using TensorFlow backend.


We preprocess the data as in Chapter 6.5.

In [2]:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
val_images = val_images.reshape((10000, 28 * 28))
val_images = val_images.astype("float32") / 255
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
val_labels = to_categorical(val_labels)

We set up the same model as Chapter 6.5.

In [3]:
from keras import layers
from keras import models
model = models.Sequential()
model.add(layers.Dense(128, activation="relu", input_dim=28 * 28))
model.add(layers.Dense(32, activation="relu"))
model.add(layers.Dense(10, activation="softmax"))

Instructions for updating:
Colocations handled automatically by placer.


Now we initialise the callbacks required for early stopping.

In [4]:
from keras import callbacks
callbacks_list = [
    callbacks.EarlyStopping(monitor="val_loss", patience=3),
    callbacks.ModelCheckpoint(filepath="mymodel.h5", monitor="val_loss",
        save_best_only=True)
]

Now we run the model using early stopping.

In [5]:
model.compile(optimizer="adam",
              loss="categorical_crossentropy",
              metrics=["accuracy"])
hist_MNIST = model.fit(train_images, train_labels, epochs=20, batch_size=64,
                       callbacks = callbacks_list,
                       validation_data=(val_images, val_labels))

Instructions for updating:
Use tf.cast instead.
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20


We can load our optimal model and use it for prediction as follows:

In [6]:
from keras.models import load_model
final_model = load_model("mymodel.h5")
final_model.predict(val_images[0:2,:])

array([[1.1208791e-05, 3.9576929e-07, 5.1955813e-05, 2.5260064e-04,
        1.3371984e-10, 3.1532405e-08, 8.8002704e-12, 9.9959677e-01,
        3.1356396e-07, 8.6693050e-05],
       [1.2812493e-08, 1.1812105e-06, 9.9999690e-01, 1.2398214e-06,
        7.5552899e-15, 2.3686269e-09, 7.5436145e-07, 6.9821888e-14,
        1.9390784e-08, 2.8188596e-14]], dtype=float32)