In [None]:
import tensorflow as tf
import numpy as np

Simple linear layer:

In [None]:
class Linear(tf.Module):
    def __init__(self, input_dim, output_dim, name="Linear"):
        super().__init__(name=name)
        self.w = tf.Variable(tf.random.uniform([input_dim, output_dim]), name=name + "_w")
        self.b = tf.Variable(tf.zeros([output_dim]), name=name+"_b")
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x):
        return tf.matmul(x, self.w) + self.b

Gaussan Error Linear Unit activation for FFN.
Same as tf.keras.activations.Gelu(x, approximate=True):

In [None]:
@tf.function
def gelu_new(x):
    return 0.5*x*(1+tf.math.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.math.pow(x, 3))))

Create triangular mask for decoder layer, to make sure its attention uses only previous tokens:

In [None]:
@tf.function
def CasualMask(size): 
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) 
    return mask

FeedForwardNetwork. Make transformer block attention outputs fit to the input of the next transformer block:

In [None]:
class FFN(tf.Module):
    def __init__(self, embed_dim, ffn_dim, name="FeedForward"):
        super().__init__(name=name)
        self.dense0 = Linear(embed_dim, ffn_dim, name=name)
        self.dense1 = Linear(ffn_dim, embed_dim, name=name)
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x):
        out1 = self.dense0(x)
        out2 = gelu_new(out1)
        return self.dense1(out2)

Calculate attention outputs and weights. The mask is multiplied by -1e9 before softmax to assign zero weights for useless tokens that are marked with 1. Setting weights after softmax breaks the probability distribution.

In [None]:
def scaled_dot_product_attention(q, k, v, mask=None):
    qk = tf.matmul(q, k, transpose_b=True)
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    attention_logits = qk / tf.math.sqrt(dk)
    
    if mask is not None:
        attention_logits += (mask * -1e9)
        
    attention_weights = tf.nn.softmax(attention_logits, axis=-1)
    output = tf.matmul(attention_weights, v)
    return output, attention_weights

Multi Head Attention layer. This layer splits the Q, K, V into separate heads. Each value has a dimension divided by the number of heads after splitting.

In [None]:
class MHA(tf.Module):
    def __init__(self, num_heads, key_dim, name="MHA"):
        super().__init__(name=name)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.depth = key_dim // num_heads
        
        self.wq = Linear(key_dim, key_dim, name="MHA_query")
        self.wk = Linear(key_dim, key_dim, name="MHA_key")
        self.wv = Linear(key_dim, key_dim, name="MHA_value")
        
        self.dense = Linear(key_dim, key_dim, name="MHA_dense")
        
    @tf.function
    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, - 1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0,2,1,3])
    
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, v, q, k, mask):
        batch_size = tf.shape(q)[0]
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
        
        scaled_attention = tf.transpose(scaled_attention, perm=[0,2,1,3])
        concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.key_dim))
        
        output = self.dense(concat_attention)
        
        return output, attention_weights

Token and postitional embedding. If it is needed, this layer also returns the padding mask used in the scaled dot product attention function.

In [None]:
class Embedding(tf.Module):
    def __init__(self, vocab_size, embed_dim, maxlen, name="Embeddings"):
        super().__init__(name=name)
        self.w0 = tf.Variable(tf.compat.v1.glorot_uniform_initializer()((vocab_size, embed_dim)),name="token_embedding")
        self.w1 = tf.Variable(tf.compat.v1.glorot_uniform_initializer()((maxlen, embed_dim)),name="pos_embedding")
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, masking=False):
        x = tf.cast(x, tf.int32)
        
        if isinstance(x, tf.sparse.SparseTensor):
            sparse_inputs_expanded = tf.sparse.expand_dims(x, axis=-1)
            out = tf.nn.safe_embedding_lookup_sparse(embedding_weights=self.w0,sparse_ids=sparse_inputs_expanded, default_id=0)
            
        else:
            out = tf.nn.embedding_lookup(self.w0, x)
            
            maxlen = tf.shape(x)[-1]
            positions = tf.range(start=0, limit=maxlen, delta=1)
            pos_out = tf.nn.embedding_lookup(self.w1, positions)
            
        if masking:
            #Create mask to prevent encoder paying attention to PAD tokens
            mask = tf.cast(tf.math.equal(x, 0)[:, tf.newaxis, tf.newaxis, :], tf.float32)
            return out + pos_out, mask
        
        return out + pos_out

Layer normalization. Normalize to mean = 0, std = 1.

In [None]:
class LayerNorm(tf.Module):
    def __init__(self, epsilon=1e-5, axis=-1, name="LayerNorm"):
        super().__init__(name=name)
        self.gamma = None
        self.beta = None
        self.epsilon = epsilon
        self.axis = axis
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x):
        n_state = x.shape[-1]
        if self.gamma is None and self.beta is None:
            self.gamma = tf.Variable(tf.ones(n_state), name="layernorm_gamma")
            self.beta = tf.Variable(tf.zeros(n_state), name="layernorm_beta")
            
        u = tf.reduce_mean(x, axis=self.axis, keepdims=True)
        s = tf.reduce_mean(tf.square(x-u), axis=self.axis, keepdims=True)
        x = (x - u) * tf.math.rsqrt(s + self.epsilon)
        return x * self.gamma + self.beta

Create encoder block:

In [None]:
class Encoder(tf.Module):
    def __init__(self, num_heads, embed_dim, ffn_dim, name="encoder"):
        super().__init__(name=name)
        self.att = MHA(num_heads, embed_dim)
        self.ffn = FFN(embed_dim, ffn_dim)
        self.norm1 = LayerNorm(epsilon=1e-5)
        self.norm2 = LayerNorm(epsilon=1e-5)
        
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_dim
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, mask=None):
        att_output, att_weights = self.att(x,x,x, mask)
        att_output = tf.nn.dropout(att_output, rate=0.1)
        norm_output = self.norm1(x + att_output)
        ffn_out = self.ffn(norm_output)
        ffn_out = tf.nn.dropout(ffn_out, rate=0.1)
        
        return self.norm2(norm_output + ffn_out)

Create decoder block:

In [None]:
class Decoder(tf.Module):
    def __init__(self, num_heads, embed_dim, ffn_dim, name="decoder"):
        super().__init__(name=name)
        self.att1 = MHA(num_heads, embed_dim)
        self.att2 = MHA(num_heads, embed_dim)
        self.ffn = FFN(embed_dim, ffn_dim)
        
        self.layernorm1 = LayerNorm(epsilon=1e-5)
        self.layernorm2 = LayerNorm(epsilon=1e-5)
        self.layernorm3 = LayerNorm(epsilon=1e-5)
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, encoder_seq, mask=None):
        seq_len = tf.shape(x)[1]
        
        casual_mask = CasualMask(seq_len)
        
        if mask is not None:
            mask = tf.minimum(tf.cast(mask, tf.float32), casual_mask)
            
        att1_output, _ = self.att1(x, x, x, casual_mask)
        att1_dropout = tf.nn.dropout(att1_output, rate=0.1)
        out1 = self.layernorm1(x+att1_dropout)
        
        att2_output, _ = self.att2(att1_output, encoder_seq, encoder_seq, mask)
        att2_dropout = tf.nn.dropout(att2_output, rate=0.1)
        out2 = self.layernorm2(out1 + att2_dropout)
        
        ffn_output = self.ffn(out2)
        ffn_dropout = tf.nn.dropout(ffn_output, rate=0.1)
        return self.layernorm3 (out2+ffn_dropout)

These layers makes it possible to use multiple encoder and decoder layers by iterating over the specified number of layers.

In [None]:
class Encoder_layer(tf.Module):
    def __init__(self, num_layers, num_heads, embed_dim, ffn_dim, name="encoder_layer"):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.enc_layers = [Encoder(num_heads, embed_dim, ffn_dim, name="encoder"+str(_)) #Every block needs a name
                    for _ in range(num_layers)]                                          #Or it will not be recognized
                                                                                         #After loading model from 
                                                                                         #tf.saved_model
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, mask=None):
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, mask)
        return x

In [None]:
class Decoder_layer(tf.Module):
    def __init__(self, num_layers, num_heads, embed_dim, ffn_dim, name="decoder_layer"):
        super().__init__(name=name)
        self.num_layers = num_layers
        self.dec_layers = [Decoder(num_heads, embed_dim, ffn_dim, name="decoder"+str(_))
                          for _ in range(num_layers)]
        
    @tf.function
    @tf.Module.with_name_scope
    def __call__(self, x, encoder_seq, mask=None):
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, encoder_seq, mask)
        return x

Define arguments for building the model

In [None]:
num_layers = 6
num_heads = 8
ffn_dim = 512
embed_dim = 128
vocab_size = 31000
maxlen = 200

Build the transformer model:

In [None]:
class Transformer(tf.Module):
    def __init__(self, num_layers, num_heads, ffn_dim, embed_dim, vocab_size, maxlen, name="transformer"):
        super().__init__(name=name)
        
        self.maxlen = maxlen
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_dim
        self.num_layers = num_layers
        
        #Weights and bias of output dense layer
        self.output_w = tf.Variable(tf.compat.v1.glorot_uniform_initializer()((embed_dim,vocab_size)), name="model_output_w")
        self.output_b = tf.cast(tf.Variable(tf.zeros(vocab_size), name="model_output_b"), tf.float32)
        
        self.embedding_encoder = Embedding(vocab_size, embed_dim, maxlen)
        self.embedding_decoder = Embedding(vocab_size, embed_dim, maxlen)
        
        self.encoder = Encoder_layer(num_layers, num_heads, embed_dim, ffn_dim)
        self.decoder = Decoder_layer(num_layers, num_heads, embed_dim, ffn_dim)

    @tf.function(input_signature=[tf.TensorSpec(shape=[None, maxlen], dtype=tf.int32), tf.TensorSpec(shape=[None,maxlen], dtype=tf.int32)])
    @tf.Module.with_name_scope
    def __call__(self, inp, targ):
        encoder_emb_output, mask = self.embedding_encoder(inp, masking=True)
        encoder1_output = self.encoder(encoder_emb_output, mask=mask)
        
        decoder_emb_output = self.embedding_decoder(targ)
        decoder1_output = self.decoder(decoder_emb_output, encoder1_output, mask=mask)

        return tf.matmul(decoder1_output, self.output_w) + self.output_b

In [None]:
transformer = Transformer(num_layers, num_heads, ffn_dim, embed_dim, vocab_size, maxlen)

In [None]:
tf.saved_model.save(transformer, "transformer")