In [1]:
import tensorflow as tf

print(f'Tensorflow version: {tf.__version__}')

Tensorflow version: 2.9.1


In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

In [4]:
predictions = model(x_train[:1]).numpy()
predictions

array([[ 1.3522527 , -1.4696175 , -0.1081751 , -0.01228109,  1.0481275 ,
         0.3614598 , -0.7485703 , -0.7804753 , -0.21214724,  0.5028113 ]],
      dtype=float32)

In [5]:
tf.nn.softmax(predictions).numpy()

array([[0.282972  , 0.0168353 , 0.06568826, 0.07229927, 0.20876783,
        0.10506246, 0.03462324, 0.03353601, 0.05920156, 0.12101403]],
      dtype=float32)

In [6]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [7]:
loss_fn(y_train[:1], predictions).numpy()

2.2532003

In [8]:
model.compile(
    optimizer='adam',
    loss=loss_fn,
    metrics=['accuracy']
)

In [9]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1ec1abffe80>

In [10]:
model.evaluate(x_test, y_test, verbose=2)

313/313 - 1s - loss: 0.0730 - accuracy: 0.9784 - 768ms/epoch - 2ms/step


[0.07302378118038177, 0.9783999919891357]

In [11]:
probability_model = tf.keras.Sequential([
    model,
    tf.keras.layers.Softmax()
])

In [12]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[2.2373893e-08, 1.6543147e-10, 4.2447300e-06, 5.8041938e-04,
        7.5099974e-13, 7.9136761e-08, 3.6053051e-15, 9.9941337e-01,
        1.6536326e-07, 1.5884860e-06],
       [1.1309726e-09, 4.7098551e-06, 9.9998617e-01, 8.9196819e-06,
        2.4602296e-15, 6.7720045e-08, 5.6122627e-08, 1.2818018e-13,
        4.5172584e-08, 5.3779453e-14],
       [7.3085671e-06, 9.8871952e-01, 8.0784055e-04, 4.0582556e-04,
        6.8975110e-05, 7.9310748e-06, 5.4202523e-05, 9.4082896e-03,
        5.1509682e-04, 5.0878975e-06],
       [9.9997938e-01, 3.8765854e-10, 1.1603885e-05, 5.1430572e-08,
        7.2331358e-07, 1.6405336e-07, 5.4908983e-06, 2.2538941e-06,
        8.7231994e-10, 3.2594957e-07],
       [1.1145790e-05, 1.2602449e-09, 4.7361341e-06, 3.7230113e-06,
        9.7747505e-01, 1.4591249e-06, 8.5509437e-06, 8.2289116e-05,
        3.8850262e-06, 2.2409104e-02]], dtype=float32)>