In [1]:
import tensorflow as tf
print(f'Tensorflow version: {tf.__version__}')

Tensorflow version: 2.8.0


In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [3]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

In [4]:
predictions = model(x_train[:1]).numpy()
predictions

array([[-0.180335  ,  0.6092944 ,  0.10032096, -0.40905142,  0.6716069 ,
        -0.2799277 , -0.7782379 ,  0.8119817 , -0.26066893,  0.04133122]],
      dtype=float32)

In [5]:
tf.nn.softmax(predictions).numpy()

array([[0.07147996, 0.15744032, 0.09463932, 0.05686618, 0.16756293,
        0.06470409, 0.03931139, 0.1928155 , 0.06596228, 0.08921805]],
      dtype=float32)

In [6]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [7]:
loss_fn(y_train[:1], predictions).numpy()

2.7379308

In [8]:
model.compile(
    optimizer='adam',
    loss=loss_fn,
    metrics=['accuracy']
)

In [9]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1d8020d8fa0>

In [10]:
model.evaluate(x_test, y_test, verbose=2)

313/313 - 1s - loss: 0.0702 - accuracy: 0.9773 - 682ms/epoch - 2ms/step


[0.0702342614531517, 0.9772999882698059]

In [11]:
probability_model = tf.keras.Sequential([
    model,
    tf.keras.layers.Softmax()
])

In [12]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[7.63878859e-07, 2.51029810e-08, 8.29281416e-05, 7.51016487e-04,
        5.88687571e-11, 1.18414471e-06, 1.04319304e-13, 9.99008000e-01,
        1.35203015e-06, 1.54750756e-04],
       [2.66884115e-09, 3.13593409e-05, 9.99959469e-01, 6.95918834e-06,
        3.47565316e-15, 6.78201957e-07, 3.98407707e-09, 2.96731368e-12,
        1.56456565e-06, 3.24018725e-13],
       [9.19934450e-07, 9.96926725e-01, 1.28133222e-04, 1.35570690e-05,
        1.79752769e-05, 1.20809964e-05, 3.26703121e-05, 2.47043581e-03,
        3.93576047e-04, 3.92364745e-06],
       [9.99908447e-01, 1.83503227e-10, 1.54094105e-05, 4.19410867e-07,
        1.92666061e-08, 1.75807804e-06, 8.54424343e-06, 6.03223525e-05,
        3.80325460e-08, 4.93411881e-06],
       [1.09366274e-05, 2.78489978e-08, 4.68277794e-05, 9.38753089e-07,
        9.85092103e-01, 5.68624955e-07, 2.56449380e-06, 5.14418585e-04,
        2.20195252e-05, 1.43096428e-02]], dtype=float32)>