In [29]:
import numpy
import torch
from torch import nn
import utils

In [30]:
batch_size = 32
num_steps = 20
train_iter, source_Vocab, target_Vocab = utils.get_train_iter(batch_size, num_steps)

In [31]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, bidirectional=False):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size, num_hiddens, num_layers, bidirectional=bidirectional)
        self.num_layers = num_layers
        self.num_hiddens = num_hiddens
        if bidirectional:
            # 由于每一层有两个方向,因此需要将两个方向进行合并
            self.linear_hidden = nn.Linear(self.num_hiddens * 2, self.num_hiddens)
            self.linear_content = nn.Linear(self.num_hiddens * 2, self.num_hiddens)
            self.bidirectional = True

    def forward(self, X):
        X = self.embedding(X)
        X = X.permute(1, 0, 2)
        output, state = self.rnn(X)
        hidden_state, content_state = state
        if self.bidirectional:
            # 将每一层的正反state拼在一起,再放入神经网络中,使得与decoder的num_hiddens一致
            hidden_state = torch.cat(
                [hidden_state[:self.num_layers * 2:2, :, :], hidden_state[1:self.num_layers * 2 + 1:2, :, :]], dim=2)
            content_state = torch.cat(
                [content_state[:self.num_layers * 2:2, :, :], content_state[1:self.num_layers * 2 + 1:2, :, :]], dim=2)
            hidden_state = self.linear_hidden(hidden_state)
            content_state = self.linear_content(content_state)
        return hidden_state, content_state


class Decoder(nn.Module):

    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.LSTM(embed_size + num_hiddens * 2, num_hiddens, num_layers)
        self.linear = nn.Linear(num_hiddens, vocab_size)

    def init_state(self, encoder_output_state):
        return encoder_output_state

    def forward(self, X, state, predict=False):
        if not predict:
            X = self.embedding(X).permute(1, 0, 2)
            # 由于decoder的信息全由encoder的最后一个时间state得到,
            # 因此最后一个state的最后一层很重要,要尽可能的充分利用,
            # 因此将最后一个state的最后一层也作为decoder的输入
            hidden_state, content_state = state
            new_hidden_state = torch.cat([hidden_state[-1].unsqueeze(0)] * X.shape[0], dim=0)
            new_content_state = torch.cat([content_state[-1].unsqueeze(0)] * X.shape[0], dim=0)
            X = torch.cat([new_hidden_state, new_content_state, X], dim=2)
        # X 的shape为:(num_steps, batch_size, decoder_embed_size + encoder_hidden_num * 2)
        output, state = self.rnn(X, state)
        output = self.linear(output).permute(1, 0, 2)
        return output, state


def value_mask(X, valid_len):
    mask = torch.arange(X.shape[1], dtype=torch.float32, device=X.device)[None, :] > valid_len[:, None]
    X[mask] = 0
    return X


class Myloss(nn.CrossEntropyLoss):
    def forward(self, predict, target, valid_len=None):
        weights = torch.ones_like(target)
        weights = value_mask(weights, valid_len)
        self.reduction = 'none'
        unweighted_loss = super().forward(predict.permute(0, 2, 1), target)
        weighted_loss = unweighted_loss * weights
        return weighted_loss.mean()


class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target):
        encoder_output_state = self.encoder(source)
        decoder_init_state = self.decoder.init_state(encoder_output_state)
        return self.decoder(target, decoder_init_state)

In [32]:
def train(net, data_iter, lr, num_epochs, device):
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    loss = Myloss()
    net.train()
    for epoch in range(num_epochs):
        for batch in data_iter:
            optimizer.zero_grad()
            # 将数据放到device上
            source, source_valid_len, target, target_valid_len = [x.to(device) for x in batch]
            # 再每一个句子前面添加<bos>的index,bos的index为2
            bos = torch.tensor([2] * target.shape[0], device=device).reshape(-1, 1)
            decoder_input = torch.cat([bos, target[:, :-1]], dim=1)
            # 进行优化
            Y_hat, _ = net(source, decoder_input)
            l = loss(Y_hat, target, target_valid_len)
            l.backward()
            optimizer.step()
    print('训练完毕', l)

In [None]:
encoder_embed_size = 300
decoder_embed_size = 300
hidden_size = 64
num_layers = 2
encoder = Encoder(len(source_Vocab), decoder_embed_size, hidden_size, num_layers, True)
decoder = Decoder(len(target_Vocab), decoder_embed_size, hidden_size, num_layers)
net = EncoderDecoder(encoder, decoder)
num_epoch = 200
lr = 0.005
device = 'cuda'
train(net, train_iter, lr, num_epoch, device)

In [None]:
def predict(net, source_sentence, source_Vocab, target_Vocab, num_steps, device):
    # 用于存储译文
    result = []
    # 原文
    source = source_Vocab.prase(source_sentence).to(device)
    # 获取最后一个状态
    state = net.encoder(source)
    # 获取encoder的最后一个state的信息
    hidden_state, content_state = state
    new_hidden_state = hidden_state[-1].unsqueeze(0)
    new_content_state = content_state[-1].unsqueeze(0)
    # 初始化decoder的第一个状态
    state = net.decoder.init_state(state)
    # 构造翻译的第一个词
    X = torch.tensor(target_Vocab.word_to_index['<eos>']).reshape(-1, 1).to(device)
    X = net.decoder.embedding(X).permute(1, 0, 2)
    X = torch.cat([new_hidden_state, new_content_state, X], dim=2)
    for i in range(num_steps):
        # 开启预测模式,进行预测
        Y, state = net.decoder(X, state, True)
        X = Y.argmax(dim=2)
        # 获取最大概率的index
        pred = X.squeeze(dim=0).type(torch.int32).item()
        # 如果index为eos,则停止预测
        if pred == target_Vocab.word_to_index['<eos>']:
            break
        X = net.decoder.embedding(X).permute(1, 0, 2)
        X = torch.cat([new_hidden_state, new_content_state, X], dim=2)
        result.append(pred)
    return ' '.join(target_Vocab.to_word(result))

In [None]:
predict(net, 'Can anybody stop them?', source_Vocab, target_Vocab, num_steps, device)