# Transformer: Attention is all you need

This jupyter notebook is Tensorflow version implemented in the paper [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf). The task is translating a source human-readable datetime to a target fixed datetime format **yyyy-mm-dd**, e.g: "24th Aug 19" -> "2019-08-24". Best way to start implement a model from scratch is using small dataset and non-complex.

In [1]:
import numpy as np
import tqdm
from faker import Faker
from babel.dates import format_date
from nmt_utils import load_dataset_v2, preprocess_data, string_to_int, int_to_string, softmax
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import train_test_split
import os
import torch
import torch.nn as nn

In [2]:
m = 40000
dataset, human_vocab, machine_vocab, inv_machine_vocab = load_dataset_v2(m)

100%|█████████████████████████████████████████████████████████████████████████| 40000/40000 [00:01<00:00, 32013.67it/s]


In [3]:
human_vocab

{'<pad>': 0,
 '<unk>': 1,
 ' ': 2,
 '.': 3,
 '/': 4,
 '0': 5,
 '1': 6,
 '2': 7,
 '3': 8,
 '4': 9,
 '5': 10,
 '6': 11,
 '7': 12,
 '8': 13,
 '9': 14,
 'a': 15,
 'b': 16,
 'c': 17,
 'd': 18,
 'e': 19,
 'f': 20,
 'g': 21,
 'h': 22,
 'i': 23,
 'j': 24,
 'l': 25,
 'm': 26,
 'n': 27,
 'o': 28,
 'p': 29,
 'r': 30,
 's': 31,
 't': 32,
 'u': 33,
 'v': 34,
 'w': 35,
 'y': 36}

In [4]:
machine_vocab

{'#': 0,
 '-': 1,
 '0': 2,
 '1': 3,
 '2': 4,
 '3': 5,
 '4': 6,
 '5': 7,
 '6': 8,
 '7': 9,
 '8': 10,
 '9': 11}

In [5]:
Tx = 30
Ty = 10

X, Y = preprocess_data(dataset, human_vocab, machine_vocab, Tx, Ty+1)

print("X.shape:", X.shape)
print("Y.shape:", Y.shape)

X.shape: (40000, 30)
Y.shape: (40000, 11)


## Transformer model with Tensorflow.

### Hyperparameter:

$d_{model}$: dimension of word embeding, output of **Multi-head Attention** layer, output of **Feed Forward** layer.

$d_k$: dimension of matrix Q, K

$d_v$: dimension of matrix V

$d_{ff}$: dimension of intermediate **Feed forward** layer

$h$: number of heads at each block.


### Positional Encoding:

Since the Transformer model isn't sequential model like RNN and CNN. The computation is parallel over all input sentence flow from Embedding Layer, so we need to compute the relative or absolute position between the words. The author use non-trainable/fixed signusoid function:

$$PE_{(pos, 2i)} = sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \mbox{this corresponding to the even indices}$$
$$PE_{(pos, 2i+1)} = cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) \mbox{this corresponding to the odd indices}$$

where $pos$ is position in the sequence and $i$ is the dimension.


### Scaled Dot-Product Attention:

<img style="width:300px; height:300px" src="https://i.imgur.com/HuXNlr0.png" />

$$Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

### (Encoder-Decoder) Multi-Head Attention:

<img style="weight:300px; height:300px" src="https://i.imgur.com/vgfOLR2.png" />

$$MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h)W^O$$
$$\mbox{where } head_i = Attention(Q, K, V)$$

### Feed forward:

$$FFN(x) = max(0, xW_1 + b_1)W_2 + b_2$$

### Encoder blocks:

Each encoder block include 2 layers: **Multi-head Attention Mechanism** and **Position-wise Feed Forward**, respestively. Output at each layer use residual connection with its input followed by [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf): $LayerNorm(x + f(x))$

### Decoder blocks:

Each decoder block includes 3 layers: **Multi-head Attention Mechanism**, **Encoder-Decoder Multi-head Attention** and **Position-wise Feed Forward**. Same as **Encoder** blocks, output at each layer use residual connection with its input follow by Layer Normalization.

<img src="https://i.imgur.com/1NUHvLi.jpg" />

In [6]:
class CustomLinear(nn.Linear):

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.bias is not None:
            nn.init.zeros_(self.bias)

class Transformer(nn.Module):

    def __init__(self, num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff, device):
        super(Transformer, self).__init__()
        self.num_blocks = num_blocks
        self.num_heads = num_heads
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.seq_len = seq_len
        self.d_k = d_k
        self.d_v = d_v
        self.d_ff = d_ff
        self.device = device
        self.word_embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model).to(device)

    def _init_structure(self, decoder_part=False):
        assert not hasattr(self, "pos_enc"), "The structure is initialized already."
        self.pos_embed = torch.zeros(size=(self.seq_len, self.d_model), requires_grad=False, device=self.device)
        for pos in range(self.seq_len):
            for i in range(0, self.d_model, 2):
                self.pos_embed[pos, i] = torch.sin(torch.Tensor([pos / (10000 ** (i/self.d_model))]))
                self.pos_embed[pos, i + 1] = torch.cos(torch.Tensor([pos / (10000 ** (i/self.d_model))]))
        self.pos_embed = self.pos_embed.unsqueeze(0)

        if decoder_part:
            self.mask = [[0]*(i+1) + [-1e9]*(self.seq_len-(i+1)) for i in range(self.seq_len)]
            self.mask = torch.tensor([self.mask], requires_grad=False).to(self.device)

        for block_id in range(self.num_blocks):
            # Self-attention sub-layer
            setattr(self, "Q" + str(block_id), 
                    CustomLinear(in_features=self.d_model, out_features=self.d_k * self.num_heads).to(self.device))
            setattr(self, "K" + str(block_id), 
                    CustomLinear(in_features=self.d_model, out_features=self.d_k * self.num_heads).to(self.device))
            setattr(self, "V" + str(block_id), 
                    CustomLinear(in_features=self.d_model, out_features=self.d_v * self.num_heads).to(self.device))
                        
            setattr(self, "LN1" + str(block_id), nn.LayerNorm(self.d_model).to(self.device))
            # ---------------------------
            
            # Encoder-Decoder attention sub-layer
            if decoder_part:
                setattr(self, "Qconn" + str(block_id), 
                        CustomLinear(in_features=self.d_model, out_features=self.d_k * self.num_heads).to(self.device))
                setattr(self, "Kconn" + str(block_id), 
                        CustomLinear(in_features=self.d_model, out_features=self.d_k * self.num_heads).to(self.device))
                setattr(self, "Vconn" + str(block_id), 
                        CustomLinear(in_features=self.d_model, out_features=self.d_v * self.num_heads).to(self.device))
                
                setattr(self, "LN2" + str(block_id), nn.LayerNorm(self.d_v * self.num_heads).to(self.device))
            # -----------------------------------
            
            # Layer multi-head attention output
            setattr(self, "O" + str(block_id), 
                    CustomLinear(in_features=self.d_v * self.num_heads, out_features=self.d_model).to(self.device))
            # Layer FNN 1
            setattr(self, "FNN1" + str(block_id), 
                    CustomLinear(in_features=self.d_model, out_features=self.d_ff).to(self.device))
            # Layer FNN 2
            setattr(self, "FNN2" + str(block_id), 
                    CustomLinear(in_features=self.d_ff, out_features=self.d_model).to(self.device))
            
            setattr(self, "LN3" + str(block_id), nn.LayerNorm(self.d_model).to(self.device))

    def _compute_multi_head_attention(self, Q, K, V, block_id, mask=False, connection_head=False):
        if connection_head:
            Q = getattr(self, "Qconn" + str(block_id))(Q)
            K = getattr(self, "Qconn" + str(block_id))(K)
            V = getattr(self, "Qconn" + str(block_id))(V)
        else:
            Q = getattr(self, "Q" + str(block_id))(Q)
            K = getattr(self, "Q" + str(block_id))(K)
            V = getattr(self, "Q" + str(block_id))(V)
        QK = torch.einsum("ntk,nyk->nty", Q, K)
        if mask:
            # apply mask to QK, prevent the affect of feature words to current word in decoder.
            QK = QK + self.mask[:, :QK.shape[1], :QK.shape[2]]
        QK = torch.softmax(QK/torch.sqrt(torch.Tensor([self.d_model]).to(self.device)), dim=-1)
        atts = torch.einsum("nty,nyv->ntv", QK, V)
        O = getattr(self, "O" + str(block_id))(atts)
        return O

    def _compute_layer_norm(self, fX, X, name):
        LN = getattr(self, name)(fX + X)
        return LN

    def _compute_fnn(self, X, block_id):
        ffn1 = nn.functional.relu(getattr(self, "FNN1" + str(block_id))(X)) 
        ffn2 = getattr(self, "FNN2" + str(block_id))(ffn1)
        return ffn2


In [7]:
class Encoder(Transformer):

    def __init__(self, num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff, device):
        super(Encoder, self).__init__(num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff, device)
        self._init_structure()

    def forward(self, X):
        x = self.word_embed(X)
        x = x + self.pos_embed

        for block_id in range(self.num_blocks):
            fx = self._compute_multi_head_attention(x, x, x, block_id)
            x = self._compute_layer_norm(fx, x, "LN1" + str(block_id))
 
            fx = self._compute_fnn(x, block_id)
            x = self._compute_layer_norm(fx, x, "LN3" + str(block_id))

        return x

In [8]:
class Decoder(Transformer):
    
    def __init__(self, num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff, device):
        super(Decoder, self).__init__(num_blocks, num_heads, vocab_size, seq_len, d_model, d_k, d_v, d_ff, device)
        self._init_structure(decoder_part=True)
        self.O_last = CustomLinear(in_features=d_model, out_features=vocab_size).to(self.device)

    def forward(self, X, encoder_output):
        x = self.word_embed(X)
        x = x + self.pos_embed[:, :x.shape[1], :]

        for block_id in range(self.num_blocks):
            fx = self._compute_multi_head_attention(x, x, x, block_id, mask=True)
            x = self._compute_layer_norm(fx, x, "LN1" + str(block_id))

            fx = self._compute_multi_head_attention(x, encoder_output, encoder_output, block_id, connection_head=True)
            x = self._compute_layer_norm(fx, x, "LN2" + str(block_id))

            fx = self._compute_fnn(x, block_id)
            x = self._compute_layer_norm(fx, x, "LN3" + str(block_id))
        
        logits = self.O_last(x)
        return logits

In [9]:
def loss_function(logits, target):
    logits = logits.reshape(-1, logits.shape[-1])
    target = target.reshape(-1)
    return torch.nn.functional.cross_entropy(logits, target)

### Define hyperparameter for Transformer Model

In [10]:
NUM_BLOCKS = 2
NUM_HEADS = 2
DIMENSION_MODEL = 32
DIMENSION_K = 16
DIMENSION_V = 16
DIMENSION_FF = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [11]:
X = torch.Tensor(X).long().to(DEVICE)
Y = torch.Tensor(Y).long().to(DEVICE)

In [12]:
encoder = Encoder(num_blocks=NUM_BLOCKS, num_heads=NUM_HEADS, vocab_size=len(human_vocab), seq_len=Tx, 
                  d_model=DIMENSION_MODEL, d_k=DIMENSION_K, d_v=DIMENSION_V, d_ff=DIMENSION_FF, device=DEVICE)

decoder = Decoder(num_blocks=NUM_BLOCKS, num_heads=NUM_HEADS, vocab_size=len(machine_vocab), seq_len=Ty, 
                  d_model=DIMENSION_MODEL, d_k=DIMENSION_K, d_v=DIMENSION_V, d_ff=DIMENSION_FF, device=DEVICE)

In [13]:
epochs = 2
batch_size = 64
num_batches = X.shape[0]//batch_size if X.shape[0] % batch_size == 0 else X.shape[0]//batch_size + 1
data = torch.cat((X, Y), dim=1)

In [14]:
params_update = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params_update, lr=0.001)

In [15]:
for e in range(epochs):
    
    data = data[torch.randperm(data.shape[0])]
    
    X, Y = data[:, :Tx], data[:, Tx:]
    
    pbar = tqdm.tqdm_notebook(range(0, num_batches), desc="Epoch " + str(e+1))
    
    train_loss = 0
    
    for it in pbar:
        loss = 0
        start = it*batch_size
        end = (it+1)*batch_size
        
        encoder_output = encoder(X[start:end])
            
        logits = decoder(Y[start:end, :-1], encoder_output)
            
        loss = loss_function(logits, Y[start:end, 1:])
        
        optimizer.zero_grad()
        
        loss.backward()
        
        optimizer.step()
            
        train_loss += float(loss)
        
        pbar.set_description("Epoch %s - Training loss: %f" % (e+1, (train_loss / (it+1))))

HBox(children=(IntProgress(value=0, description='Epoch 1', max=625, style=ProgressStyle(description_width='ini…




HBox(children=(IntProgress(value=0, description='Epoch 2', max=625, style=ProgressStyle(description_width='ini…




In [17]:
EXAMPLES = ['3 May 1979', '5 April 09', '21th of August 2016', 'Tue 10 Jul 2007', 'Saturday May 9 2018', 'March 3 2001', 'March 3rd 2001', '1 March 2001']

for example in EXAMPLES:
    source = string_to_int(example, Tx, human_vocab)
    source = torch.Tensor([source]).long().to(DEVICE)

    encoder_output = encoder(source)
    sentence = [machine_vocab["#"]]

    for t in range(Ty):
        prediction = decoder(torch.Tensor([sentence]).long().to(DEVICE), encoder_output)
        prediction = torch.softmax(prediction, dim=-1)
        prediction = torch.argmax(prediction, dim=-1)
        sentence.append(prediction[0][-1])

    prediction = prediction.tolist()
    #sequential_output = [inv_machine_vocab[s] for s in sentence[1:]]
    parallel_output = [inv_machine_vocab[s] for s in prediction[0]]
    
    print("source:", example)
    print("parallel output:", ''.join(parallel_output))
    print("-----------------------------------------------")

source: 3 May 1979
parallel output: 1979-05-03
-----------------------------------------------
source: 5 April 09
parallel output: 1990-04-05
-----------------------------------------------
source: 21th of August 2016
parallel output: 2016-08-21
-----------------------------------------------
source: Tue 10 Jul 2007
parallel output: 2007-07-10
-----------------------------------------------
source: Saturday May 9 2018
parallel output: 2018-05-09
-----------------------------------------------
source: March 3 2001
parallel output: 2001-03-03
-----------------------------------------------
source: March 3rd 2001
parallel output: 2001-03-03
-----------------------------------------------
source: 1 March 2001
parallel output: 2001-03-01
-----------------------------------------------
