In [1]:
import pandas as pd
import numpy as np

# Make numpy values easier to read.
np.set_printoptions(precision=3, suppress=True)

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing

In [2]:
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.6.0


In [3]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[:500]
y_train = y_train[:500]
x_test = x_test[:500]
y_test = y_test[:500]
x_train, x_test = x_train / 255.0, x_test / 255.0

In [4]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

In [5]:
predictions = model(x_train[:1]).numpy()
predictions

array([[ 0.26 , -0.372,  0.314, -0.31 ,  0.238,  0.527,  0.337,  0.247,
         0.101, -0.045]], dtype=float32)

In [6]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [7]:
loss_fn(y_train[:1], predictions).numpy()

1.9409138

In [8]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [9]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x2022f3cce80>

In [10]:
model.evaluate(x_test,  y_test, verbose=2)

16/16 - 0s - loss: 0.6061 - accuracy: 0.8220


[0.6060577034950256, 0.8220000267028809]

In [11]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [12]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[0.001, 0.   , 0.002, 0.003, 0.002, 0.001, 0.   , 0.97 , 0.002,
        0.019],
       [0.032, 0.008, 0.837, 0.022, 0.   , 0.028, 0.063, 0.   , 0.009,
        0.   ],
       [0.008, 0.873, 0.019, 0.011, 0.003, 0.014, 0.015, 0.021, 0.025,
        0.011],
       [0.886, 0.   , 0.006, 0.001, 0.   , 0.008, 0.08 , 0.017, 0.001,
        0.   ],
       [0.006, 0.001, 0.028, 0.003, 0.849, 0.006, 0.008, 0.034, 0.006,
        0.059]], dtype=float32)>