In [1]:
import numpy as np
import tensorflow as tf

In [2]:
DATA_URL = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz'

path = tf.keras.utils.get_file('mnist.npz', DATA_URL)
with np.load(path) as data:
  train_examples = data['x_train']
  train_labels = data['y_train']
  test_examples = data['x_test']
  test_labels = data['y_test']

In [3]:
train_dataset = tf.data.Dataset.from_tensor_slices(({'image_input':train_examples}, {'label_output':train_labels}))
test_dataset = tf.data.Dataset.from_tensor_slices(({'image_input':test_examples}, {'label_output':test_labels}))

In [4]:
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

In [5]:
input = tf.keras.layers.Input(shape=(28, 28), dtype='float32', name='image_input')
flatten = tf.keras.layers.Flatten(name='mnist_flatten')(input)
dense = tf.keras.layers.Dense(units=128, activation='relu', name='mnist_dense')(flatten)
output = tf.keras.layers.Dense(units=10, activation='softmax', name='label_output')(dense)

In [6]:
model = tf.keras.models.Model(
    inputs={'image_input': input},
    outputs={'label_output':output},
    name='mnist_classification_model'
)

In [7]:
model.compile(optimizer=tf.keras.optimizers.RMSprop(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['sparse_categorical_accuracy'])

In [8]:
model.fit(train_dataset, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x2964f339688>

In [9]:
model.evaluate(x={'image_input':test_examples}, y={'label_output':test_labels})



[2.09651517868042, 0.3646000027656555]

In [12]:
model.predict(x={'image_input':test_examples[0:10]})

{'label_output': array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]], dtype=float32)}

In [13]:
test_labels[0:10]

array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], dtype=uint8)