<a href="https://colab.research.google.com/github/lilianabs/learn-tensorflow/blob/main/Quickstart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf

In [4]:
# Load the data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

In [5]:
# Build a model

model = tf.keras.models.Sequential([
   tf.keras.layers.Flatten(input_shape=(28, 28)),
   tf.keras.layers.Dense(128, activation='relu'),
   tf.keras.layers.Dropout(0.2),
   tf.keras.layers.Dense(10)                            
])

In [6]:
predictions = model(x_train[:1]).numpy()
predictions

array([[ 1.0110352 ,  0.05250525,  0.73398215, -0.60775477, -0.7671656 ,
        -0.51996857, -0.17065623,  0.33238178, -0.4194876 , -0.13145015]],
      dtype=float32)

In [7]:
# the tf.nn.softmax function converts these logits to probabilities for each class
tf.nn.softmax(predictions).numpy()

array([[0.24407287, 0.09359126, 0.18501072, 0.04836019, 0.04123412,
        0.05279746, 0.07487166, 0.12381808, 0.05837829, 0.07786538]],
      dtype=float32)

In [8]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [9]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

In [10]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f75c40fa390>

In [11]:
model.evaluate(x_test, y_test, verbose=2)

313/313 - 1s - loss: 0.0718 - accuracy: 0.9782 - 1s/epoch - 3ms/step


[0.0718393549323082, 0.9782000184059143]

In [12]:
# If you want your model to return a probability, you can wrap the trained model, and attach the softmax to it:
probability_model = tf.keras.Sequential([
    model,
    tf.keras.layers.Softmax()
])

In [13]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[1.8579631e-07, 2.0908115e-09, 3.8815229e-05, 2.5147726e-04,
        1.8013353e-11, 1.5209774e-07, 3.0897005e-13, 9.9970859e-01,
        5.4679759e-07, 3.2911757e-07],
       [2.2345134e-08, 6.6283651e-06, 9.9999237e-01, 5.4482019e-08,
        2.2648666e-16, 7.4092748e-07, 2.8528993e-08, 1.1556785e-15,
        2.7503745e-07, 1.0378026e-15],
       [2.0560262e-06, 9.9969554e-01, 1.1803074e-04, 5.7704160e-06,
        2.6051230e-06, 5.1169100e-06, 4.9376737e-05, 5.7986970e-05,
        6.3285421e-05, 2.1984300e-07],
       [9.9816257e-01, 4.3286306e-08, 4.7812538e-04, 1.5252723e-06,
        4.1912699e-06, 1.8903367e-05, 9.7377437e-05, 6.5808586e-06,
        1.3339054e-06, 1.2293858e-03],
       [1.5196691e-06, 1.2953996e-09, 1.6949100e-05, 1.8868249e-08,
        9.9670804e-01, 1.6068937e-07, 3.4668003e-06, 1.1095670e-04,
        1.3229997e-06, 3.1575591e-03]], dtype=float32)>