In [None]:
import tensorflow as tf

In [15]:
class DotProductAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        scores = tf.matmul(queries, keys, transpose_b=True) / tf.sqrt(d_k)

        # Effectively zeroing entries in softmax
        if mask is not None:
            scores += -1e9 * mask 

        weights = tf.nn.softmax(scores)

        return tf.matmul(weights, values)

        

In [16]:
d_k = 64
d_v = 64
batch_size = 64

## Encoded Q, K, V in d_k, d_v dimensional embeddings 

In [17]:
input_seq_length = 5

queries = tf.random.normal((batch_size, input_seq_length, d_k))
keys = tf.random.normal((batch_size, input_seq_length, d_k))
values = tf.random.normal((batch_size, input_seq_length, d_v))

In [18]:
attention = DotProductAttention()

In [19]:
attention(queries, keys, values, float(d_k))

<tf.Tensor: shape=(64, 5, 64), dtype=float32, numpy=
array([[[ 3.75479087e-02,  4.72116023e-01, -6.70002759e-01, ...,
          6.63412809e-01, -2.52603471e-01, -2.19621927e-01],
        [-3.02511733e-02,  1.93555504e-01, -9.69966292e-01, ...,
          9.31107283e-01, -2.16214523e-01,  5.26975095e-03],
        [-3.55486959e-01,  8.85416269e-01, -1.56694031e+00, ...,
          1.09475648e+00,  3.37226242e-01,  8.43522102e-02],
        [ 1.45715311e-01,  1.69898242e-01, -4.37818170e-01, ...,
          7.30177879e-01, -2.84205258e-01, -1.48013800e-01],
        [ 3.85785073e-01, -3.60571355e-01, -1.58017725e-02, ...,
          7.51868129e-01, -5.30208945e-01, -1.70763046e-01]],

       [[ 8.44509482e-01,  7.87269831e-01,  6.28012180e-01, ...,
         -8.55544209e-02,  4.41683620e-01,  1.23979166e-01],
        [ 3.44490707e-01,  2.38743469e-01,  1.76922619e-01, ...,
         -3.50720286e-01,  2.88799822e-01, -1.85766384e-01],
        [ 4.57370162e-01,  6.12573564e-01,  4.18371260e-01, ...