In [1]:
import tensorflow as tf

# 初始化迭代器

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(iterator.initializer, feed_dict={max_value: 10})
    for i in range(10):
        value = sess.run(next_element)
        assert value == i

    sess.run(iterator.initializer, feed_dict={max_value: 100})
    for i in range(100):
        value = sess.run(next_element)
        assert value == i


可重新初始化迭代器

In [2]:
train_dataset = tf.data.Dataset.range(100).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)

iterator = tf.data.Iterator.from_structure(train_dataset.output_types, train_dataset.output_shapes)

next_element = iterator.get_next()

training_init_op = iterator.make_initializer(train_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)

with tf.Session() as sess:
    for _ in range(20):
        sess.run(training_init_op)
        for _ in range(100):
            sess.run(next_element)

        sess.run(training_init_op)
        for _ in range(50):
            sess.run(next_element)


可馈送迭代器可以与 tf.placeholder 一起使用，通过熟悉的 feed_dict 机制选择每次调用 tf.Session.run 时所使用的 Iterator。它提供的功能与可重新初始化迭代器的相同，但在迭代器之间切换时不需要从数据集的开头初始化迭代器。

In [3]:
# 训练和验证集
training_dataset = tf.data.Dataset.range(10).map(
    lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(5)

# 训练和验证集迭代器
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

# 创建可馈送的迭代器
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
    handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

with tf.Session() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())

    # 迭代器之间切换时不需要从数据集的开头初始化迭代器
    # training_dataset 使用了repeat(),所以是无穷无尽的
    for _ in range(20):
        next = sess.run(next_element, feed_dict={handle: training_handle})
        print(_, next)

    sess.run(validation_iterator.initializer)
    for i in range(5):
        value = sess.run(next_element, feed_dict={handle: validation_handle})
        assert i == value


0 7
1 0
2 2
3 5
4 -6
5 6
6 10
7 -2
8 11
9 11
10 0
11 -1
12 7
13 7
14 -1
15 3
16 0
17 4
18 14
19 7


消耗迭代器中的值

In [4]:
dataset = tf.data.Dataset.range(5)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
result = tf.add(next_element, next_element)
with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(sess.run(result))
    print(sess.run(result))
    print(sess.run(result))
    print(sess.run(result))
    print(sess.run(result))

    sess.run(iterator.initializer)
    while True:
        try:
            sess.run(result)
        except tf.errors.OutOfRangeError:
            break

# 嵌套结构的迭代器的消耗
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([10, 1]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([10,1]), tf.random_uniform([10, 2])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    next1, (next2, next3) = sess.run(iterator.get_next())
    print(next1, (next2, next3))

0
2
4
6
8
[0.9367877] (array([0.6502514], dtype=float32), array([0.89872587, 0.5918896 ], dtype=float32))


简单的批处理

In [5]:
inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)

iterator = batched_dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    print(sess.run(next_element))
    print(sess.run(next_element))
    print(sess.run(next_element))


(array([0, 1, 2, 3]), array([ 0, -1, -2, -3]))
(array([4, 5, 6, 7]), array([-4, -5, -6, -7]))
(array([ 8,  9, 10, 11]), array([ -8,  -9, -10, -11]))


使用填充批处理张量

In [6]:
dataset = tf.data.Dataset.range(100)
# tf.fill([2, 3], 9) ==> [[9, 9, 9]  [9, 9, 9]]
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    print(sess.run(next_element))
    print(sess.run(next_element))


[[0 0 0]
 [1 0 0]
 [2 2 0]
 [3 3 3]]
[[4 4 4 4 0 0 0]
 [5 5 5 5 5 0 0]
 [6 6 6 6 6 6 0]
 [7 7 7 7 7 7 7]]


处理多个周期

In [7]:
# repeat(3) 重复3遍
dataset = tf.data.Dataset.range(10).repeat(2)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(20):
        value = sess.run(next_element)
        assert i % 10 == value


## 随机重排输入数据

Dataset.shuffle() 转换会使用类似于 tf.RandomShuffleQueue 的算法随机重排输入数据集：它会维持一个固定大小的缓冲区，并从该缓冲区统一地随机选择下一个元素。

In [8]:
dataset = tf.data.Dataset.range(5).repeat(3)
dataset = dataset.batch(3).shuffle(buffer_size=6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    print(sess.run(next_element))
    print(sess.run(next_element))
    print(sess.run(next_element))
    print(sess.run(next_element))
    print(sess.run(next_element))


[2 3 4]
[0 1 2]
[4 0 1]
[3 4 0]
[1 2 3]


## tf.train.MonitoredTrainingSession 
tf.train.MonitoredTrainingSession API 简化了在分布式设置下运行 TensorFlow 的很多方面。MonitoredTrainingSession 使用 tf.errors.OutOfRangeError 表示训练已完成，因此要将其与 tf.data API 结合使用，我们建议使用 Dataset.make_one_shot_iterator()

In [9]:
dataset = tf.data.Dataset.range(5).repeat(3)
dataset = dataset.batch(3).shuffle(buffer_size=6)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
        print(sess.run(next_element))


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


[0 1 2]
[3 4 0]
[2 3 4]
[1 2 3]
[4 0 1]
