In [40]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
mnist = tf.keras.datasets.mnist

In [41]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [42]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

In [43]:
predictions = model(x_train[:1]).numpy()
predictions

array([[-0.20031293,  0.4417467 ,  0.2571655 , -0.04770537,  0.60568327,
        -0.33207938,  0.15622663, -0.10232562, -0.134523  ,  1.0202215 ]],
      dtype=float32)

In [44]:
tf.nn.softmax(predictions).numpy()

array([[0.06349552, 0.12066632, 0.1003283 , 0.07396388, 0.14216177,
        0.05565672, 0.09069561, 0.07003231, 0.06781337, 0.21518622]],
      dtype=float32)

In [45]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [46]:
loss_fn(y_train[:1], predictions).numpy()

2.8885522

In [47]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [48]:
model.fit(x_train, y_train, epochs=5)

Train on 60000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x18170046ac8>

In [49]:
model.evaluate(x_test, y_test, verbose=2)

10000/10000 - 0s - loss: 0.0736 - accuracy: 0.9766


[0.07362722312160767, 0.9766]

In [50]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [51]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[6.28850749e-09, 5.48203372e-10, 2.10331996e-06, 3.67189750e-05,
        1.21331235e-11, 2.99903782e-08, 3.70090548e-14, 9.99956489e-01,
        2.71196569e-07, 4.37804283e-06],
       [3.02143222e-09, 2.43908630e-07, 9.99999762e-01, 5.21061168e-08,
        1.76105646e-15, 2.29124062e-08, 3.45586448e-10, 6.61151440e-15,
        1.01501128e-08, 3.82422962e-13],
       [1.96367864e-07, 9.99158621e-01, 3.64593616e-05, 3.77283118e-06,
        7.18693627e-05, 1.87530975e-06, 1.03297480e-05, 4.68929444e-04,
        2.47391989e-04, 4.80209053e-07],
       [9.99552190e-01, 6.93497343e-11, 1.70992527e-04, 2.34924418e-08,
        1.64179966e-07, 1.22462882e-06, 2.10678309e-06, 7.07164727e-05,
        8.13159104e-08, 2.02554424e-04],
       [3.43466927e-05, 4.49118936e-11, 8.22148195e-06, 6.68538149e-08,
        9.51210439e-01, 5.99645659e-07, 7.53421818e-06, 5.57809253e-05,
        4.38885509e-05, 4.86391671e-02]], dtype=float32)>