In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import tensorflow as tf
import numpy as np
# 用tensorflow 导入数据
from tensorflow.examples.tutorials.mnist import input_data

In [3]:
# 设置GPU按需增长
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

In [4]:
mnist = input_data.read_data_sets('../../data/MNIST_data', one_hot=True) 
# 看看咱们样本的数量
print(mnist.test.labels.shape)
print(mnist.train.labels.shape)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../data/MNIST_data\train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../data/MNIST_data\train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../../data/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ../../data/MNIST_data\t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
(10000, 10)
(55000, 10)


In [5]:
lr = 1e-3
input_size = 28      # 每个时刻的输入特征是28维的，就是每个时刻输入一行，一行有 28 个像素
timestep_size = 28   # 时序持续长度为28，即每做一次预测，需要先输入28行
hidden_size = 256    # 隐含层的数量
layer_num = 2        # LSTM layer 的层数
class_num = 10       # 最后输出分类类别数量，如果是回归预测的话应该是 1
cell_type = "lstm"   # lstm 或者 block_lstm

In [6]:
X_input = tf.placeholder(tf.float32, [None, 784])
y_input = tf.placeholder(tf.float32, [None, class_num])
batch_size = tf.placeholder(tf.int32, [])
keep_prob = tf.placeholder(tf.float32, [])

In [7]:
# RNN 的输入shape = (batch_size, timestep_size, input_size) 
X = tf.reshape(X_input, [-1,28,28])
print(X.shape)

(?, 28, 28)


In [8]:
def lstm_cell(cell_type, num_nodes, keep_prob):
    assert (cell_type in ['lstm', 'block_lstm'], 'wrong type')
    if cell_type == 'lstm':
        cell = tf.contrib.rnn.BasicLSTMCell(num_nodes)
    else:
        cell = tf.contrib.rnn.LSTMBlockCell(num_units=num_nodes)
    cell = tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)
    return cell

In [9]:
mlstm_cell = tf.contrib.rnn.MultiRNNCell(
    [lstm_cell(cell_type, hidden_size, keep_prob) for _ in range(layer_num)], 
    state_is_tuple=True)

In [10]:
init_state = mlstm_cell.zero_state(batch_size=batch_size, dtype=tf.float32)

In [11]:
outputs, state = tf.nn.dynamic_rnn(mlstm_cell, inputs=X, initial_state=init_state, time_major=False)
# print(state.shape)

In [12]:
h_state = state[-1][1]

In [13]:
import time

In [14]:
W = tf.Variable(tf.truncated_normal([hidden_size, class_num]))
bias = tf.Variable(tf.constant(0.1, shape=[class_num], dtype=tf.float32))
y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)

In [15]:
cross_entropy = -tf.reduce_mean(y_input*tf.log(y_pre))
tf.summary.scalar('cross_entropy',cross_entropy)
train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)

In [16]:
correct_pred = tf.equal(tf.argmax(y_pre,1), tf.argmax(y_input,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, 'float'))
tf.summary.scalar('accuracy',accuracy)

<tf.Tensor 'accuracy:0' shape=() dtype=string>

In [17]:
merged = tf.summary.merge_all()
sess.run(tf.global_variables_initializer())
train_writer = tf.summary.FileWriter('./log', sess.graph)

In [18]:
time0 = time.time()

In [19]:
for i in range(300):#1000
    _batch_size = 64
    X_batch, y_batch = mnist.train.next_batch(batch_size=_batch_size)
    if i%50 == 0:
        # 配置运行时需要记录的信息
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        # 运行时记录运行信息的proto
        run_metadata = tf.RunMetadata()
        summary, cost, acc, _ = sess.run([merged, cross_entropy, accuracy, train_op], 
                                feed_dict={X_input:X_batch, y_input:y_batch,keep_prob:0.5,
                                          batch_size:_batch_size},
                                         options=run_options,
                                         run_metadata=run_metadata)
        # 将节点在运行时的信息写入日志文件
        train_writer.add_run_metadata(run_metadata, 'step%03d' % i)
    else:
        summary, cost, acc, _ = sess.run([merged, cross_entropy, accuracy, train_op],
                                feed_dict={X_input:X_batch, y_input:y_batch,keep_prob:0.5,
                                        batch_size:_batch_size})
    train_writer.add_summary(summary,i)
    if i%50 == 0:
        test_acc = 0.0
        test_cost = 0.0
        N = 100
        for j in range(N):
            X_batch, y_batch = mnist.test.next_batch(batch_size=_batch_size)
            _cost, _acc = sess.run([cross_entropy, accuracy], 
                                   feed_dict={X_input:X_batch, y_input:y_batch, keep_prob:1.0, 
                                              batch_size:_batch_size})
            test_acc += _acc
            test_cost += _cost
        print("step {}, train cost={:.6f}, acc={:.6f}; test cost={:.6f}, acc={:.6f}; pass {}s".format(i+1, cost, acc, test_cost/N, test_acc/N, time.time() - time0))
        time0 = time.time()
print('train finished')
train_writer.close()

step 1, train cost=0.234208, acc=0.156250; test cost=0.243825, acc=0.184375; pass 8.099560022354126s
step 51, train cost=0.092172, acc=0.750000; test cost=0.085505, acc=0.726719; pass 17.221086025238037s
step 101, train cost=0.068696, acc=0.781250; test cost=0.048413, acc=0.844531; pass 17.318379163742065s
step 151, train cost=0.043525, acc=0.890625; test cost=0.045675, acc=0.858750; pass 18.619638442993164s
step 201, train cost=0.034354, acc=0.906250; test cost=0.028370, acc=0.915000; pass 17.23573398590088s
step 251, train cost=0.029636, acc=0.890625; test cost=0.025004, acc=0.921250; pass 16.360023498535156s
train finished
