Coded by Lujia Zhong @lujiazho<br>
Reference: https://github.com/graykode/nlp-tutorial, https://github.com/Kyubyong/transformer

In [1]:
import numpy as np
import torch
import torch.nn as nn

# positional encoding
def sinusoid_pos_encoding(n_position, d_model):
    position_enc = np.array([
        [pos / np.power(10000, (i-i%2)/d_model) for i in range(d_model)] 
        for pos in range(n_position)])
    position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])  # dim 2i
    position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])  # dim 2i+1
    return torch.FloatTensor(position_enc)

# for encoder & decoder
def padding_mask(seq_q, seq_k):
    B, len_q = seq_q.shape
    B, len_k = seq_k.shape
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # 0 is padding
    return pad_attn_mask.expand(B, len_q, len_k)

# for decoder
def subsequent_mask(seq):
    B, len_ = seq.shape
    mask = np.triu(np.ones([B, len_, len_]), k=1)
    return torch.from_numpy(mask).type(torch.uint8)


class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        
        self.d_key = config.d_key
        self.d_value = config.d_value
        self.n_heads = config.n_heads
        
        self.W_Q = nn.Linear(config.d_model, self.d_key * self.n_heads)
        self.W_K = nn.Linear(config.d_model, self.d_key * self.n_heads)
        self.W_V = nn.Linear(config.d_model, self.d_value * self.n_heads)
        
        self.linear = nn.Linear(self.n_heads * self.d_value, config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model)
        
        self.attn_dropout = nn.Dropout(config.dropout_rate)
        self.proj_dropout = nn.Dropout(config.dropout_rate)
        
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V, attn_mask):
        # Self-Attn: Q,K,V are torch.Size([2, len_of_sentence, 512])
        # Dec-Enc-Attn: Q is torch.Size([2, len_of_dec_sent, 512]) K,V are torch.Size([2, len_of_enc_sent, 512])

        residual, batch_size = Q, Q.shape[0]
        
        query_layer = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_key).transpose(1,2)
        key_layer = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_key).transpose(1,2)
        value_layer = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_value).transpose(1,2)

        # expand in heads' dimension
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
        
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / np.sqrt(self.d_key)
        attention_scores.masked_fill_(attn_mask, -1e9) # masked_fill_: 1 masked, 0 unmasked
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.attn_dropout(attention_probs)
        
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.transpose(1, 2)
        context_layer = context_layer.contiguous().view(batch_size, -1, self.n_heads*self.d_value)
        
        attention_output = self.linear(context_layer)
        attention_output = self.proj_dropout(attention_output)
        
        return self.layer_norm(attention_output + residual), attention_probs

class MLP(nn.Module):
    def __init__(self, config):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(config.d_model, config.d_mlp)
        self.fc2 = nn.Linear(config.d_mlp, config.d_model)
        self.layer_norm = nn.LayerNorm(config.d_model)
        self.dropout = nn.Dropout(config.dropout_rate)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, inputs):
        # torch.Size([2, len_of_sentence, 512])
        residual = inputs
        
        output = nn.ReLU()(self.fc1(inputs))
        # torch.Size([2, len_of_sentence, 2048])
        output = self.dropout(output)
        
        output = self.fc2(output)
        # torch.Size([2, len_of_sentence, 512])
        output = self.dropout(output)
        
        return self.layer_norm(output + residual)

class EncoderLayer(nn.Module):
    def __init__(self, config):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(config)
        self.ffn = MLP(config)

    def forward(self, enc_inputs, enc_self_attn_mask):
        # torch.Size([2, 4, 512]) torch.Size([2, 4, 4])

        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        # torch.Size([2, 4, 512]) torch.Size([2, 8, 4, 4])
        
        enc_outputs = self.ffn(enc_outputs)
        # torch.Size([2, 4, 512])
        
        return enc_outputs, attn

class DecoderLayer(nn.Module):
    def __init__(self, config):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(config)
        self.dec_enc_attn = MultiHeadAttention(config)
        self.ffn = MLP(config)

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        # torch.Size([2, 6, 512]) torch.Size([2, 4, 512]) torch.Size([2, 6, 6]) torch.Size([2, 6, 4])

        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # torch.Size([2, 6, 512]) torch.Size([2, 8, 6, 6])
        
        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        # torch.Size([2, 6, 512]) torch.Size([2, 8, 6, 4])
        
        dec_outputs = self.ffn(dec_outputs)
        # torch.Size([2, 6, 512])
        
        return dec_outputs, dec_self_attn, dec_enc_attn

class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, config.d_model)
        self.pos_emb = nn.Embedding.from_pretrained(sinusoid_pos_encoding(config.src_len+1, config.d_model),
                                                    freeze=True)
        self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.n_layers)])

    def forward(self, enc_inputs):
        # torch.Size([2, 4])

        enc_outputs = self.src_emb(enc_inputs) + self.pos_emb(enc_inputs)
        # torch.Size([2, 4, 512])
        
        enc_self_attn_mask = padding_mask(enc_inputs, enc_inputs)
        # torch.Size([2, 4, 4])

        enc_self_attns = []
        for layer in self.layers:
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            # enc_outputs: torch.Size([2, 4, 512])
            enc_self_attns.append(enc_self_attn)
        
        return enc_outputs, enc_self_attns

class Decoder(nn.Module):
    def __init__(self, config):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, config.d_model)
        self.pos_emb = nn.Embedding.from_pretrained(sinusoid_pos_encoding(config.tgt_len+1, config.d_model),
                                                    freeze=True)
        self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layers)])

    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        # torch.Size([2, 6]) torch.Size([2, 4]) torch.Size([2, 4, 512])

        dec_outputs = self.tgt_emb(dec_inputs) + self.pos_emb(dec_inputs)
        # torch.Size([2, 6, 512])
        
        dec_self_attn_pad_mask = padding_mask(dec_inputs, dec_inputs)
        # torch.Size([2, 6, 6])
        dec_self_attn_subsequent_mask = subsequent_mask(dec_inputs)
        # torch.Size([2, 6, 6])
        
        dec_self_attn_mask = torch.logical_or(dec_self_attn_pad_mask, dec_self_attn_subsequent_mask)
        # torch.Size([2, 6, 6])

        dec_enc_attn_mask = padding_mask(dec_inputs, enc_inputs)
        # torch.Size([2, 6, 4])

        dec_self_attns, dec_enc_attns = [], []
        for layer in self.layers:
            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, 
                                                             dec_self_attn_mask, dec_enc_attn_mask)
            # torch.Size([2, 6, 512]) torch.Size([2, 8, 6, 6]) torch.Size([2, 8, 6, 4])
            dec_self_attns.append(dec_self_attn)
            dec_enc_attns.append(dec_enc_attn)
        
        return dec_outputs, dec_self_attns, dec_enc_attns

class Transformer(nn.Module):
    def __init__(self, config):
        super(Transformer, self).__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.projection = nn.Linear(config.d_model, tgt_vocab_size, bias=False)
    
    def forward(self, enc_inputs, dec_inputs):
        # torch.Size([2, 4]) torch.Size([2, 6])

        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # torch.Size([2, 4, 512]) [6, torch.Size([2, 8, 4, 4])]

        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(dec_inputs, enc_inputs, enc_outputs)
        # torch.Size([2, 6, 512]) [6, torch.Size([2, 8, 6, 6])] [6, torch.Size([1, 8, 6, 4])]

        logits = self.projection(dec_outputs)
        # torch.Size([2, 6, 8])
        
        return logits.view(-1, logits.shape[-1]), enc_self_attns, dec_self_attns, dec_enc_attns

class ModelConfig:
    src_len = 4             # length of encoder inputs
    tgt_len = 6             # length of decoder inputs/outputs

    d_model = 512           # embedding Size
    d_mlp = 4*d_model       # MLP hidden dimension
    d_key = d_value = 64    # dimension of K == Q, V could be different in dot_product_attention
    n_layers = 6            # number of Encoder & Decoder Layer
    n_heads = 8             # number of heads in Multi-Head Attention
    
    dropout_rate = 0.1

**Data**

In [2]:
sentences = [['quiero una cerveza P', 'quiero una buena cerveza'], 
             ['S i want a beer P', 'S i want a good beer'], 
             ['i want a beer P E', 'i want a good beer E']]

# S: starting of decoder input
# E: ending of decoder output
# P: padding will fill in blank sequence for shorter sentence
src_vocab = {'P': 0, 'quiero': 1, 'una': 2, 'cerveza': 3, 'buena': 4}
idx2src = {i: w for i, w in enumerate(src_vocab)}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'good': 4, 'beer': 5, 'S': 6, 'E': 7}
idx2tgt = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

batch = 2

enc_inputs = torch.LongTensor([[src_vocab[n] for n in sentences[0][i].split()] for i in range(batch)])
dec_inputs = torch.LongTensor([[tgt_vocab[n] for n in sentences[1][i].split()] for i in range(batch)])
target = torch.LongTensor([[tgt_vocab[n] for n in sentences[2][i].split()] for i in range(batch)])

In [3]:
model = Transformer(ModelConfig())

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(200):
    optimizer.zero_grad()
    outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
    loss = criterion(outputs, target.contiguous().view(-1))
    if (epoch+1) % 20 == 0:
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
    loss.backward()
    optimizer.step()

Epoch: 0020 cost = 1.810086
Epoch: 0040 cost = 1.901784
Epoch: 0060 cost = 1.956345
Epoch: 0080 cost = 1.947199
Epoch: 0100 cost = 1.923884
Epoch: 0120 cost = 1.912930
Epoch: 0140 cost = 1.919369
Epoch: 0160 cost = 1.898509
Epoch: 0180 cost = 1.958702
Epoch: 0200 cost = 1.897602


In [4]:
# testing
predict, _, _, _ = model(enc_inputs, dec_inputs)
predict = predict.data.max(1, keepdim=True)[1].squeeze().view(batch,-1)
for i in range(batch):
    print(sentences[0][i], '->', ' '.join([idx2tgt[n.item()] for n in predict[i]]))

quiero una cerveza P -> i want i a want i
quiero una buena cerveza -> beer want a a beer i


In [5]:
# predicting
maximum = 10
idx = 0
src = enc_inputs[0].unsqueeze(0)
pred = torch.tensor([[6]])
while (idx < maximum):
    predict, _, _, _ = model(src, pred)
    nxt = torch.softmax(predict, -1).argmax(-1)[-1]
    pred = torch.cat((pred, nxt.unsqueeze(0).unsqueeze(0)), -1)
    if (nxt == tgt_vocab['E']):
        break
    idx += 1
    
print(' '.join([idx2src[n.item()] for n in src[0]]), 
      '->', 
      ' '.join([idx2tgt[n.item()] for n in pred[0]]))

quiero una cerveza P -> S beer a want want E
