本文演示了如何使用 TensorFlow 中的 LSTM 来实现对 mnist 手写数字的识别。原文参考: https://jasdeep06.github.io/posts/Understanding-LSTM-in-Tensorflow-MNIST/

In [6]:
import tensorflow as tf
from tensorflow.contrib import rnn

#import mnist dataset
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("/tmp/tensorflow/mnist/input_data/",one_hot=True)

#define constants
#unrolled through 28 time steps
time_steps=28
#hidden LSTM units
num_units=128
#rows of 28 pixels
n_input=28
#learning rate for adam
learning_rate=0.001
#mnist is meant to be classified in 10 classes(0-9).
n_classes=10
#size of batch
batch_size=128

Extracting /tmp/tensorflow/mnist/input_data/train-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/train-labels-idx1-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-images-idx3-ubyte.gz
Extracting /tmp/tensorflow/mnist/input_data/t10k-labels-idx1-ubyte.gz


In [11]:
mnist.train.images.shape, mnist.train.labels.shape

((55000, 784), (55000, 10))

In [12]:
#weights and biases of appropriate shape to accomplish above task
out_weights=tf.Variable(tf.random_normal([num_units,n_classes]))
out_bias=tf.Variable(tf.random_normal([n_classes]))

#defining placeholders
#input image placeholder
x=tf.placeholder("float",[None,time_steps,n_input])
#input label placeholder
y=tf.placeholder("float",[None,n_classes])

In [15]:
#processing the input tensor from [batch_size,n_steps,n_input] to "time_steps" number of [batch_size,n_input] tensors
input=tf.unstack(x ,time_steps,1)

In [17]:
x, y, len(input), input[0]

(<tf.Tensor 'Placeholder:0' shape=(?, 28, 28) dtype=float32>,
 <tf.Tensor 'Placeholder_1:0' shape=(?, 10) dtype=float32>,
 28,
 <tf.Tensor 'unstack:0' shape=(?, 28) dtype=float32>)

In [18]:
#defining the network
lstm_layer=rnn.BasicLSTMCell(num_units,forget_bias=1)
outputs,_=rnn.static_rnn(lstm_layer,input,dtype="float32")

In [19]:
#converting last output of dimension [batch_size,num_units] to [batch_size,n_classes] by out_weight multiplication
prediction=tf.matmul(outputs[-1],out_weights)+out_bias

In [22]:
len(outputs), outputs[0], prediction

(28,
 <tf.Tensor 'rnn/rnn/basic_lstm_cell/mul_2:0' shape=(?, 128) dtype=float32>,
 <tf.Tensor 'add:0' shape=(?, 10) dtype=float32>)

In [23]:
#loss_function
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
#optimization
opt=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

#model evaluation
correct_prediction=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

In [28]:
#initialize variables
init=tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    iter=1
    while iter <= 800:
        batch_x,batch_y=mnist.train.next_batch(batch_size=batch_size)

        batch_x=batch_x.reshape((batch_size,time_steps,n_input))

        sess.run(opt, feed_dict={x: batch_x, y: batch_y})

        if iter % 100==0:
            acc=sess.run(accuracy,feed_dict={x:batch_x,y:batch_y})
            los=sess.run(loss,feed_dict={x:batch_x,y:batch_y})
            print("For iter ",iter)
            print("Accuracy ",acc)
            print("Loss ",los)
            print("__________________")

        iter=iter+1
    #calculating test accuracy
    test_data = mnist.test.images[:1280].reshape((-1, time_steps, n_input))
    test_label = mnist.test.labels[:1280]
    print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

('For iter ', 100)
('Accuracy ', 0.828125)
('Loss ', 0.46741468)
__________________
('For iter ', 200)
('Accuracy ', 0.9375)
('Loss ', 0.19848692)
__________________
('For iter ', 300)
('Accuracy ', 0.921875)
('Loss ', 0.21564379)
__________________
('For iter ', 400)
('Accuracy ', 0.9453125)
('Loss ', 0.18260565)
__________________
('For iter ', 500)
('Accuracy ', 0.9453125)
('Loss ', 0.15409306)
__________________
('For iter ', 600)
('Accuracy ', 0.9609375)
('Loss ', 0.14374393)
__________________
('For iter ', 700)
('Accuracy ', 0.9375)
('Loss ', 0.15059316)
__________________
('For iter ', 800)
('Accuracy ', 0.984375)
('Loss ', 0.063490272)
__________________
('Testing Accuracy:', 0.9609375)


可以，在测试集上的准确率达到了96.09%。