## Transformer put together as a model class

In [None]:
class WideAndDeepModel(K.Model):
    def __init__(self, units=30, activation='relu', **kwargs):
        super().__init__(**kwargs)
        self.norm_layer_wide=K.layers.Normalization()
        self.norm_layer_deep=K.layers.Normalization()
        self.hidden_layer1=K.layers.Dense(units, activation=activation)
        self.hidden_layer2=K.layers.Dense(units, activation=activation)
        self.main_output=K.layers.Dense(1, name='main_output')
        self.aux_output=K.layers.Dense(1, name='aux_output')
        
    def call(self, inputs):
        input_wide=inputs[0]
        input_deep=inputs[1]
        norm_wide=self.norm_layer_wide(input_wide)
        norm_deep=self.norm_layer_deep(input_deep)
        hidden1=self.hidden_layer1(norm_deep)
        hidden2=self.hidden_layer2(hidden1)
        concat=K.layers.concatenate([norm_wide, hidden2])
        
        return {'main_output':self.main_output(concat), 'aux_output':self.aux_output(hidden2)}

In [None]:
model=WideAndDeepModel(30, activation='relu', name='my_cool_model')

In [None]:
class Transformer(K.Model):
    def __init__(self, 
                 vocab_size, 
                 embed_size, 
                 max_length,
                 max_length = 50,
                 N = 2,
                 num_heads = 8,
                 dropout_rate = 0.1,
                 n_units = 128,
                **kwargs):
        super().__init__(**kwargs)
        self.vocab_size=vocab_size
        self.embed_size = embed_size
        self.max_length = max_length 
        self.N = N
        self.max_length = max_length
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.n_units = n_units

        # Encoder and decoder inputs
        #self.encoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)
        #self.decoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)

        # Encoder and decoder tokenized inputs embedding in embed_size dimensional space
        # Maskings zeros ignores contribution from padding zeros to the loss
        self.encoder_embedding_layer = tf.keras.layers.Embedding(self.vocab_size, self.embed_size, mask_zero=True)
        self.decoder_embedding_layer = tf.keras.layers.Embedding(self.vocab_size, self.embed_size, mask_zero=True)

        self.pos_embed_layer = PositionalEncoding(self.max_length, self.embed_size)

        self.attn_layer_enc = tf.keras.layers.MultiHeadAttention(
            num_heads=self.num_heads, key_dim=self.embed_size, dropout=self.dropout_rate)

        self.norm_layer = tf.keras.layers.LayerNormalization()
        self.add_layer = tf.keras.layers.Add()

        self.dense_enc_0 = tf.keras.layers.Dense(self.n_units, activation="relu")
        self.dense_enc_1 = tf.keras.layers.Dense(self.embed_size)
        self.dropout = tf.keras.layers.Dropout(self.dropout_rate)

        self.attn_layer_dec_0 = tf.keras.layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_size, dropout=self.dropout_rate)
        self.attn_layer_dec_1 = tf.keras.layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.embed_size, dropout=self.dropout_rate)

        self.dense_dec_0 = tf.keras.layers.Dense(self.n_units, activation="relu")
        self.dense_dec_1 = tf.keras.layers.Dense(self.embed_size)

        self.dense_output = tf.keras.layers.Dense(self.vocab_size, activation="softmax")
        
        def call(self, inputs):
            encoder_inputs = inputs[0]
            decoder_inputs = inputs[1]

            # Encoder and decoder inputs tokenization
            # At this point tokenizers are already adapted above
            encoder_input_ids = text_vec_layer_en(encoder_inputs)
            decoder_input_ids = text_vec_layer_es(decoder_inputs)

            # Casting to float32 for consistency
            encoder_input_ids = tf.cast(encoder_input_ids, tf.float32)
            decoder_input_ids = tf.cast(decoder_input_ids, tf.float32)

            encoder_embeddings = self.encoder_embedding_layer(encoder_input_ids)
            decoder_embeddings = self.decoder_embedding_layer(decoder_input_ids)

            batch_max_len_dec = tf.shape(decoder_embeddings)[1]

            encoder_in = self.pos_embed_layer(encoder_embeddings)
            decoder_in = self.pos_embed_layer(decoder_embeddings)

            encoder_pad_mask = tf.math.not_equal(encoder_input_ids, 0)[:, tf.newaxis]
            
            # Input data
            Z = encoder_in

            # Encoder block
            for _ in range(self.N):
                skip = Z
    
                Z = self.attn_layer_enc(Z, value=Z, attention_mask=encoder_pad_mask)
                Z = self.norm_layer()(sef.add_layer([Z, skip]))
                
                skip = Z
                Z = dense_enc_0(Z)
                Z = dense_enc_1(Z)
                Z = self.dropout(Z)
                Z = self.norm_layer(self.add_layer([Z, skip]))

            
            decoder_pad_mask = tf.math.not_equal(decoder_input_ids, 0)[:, tf.newaxis]

            causal_mask = tf.linalg.band_part(
                tf.ones((self.batch_max_len_dec, self.batch_max_len_dec), tf.bool), -1, 0)

            encoder_outputs = Z
            Z = decoder_in

            
            for _ in range(self.N):
                skip = Z
    
                Z = sef.attn_layer_dec_0(Z, value=Z, attention_mask=causal_mask & decoder_pad_mask)
                Z = tf.norm_layer(self.add_layer([Z, skip]))
                
                skip = Z
    
                # Cross-Attenion: Query from decoder, Key and Value from Encoder
                Z = self.attn_layer_dec_1(Z, value=encoder_outputs, attention_mask=encoder_pad_mask)
                Z = self.norm_layer(self.add_layer([Z, skip]))
                
                skip = Z
                Z = self.dense_dec_0(Z)
                Z = self.dense_dec_1(Z)
                Z = self.norm_layer(self.add_layer([Z, skip]))

            Y_proba = self.dense_output(Z)
            #Y_proba._keras_mask = Y_proba._keras_mask[:, :, tf.newaxis]#, :, tf.newaxis]
            return Y_proba
