## Load and prepare the dataset
You will use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data

In [6]:
import tensorflow as tf
print("TensorFlow version:", tf.__version__)

TensorFlow version: 2.7.0


## Load a dataset

In [8]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [9]:
print(x_train.shape)
print(x_test.shape)

(60000, 28, 28)
(10000, 28, 28)


## Build a model
Build a tf.keras.Sequential model by stacking layers.

In [4]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10)
])

The model returns a vector of logits or log-odds scores, one for each class.

In [10]:
x_train[:1].shape

(1, 28, 28)

In [11]:
predictions = model(x_train[:1]).numpy()
predictions

array([[ 0.03825606, -0.38697302, -0.138106  , -0.5955628 , -0.08486367,
         0.82853395, -1.0889051 ,  0.16788127,  0.15458284,  0.37596893]],
      dtype=float32)

The tf.nn.softmax function converts these logits to probabilities for each class:

In [12]:
tf.nn.softmax(predictions).numpy()

array([[0.09902837, 0.06472692, 0.08301691, 0.05254067, 0.0875567 ,
        0.2182594 , 0.0320804 , 0.11273405, 0.1112448 , 0.13881183]],
      dtype=float32)

Note: It is possible to bake the tf.nn.softmax function into the activation function for the last layer of the network. While this can make the model output more directly interpretable, this approach is discouraged as it's impossible to provide an exact and numerically stable loss calculation for all models when using a softmax output.

Define a loss function for training using `losses.SparseCategoricalCrossentropy`, which takes a vector of logits and a `True` index and returns a scalar loss for each example.

In [13]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
print(loss_fn)

<keras.losses.SparseCategoricalCrossentropy object at 0x7f71ca6a0d90>


This loss is equal to the negative log probability of the true class: The loss is zero if the model is sure of the correct class.

This untrained model gives probabilities close to random (1/10 for each class), so the initial loss should be close to `-tf.math.log(1/10) ~= 2.3`.

In [14]:
loss_fn(y_train[:1], predictions).numpy()

1.522071

Before you start training, configure and compile the model using Keras `Model.compile`. Set the [`optimizer`](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers) class to `adam`, set the `loss` to the `loss_fn` function you defined earlier, and specify a metric to be evaluated for the model by setting the `metrics` parameter to `accuracy`.

In [15]:
model.compile(optimizer='adam',
              loss=loss_fn,
              metrics=['accuracy'])

## Train and evaluate your model

Use the `Model.fit` method to adjust your model parameters and minimize the loss: 

In [16]:
model.fit(x_train, y_train, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f71b0363e20>

The `Model.evaluate` method checks the models performance, usually on a "[Validation-set](https://developers.google.com/machine-learning/glossary#validation-set)" or "[Test-set](https://developers.google.com/machine-learning/glossary#test-set)".

In [17]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 0s - loss: 0.0756 - accuracy: 0.9772 - 290ms/epoch - 925us/step


[0.07558929175138474, 0.9771999716758728]

If you want your model to return a probability, you can wrap the trained model, and attach the softmax to it:

In [18]:
probability_model = tf.keras.Sequential([
  model,
  tf.keras.layers.Softmax()
])

In [19]:
probability_model(x_test[:5])

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[3.3632867e-07, 3.8181387e-09, 1.2544197e-06, 2.4038793e-04,
        4.8075648e-11, 2.8259055e-06, 1.1488607e-12, 9.9974459e-01,
        3.7654541e-08, 1.0539260e-05],
       [7.8477413e-10, 5.9313735e-04, 9.9940193e-01, 4.7491822e-06,
        2.5922315e-15, 1.1993889e-07, 1.0662419e-09, 3.5056810e-13,
        2.6458146e-08, 6.4379980e-15],
       [1.1580319e-06, 9.9733633e-01, 1.3129889e-03, 1.4027314e-05,
        9.1410024e-05, 1.2198306e-05, 1.2240747e-05, 8.2507334e-04,
        3.9288460e-04, 1.6073678e-06],
       [9.9999130e-01, 1.1157100e-10, 2.4173846e-06, 1.3360243e-08,
        1.5622463e-06, 6.6809213e-07, 5.6674997e-07, 7.2940554e-07,
        1.8458863e-08, 2.7202209e-06],
       [1.5119151e-05, 6.0381189e-10, 1.2906025e-05, 4.4622172e-07,
        9.9885631e-01, 1.3051203e-06, 2.4189778e-05, 1.4033301e-04,
        3.6530454e-07, 9.4894954e-04]], dtype=float32)>