In [3]:
import torch
import torch.nn as nn
import numpy as np

In [4]:
class Deep_NMT(nn.Module):
    def __init__(self, source_vocab_size=3000, target_vocab_size=3000, 
                 embedding_size=32, source_length=100, target_length=100, lstm_size=128):
        super(Deep_NMT,self).__init__()
        self.source_embedding = nn.Embedding(source_vocab_size, embedding_size)
        self.target_embedding = nn.Embedding(target_vocab_size, embedding_size)
        self.encoder = nn.LSTM(input_size=embedding_size, hidden_size=lstm_size, num_layers=4, batch_first=True) 
        self.decoder = nn.LSTM(input_size=embedding_size, hidden_size=lstm_size, num_layers=4, batch_first=True)
        self.fc = nn.Linear(lstm_size, target_vocab_size)
        
        
    def forward(self, source_data, target_data, mode = "train"):
        source_data_embedding = self.source_embedding(source_data) # 64, 100, 32
        enc_output, enc_hidden = self.encoder(source_data_embedding)  # 64, 100, 128 \ 4, 64, 128
        # enc_output: b * length * lstm_size 只返回最高层的所有hidden
        # enc_hidden：[[h1,h2,h3,h4],[c1,c2,c3,c4]] 返回每层最后一个时间步的h和c
        
        if mode=="train":
            target_data_embedding = self.target_embedding(target_data) # 64, 100, 32
            dec_output, dec_hidden = self.decoder(target_data_embedding, enc_hidden)  # 64, 100, 128 \ 4, 64, 128
            outs = self.fc(dec_output) # 64, 100, 3000
        else:
            target_data_embedding = self.target_embedding(target_data) # 64, 1, 32
            dec_prev_hidden = enc_hidden # 64, 100, 128 \ 4, 64, 128
            outs = []
            for i in range(100):
                dec_output, dec_hidden = self.decoder(target_data_embedding, dec_prev_hidden) # 64, 1, 128 / 4, 64, 128
                pred = self.fc(dec_output)        # 64, 1, 3000
                pred = torch.argmax(pred, dim=-1)  # 64, 1
                outs.append(pred.squeeze().cpu().numpy()) # 64, 100
                dec_prev_hidden = dec_hidden 
                target_data_embedding = self.target_embedding(pred) # 64, 1, 32
        return outs

In [5]:
model = Deep_NMT()
source_data = torch.Tensor(np.zeros([64,100])).long()
target_data = torch.Tensor(np.zeros([64,100])).long()
train_preds = model(source_data, target_data, mode='train')
print (train_preds.shape)

target_data = torch.Tensor(np.zeros([64, 1])).long()
test_preds = model(source_data, target_data, mode="test")
print(np.array(test_preds).shape)

torch.Size([64, 100, 3000])
(100, 64)
