[TensorFlow Datasets, TFDS](https://tensorflow.org/datasets) 使得下载通用数据集 (common datasets) 变得非常容易，不管是 MNIST 或 Fashion MNIST 这样的小数据集，还是 ImageNet 这样的大数据集。

TFDS 包含 image、text、audio 和 video 数据集。

[这里](https://homl.info/tfds) 列出了所有了所有可用数据集及其描述。

<br>

- 首先，安装 tensorflow-datasets（`pip install tensorflow-datasets`）。


- 然后，调用 `tfds.load()` 函数来下载所需的数据集（已下载的除外），以 dict 的形式返回数据。dict 的内容取决于选择的数据集，通常一个用于训练，一个用于测试。


In [1]:
import os

import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds

np.random.seed(42)
tf.random.set_seed(42)

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
print("TensorFlow version:", tf.__version__)
print("Keras version:", keras.__version__)
print("TensorFlow-Datasets version:", tfds.__version__)

TensorFlow version: 2.3.0
Keras version: 2.4.0
TensorFlow-Datasets version: 3.1.0


In [3]:
# Download MNIST
dataset = tfds.load(name='mnist')

[1mDownloading and preparing dataset mnist/3.0.1 (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /home/libing/tensorflow_datasets/mnist/3.0.1...[0m


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead pass
`try_gcs=True` to `tfds.load` or set `data_dir=gs://tfds-data/datasets`.



HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /home/libing/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.[0m


In [4]:
dataset

{'test': <PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>,
 'train': <PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>}

In [5]:
isinstance(dataset, dict)

True

In [6]:
mnist_train, mnist_test = dataset['train'], dataset['test']

In [7]:
mnist_train

<PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>

**dataset 里的每一个 item 都是一个 dict：包含 features 和 labels。** 但是 keras 需要每一个 item 是一个有 2 个 element 组成的 tuple。可以使用 `map` 来转换：

In [9]:
# Apply any transformations
mnist_train = mnist_train.repeat(5).batch(32)
mnist_train = mnist_train.map(lambda item: (item['image'], item['label']))
mnist_train = mnist_train.prefetch(1)

for image, label in mnist_train.take(1):
    print(image.shape)
    print(label.numpy())

(32, 28, 28, 1)
[4 1 0 7 8 1 2 7 1 6 6 4 7 7 3 3 7 9 9 1 0 6 6 9 9 4 8 9 4 7 3 3]


<div class="alert alert-block alert-info">
    <b>通过设置 <code>shuffle_files=True</code>，<code>tfds.load()</code> 函数可以 shuffle 下载的文件。但这还不够充分，所以，最好还是再 shuffle 训练数据。</b>
</div>

更加简便的方法：

**使用 `tfds.load()` 的时候指定 `as_supervised=True`。还可以指定 batch 的大小。然后就可以直接把 dataset 传入 tf.keras 模型了。**

In [10]:
dataset = tfds.load(
    name='mnist',
    batch_size=32,
    as_supervised=True
)

In [11]:
dataset

{'test': <PrefetchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.uint8, tf.int64)>,
 'train': <PrefetchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.uint8, tf.int64)>}

In [12]:
mnist_train = dataset['train'].repeat().prefetch(1)

In [13]:
mnist_train

<PrefetchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.uint8, tf.int64)>

In [14]:
for image, label in mnist_train.take(1):
    print(image.shape)
    print(label)

(32, 28, 28, 1)
tf.Tensor([4 1 0 7 8 1 2 7 1 6 6 4 7 7 3 3 7 9 9 1 0 6 6 9 9 4 8 9 4 7 3 3], shape=(32,), dtype=int64)


In [16]:
tf.keras.backend.clear_session()
model = keras.Sequential([
    keras.layers.Input(shape=(28, 28, 1)),
    keras.layers.Flatten(),
    keras.layers.Lambda(lambda image: tf.cast(image, tf.float32)),
    keras.layers.Dense(64, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.compile(
    loss='sparse_categorical_crossentropy',
    optimizer='adam',
    metrics=['acc']
)
model.fit(mnist_train, epochs=5, steps_per_epoch=60000//32)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7f20c04180b8>