In [1]:
import tensorflow as tf
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.12.1


## Load a dataset

Load and prepare the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). The pixel values of the images range from 0 through 255. Scale these values to a range of 0 to 1 by dividing the values by `255.0`. This also converts the sample data from integers to floating-point numbers:

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

## Build a machine learning model

In [3]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

In [4]:
y_train[:1]

array([5], dtype=uint8)

In [5]:
predictions = model(x_train[:1]).numpy()
predictions

array([[ 0.02042624, -0.31973466,  0.24808899, -0.5230112 , -0.57477915,
         0.5741872 , -0.45984206, -0.36943716,  0.5539082 , -0.22327459]],
      dtype=float32)

In [6]:
tf.nn.softmax(predictions).numpy()

array([[0.10391071, 0.07394867, 0.13047671, 0.06034599, 0.05730149,
        0.1807822 , 0.06428098, 0.07036307, 0.17715305, 0.08143712]],
      dtype=float32)

In [7]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [8]:
loss_fn(y_train[:1], predictions).numpy()

1.7104623

In [9]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [10]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1e1c4655990>

In [11]:
model.evaluate(x_test,  y_test, verbose=1)



[0.07359051704406738, 0.9769999980926514]

In [12]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [13]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[8.1293012e-07, 2.2272904e-07, 2.4268515e-05, 3.0252698e-03,
        5.2854294e-11, 1.7376091e-05, 5.5397818e-14, 9.9692106e-01,
        6.1411461e-06, 4.9654795e-06],
       [1.3578257e-08, 1.3962031e-03, 9.9855298e-01, 2.2991642e-06,
        1.7025269e-15, 4.7827045e-05, 2.9793716e-08, 2.1571566e-11,
        7.5027606e-07, 1.0499384e-13],
       [5.1835389e-07, 9.9958223e-01, 3.3758701e-05, 3.2913606e-06,
        1.0406998e-05, 3.2530927e-06, 1.5944435e-06, 3.5138562e-04,
        1.3365309e-05, 2.8656291e-07],
       [9.9995780e-01, 1.9467366e-08, 7.4686559e-06, 6.6283985e-07,
        1.5662652e-06, 2.9255057e-07, 1.0034048e-05, 1.0841211e-05,
        5.3176699e-08, 1.1304189e-05],
       [1.8594164e-06, 4.8280118e-08, 1.0034024e-04, 2.4385317e-08,
        9.9887985e-01, 1.3232615e-06, 8.8608458e-06, 6.4652879e-05,
        2.5293562e-06, 9.4047457e-04]], dtype=float32)>