# 目录
## 1. 模块导入
## 2. 获取mnist数据集
## 3. 写入mnist到tfrecords
  - `tf.python_io.TFRecordWriter`
  - `tf.train.Example`
  - `tf.train.Features`
  - `tf.train.Feature`
  - `tf.train.BytesList`
  - `tf.train.Int64List`
  - `SerializeToString`
  
## 4. 读取tfrecords（单一读取，批量读取）
  - `tf.TFRecordReader`
  - `tf.train.string_input_producer`
  - `read`
  - `tf.io.FixedLenFeature`
  - `tf.io.parse_single_example`
  - `tf.decode_raw`
  - `tf.train.shuffle_batch`
  - `tf.train.batch`
  - `tf.train.Coordinator`
  - `tf.train.start_queue_runners`
  - `request_stop`
  - `join`

## 1. 模块导入

In [1]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import sklearn

from tensorflow import keras
import tensorflow as tf
import sys
import os
import time
import datetime

for module in [np, pd, mpl, sklearn, keras, tf]:
    print(module.__name__, module.__version__)

numpy 1.17.2
pandas 0.25.1
matplotlib 3.1.1
sklearn 0.21.3
tensorflow.python.keras.api._v1.keras 2.2.4-tf
tensorflow 1.15.0


## 2. 获取mnist数据集

In [29]:
from tensorflow.examples.tutorials.mnist import input_data

output_dir = "tf1_data"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

mnist = input_data.read_data_sets("tf1_data/MNIST_data", dtype=tf.uint8, one_hot=False)

Extracting tf1_data/MNIST_data/train-images-idx3-ubyte.gz
Extracting tf1_data/MNIST_data/train-labels-idx1-ubyte.gz
Extracting tf1_data/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting tf1_data/MNIST_data/t10k-labels-idx1-ubyte.gz


## 3. 写入mnist到tfrecords

In [7]:
def mnist_to_tfrecords(images, labels, save_dir, name_prefix):
    filename_fullpath = os.path.join(save_dir, name_prefix+".tfrecords")
    with tf.python_io.TFRecordWriter(filename_fullpath) as writer:
        for image, label in zip(images, labels):
            example = tf.train.Example(
                features=tf.train.Features(
                    feature={
                        "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tostring()])),
                        "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
                    }
                )
            )
            serialized_example = example.SerializeToString()
            writer.write(serialized_example)
    return filename_fullpath            

In [8]:
train_filename = mnist_to_tfrecords(mnist.train.images, mnist.train.labels, "tf1_data", "train")
valid_filename = mnist_to_tfrecords(mnist.validation.images, mnist.validation.labels, "tf1_data", "valid")
test_filename = mnist_to_tfrecords(mnist.test.images, mnist.test.labels, "tf1_data", "test")

print(train_filename)
print(valid_filename)
print(test_filename)

tf1_data/train.tfrecords
tf1_data/valid.tfrecords
tf1_data/test.tfrecords


## 4. 读取tfrecords（单一读取，批量读取）

In [26]:
def mnist_tfrecords_single_reader(filename, epochs=None):
    '''
    filename: 是一个文件名列表
    '''
    # 实例化一个读取对象
    reader = tf.TFRecordReader()
    
    # 把文件名列表放入队列
    filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=epochs)
    
    # 读取一个序列化的example
    _,serialized_example = reader.read(filename_queue)
    
    # 解析example， 方法跟tf2 一样
    expect_features = {
        "image": tf.io.FixedLenFeature([1], tf.string),
        "label": tf.io.FixedLenFeature([1], tf.int64)
    }
    
    example = tf.io.parse_single_example(serialized_example, expect_features) # 是一个字典
    
    # 解析图片数组
    image = tf.decode_raw(example["image"],tf.uint8)
    image = tf.reshape(image, shape=[784])
    
    # 解析label
    label = example["label"]
    label = label[0]
    
    return image, label

def mnist_tfrecords_batch_reader(filename, batch_size=32, shuffle=True, epochs=None):
     '''
    filename: 是一个文件名列表
    '''
    # 实例化一个读取对象
    reader = tf.TFRecordReader()
    
    # 把文件名列表放入队列
    filename_queue = tf.train.string_input_producer(filename, shuffle=False, num_epochs=epochs)
    
    # 读取一个序列化的example
    _,serialized_example = reader.read(filename_queue)
    
    # 解析example， 方法跟tf2 一样
    expect_features = {
        "image": tf.io.FixedLenFeature([1], tf.string),
        "label": tf.io.FixedLenFeature([1], tf.int64)
    }
    
    example = tf.io.parse_single_example(serialized_example, expect_features)
    
    # 解析图片数组
    image = tf.decode_raw(example["image"],tf.uint8)
    image = tf.reshape(image, shape=[784])
    
    # 解析label
    label = example["label"]
    label = label[0]
    
    # 封装batch
    if shuffle:
        image_batch, label_batch = tf.train.shuffle_batch([image, label], 
                                                          batch_size=batch_size, 
                                                          capacity=5000+3*batch_size, 
                                                          min_after_dequeue=100)
    else:
        image_batch, label_batch = tf.train.batch([image, label], batch_size=batch_size, capacity=5000+3*batch_size)
    
    return image_batch, label_batch

In [27]:
image, label = mnist_tfrecords_single_reader([train_filename])
with tf.Session() as sess:
    
    # 启动线程
    coordinator = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)

    # 读取数据
    for i in range(2):
        image_value, label_value= sess.run([image, label])
        print(label_value)
        print(image_value.shape)
    
    # 结束线程
    coordinator.request_stop()
    coordinator.join(threads)

7
(784,)
3
(784,)


In [28]:
image_batch, label_batch = mnist_tfrecords_batch_reader([train_filename], batch_size=3)
with tf.Session() as sess:
    
    # 启动线程
    coordinator = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
    
    # 读取数据
    for i in range(2):
        image_value, label_value= sess.run([image_batch, label_batch])
        print(label_value)
        print(image_value.shape)
    
    # 结束线程
    coordinator.request_stop()
    coordinator.join(threads)

[4 8 9]
(3, 784)
[3 0 4]
(3, 784)
