In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
# Load the MNIST dataset, shuffled and supervised (returning tuple (img, label)
# instead of dictionary {'image': img, 'label': label})
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files = True,
    as_supervised = True,
    with_info = True
)

In [3]:
# Building a training set pipeline

def normalize_img(image, label):
  '''Normalizes images from 'uint-8' to 'float32'''
  return tf.cast(image, tf.float32) / 255.0, label

# Normalize our training set, with a dynamic number of parallel calls
ds_train = ds_train.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

# Cache our training set for better performance (done before random transforms)
ds_train = ds_train.cache()

# Shuffle training set, with buffer set to full dataset size for max randomness
# Can be lowered if large dataset cannot fit into memory
ds_train = ds_train.shuffle(buffer_size = ds_info.splits['train'].num_examples)

# Self-explanatory, batches up elements to get unique batches at each epoch
ds_train = ds_train.batch(batch_size = 128)

# Overlaps preprocessing and model execution for performance
# Good practice to end off pipelines with this!
ds_train = ds_train.prefetch(buffer_size = tf.data.AUTOTUNE)

In [4]:
# Building a test set pipeline, same as training

ds_test = ds_test.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)

ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(buffer_size = tf.data.AUTOTUNE)

In [5]:
# Create the model

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10)
])

model.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
)

In [6]:
# Train the model

model.fit(
    ds_train,
    epochs = 10,
    validation_data = ds_test
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
 67/469 [===>..........................] - ETA: 1s - loss: 2.1783 - sparse_categorical_accuracy: 0.4781

KeyboardInterrupt: 

In [None]:
# Save model

model.save("model.keras")