## 简介

In [73]:
"""
在pytorch tutorial的某个例子的基础上作了改动，可以运行，可以作为参考
暂时还是用的例子里面的attention方式,已经实现了bidirectional LSTM
尚未仔细调参，效果待定
"""

'\n在pytorch tutorial的某个例子的基础上作了改动，可以运行，可以作为参考\n暂时还是用的例子里面的attention方式,已经实现了bidirectional LSTM\n尚未仔细调参，效果待定\n'

## 包导入与常量定义

In [74]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim
import os
import random
import numpy as np
import time
import math
torch.manual_seed(1)
random.seed(1)


In [75]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## 数据处理

In [76]:
root_path = "./rnnpg_data_emnlp-2014/partitions_in_Table_2/rnnpg/"  #到数据集的路径，可能根据具体情况修改
BATCH_SIZE=128
LEN = 7 # 用于决定5言还是7言

In [77]:
def get_train_data(fileName):
    """
    @params:
        fileName:文件名，具体应该为"qtrain"

    @return:
        poem_line_lst5:五言绝句列表
        poem_line_lst7:七言绝句列表
        wd2Idx5:适用于五言绝句的wd2Idx映射
        wd2Idx7:适用于七言绝句的wd2Idx映射
        idx2Wd5:适用于五言绝句的idx2Wd映射
        idx2Wd7:适用于七言绝句的idx2Wd映射
        poem_vec_lst5:映射后的五言绝句列表
        poem_vec_lst7:映射后的七言绝句列表
    
    其它:
        暂时没有为每句诗加上<S>和<E>
    """
    poem_line_lst5 = []
    poem_line_lst7 = []

    poem_vec_lst5 = []
    poem_vec_lst7 = []

    vocab5 = []
    vocab7 = []

    with open(root_path + fileName, 'r', encoding='utf-8') as fin:
        for line in fin:
            line = (" ".join(line.strip().split("\t"))).split(" ")
            line = ["<S>"] + line + ["<E>"]
            if len(line) == 22:
                poem_line_lst5.append(line)
                vocab5.extend(line)
            elif len(line) == 30:
                poem_line_lst7.append(line)
                vocab7.extend(line)

    vocab5 = list(set(vocab5))
    vocab7 = list(set(vocab7))
    
    random.shuffle(poem_line_lst5)
    random.shuffle(poem_line_lst7)

    wd2Idx5 = {wd: idx for idx, wd in enumerate(vocab5)}
    wd2Idx7 = {wd: idx for idx, wd in enumerate(vocab7)}

    idx2Wd5 = {idx: wd for idx, wd in enumerate(vocab5)}
    idx2Wd7 = {idx: wd for idx, wd in enumerate(vocab7)}

    poem_vec_lst5 = [[wd2Idx5[wd] for wd in line] for line in poem_line_lst5]
    poem_vec_lst7 = [[wd2Idx7[wd] for wd in line] for line in poem_line_lst7]

    print(len(poem_line_lst5), len(poem_line_lst7))
    print(len(wd2Idx5), len(wd2Idx7))
    print(len(poem_vec_lst5), len(poem_vec_lst7))

    return poem_line_lst5, poem_line_lst7, wd2Idx5, wd2Idx7, idx2Wd5, idx2Wd7,poem_vec_lst5, poem_vec_lst7


poem_line_lst5, poem_line_lst7, wd2Idx5, wd2Idx7, idx2Wd5, idx2Wd7, poem_vec_lst5, poem_vec_lst7 = get_train_data( 
    "qtrain")

11274 63535
5260 6742
11274 63535


In [78]:
def get_batch(data,bat,sent_len):
    """
    @params:
        data:待划分的数据集
        bat:BATCH_SIZE
        sent_len:单句长度
    
    @returns:
        X_batch:shape: len(data)//bat,bat,seq_len,其中seq_len包含四句诗
        Y_batch:shape: len(data)//bat,bat,seq_len,其中seq_len包含后三句诗
    """
    X_batch = []
    Y_batch = []
    for idx in range(len(data)//bat):
        st = idx * bat
        ed = st + bat
        X_batch.append([vec[:sent_len] for vec in data[st:ed]])
        Y_batch.append([vec[sent_len:] for vec in data[st:ed]])
    X_batch = torch.tensor(X_batch,device=device)
    Y_batch = torch.tensor(Y_batch,device=device)
    
    return X_batch,Y_batch

X_batch,Y_batch = get_batch(poem_vec_lst7,BATCH_SIZE,LEN+1)

In [79]:
print(X_batch.shape)
print(X_batch.size(0))
# print(X_batch[0].permute(1,0))

torch.Size([496, 128, 8])
496


## 时间处理函数

In [80]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

## Encoder 模块

In [81]:
class Encoder(nn.Module):
    def __init__(self,input_size,hidden_size,vec_dim,num_layer):
        super(Encoder,self).__init__()
        self.hidden_size = hidden_size
        self.vec_dim = vec_dim
        self.embedding = nn.Embedding(input_size,vec_dim)
        self.lstm = nn.LSTM(vec_dim,hidden_size,num_layers=num_layer,bidirectional=True)
        self.num_layer = num_layer
        self.num_dir = 1 if self.lstm.bidirectional == False else 2

    def forward(self,input,hidden):
        """
        @params:
            input:(seq_len,batch)
            hidden=(hn,cn):(num_layers*num_dirs,batch,hidden_size)*2
        """
        seq_len,batch = input.size()

        embedded = self.embedding(input).view(seq_len,batch,-1) 
        output = embedded  # output:(seq_len,batch,vec_dim)
        output,hidden = self.lstm(output,hidden) # output:(seq_len,batch,num_dir*hidden_size)
                                                # hidden:(num_layer*num_dir,batch,hidden_size)*2
        output = output[:,:,:self.hidden_size]+output[:,:,self.hidden_size:]
        hn,cn = hidden
        hn = hn.view(self.num_layer,self.num_dir,batch,self.hidden_size)
        hn = hn[:,0,:,:] + hn[:,1,:,:]

        cn = cn.view(self.num_layer,self.num_dir,batch,self.hidden_size)
        cn = cn[:,0,:,:] + cn[:,1,:,:]
        # output:seq_len,batch,hidden_size
        # hidden=(hc,cn):(num_layer,batch,hidden_size)*2
        return output,(hn,cn)

    def initHidden(self,bat):
        """
        @params
            bat:batch参数
        """
        h0 = torch.zeros(self.num_layer*self.num_dir, bat, self.hidden_size, device=device)
        c0 = torch.zeros(self.num_layer*self.num_dir, bat, self.hidden_size, device=device)
        return (h0,c0) 
        


## 带attention机制的Decoder模块

In [82]:
class Decoder(nn.Module):
    def __init__(self,input_size,hidden_size,vec_dim,num_layer,dropout_p):
        super(Decoder,self).__init__()
        self.hidden_size = hidden_size
        self.vec_dim = vec_dim
        self.embedding = nn.Embedding(input_size,vec_dim)
        self.encode_seq_len = LEN+1
        self.dropout_p = dropout_p
        self.input_size = input_size
        self.num_layer = num_layer

        self.lstm = nn.LSTM(vec_dim,hidden_size,num_layers=num_layer,bidirectional=False)
        self.attn = nn.Linear(self.hidden_size+self.vec_dim,self.encode_seq_len)
        self.attn_combine = nn.Linear(self.hidden_size+self.vec_dim,self.vec_dim)
        self.dropout = nn.Dropout(self.dropout_p)
        self.out = nn.Linear(self.hidden_size,self.input_size)

        self.num_dir = 1 if self.lstm.bidirectional == False else 2

    def forward(self,input,hidden,encoder_outputs):
        """
        @params:
            encoder_outputs:encode_seq_len,batch,num_dir*hidden_size
            hidden=(hc,cn):(num_layer*num_dir,batch,hidden_size)*2
            input:seq_len,batch
        """
        seq_len,batch = input.size()  # when decoding ,we let seq_len = 1

        embedded = self.embedding(input).view(seq_len,batch,-1)
        embedded = self.dropout(embedded)      # embedded:1,batch,vec_dim

        attn_weights = F.softmax(              # attn_weights:batch,encode_seq_len
            self.attn(torch.cat((embedded[0], hidden[0][0]), 1)), dim=1) 
        attn_applied = torch.bmm(attn_weights.unsqueeze(1),
                                encoder_outputs.permute(1,0,2).contiguous()) 
        # so far,shape of attn_applied:batch,1,hidden_size
        attn_applied = attn_applied.permute(1,0,2).contiguous()

        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)
        output = F.relu(output)
        # output:1,batch,vec_dim
        output,hidden = self.lstm(output,hidden)
        # output:1,batch,vec_dim
        logits = self.out(output)  # logits:1,batch,input_size
        logits = logits.view(-1,self.input_size) #logits:(-1,input_size)
        
        return logits,hidden,attn_weights
    
    def initHidden(self,bat):
        h0 = torch.zeros(self.num_dir * self.num_layer, bat, self.hidden_size, device=device)
        c0 = torch.zeros(self.num_dir * self.num_layer, bat, self.hidden_size, device=device)
        return (h0,c0)


## train 模块

In [83]:
def train(input_tensor,target_tensor,encoder,decoder,encoder_optimizer,decoder_optimizer,criterion,wd2Idx):
    """
    @params
        input_tensor:batch,seq_len
    """

    input_tensor = input_tensor.permute(1,0).contiguous()
    target_tensor = target_tensor.permute(1,0).contiguous()

    encoder_hidden = encoder.initHidden(input_tensor.size()[1])

    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_len = input_tensor.size(0)
    target_len = target_tensor.size(0)

    encoder_outputs =  torch.zeros(LEN+1,encoder.hidden_size,device = device) # 单向、batch=1
    loss = 0

    encoder_outputs,encoder_hidden = encoder(input_tensor,encoder_hidden)
    # encoder_outputs:encode_seq_len,batch,num_dir*hidden_size
    # encoder_hidden:num_layer*num_dir,batch,hidden_size
    
    decoder_input = torch.tensor([wd2Idx["<S>"]]*BATCH_SIZE,device=device).view(1,BATCH_SIZE)
    decoder_hidden = encoder_hidden

    # Teacher forcing
    for di in range(target_len):
        decoder_output,decoder_hidden,decoder_attention = decoder(
            decoder_input,decoder_hidden,encoder_outputs
        )
        loss += criterion(decoder_output,target_tensor[di])
        decoder_input = target_tensor[di].view(1,-1)
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item()/target_len

## trainIters 模块

In [84]:
def trainIters(encoder, decoder, wd2Idx,epoch,print_every=100, plot_every=100, learning_rate=0.01):

    global X_batch,Y_batch

    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    criterion = nn.CrossEntropyLoss()
    batch_len = len(X_batch)
    for ep in range(epoch):
        print("epoch:{}".format(ep))
        for iter in range(0, batch_len):
            input_tensor = X_batch[iter]
            target_tensor = Y_batch[iter]

            loss = train(input_tensor, target_tensor, encoder,
                        decoder, encoder_optimizer, decoder_optimizer, criterion,wd2Idx)
            print_loss_total += loss
            plot_loss_total += loss

            if iter % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0
                print('%s (%d %d%%) %.4f' % (timeSince(start, (iter+1) / batch_len),
                                            iter+1, (iter+1) / batch_len * 100, print_loss_avg))

            if iter % plot_every == 0:
                plot_loss_avg = plot_loss_total / plot_every
                plot_losses.append(plot_loss_avg)
                plot_loss_total = 0

    # showPlot(plot_losses)

In [85]:
hidden_size = 256
vec_dim = 200
num_layer = 1

In [86]:
encoder = Encoder(len(wd2Idx7), hidden_size,vec_dim,num_layer).to(device)
decoder = Decoder(len(wd2Idx7), hidden_size,vec_dim,num_layer,dropout_p=0.1).to(device)

In [87]:
trainIters(encoder, decoder,wd2Idx7,10,print_every=5)

epoch:0
0m 0s (- 2m 14s) (1 0%) 1.7638
0m 1s (- 2m 26s) (6 1%) 7.5657
0m 3s (- 2m 24s) (11 2%) 7.1505
0m 4s (- 2m 22s) (16 3%) 7.0559
0m 6s (- 2m 20s) (21 4%) 6.9869
0m 7s (- 2m 19s) (26 5%) 6.9128
0m 9s (- 2m 17s) (31 6%) 6.8822
0m 10s (- 2m 15s) (36 7%) 6.8506
0m 11s (- 2m 13s) (41 8%) 6.8084
0m 13s (- 2m 11s) (46 9%) 6.7811
0m 14s (- 2m 9s) (51 10%) 6.7550
0m 16s (- 2m 7s) (56 11%) 6.8223
0m 17s (- 2m 6s) (61 12%) 6.7848
0m 19s (- 2m 4s) (66 13%) 6.7626
0m 20s (- 2m 2s) (71 14%) 6.7480
0m 21s (- 2m 1s) (76 15%) 6.7182
0m 23s (- 1m 59s) (81 16%) 6.7333
0m 24s (- 1m 57s) (86 17%) 6.7087
0m 26s (- 1m 56s) (91 18%) 6.6663
0m 27s (- 1m 55s) (96 19%) 6.6634
0m 29s (- 1m 53s) (101 20%) 6.6681
0m 30s (- 1m 52s) (106 21%) 6.6412
0m 31s (- 1m 50s) (111 22%) 6.5857
0m 33s (- 1m 49s) (116 23%) 6.5706
0m 34s (- 1m 47s) (121 24%) 6.5818
0m 36s (- 1m 46s) (126 25%) 6.6142
0m 37s (- 1m 44s) (131 26%) 6.5296
0m 38s (- 1m 43s) (136 27%) 6.5576
0m 40s (- 1m 41s) (141 28%) 6.5234
0m 41s (- 1m 40s) (146

KeyboardInterrupt: 

## Generate with Beam Search

In [88]:
class trace:
    def __init__(self):
        self.poem = ["<S>"]
        self.hidden = None
        self.posb = 0

In [89]:
def generate(encoder,decoder,wd2Idx,idx2Wd,input_tensor):
    """
    @params:
        input_tensor:(1,seq_len)  已经向量化了
    """
    encoder.eval()
    decoder.eval()

    input_tensor =  input_tensor.permute(1,0).contiguous()

    encoder_hidden = encoder.initHidden(1)
    encoder_outputs = torch.zeros(LEN+1,encoder.hidden_size, device=device)
    encoder_outputs,encoder_hidden = encoder(input_tensor,encoder_hidden)

    beam = [trace()]
    beam[0].hidden = encoder_hidden

    k=5
    for _ in range(4*LEN):
        btmp = []
        for tce in beam:
            inputs = torch.tensor([wd2Idx[tce.poem[-1]]]).view(1,1).to(device)
            outputs,hidden,attention = decoder(inputs,tce.hidden,encoder_outputs)
            topk = torch.topk(F.softmax(outputs[0]),k)
            for i in range(k):
                nxt = trace()
                nxt.poem = tce.poem+[idx2Wd[topk[1][i].item()]]
                nxt.hidden = hidden
                nxt.posb = tce.posb + np.log(topk[0][i].item())
                btmp.append(nxt)
        beam = []
        for _ in range(k):
            posMax = -1e6
            idxMax = 0
            for idx,tce in enumerate(btmp):
                if tce.posb - posMax > 1e-6:
                    posMax = tce.posb
                    idxMax = idx
            beam.append(btmp[idxMax])
            btmp.remove(btmp[idxMax])
        
    return beam
 

In [90]:
input_tensor = X_batch[3][0].view(1,-1)
beam_res = generate(encoder,decoder,wd2Idx7,idx2Wd7,input_tensor)


In [91]:
encoded_words = [idx2Wd7[idx.item()] for idx in input_tensor[0]]
for idx,each in enumerate(beam_res):
    print("No.",idx,sep="")
    poem_lst = each.poem
    poem_lst = poem_lst[1:poem_lst.index("<E>")]
    print("".join(encoded_words[1:]))
    for i in range(len(poem_lst)):
        print(poem_lst[i],end="")
        if (i+1)%7 == 0:
            print("")
    print("")
    print("posb:",each.posb)

No.0
晴日东山饱看花
不知何处有人家
不知此地无消息
只有春风一夜来

posb: -49.7628253349643
No.1
晴日东山饱看花
不知何处有人家
不知此地无消息
只有春风一片云

posb: -49.870424454870495
No.2
晴日东山饱看花
不知何处有人家
不知此地无消息
只有春风一片花

posb: -50.053576593636784
No.3
晴日东山饱看花
不知何处有人家
不知此地无消息
只有春风不肯知

posb: -50.059569612965994
No.4
晴日东山饱看花
不知何处有人家
不知此地无消息
只有人间不肯知

posb: -50.40054312784527


In [29]:
a = torch.tensor([[[1,2,3],[3,4,5]],[[1,2,5],[6,7,8]]])
print(a)

tensor([[[1, 2, 3],
         [3, 4, 5]],

        [[1, 2, 5],
         [6, 7, 8]]])


In [30]:
a = 1
print("a={}".format(a))

a=1
