## 1. 任务定义
---
语言模型——即对语句的概率分布的建模，定义如下： 

对于语言序列：$w_1,w_2,\dots,w_n$, 计算该序列的概率，即 $P(w_1,w_2,\dots,w_n)$
主要有以下`n-gram` 和 深度序列生成方法

深度序列模型一般可以分为三个模块：嵌入层、特征层、输出层．

### 嵌入层

![Snipaste_2020-08-22_13-43-47.png](http://ww1.sinaimg.cn/large/005XIOOugy1ghzjeu51a7j30hq07p74i.jpg)
### 特征层
![Snipaste_2020-08-22_13-44-02.png](http://ww1.sinaimg.cn/large/005XIOOugy1ghzjfljqg1j30rc0da75b.jpg)
### 输出层
输出层一般采用softmax函数来进行分类

## 2.环境准备

In [29]:
import pickle
import os
import json
import re
import collections

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence,pad_packed_sequence

## 3. 配置参数

In [23]:
class Config(object):
    data_path = "../data/poetry.txt"

    batch_size = 16
    num_epochs = 200
    shuffle = True

    embedding_size = 256  # 嵌入词向量维度
    hidden_size = 256    # 隐状态维度
    num_layers = 3     # rnn的层数
    dropout_rate = 0.5
    bidirectional = False
    
    lr = 0.01           # 学习率
    
    max_len = 125 # 诗句最大长度
    max_gen_len = 128
    
    START = '<START>'
    EOP = '<EOP>'
    PAD = '<PAD>'
    
config = Config()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## 3. 处理数据

In [26]:
def parse_raw_data(data_path):
#     poetries = []
#     with open(data_path,'r',encoding='utf-8') as f:
#         lines = f.readlines()
#         poetry =""
#         for line in lines:
#             line = line.strip()
#             if len(line)>0:
#                 poetry += line
#             elif len(poetry)>0:
#                 poetries.append(poetry)
#                 poetry = ""
#     print("共有唐诗{}首".format(len(poetries)))
#     return poetries

    poetries = []
    with  open(config.data_path,'r',encoding='utf-8') as f:
        lines = f.readlines()
        for line in lines:
            try:
                line = line.strip()
                title,content = line.split(u':',1)
                if u'_' in content or u'(' in content or u'（' in content or u'《' in content or u'[' in content:  
                        continue  
                if len(content) < 5 or len(content) > 79:  
                        continue  
                content = u'[' + content + u']'  
                poetries.append(content)  
            except Exception as e:   
                pass  

        
    poetries = sorted(poetries,key=lambda line: len(line))  
    print('共有唐诗: ', len(poetries))
    return poetries
 
           
    
def get_data(config):
    # 1.获取数据
    data = parse_raw_data(config.data_path)
    
    # 2. 构建词典
    words = [w for sent in data for w in sent]
    words = list(set(words))
    word_list = [config.PAD,config.START,config.EOP]
    word_list.extend(words)
    
    word2id = {w:i for i,w in enumerate(word_list)}
    id2word = {i:w for w,i in list(word2id.items())}
    
    # 3. 处理每首诗
    for i in range(len(data)):
        data[i] =[config.START] + list(data[i]) + [config.EOP]
     
    #文字转id
    data_id = [[word2id[w] for w in sent] for sent in data]
    
    return data_id,word2id,id2word

data_id,word2id,id2word = get_data(config)

共有唐诗:  34647


[[1, 4102, 5717, 142, 5666, 3507, 5639, 580, 2],
 [1, 4102, 1868, 45, 3589, 4074, 5639, 580, 2],
 [1, 4102, 418, 3015, 5266, 3339, 5639, 580, 2],
 [1, 4102, 3200, 4030, 2045, 32, 5639, 580, 2],
 [1, 4102, 1273, 2738, 2510, 5748, 5639, 580, 2],
 [1, 4102, 1246, 78, 1705, 2489, 328, 5639, 580, 2],
 [1, 4102, 459, 2122, 3457, 1503, 5595, 5639, 580, 2],
 [1, 4102, 5258, 187, 5617, 4584, 3517, 5639, 580, 2],
 [1, 4102, 4710, 2314, 2583, 5714, 2106, 5639, 580, 2],
 [1, 4102, 2227, 5666, 5617, 603, 4141, 5639, 580, 2],
 [1, 4102, 1868, 45, 3589, 4074, 5050, 5639, 580, 2],
 [1, 4102, 4686, 5862, 3095, 4696, 5862, 5639, 580, 2],
 [1, 4102, 2481, 5717, 1126, 2481, 3329, 5639, 580, 2],
 [1, 4102, 4453, 3234, 5717, 3605, 3758, 5639, 580, 2],
 [1, 4102, 1919, 3597, 3589, 1405, 3782, 5639, 580, 2],
 [1, 4102, 5727, 2497, 1307, 4686, 4945, 5639, 580, 2],
 [1, 4102, 5688, 3589, 1026, 2609, 48, 5639, 580, 2],
 [1, 4102, 975, 950, 950, 6049, 5076, 1556, 5639, 580, 2],
 [1, 4102, 4171, 3865, 3589, 4171, 

In [44]:
config.vocab_size = len(word2id)

## 4. 构建 Dataset 和 DataLoader

In [32]:
class TangDataset(Dataset):
    def __init__(self,data):
        super(TangDataset,self).__init__()
        self.data = data
    
    def __getitem__(self,index):
        return torch.LongTensor(self.data[index])
    
    def __len__(self):
        return len(self.data)
    
def collate_fn(batch_data):
#     length = [data.shape[0] for data in batch_data]
#     max_len = max(length)
    
    batch = pad_sequence(batch_data,batch_first=True,padding_value=0)
    return batch
    
    
data_set = TangDataset(data_id)
data_loader = DataLoader(data_set,config.batch_size,config.shuffle,collate_fn=collate_fn)
for batch in data_loader:
    print(batch)

tensor([[   1, 4102, 2608,  ...,    0,    0,    0],
        [   1, 4102, 1557,  ...,    0,    0,    0],
        [   1, 4102,  984,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  685,  ...,    0,    0,    0],
        [   1, 4102, 2831,  ...,    0,    0,    0],
        [   1, 4102, 4211,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4748,  ..., 5639,  580,    2],
        [   1, 4102, 5894,  ...,    0,    0,    0],
        [   1, 4102, 4136,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 4219,  ...,    0,    0,    0],
        [   1, 4102, 1344,  ...,    0,    0,    0],
        [   1, 4102, 5293,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4692,  ...,    0,    0,    0],
        [   1, 4102, 5236,  ..., 5639,  580,    2],
        [   1, 4102, 1499,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  778,  ...,    0,    0,    0],
        [   1, 4102, 4024,  ...,    0,    0,    0],
        [   1, 4102,  968,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4

        [   1, 4102, 2473,  ...,    0,    0,    0]])
tensor([[   1, 4102,  238,  ...,    0,    0,    0],
        [   1, 4102, 2489,  ...,    0,    0,    0],
        [   1, 4102, 5636,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  653,  ...,    0,    0,    0],
        [   1, 4102, 3785,  ...,    0,    0,    0],
        [   1, 4102, 4656,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 2168,  ...,    0,    0,    0],
        [   1, 4102, 4267,  ...,    0,    0,    0],
        [   1, 4102, 2251,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 2641,  ...,    0,    0,    0],
        [   1, 4102, 5538,  ...,    0,    0,    0],
        [   1, 4102,  960,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2473,  ...,    0,    0,    0],
        [   1, 4102, 5595,  ...,    0,    0,    0],
        [   1, 4102,  881,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 2967,  ...,    0,    0,    0],
        [   1, 4102, 3252,  ...,    0,    0,    0],
        [   1, 4102,  

        [   1, 4102, 4136,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2481,  ...,    0,    0,    0],
        [   1, 4102, 5112,  ...,    0,    0,    0],
        [   1, 4102, 5236,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 1023,  ...,    0,    0,    0],
        [   1, 4102,  975,  ...,    0,    0,    0],
        [   1, 4102, 2898,  ...,    0,    0,    0]])
tensor([[   1, 4102, 1239,  ...,    0,    0,    0],
        [   1, 4102, 4608,  ...,    0,    0,    0],
        [   1, 4102, 4401,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 3589,  ...,    0,    0,    0],
        [   1, 4102, 3541,  ...,    0,    0,    0],
        [   1, 4102, 2969,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3367,  ...,    0,    0,    0],
        [   1, 4102,  174,  ...,    0,    0,    0],
        [   1, 4102, 2281,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5748,  ..., 5639,  580,    2],
        [   1, 4102, 5879,  ..., 5639,  580,    2],
        [   1, 4102, 1

        [   1, 4102,  315,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2738,  ...,    0,    0,    0],
        [   1, 4102,   35,  ...,    0,    0,    0],
        [   1, 4102, 1852,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 3603,  ...,    0,    0,    0],
        [   1, 4102, 2489,  ...,    0,    0,    0],
        [   1, 4102, 1874,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5702,  ...,    0,    0,    0],
        [   1, 4102, 3854,  ...,    0,    0,    0],
        [   1, 4102, 2723,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 6090,  ...,    0,    0,    0],
        [   1, 4102,  881,  ...,    0,    0,    0],
        [   1, 4102,  371,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2340,  ...,    0,    0,    0],
        [   1, 4102, 5051,  ...,    0,    0,    0],
        [   1, 4102, 4608,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 1903,  ..., 5639,  580,    2],
        [   1, 4102, 1262,  ...,    0,    0,    0],
        [   1, 4102,  

        [   1, 4102, 5450,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2489,  ...,    0,    0,    0],
        [   1, 4102, 2983,  ..., 5639,  580,    2],
        [   1, 4102,  281,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 2019,  ..., 5639,  580,    2],
        [   1, 4102, 4608,  ...,    0,    0,    0],
        [   1, 4102, 3529,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4140,  ...,    0,    0,    0],
        [   1, 4102, 3757,  ...,    0,    0,    0],
        [   1, 4102, 1123,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 1390,  ...,    0,    0,    0],
        [   1, 4102, 3318,  ...,    0,    0,    0],
        [   1, 4102, 1336,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3375,  ..., 5639,  580,    2],
        [   1, 4102, 2552,  ...,    0,    0,    0],
        [   1, 4102, 5688,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  936,  ...,    0,    0,    0],
        [   1, 4102, 1356,  ...,    0,    0,    0],
        [   1, 4102, 2

tensor([[   1, 4102,  702,  ..., 5639,  580,    2],
        [   1, 4102, 2507,  ...,    0,    0,    0],
        [   1, 4102, 1642,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5349,  ...,    0,    0,    0],
        [   1, 4102,  372,  ...,    0,    0,    0],
        [   1, 4102, 5051,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2169,  ...,    0,    0,    0],
        [   1, 4102, 5450,  ...,    0,    0,    0],
        [   1, 4102, 5481,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5948,  ...,    0,    0,    0],
        [   1, 4102, 3260,  ...,    0,    0,    0],
        [   1, 4102,  423,  ...,    0,    0,    0]])
tensor([[   1, 4102,  973,  ...,    0,    0,    0],
        [   1, 4102, 2473,  ...,    0,    0,    0],
        [   1, 4102,  155,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 4200,  ...,    0,    0,    0],
        [   1, 4102, 3785,  ...,    0,    0,    0],
        [   1, 4102, 5556,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 4

tensor([[   1, 4102, 4432,  ...,    0,    0,    0],
        [   1, 4102, 3882,  ...,    0,    0,    0],
        [   1, 4102, 3730,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 5692,  ...,    0,    0,    0],
        [   1, 4102,  964,  ...,    0,    0,    0],
        [   1, 4102, 4220,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4608,  ...,    0,    0,    0],
        [   1, 4102, 2400,  ...,    0,    0,    0],
        [   1, 4102, 4686,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 1416,  ..., 5639,  580,    2],
        [   1, 4102, 2514,  ...,    0,    0,    0],
        [   1, 4102, 3236,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2095,  ...,    0,    0,    0],
        [   1, 4102, 2047,  ...,    0,    0,    0],
        [   1, 4102, 3777,  ...,    0,    0,    0],
        ...,
        [   1, 4102,   33,  ...,    0,    0,    0],
        [   1, 4102, 2601,  ..., 5639,  580,    2],
        [   1, 4102, 2749,  ...,    0,    0,    0]])
tensor([[   1, 4102,  

        [   1, 4102, 2216,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2738,  ...,    0,    0,    0],
        [   1, 4102, 6017,  ...,    0,    0,    0],
        [   1, 4102, 3032,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 3695,  ...,    0,    0,    0],
        [   1, 4102, 5236,  ...,    0,    0,    0],
        [   1, 4102, 5584,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5016,  ...,    0,    0,    0],
        [   1, 4102,  585,  ...,    0,    0,    0],
        [   1, 4102, 5382,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5445,  ..., 5639,  580,    2],
        [   1, 4102,  280,  ...,    0,    0,    0],
        [   1, 4102,  722,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4696,  ...,    0,    0,    0],
        [   1, 4102,  210,  ...,    0,    0,    0],
        [   1, 4102,  950,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 2409,  ...,    0,    0,    0],
        [   1, 4102, 5051,  ...,    0,    0,    0],
        [   1, 4102, 6

        [   1, 4102, 1062,  ...,    0,    0,    0]])
tensor([[   1, 4102,  751,  ..., 5639,  580,    2],
        [   1, 4102, 3481,  ...,    0,    0,    0],
        [   1, 4102, 5375,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 1970,  ...,    0,    0,    0],
        [   1, 4102, 5236,  ..., 5639,  580,    2],
        [   1, 4102,  191,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3948,  ..., 5639,  580,    2],
        [   1, 4102,  571,  ...,    0,    0,    0],
        [   1, 4102, 3801,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 3820,  ...,    0,    0,    0],
        [   1, 4102, 2095,  ..., 5639,  580,    2],
        [   1, 4102, 2095,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2489,  ..., 5639,  580,    2],
        [   1, 4102, 2608,  ...,    0,    0,    0],
        [   1, 4102, 4610,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 3962,  ..., 5639,  580,    2],
        [   1, 4102, 3200,  ...,    0,    0,    0],
        [   1, 4102, 2

        [   1, 4102, 1262,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5524,  ..., 5639,  580,    2],
        [   1, 4102, 5051,  ...,    0,    0,    0],
        [   1, 4102, 5702,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 2129,  ...,    0,    0,    0],
        [   1, 4102, 2461,  ...,    0,    0,    0],
        [   1, 4102, 4021,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3236,  ...,    0,    0,    0],
        [   1, 4102, 4735,  ...,    0,    0,    0],
        [   1, 4102, 3659,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5666,  ...,    0,    0,    0],
        [   1, 4102, 1903,  ...,    0,    0,    0],
        [   1, 4102, 2738,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5333,  ...,    0,    0,    0],
        [   1, 4102, 2520,  ...,    0,    0,    0],
        [   1, 4102, 4253,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 1239,  ...,    0,    0,    0],
        [   1, 4102, 1336,  ...,    0,    0,    0],
        [   1, 4102, 3

        [   1, 4102, 3640,  ...,    0,    0,    0]])
tensor([[   1, 4102, 1886,  ...,    0,    0,    0],
        [   1, 4102, 4316,  ...,    0,    0,    0],
        [   1, 4102, 1838,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5469,  ..., 5639,  580,    2],
        [   1, 4102,  936,  ...,    0,    0,    0],
        [   1, 4102,   10,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5051,  ...,    0,    0,    0],
        [   1, 4102,  693,  ...,    0,    0,    0],
        [   1, 4102, 3630,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 2294,  ...,    0,    0,    0],
        [   1, 4102, 6090,  ..., 5639,  580,    2],
        [   1, 4102,  490,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 4828,  ...,    0,    0,    0],
        [   1, 4102,  436,  ...,    0,    0,    0],
        [   1, 4102, 1748,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5958,  ...,    0,    0,    0],
        [   1, 4102,  120,  ...,    0,    0,    0],
        [   1, 4102, 5

        [   1, 4102, 3200,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 4867,  ...,    0,    0,    0],
        [   1, 4102,   78,  ...,    0,    0,    0],
        [   1, 4102,   33,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5435,  ...,    0,    0,    0],
        [   1, 4102, 5349,  ...,    0,    0,    0],
        [   1, 4102, 5688,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4943,  ...,    0,    0,    0],
        [   1, 4102, 2738,  ...,    0,    0,    0],
        [   1, 4102, 5957,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 3213,  ..., 5639,  580,    2],
        [   1, 4102, 5445,  ..., 5639,  580,    2],
        [   1, 4102, 3188,  ...,    0,    0,    0]])
tensor([[   1, 4102,  418,  ...,    2,    0,    0],
        [   1, 4102, 4843,  ...,    0,    0,    0],
        [   1, 4102, 3589,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  132,  ..., 5639,  580,    2],
        [   1, 4102,  880,  ...,    0,    0,    0],
        [   1, 4102, 1

        [   1, 4102, 2239,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5114,  ...,    0,    0,    0],
        [   1, 4102, 2155,  ...,    0,    0,    0],
        [   1, 4102,  287,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 2340,  ...,    0,    0,    0],
        [   1, 4102, 4170,  ...,    0,    0,    0],
        [   1, 4102, 2753,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2473,  ...,    0,    0,    0],
        [   1, 4102, 3591,  ...,    0,    0,    0],
        [   1, 4102, 3882,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 4828,  ...,    0,    0,    0],
        [   1, 4102, 6073,  ..., 5639,  580,    2],
        [   1, 4102, 2376,  ...,    0,    0,    0]])
tensor([[   1, 4102,  988,  ...,    0,    0,    0],
        [   1, 4102, 3367,  ...,    0,    0,    0],
        [   1, 4102,  265,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5293,  ...,    0,    0,    0],
        [   1, 4102, 2753,  ...,    0,    0,    0],
        [   1, 4102, 5

        [   1, 4102, 1872,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 2489,  ...,    0,    0,    0],
        [   1, 4102, 3541,  ...,    0,    0,    0],
        [   1, 4102,  701,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5617,  ...,    0,    0,    0],
        [   1, 4102, 2489,  ...,    0,    0,    0],
        [   1, 4102,  984,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 1063,  ...,    0,    0,    0],
        [   1, 4102, 4372,  ...,    0,    0,    0],
        [   1, 4102, 3882,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  287,  ...,    0,    0,    0],
        [   1, 4102, 1462,  ...,    0,    0,    0],
        [   1, 4102, 4002,  ..., 5639,  580,    2]])
tensor([[   1, 4102,  543,  ...,    0,    0,    0],
        [   1, 4102,  287,  ...,    0,    0,    0],
        [   1, 4102, 2690,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  759,  ...,    0,    0,    0],
        [   1, 4102, 2423,  ...,    0,    0,    0],
        [   1, 4102, 2

        [   1, 4102, 3333,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5945,  ...,    0,    0,    0],
        [   1, 4102, 2407,  ...,    0,    0,    0],
        [   1, 4102, 3785,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  354,  ...,    0,    0,    0],
        [   1, 4102, 3625,  ...,    0,    0,    0],
        [   1, 4102,  701,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 4849,  ...,    0,    0,    0],
        [   1, 4102, 5258,  ..., 5639,  580,    2],
        [   1, 4102, 3198,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 1774,  ...,    0,    0,    0],
        [   1, 4102, 2849,  ...,    0,    0,    0],
        [   1, 4102,  281,  ...,    0,    0,    0]])
tensor([[   1, 4102,  672,  ...,    0,    0,    0],
        [   1, 4102,  702,  ...,    0,    0,    0],
        [   1, 4102, 5665,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5628,  ..., 5639,  580,    2],
        [   1, 4102, 2095,  ...,    0,    0,    0],
        [   1, 4102, 1

            0,    0,    0,    0]])
tensor([[   1, 4102,  603,  ...,    0,    0,    0],
        [   1, 4102, 2963,  ...,    0,    0,    0],
        [   1, 4102, 5016,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  547,  ...,    0,    0,    0],
        [   1, 4102, 5682,  ...,    0,    0,    0],
        [   1, 4102, 3695,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2473,  ...,    0,    0,    0],
        [   1, 4102, 2473,  ...,    0,    0,    0],
        [   1, 4102, 6047,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  751,  ...,    0,    0,    0],
        [   1, 4102,  678,  ...,    0,    0,    0],
        [   1, 4102,  224,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4909,  ..., 5639,  580,    2],
        [   1, 4102, 4954,  ...,    0,    0,    0],
        [   1, 4102, 3598,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  701,  ...,    0,    0,    0],
        [   1, 4102,  653,  ...,    0,    0,    0],
        [   1, 4102, 4749,  ..., 5639,  

        [   1, 4102, 3248,  ...,    0,    0,    0]])
tensor([[   1, 4102,  742,  ...,    0,    0,    0],
        [   1, 4102, 2608,  ...,    0,    0,    0],
        [   1, 4102, 3785,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 4300,  ...,    0,    0,    0],
        [   1, 4102, 1051,  ...,    0,    0,    0],
        [   1, 4102, 2641,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2297,  ...,    0,    0,    0],
        [   1, 4102, 2297,  ...,    0,    0,    0],
        [   1, 4102, 2905,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 3695,  ...,    0,    0,    0],
        [   1, 4102,   67,  ...,    0,    0,    0],
        [   1, 4102, 2022,  ...,    0,    0,    0]])
tensor([[   1, 4102,  603,  ...,    0,    0,    0],
        [   1, 4102, 2655,  ...,    0,    0,    0],
        [   1, 4102, 3854,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5051,  ...,    0,    0,    0],
        [   1, 4102,  984,  ...,    0,    0,    0],
        [   1, 4102, 3

        [   1, 4102, 1012,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5227,  ...,    0,    0,    0],
        [   1, 4102, 5027,  ...,    0,    0,    0],
        [   1, 4102, 1209,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 2075,  ...,    0,    0,    0],
        [   1, 4102, 2019,  ...,    0,    0,    0],
        [   1, 4102,  970,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3442,  ...,    0,    0,    0],
        [   1, 4102, 1737,  ...,    0,    0,    0],
        [   1, 4102,  702,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5237,  ...,    0,    0,    0],
        [   1, 4102, 4338,  ...,    0,    0,    0],
        [   1, 4102, 5894,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3236,  ...,    0,    0,    0],
        [   1, 4102,  658,  ...,    0,    0,    0],
        [   1, 4102,  603,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 4692,  ...,    0,    0,    0],
        [   1, 4102, 5948,  ..., 5639,  580,    2],
        [   1, 4102,  

        [   1, 4102, 5188,  ...,    0,    0,    0]])
tensor([[   1, 4102, 1509,  ..., 5639,  580,    2],
        [   1, 4102, 2489,  ...,    0,    0,    0],
        [   1, 4102, 1903,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5237,  ...,    0,    0,    0],
        [   1, 4102, 1386,  ..., 5639,  580,    2],
        [   1, 4102,  287,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5385,  ...,    0,    0,    0],
        [   1, 4102,   92,  ...,    0,    0,    0],
        [   1, 4102, 6088,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 3695,  ...,    0,    0,    0],
        [   1, 4102,  759,  ..., 5639,  580,    2],
        [   1, 4102, 1533,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2366,  ...,    0,    0,    0],
        [   1, 4102,  702,  ...,    0,    0,    0],
        [   1, 4102, 3529,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  701,  ...,    0,    0,    0],
        [   1, 4102,  463,  ...,    0,    0,    0],
        [   1, 4102, 2

tensor([[   1, 4102, 5385,  ...,    0,    0,    0],
        [   1, 4102,  389,  ..., 5639,  580,    2],
        [   1, 4102, 5842,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5227,  ...,    0,    0,    0],
        [   1, 4102, 2749,  ...,    0,    0,    0],
        [   1, 4102, 6090,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3471,  ..., 5639,  580,    2],
        [   1, 4102, 4481,  ...,    0,    0,    0],
        [   1, 4102,  121,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  603,  ..., 5639,  580,    2],
        [   1, 4102, 4100,  ...,    0,    0,    0],
        [   1, 4102,  759,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3777,  ...,    0,    0,    0],
        [   1, 4102, 2600,  ...,    0,    0,    0],
        [   1, 4102, 2638,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5275,  ...,    0,    0,    0],
        [   1, 4102, 1483,  ...,    0,    0,    0],
        [   1, 4102, 1130,  ...,    0,    0,    0]])
tensor([[   1, 4102,  

        [   1, 4102, 4071,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4684,  538, 2839, 1302, 1501, 2976, 6063, 5431, 4709, 6080,
         5631, 5639,  145, 1063, 3777, 2514, 5663, 2976, 4654,  191, 3854, 2217,
         1886, 5639, 5344, 5050,  653, 1642,  342, 2976, 4818, 1247, 3233, 2239,
         4450, 5639, 4482, 5385,  287, 4609, 5324, 2976, 1383, 4491, 2507, 6049,
         1659, 5639,  580,    2],
        [   1, 4102, 5344, 5590, 4467, 5333,  156, 2976, 2095, 5842, 1037,   70,
         2445, 5639, 2602, 3457, 5628, 4295, 3277, 2976, 1928, 3589, 3447, 2217,
         4021, 5639,  580,    2,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [   1, 4102,  287, 2350, 4798, 1593, 1405, 3258, 5263, 2976, 2905, 5553,
         5975, 4139, 5379, 5021, 5021, 5639,  881,  490, 2239, 5144, 4661, 3020,
         4999, 2976, 3020, 4999,  778, 2122,  459, 2608, 4021, 5639, 

tensor([[   1, 4102, 3329,  ...,    0,    0,    0],
        [   1, 4102, 5748,  ...,    0,    0,    0],
        [   1, 4102, 4138,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 2507,  ..., 5639,  580,    2],
        [   1, 4102, 3077,  ..., 5639,  580,    2],
        [   1, 4102, 3198,  ...,    0,    0,    0]])
tensor([[   1, 4102, 1919,  ...,    0,    0,    0],
        [   1, 4102,  280,  ...,    0,    0,    0],
        [   1, 4102, 3582,  ..., 5639,  580,    2],
        ...,
        [   1, 4102,  747,  ..., 5639,  580,    2],
        [   1, 4102, 5636,  ...,    0,    0,    0],
        [   1, 4102, 3000,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5595,  ...,    0,    0,    0],
        [   1, 4102, 4100,  ...,    0,    0,    0],
        [   1, 4102, 5509,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5894,  ...,    0,    0,    0],
        [   1, 4102,  459,  ...,    0,    0,    0],
        [   1, 4102, 4930,  ...,    0,    0,    0]])
tensor([[   1, 4102, 1

        [   1, 4102, 1357,  ...,    0,    0,    0]])
tensor([[   1, 4102,  904,  ...,    0,    0,    0],
        [   1, 4102, 3375,  ...,    0,    0,    0],
        [   1, 4102, 1066,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 3640,  ...,    0,    0,    0],
        [   1, 4102, 1564,  ...,    0,    0,    0],
        [   1, 4102, 3188,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4608,  ...,    0,    0,    0],
        [   1, 4102,  281,  ...,    0,    0,    0],
        [   1, 4102, 5975,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 3090,  ...,    0,    0,    0],
        [   1, 4102, 2340,  ..., 5639,  580,    2],
        [   1, 4102,  702,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2508,  ...,    0,    0,    0],
        [   1, 4102, 2404,  ...,    0,    0,    0],
        [   1, 4102, 3785,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  234,  ...,    0,    0,    0],
        [   1, 4102, 2019,  ...,    0,    0,    0],
        [   1, 4102,  

        [   1, 4102, 5227,  ...,    0,    0,    0]])
tensor([[   1, 4102, 4696,  ...,    0,    0,    0],
        [   1, 4102,  148,  ...,    0,    0,    0],
        [   1, 4102,  612,  ..., 5639,  580,    2],
        ...,
        [   1, 4102,  759,  ...,    0,    0,    0],
        [   1, 4102, 5556,  ...,    0,    0,    0],
        [   1, 4102, 5702,  ...,    0,    0,    0]])
tensor([[   1, 4102, 3260,  ...,    0,    0,    0],
        [   1, 4102, 5016,  ...,    0,    0,    0],
        [   1, 4102, 1239,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 2297,  ...,    0,    0,    0],
        [   1, 4102, 3466,  ...,    0,    0,    0],
        [   1, 4102,  950,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5419,  ...,    0,    0,    0],
        [   1, 4102, 1037,  ...,    0,    0,    0],
        [   1, 4102, 3248,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 1655,  ...,    0,    0,    0],
        [   1, 4102, 2518,  ...,    0,    0,    0],
        [   1, 4102, 3

        [   1, 4102,  148,  ...,    0,    0,    0]])
tensor([[   1, 4102, 5674,  ...,    0,    0,    0],
        [   1, 4102, 5154,  ..., 5639,  580,    2],
        [   1, 4102,  482,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 3442,  ...,    0,    0,    0],
        [   1, 4102, 2473,  ...,    0,    0,    0],
        [   1, 4102,  306,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 2445,  ...,    0,    0,    0],
        [   1, 4102, 4106,  ...,    0,    0,    0],
        [   1, 4102, 4494,  ..., 5639,  580,    2],
        ...,
        [   1, 4102, 2738,  ..., 5639,  580,    2],
        [   1, 4102, 5236,  ..., 5639,  580,    2],
        [   1, 4102, 1099,  ...,    0,    0,    0]])
tensor([[   1, 4102, 2161,  ...,    0,    0,    0],
        [   1, 4102, 3660,  ..., 5639,  580,    2],
        [   1, 4102, 4863,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 3850,  ...,    0,    0,    0],
        [   1, 4102, 3521,  ...,    0,    0,    0],
        [   1, 4102, 5

        [   1, 4102, 2297,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 3729,  ..., 5639,  580,    2],
        [   1, 4102, 5227,  ...,    0,    0,    0],
        [   1, 4102, 5628,  ...,    0,    0,    0],
        ...,
        [   1, 4102,  701,  ...,    0,    0,    0],
        [   1, 4102, 2402,  ...,    0,    0,    0],
        [   1, 4102, 4467,  ..., 5639,  580,    2]])
tensor([[   1, 4102, 5379,  ...,    0,    0,    0],
        [   1, 4102, 4936,  ...,    0,    0,    0],
        [   1, 4102, 4684,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 4686,  ...,    0,    0,    0],
        [   1, 4102,   45,  ...,    0,    0,    0],
        [   1, 4102, 3634,  ...,    0,    0,    0]])
tensor([[   1, 4102, 6080,  ...,    0,    0,    0],
        [   1, 4102, 1509,  ...,    0,    0,    0],
        [   1, 4102, 2417,  ...,    0,    0,    0],
        ...,
        [   1, 4102, 5688,  ...,    0,    0,    0],
        [   1, 4102, 4974,  ...,    0,    0,    0],
        [   1, 4102, 2

## 5. 建立模型

In [46]:
class TangPoetry(nn.Module):
    def __init__(self, config):
        super(TangPoetry,self).__init__()
        if config.bidirectional:
            self.direction_num =2
        else:
            self.direction_num = 1
            
        self.embedding = nn.Embedding(config.vocab_size, config.embedding_size)
        self.lstm = nn.LSTM(config.embedding_size,
                            config.hidden_size,
                            num_layers=config.num_layers,
                            batch_first=True,
                            dropout
                            bidirectional=config.bidirectional)
        self.classifier=nn.Sequential(
            nn.Linear(config.hidden_size * self.direction_num, 512), 
            nn.ReLU(inplace=True), 
            nn.Linear(512, 2048), 
            nn.ReLU(inplace=True),
            nn.Linear(2048, config.vocab_size)
        )
        
    def forward(self,x):
        batch_size,seq_len= x.size()
        embeds = self.embedding(x) #[batch_size,seq_len,embedding_size]
        output,_ = self.lstm(embeds) # [batch_size,seq_len,hidden_size*direction_num]
        #nn.Linear输入是二维张量，所以得将lstm的输出展平
        output = self.classifier(output.reshape(batch_size*seq_len,-1)) 
        return output

In [47]:
model = TangPoetry(config)
model.to(device)

TangPoetry(
  (embedding): Embedding(2507, 256)
  (lstm): LSTM(256, 256, num_layers=3, batch_first=True)
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=512, out_features=2048, bias=True)
    (3): ReLU(inplace=True)
    (4): Linear(in_features=2048, out_features=2507, bias=True)
  )
)

## 6. 训练

In [48]:
def train(data_loader,model,epochs,device,lr):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(),lr=lr)
    train_loss = []
    for epoch in range(epochs):
        model.train()
        train_loss_epoch = 0.0
        for batch_idx, data in enumerate(data_loader):
            input_x, target = data[:,:-1].to(device),data[:,1:].to(device)
            target = target.reshape(-1)
            optimizer.zero_grad()
            output = model(input_x)
            loss = criterion(output,target)
            train_loss_epoch += loss.item()
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0:
                print('Train Epoch:{}[{}/{}({:.0f}%)]\tLoss:{:.6f}'.format(
                    epoch, batch_idx * len(input_x), len(data_loader.dataset),
                    100. * batch_idx / len(data_loader), loss.item()))
        train_loss_epoch /= len(data_loader)
        train_loss.append(train_loss_epoch)
        print('Train Epoch:{}\t average loss:{:.6f}'.format(epoch,train_loss_epoch))
    return train_loss
                
    

In [49]:
train_loss = train(data_loader,model,config.num_epochs,device,config.lr)



## 7. 生成

###  7.1 给定首句，生成

In [2]:
def generate(model,start_words,id2word,word2id,max_gen_len,device):
    res = list(start_words) #先将首句添加进诗中
    sent_len = len(start_words) #每句的长度

    # 设置第一个词为<START>
    input_x = torch.Tensor([word2ix['<START>']]).view(1, 1).long()
    input_x = input_x.to(DEVICE)
    for i in range(max_gen_len):
        output = model(input_x)
        if i < start_word_len:
            w = results[i]
            input_x = input_x.data.new([word2ix[w]]).view(1, 1)
        # 生成后面的句子
        else:
            top_index = output.data[0].topk(1)[1][0].item()
            w = id2word[top_index]
            res.append(w)
            input_x = input_x.data.new([top_index]).view(1,1)

        if w == '<EOP>':
            del res[-1]
            break
    return res

In [None]:
start_words = '春江花月夜'
max_gen_len = config.max_gen_len

res = generate(model,start_words,id2word,word2id,max_gen_len,device)
poetry = ''
for word in res:
    poetry += word
    if word == '。' or word == '!':
        poetry += '\n'

print(poetry)

### 7.2 生成藏头诗

In [50]:
def generate_head(model,device,max_gen_len,head_sentence,word2id,id2word):
    """
    生成藏头诗
    """
    res = []
    poetry_len = len(head_sentence) #要生成的句子的数量
    index = 0 #当前句子的数量
    pre_word = '<START>' # 前一个已经生成的字
    
    # 准备第一步输入的数据
    input_x = (torch.Tensor([word2id['<START>']]).view(1, 1).long()).to(device)
    input_x = input_x.to(device)
    
    for i in range(max_gen_len):
            # 前向计算出概率最大的当前词
            output= model(input)
            top_index = output.data[0].topk(1)[1][0].item()
            w = id2word[top_index]

            # 句首的字用藏头字代替
            if pre_word in {u'。', u'！', '<START>'}:
                if index == poetry_len:
                    break
                else:
                    w = head_sentence[index]
                    index += 1
                    input_x = (input_x.data.new([word2id[w]])).view(1,1)
            else:
                input_x = (input_x.data.new([top_index])).view(1,1)

            res.append(char)
            pre_word = w


    return res

In [51]:
start_words_acrostic = '春江花月夜'  # 唐诗的“头”
max_gen_len_acrostic = 120               # 生成藏诗的最长长度

res = generate_head(model,
                    device,
                    max_gen_len,
                    head_sentence,
                    word2id,
                    id2word)

poetry = ''
for word in res:
    poetry += word
    if word == '。' or word == '!':
        poetry += '\n'

print(poetry)

['春',
 '云',
 '，',
 '望',
 '女',
 '<PAD>',
 '江',
 '取',
 '。',
 '花',
 '棹',
 '里',
 '。',
 '月',
 '。',
 '夜',
 '师',
 '制',
 '，',
 '望',
 '女',
 '<PAD>']