In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from utils import pad_mask, subsequence_mask
from utils import PositionalEncoding, EncoderLayer, DecoderLayer
from utils import TransformerDataset

In [70]:
class CustomData:
    def __init__(self):
        self.src_vocab = {'P' : 0, 'ich' : 1, 'mochte' : 2, 'ein' : 3, 'bier' : 4, 'cola' : 5}
        self.src_vocab_size = len(self.src_vocab)
        self.src_idx2word = {i: w for i, w in enumerate(self.src_vocab)}

        self.tgt_vocab = {'P' : 0, 'i' : 1, 'want' : 2, 'a' : 3, 'beer' : 4, 'coke' : 5, 'S' : 6, 'E' : 7, '.' : 8}
        self.tgt_idx2word = {i: w for i, w in enumerate(self.tgt_vocab)}
        self.tgt_vocab_size = len(self.tgt_vocab)
        
    def make_data(self, sentences):
        enc_inputs, dec_inputs, dec_outputs = [], [], []
        for i in range(len(sentences)):
            enc_input = [[self.src_vocab[n] for n in sentences[i][0].split()]]
            dec_input = [[self.tgt_vocab[n] for n in sentences[i][1].split()]]        
            dec_output = [[self.tgt_vocab[n] for n in sentences[i][2].split()]]

            enc_inputs.extend(enc_input)
            dec_inputs.extend(dec_input)
            dec_outputs.extend(dec_output)

        return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)
    
    def get_test_data(self, sentences):
        return self.make_data(sentences)

In [3]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, k_dim, v_dim, n_heads, hidden_size, num_layers):
        super(Encoder, self).__init__()
        self.embeddding = nn.Embedding(vocab_size, embedding_size)
        self.positonal_encoder = PositionalEncoding(embedding_size)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            layer = EncoderLayer(embedding_size, k_dim, v_dim, n_heads, hidden_size)
            self.layers.append(layer)
            
    def forward(self, X):
        mask = pad_mask(X)
        X = self.embeddding(X)
        X = self.positonal_encoder(X)
        for layer in self.layers:
            X = layer.forward(X, mask)
        return X

In [4]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embedding_size, k_dim, v_dim, n_heads, hidden_size, num_layers):
        super(Decoder, self).__init__()
        self.embeddding = nn.Embedding(vocab_size, embedding_size)
        self.positonal_encoder = PositionalEncoding(embedding_size)
        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            layer = DecoderLayer(embedding_size, k_dim, v_dim, n_heads, hidden_size)
            self.layers.append(layer)
            
    def forward(self, X, X_encoder, encoder_inputs):
        self_attention_mask = pad_mask(X)
        subseq_mask = subsequence_mask(X)
        # torch.gt(matrix, target), if element in matrix, if element > target, return 1, else 0
        self_attention_mask = torch.gt((self_attention_mask + subseq_mask), 0).to(X.device)
        decoder_encoder_attention_mask = pad_mask(encoder_inputs, X)
        
        X = self.embeddding(X)
        X = self.positonal_encoder(X)
        for layer in self.layers:
            X = layer.forward(X, X_encoder, self_attention_mask, decoder_encoder_attention_mask)
        return X

In [5]:
class Transformer(nn.Module):
    def __init__(self, encoder_vocab_size, decoder_vocab_size, 
                 embedding_size, k_dim, v_dim, n_heads, hidden_size, num_layers):
        super(Transformer, self).__init__()
        self.encoder = Encoder(encoder_vocab_size, embedding_size, k_dim, v_dim, n_heads, hidden_size, num_layers)
        self.decoder = Decoder(decoder_vocab_size, embedding_size, k_dim, v_dim, n_heads, hidden_size, num_layers)
        self.linear = nn.Linear(embedding_size, decoder_vocab_size, bias=False)
        
    def forward(self, X_encoder, X_decoder):
        Y_encoder = self.encoder(X_encoder)
        Y_decoder = self.decoder(X_decoder, Y_encoder, X_encoder)
        # Y: [batch_size, length_decoder, vocab_decoder_size]
        Y = self.linear(Y_decoder)
        return Y

### 生成测试数据

In [72]:
sentences = [
        # enc_input           dec_input         dec_output
        ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
        ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]
custom_data = CustomData()
enc_inputs, dec_inputs, dec_outputs = custom_data.get_test_data(sentences)
dataset = TransformerDataset(enc_inputs, dec_inputs, dec_outputs)
data_loader = DataLoader(dataset, batch_size=2, shuffle=True)

In [73]:
for batch in data_loader:
    enc_inputs, dec_inputs, dec_outputs = [x for x in batch]
    break

In [74]:
enc_inputs

tensor([[1, 2, 3, 5, 0],
        [1, 2, 3, 4, 0]])

### 超参数

In [9]:
batch_size = 2
lr = 1e-4
weight_decay =1e-5
epochs = 5
num_layers = 6
n_heads = 8
k_dim = 64
v_dim = 64
embedding_size = 512
hidden_size = 300
encoder_vocab_size = custom_data.src_vocab_size
decoder_vocab_size = custom_data.tgt_vocab_size

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

### 初始化模型

In [10]:
model = Transformer(
    encoder_vocab_size, decoder_vocab_size,
    embedding_size, k_dim, v_dim, n_heads, hidden_size, num_layers
)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=lr, 
    weight_decay=weight_decay
)

### 模型训练

In [19]:
for epoch in range(epochs):
    for batch in data_loader:
        enc_inputs, dec_inputs, dec_outputs = [x.to(device) for x in batch]
        # outputs: [batch_size, length, vocab_size]
        outputs = model(enc_inputs, dec_inputs)
        l = loss(outputs.transpose(2, 1), dec_outputs)
        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(l))
        optimizer.zero_grad()
        l.backward()
        optimizer.step()

Epoch: 0001 loss = 0.468823
Epoch: 0002 loss = 0.229420
Epoch: 0003 loss = 0.162966
Epoch: 0004 loss = 0.085150
Epoch: 0005 loss = 0.063321


### 预测

In [324]:
class BeamSearch:
    def __init__(self, model, k=2, start_symbol=None, stop_symbol=None, max_predict_length=1000):
        self.model = model
        self.k, self.max_predict_length = k, max_predict_length
        self.start_symbol, self.stop_symbol= start_symbol, stop_symbol
        
    def greedy_decoder(self, X):
        X_encoder = X.view(1, -1)
        Y_encoder = model.encoder(X_encoder)
        
        next_word = self.start_symbol
        dec_input = torch.zeros(1, 0).type_as(enc_input.data)
        while True:
            # concate next word
            dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_word]],dtype=enc_input.dtype)],-1)
            # run decoder and linear
            Y_decoder = model.decoder(dec_input, Y_encoder, X_encoder)
            word_probability = model.linear(Y_decoder)
            word_probability = word_probability.squeeze(0).max(dim=-1, keepdim=False)[1]
            next_word = word_probability[-1]
            
            if next_word in self.stop_symbol or dec_input.size(1)>=self.max_predict_length:
                dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_word]],dtype=enc_input.dtype)],-1)
                return dec_input.squeeze(0)
    
    def search(self, X):
        K = self.k
        X_encoder = X.view(1, -1)
        Y_encoder = model.encoder(X_encoder)
        
        next_word = self.start_symbol
        dec_input = torch.zeros(1, 0).type_as(enc_input.data)
        dec_input = torch.cat([dec_input.detach(),torch.tensor([[next_word]],dtype=enc_input.dtype)],-1)
        sequences = [(dec_input, 1)]
        for _ in range(self.max_predict_length):
            # concate next word
            counter = 0
            all_condidates = []
            for (sequence, prob) in sequences:
                dec_input = sequence
                # run decoder and linear
                Y_decoder = model.decoder(dec_input, Y_encoder, X_encoder)
                word_probability = model.linear(Y_decoder)
                word_probability = nn.Softmax(dim=-1)(word_probability.squeeze(0))
                word_probability = word_probability[-1].squeeze(0)
                for i in range(word_probability.size(-1)):
                    candidate = (
                        torch.cat([dec_input.detach(),torch.tensor([[i]],dtype=enc_input.dtype)],-1),
                        prob * word_probability[i].item()
                    )
                    all_condidates.append(candidate)
            ordered = sorted(all_condidates, key=lambda x:x[1], reverse=True)
            sequences = ordered[:K]
        return sequences

In [325]:
start_symbol= custom_data.tgt_vocab['S']
stop_symbol = [custom_data.tgt_vocab['.'], custom_data.tgt_vocab['E']]
stop_symbol = [custom_data.tgt_vocab['E']]
max_predict_length = 6
enc_inputs, _, _ = next(iter(data_loader))

In [326]:
enc_inputs, len(enc_inputs), start_symbol

(tensor([[1, 2, 3, 4, 0],
         [1, 2, 3, 5, 0]]),
 2,
 6)

In [332]:
beam_search = BeamSearch(model, k=5, start_symbol=start_symbol, stop_symbol=stop_symbol, 
                         max_predict_length=max_predict_length)

In [333]:
for i in range(len(enc_inputs)):
    predict = beam_search.greedy_decoder(enc_inputs[i])
    print(predict)
    print([custom_data.src_idx2word[x.item()] for x in enc_inputs[i]])
    print([custom_data.tgt_idx2word[x.item()] for x in predict])

tensor([6, 1, 2, 3, 4, 8, 7])
['ich', 'mochte', 'ein', 'bier', 'P']
['S', 'i', 'want', 'a', 'beer', '.', 'E']
tensor([6, 1, 2, 3, 5, 8, 7])
['ich', 'mochte', 'ein', 'cola', 'P']
['S', 'i', 'want', 'a', 'coke', '.', 'E']


In [334]:
for i in range(len(enc_inputs)):
    predict_sequences = beam_search.search(enc_inputs[i])
    print(predict_sequences)
    print([custom_data.src_idx2word[x.item()] for x in enc_inputs[i]])
    print([[custom_data.tgt_idx2word[x.item()] for x in seq.squeeze(0)] for (seq, _) in predict_sequences])

[(tensor([[6, 1, 2, 3, 4, 8, 7]]), 0.7471778901401462), (tensor([[6, 1, 2, 3, 4, 4, 8]]), 0.023645795037044923), (tensor([[6, 1, 2, 2, 3, 4, 8]]), 0.020296269745991666), (tensor([[6, 1, 2, 3, 4, 8, 4]]), 0.014324764169617329), (tensor([[6, 1, 2, 4, 8, 7, 7]]), 0.011409323775865048)]
['ich', 'mochte', 'ein', 'bier', 'P']
[['S', 'i', 'want', 'a', 'beer', '.', 'E'], ['S', 'i', 'want', 'a', 'beer', 'beer', '.'], ['S', 'i', 'want', 'want', 'a', 'beer', '.'], ['S', 'i', 'want', 'a', 'beer', '.', 'beer'], ['S', 'i', 'want', 'beer', '.', 'E', 'E']]
[(tensor([[6, 1, 2, 3, 5, 8, 7]]), 0.729272290460904), (tensor([[6, 1, 2, 3, 4, 8, 7]]), 0.021989770150991506), (tensor([[6, 1, 2, 2, 3, 5, 8]]), 0.020796378916303152), (tensor([[6, 1, 2, 3, 5, 5, 8]]), 0.017689265194172424), (tensor([[6, 1, 2, 5, 8, 7, 7]]), 0.017455388161037484)]
['ich', 'mochte', 'ein', 'cola', 'P']
[['S', 'i', 'want', 'a', 'coke', '.', 'E'], ['S', 'i', 'want', 'a', 'beer', '.', 'E'], ['S', 'i', 'want', 'want', 'a', 'coke', '.'],

In [335]:
predict_sequences

[(tensor([[6, 1, 2, 3, 5, 8, 7]]), 0.729272290460904),
 (tensor([[6, 1, 2, 3, 4, 8, 7]]), 0.021989770150991506),
 (tensor([[6, 1, 2, 2, 3, 5, 8]]), 0.020796378916303152),
 (tensor([[6, 1, 2, 3, 5, 5, 8]]), 0.017689265194172424),
 (tensor([[6, 1, 2, 5, 8, 7, 7]]), 0.017455388161037484)]

In [336]:
l

tensor(0.0633, grad_fn=<NllLoss2DBackward>)