In [1]:
import torch
import torch.nn as nn
from Transformer import Transformer
from processData import collate_fn, PAD_IDX, BOS_IDX, EOS_IDX, vocab_transform
from torchtext.datasets import Multi30k
from torch.utils.data import DataLoader



In [2]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'
src_vocab_size = len(vocab_transform['de'])
tgt_vocab_size = len(vocab_transform['en'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 12
batch_size = 128
d_model = 512
n_head = 8
d_ff = 2048
n_layers = 3

In [10]:
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, n_head, d_ff, n_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(epochs):
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)

    for i, (enc_inputs, dec_inputs) in enumerate(train_dataloader):
        enc_inputs, dec_inputs = enc_inputs.to(device), dec_inputs.to(device)

        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs[:, :-1])
        prob = outputs.max(dim=-1, keepdim=False)[1]
        loss = criterion(outputs, dec_inputs[:, 1:].contiguous().view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 50 == 0:
            print(f'Epoch: {epoch + 1}, loss = {loss.item()}')

In [5]:
torch.save(model.state_dict(), './model.pth')

In [15]:
model.load_state_dict(torch.load('./model.pth'))

<All keys matched successfully>

In [16]:
def greedy_decoder(model, enc_input, start_symbol, end_symbol):
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)
    terminal = False
    next_symbol = start_symbol
    i = 0
    while not terminal and i < 50:
        # 预测阶段：dec_input序列会一点点变长（每次添加一个新预测出来的单词）
        dec_input = torch.cat([dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
                              -1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        # 增量更新（我们希望重复单词预测结果是一样的）
        # 我们在预测是会选择性忽略重复的预测的词，只摘取最新预测的单词拼接到输入序列中
        next_word = prob.data[-1]  # 拿出当前预测的单词(数字)。我们用x'_t对应的输出z_t去预测下一个单词的概率，不用z_1,z_2..z_{t-1}
        next_symbol = next_word
        if next_symbol == end_symbol:
            terminal = True
        i += 1

    greedy_dec_predict = dec_input[:, 1:]
    return greedy_dec_predict

In [17]:
test_iter = Multi30k(split='test', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
test_dataloader = DataLoader(test_iter, batch_size=batch_size, collate_fn=collate_fn)

enc_inputs, dec_outputs = next(iter(test_dataloader))
enc_input = enc_inputs[11:12].to(device)
dec_output = dec_outputs[11:12]

greedy_dec_predict = greedy_decoder(model, enc_input, BOS_IDX, EOS_IDX)

w1 = [vocab_transform[TGT_LANGUAGE].lookup_token(idx) for idx in dec_output[0] 
                                            if idx != BOS_IDX and idx != PAD_IDX and idx != EOS_IDX]
w2 = [vocab_transform[TGT_LANGUAGE].lookup_token(idx) for idx in greedy_dec_predict[0] 
                                            if idx != BOS_IDX and idx != PAD_IDX and idx != EOS_IDX]

print(' '.join(w1))
print(' '.join(w2))

Men playing volleyball , with one player missing the ball but hands still in the air .
Men playing volleyball , a man playing a ball , while the other man is kicking his hands in the air .
