In [4]:
import tensorflow as tf

In [7]:
class Decoder(tf.keras.layers.Layer):
    
    def __init__(self, embedding_dim, num_heads, ff_hidden_dim, dropout_rate=0.1):
        super(Decoder, self).__init__()
        
        self.mha_1 = MHAttention(num_heads, embedding_dim)
        self.mha_2 = MHAttention(num_heads, embedding_dim)
        
        self.ff = feed_forward(embedding_dim, ff_hidden_dim)
        
        self.layernorm_1 = tf.keras.layers.LayerNormalization()
        self.layernorm_2 = tf.keras.layers.LayerNormalization()
        self.layernorm_3 = tf.keras.layers.LayerNormalization()
        
        self.dropout_1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout_2 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout_3 = tf.keras.layers.Dropout(dropout_rate)
        
    def call(self, input_tensor, encoder_output, training, mask):
        
        #Sublayer 1
        mha_1_output, _ = self.mha_1(input_tensor, input_tensor, input_tensor, mask)
        mha_1_output = self.dropout_1(mha_1_output, training=training)
        sublayer_1_output = self.layernorm_1(mha_1_output + input_tensor)
        
        #Sublayer 2
        mha_2_output, _ = self.mha_2(input_tensor, sublayer_1_output, sublayer_1_output, mask)
        mha_2_output = self.dropout_2(mha_2_output, training=training)
        sublayer_2_output = self.layernorm_2(mha_2_output + sublayer_1_output)
        
        #Sublayer 2
        ff_output = self.ff(sublayer_2_output)
        ff_output = self.dropout_3(ff_output, training=training)
        return self.layernorm_3(ff_output + sublayer_2_output)

In [9]:
temp_decoder = Decoder(512, 8, 1024)
y = tf.random.uniform((1, 60, 512))
output = temp_decoder(y, y, False, None)
print(output.shape)

(1, 60, 512)


In [16]:
class DecoderStack(tf.keras.layers.Layer):
    def __init__(self, num_decoders, embedding_dim, num_heads, ff_hidden_dim, dropout = 0.1):
        super(DecoderStack, self).__init__()
        
        self.num_decoders = num_decoders
        
        self.decoders = []
        for i in range(self.num_decoders):
            self.decoders.append(Decoder(embedding_dim, num_heads, ff_hidden_dim))
            
    def call(self, input_tensor, encoder_output, training, mask):
        
        output_tensor = input_tensor
        
        for i in range(self.num_decoders):
            output_tensor = self.decoders[i](output_tensor, encoder_output, training, mask)
            
        return output_tensor    

In [17]:
temp_decoder_stack = DecoderStack(6, 512, 8, 1024)
output = temp_decoder_stack(y, y, False, None)
output.shape

TensorShape([1, 60, 512])