# 基于CharRNN的藏头诗生成

由于RNN自身的结构特性和设计原理，其在处理文本内容时有着天然的优势。
在这一个示例中，我们将会介绍在给定句首的情况下，如何使用字符集别的RNN网络来生成工整和韵的藏头诗。当然，除了诗歌之外，char-rnn还可以被用来生成各种蕴含时序结构的内容，包括一段文字、一段使用tex语法书写的文章、一段代码等，详细内容可以参考[这篇博客](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)。

### 1. 网络结构
本示例中将使用的charRNN网络结构如下:
![char-rnn](img/char-rnn.svg)
它由两层RNN单元Stack得到，在每一个时间点，输入是一个字符，例如图中输入的第一个字符是“床”，通过两层RNN单元运算后得到一个向量，将这个向量经过一层Dense Connect和softMax得到一个在全字典上的输出概率，在训练的时候，我们期望的输出是“前”，这里直接用交叉熵来作为损失函数；接下来，我们将“前”作为下一个时间点的输入，得到新的字符。而在训练完毕后的生成中，为了结果的多样性，每次可以选择topK概率的字符作为输出。

### 2. 数据准备
在github上已经有许多完善的古诗词仓库，我们这里选用[chinese-poetry](https://github.com/chinese-poetry/chinese-poetry)中提供的诗词作为训练样本。

仓库中提供的诗词均以繁体展示，且被存储为如下所示的json格式。
>[{
    "strains": [
      "仄仄平平仄，平平仄仄平。", 
      "仄平平仄仄，平仄仄平平。"
    ],    
    "author": "蓋嘉運", 
    "paragraphs": [
      "聞道黃花戍，頻年不解兵。", 
      "可憐閨裏月，偏照漢家營。"
    ],    
    "title": "雜曲歌辭 伊州 歌第三"
  }, 
  {
    "strains": [
      "平仄平平仄，平平仄仄平。", 
      "仄平平仄仄，平仄仄平平。"
    ],    
    "author": "蓋嘉運", 
    "paragraphs": [
      "千里東歸客，無心憶舊遊。", 
      "挂帆游白水，高枕到青州。"
    ],    
    "title": "雜曲歌辭 伊州 歌第四"
  }, 
  {
    "strains": [
      "仄仄平平仄，平平仄仄平。", 
      "平平平仄仄，仄仄仄平平。"
    ],    
    "author": "蓋嘉運", 
    "paragraphs": [
      "桂殿江烏對，彫屏海燕重。", 
      "秖應多釀酒，醉罷樂高鐘。"
    ],    
    "title": "雜曲歌辭 伊州 歌第五"
  },
  ....
]

我们首先定义提取出需要的五言诗的函数：

In [1]:
import json
import os

def extract_poems(to_file,
                  origin_poetry_dir,
                  poetry_type=5):
    with open(to_file, "w") as fw:
        filenames = filter(lambda name: name.startswith("poet"),
                           os.listdir(origin_poetry_dir))
        for filename in filenames:
            input_file = origin_poetry_dir + "/" + filename
            with open(input_file, encoding="utf-8") as fi:
                json_data = json.load(fi),
                for poem in json_data[0]:
                    paragraphs = poem["paragraphs"]
                    if len(paragraphs) == 2 and all(map(lambda line: len(line) == (poetry_type + 1) * 2, paragraphs)):
                        for two_lines in paragraphs:
                            fw.write(" ".join(two_lines) + " \n")

在网络中字符都需要被embedding为一个向量，在这个示例中，我们使用Word2vector的embedding向量作为char-rnn网络embedding层的初始值

In [2]:
import numpy as np
from gensim.models import Word2Vec

def word2vec(input_file):
    with open(input_file) as fi:
        lines = fi.readlines()
    model = Word2Vec(sentences=lines, size=128, min_count=1)
    return model

同时，我们定义一个dataset类来方便进行batch提取和index和word间的转换

In [3]:
class DataSet:
    def __init__(self, lines: list, index2word: list):
        self._index2word = index2word
        self._word2index = {word: i for (i, word) in enumerate(index2word)}
        self._index_in_epoch = 0
        self._num_examples = len(lines)
        self._epochs_completed = False

        lines = map(lambda line: [self._word2index[word] for word in line.split(" ")], lines)
        self._data = np.array(list(lines))

    def next_batch(self, batch_size):
        start = self._index_in_epoch
        self._index_in_epoch += batch_size
        if self._index_in_epoch > self._num_examples:
            # Finished epoch
            self._epochs_completed += 1
            # Shuffle the data
            perm = np.arange(self._num_examples)
            np.random.shuffle(perm)
            self._data = self._data[perm]
            # Start next epoch
            start = 0
            self._index_in_epoch = batch_size
            assert batch_size <= self._num_examples
        end = self._index_in_epoch
        return self._data[start:end]

    def label2word(self, label):
        return self._index2word[label]

    def word2label(self, word):
        return self._word2index[word]

### 3. 网络定义
在我们的char-rnn网络中，定义的placeholder首先包含输入字符input_ids以及标签input_labels。同时，为了适应输入batch的变化（训练时通常用大于1的batch输入，而预测时输入batch一般为1），我们还需要定义一个batch_size的placeholder。考虑到在TensorFLow静态图的特点，在预测时，我们要取出上一步中的hidden state作为输入，因此还需要定义隐藏状态的placeholder。

在这个char-rnn中我们使用每两句诗作为一个训练样本，这样既可以在一定程度上保持两句之间的关联性，又能避免利用整首诗时过长的依赖。

一些参数和placeholder的定义如下：

In [4]:
import tensorflow as tf
from tensorflow.contrib.rnn import BasicLSTMCell, MultiRNNCell


ORIGIN_POETRY_DIR = "../data/chinese-poetry/json"
EXTRACTED_POETRY_FILE = "../data/train_poetry.txt"
TRAIN_FILE = "../data/train_poetry_s.txt"

EMBEDDING_SIZE = 128
HIDDEN_SIZE = 256
BATCH_SIZE = 128
STACK_LAYER = 2
INPUT_LENGTH = 12


input_ids = tf.placeholder(tf.int32, shape=[None, INPUT_LENGTH], name="input_id")
input_labels = tf.placeholder(tf.int32, shape=[None, INPUT_LENGTH], name="input_label")
input_batch_size = tf.placeholder(tf.int32, shape=())


# 选用标准的LSTM作为基本RNN单元，
stacked_cells = BasicLSTMCell(num_units=HIDDEN_SIZE, state_is_tuple=False)
if STACK_LAYER > 1:
    stacked_cells = MultiRNNCell(
        [BasicLSTMCell(HIDDEN_SIZE, state_is_tuple=False) for _ in range(STACK_LAYER)],
        state_is_tuple=False)
initial_zero_states = stacked_cells.zero_state(input_batch_size, tf.float32)
initial_states = tf.placeholder(tf.float32, initial_zero_states.shape)



定义inference过程：

In [5]:
def inference(input_ids, batch_size, initial_state, initial_embeddings,
              stacked_cells, hidden_size=HIDDEN_SIZE, input_length=INPUT_LENGTH):

    vocab_size, embedding_size = initial_embeddings.shape

    # embedding layer
    with tf.name_scope("embedding"):
        with tf.device("/cpu:0"):
            embeddings = tf.get_variable(name="embedding",
                                         initializer=tf.truncated_normal(
                                             [vocab_size, embedding_size]),
                                         dtype=tf.float32)
            # 3D array [batch, time_stamp, embedding_feature]
            input_tensor = tf.nn.embedding_lookup(embeddings, input_ids)

    # rnn layer
    with tf.variable_scope("rnn") as vs:
        # transpose to [time_stamp, batch, embedding_feature]
        input_tensor = tf.transpose(input_tensor, perm=[1, 0, 2])

        state = initial_state
        states = []
        outputs = []
        for time in range(input_length):
            input_this_time = input_tensor[time]
            if time > 0:
                vs.reuse_variables()
            # output shape: [batch_size, embedding_size]
            output, state = stacked_cells(input_this_time, state)
            outputs.append(output)
            states.append(state)
    
    # softmax layer
    with tf.name_scope("softmax"):
        # [batch, embedding_size * time]
        outputs = tf.concat(outputs, axis=1)
        # [batch * time, embedding_size ]
        outputs = tf.reshape(outputs, [-1, hidden_size])
        w = tf.Variable(initial_value=tf.truncated_normal(shape=[hidden_size, vocab_size]), name="w")
        b = tf.Variable(initial_value=tf.zeros([vocab_size]), name="b")
        logits = tf.matmul(outputs, w) + b
        logits = tf.reshape(logits, [batch_size, input_length, vocab_size])

    return logits, states[0]

在预测时我们需要的是RNN单元在第一个时间点的输出，由于TensorFLow静态图的特点，要取出每一次运行后的第一个输出和状态，作为下一次运行的输入，因此预测的函数如下：

In [10]:
def predict(head_words, sess, logits, first_out_state, input_pl, batch_size_pl, initial_states_pl,
            zero_states, dataset, output_length=INPUT_LENGTH):
    ret_list = ["1 ", "2 ", "3 ", "4 "]
    for (i, head_word) in enumerate(head_words):
        if i % 2 != 0:
            continue
        state = zero_states
        next_input = head_word
        ret_list[i] += next_input
        for j in range(output_length):
            input_index = dataset.word2label(next_input)
            input_batch = np.zeros([1, output_length], dtype=np.int32)
            input_batch[0][0] = input_index
            logits_val, state = sess.run([logits, first_out_state],
                                         feed_dict={input_pl: input_batch,
                                                    batch_size_pl: 1,
                                                    initial_states_pl: state})

            # 下半句时，需要手动指定
            if j == output_length // 2 - 1:
                next_input = head_words[i+1]
                i += 1
            else:
                next_input = dataset.label2word(np.argmax(logits_val[0][0]))

            if next_input == "\n":
                ret_list[i] += "\\n"
            else:
                ret_list[i] += next_input
    return ret_list

定义整个网络结构

In [7]:
# 提取五言诗
extract_poems(to_file=EXTRACTED_POETRY_FILE, origin_poetry_dir=ORIGIN_POETRY_DIR, poetry_type=5)

# 为了看起来方便，使用opencc将繁体字转换为简体
os.system("opencc -i {0} -o {1} -c t2s".format(EXTRACTED_POETRY_FILE, TRAIN_FILE))

w2v_model = word2vec(TRAIN_FILE)

initial_embeddings = w2v_model.wv.syn0

logits, first_state = inference(input_ids, input_batch_size, initial_states,
                                               initial_embeddings, stacked_cells)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=input_labels, logits=logits))

train_op = tf.train.AdamOptimizer(0.001).minimize(loss)

init_op = tf.global_variables_initializer()

训练过程

In [11]:
# 载入数据
with open(TRAIN_FILE) as fi:
    lines = fi.readlines()
data = DataSet(lines, w2v_model.wv.index2word)

def get_zero_states(batch_size):
    return np.zeros(shape=[batch_size, initial_zero_states.shape[1]], dtype=np.float32)

# 最大迭代次数
max_iter = 100000

config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.4
with tf.Session(graph=loss.graph, config=config) as sess:
    sess.run(init_op)
    for iter in range(max_iter):
        input_batch = data.next_batch(BATCH_SIZE)
        label_batch = input_batch[:, 1:]
        input_batch = input_batch[:, :-1]
        *_, loss_val = sess.run([train_op, logits, loss],
                                feed_dict={ input_ids: input_batch,
                                            input_labels: label_batch,
                                            input_batch_size: BATCH_SIZE,
                                            initial_states: get_zero_states(BATCH_SIZE)})
        if iter % 100 == 0:
            print("iter {0}/{1} with loss: {2:.3}".format(iter, max_iter, loss_val))
            head_words = "星环科技"
            ret = predict(head_words, sess,
                          logits, first_state,                          # output tensor of inference
                          input_ids, input_batch_size, initial_states,  # place holder
                          get_zero_states(1), data)
            for line in ret:
                print(line)
            print("===========================")

iter 0/100000 with loss: 8.81
1 星盟鳣梦梦梦
2 环梦粱粱梦覻梦
3 科骓隼怯喝湘
4 技翘业鹂鹂悭閙
iter 100/100000 with loss: 5.89
1 星不不知，不
2 环不知知。\n。
3 科，不不知知
4 技。\n。\n。\n
iter 200/100000 with loss: 5.47
1 星不知知处，
2 环不知知中。\n
3 科不知知处，
4 技不知不知。\n
iter 300/100000 with loss: 5.32
1 星不知知处，
2 环人不知春。\n
3 科不知知处，
4 技不知不知。\n
iter 400/100000 with loss: 5.24
1 星不知处处，
2 环是不可心。\n
3 科不知处处，
4 技是不可心。\n
iter 500/100000 with loss: 5.2
1 星人不可见，
2 环子不可心。\n
3 科不知君子，
4 技下不可人。\n
iter 600/100000 with loss: 4.98
1 星人不可见，
2 环子不可见。\n
3 科不知君处，
4 技下不可心。\n
iter 700/100000 with loss: 5.06
1 星人不见处，
2 环人不知处。\n
3 科来不知处，
4 技来一一枝。\n
iter 800/100000 with loss: 4.84
1 星人不可见，
2 环子不可怜。\n
3 科人不知处，
4 技来一一片。\n
iter 900/100000 with loss: 4.9
1 星人不可见，
2 环山不可见。\n
3 科中有一事，
4 技地不可怜。\n


KeyboardInterrupt: 