References:
- https://machinelearningmastery.com/how-to-implement-scaled-dot-product-attention-from-scratch-in-tensorflow-and-keras/

In [1]:
import numpy as np
import tensorflow as tf

In [2]:
print('NumPy version:', np.__version__)
print('TensorFlow version:', tf.__version__)

NumPy version: 1.22.4
TensorFlow version: 2.12.0


## Scaled-Dot Product Attention

In [7]:
# Implementing the Scaled-Dot Product Attention
class DotProductAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
 
    def call(self, queries, keys, values, d_k, mask=None):
        # Scoring the queries against the keys after transposing the latter, 
        # and scaling
        scores = tf.matmul(queries, keys, transpose_b=True) \
        / tf.math.sqrt(tf.cast(d_k, tf.float32))
 
        # Apply mask to the attention scores
        if mask is not None:
            scores += -1e9 * mask
 
        # Computing the weights by a softmax operation
        weights = tf.keras.activations.softmax(scores)
 
        # Computing the attention by a weighted sum of the value vectors
        return tf.matmul(weights, values)

### Test

In [11]:
input_seq_length = 5  # Maximum length of the input sequence
d_k = 64  # Dimensionality of the linearly projected queries and keys
d_v = 64  # Dimensionality of the linearly projected values
batch_size = 64  # Batch size from the training process
 
queries = np.random.random((batch_size, input_seq_length, d_k))
keys = np.random.random((batch_size, input_seq_length, d_k))
values = np.random.random((batch_size, input_seq_length, d_v))

print(queries.shape, keys.shape, values.shape)

(64, 5, 64) (64, 5, 64) (64, 5, 64)


In [9]:
attention = DotProductAttention()
print(attention(queries, keys, values, d_k))

tf.Tensor(
[[[0.5405959  0.6993936  0.5026516  ... 0.6073619  0.60819995 0.38466918]
  [0.56965035 0.6883776  0.5247059  ... 0.5629483  0.60342413 0.3896538 ]
  [0.54332215 0.69759405 0.5120189  ... 0.5976771  0.6104103  0.37741628]
  [0.55770063 0.68541634 0.5183922  ... 0.5784359  0.61775726 0.3709305 ]
  [0.5459266  0.6864781  0.5009302  ... 0.60004544 0.6235326  0.37195563]]

 [[0.58028376 0.56536126 0.6102319  ... 0.753921   0.460257   0.71111435]
  [0.56735826 0.55447745 0.61917734 ... 0.7621567  0.48218375 0.6857721 ]
  [0.58125937 0.5737533  0.60878515 ... 0.7615038  0.46895206 0.71606445]
  [0.56346095 0.55564106 0.6274744  ... 0.76294804 0.47242552 0.6780621 ]
  [0.5818496  0.5666653  0.6097885  ... 0.7601096  0.47798562 0.7199864 ]]

 [[0.4112715  0.5293472  0.84149915 ... 0.47350776 0.34974098 0.45024145]
  [0.38375038 0.54324013 0.8284927  ... 0.47079813 0.35099322 0.4360806 ]
  [0.40054032 0.5157875  0.82854545 ... 0.50581276 0.3657748  0.4326033 ]
  [0.40492395 0.5375647

## Dependencies

In [5]:
!pip install session-info

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting session-info
  Downloading session_info-1.0.0.tar.gz (24 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting stdlib_list
  Downloading stdlib_list-0.8.0-py3-none-any.whl (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.5/63.5 kB[0m [31m699.0 kB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: session-info
  Building wheel for session-info (setup.py) ... [?25l[?25hdone
  Created wheel for session-info: filename=session_info-1.0.0-py3-none-any.whl size=8042 sha256=59944e38e2581e2ee17d531f4b5728fc828c92595a38ee3c6cf23951be259992
  Stored in directory: /root/.cache/pip/wheels/6a/aa/b9/eb5d4031476ec10802795b97ccf937b9bd998d68a9b268765a
Successfully built session-info
Installing collected packages: stdlib_list, session-info
Successfully installed session-info-1.0.0 stdlib_list-0.8.0


In [6]:
import session_info

session_info.show()