In [8]:
# coding: utf-8
# author: Fengzhijin
# time: 2017.11.16
# ==================================
'''
实现了将FashionMNIST数据集源文件转换成tfrecords文件
1._int64_feature() - int64数据转换函数
2._bytes_feature() - 二进制字符串转换函数
3.convert_to() - tfrecords文件生成函数
'''

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import mnist


# 转换成int64数据
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 转换成二进制字符串数据
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


# 生成tfrecords文件
def convert_to(data_set, name):
    images = data_set.images    # 图片数据
    labels = data_set.labels    # 标签数据
    num_examples = data_set.num_examples    # 图片个数
    
    # 创建tfrecords文件并进行写操作
    with tf.python_io.TFRecordWriter('../data/tfrecords/'+name+'.tfrecords') as writer:
        print('Writing:  '+name+'.tfrecords')
        for index in range(num_examples):
            image_raw = images[index].tostring()    # 将图片数据转换成字符串
            # 将一个样例转换成Example Protocol Buffer，并写入数据
            example = tf.train.Example(features=tf.train.Features(feature={
                'label': _int64_feature(int(labels[index])),
                'image_raw': _bytes_feature(image_raw)}))
            writer.write(example.SerializeToString())


if __name__ == '__main__':
    # mnist数据路径
    data_file = '../data/fashion/'
    # 将mnist数据进行处理读取，分成train、validation、test三种数据
    data_sets = mnist.read_data_sets(data_file,
                                     dtype=tf.uint8,
                                     reshape=False,
                                     validation_size=5000)
    convert_to(data_sets.train, 'train')    # 生成训练集tfrecords文件
    convert_to(data_sets.validation, 'validation')    # 生成验证集tfrecords文件
    convert_to(data_sets.test, 'test')    # 生成测试集tfrecords文件

Extracting ../data/fashion/train-images-idx3-ubyte.gz
Extracting ../data/fashion/train-labels-idx1-ubyte.gz
Extracting ../data/fashion/t10k-images-idx3-ubyte.gz
Extracting ../data/fashion/t10k-labels-idx1-ubyte.gz
Writing:  train.tfrecords
Writing:  validation.tfrecords
Writing:  test.tfrecords
