In [1]:
import torch
import torch.nn as nn
from Transformer import Transformer
from processData import collate_fn, PAD_IDX, vocab_transform
from torchtext.datasets import Multi30k
from torch.utils.data import DataLoader



In [2]:
SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'
src_vocab_size = len(vocab_transform['de'])
tgt_vocab_size = len(vocab_transform['en'])

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
epochs = 10
batch_size = 128
d_model = 512
n_head = 8
d_ff = 2048
n_layers = 3

In [6]:
model = Transformer(src_vocab_size, tgt_vocab_size, d_model, n_head, d_ff, n_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [7]:
for epoch in range(epochs):
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=batch_size, collate_fn=collate_fn)

    for i, (enc_inputs, dec_inputs) in enumerate(train_dataloader):
        enc_inputs, dec_inputs = enc_inputs.to(device), dec_inputs.to(device)
        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
        loss = criterion(outputs, dec_inputs.contiguous().view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 50 == 0:
            print(f'Epoch: {epoch + 1}, loss = {loss.item()}')

Epoch: 1, loss = 2.7161667346954346
Epoch: 1, loss = 1.72246515750885
Epoch: 1, loss = 1.2473042011260986
Epoch: 1, loss = 1.312986969947815
Epoch: 2, loss = 0.7550764083862305
Epoch: 2, loss = 0.7086830139160156
Epoch: 2, loss = 0.5863784551620483
Epoch: 2, loss = 0.7209687829017639
Epoch: 3, loss = 0.4586000144481659
Epoch: 3, loss = 0.45253798365592957
Epoch: 3, loss = 0.36540523171424866
Epoch: 3, loss = 0.47708404064178467
Epoch: 4, loss = 0.3233679234981537
Epoch: 4, loss = 0.3235170543193817
Epoch: 4, loss = 0.2527376413345337
Epoch: 4, loss = 0.3431154191493988
Epoch: 5, loss = 0.24506182968616486
Epoch: 5, loss = 0.24522532522678375
Epoch: 5, loss = 0.18657676875591278
Epoch: 5, loss = 0.2624085545539856
Epoch: 6, loss = 0.19309698045253754
Epoch: 6, loss = 0.19185684621334076
Epoch: 6, loss = 0.13875356316566467
Epoch: 6, loss = 0.20751777291297913
Epoch: 7, loss = 0.15355652570724487
Epoch: 7, loss = 0.15084387362003326
Epoch: 7, loss = 0.10686091333627701
Epoch: 7, loss = 0

In [12]:
torch.save(model, './model.pth')

In [None]:
def greedy_decoder(model, enc_input, start_symbol):
    """
    For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
    target sequence input. Therefore we try to generate the target input word by word, then feed it into the transformer.
    Starting Reference: http://nlp.seas.harvard.edu/2018/04/03/attention.html#greedy-decoding
    :param model: Transformer Model
    :param enc_input: The encoder input
    :param start_symbol: The start symbol. In this example it is 'S' which corresponds to index 4
    :return: The target input
    """
    enc_outputs, enc_self_attns = model.encoder(enc_input)
    dec_input = torch.zeros(1, 0).type_as(enc_input.data)
    terminal = False
    next_symbol = start_symbol
    while not terminal:
        # 预测阶段：dec_input序列会一点点变长（每次添加一个新预测出来的单词）
        dec_input = torch.cat([dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
                              -1)
        dec_outputs, _, _ = model.decoder(dec_input, enc_input, enc_outputs)
        projected = model.projection(dec_outputs)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        # 增量更新（我们希望重复单词预测结果是一样的）
        # 我们在预测是会选择性忽略重复的预测的词，只摘取最新预测的单词拼接到输入序列中
        next_word = prob.data[-1]  # 拿出当前预测的单词(数字)。我们用x'_t对应的输出z_t去预测下一个单词的概率，不用z_1,z_2..z_{t-1}
        next_symbol = next_word
        if next_symbol == tgt_vocab["E"]:
            terminal = True
        # print(next_word)

    # greedy_dec_predict = torch.cat(
    #     [dec_input.to(device), torch.tensor([[next_symbol]], dtype=enc_input.dtype).to(device)],
    #     -1)
    greedy_dec_predict = dec_input[:, 1:]
    return greedy_dec_predict


# ==========================================================================================
# 预测阶段
enc_inputs, _, _ = next(iter(loader))
for i in range(len(enc_inputs)):
    greedy_dec_predict = greedy_decoder(model, enc_inputs[i].view(1, -1).to(device), start_symbol=tgt_vocab["S"])
    print(enc_inputs[i], '->', greedy_dec_predict.squeeze())
    print([src_idx2word[t.item()] for t in enc_inputs[i]], '->',
          [idx2word[n.item()] for n in greedy_dec_predict.squeeze()])