[toc]

# Tensorflow Dataset

## from_generator

下面的代码展示如何使用 `from_generator` 来生成成对数据集。对于只有 X 而没有 y 的数据集，只需要进行小修改即可。

In [1]:
import tensorflow as tf
import numpy as np

print(tf.__version__) # 2.2.0

# 生成数据集
x_train = np.random.uniform(0, 1, [10, 3])
y_train = np.random.randint(0, 10, [10, ])

# 定义生成器
def batch_generator():
    n_samples = x_train.shape[0]
    for i in range(n_samples):
        yield x_train[i], y_train[i]

# 使用 生成器
train_dataset = tf.data.Dataset.from_generator(batch_generator, (tf.float32, tf.int32))

# 设置epoch为2
train_dataset = train_dataset.repeat(2)
# 这里生成的 dataset 还没有分 batch，使用 .batch() 设置batch_size为 5
train_dataset = train_dataset.batch(5)
for x, y in train_dataset:
    print(x, y)

1.15.0


RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.

**其中有一个坑。 generator 函数返回的应该是一个 tuple 对象，而不是一个 list 对象。**

如果将上面的 `batch_generator` 修改成下面的样子，则会报错 `TypeError: `generator` yielded an element that did not match the expected structure. The expected structure was (tf.float32, tf.int32), but the yielded element was [array([0.71891118, 0.55713524, 0.83305131]), 2].`

In [None]:
# 定义生成器
def batch_generator():
    n_samples = x_train.shape[0]
    for i in range(n_samples):
        yield [x_train[i], y_train[i]] # Error!

### 生成器传参

有时，我们想在创建生成器的时候传入参数，此时可以使用 args 来传入参数

In [3]:
import tensorflow as tf
import numpy as np

print(tf.__version__) # 2.2.0

# 生成数据集
x_train = np.random.uniform(0, 1, [10, 3])
y_train = np.random.randint(0, 10, [10, ])

# 定义生成器
def batch_generator(batch_size):
    n_samples = x_train.shape[0]
    start = 0
    while start < n_samples:
        yield x_train[start: start+batch_size], y_train[start: start+batch_size]
        start += batch_size

# 使用 生成器
train_dataset = tf.data.Dataset.from_generator(batch_generator, (tf.float32, tf.int32), args=(4,)) # 注意，一个参数时要写作 (4,) 而不是 (4)

for x, y in train_dataset:
    print(x, y)

2.2.0
tf.Tensor(
[[0.8848176  0.32335648 0.11673396]
 [0.635564   0.9853623  0.34263027]
 [0.525984   0.10160588 0.93273044]
 [0.9645629  0.07456936 0.7372238 ]], shape=(4, 3), dtype=float32) tf.Tensor([8 8 9 9], shape=(4,), dtype=int32)
tf.Tensor(
[[0.4510364  0.50084555 0.9311134 ]
 [0.00404259 0.4692437  0.44436535]
 [0.8974993  0.8824926  0.99141914]
 [0.9761399  0.7577163  0.5177157 ]], shape=(4, 3), dtype=float32) tf.Tensor([2 3 2 8], shape=(4,), dtype=int32)
tf.Tensor(
[[0.04106055 0.00701396 0.10839242]
 [0.660547   0.9916826  0.9171571 ]], shape=(2, 3), dtype=float32) tf.Tensor([7 8], shape=(2,), dtype=int32)


### 获取一个 batch

In [43]:
iterator = train_dataset.as_numpy_iterator()
iterator.next()

(array([[0.8848176 , 0.32335648, 0.11673396],
        [0.635564  , 0.9853623 , 0.34263027],
        [0.525984  , 0.10160588, 0.93273044],
        [0.9645629 , 0.07456936, 0.7372238 ]], dtype=float32),
 array([8, 8, 9, 9], dtype=int32))

## from_tensor_slices

In [17]:

import tensorflow as tf
import numpy as np

print(tf.__version__) # 2.2.0

# 生成数据集
x_train = np.random.uniform(0, 1, [10, 3])
y_train = np.random.randint(0, 10, [10, ])


# 使用 生成器
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))

# 设置epoch为2，设置batch_size为 5
train_dataset = train_dataset.repeat(2).batch(5)
for x, y in train_dataset:
    print(x, y)

2.2.0
tf.Tensor(
[[0.63026611 0.84018849 0.96822053]
 [0.00950266 0.82345767 0.14237034]
 [0.3628959  0.36985282 0.76039462]
 [0.60812065 0.30927249 0.17130571]
 [0.27561707 0.12677321 0.7169283 ]], shape=(5, 3), dtype=float64) tf.Tensor([6 7 0 5 8], shape=(5,), dtype=int64)
tf.Tensor(
[[0.32115764 0.44327448 0.6138822 ]
 [0.71774647 0.05496469 0.48456999]
 [0.00653992 0.61188734 0.01731185]
 [0.9451552  0.08030777 0.47051986]
 [0.06737426 0.55813358 0.97273234]], shape=(5, 3), dtype=float64) tf.Tensor([7 3 7 1 5], shape=(5,), dtype=int64)
tf.Tensor(
[[0.63026611 0.84018849 0.96822053]
 [0.00950266 0.82345767 0.14237034]
 [0.3628959  0.36985282 0.76039462]
 [0.60812065 0.30927249 0.17130571]
 [0.27561707 0.12677321 0.7169283 ]], shape=(5, 3), dtype=float64) tf.Tensor([6 7 0 5 8], shape=(5,), dtype=int64)
tf.Tensor(
[[0.32115764 0.44327448 0.6138822 ]
 [0.71774647 0.05496469 0.48456999]
 [0.00653992 0.61188734 0.01731185]
 [0.9451552  0.08030777 0.47051986]
 [0.06737426 0.55813358 0.972

# References
1. [(1条消息)tf.dataset 使用 python generator 加载和预处理数据，dataset.map 对数据进行额外加工_ONE_SIX_MIX的专栏-CSDN博客_dataset.map](https://blog.csdn.net/ONE_SIX_MIX/article/details/80633187)