In [1]:
import collections
import os
import numpy as np
import tensorflow as tf
import tensorflow.contrib.legacy_seq2seq as seq2seq

collections是python内建的一个集合模块，提供了许多有用的集合类  
seq2seq用于损失函数，这里使用交叉熵交叉熵

In [2]:
BEGIN_CHAR = '['
END_CHAR = ']'
UNKNOWN_CHAR = '*'
MAX_LENGTH = 100
MIN_LENGTH = 10
max_words = 3000
epochs = 30
poetry_file = './data/poetry.txt'
save_dir = 'log'

部分参数，确定数据集的位置以及checkpoint保存路径

# 唐诗数据封装类

In [3]:
class Data:
    def __init__(self):
        self.batch_size = 64
        self.poetry_file = poetry_file
        self.load()
        self.create_batches()

    def load(self):
        def handle(line):
            """
            处理过长的唐诗，并为每首诗添加开始符 '[' 和结束符 ']'
            """
            if len(line) > MAX_LENGTH:
                index_end = line.rfind('。', 0, MAX_LENGTH)
                index_end = index_end if index_end > 0 else MAX_LENGTH
                line = line[:index_end + 1]
            return BEGIN_CHAR + line + END_CHAR

        self.poetrys = []
        lines = open(self.poetry_file, encoding='utf-8')
        self.poetrys = [line.strip().replace(' ', '').split(':')[1] for line in lines]
        # 过滤短的唐诗
        self.poetrys = [handle(line) for line in self.poetrys if len(line) > MIN_LENGTH]
        
        # 统计所有出现的字
        words = []
        for poetry in self.poetrys:
            words += [word for word in poetry]
        counter = collections.Counter(words)
        count_pairs = sorted(counter.items(), key=lambda x: -x[1])
        words, _ = zip(*count_pairs)
        
        # 取出现频率最高的词的数量组成字典，不在字典中的字用'*'代替
        words_size = min(max_words, len(words))
        self.words = words[:words_size] + (UNKNOWN_CHAR,)
        self.words_size = len(self.words)

        # 建立字典char2id 和 id2char
        self.char2id_dict = {w: i for i, w in enumerate(self.words)}
        self.id2char_dict = {i: w for i, w in enumerate(self.words)}
        self.unknow_char = self.char2id_dict.get(UNKNOWN_CHAR)
        self.char2id = lambda char: self.char2id_dict.get(char, self.unknow_char)
        self.id2char = lambda num: self.id2char_dict.get(num)
        
        # 诗句->向量
        self.poetrys = sorted(self.poetrys, key=lambda line: len(line))
        self.poetrys_vector = [list(map(self.char2id, poetry)) for poetry in self.poetrys]
        # map(function, iterable, ...)

    def create_batches(self):
        self.n_size = len(self.poetrys_vector) // self.batch_size
        self.poetrys_vector = self.poetrys_vector[:self.n_size * self.batch_size]
        # 去掉末尾不足一个batch_sieze的数据
        self.x_batches = []
        self.y_batches = []
        for i in range(self.n_size):
            batches = self.poetrys_vector[i * self.batch_size: (i + 1) * self.batch_size]
            length = max(map(len, batches))
            
            # 把每首诗填充成长度为 length (每个batch中最长的诗) 的向量
            for row in range(self.batch_size):
                if len(batches[row]) < length:
                    r = length - len(batches[row])
                    batches[row][len(batches[row]): length] = [self.unknow_char] * r
            # 网络的输入是xdata,输出是ydata。ydata由xdata向后移动一位得到        
            xdata = np.array(batches)
            ydata = np.copy(xdata)
            ydata[:, :-1] = xdata[:, 1:]
            self.x_batches.append(xdata)
            self.y_batches.append(ydata)

# 模型构建类

In [4]:
class Model:
    def __init__(self, data, model='lstm', infer=False):
        # 设置RNN网络层超参数
        self.rnn_size = 128
        self.n_layers = 2
        
        # 如果是测试模式，设置batch_size为1，即每次只预测一首诗
        if infer:
            self.batch_size = 1
        else:
            self.batch_size = data.batch_size
        
        # 可选的RNN变体，本次实验选LSTM
        if model == 'rnn':
            cell_rnn = tf.nn.rnn_cell.BasicRNNCell
        elif model == 'gru':
            cell_rnn = tf.nn.rnn_cell.GRUCell
        elif model == 'lstm':
            cell_rnn = tf.nn.rnn_cell.LSTMCell
       
        # 开始构建图
        
        # 2层LSTM堆叠作为诗的表示层
        cell = cell_rnn(self.rnn_size, name='basic_lstm_cell')
        self.cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self.n_layers)
        
        # 设置计算图xdata和ydata的占位符，可理解为形参，在执行run方法将实参传入
        self.x_tf = tf.placeholder(tf.int32, [self.batch_size, None])
        self.y_tf = tf.placeholder(tf.int32, [self.batch_size, None])
        
        #LSTM网络层初始化，每次初始化一个batch
        self.initial_state = self.cell.zero_state(self.batch_size, tf.float32)
        
        # 设置变量空间,可设置多个参数，这里只是简单的使用默认设置。（类似于C++命名空间）
        with tf.variable_scope('rnnlm'):
            softmax_w = tf.get_variable("softmax_w", [self.rnn_size, data.words_size])
            softmax_b = tf.get_variable("softmax_b", [data.words_size])
            # 将one-hot向量映射成符合RNN网络输入的向量
            embedding = tf.get_variable("embedding", [data.words_size, self.rnn_size])
            inputs = tf.nn.embedding_lookup(embedding, self.x_tf)

        # 前馈计算
        outputs, final_state = tf.nn.dynamic_rnn(self.cell, 
                                                 inputs, 
                                                 initial_state=self.initial_state, 
                                                 scope='rnnlm') 
        """
        创建批数据的时候，每首诗填充成长度为 length (每个batch中最长的诗),因此每个batch之间的shape可能不一致。
        tf.nn.dynamic_rnn适合输入的shape不同的情况，tf.nn.rnn必须要求输入的shape必须一致。
        上面的initial_state，final_state神经网络隐藏层单元状态值最后一个单元的值
        """ 
        
        self.output = tf.reshape(outputs, [-1, self.rnn_size])
        self.logits = tf.matmul(self.output, softmax_w) + softmax_b
        
        # 字典所有词的概率分布，在测试时会用到
        self.probs = tf.nn.softmax(self.logits)
        self.final_state = final_state
        pred = tf.reshape(self.y_tf, [-1])
        # seq2seq
        loss = seq2seq.sequence_loss_by_example([self.logits],
                                                [pred],
                                                [tf.ones_like(pred, dtype=tf.float32)],)
        self.cost = tf.reduce_mean(loss)
        self.learning_rate = tf.Variable(0.0, trainable=False)
        tvars = tf.trainable_variables()
        grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), 5)
        optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = optimizer.apply_gradients(zip(grads, tvars))
        """
        gradient clipping(梯度修剪)的引入是为了处理梯度消失或者梯度爆炸的问题。
        当在一次迭代中权重的更新过于迅猛的话，很容易导致loss发散。
        clipping让权重的更新限制在一个合适的范围。
        apply_gradients：计算得到的梯度来更新对应的variable
        """


In [5]:
tf.nn.dynamic_rnn?

# 训练函数

In [6]:
def train(data, model):
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        elif len(os.listdir(save_dir))>0:
            model_file = tf.train.latest_checkpoint(save_dir)
            saver.restore(sess, model_file)
        n = 0
        for epoch in range(epochs):
            # 逐级降低学习率
            sess.run(tf.assign(model.learning_rate, 0.002 * (0.97 ** epoch)))
            pointer = 0
            for batch in range(data.n_size):
                n += 1
                feed_dict = {model.x_tf: data.x_batches[pointer], model.y_tf: data.y_batches[pointer]}
                pointer += 1
                train_loss, _, _ = sess.run([model.cost, model.final_state, model.train_op], feed_dict=feed_dict)
                info = "{}/{} (epoch {}) | train_loss {:.3f}" \
                    .format(epoch * data.n_size + batch,
                            epochs * data.n_size, epoch, train_loss)
                print(info)
                # save
                if (epoch * data.n_size + batch) % 1000 == 0 \
                        or (epoch == epochs-1 and batch == data.n_size-1):
                    checkpoint_path = os.path.join(save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=n)
                    print("model saved to {}".format(checkpoint_path))
            print('\n')

# 测试函数

In [7]:
def test(data, model, head=u''):
    def to_word(weights):
        t = np.cumsum(weights)
        s = np.sum(weights)
        sample = int(np.searchsorted(t, np.random.rand(1) * s))
        return data.id2char(sample)
        """
        有两种采用策略，一种是贪婪采样（greedy sampling）,另一种是随机采样。其实前一种可以看作后一种的特例。
        随机就是根据sofmax输出的概率进行采样，如果预测的一个词概率是10%，那么它被采样的概率也是10%。
        """
    for word in head:
        if word not in data.words:
            return u'{} 不在字典中'.format(word)
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config) as sess:
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver(tf.global_variables())
        model_file = tf.train.latest_checkpoint(save_dir)
        # print(model_file)
        saver.restore(sess, model_file)

        if head:
            print('生成藏头诗 ---> ', head)
            poem = BEGIN_CHAR
            for head_word in head:
                poem += head_word
                x = np.array([list(map(data.char2id, poem))])
                state = sess.run(model.cell.zero_state(1, tf.float32))
                feed_dict = {model.x_tf: x, model.initial_state: state}
                [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
                word = to_word(probs[-1])
                while word != u'，' and word != u'。':
                    poem += word
                    x = np.zeros((1, 1))
                    x[0, 0] = data.char2id(word)
                    [probs, state] = sess.run([model.probs, model.final_state],
                                              {model.x_tf: x, model.initial_state: state})
                    word = to_word(probs[-1])
                poem += word
            return poem[1:]
        else:
            poem = ''
            head = BEGIN_CHAR
            x = np.array([list(map(data.char2id, head))])
            state = sess.run(model.cell.zero_state(1, tf.float32))
            feed_dict = {model.x_tf: x, model.initial_state: state}
            [probs, state] = sess.run([model.probs, model.final_state], feed_dict)
            word = to_word(probs[-1])
            while word != END_CHAR:
                poem += word
                x = np.zeros((1, 1))
                x[0, 0] = data.char2id(word)
                [probs, state] = sess.run([model.probs, model.final_state],
                                          {model.x_tf: x, model.initial_state: state})
                word = to_word(probs[-1])
            return poem

# 训练配置

In [8]:
 # 设置用哪一块GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
mode = "test"  # 默认训练模式
head = u'明月别枝惊鹊'
# head = u''
assert mode in ["train","test"]
print("%s mode..." % mode)

test mode...


# 训练测试入口

In [9]:
if mode == 'train':
    data = Data()
    model = Model(data=data, infer=False)
    train(data, model) 
        
elif mode == 'test':
    data = Data()
    model = Model(data=data, infer=True)
    poem = test(data, model, head=head)
    print(poem)

INFO:tensorflow:Restoring parameters from log\model.ckpt-9001
生成藏头诗 --->  明月别枝惊鹊
明月今华地，月中逢久闲。别中莺左右，枝汉半时秋。惊净妆行晚，鹊楼红艳波。


In [10]:
data.poetry_file

'./data/poetry.txt'

In [11]:
data.id2char_dict

{0: '，',
 1: '。',
 2: '[',
 3: ']',
 4: '不',
 5: '人',
 6: '山',
 7: '风',
 8: '日',
 9: '无',
 10: '一',
 11: '云',
 12: '来',
 13: '有',
 14: '花',
 15: '春',
 16: '何',
 17: '天',
 18: '水',
 19: '上',
 20: '月',
 21: '中',
 22: '时',
 23: '年',
 24: '相',
 25: '长',
 26: '生',
 27: '君',
 28: '秋',
 29: '心',
 30: '自',
 31: '为',
 32: '归',
 33: '知',
 34: '白',
 35: '如',
 36: '行',
 37: '见',
 38: '去',
 39: '江',
 40: '夜',
 41: '清',
 42: '此',
 43: '空',
 44: '在',
 45: '下',
 46: '高',
 47: '里',
 48: '得',
 49: '未',
 50: '客',
 51: '门',
 52: '处',
 53: '明',
 54: '寒',
 55: '多',
 56: '青',
 57: '是',
 58: '落',
 59: '雨',
 60: '声',
 61: '金',
 62: '远',
 63: '家',
 64: '千',
 65: '南',
 66: '玉',
 67: '三',
 68: '事',
 69: '路',
 70: '前',
 71: '今',
 72: '城',
 73: '出',
 74: '子',
 75: '草',
 76: '入',
 77: '朝',
 78: '道',
 79: '东',
 80: '万',
 81: '新',
 82: '树',
 83: '烟',
 84: '飞',
 85: '开',
 86: '流',
 87: '尽',
 88: '深',
 89: '别',
 90: '色',
 91: '思',
 92: '回',
 93: '应',
 94: '西',
 95: '酒',
 96: '马',
 97: '地',
 98: '闲',
 99: '已',
 100: '还',