# Training a neural network on MNIST with Keras

This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.


Copyright 2020 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/datasets/keras_example"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

2022-12-14 12:10:12.401315: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2022-12-14 12:10:12.401417: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


## Step 1: Create your input pipeline

Start by building an efficient input pipeline using advices from:
* The [Performance tips](https://www.tensorflow.org/datasets/performances) guide
* The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide


### Load a dataset

Load the MNIST dataset with the following arguments:

* `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
* `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

2022-12-14 12:10:14.569060: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


### Build a training pipeline

Apply the following transformations:

* `tf.data.Dataset.map`: TFDS provide images of type `tf.uint8`, while the model expects `tf.float32`. Therefore, you need to normalize images.
* `tf.data.Dataset.cache` As you fit the dataset in memory, cache it before shuffling for a better performance.<br/>
__Note:__ Random transformations should be applied after caching.
* `tf.data.Dataset.shuffle`: For true randomness, set the shuffle buffer to the full dataset size.<br/>
__Note:__ For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.
* `tf.data.Dataset.batch`: Batch elements of the dataset after shuffling to get unique batches at each epoch.
* `tf.data.Dataset.prefetch`: It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).

In [3]:
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

### Build an evaluation pipeline

Your testing pipeline is similar to the training pipeline with small differences:

 * You don't need to call `tf.data.Dataset.shuffle`.
 * Caching is done after batching because batches can be the same between epochs.

In [4]:
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)

## Step 2: Create and train the model

Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.

In [5]:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

Epoch 1/6


  1/469 [..............................] - ETA: 18:08 - loss: 2.4104 - sparse_categorical_accuracy: 0.1250

 20/469 [>.............................] - ETA: 1s - loss: 1.5395 - sparse_categorical_accuracy: 0.6156   

 41/469 [=>............................] - ETA: 1s - loss: 1.1002 - sparse_categorical_accuracy: 0.7237

 63/469 [===>..........................] - ETA: 1s - loss: 0.8878 - sparse_categorical_accuracy: 0.7748

 85/469 [====>.........................] - ETA: 0s - loss: 0.7715 - sparse_categorical_accuracy: 0.8010

107/469 [=====>........................] - ETA: 0s - loss: 0.6931 - sparse_categorical_accuracy: 0.8194

































Epoch 2/6


  1/469 [..............................] - ETA: 30s - loss: 0.1981 - sparse_categorical_accuracy: 0.9297

 24/469 [>.............................] - ETA: 0s - loss: 0.1907 - sparse_categorical_accuracy: 0.9453 

 47/469 [==>...........................] - ETA: 0s - loss: 0.1923 - sparse_categorical_accuracy: 0.9463

 70/469 [===>..........................] - ETA: 0s - loss: 0.1884 - sparse_categorical_accuracy: 0.9477

 93/469 [====>.........................] - ETA: 0s - loss: 0.1876 - sparse_categorical_accuracy: 0.9479



































Epoch 3/6


  1/469 [..............................] - ETA: 28s - loss: 0.1427 - sparse_categorical_accuracy: 0.9531

 24/469 [>.............................] - ETA: 0s - loss: 0.1292 - sparse_categorical_accuracy: 0.9642 

 47/469 [==>...........................] - ETA: 0s - loss: 0.1280 - sparse_categorical_accuracy: 0.9644

 70/469 [===>..........................] - ETA: 0s - loss: 0.1283 - sparse_categorical_accuracy: 0.9648

 93/469 [====>.........................] - ETA: 0s - loss: 0.1267 - sparse_categorical_accuracy: 0.9655



































Epoch 4/6


  1/469 [..............................] - ETA: 28s - loss: 0.0815 - sparse_categorical_accuracy: 0.9844

 23/469 [>.............................] - ETA: 1s - loss: 0.0926 - sparse_categorical_accuracy: 0.9759 

 46/469 [=>............................] - ETA: 0s - loss: 0.0921 - sparse_categorical_accuracy: 0.9754

 69/469 [===>..........................] - ETA: 0s - loss: 0.0923 - sparse_categorical_accuracy: 0.9741

 92/469 [====>.........................] - ETA: 0s - loss: 0.0940 - sparse_categorical_accuracy: 0.9742



































Epoch 5/6


  1/469 [..............................] - ETA: 30s - loss: 0.0508 - sparse_categorical_accuracy: 0.9766

 24/469 [>.............................] - ETA: 0s - loss: 0.0732 - sparse_categorical_accuracy: 0.9811 

 47/469 [==>...........................] - ETA: 0s - loss: 0.0756 - sparse_categorical_accuracy: 0.9806

 70/469 [===>..........................] - ETA: 0s - loss: 0.0757 - sparse_categorical_accuracy: 0.9794

 93/469 [====>.........................] - ETA: 0s - loss: 0.0769 - sparse_categorical_accuracy: 0.9785



































Epoch 6/6


  1/469 [..............................] - ETA: 27s - loss: 0.0442 - sparse_categorical_accuracy: 0.9922

 24/469 [>.............................] - ETA: 0s - loss: 0.0582 - sparse_categorical_accuracy: 0.9840 

 47/469 [==>...........................] - ETA: 0s - loss: 0.0622 - sparse_categorical_accuracy: 0.9815

 70/469 [===>..........................] - ETA: 0s - loss: 0.0641 - sparse_categorical_accuracy: 0.9817

 93/469 [====>.........................] - ETA: 0s - loss: 0.0647 - sparse_categorical_accuracy: 0.9826



































<keras.callbacks.History at 0x7f9bc821b3d0>