In [15]:
# The Illustrated Transformer 
# http://jalammar.github.io/illustrated-transformer/

In [1]:
from keras.layers import LSTM, Dense, Input, Add, Concatenate
from keras.layers import BatchNormalization
from keras.models import Model
import keras
import tensorflow as tf
import numpy as np

In [9]:
N = 1000
num_encoder_tokens = 128
num_decoder_tokens = 128
max_encoder_seq_length = 512 # 512 in original paper
max_decoder_seq_length = 512 # 512 in original paper
latent_dim = 64 # 64 in original pape

In [3]:
class SelfAttention(keras.layers.Layer):
    def __init__(self, units):
        super(SelfAttention,self).__init__()
        self.units = units
        self.Key = Dense(units = self.units, activation = 'relu')
        self.Query = Dense(units = self.units, activation = 'relu')
        self.Value = Dense(units = self.units, activation = 'relu')
    def call(self, encoder_input):
        K = self.Key(encoder_input) # Shape(N, num_encoder_tokens, latent_dim)
        Q = self.Query(encoder_input) # Shape(N, num_encoder_tokens, latent_dim)
        V = self.Value(encoder_input) # Shape(N, num_encoder_tokens, latent_dim)
        x = tf.matmul(Q,K,transpose_b = True)/8 # sqrt(latent_dim) # Shape(N, num_encoder_tokens, num_encoder_tokens)
        x = tf.nn.softmax(x,axis = 2) # Shape(N, num_encoder_tokens, num_encoder_tokens)
        return tf.matmul(x,V) # Shape(N, num_encoder_tokens, latent_dim)

In [13]:
class Encoder(keras.layers.Layer):
    def __init__(self, latent_dim, max_encoder_seq_length):
        super(Encoder,self).__init__()
        self.latent_dim = latent_dim
        self.max_encoder_seq_length = max_encoder_seq_length
        self.selfattention_layer_1 = SelfAttention(self.latent_dim)
        self.selfattention_layer_2 = SelfAttention(self.latent_dim)
        self.selfattention_layer_3 = SelfAttention(self.latent_dim)
        self.selfattention_layer_4 = SelfAttention(self.latent_dim)
        self.selfattention_layer_5 = SelfAttention(self.latent_dim)
        self.selfattention_layer_6 = SelfAttention(self.latent_dim)
        self.selfattention_layer_7 = SelfAttention(self.latent_dim)
        self.selfattention_layer_8 = SelfAttention(self.latent_dim)
        self.cat = Concatenate(axis=2)
        self.reduce_dim = Dense(units = self.latent_dim, activation = 'relu')
        self.up_dim = Dense(units = self.max_encoder_seq_length, activation = 'relu')
        self.add_norm_1 = BatchNormalization()
        self.add_norm_2 = BatchNormalization()
        self.add= Add()
        self.feed_forward = Dense(units = self.max_encoder_seq_length, activation = 'relu')
    def call(self, encoder_input):
        selfattention_layer_1_output = self.selfattention_layer_1(encoder_input)
        selfattention_layer_2_output = self.selfattention_layer_2(encoder_input)
        selfattention_layer_3_output = self.selfattention_layer_3(encoder_input)
        selfattention_layer_4_output = self.selfattention_layer_4(encoder_input)
        selfattention_layer_5_output = self.selfattention_layer_5(encoder_input)
        selfattention_layer_6_output = self.selfattention_layer_6(encoder_input)
        selfattention_layer_7_output = self.selfattention_layer_7(encoder_input)
        selfattention_layer_8_output = self.selfattention_layer_8(encoder_input)
        x = self.cat([selfattention_layer_1_output,selfattention_layer_2_output,
                    selfattention_layer_3_output, selfattention_layer_4_output,
                    selfattention_layer_5_output, selfattention_layer_6_output,
                    selfattention_layer_7_output, selfattention_layer_8_output])
        x = self.reduce_dim(x)
        x = self.up_dim(x)
        add_norm_output = self.add_norm_1(self.add([encoder_input,x]),training=True)
        x = self.feed_forward(add_norm_output)
        return self.add_norm_2(self.add([add_norm_output,x]),training=True)

In [5]:
class EncoderDecoderAttention(keras.layers.Layer):
    def __init__(self, units, dim):
        super(EncoderDecoderAttention,self).__init__()
        self.units = units
        self.dim = dim
        self.Key = Dense(units = self.units, activation = 'relu')
        self.Query = Dense(units = self.units, activation = 'relu')
        self.Value = Dense(units = self.units, activation = 'relu')
        self.up_dim_context = Dense(units = self.dim, activation = 'relu')
    def call(self, encoder_output, decoder_input):
        K = self.Key(encoder_output) # Shape(N, num_encoder_tokens, latent_dim)
        Q = self.Query(decoder_input) # Shape(N, num_decoder_tokens, latent_dim)
        V = self.Value(encoder_output) # Shape(N, num_encoder_tokens, latent_dim)
        x = tf.matmul(Q,K,transpose_b = True)/8 # sqrt(latent_dim) # Shape(N, num_encoder_tokens, num_encoder_tokens)
        x = tf.nn.softmax(x,axis = 2) # Shape(N, num_encoder_tokens, num_encoder_tokens)
        x = tf.matmul(x,V) # Shape(N, num_encoder_tokens, latent_dim)
        return self.up_dim_context(x)

In [14]:
class Decoder(keras.layers.Layer):
    def __init__(self, latent_dim, max_decoder_seq_length):
        super(Decoder,self).__init__()
        self.latent_dim = latent_dim
        self.max_decoder_seq_length = max_decoder_seq_length
        self.selfattention_layer_1 = SelfAttention(self.latent_dim)
        self.selfattention_layer_2 = SelfAttention(self.latent_dim)
        self.selfattention_layer_3 = SelfAttention(self.latent_dim)
        self.selfattention_layer_4 = SelfAttention(self.latent_dim)
        self.selfattention_layer_5 = SelfAttention(self.latent_dim)
        self.selfattention_layer_6 = SelfAttention(self.latent_dim)
        self.selfattention_layer_7 = SelfAttention(self.latent_dim)
        self.selfattention_layer_8 = SelfAttention(self.latent_dim)
        self.cat = Concatenate(axis=2)
        self.reduce_dim = Dense(units = self.latent_dim, activation = 'relu')
        self.up_dim = Dense(units = self.max_decoder_seq_length, activation = 'relu')
        self.add_norm_1 = BatchNormalization()
        self.add_norm_2 = BatchNormalization()
        self.add_norm_3 = BatchNormalization()
        self.add= Add()
        self.feed_forward = Dense(units = self.max_decoder_seq_length, activation = 'relu')
        self.en_de_attention = EncoderDecoderAttention(self.latent_dim, self.max_decoder_seq_length)
    def call(self, encoder_output, decoder_input):
        selfattention_layer_1_output = self.selfattention_layer_1(decoder_input)
        selfattention_layer_2_output = self.selfattention_layer_2(decoder_input)
        selfattention_layer_3_output = self.selfattention_layer_3(decoder_input)
        selfattention_layer_4_output = self.selfattention_layer_4(decoder_input)
        selfattention_layer_5_output = self.selfattention_layer_5(decoder_input)
        selfattention_layer_6_output = self.selfattention_layer_6(decoder_input)
        selfattention_layer_7_output = self.selfattention_layer_7(decoder_input)
        selfattention_layer_8_output = self.selfattention_layer_8(decoder_input)
        x = self.cat([selfattention_layer_1_output,selfattention_layer_2_output,
                    selfattention_layer_3_output, selfattention_layer_4_output,
                    selfattention_layer_5_output, selfattention_layer_6_output,
                    selfattention_layer_7_output, selfattention_layer_8_output])
        x = self.reduce_dim(x)
        x = self.up_dim(x)
        add_norm_output_1 = self.add_norm_1(self.add([decoder_input,x]),training=True)
        x = self.en_de_attention(encoder_output, add_norm_output_1)
        add_norm_output_2 = self.add_norm_2(self.add([add_norm_output_1,x]),training=True)
        x = self.feed_forward(add_norm_output_2)
        return self.add_norm_3(self.add([add_norm_output_2,x]),training=True)

In [15]:
# Encoder
encoder_input = Input(shape=(None, max_encoder_seq_length))
encoder_block_1 = Encoder(latent_dim, max_encoder_seq_length)
x = encoder_block_1(encoder_input)
encoder_block_2 = Encoder(latent_dim, max_encoder_seq_length)
x = encoder_block_2(x)
encoder_block_3 = Encoder(latent_dim, max_encoder_seq_length)
x = encoder_block_3(x)
encoder_block_4 = Encoder(latent_dim, max_encoder_seq_length)
x = encoder_block_4(x)
encoder_block_5 = Encoder(latent_dim, max_encoder_seq_length)
x = encoder_block_5(x)
encoder_block_6 = Encoder(latent_dim, max_encoder_seq_length)
encoder_block_6_output = encoder_block_6(x)
# Decoder
decoder_input = Input(shape=(None, max_decoder_seq_length))
decoder_block_1 = Decoder(latent_dim, max_decoder_seq_length)
x = decoder_block_1(encoder_block_6_output, decoder_input)
decoder_block_2 = Decoder(latent_dim, max_decoder_seq_length)
x = decoder_block_2(encoder_block_6_output, x)
decoder_block_3 = Decoder(latent_dim, max_decoder_seq_length)
x = decoder_block_3(encoder_block_6_output, x)
decoder_block_4 = Decoder(latent_dim, max_decoder_seq_length)
x = decoder_block_4(encoder_block_6_output, x)
decoder_block_5 = Decoder(latent_dim, max_decoder_seq_length)
x = decoder_block_5(encoder_block_6_output, x)
decoder_block_6 = Decoder(latent_dim, max_decoder_seq_length)
x = decoder_block_6(encoder_block_6_output, x)
# Linear - Softmax
linear = Dense(units = latent_dim, activation = 'relu')
x = linear(x)
softmax = Dense(units = max_decoder_seq_length, activation = 'softmax')
output= softmax(x)

In [16]:
model = Model([encoder_input,decoder_input], output)

In [9]:
model.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

In [17]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_4 (InputLayer)           [(None, None, 512)]  0           []                               
                                                                                                  
 encoder_3 (Encoder)            (None, None, 512)    1120832     ['input_4[0][0]']                
                                                                                                  
 encoder_4 (Encoder)            (None, None, 512)    1120832     ['encoder_3[0][0]']              
                                                                                                  
 encoder_5 (Encoder)            (None, None, 512)    1120832     ['encoder_4[0][0]']              
                                                                                            

In [18]:
# np.random.seed(0)
# input_encoder = np.random.randint(10,size = (N, num_encoder_tokens, max_encoder_seq_length))

In [19]:
# np.random.seed(1)
# input_decoder = np.random.randint(10,size = (N, num_decoder_tokens, max_decoder_seq_length))

In [20]:
# output_decoder = model([input_encoder,input_decoder])

In [21]:
# model.fit([input_encoder, input_decoder],output_decoder,batch_size = 50, epochs = 20)