In [2]:
# coding:utf-8
# author: Fengzhijin
# time: 2017.11.23
# ==================================
'''
运用卷积神经网络解决FashionMNIST数据识别问题
1.init_weight() - 生成卷积神经网络卷积核函数
2.get_weights_bases() - 生成神经网络各层权值与偏置值函数
3.output_inference() - 神经网络层前向传播函数
4.mode（）- 两层卷积池化层和两层神经网络层模型
5.此算法对过拟合进行解决，模型只保存正确率高的情况
'''

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework import graph_util

batch_size = 1024

MODEL_SAVE_PATH = "./model/pb/"


# 生成卷积神经网络卷积核
def init_weight(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))


# 生成神经网络各层权值与偏置值
def get_weights_bases(shape):
    weights = tf.Variable(tf.truncated_normal([shape[0], shape[1]], stddev=0.1))
    bases = tf.Variable(tf.constant(0.1, shape=[shape[1]]))
    return weights, bases


# 神经网络层前向传播
def output_inference(input_tensor, weights, biases):
    layer = tf.matmul(input_tensor, weights) + biases
    return layer


# 两层卷积池化层和两层神经网络层模型
def model(X, w, w2, w3, b3, w4, b4, w5, b5, p_keep_conv):
    l1a = tf.nn.relu(tf.nn.conv2d(X, w, strides=[1, 1, 1, 1], padding='SAME'))
    l1 = tf.nn.max_pool(l1a, ksize=[1, 3, 3, 1], strides=[1, 3, 3, 1], padding='SAME')
    l1 = tf.nn.dropout(l1, p_keep_conv)
    # shape = [?, 10, 10, 32]

    l2a = tf.nn.relu(tf.nn.conv2d(l1, w2, strides=[1, 1, 1, 1], padding='SAME'))
    l2 = tf.nn.max_pool(l2a, ksize=[1, 3, 3, 1], strides=[1, 3, 3, 1], padding='SAME')
    l2 = tf.reshape(l2, [-1, w3.get_shape().as_list()[0]])
    l2 = tf.nn.dropout(l2, p_keep_conv)
    # shape = [?, 4, 4, 64]

    l3 = tf.nn.relu(tf.matmul(l2, w3) + b3)

    l4 = tf.nn.relu(tf.matmul(l3, w4) + b4)

    layer = tf.matmul(l4, w5) + b5

    return layer


def main():
    with tf.Graph().as_default() as graph:
        # 模型构建过程
        mnist = input_data.read_data_sets("../data/fashion", one_hot=True)
        trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
        trX = trX.reshape(-1, 28, 28, 1)
        teX = teX.reshape(-1, 28, 28, 1)
        X = tf.placeholder("float", [None, 28, 28, 1], name='x-input')
        Y = tf.placeholder("float", [None, 10], name='y-input')
        w = init_weight([3, 3, 1, 32])
        w2 = init_weight([3, 3, 32, 64])
        w3, b3 = get_weights_bases([64 * 4 * 4, 300])
        w4, b4 = get_weights_bases([300, 100])
        w5, b5 = get_weights_bases([100, 10])
        p_keep_conv = tf.placeholder("float", name='p_keep_conv')
        py_x = model(X, w, w2, w3, b3, w4, b4, w5, b5, p_keep_conv)
        # 模型训练过程  使用Adam算法优化损失函数
        cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
        train_op = tf.train.AdamOptimizer().minimize(cost)
        predict_op = tf.argmax(py_x, 1)

    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()
        max_acc = 0.0
        # 模型训练50次
        for i in range(50):
            training_batch = zip(range(0, len(trX), batch_size),
                                 range(batch_size, len(trX)+1, batch_size))
            for start, end in training_batch:
                sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
                         p_keep_conv: 0.8})
            accuracy = np.mean(np.argmax(teY, axis=1) == sess.run(
                 predict_op, feed_dict={X: teX, p_keep_conv: 1.0}))
            print(i, accuracy)
            # 如果模型训练的正确率超过最大正确率，则对模型进行保存
            if (accuracy > max_acc):
                max_acc = accuracy
                new_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['ArgMax'])
                tf.train.write_graph(new_graph, MODEL_SAVE_PATH, 'graph.pb', as_text=False)


if __name__ == "__main__":
    main()


Extracting ../data/fashion/train-images-idx3-ubyte.gz
Extracting ../data/fashion/train-labels-idx1-ubyte.gz
Extracting ../data/fashion/t10k-images-idx3-ubyte.gz
Extracting ../data/fashion/t10k-labels-idx1-ubyte.gz
0 0.7244
1 0.7858
2 0.8098
3 0.8288
4 0.8446
5 0.8415
6 0.8555
7 0.8617
8 0.8672
9 0.8663
10 0.871
11 0.8718
12 0.8734
13 0.8792
14 0.8826
15 0.8819
16 0.885
17 0.8868
18 0.8895
19 0.8907
20 0.8929
21 0.8935
22 0.8965
23 0.8985
24 0.9002
25 0.8988
26 0.9011
27 0.9035
28 0.9028
29 0.906
30 0.9069
31 0.9069
32 0.9068
33 0.9078
34 0.9094
35 0.9094
36 0.9099
37 0.9099
38 0.9114
39 0.9126
40 0.9136
41 0.9128
42 0.9144
43 0.915
44 0.9158
45 0.9135
46 0.9164
47 0.9155
48 0.918
49 0.9176
