In [28]:
import tensorflow as tf
import time

## Scaled dot product attention (SDA)

In [29]:
def scaled_dot_product_attention(Q, K, V, mask):
    """
    Calculates the scaled dot product attention weights
    and returns a matrix Z, which is the same size as 
    Q, K, and V.
    
    Parameters:
    
    Q - The result of applying Wq to X. 
      - Shape (..., sequence_len, dim_Wq)
    K - The result of applying Wk to X. 
      - Shape (..., sequence_len, dim_Wk)
    V - The result of applying Wv to X. 
      - Shape (..., sequence_len, dim_Wv)
    
    Returns:
    
    Z - The matrix created by applying the scaled attention 
            weights to the V matrix.
      - Shape (..., sequence_len, model_dim)
      
    """
    #compute the dot product of all query and key vectors
    QK_T = tf.matmul(Q, K, transpose_b=True)
    
    #scale the dot products by the depth of k (number of columns ie tf.shape(k)[-1])
    dk = tf.shape(K)[-1]
    dk = tf.cast(dk, tf.float32)
    scaled_attention_logits = QK_T / tf.sqrt(dk)
    
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)
        
    #apply softmax to find the weights
    scaled_attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
    
    #multiply the weights by the Value matrix
    Z = tf.matmul(scaled_attention_weights, V)
    
    return Z, scaled_attention_weights

In [44]:
def print_sda_results(Q, K, V, mask):
    Z, weights = scaled_dot_product_attention(Q, K, V, mask)
    print(f'SDA Output is:\n {Z}')
    print(f'SDA Weights are:\n {weights}')

## Dot product attention (DA)

In [45]:
def dot_product_attention(Q, K, V, mask):
    """
    Calculates the dot product attention weights
    and returns a matrix Z, which is the same size as 
    Q, K, and V.
    
    Parameters:
    
    Q - The result of applying Wq to X. 
      - Shape (..., sequence_len, dim_Wq)
    K - The result of applying Wk to X. 
      - Shape (..., sequence_len, dim_Wk)
    V - The result of applying Wv to X. 
      - Shape (..., sequence_len, dim_Wv)
    
    Returns:
    
    Z - The matrix created by applying the scaled attention 
            weights to the V matrix.
      - Shape (..., sequence_len, model_dim)
      
    """
    #compute the dot product of all query and key vectors
    QK_T = tf.matmul(Q, K, transpose_b=True)
    attention_logits = QK_T
    
    if mask is not None:
        attention_logits += (mask * -1e9)
        
    #apply softmax to find the weights
    attention_weights = tf.nn.softmax(attention_logits, axis=-1)
    
    #multiply the weights by the Value matrix
    Z = tf.matmul(attention_weights, V)
    
    return Z, attention_weights

In [46]:
def print_da_results(Q, K, V, mask):
    Z, weights = dot_product_attention(Q, K, V, mask)
    print(f'DA Output is:\n {Z}')
    print(f'DA Weights are:\n {weights}')

## Normalized dot product attention (NDA)

In [47]:
def normalized_dot_product_attention(Q, K, V, mask):
    """
    Calculates the dot product attention weights
    and returns a matrix Z, which is the same size as 
    Q, K, and V.
    
    Parameters:
    
    Q - The result of applying Wq to X. 
      - Shape (..., sequence_len, dim_Wq)
    K - The result of applying Wk to X. 
      - Shape (..., sequence_len, dim_Wk)
    V - The result of applying Wv to X. 
      - Shape (..., sequence_len, dim_Wv)
    
    Returns:
    
    Z - The matrix created by applying the scaled attention 
            weights to the V matrix.
      - Shape (..., sequence_len, model_dim)
      
    """
    #Normalizing Q
    Q = tf.divide(Q, tf.norm(Q, axis=-1, keepdims=True))
    
    #Normalizing K
    K = tf.divide(K, tf.norm(K, axis=-1, keepdims=True))
    
    #compute the dot product of all query and key vectors (b/c they are normalized 0 >= values <=1
    attention_logits = tf.matmul(Q, K, transpose_b=True)
    
    attention_logits *= 1e3
    
    if mask is not None:
        attention_logits += (mask * -1e9)
        
    #apply softmax to find the weights
    attention_weights = tf.nn.softmax(attention_logits, axis=-1)
    
    #multiply the weights by the Value matrix
    Z = tf.matmul(attention_weights, V)
    
    return Z, attention_weights

In [48]:
def print_nda_results(Q, K, V, mask):
    Z, weights = normalized_dot_product_attention(Q, K, V, mask)
    print(f'NDA Output is:\n {Z}')
    print(f'NDA Weights are:\n {weights}')

## Example 1: small magnitude query key vectors

In [49]:
K1 = tf.constant([[0, 1, 0],
                 [0, 0, 1],
                 [1, 0, 0],
                 [1, 0, 0]], dtype=tf.float32)
V1 = tf.constant([[1, 0],
                 [10, 0],
                 [100, 0],
                 [5, 0]], dtype=tf.float32)

In [50]:
Q5 = tf.constant([[1, 0, 0]], dtype=tf.float32)
Q6 = tf.constant([[0, 1, 0]], dtype=tf.float32)
Q7 = tf.constant([[0, 0, 1]], dtype=tf.float32)

In [52]:
Q8 = tf.concat([Q5, Q6, Q7], 0)
print(f'Query matrix: \n{Q8}')
print_sda_results(Q8, K1, V1, None)
print_da_results(Q8, K1, V1, None)
print_nda_results(Q8, K1, V1, None)

Query matrix: 
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
SDA Output is:
 [[35.6015    0.      ]
 [24.424534  0.      ]
 [25.895218  0.      ]]
SDA Weights are:
 [[0.17977124 0.17977124 0.32022873 0.32022873]
 [0.3725572  0.20914762 0.20914762 0.20914762]
 [0.20914762 0.3725572  0.20914762 0.20914762]]
DA Output is:
 [[39.859756  0.      ]
 [20.586304  0.      ]
 [23.290707  0.      ]]
DA Weights are:
 [[0.13447072 0.13447072 0.3655293  0.3655293 ]
 [0.47536692 0.17487772 0.17487772 0.17487772]
 [0.17487772 0.47536692 0.17487772 0.17487772]]
NDA Output is:
 [[52.5  0. ]
 [ 1.   0. ]
 [10.   0. ]]
NDA Weights are:
 [[0.  0.  0.5 0.5]
 [1.  0.  0.  0. ]
 [0.  1.  0.  0. ]]


## Example 2: medium magnitude query key vectors

In [415]:
K = tf.constant([[0, 10, 0],
                 [0, 0, 10],
                 [10, 0, 0],
                 [10, 0, 0]], dtype=tf.float32)
V = tf.constant([[1, 0],
                 [10, 0],
                 [100, 0],
                 [5, 0]], dtype=tf.float32)

In [416]:
Q1 = tf.constant([[10, 0, 0]], dtype=tf.float32)
Q2 = tf.constant([[0, 10, 0]], dtype=tf.float32)
Q3 = tf.constant([[0, 0, 10]], dtype=tf.float32)

In [417]:
Q4 = tf.concat([Q1, Q2, Q3], 0)
print(f'Query matrix: \n{Q4}')
print_sda_results(Q4, K, V)
print_sa_results(Q4, K, V)
print_nda_results(Q4, K, V)

Query matrix: 
[[10.  0.  0.]
 [ 0. 10.  0.]
 [ 0.  0. 10.]]
SDA Output is:
 [[52.5  0. ]
 [ 1.   0. ]
 [10.   0. ]]
SDA Weights are:
 [[4.2166372e-26 4.2166372e-26 5.0000000e-01 5.0000000e-01]
 [1.0000000e+00 8.4332744e-26 8.4332744e-26 8.4332744e-26]
 [8.4332744e-26 1.0000000e+00 8.4332744e-26 8.4332744e-26]]
Output is:
 [[52.5  0. ]
 [ 1.   0. ]
 [10.   0. ]]
Weights are:
 [[4.2166372e-26 4.2166372e-26 5.0000000e-01 5.0000000e-01]
 [1.0000000e+00 8.4332744e-26 8.4332744e-26 8.4332744e-26]
 [8.4332744e-26 1.0000000e+00 8.4332744e-26 8.4332744e-26]]
NDA Output is:
 [[52.5  0. ]
 [ 1.   0. ]
 [10.   0. ]]
NDA Weights are:
 [[0.  0.  0.5 0.5]
 [1.  0.  0.  0. ]
 [0.  1.  0.  0. ]]


## Example 3: large magnitude query key vectors

In [418]:
K2 = tf.constant([[0, 1000, 0],
                 [0, 0, 1000],
                 [1000, 0, 0],
                 [1000, 0, 0]], dtype=tf.float32)
V2 = tf.constant([[1, 0],
                 [10, 0],
                 [100, 0],
                 [5, 0]], dtype=tf.float32)

In [419]:
Q9 = tf.constant([[1000, 0, 0]], dtype=tf.float32)
Q10 = tf.constant([[0, 1000, 0]], dtype=tf.float32)
Q11 = tf.constant([[0, 0, 1000]], dtype=tf.float32)

In [420]:
Q12 = tf.concat([Q9, Q10, Q11], axis=0)
print(f'Query matrix: \n{Q12}')
print_sda_results(Q12, K2, V2)
print_sa_results(Q12, K2, V2)
print_nda_results(Q12, K2, V2)

Query matrix: 
[[1000.    0.    0.]
 [   0. 1000.    0.]
 [   0.    0. 1000.]]
SDA Output is:
 [[52.5  0. ]
 [ 1.   0. ]
 [10.   0. ]]
SDA Weights are:
 [[0.  0.  0.5 0.5]
 [1.  0.  0.  0. ]
 [0.  1.  0.  0. ]]
Output is:
 [[52.5  0. ]
 [ 1.   0. ]
 [10.   0. ]]
Weights are:
 [[0.  0.  0.5 0.5]
 [1.  0.  0.  0. ]
 [0.  1.  0.  0. ]]
NDA Output is:
 [[52.5  0. ]
 [ 1.   0. ]
 [10.   0. ]]
NDA Weights are:
 [[0.  0.  0.5 0.5]
 [1.  0.  0.  0. ]
 [0.  1.  0.  0. ]]


## Example 4: larger dimensional vectors (more realistic values)