In [1]:
import tensorflow as tf
# TensorFlow 提供了一个类来处理minist 数据集。自动下载并转化格式
from tensorflow.examples.tutorials.mnist import input_data # mnist dataset operation
import mnist_inference # mnist_inference.py
# Common pathname manipulations. https://docs.python.org/3/library/os.path.html
import os 

#### 1. 定义神经网络结构相关的参数。

In [2]:
BATCH_SIZE = 100 
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 30001
MOVING_AVERAGE_DECAY = 0.99 
MODEL_SAVE_PATH = "../MNIST_model/"
MODEL_NAME = "mnist_model"

#### 2. 定义训练过程。
```python
tf.train.exponential_decay(
    learning_rate, # starter learning rate
    global_step, #
    decay_steps, # 衰减速度
    decay_rate, # 衰减系数
    staircase=False, #  If True decay the learning rate at discrete intervals
    name=None
)
```
And
$$decayed\_learning\_rate = learning\_rate *decay\_rate^{(global\_step / decay\_steps)}$$
For this case:
$decayed\_learning\_rate=0.8∗0.99^{global\_step/mnist.train.num_examples / BATCH\_SIZE}$

In [3]:
def train(mnist):
    # 定义输入输出placeholder。
    x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
    y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
    # weight regularizer
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    #调用mnist_inference.py 前向传播过程
    y = mnist_inference.inference(x, regularizer)
    global_step = tf.Variable(0, trainable=False)
    
    # 定义损失函数、学习率、滑动平均操作以及训练过程。
    # 滑动平均
    variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    variables_averages_op = variable_averages.apply(tf.trainable_variables())
    
    # 计算交叉熵及其平均值
    # tf.nn.softmax_cross_entropy_with_logits() = softmax + cross_entropy
    # 在只有一个答案的分类问题中，只用`sparse_softmax_cross_entropy_with_logits`加速计算
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    # All loss
    loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
    # 学习率
    learning_rate = tf.train.exponential_decay(
        LEARNING_RATE_BASE,
        global_step,
        mnist.train.num_examples / BATCH_SIZE,  # decay_steps, # 衰减速度
        LEARNING_RATE_DECAY,
        staircase=True)
    # train step
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    
    # 反向传播更新参数和更新每一个参数的滑动平均值
    # 为了一次完成更新参数和滑动平均参数,tf提供了`tf.control_dependencies` 和`tf.group`两种机制。
    # 等价于：train_op = tf.group(train_step, variables_averages_op)
#     with tf.control_dependencies([train_step, variables_averages_op]):
#         train_op = tf.no_op(name='train')
    train_op = tf.group(train_step, variables_averages_op)
    
    # 初始化TensorFlow持久化类。
    saver = tf.train.Saver()
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        # train steps.
        for i in range(TRAINING_STEPS):
            # Feed train data.
            xs, ys = mnist.train.next_batch(BATCH_SIZE)
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
            if i % 1000 == 0:
                print("After %d training step(s), loss on training batch is %g." % (step, loss_value))
                saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)

#### 3. 主程序入口。

In [4]:
def main(argv=None):
    mnist = input_data.read_data_sets("../../0_datasets/MNIST_data/", one_hot=True)
    train(mnist)

if __name__ == '__main__':
    main()

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../0_datasets/MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../0_datasets/MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../../0_datasets/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../0_datasets/MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
After 1 training step(s), loss on training batch is 2.76234.
After 1001 training step(s), loss on training batch is 0.242863.
After 2001 training step(s), loss on training batch is 0.161668.
After 3001 training step(s), loss on train