In [None]:
# Import TensorFlow
#from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow import keras
print('TensorFlow imported')

# Load MNIST
mnist = tf.keras.datasets.mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0 #normalize each pixel to be in [0,1]
X_valid, X_train = X_train[:5000], X_train[5000:]
y_valid, y_train = y_train[:5000], y_train[5000:]
print('MNIST loaded, normalized and split')

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import BatchNormalization, Flatten, Dense, Dropout
from tensorflow.keras import Sequential

# Make CNN model

def make_model(activation='elu',learning_rate=.001):
    
    model = Sequential([
        Conv2D(32, (3, 3), activation = 'relu', input_shape = (28,28,1)),
        MaxPooling2D((2,2)),
        Conv2D(64, (3, 3), activation = 'relu'),
        MaxPooling2D((2,2)),
        Conv2D(64, (3, 3), activation = 'relu'),
        Flatten(),
        Dense(64, activation= 'relu', kernel_initializer='he_normal'),
        Dense(10)
    ])
    # (b) compile with adam optimizer and early stopping

# Make adam optimizer with parameter learning_rate
    optimizer = keras.optimizers.Adam(learning_rate=learning_rate)

    model.compile(optimizer=optimizer,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

    return model

In [None]:
# Construct model, fit with callbacks

from tensorflow.keras.callbacks import EarlyStopping
from tensorflow import reshape


model = make_model()

callback = EarlyStopping(monitor = 'accuracy', patience=3) # this will do early stopping

model.fit(reshape(X_train, (len(X_train), 28, 28, 1)), y_train, epochs=20, callbacks=[callback],
          validation_data=(reshape(X_valid, (len(X_valid), 28, 28, 1)), y_valid))

In [None]:
# check model on test set
results = model.evaluate(reshape(X_test, (len(X_test), 28, 28, 1)), y_test, batch_size=128)
print('test loss, test acc:', results)