# Early stopping example on MNIST data

First we load the MNIST data.

In [1]:
from keras.datasets import mnist
(train_images, train_labels), (val_images, val_labels) = mnist.load_data()

Using TensorFlow backend.


We preprocess the data as in Chapter 6.5.

In [2]:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
val_images = val_images.reshape((10000, 28 * 28))
val_images = val_images.astype("float32") / 255
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
val_labels = to_categorical(val_labels)

We set up the same model as Chapter 6.5.

In [3]:
from keras import layers
from keras import models
model = models.Sequential()
model.add(layers.Dense(128, activation="relu", input_shape=(28 * 28,)))
model.add(layers.Dense(32, activation="relu"))
model.add(layers.Dense(10, activation="softmax"))

Now we initialise the callbacks required for early stopping.

In [4]:
from keras import callbacks
callbacks_list = [
    callbacks.EarlyStopping(monitor="val_loss", patience=3),
    callbacks.ModelCheckpoint(filepath="mymodel.h5", monitor="val_loss",
        save_best_only=True)
]

Now we run the model using early stopping.

In [5]:
model.compile(optimizer="adam",
              loss="categorical_crossentropy",
              metrics=["accuracy"])
hist_MNIST = model.fit(train_images, train_labels, epochs=20, batch_size=64,
                       callbacks = callbacks_list,
                       validation_data=(val_images, val_labels))

Train on 60000 samples, validate on 10000 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20


We can load our optimal model and use it for prediction as follows:

In [6]:
from keras.models import load_model
final_model = load_model("mymodel.h5")
final_model.predict(val_images[0:2,:])

array([[  5.52271729e-07,   2.02607815e-08,   1.72516559e-06,
          1.85115478e-05,   7.65666003e-11,   9.02019863e-08,
          1.40622661e-11,   9.99976993e-01,   2.05270680e-07,
          1.96792871e-06],
       [  5.22331726e-11,   2.22933821e-07,   9.99997258e-01,
          1.74370825e-06,   9.01629845e-13,   4.81929358e-11,
          1.73911274e-10,   2.25338465e-13,   8.66929554e-07,
          3.24682874e-14]], dtype=float32)