In [16]:
# Parameters
# The batch size, number of training epochs and location of the data files is defined here. 
# Data files are hosted in a Google Cloud Storage (GCS) bucket which is why their address starts with gs://.

BATCH_SIZE = 128
EPOCHS = 10

training_images_file   = 'gs://mnist-public/train-images-idx3-ubyte'
training_labels_file   = 'gs://mnist-public/train-labels-idx1-ubyte'
validation_images_file = 'gs://mnist-public/t10k-images-idx3-ubyte'
validation_labels_file = 'gs://mnist-public/t10k-labels-idx1-ubyte'

In [17]:
# Imports
# All the necessary Python libraries are imported here, including TensorFlow.

import os, re, math, json, shutil, pprint
import numpy as np
import tensorflow as tf
print("Tensorflow version " + tf.__version__)

Tensorflow version 2.9.2


In [18]:
# Parse files and prepare training and validation datasets.

AUTO = tf.data.experimental.AUTOTUNE

def read_label(tf_bytestring):
    label = tf.io.decode_raw(tf_bytestring, tf.uint8)
    label = tf.reshape(label, [])
    label = tf.one_hot(label, 10)
    return label
  
def read_image(tf_bytestring):
    image = tf.io.decode_raw(tf_bytestring, tf.uint8)
    image = tf.cast(image, tf.float32)/256.0
    image = tf.reshape(image, [28*28])
    return image

In [19]:
# Apply this function to the dataset using .map and obtain a dataset of images.
# The same kind of reading and decoding for is done using .zip for images and labels.

def load_dataset(image_file, label_file):
    imagedataset = tf.data.FixedLengthRecordDataset(image_file, 28*28, header_bytes=16)
    imagedataset = imagedataset.map(read_image, num_parallel_calls=16)
    labelsdataset = tf.data.FixedLengthRecordDataset(label_file, 1, header_bytes=8)
    labelsdataset = labelsdataset.map(read_label, num_parallel_calls=16)
    dataset = tf.data.Dataset.zip((imagedataset, labelsdataset))
    return dataset 

In [20]:
# Training dataset
# The tf.data.Dataset API has all the necessary utility functions for preparing datasets.
# .cache caches the dataset in RAM
# .shuffle shuffles it with a buffer of 5000 elements
# .repeat loops the dataset
# .batch pulls multiple images and labels together into a mini-batch
# .prefetch can use the CPU to prepare the next batch while the current batch is being trained on the GPU
# The validation dataset is prepared in a similar way.

def get_training_dataset(image_file, label_file, batch_size):
    dataset = load_dataset(image_file, label_file)
    dataset = dataset.cache()  # this small dataset can be entirely cached in RAM, for TPU this is important to get good performance from such a small dataset
    dataset = dataset.shuffle(5000, reshuffle_each_iteration=True)
    dataset = dataset.repeat() # Mandatory for Keras for now
    dataset = dataset.batch(batch_size, drop_remainder=True) # drop_remainder is important on TPU, batch size must be fixed
    dataset = dataset.prefetch(AUTO)  # fetch next batches while training on the current one (-1: autotune prefetch buffer size)
    return dataset
  
def get_validation_dataset(image_file, label_file):
    dataset = load_dataset(image_file, label_file)
    dataset = dataset.cache() # this small dataset can be entirely cached in RAM, for TPU this is important to get good performance from such a small dataset
    dataset = dataset.batch(10000, drop_remainder=True) # 10000 items in eval dataset, all in one batch
    dataset = dataset.repeat() # Mandatory for Keras for now
    return dataset

# instantiate the datasets
training_dataset = get_training_dataset(training_images_file, training_labels_file, BATCH_SIZE)
validation_dataset = get_validation_dataset(validation_images_file, validation_labels_file)

# For TPU, we will need a function that returns the dataset
training_input_fn = lambda: get_training_dataset(training_images_file, training_labels_file, BATCH_SIZE)
validation_input_fn = lambda: get_validation_dataset(validation_images_file, validation_labels_file)

In [21]:
# Keras Model
# The models will be straight sequences of layers so tf.keras.Sequential is used to create them. 
# There are 10 neurons because we are classifying handwritten digits into 10 classes.
# A Keras model also needs to know the shape of its inputs, tf.keras.layers.Input can be used to define it.
# Configuring the model is done in Keras using the model.compile function. 
# A classification model requires a cross-entropy loss function, called 'categorical_crossentropy' in Keras. 
# The model computes the 'accuracy' metric, which is the percentage of correctly classified images.

model = tf.keras.Sequential(
  [
      tf.keras.layers.Input(shape=(28*28,)),
      tf.keras.layers.Dense(200, activation='relu'),
      tf.keras.layers.Dense(60, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ]
)

# To improve the recognition accuracy more layers to the neural network are added.
# Softmax as the activation function on the last layer is used because that works best for classification. 
# On intermediate layers the relu activation function is used because it has a derivative of 1. 

# In high-dimensional spaces better optimizers are needed so that the gradient descent optimizer is not stuck at saddle points.
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# print model layers
model.summary()

Model: "sequential_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_2 (Dense)             (None, 200)               157000    
                                                                 
 dense_3 (Dense)             (None, 60)                12060     
                                                                 
 dense_4 (Dense)             (None, 10)                610       
                                                                 
Total params: 169,670
Trainable params: 169,670
Non-trainable params: 0
_________________________________________________________________


In [22]:
# Train and validate the model
# The training happens by calling model.fit and passing in both the training and validation datasets.

steps_per_epoch = 60000//BATCH_SIZE  # 60,000 items in this dataset
print("Steps per epoch: ", steps_per_epoch)

history = model.fit(training_dataset, steps_per_epoch=steps_per_epoch, epochs=EPOCHS, validation_data=validation_dataset, validation_steps=1)

Steps per epoch:  468
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [23]:
# The shape of the output tensor is [128, 10] because 128 images are processed and computing the argmax across the 10 probabilities returned for each image, thus axis=1.
# This model recognises approximately 97% of the digits.