In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.nn import functional as F
from torch.autograd import Variable

In [1]:
class Attention_NMT(nn.Module):
    def __init__(self, source_vocab_size=3000, target_vocab_size=3000, embedding_size=128,
                 source_length=100, target_length=100, lstm_size=256, batch_size=64):
        super(Attention_NMT,self).__init__()
        
        self.source_embedding =nn.Embedding(source_vocab_size, embedding_size) 
        self.target_embedding = nn.Embedding(target_vocab_size, embedding_size) 
        self.encoder = nn.LSTM(input_size=embedding_size, hidden_size=lstm_size, num_layers=1, 
                               bidirectional=True, batch_first=True)  # 128, 256
        self.decoder = nn.LSTM(input_size=embedding_size + 2 * lstm_size, hidden_size=lstm_size, num_layers=1,
                               batch_first=True)   # 640, 256
        
        self.attention_fc_1 = nn.Linear(3 * lstm_size, 3 * lstm_size) # 768, 768
        self.attention_fc_2 = nn.Linear(3 * lstm_size, 1) # 768, 1
        self.class_fc_1 = nn.Linear(embedding_size + 2 * lstm_size + lstm_size, 2 * lstm_size) # 896, 512
        self.class_fc_2 = nn.Linear(2 * lstm_size, target_vocab_size) # 512, 3000
        
    # input_embedding: 64, 1, 128
    # dec_prev_hidden: 2, 64, 256
    # enc_output:      64, 100, 512
    def attention_forward(self, input_embedding, dec_prev_hidden, enc_output):
        '''
        query: st-1
        key = value = hi
        '''
        # si-1
        prev_dec_h = dec_prev_hidden[0].squeeze().unsqueeze(1).repeat(1, 100, 1) # 64, 1, 256 -> 64, 100, 256
        # eij = a(si-1, hj)  通过全连接生成权重
        atten_input = torch.cat([prev_dec_h, enc_output], dim=-1) # 64, 100, 256+512=768
        # attention全连接层
        attention_weights = self.attention_fc_2(F.relu(self.attention_fc_1(atten_input)))  # 64,100,768->64,100,768->64,100,1
        # aij = softmax(eij)
        attention_weights = F.softmax(attention_weights, dim=1)   # 64, 100, 1
        # atten_output = ci = sum(aij * hj)
        atten_output = torch.sum(attention_weights * enc_output, dim=1).unsqueeze(1) # 64, 100, 512 -> 64, 512 -> 64, 1, 512 
        
        dec_lstm_input = torch.cat([input_embedding, atten_output], dim=2)  # 64, 1, 640
        dec_output, dec_hidden = self.decoder(dec_lstm_input, dec_prev_hidden) # 64, 1, 512  / 2, 64, 256
        return atten_output, dec_output, dec_hidden  # 64, 1, 512 / 64, 1, 512 / 2, 64, 256
    
    
    def forward(self, source_data, target_data, mode = "train", is_gpu=True):
        source_data_embedding = self.source_embedding(source_data)   # 64, 100, 128
        # enc_output: b * length * (2*lstm_size) 返回所有hidden, concat
        # enc_hidden：[[h1,h2],[c1,c2]] 返回每个方向最后一个时间步的h和c
        enc_output, enc_hidden = self.encoder(source_data_embedding)  # 64, 100, 512  / 2, 64, 256 
        # Variable包裹tensor后可以进行反向传播，与tensor无太大区别
        self.atten_outputs = Variable(torch.zeros(target_data.shape[0],  
                                                  target_data.shape[1], 
                                                  enc_output.shape[2])) # 64, 100, 512
        self.dec_outputs = Variable(torch.zeros(target_data.shape[0], 
                                                target_data.shape[1], 
                                                enc_hidden[0].shape[2])) # 64, 100, 512
        if is_gpu:
            self.atten_outputs = self.atten_outputs.cuda()
            self.dec_outputs = self.dec_outputs.cuda()
            
        if mode=="train": 
            target_data_embedding = self.target_embedding(target_data)  # 64, 100, 128
            # 合并最后一个时间步的同向的h和c
            dec_prev_hidden = [enc_hidden[0][0].unsqueeze(0), enc_hidden[1][0].unsqueeze(0)]  # 2, 64, 256
        
            for i in range(100):
                input_embedding = target_data_embedding[:, i, :].unsqueeze(1)  # 64, 1, 128
                # 64, 1, 512 / 64, 1, 512 / 2, 64, 256
                atten_output, dec_output, dec_hidden = self.attention_forward(input_embedding,
                                                                              dec_prev_hidden,
                                                                              enc_output)  
                self.atten_outputs[:, i] = atten_output.squeeze()  # 64, 512
                self.dec_outputs[:, i] = dec_output.squeeze()    # 64, 256
                dec_prev_hidden = dec_hidden   # 2, 64, 256
            # 64, 100, 128 + 64, 100, 512 + 64, 100, 256 = 64, 100, 896
            class_input = torch.cat([target_data_embedding, self.atten_outputs, self.dec_outputs], dim=2)  
            # 64, 100, 896 -> 64, 100, 512 -> 64, 100, 3000
            outs = self.class_fc_2(F.relu(self.class_fc_1(class_input)))
        else:
            input_embedding = self.target_embedding(target_data)
            dec_prev_hidden = [enc_hidden[0][0].unsqueeze(0), enc_hidden[1][0].unsqueeze(0)]
            outs = []
            for i in range(100):
                atten_output, dec_output, dec_hidden = self.attention_forward(input_embedding,
                                                                              dec_prev_hidden,
                                                                              enc_output)

                class_input = torch.cat([input_embedding, atten_output, dec_output], dim=2)
                pred = self.class_fc_2(F.relu(self.class_fc_1(class_input)))
                pred = torch.argmax(pred, dim=-1)
                outs.append(pred.squeeze().cpu().numpy())
                dec_prev_hidden = dec_hidden
                input_embedding = self.target_embedding(pred)
        return outs

NameError: name 'nn' is not defined

In [6]:
model = Attention_NMT()
source_data = torch.Tensor(np.zeros([64,100])).long()
target_data = torch.Tensor(np.zeros([64,100])).long()
preds = model(source_data, target_data,is_gpu=False)
print (preds.shape)

target_data = torch.Tensor(np.zeros([64, 1])).long()
preds = model(source_data, target_data, mode="test", is_gpu=False)
print(np.array(preds).shape)

torch.Size([64, 100, 896])
torch.Size([64, 100, 3000])
(100, 64)
