In [53]:
import tensorflow as tf
import os
import numpy as np
import tempfile

example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")  # 临时文件夹
np.random.seed(0)

* TFRecord是Google官方推荐的一种数据格式,是Google专门为TensorFlow设计的一种数据格式
* TFRecord内部使用了“Protocol Buffer”二进制数据编码方案,它只占用一个内存块,只需要一次性加载一个二进制文件的方式即可,简单,快速,尤其对大型训练数据很友好
* 除了"快",还有另外一个优点就是,在多模态学习(比如视频+音频+文案作为特征)中可以将各种形式的特征预处理后统一放在TFRecord中,避免了读取数据时候的麻烦


The tf.train.Feature message type can accept one of the following three types (See the .proto file for reference). Most other generic types can be coerced into one of these:

1. tf.train.BytesList (the following types can be coerced)
    * string
    * byte

2. tf.train.FloatList (the following types can be coerced)
    * float (float32)
    * double (float64)

3. tf.train.Int64List (the following types can be coerced)
    * bool
    * enum
    * int32
    * uint32
    * int64
    * uint64

### 实例一

#### 写入

In [None]:
with tf.io.TFRecordWriter(example_path) as file_writer:
    for _ in range(4):
        # 1,数据x,y
        x, y = np.random.random(), np.random.randint(0, 9)

        # 2,Features字典
        features = tf.train.Features(feature={
            "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),  # value必须为list向量
            "y": tf.train.Feature(int64_list=tf.train.Int64List(value=[y])),
        })

        # 3,Features字典转换为Example对象并序列化
        record_bytes = tf.train.Example(features=features).SerializeToString()

        # 4,写入到TFRecord文件
        file_writer.write(record_bytes)

#### 加载

In [None]:
def decode_fn(record_bytes):
    # 2,定义Feature结构,告诉解码器每个Feature的类型是什么
    features = {"x": tf.io.FixedLenFeature([], dtype=tf.float32),
                "y": tf.io.FixedLenFeature([], dtype=tf.int64)}

    # 3,将TFRecord文件中的每一个序列化的tf.train.Example解码
    result = tf.io.parse_single_example(record_bytes, features)
    return result

In [56]:
dataset = tf.data.TFRecordDataset([example_path])  # 1. 读取TFRecore文件

# 数据处理方法含义与tf.data.Dataset相同
for batch in dataset.map(decode_fn):
    print("x = {x:.4f},  y = {y:.4f}".format(**batch))

{'x': <tf.Tensor 'ParseSingleExample/ParseExample/ParseExampleV2:0' shape=() dtype=float32>, 'y': <tf.Tensor 'ParseSingleExample/ParseExample/ParseExampleV2:1' shape=() dtype=int64>}
x = 0.5488,  y = 5.0000
x = 0.8443,  y = 3.0000
x = 0.8473,  y = 3.0000
x = 0.6459,  y = 4.0000
