In [4]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, LSTM, MultiHeadAttention, Dense
from tensorflow.keras.layers import Concatenate, TimeDistributed

In [13]:
embedding_inputs = Input(shape=(32), name = 'Encoder_Input') #max length of sentence
embedding = Embedding(1024, 32, trainable=False, name = 'Encoder_Embedding')(embedding_inputs) #vocab_size_article, neurons

encoder_layer_1 = LSTM(32, return_sequences=True, return_state=True, name = 'LSTM_1') #neurons 
encoder_output_1, state_h1, state_c1 = encoder_layer_1(embedding)

encoder_layer_2 = LSTM(32, return_sequences=True, return_state=True, name = 'LSTM_2') #neurons 
encoder_output_2, state_h2, state_c2 = encoder_layer_2(encoder_output_1)

encoder_layer_3 = LSTM(32, return_sequences=True, return_state=True, name = 'LSTM_3') #neurons 
encoder_output_3, state_h3, state_c3 = encoder_layer_3(encoder_output_2)


decoder_inputs = Input(shape=(None,), name = 'Decoder_Input')
decoder_embedding = Embedding(1024, 32, trainable=False, name = 'Decoder_Embedding')(decoder_inputs)#vocab_size_summary, neurons

decoder_layer_1 = LSTM(32, return_sequences = True, return_state = True, name = 'Decoder_LSTM')
decoder_output_1, decoder_state_h1, decoder_state_c1 = decoder_layer_1(decoder_embedding, 
                                                                       initial_state=[state_h3, state_c3])

attention_layer = MultiHeadAttention(num_heads = 2, key_dim = 2, name = 'Attention')
attn_out, attn_state = attention_layer(encoder_output_3, decoder_output_1, return_attention_scores = True)

decoder_concat = Concatenate(axis=-1)([decoder_output_1, attn_out])

decoder_dense = TimeDistributed(Dense(1024, activation = 'softmax'))
decoder_outputs = decoder_dense(decoder_concat)

model = Model([embedding_inputs, decoder_inputs], decoder_outputs)
model.summary()

Model: "model_5"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 Encoder_Input (InputLayer)  [(None, 32)]                 0         []                            
                                                                                                  
 Encoder_Embedding (Embeddi  (None, 32, 32)               32768     ['Encoder_Input[0][0]']       
 ng)                                                                                              
                                                                                                  
 LSTM_1 (LSTM)               [(None, 32, 32),             8320      ['Encoder_Embedding[0][0]']   
                              (None, 32),                                                         
                              (None, 32)]                                                   