https://www.tensorflow.org/tutorials/quickstart/beginner

In [1]:
import tensorflow as tf

print(f'Tensorflow version: {tf.__version__}')

Tensorflow version: 2.9.1


In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10)
])

In [4]:
predictions = model(x_train[:1]).numpy()
predictions

array([[-0.46707538, -0.49244565,  0.5335856 ,  0.3134041 , -0.55763656,
        -0.42966643,  0.48809025,  0.7674017 , -0.1901529 , -0.13632442]],
      dtype=float32)

In [5]:
tf.nn.softmax(predictions).numpy()

array([[0.05689628, 0.05547095, 0.15476237, 0.12417717, 0.05197011,
        0.05906502, 0.14787917, 0.19552866, 0.07504984, 0.07920036]],
      dtype=float32)

In [6]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [7]:
loss_fn(y_train[:1], predictions).numpy()

2.8291163

In [8]:
model.compile(
    optimizer='adam',
    loss=loss_fn,
    metrics=['accuracy']
)

In [9]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x2cb003808b0>

In [10]:
model.evaluate(x_test, y_test, verbose=2)

313/313 - 1s - loss: 0.0743 - accuracy: 0.9777 - 707ms/epoch - 2ms/step


[0.07434866577386856, 0.9776999950408936]

In [11]:
probability_model = tf.keras.Sequential([
    model,
    tf.keras.layers.Softmax()
])

In [12]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[2.2682291e-07, 7.1808968e-09, 5.0859227e-05, 6.2780036e-04,
        6.7803915e-11, 1.8951259e-07, 2.9498845e-13, 9.9930298e-01,
        1.8858077e-06, 1.5926222e-05],
       [2.7615972e-09, 2.3805413e-05, 9.9989498e-01, 6.8856185e-05,
        5.5937161e-14, 1.4153603e-06, 1.0667700e-05, 4.3762427e-15,
        2.6060516e-07, 1.9824943e-13],
       [1.4113740e-06, 9.9388701e-01, 3.4873844e-03, 9.3590796e-05,
        1.6835325e-04, 1.2866127e-05, 5.6692923e-04, 1.3449437e-03,
        4.2908464e-04, 8.4739631e-06],
       [9.9992204e-01, 5.9425267e-13, 3.0092930e-05, 6.5419968e-08,
        3.5373662e-07, 3.7676532e-06, 4.2773368e-05, 4.3369823e-07,
        1.7770191e-09, 5.0434107e-07],
       [5.1520015e-05, 1.7230876e-11, 6.9031790e-05, 5.7809774e-07,
        9.9378115e-01, 9.6877818e-07, 1.8961146e-05, 9.3176175e-05,
        6.8011054e-06, 5.9779030e-03]], dtype=float32)>