### The Illustrated Transformer 
http://jalammar.github.io/illustrated-transformer/

In [1]:
from keras.layers import LSTM, Dense, Input, Add, Concatenate
from keras.layers import BatchNormalization
from keras.models import Model
import keras
import tensorflow as tf
import numpy as np

In [2]:
N = 100
num_encoder_tokens = 20
num_decoder_tokens = 30
max_encoder_seq_length = 128 # 512 in original paper
max_decoder_seq_length = 128 # 512 in original paper
latent_dim = 32 # 64 in original pape

In [3]:
class SelfAttention(keras.layers.Layer):
    def __init__(self, units):
        super(SelfAttention,self).__init__()
        self.units = units
        self.Key = Dense(units = self.units, activation = 'relu')
        self.Query = Dense(units = self.units, activation = 'relu')
        self.Value = Dense(units = self.units, activation = 'relu')
    def call(self, encoder_input):
        K = self.Key(encoder_input) # Shape(N, num_encoder_tokens, latent_dim)
        Q = self.Query(encoder_input) # Shape(N, num_encoder_tokens, latent_dim)
        V = self.Value(encoder_input) # Shape(N, num_encoder_tokens, latent_dim)
        x = tf.matmul(Q,K,transpose_b = True)/8 # sqrt(latent_dim) # Shape(N, num_encoder_tokens, num_encoder_tokens)
        x = tf.nn.softmax(x,axis = 2) # Shape(N, num_encoder_tokens, num_encoder_tokens)
        return tf.matmul(x,V) # Shape(N, num_encoder_tokens, latent_dim)

In [4]:
class Encoder(keras.layers.Layer):
    def __init__(self, latent_dim, max_encoder_seq_length):
        super(Encoder,self).__init__()
        self.latent_dim = latent_dim
        self.max_encoder_seq_length = max_encoder_seq_length
        self.selfattention_layer_1 = SelfAttention(self.latent_dim)
        self.selfattention_layer_2 = SelfAttention(self.latent_dim)
        self.cat = Concatenate(axis=2)
        self.reduce_dim = Dense(units = self.latent_dim, activation = 'relu')
        self.up_dim = Dense(units = self.max_encoder_seq_length, activation = 'relu')
        self.add_norm_1 = BatchNormalization()
        self.add_norm_2 = BatchNormalization()
        self.add= Add()
        self.feed_forward = Dense(units = self.max_encoder_seq_length, activation = 'relu')
    def call(self, encoder_input):
        selfattention_layer_1_output = self.selfattention_layer_1(encoder_input)
        selfattention_layer_2_output = self.selfattention_layer_2(encoder_input)
        x = self.cat([selfattention_layer_1_output,selfattention_layer_2_output])
        x = self.reduce_dim(x)
        x = self.up_dim(x)
        add_norm_output = self.add_norm_1(self.add([encoder_input,x]),training=True)
        x = self.feed_forward(add_norm_output)
        return self.add_norm_2(self.add([add_norm_output,x]),training=True)

In [5]:
class EncoderDecoderAttention(keras.layers.Layer):
    def __init__(self, units, dim):
        super(EncoderDecoderAttention,self).__init__()
        self.units = units
        self.dim = dim
        self.Key = Dense(units = self.units, activation = 'relu')
        self.Query = Dense(units = self.units, activation = 'relu')
        self.Value = Dense(units = self.units, activation = 'relu')
        self.up_dim_context = Dense(units = self.dim, activation = 'relu')
    def call(self, encoder_output, decoder_input):
        K = self.Key(encoder_output) # Shape(N, num_encoder_tokens, latent_dim)
        Q = self.Query(decoder_input) # Shape(N, num_decoder_tokens, latent_dim)
        V = self.Value(encoder_output) # Shape(N, num_encoder_tokens, latent_dim)
        x = tf.matmul(Q,K,transpose_b = True)/8 # sqrt(latent_dim) # Shape(N, num_encoder_tokens, num_encoder_tokens)
        x = tf.nn.softmax(x,axis = 2) # Shape(N, num_encoder_tokens, num_encoder_tokens)
        x = tf.matmul(x,V) # Shape(N, num_encoder_tokens, latent_dim)
        return self.up_dim_context(x)

In [6]:
class Decoder(keras.layers.Layer):
    def __init__(self, latent_dim, max_decoder_seq_length):
        super(Decoder,self).__init__()
        self.latent_dim = latent_dim
        self.max_decoder_seq_length = max_decoder_seq_length
        self.selfattention_layer_1 = SelfAttention(self.latent_dim)
        self.selfattention_layer_2 = SelfAttention(self.latent_dim)
        self.cat = Concatenate(axis=2)
        self.reduce_dim = Dense(units = self.latent_dim, activation = 'relu')
        self.up_dim = Dense(units = self.max_decoder_seq_length, activation = 'relu')
        self.add_norm_1 = BatchNormalization()
        self.add_norm_2 = BatchNormalization()
        self.add_norm_3 = BatchNormalization()
        self.add= Add()
        self.feed_forward = Dense(units = self.max_decoder_seq_length, activation = 'relu')
        self.en_de_attention = EncoderDecoderAttention(self.latent_dim, self.max_decoder_seq_length)
    def call(self, encoder_output, decoder_input):
        selfattention_layer_1_output = self.selfattention_layer_1(decoder_input)
        selfattention_layer_2_output = self.selfattention_layer_2(decoder_input)
        x = self.cat([selfattention_layer_1_output,selfattention_layer_2_output])
        x = self.reduce_dim(x)
        x = self.up_dim(x)
        add_norm_output_1 = self.add_norm_1(self.add([decoder_input,x]),training=True)
        x = self.en_de_attention(encoder_output, add_norm_output_1)
        add_norm_output_2 = self.add_norm_2(self.add([add_norm_output_1,x]),training=True)
        x = self.feed_forward(add_norm_output_2)
        return self.add_norm_3(self.add([add_norm_output_2,x]),training=True)

In [7]:
# Encoder
encoder_input = Input(shape=(None, max_encoder_seq_length))
encoder_block_1 = Encoder(latent_dim, max_encoder_seq_length)
x = encoder_block_1(encoder_input)
encoder_block_2 = Encoder(latent_dim, max_encoder_seq_length)
encoder_block_2_output = encoder_block_2(x)
# Decoder
decoder_input = Input(shape=(None, max_decoder_seq_length))
decoder_block_1 = Decoder(latent_dim, max_decoder_seq_length)
x = decoder_block_1(encoder_block_2_output, decoder_input)
decoder_block_2 = Decoder(latent_dim, max_decoder_seq_length)
x = decoder_block_2(encoder_block_2_output, x)
# Linear - Softmax
linear = Dense(units = latent_dim, activation = 'relu')
x = linear(x)
softmax = Dense(units = max_decoder_seq_length, activation = 'softmax')
output= softmax(x)

In [8]:
model = Model([encoder_input,decoder_input], output)

In [9]:
model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

In [10]:
print(model.summary())

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None, 128)]  0           []                               
                                                                                                  
 encoder (Encoder)              (None, None, 128)    48608       ['input_1[0][0]']                
                                                                                                  
 encoder_1 (Encoder)            (None, None, 128)    48608       ['encoder[0][0]']                
                                                                                                  
 input_2 (InputLayer)           [(None, None, 128)]  0           []                               
                                                                                              

In [11]:
np.random.seed(0)
input_encoder = np.random.randint(10,size = (N, num_encoder_tokens, max_encoder_seq_length))

In [12]:
np.random.seed(1)
input_decoder = np.random.randint(10,size = (N, num_decoder_tokens, max_decoder_seq_length))

In [13]:
output_decoder = model([input_encoder,input_decoder])

In [14]:
print(model.fit([input_encoder, input_decoder],output_decoder,batch_size = 10, epochs = 50))

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
<keras.callbacks.History object at 0x000002C826CE2FA0>
