Version Check

In [2]:
import tensorflow as tf
print("TensorFlow version: %s" % tf.__version__)

TensorFlow version: 2.11.0


Load MNIST dataset

In [3]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


Build a ML model

In [4]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

logits (log-odds) scores for each class

In [5]:
pred = model(x_train[:1]).numpy()
pred

array([[-0.02304666,  0.28690127,  0.5209847 , -0.37468466, -0.12255435,
        -0.27093273,  0.01349288, -0.59458566, -0.12201694,  0.59636164]],
      dtype=float32)

converted logits to probabilities for each class

In [6]:
tf.nn.softmax(pred).numpy()

array([[0.09224217, 0.12575875, 0.1589276 , 0.06489558, 0.08350527,
        0.0719903 , 0.09567499, 0.05208508, 0.08355016, 0.17137013]],
      dtype=float32)

Define a loss function for training

In [10]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
loss_fn(y_train[:1], pred).numpy()

2.631224

Compiler settings

In [11]:
model.compile(optimizer='adam', loss=loss_fn, metrics=['accuracy'])

Train and Evaluate model

In [12]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1a19e687dc0>

In [13]:
model.evaluate(x_test, y_test, verbose=2)

313/313 - 0s - loss: 0.0810 - accuracy: 0.9754 - 418ms/epoch - 1ms/step


[0.08103282749652863, 0.9753999710083008]

Want model to return a probability?

In [14]:
probability_model = tf.keras.Sequential([
    model,
    tf.keras.layers.Softmax()
])
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[1.4945706e-06, 1.2007980e-07, 8.0937352e-05, 8.9258931e-05,
        7.0706746e-10, 2.0011934e-09, 9.7838829e-12, 9.9981290e-01,
        1.6361827e-07, 1.5126662e-05],
       [1.1956149e-08, 1.3406106e-03, 9.9865770e-01, 1.6346595e-06,
        1.7641273e-13, 9.0305939e-08, 3.5986297e-08, 1.3467435e-14,
        1.3184290e-08, 3.5104542e-13],
       [1.6279266e-06, 9.9767929e-01, 2.4012321e-04, 5.5782466e-06,
        1.9363788e-04, 8.5791635e-06, 1.6924376e-05, 1.6862213e-03,
        1.6683748e-04, 1.2062183e-06],
       [9.9977881e-01, 3.7237486e-09, 1.3385486e-04, 2.0349669e-08,
        2.0423356e-06, 2.7991592e-07, 5.1254417e-05, 1.2469779e-06,
        1.6224891e-07, 3.2423657e-05],
       [5.9129724e-07, 3.4108820e-09, 6.3673770e-06, 2.1865366e-08,
        9.9602258e-01, 2.0453236e-08, 1.5369412e-06, 5.9117599e-05,
        9.4570836e-08, 3.9095799e-03]], dtype=float32)>