In [2]:
import tensorflow as tf

In [3]:
def normalized_dot_product_attention(Q, K, V, mask):
    """
    Calculates the dot product attention weights
    and returns a matrix Z, which is the same size as 
    Q, K, and V.
    
    Parameters:
    
    Q - The result of applying Wq to X. 
      - Shape (..., sequence_len, dim_Wq)
    K - The result of applying Wk to X. 
      - Shape (..., sequence_len, dim_Wk)
    V - The result of applying Wv to X. 
      - Shape (..., sequence_len, dim_Wv)
    
    Returns:
    
    Z - The matrix created by applying the scaled attention 
            weights to the V matrix.
      - Shape (..., sequence_len, model_dim)
      
    """
    #Normalizing Q
    Q = tf.divide(Q, tf.norm(Q, axis=-1, keepdims=True))
    
    #Normalizing K
    K = tf.divide(K, tf.norm(K, axis=-1, keepdims=True))
    
    #compute the dot product of all query and key vectors (b/c they are normalized 0 >= values <=1
    attention_logits = tf.matmul(Q, K, transpose_b=True)
    
    attention_logits *= 1e2
    
    attention_logits += (mask * -1e9)
        
    #apply softmax to find the weights
    attention_weights = tf.nn.softmax(attention_logits, axis=-1)
    
    #multiply the weights by the Value matrix
    Z = tf.matmul(attention_weights, V)
    
    return Z, attention_weights

In [6]:
class MHAttention(tf.keras.layers.Layer):
    
    def __init__(self, num_heads, embedding_dim):
        super(MHAttention, self).__init__()
        assert embedding_dim % num_heads == 0

        self.head_dim = embedding_dim // num_heads
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim

        self.Wq = tf.keras.layers.Dense(self.embedding_dim)
        self.Wk = tf.keras.layers.Dense(self.embedding_dim)
        self.Wv = tf.keras.layers.Dense(self.embedding_dim)

        self.Wz = tf.keras.layers.Dense(self.embedding_dim)
        
    def create_heads(self, x, batch_size):
        
        return tf.reshape(tf.transpose(x), (batch_size, self.num_heads, -1, self.head_dim))
        
    def call(self, q, k, v):
         
        batch_size = q.shape[0]
        
        q = self.Wq(q)
        k = self.Wk(k)
        v = self.Wv(v)
        
        q = self.create_heads(q, batch_size)
        k = self.create_heads(k, batch_size)
        v = self.create_heads(v, batch_size)
        
        z, attention_weights = normalized_dot_product_attention(q, k, v)
        
        concat_z = tf.transpose(z, perm=[0, 2, 1, 3])
        
        concat_z = tf.reshape(concat_z, (batch_size, -1, self.embedding_dim))
        
        z = self.Wz(concat_z)
        
        return z, attention_weights


In [8]:
temp_mha = MHAttention(num_heads=8, embedding_dim=512)
y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, y, y)
out.shape

TensorShape([1, 60, 512])

In [9]:
def feed_forward(embedding_dim, hidden_dim):
    
    hidden_layer = tf.keras.layers.Dense(hidden_dim, activation='relu') #(batch size, seq len, hidden dim)
    output_layer = tf.keras.layers.Dense(embedding_dim) #(batch size, seq len, embedding dim)
    
    return tf.keras.Sequential([
        hidden_layer,
        output_layer
    ])

In [10]:
temp_ff = feed_forward(512, 1024)
temp_ff(y).shape

TensorShape([1, 60, 512])

## Scratch Work

In [149]:
q = tf.constant([[[0, 10, 0, 0, 0, 100, 1000, 0],
                  [0, 0,  1, 0, 100, 0, 0, 0],
                  [10, 0, 0, 1, 0, 0, 0, 1000],
                  [10, 0, 0, 0, 0, 0, 0, 0],
                  [10, 0, 0, 0, 0, 0, 0, 0],
                  [0, 0,  1, 0, 100, 0, 1000, 0]]], dtype=tf.float32)
num_heads = 4
head_dim = 2
batch_size = 1

We are using 4 heads. So we in theory pass x through 4 different Wq layers, then use the resulting q in self attention. We can actually do this with one layer and split up the result into 4 q's to then be passed into 4 different self attention layers. Below I'm going to try and figure out how to split up the results of one large Wq.

In [150]:
q

<tf.Tensor: shape=(1, 6, 8), dtype=float32, numpy=
array([[[   0.,   10.,    0.,    0.,    0.,  100., 1000.,    0.],
        [   0.,    0.,    1.,    0.,  100.,    0.,    0.,    0.],
        [  10.,    0.,    0.,    1.,    0.,    0.,    0., 1000.],
        [  10.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
        [  10.,    0.,    0.,    0.,    0.,    0.,    0.,    0.],
        [   0.,    0.,    1.,    0.,  100.,    0., 1000.,    0.]]],
      dtype=float32)>

In [151]:
q_multi_head = tf.reshape(tf.transpose(q), (batch_size, num_heads, -1, head_dim))

#The shape is (batch size, number of heads (width of tensor), sequence_length (height of tensor), head_dim (# of features in tensor))

In [152]:
q_multi_head

<tf.Tensor: shape=(1, 4, 6, 2), dtype=float32, numpy=
array([[[[   0.,    0.],
         [  10.,   10.],
         [  10.,    0.],
         [  10.,    0.],
         [   0.,    0.],
         [   0.,    0.]],

        [[   0.,    1.],
         [   0.,    0.],
         [   0.,    1.],
         [   0.,    0.],
         [   1.,    0.],
         [   0.,    0.]],

        [[   0.,  100.],
         [   0.,    0.],
         [   0.,  100.],
         [ 100.,    0.],
         [   0.,    0.],
         [   0.,    0.]],

        [[1000.,    0.],
         [   0.,    0.],
         [   0., 1000.],
         [   0.,    0.],
         [1000.,    0.],
         [   0.,    0.]]]], dtype=float32)>

In [175]:
z, attn = normalized_dot_product_attention(q_multi_head, q_multi_head, q_multi_head)

In [177]:
z

<tf.Tensor: shape=(1, 4, 6, 2), dtype=float32, numpy=
array([[[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]],

        [[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]],

        [[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]],

        [[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]]]], dtype=float32)>

In [178]:
attn

<tf.Tensor: shape=(1, 4, 6, 6), dtype=float32, numpy=
array([[[[nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan]],

        [[nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan],
         [nan, nan, nan, nan, nan, nan

Self attention with the split heads creates num_heads output. All these are concatenated then passed through a dense layer to convert them back to the original embedding size

In [154]:
temp = tf.transpose(q_multi_head, perm=[0, 2, 1, 3])
temp

<tf.Tensor: shape=(1, 6, 4, 2), dtype=float32, numpy=
array([[[[   0.,    0.],
         [   0.,    1.],
         [   0.,  100.],
         [1000.,    0.]],

        [[  10.,   10.],
         [   0.,    0.],
         [   0.,    0.],
         [   0.,    0.]],

        [[  10.,    0.],
         [   0.,    1.],
         [   0.,  100.],
         [   0., 1000.]],

        [[  10.,    0.],
         [   0.,    0.],
         [ 100.,    0.],
         [   0.,    0.]],

        [[   0.,    0.],
         [   1.,    0.],
         [   0.,    0.],
         [1000.,    0.]],

        [[   0.,    0.],
         [   0.,    0.],
         [   0.,    0.],
         [   0.,    0.]]]], dtype=float32)>

In [155]:
tf.reshape(temp, (1, -1, 8))

<tf.Tensor: shape=(1, 6, 8), dtype=float32, numpy=
array([[[   0.,    0.,    0.,    1.,    0.,  100., 1000.,    0.],
        [  10.,   10.,    0.,    0.,    0.,    0.,    0.,    0.],
        [  10.,    0.,    0.,    1.,    0.,  100.,    0., 1000.],
        [  10.,    0.,    0.,    0.,  100.,    0.,    0.,    0.],
        [   0.,    0.,    1.,    0.,    0.,    0., 1000.,    0.],
        [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.]]],
      dtype=float32)>