In [1]:
# tfrecord 文件格式
# tf.train.Example
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)


2.0.0
sys.version_info(major=3, minor=6, micro=3, releaselevel='final', serial=0)
matplotlib 3.1.2
numpy 1.17.4
pandas 0.25.3
sklearn 0.21.3
tensorflow 2.0.0
tensorflow_core.keras 2.2.4-tf


#### 1.1 自己做成feature

In [17]:
# tfrecord
# -> tf.train.Example
#   -> tf.train.Features  -> {"key":tf.train.Feature}
#     -> tf.train.Feature -> tf.train.BytesList/FloatList/Int64List
favourite_books = [name.encode("utf-8") for name in ["machine learning", "cc150"]
                  ]
# favourite_books = ["machine learning", "cc150"]
favourite_books_bytelist = tf.train.BytesList(value = favourite_books)
print(favourite_books_bytelist)
hours_floatlist = tf.train.FloatList(value = [14.4, 9.3, 121.2, 121.1])

features = tf.train.Features(
    feature = {
        "favourite_books":tf.train.Feature(
            bytes_list = favourite_books_bytelist
        ),
        "hours":tf.train.Feature(
            float_list = hours_floatlist
        ),
    }
)
print(features)

value: "machine learning"
value: "cc150"

feature {
  key: "favourite_books"
  value {
    bytes_list {
      value: "machine learning"
      value: "cc150"
    }
  }
}
feature {
  key: "hours"
  value {
    float_list {
      value: 14.399999618530273
      value: 9.300000190734863
      value: 121.19999694824219
      value: 121.0999984741211
    }
  }
}



#### 1.2 封装成example, 再封装成序列化example

In [22]:
example = tf.train.Example(features=features)
print(example)

serialized_example = example.SerializeToString()
print(serialized_example)

features {
  feature {
    key: "favourite_books"
    value {
      bytes_list {
        value: "machine learning"
        value: "cc150"
      }
    }
  }
  feature {
    key: "hours"
    value {
      float_list {
        value: 14.399999618530273
        value: 9.300000190734863
        value: 121.19999694824219
        value: 121.0999984741211
      }
    }
  }
}

b'\nO\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10fffA\xcd\xcc\x14Aff\xf2B33\xf2B\n.\n\x0ffavourite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150'


#### 2.1 存成正常文件

In [32]:
output_dir = "tfrecord_basic"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filename  = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
    for i in range(3):
        writer.write(serialized_example)

#### 2.2 读取

In [24]:
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    print(serialized_example_tensor)

tf.Tensor(b'\nO\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10fffA\xcd\xcc\x14Aff\xf2B33\xf2B\n.\n\x0ffavourite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150', shape=(), dtype=string)
tf.Tensor(b'\nO\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10fffA\xcd\xcc\x14Aff\xf2B33\xf2B\n.\n\x0ffavourite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150', shape=(), dtype=string)
tf.Tensor(b'\nO\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10fffA\xcd\xcc\x14Aff\xf2B33\xf2B\n.\n\x0ffavourite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150', shape=(), dtype=string)


In [28]:
expected_features = {
    "favourite_books":tf.io.VarLenFeature(dtype = tf.string),
    "hours": tf.io.VarLenFeature(dtype = tf.float32),
}
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features,
    )
#   print(example)
    books = tf.sparse.to_dense(example["favourite_books"],
                              default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))

machine learning
cc150
machine learning
cc150
machine learning
cc150


#### 3.1 存成压缩文件

In [31]:
filename_fullpath_zip = filename_fullpath + ".zip"
options = tf.io.TFRecordOptions(compression_type = "GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:
    for i in range(3):
        writer.write(serialized_example)

#### 3.2 读取

In [40]:
dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip],
                                 compression_type = "GZIP")
for serialized_example_tensor in dataset_zip:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features,
    )
    books = tf.sparse.to_dense(example["favourite_books"],
                              default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))

machine learning
cc150
machine learning
cc150
machine learning
cc150


#### 4.1 实战
##### 将csv文件中数据

In [74]:
# 获取文件名
import pprint
source_dir = "./generate_csv/"
# pprint.pprint(os.listdir(source_dir))
def get_filenames_by_prefix(source_dir, prefix_name):
    all_files = os.listdir(source_dir)
    results = []
    for filename in all_files:
        if filename.startswith(prefix_name):
            results.append(os.path.join(source_dir,filename))
    return results

train_filenames = get_filenames_by_prefix(source_dir, "train")
valid_filenames = get_filenames_by_prefix(source_dir, "valid")
test_filenames = get_filenames_by_prefix(source_dir, "test")

# pprint.pprint(valid_filenames)


['./generate_csv/valid_00.csv',
 './generate_csv/valid_01.csv',
 './generate_csv/valid_02.csv',
 './generate_csv/valid_03.csv',
 './generate_csv/valid_04.csv',
 './generate_csv/valid_05.csv',
 './generate_csv/valid_06.csv',
 './generate_csv/valid_07.csv',
 './generate_csv/valid_08.csv',
 './generate_csv/valid_09.csv']


In [75]:
# 读取csv 文件
def parse_csv_line(line, n_fields=9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x,y

def csv_reader_dataset(filenames, n_readers=5, batch_size = 32, n_parse_threads=5,
                      shuffle_buffer_size = 10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_readers
    )
    # 数据混排
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line,num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset
batch_size = 32
train_set = csv_reader_dataset(train_filenames, batch_size=batch_size)
valid_set = csv_reader_dataset(valid_filenames, batch_size=batch_size)
test_set = csv_reader_dataset(test_filenames, batch_size=batch_size)


In [80]:
def serilized_example(x,y):
    """Converts x,y to tf.train.Example and serialize"""
    input_features = tf.train.FloatList(value=x)
    label = tf.train.FloatList(value=y)
    features = tf.train.Features(
        feature={
            "input_features":tf.train.Feature(
                float_list = input_features
            ),
            "label":tf.train.Feature(
                float_list = label
            )
        }
    )
    example = tf.train.Example(features = features)
    return example.SerializeToString()


In [87]:
def csv_dataset_to_tfrecords(base_filename, dataset,
                             n_shards, step_per_shard,
                            compression_type= None):
    options = tf.io.TFRecordOptions(
        compression_type=compression_type
    )
    all_filenames = []
    for shard_id in range(n_shards):
        filename_fullpath = "{}_{:05d}-of-{:05d}".format(
            base_filename, shard_id, n_shards
        )
        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
            for x_batch, y_batch in dataset.take(step_per_shard):
                for x_example, y_example in zip(x_batch, y_batch):
                    writer.write(
                        serilized_example(x_example,y_example)
                    )
        all_filenames.append(filename_fullpath)
    return all_filenames

In [88]:
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecord"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filename = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard, None
)
valid_tfrecord_filename = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard, None
)
test_tfrecord_filename = csv_dataset_to_tfrecords(
    test_basename, valid_set, n_shards, test_steps_per_shard, None
)
