References:
- https://machinelearningmastery.com/how-to-implement-scaled-dot-product-attention-from-scratch-in-tensorflow-and-keras/

In [1]:
import numpy as np
import tensorflow as tf

In [2]:
print('NumPy version:', np.__version__)
print('TensorFlow version:', tf.__version__)

NumPy version: 1.22.4
TensorFlow version: 2.12.0


## Scaled-Dot Product Attention

In [3]:
# Implementing the Scaled-Dot Product Attention
class DotProductAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
 
    def call(self, queries, keys, values, d_k, mask=None):
        # Scoring the queries against the keys after transposing the latter, 
        # and scaling
        scores = tf.matmul(queries, keys, transpose_b=True) \
        / tf.math.sqrt(tf.cast(d_k, tf.float32))
 
        # Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask
 
        # Computing the weights by a softmax operation
        weights = tf.keras.activations.softmax(scores)
 
        # Computing the attention by a weighted sum of the value vectors
        return tf.matmul(weights, values)

### Test

In [4]:
input_seq_length = 5  # Maximum length of the input sequence
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
batch_size = 64  # Batch size from the training process
 
queries = np.random.random((batch_size, input_seq_length, d_k))
keys = np.random.random((batch_size, input_seq_length, d_k))
values = np.random.random((batch_size, input_seq_length, d_v))

print(queries.shape, keys.shape, values.shape)

(64, 5, 64) (64, 5, 64) (64, 5, 64)


In [5]:
attention = DotProductAttention()
attention_output = attention(queries, keys, values, d_k)
print(attention_output.shape)
print(attention_output)

(64, 5, 64)
tf.Tensor(
[[[0.35911444 0.27986366 0.43175602 ... 0.4172761  0.7320518  0.5585019 ]
  [0.35966444 0.28836572 0.42785645 ... 0.42110783 0.7298768  0.54711026]
  [0.37074962 0.25623086 0.4194161  ... 0.4300927  0.7316338  0.5674546 ]
  [0.36677894 0.2786251  0.42215994 ... 0.42453918 0.73416126 0.5516065 ]
  [0.37361792 0.25851774 0.42539382 ... 0.42708644 0.73355174 0.5652497 ]]

 [[0.6587192  0.53937733 0.55646425 ... 0.40068153 0.5077247  0.58336484]
  [0.6719176  0.5389432  0.5485055  ... 0.4065726  0.50623816 0.5716106 ]
  [0.64990133 0.54997545 0.55187446 ... 0.40260786 0.509494   0.5744652 ]
  [0.6299228  0.5615862  0.5461196  ... 0.39455885 0.5144636  0.5840039 ]
  [0.6494949  0.53913033 0.5454416  ... 0.41275638 0.49786437 0.5701406 ]]

 [[0.58725727 0.60501516 0.64156836 ... 0.53461623 0.6168226  0.2626707 ]
  [0.5684277  0.6140933  0.638802   ... 0.50707954 0.61244035 0.26776162]
  [0.5767349  0.5994956  0.633645   ... 0.52529156 0.5953866  0.2644099 ]
  [0.560262

## Multi-Head Attention

In [6]:
# Implementing the Multi-Head Attention
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, h, d_k, d_v, d_model, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.attention = DotProductAttention() # Scaled dot product attention
        self.heads = h  # Number of attention heads to use
        self.d_k = d_k  
        # Dimensionality of the linearly projected queries and keys
        self.d_v = d_v  
        # Dimensionality of the linearly projected values
        self.d_model = d_model  # Dimensionality of the model
        self.W_q = tf.keras.layers.Dense(d_k)  
        # Learned projection matrix for the queries
        self.W_k = tf.keras.layers.Dense(d_k)  
        # Learned projection matrix for the keys
        self.W_v = tf.keras.layers.Dense(d_v)  
        # Learned projection matrix for the values
        self.W_o = tf.keras.layers.Dense(d_model)  
        # Learned projection matrix for the multi-head output
 
    def reshape_tensor(self, x, heads, flag):
        if flag:
            # Tensor shape after reshaping and transposing: 
            # (batch_size, heads, seq_length, -1)
            x = tf.reshape(x, shape=(tf.shape(x)[0], tf.shape(x)[1], heads, -1))
            x = tf.transpose(x, perm=(0, 2, 1, 3))
        else:
            # Reverting the reshaping and transposing operations: 
            # (batch_size, seq_length, d_k)
            x = tf.transpose(x, perm=(0, 2, 1, 3))
            x = tf.reshape(x, shape=(tf.shape(x)[0], tf.shape(x)[1], self.d_k))
        return x
 
    def call(self, queries, keys, values, mask=None):
        # Rearrange the queries to be able to compute all heads in parallel
        q_reshaped = self.reshape_tensor(self.W_q(queries), self.heads, True)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
 
        # Rearrange the keys to be able to compute all heads in parallel
        k_reshaped = self.reshape_tensor(self.W_k(keys), self.heads, True)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
 
        # Rearrange the values to be able to compute all heads in parallel
        v_reshaped = self.reshape_tensor(self.W_v(values), self.heads, True)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
 
        # Compute the multi-head attention output using 
        # the reshaped queries, keys and values
        o_reshaped = self.attention(q_reshaped, k_reshaped, v_reshaped, 
                                    self.d_k, mask)
        # Resulting tensor shape: (batch_size, heads, input_seq_length, -1)
 
        # Rearrange back the output into concatenated form
        output = self.reshape_tensor(o_reshaped, self.heads, False)
        # Resulting tensor shape: (batch_size, input_seq_length, d_v)
 
        # Apply one final linear projection to the output to generate 
        # the multi-head attention
        # Resulting tensor shape: (batch_size, input_seq_length, d_model)
        return self.W_o(output)

### Test

In [7]:
input_seq_length = 5  # Maximum length of the input sequence
h = 8  # Number of self-attention heads
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
d_model = 512  # Dimensionality of the model sub-layers' outputs
batch_size = 64  # Batch size from the training process
 
queries = np.random.random((batch_size, input_seq_length, d_k))
keys = np.random.random((batch_size, input_seq_length, d_k))
values = np.random.random((batch_size, input_seq_length, d_v))

print(queries.shape, keys.shape, values.shape)

(64, 5, 64) (64, 5, 64) (64, 5, 64)


In [8]:
multihead_attention = MultiHeadAttention(h, d_k, d_v, d_model)
multihead_attention_output = multihead_attention(queries, keys, values)
print(multihead_attention_output.shape)
print(multihead_attention_output)

(64, 5, 512)
tf.Tensor(
[[[ 0.20084411 -0.04835438  0.08506437 ... -0.30154708 -0.09826905
    0.02063181]
  [ 0.19723743 -0.05085796  0.07998941 ... -0.30138454 -0.09685958
    0.01925156]
  [ 0.2006126  -0.05019293  0.08474492 ... -0.301715   -0.0968755
    0.01739117]
  [ 0.20130381 -0.04796845  0.08251181 ... -0.3006359  -0.09820308
    0.0220839 ]
  [ 0.20102677 -0.04776622  0.0855033  ... -0.3031875  -0.0991528
    0.01686744]]

 [[ 0.28139213  0.06235213  0.12181309 ... -0.33304513 -0.22435343
    0.00647064]
  [ 0.27891335  0.06149861  0.12167642 ... -0.333166   -0.22487411
    0.00797319]
  [ 0.28005055  0.0647476   0.12119246 ... -0.33376288 -0.22502379
    0.00722794]
  [ 0.28137016  0.0615837   0.12217928 ... -0.33417442 -0.22468513
    0.00510436]
  [ 0.2817069   0.05971115  0.12134649 ... -0.33609265 -0.22557211
    0.0078532 ]]

 [[ 0.36538604  0.15977253  0.15339254 ... -0.3448407  -0.19333048
    0.02876215]
  [ 0.36616623  0.16131091  0.1524326  ... -0.34518772 -0.195

## Dependencies

In [9]:
!pip install session-info

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [10]:
import session_info

session_info.show()