In [3]:
import tensorflow as tf

mnist = tf.keras.datasets.mnist # load MNIST dataset

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# scale the pixel values to range of 0 to 1 by dividing by 255
# at the same time convert the integers to floating-point numbers
x_train, x_test = x_train / 255.0, x_test / 255.0 

# Build sequential model where each layer has 1 input tensor and 1 output tensor
# Each layer is a 
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'), # linear function output input directly if +ve otherwise 0
  tf.keras.layers.Dropout(0.2), # randomly sets input units to 0, units not set to 0 scaled up by 1/(1-rate)
  tf.keras.layers.Dense(10)
])

# loss function
# Takes vector of ground truth values and vector of logits and returns a scalar loss for each example
# loss is 0 if the model is sure of correct class
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) 



Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [5]:
# compile model
# Adam optimiser: stochastic gradient descent that maintains per-parameter learning rate 
# and these are adapted based av of recent magnitudes of gradients for weights
model.compile(optimizer='adam', 
              loss=loss_fn,
              metrics=['accuracy'])


model.fit(x_train, y_train, epochs=5) # train model over 5 epochs


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7fe3020c9220>

In [6]:
model.evaluate(x_test,  y_test, verbose=2)  # evaluate model's performance using test set

313/313 - 0s - loss: 0.0710 - accuracy: 0.9799 - 332ms/epoch - 1ms/step


[0.07104834914207458, 0.9799000024795532]

In [8]:
# apply softmax to the trained model to return probability
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[4.27526348e-10, 8.34430525e-10, 4.55912016e-07, 2.01845505e-05,
        4.69350210e-12, 2.13073008e-07, 1.91312278e-16, 9.99975681e-01,
        8.41720578e-08, 3.36002927e-06],
       [3.56374236e-11, 6.23037488e-09, 1.00000000e+00, 9.65691882e-09,
        6.07158155e-23, 2.57470556e-09, 1.94802551e-12, 7.33492379e-19,
        1.85159874e-10, 2.13866417e-15],
       [2.31159675e-10, 9.99843240e-01, 1.65007095e-05, 3.83161478e-07,
        5.25838959e-06, 1.15934881e-07, 3.37331386e-07, 6.16176403e-05,
        7.18960509e-05, 7.16414320e-07],
       [9.99984980e-01, 3.01509642e-12, 1.01412516e-05, 1.43225765e-10,
        6.55711631e-07, 1.30340595e-07, 3.50769824e-06, 2.36958371e-07,
        1.06400617e-10, 3.29658434e-07],
       [3.65661504e-06, 2.51789236e-12, 2.76277569e-06, 4.20638219e-10,
        9.88357842e-01, 8.96466901e-08, 1.85052585e-07, 1.08786764e-04,
        2.58277009e-08, 1.15266824e-02]], dtype=float32)>