In [8]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from utils import pad_mask, subsequence_mask
from utils import PositionalEncoding, EncoderLayer, DecoderLayer
from utils import TransformerDataset

In [9]:
class CustomData:
    def __init__(self):
        self.src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
        self.src_vocab_size = len(self.src_vocab)
        self.src_idx2word = {i: w for i, w in enumerate(self.src_vocab)}

        self.tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
        self.tgt_idx2word = {i: w for i, w in enumerate(self.tgt_vocab)}
        self.tgt_vocab_size = len(self.tgt_vocab)
        
    def make_data(self, sentences):
        enc_inputs, dec_inputs, dec_outputs = [], [], []
        for i in range(len(sentences)):
            enc_input = [[self.src_vocab[n] for n in sentences[i][0].split()]]
            dec_input = [[self.tgt_vocab[n] for n in sentences[i][1].split()]]        
            dec_output = [[self.tgt_vocab[n] for n in sentences[i][2].split()]]

            enc_inputs.extend(enc_input)
            dec_inputs.extend(dec_input)
            dec_outputs.extend(dec_output)

        return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)
    
    def get_test_data(self, sentences):
        return self.make_data(sentences)

### pytorch.nn.Transformer mask参数解释

**\*_mask: attn_mask(T, T), 为了在decoder中屏蔽未来的词**  
**\*_key_padding_mask: pad_mask(B, S/T), 避免PAD填充项参与运算**   
**另外，在最后计算loss的时候，也要指定ignoreindex=pad_idx**

### pytorch.nn.Transformer 是没有实现Positional Encoding的，需要自己实现 

In [27]:
# 等同于 nn.Transformer.generate_square_subsequent_mask
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [25]:
def create_mask(X_encoder, X_decoder):
    encoder_sequence_len = X_encoder.shape[1]
    decoder_sequence_len = X_decoder.shape[1]
    
    decoder_sub_sequence_mask = generate_square_subsequent_mask(decoder_sequence_len)
    encoder_sub_sequence_mask = torch.zeros((encoder_sequence_len, encoder_sequence_len)).type(torch.bool)
    
    encoder_padding_mask = (X_encoder == PAD_IDX)
    decoder_padding_mask = (X_decoder == PAD_IDX)
    return encoder_sub_sequence_mask, decoder_sub_sequence_mask, encoder_padding_mask, decoder_padding_mask

In [36]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, encoder_vocab_size, decoder_vocab_size, 
                 embedding_size, n_heads, hidden_size, num_encoder_layers, num_decoder_layers, dropout=0.1,
                 batch_first=True
                ):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = nn.Transformer(
            d_model=embedding_size,
            nhead=n_heads,
            num_encoder_layers= num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=hidden_size,
            dropout=dropout,
            batch_first=batch_first
        )
        self.linear = nn.Linear(embedding_size, decoder_vocab_size)
        self.encoder_embedding = nn.Embedding(encoder_vocab_size, embedding_size)
        self.decoder_embedding = nn.Embedding(decoder_vocab_size, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        
    def forward(self, X_encoder, X_decoder, 
                encoder_mask, decoder_mask, 
                encoder_key_padding_mask, 
                decoder_key_padding_mask,
                memory_key_padding_mask
               ):
        # embedding and positional encoding
        X_encoder = self.positional_encoding(
            self.encoder_embedding(X_encoder)
        )
        X_decoder = self.positional_encoding(
            self.decoder_embedding(X_decoder)
        )
        # transformer forward
        Y = self.transformer(
            X_encoder, X_decoder,
            encoder_mask, decoder_mask,
            None, encoder_key_padding_mask, decoder_key_padding_mask,
            memory_key_padding_mask
        )
        # Y: [batch_size, length_decoder, vocab_decoder_size]
        Y = self.linear(Y)
        return Y
    
    def encoder(self, X, mask):
        X = self.positional_encoding(
            self.encoder_embedding(X)
        )
        Y = self.transformer.encoder(X, mask)
        return Y
    
    def decoder(self, X, memory, mask):
        X = self.positional_encoding(
            self.decoder_embedding(X)
        )
        Y = self.transformer.decoder(X, memory, mask)
        return Y

### 生成测试数据

In [10]:
sentences = [
        # enc_input           dec_input         dec_output
        ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
        ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]
custom_data = CustomData()
enc_inputs, dec_inputs, dec_outputs = custom_data.get_test_data(sentences)
dataset = TransformerDataset(enc_inputs, dec_inputs, dec_outputs)
data_loader = DataLoader(dataset, batch_size=2, shuffle=True)

In [11]:
for batch in data_loader:
    enc_inputs, dec_inputs, dec_outputs = [x for x in batch]
    break

In [12]:
enc_inputs

tensor([[1, 2, 3, 4, 0],
        [1, 2, 3, 5, 0]])

### 超参数

In [94]:
batch_size = 2
lr = 1e-4
weight_decay =1e-5
epochs = 100
num_encoder_layers= num_decoder_layers = 6
n_heads = 8
embedding_size = 512
hidden_size = 300
dropout = 0.1
batch_first=True
encoder_vocab_size = custom_data.src_vocab_size
decoder_vocab_size = custom_data.tgt_vocab_size

PAD_IDX = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

### 初始化模型

In [95]:
model = Seq2SeqTransformer(
    encoder_vocab_size, decoder_vocab_size,
    embedding_size, n_heads, hidden_size, num_encoder_layers, num_decoder_layers, dropout, True
)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=lr, 
    weight_decay=weight_decay
)

### 模型训练

In [96]:
for epoch in range(epochs):
    for batch in data_loader:
        enc_inputs, dec_inputs, dec_outputs = [x.to(device) for x in batch]
        # mask
        encoder_sub_sequence_mask, decoder_sub_sequence_mask, encoder_padding_mask, decoder_padding_mask = create_mask(
            enc_inputs, dec_inputs
        )
        
        # outputs: [batch_size, length, vocab_size]
        outputs = model(
            enc_inputs, dec_inputs, 
            encoder_sub_sequence_mask, decoder_sub_sequence_mask, encoder_padding_mask, decoder_padding_mask,
            encoder_padding_mask
        )
        l = loss(outputs.transpose(2, 1), dec_outputs)
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(l))
        optimizer.zero_grad()
        l.backward()
        optimizer.step()

Epoch: 0001 loss = 2.342865
Epoch: 0002 loss = 1.988479
Epoch: 0003 loss = 1.983640
Epoch: 0004 loss = 1.990919
Epoch: 0005 loss = 1.912256
Epoch: 0006 loss = 1.940863
Epoch: 0007 loss = 1.961138
Epoch: 0008 loss = 1.937107
Epoch: 0009 loss = 1.792636
Epoch: 0010 loss = 1.867918
Epoch: 0011 loss = 1.973863
Epoch: 0012 loss = 1.801910
Epoch: 0013 loss = 1.827447
Epoch: 0014 loss = 1.761702
Epoch: 0015 loss = 1.789762
Epoch: 0016 loss = 1.654130
Epoch: 0017 loss = 1.629984
Epoch: 0018 loss = 1.649120
Epoch: 0019 loss = 1.677012
Epoch: 0020 loss = 1.472674
Epoch: 0021 loss = 1.595101
Epoch: 0022 loss = 1.534014
Epoch: 0023 loss = 1.495038
Epoch: 0024 loss = 1.390178
Epoch: 0025 loss = 1.299923
Epoch: 0026 loss = 1.251586
Epoch: 0027 loss = 1.121074
Epoch: 0028 loss = 1.097341
Epoch: 0029 loss = 1.165710
Epoch: 0030 loss = 1.097595
Epoch: 0031 loss = 0.904429
Epoch: 0032 loss = 0.982850
Epoch: 0033 loss = 0.976796
Epoch: 0034 loss = 0.692688
Epoch: 0035 loss = 0.682223
Epoch: 0036 loss = 0

### 预测

In [97]:
class BeamSearch:
    def __init__(self, model, k=2, start_symbol=None, stop_symbol=None, max_predict_length=1000):
        self.model = model
        self.k, self.max_predict_length = k, max_predict_length
        self.start_symbol, self.stop_symbol= start_symbol, stop_symbol
        
    def greedy_decoder(self, X):
        X_encoder = X.view(1, -1)
        num_tokens = X_encoder.shape[1]
        X_encoder_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        Y_encoder = model.encoder(X_encoder, X_encoder_mask)
        
        next_word = self.start_symbol
        dec_input = torch.zeros(1, 0).type_as(X_encoder.data)
        while True:
            # concate next word
            dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_word]],dtype=X_encoder.dtype)],-1)
            # run decoder and linear
            tgt_mask = (generate_square_subsequent_mask(dec_input.size(1))
                    .type(torch.bool))
            Y_decoder = model.decoder(dec_input, Y_encoder, tgt_mask)
            word_probability = model.linear(Y_decoder)
            word_probability = word_probability.squeeze(0).max(dim=-1, keepdim=False)[1]
            next_word = word_probability[-1]
            
            if next_word in self.stop_symbol or dec_input.size(1)>=self.max_predict_length:
                dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_word]],dtype=X_encoder.dtype)],-1)
                return dec_input.squeeze(0)
    
    def search(self, X):
        K = self.k
        X_encoder = X.view(1, -1)
        Y_encoder = model.encoder(X_encoder)
        
        next_word = self.start_symbol
        dec_input = torch.zeros(1, 0).type_as(enc_input.data)
        dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_word]],dtype=enc_input.dtype)],-1)
        sequences = [(dec_input, 1)]
        for _ in range(self.max_predict_length):
            # concate next word
            counter = 0
            all_condidates = []
            for (sequence, prob) in sequences:
                dec_input = sequence
                # run decoder and linear
                Y_decoder = model.decoder(dec_input, Y_encoder, X_encoder)
                word_probability = model.linear(Y_decoder)
                word_probability = nn.Softmax(dim=-1)(word_probability.squeeze(0))
                word_probability = word_probability[-1].squeeze(0)
                for i in range(word_probability.size(-1)):
                    candidate = (
                        torch.cat([dec_input.detach(),torch.tensor([[i]],dtype=enc_input.dtype)],-1),
                        prob * word_probability[i].item()
                    )
                    all_condidates.append(candidate)
            ordered = sorted(all_condidates, key=lambda x:x[1], reverse=True)
            sequences = ordered[:K]
        return sequences

In [98]:
start_symbol= custom_data.tgt_vocab['S']
stop_symbol = [custom_data.tgt_vocab['.'], custom_data.tgt_vocab['E']]
stop_symbol = [custom_data.tgt_vocab['E']]
max_predict_length = 6
enc_inputs, _, _ = next(iter(data_loader))

In [99]:
enc_inputs, len(enc_inputs), start_symbol

(tensor([[1, 2, 3, 4, 0],
         [1, 2, 3, 5, 0]]),
 2,
 6)

In [100]:
beam_search = BeamSearch(model, k=5, start_symbol=start_symbol, stop_symbol=stop_symbol, 
                         max_predict_length=max_predict_length)

In [101]:
for i in range(len(enc_inputs)):
    predict = beam_search.greedy_decoder(enc_inputs[i])
    print(predict)
    print([custom_data.src_idx2word[x.item()] for x in enc_inputs[i]])
    print([custom_data.tgt_idx2word[x.item()] for x in predict])

tensor([6, 1, 2, 3, 4, 8, 7])
['ich', 'mochte', 'ein', 'bier', 'P']
['S', 'i', 'want', 'a', 'beer', '.', 'E']
tensor([6, 1, 2, 3, 5, 8, 7])
['ich', 'mochte', 'ein', 'cola', 'P']
['S', 'i', 'want', 'a', 'coke', '.', 'E']


In [85]:
for i in range(len(enc_inputs)):
    predict_sequences = beam_search.search(enc_inputs[i])
    print(predict_sequences)
    print([custom_data.src_idx2word[x.item()] for x in enc_inputs[i]])
    print([[custom_data.tgt_idx2word[x.item()] for x in seq.squeeze(0)] for (seq, _) in predict_sequences])

TypeError: encoder() missing 1 required positional argument: 'mask'

In [None]:
predict_sequences

In [None]:
l