In [1]:
# import tensorflow

import tensorflow as tf


In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])


In [4]:
predictions = model(x_train[:1]).numpy()
predictions


array([[ 0.35473633, -0.3974684 , -0.10580429,  0.22280826, -0.1429728 ,
        -0.25611702,  0.04287297,  0.11826724,  0.32941687, -0.19723085]],
      dtype=float32)

In [5]:
tf.nn.softmax(predictions).numpy()


array([[0.13885441, 0.06544573, 0.08760915, 0.12169257, 0.08441261,
        0.07538232, 0.10165275, 0.10961311, 0.13538283, 0.07995459]],
      dtype=float32)

In [7]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn(y_train[:1], predictions).numpy()


2.5851827

In [8]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])


In [9]:
model.fit(x_train, y_train, epochs=5)


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fd1082b1190>

In [10]:
model.evaluate(x_test,  y_test, verbose=2)


313/313 - 1s - loss: 0.0789 - accuracy: 0.9760


[0.07885855436325073, 0.9760000109672546]

In [11]:
# The image classifier is now trained to ~97% accuracy on this dataset

In [12]:
# If you want your model to return a probability, you can wrap the trained model, and attach the softmax to it:

probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])


probability_model(x_test[:5])



<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[3.4741070e-08, 6.0987686e-09, 5.6300615e-07, 2.5027734e-06,
        5.5240066e-12, 3.6851853e-07, 3.3023322e-15, 9.9999547e-01,
        4.0085617e-08, 1.0496844e-06],
       [6.0007470e-09, 3.0383419e-06, 9.9999547e-01, 4.2321554e-07,
        6.7677431e-19, 4.8279384e-07, 1.6692929e-09, 1.6528078e-14,
        6.1499560e-07, 5.8353476e-16],
       [1.8827624e-06, 9.9807560e-01, 4.5856618e-04, 4.3149892e-05,
        4.7369598e-05, 2.6921118e-06, 8.8935267e-06, 1.1801813e-03,
        1.8078087e-04, 8.1443994e-07],
       [9.9999046e-01, 9.2139393e-13, 6.3243110e-06, 9.0108188e-10,
        1.1290470e-08, 1.3256655e-06, 8.2880327e-08, 1.7018909e-06,
        3.4030676e-09, 1.3513302e-07],
       [9.5448934e-07, 1.9862059e-10, 4.8043688e-07, 1.6648470e-07,
        9.9840838e-01, 5.2410746e-07, 5.1501883e-07, 3.5493169e-05,
        3.9406530e-07, 1.5531185e-03]], dtype=float32)>