In [1]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # normalize to between 0-1

# model layers
xIn = Input((28, 28))
x = Flatten()(xIn)
x = Dense(256, activation='swish')(x)
x = Dropout(0.2)(x)
x = Dense(256, activation='swish')(x)
x = Dropout(0.2)(x)
x = Dense(256, activation='swish')(x)
x = Dropout(0.2)(x)
xOut = Dense(10)(x)

model = Model(inputs=xIn, outputs=xOut)

model.compile(
    optimizer=tf.keras.optimizers.Adam(3e-3),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)
]

model.summary()
model.fit(x_train, y_train, epochs=30, batch_size=1024, validation_data=(x_test, y_test), callbacks=callbacks)

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 28, 28)]          0         
                                                                 
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 256)               200960    
                                                                 
 dropout (Dropout)           (None, 256)               0         
                                                                 
 dense_1 (Dense)             (None, 256)               65792     
                                                                 
 dropout_1 (Dropout)         (None, 256)               0         
                                                                 
 dense_2 (Dense)             (None, 256)               65792 

<keras.callbacks.History at 0x22896498f10>

In [2]:
model.save("mymodel.h5")
model.evaluate(x_test, y_test)



[0.05432068556547165, 0.985200047492981]