In [204]:
import tensorflow as tf

In [215]:
def normalized_dot_product_attention(Q, K, V, mask):
    """
    Calculates the dot product attention weights
    and returns a matrix Z, which is the same size as 
    Q, K, and V.
    
    Parameters:
    
    Q - The result of applying Wq to X. 
      - Shape (..., sequence_len, dim_Wq)
    K - The result of applying Wk to X. 
      - Shape (..., sequence_len, dim_Wk)
    V - The result of applying Wv to X. 
      - Shape (..., sequence_len, dim_Wv)
    
    Returns:
    
    Z - The matrix created by applying the scaled attention 
            weights to the V matrix.
      - Shape (..., sequence_len, model_dim)
      
    """
    #Normalizing Q
    Q = tf.divide(Q, tf.norm(Q, axis=-1, keepdims=True))
    
    #Normalizing K
    K = tf.divide(K, tf.norm(K, axis=-1, keepdims=True))
    
    #compute the dot product of all query and key vectors (b/c they are normalized 0 >= values <=1
    attention_logits = tf.matmul(Q, K, transpose_b=True)
    
    attention_logits *= 1e2
        
    if mask is not None:
        attention_logits += (mask * -1e9)
    #apply softmax to find the weights
    attention_weights = tf.nn.softmax(attention_logits, axis=-1)
    #multiply the weights by the Value matrix
    Z = tf.matmul(attention_weights, V)
    
    return Z, attention_weights

In [216]:
class MHAttention(tf.keras.layers.Layer):
    
    def __init__(self, num_heads, embedding_dim):
        super(MHAttention, self).__init__()
        assert embedding_dim % num_heads == 0

        self.head_dim = embedding_dim // num_heads
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim

        self.Wq = tf.keras.layers.Dense(self.embedding_dim)
        self.Wk = tf.keras.layers.Dense(self.embedding_dim)
        self.Wv = tf.keras.layers.Dense(self.embedding_dim)

        self.Wz = tf.keras.layers.Dense(self.embedding_dim)
        
    def create_heads(self, x, batch_size):
        
        return tf.reshape(tf.transpose(x), (batch_size, self.num_heads, -1, self.head_dim))
        
    def call(self, q, k, v, mask):
         
        batch_size = q.shape[0]
        
        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)
        
        q = self.create_heads(q, batch_size)
        k = self.create_heads(k, batch_size)
        v = self.create_heads(v, batch_size)
        
        z, attention_weights = normalized_dot_product_attention(q, k, v, mask)
        
        concat_z = tf.transpose(z, perm=[0, 2, 1, 3])
        
        concat_z = tf.reshape(concat_z, (batch_size, -1, self.embedding_dim))
        
        z = self.Wz(concat_z)
        
        return z, attention_weights


In [217]:
def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

    # add extra dimensions to add the padding
    # to the attention logits.
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

In [218]:
def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len) 

In [219]:
def create_masks(inp, tar):
    # Encoder padding mask
    enc_padding_mask = create_padding_mask(inp)

    # Used in the 2nd attention block in the decoder.
    # This padding mask is used to mask the encoder outputs.
    dec_padding_mask = create_padding_mask(inp)

    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by 
    # the decoder.
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return enc_padding_mask, combined_mask, dec_padding_mask

In [220]:
y = tf.constant([[1, 1, 2, 0, 0]], dtype=tf.float32)
y

<tf.Tensor: shape=(1, 5), dtype=float32, numpy=array([[1., 1., 2., 0., 0.]], dtype=float32)>

In [221]:
embed = tf.keras.layers.Embedding(3, 8)
y_embedded = embed(y)
y_embedded

<tf.Tensor: shape=(1, 5, 8), dtype=float32, numpy=
array([[[ 0.04474348,  0.00889816,  0.03163097,  0.01507231,
         -0.01323302, -0.01447612,  0.01418469, -0.01816808],
        [ 0.04474348,  0.00889816,  0.03163097,  0.01507231,
         -0.01323302, -0.01447612,  0.01418469, -0.01816808],
        [ 0.04628273, -0.02669678,  0.03306064,  0.01664596,
         -0.03828354,  0.02165109,  0.00331752, -0.04323517],
        [-0.00194035, -0.03082768, -0.01232624, -0.01268417,
         -0.00213158, -0.0156741 , -0.00873809,  0.00684651],
        [-0.00194035, -0.03082768, -0.01232624, -0.01268417,
         -0.00213158, -0.0156741 , -0.00873809,  0.00684651]]],
      dtype=float32)>

In [222]:
enc_padding_mask, look_ahead_mask, dec_padding_mask = create_masks(y, y)
create_masks(y, y)

(<tf.Tensor: shape=(1, 1, 1, 5), dtype=float32, numpy=array([[[[0., 0., 0., 1., 1.]]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1, 5, 5), dtype=float32, numpy=
 array([[[[0., 1., 1., 1., 1.],
          [0., 0., 1., 1., 1.],
          [0., 0., 0., 1., 1.],
          [0., 0., 0., 1., 1.],
          [0., 0., 0., 1., 1.]]]], dtype=float32)>,
 <tf.Tensor: shape=(1, 1, 1, 5), dtype=float32, numpy=array([[[[0., 0., 0., 1., 1.]]]], dtype=float32)>)

In [223]:
mha = MHAttention(1, 8)
out, attn_weights = mha(y_embedded, y_embedded, y_embedded, enc_padding_mask)

In [224]:
attn_weights.numpy()[0]

array([[[0.0000000e+00, 2.6526348e-19, 1.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [1.6504135e-38, 1.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 1.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [5.8382332e-28, 1.0000000e+00, 5.0394636e-12, 0.0000000e+00,
         0.0000000e+00]]], dtype=float32)