# 使用RNN训练一个古诗词生成器

## 1.数据预处理

In [3]:
import codecs
from collections import Counter

# 1.过滤
poetrys = []
with codecs.open('./poetry.txt', encoding='utf-8') as fr:
    for line in fr:
        try:
            title, content = line.strip().split(':')
            content = content.replace(' ', '')
            if u'（' in content or u'(' in content or u'《' in content or u'_' in content or u'[' in content:
                continue
            if len(content) < 5 or len(content) > 80:
                continue
            content = '[' + content + ']'
            poetrys.append(content)
        except Exception as e:
            pass
print u'唐诗总数：', len(poetrys)
poetrys = sorted(poetrys, key=lambda x: len(x))

# 2.构建word和id的双向映射表
cnt = Counter(''.join(poetrys)).most_common()
words_sorted_by_freq = zip(*cnt)[0] + (' ', )
print '不重复字总数：', len(words_sorted_by_freq)
id2word = dict(enumerate(words_sorted_by_freq))
word2id = dict(zip(id2word.values(), id2word.keys()))

# 3.将诗歌转为id向量形式
poetrys_vec = [[word2id[word] for word in poetry] for poetry in poetrys]
word_nums = len(words_sorted_by_freq)
print poetrys_vec[1000]

唐诗总数： 34813
不重复字总数： 6122
[3, 279, 573, 114, 422, 973, 0, 487, 104, 468, 872, 1209, 1, 10, 14, 226, 212, 3451, 0, 31, 49, 98, 6, 12, 1, 2]


## 2.定义batch函数

In [4]:
import numpy as np
def next_batch(data, batch_size, num_steps):
    chunk_nums = len(data) / batch_size
    batch_len = chunk_nums if num_steps == -1 else min(chunk_nums, num_steps)
    for i in range(batch_len):
        batches = data[i*batch_size:(i+1)*batch_size]
        feature_nums = max(map(len, batches))
        x = np.full((batch_size, feature_nums), word2id[' '], np.int32)
        for row, poetry in enumerate(batches):
            x[row, :len(poetry)] = poetry
        y = np.copy(x)
        y[:, :-1] = x[:, 1:]
        yield x, y

思考：对于RNN来讲，**rnn_size(hidden state size)**与输入序列的**元素**的嵌入向量维度相同（此处与字的嵌入维度相关<比如使用BOW模型>），但是由于序列的长度长短不一，这不利于RNN的训练，因此**一般**都会对输入序列做**padding**操作（这样带来选择：1.手动做padding处理 2.使用tensorflow做处理(**dynamic_rnn甚至可以动态构建，而不需要padding**)），使序列长度均为**time_step**步；

具体信息，请参见：
 - http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/
 - http://r2rt.com/recurrent-neural-networks-in-tensorflow-ii.html

## 3.定义RNN网络

### 3.1.引入依赖包

In [5]:
import numpy as np
import tensorflow as tf

def reset_graph():
    if 'sess' in globals() and sess:
        sess.close()
    tf.reset_default_graph()

# reset_graph()
# tf.contrib.seq2seq.sequence_loss?
# tf.variable_scope?

### 3.2.定义RNN网络模型

In [6]:
def MultiLayerRNN(cell_type='lstm', hstate_size=128, layer_nums=2, learning_rate=1e-4, batch_size=64):
    
    if cell_type == 'lstm':
        cell_func = tf.contrib.rnn.BasicLSTMCell
    elif cell_type == 'gru':
        cell_func = tf.contrib.rnn.GRUCell
    elif cell_type == 'rnn':
        cell_func = tf.contrib.rnn.BasicRNNCell
    
    x = tf.placeholder(tf.int32, [None, None]) # 第一个是batch_size；第二个是序列长度，即time_step(由于这个是不定长的，所以是None)
    y = tf.placeholder(tf.int32, [None, None])

    # 1.embedding layer
    with tf.variable_scope('embedding'):
        embeddings = tf.get_variable('embedding_matrix', [word_nums, hstate_size])
        # Note that our inputs are no longer a list, but a tensor of dims batch_size x num_steps x state_size
        rnn_inputs = tf.nn.embedding_lookup(embeddings, x)
    
    # 2.定义rnn layer
    cell = cell_func(hstate_size, state_is_tuple=True)
    cell = tf.contrib.rnn.MultiRNNCell([cell]*layer_nums, state_is_tuple=True)
    initial_cstate = cell.zero_state(batch_size, tf.float32)
    rnn_outputs, final_cstate = tf.nn.dynamic_rnn(cell, rnn_inputs, initial_state=initial_cstate)
    
    with tf.variable_scope('softmax'):
        W = tf.get_variable('W', [hstate_size, word_nums])
        b = tf.get_variable('b', [word_nums], initializer=tf.constant_initializer(0.0))
    
    # reshape rnn_outputs and y so we can get the logits in a single matmul
    rnn_outputs = tf.reshape(rnn_outputs, [-1, hstate_size])
    y_reshaped = tf.reshape(y, [-1])

    # 3.定义output layer
    logits = tf.matmul(rnn_outputs, W) + b
    probs  = tf.nn.softmax(logits)
    total_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y_reshaped))
    #train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
    
    # 手动控制梯度的传播
    tvars = tf.trainable_variables()
    # grads_and_tvars = tf.compute_gradients(total_loss, tvars)
    # grads = zip(*grads_and_tvars)[0]
    grads = tf.gradients(total_loss, tvars)
    clipped_grads, _ = tf.clip_by_global_norm(grads, 5)
    opti  = tf.train.AdamOptimizer(learning_rate)
    train_step = opti.apply_gradients(zip(clipped_grads, tvars))
    
    saver = tf.train.Saver(tf.global_variables())
    
    return dict(
        x = x,
        y = y,
        initial_cstate = initial_cstate,
        final_cstate = final_cstate,
        probs = probs,
        total_loss = total_loss,
        train_step = train_step,
        saver = saver
    )

### 3.3.定义训练RNN网络

In [9]:
def train_network(g, num_epochs, num_steps = -1, batch_size = 64, verbose = True, save=False):
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        training_losses = []
        for epoch in range(num_epochs):
            training_loss = 0
            steps = 0
            training_state = None
            for X, Y in next_batch(poetrys_vec, batch_size, num_steps):
                steps += 1
                feed_dict={g['x']: X, g['y']: Y}
                if training_state is not None:
                    feed_dict[g['initial_cstate']] = training_state
                training_loss_, training_state, _ = sess.run([g['total_loss'],
                                                      g['final_cstate'],
                                                      g['train_step']],
                                                             feed_dict)
                training_loss += training_loss_
            if verbose:
                print("Average training loss for Epoch {}/{}:{}".format(epoch+1, num_epochs, training_loss/steps))
            training_losses.append(training_loss/steps)

            if save and (epoch+1) % 5 == 0:
                g['saver'].save(sess, 'poetry_gen_model', global_step=epoch)

    return training_losses

### 3.5.训练RNN

In [10]:
reset_graph()
g = MultiLayerRNN()
train_network(g, 100, save=True)

Average training loss for Epoch 1/100:6.86393547409
Average training loss for Epoch 2/100:6.42438518584
Average training loss for Epoch 3/100:6.41937117111
Average training loss for Epoch 4/100:6.41712725272
Average training loss for Epoch 5/100:6.4161247847
Average training loss for Epoch 6/100:6.4156662667
Average training loss for Epoch 7/100:6.41529980285
Average training loss for Epoch 8/100:6.4113553769
Average training loss for Epoch 9/100:6.36540350642
Average training loss for Epoch 10/100:6.30113961025
Average training loss for Epoch 11/100:6.14469979306
Average training loss for Epoch 12/100:6.03695265189
Average training loss for Epoch 13/100:5.98728574661
Average training loss for Epoch 14/100:5.94861563705
Average training loss for Epoch 15/100:5.89806809873
Average training loss for Epoch 16/100:5.82521183934
Average training loss for Epoch 17/100:5.78797029702
Average training loss for Epoch 18/100:5.76041570597
Average training loss for Epoch 19/100:5.73833363798
Avera

[6.8639354740936671,
 6.424385185838843,
 6.4193711711117798,
 6.4171272527227528,
 6.4161247847049498,
 6.4156662666995219,
 6.4152998028539159,
 6.4113553769022058,
 6.3654035064177403,
 6.3011396102483763,
 6.1446997930551319,
 6.0369526518862111,
 5.9872857466147966,
 5.948615637053881,
 5.8980680987321215,
 5.825211839342205,
 5.7879702970248337,
 5.7604157059671248,
 5.7383336379822225,
 5.7199812154963094,
 5.6996603811204327,
 5.6804684292984708,
 5.6583372874813183,
 5.6377923317377077,
 5.6239160894249904,
 5.6139115400296768,
 5.5974927738885194,
 5.5840448700920655,
 5.5702980254016969,
 5.5572971533675224,
 5.5435449317454415,
 5.5307136477686427,
 5.5186498468093452,
 5.5063811182756233,
 5.4955081404023849,
 5.4840180579689548,
 5.4729930401726543,
 5.4635661003997971,
 5.4525784515324656,
 5.4424307289264036,
 5.4321890676438702,
 5.4218463774763634,
 5.4110386894113667,
 5.3999513130820258,
 5.388680164765697,
 5.3774263274823326,
 5.3665183675003751,
 5.35539175704478

### 3.6.诗词生成（RNN模型重用）

In [168]:
import numpy as np

def gen_poetry():
    
    def prob2word(weights):
        cum_weights = np.cumsum(weights)
        sum_weights = np.sum(weights)
        idx = np.searchsorted(cum_weights, np.random.rand(1)*sum_weights)[0]
        # wordid = np.random.choice(words_sorted_by_freq, 1, p=probs)[0] # sum of probs is not be 1
        return id2word[idx]
    
    reset_graph()
    g = MultiLayerRNN(batch_size=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        g['saver'].restore(sess, './poetry_gen_model-55')

        X = np.zeros((1, 1), dtype=np.int32)
        X[0, 0] = word2id['[']
        feed_dict = {g['x']: X}
        probs_, state_ = sess.run([g['probs'], g['final_cstate']], feed_dict=feed_dict)
        word = prob2word(probs_)
        poetry = ''
        while word != ']':
            poetry += word
            X[0, 0] = word2id[word]
            feed_dict = {g['x']: X, g['initial_cstate']: state_}
            probs_, state_ = sess.run([g['probs'], g['final_cstate']], feed_dict=feed_dict)
            word = prob2word(probs_)   
        return poetry
    
print gen_poetry()

生东山市雀客，四邻来四邻。


### 3.7.藏头诗生成

In [162]:
import numpy as np

def gen_poetry_with_head(heads):
    
    def prob2word(weights):
        cum_weights = np.cumsum(weights)
        sum_weights = np.sum(weights)
        idx = np.searchsorted(cum_weights, np.random.rand(1)*sum_weights)[0]
        return id2word[idx]
    
    reset_graph()   
    g = MultiLayerRNN(batch_size=1)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        g['saver'].restore(sess, './poetry_gen_model-59')

        i = 0
        state_ = None
        poetry = ''
        X = np.zeros((1, 1), dtype=np.int32)
        for word in heads:
            while word != u'，' and word != u'。':
                poetry += word
                X[0, 0] = word2id[word]
                feed_dict = {g['x']: X}
                if state_ is not None:
                    feed_dict[g['initial_cstate']] = state_
                probs_, state_ = sess.run([g['probs'], g['final_cstate']], feed_dict=feed_dict)
                word = prob2word(probs_)
            if i % 2 == 0:
                poetry += u'，'
            else:
                poetry += u'。\n'
            i += 1
        return poetry
    
print gen_poetry_with_head(u'一二三四')

一战眼，二壁。
三守，四。



### 3.8.网络服务

In [None]:
import web

render = web.template.render('template/')
urls = (
    '/', 'index'
)

class index:
    def GET(self):
        heads = web.input.heads
        poem = gen_poetry_with_head(head)
        return render.index(poem)
    
if __name__ == '__main__':
    app = web.application(urls, globals())
    app.run()