## 1. 将数据集划分成训练集、验证集、测试集并存储成TFRecord文件

In [3]:
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data


# 将训练集和测试集分为训练集、验证集和测试集
data = input_data.read_data_sets('data_initial',           # 读取的时候规定好dtype
                                    dtype=tf.uint8,
                                    reshape=False,
                                ) 


# # 转为tfrecord文件
# # 路径
config = [{'dir': 'data_initial/train/', 'type': 'train'},
          {'dir': 'data_initial/validation/', 'type': 'validation'},
          {'dir': 'data_initial/test/', 'type': 'test'},
          ]

for each in range(len(config)):
    mnist_dir = config[each]['dir']
    mnist_type = config[each]['type']
    # tfrecord格式文件名
    with tf.python_io.TFRecordWriter('mnist_' + mnist_type + '.tfrecords') as writer:
        image_path = data[each].images
        for num_images in range(image_path.shape[0]):
            image_byte = data[each].images[num_images].tobytes()
            label = data[each].labels[num_images]
            example = tf.train.Example(features=tf.train.Features(feature={  
                    "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),  
                    'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_byte]))  
                }))
            writer.write(example.SerializeToString())
print('successful')

Extracting data_initial/train-images-idx3-ubyte.gz
Extracting data_initial/train-labels-idx1-ubyte.gz
Extracting data_initial/t10k-images-idx3-ubyte.gz
Extracting data_initial/t10k-labels-idx1-ubyte.gz
successful


## 2.利用matplotlib等工具对TFRecord中的样本数据进行可视化，以验证存储在TFRecord文件中的样本与标记的完整性与对应性，并对数据集有个直观的认识。

In [4]:
import tensorflow as tf
from PIL import Image
import numpy as np


def read_tfrecord(config_dir, num = 1):   
    # 读取tfrecord代码      
    filename_queue = tf.train.string_input_producer([config_dir])    # 创建输入队列，读入流中
    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)  # 返回文件名和文件

    # 取出包含有image 和 label的feature对象
    features = tf.parse_single_example(example,
                                        features={'label': tf.FixedLenFeature([], tf.int64),
                                                    'data': tf.FixedLenFeature([], tf.string)})  # 将对应的内存块读为张量流
    image = tf.decode_raw(features['data'], tf.uint8)  # tf.decode_raw可以将字符串解析成图像对应的像素组
    image = tf.cast(image, tf.float32)    # 解码之后转数据类型 
    image = tf.reshape(image, [28, 28])
    label = tf.cast(features['label'], tf.int32)  # 类型转换
    # 随机读取数据，验证图片对应正确性
    image_batch, label_batch = tf.train.shuffle_batch([image, label],
                                                        batch_size=1,
                                                        capacity=100,
                                                        min_after_dequeue=50)

    # 开始一个会话
    with tf.Session() as sess:
        exm_images = np.zeros((num, 784))
        exm_labels = np.zeros((num, 1))

        init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init)
        # 启动多线程
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        for count in range(num):

            image, label = sess.run([image_batch, label_batch])  # 在会话中取出image和label
            img = image.reshape([28, 28])  # 这里要reshape因为默认一个批次处理的数据会外层嵌套一层
            img = img.astype(np.uint8)  # PIL保存时，必须是整数
            if num == 1:
                coord.request_stop()  
                coord.join(threads)
                return img, label
            else:
                image = image.reshape(784)
                # for i in range(784):
                #     # if image[i] > 127:
                #     #     image[i] = 1
                #     # else:
                #     #     image[i] = 0     # 改进 ： 会让正确率提高5% 左右
                image = image / 255
                exm_images[count, :] = image
                exm_labels[count, :] = label
                if count % 1000 == 0:
                    print(count)
        coord.request_stop()  
        coord.join(threads)
    return exm_images, exm_labels
     

In [8]:
# 读取tfrecord代码 , 并验证
for each in range(len(config)):
    config_path = 'mnist_' + config[each]['type'] + '.tfrecords'
    for num in range(3):
        [img, label] = read_tfrecord(config_path)
        result = Image.fromarray(img)  # 这里image是之前提到的
        result.save(str(each) + str(num)+'.jpg')  

        print(label, end=' ')  

[4] [3] [7] [0] [4] [4] [1] [7] [0] 

![image](00.jpg)
![image](01.jpg)
![image](02.jpg)
![image](10.jpg)
![image](11.jpg)
![image](12.jpg)
![image](20.jpg)
![image](21.jpg)
![image](22.jpg)

## 3.设计并训练KNN算法对图片进行分类

In [10]:
from numpy import *
import operator
import os

# 分类函数

def kNNClassify(newInput, dataSet, labels, k):
    numSamples = dataSet.shape[0]  # shape[0] stands for the num of row
  
    ## step 1:计算欧式距离
    diff = tile(newInput, (numSamples, 1)) - dataSet # 单个测试样本与训练样本中的每一个数据做比较
    squaredDiff = diff ** 2 
    squaredDist = sum(squaredDiff, axis = 1) 
    distance = squaredDist ** 0.5
  
    ## step 2: 对距离进行分类
    sortedDistIndices = argsort(distance)     # 将x中的元素从小到大排列，提取其对应的index(索引)，然后输出到y
  
    classCount = {} 
    for i in range(k):
        ## step 3: 找最小距离
        voteLabel = labels[sortedDistIndices[i]][0]
  
        ## step 4: 分类与次数一一对应
        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
    ## step 5: 返回投票结果
    maxCount = 0
    for key, value in classCount.items():
        if value > maxCount:
            maxCount = value
            maxIndex = key
    return maxIndex


# 分类正确率
def testHandWritingClass():
    print("step 1: load data...")
    [train_x, train_y] = read_tfrecord('mnist_train.tfrecords', 55000)
    [test_x, test_y] = read_tfrecord('mnist_test.tfrecords', 500)
    
     
    print("step 2: testing...")
    numTestSamples = test_x.shape[0]
    matchCount = 0
    for i in range(numTestSamples):
        predict = kNNClassify(test_x[i], train_x, train_y, 5)
        if predict == test_y[i]:
            matchCount += 1
        if i % 100 == 0:
            print(i)
    accuracy = float(matchCount) / numTestSamples

    print("step 3: show the result...") 
    print('The classify accuracy is: %.2f%%' % (accuracy * 100))

testHandWritingClass()


step 1: load data...
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
0
step 2: testing...
0
100
200
300
400
step 3: show the result...
The classify accuracy is: 87.20%
