In [1]:
from tensorflow import matmul, math, cast, float32
from tensorflow.keras.layers import Layer
from tensorflow.keras.backend import softmax

In [2]:
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)

    def call(self, queries, keys, values, d_k, mask=None):
        scores = matmul(queries, keys, transpose_b=True) / math.sqrt(cast(d_k, float32))

        if mask is not None:
            scores += -1e9 * mask
        
        weights = softmax(scores)

        return matmul(weights, values)

testing implementation

In [3]:
from turtle import back
from numpy import random

d_k = 64
d_v = 64
batch_size = 64
input_seq_length = 5

queries = random.random((batch_size, input_seq_length, d_k))
keys = random.random((batch_size, input_seq_length, d_k))
values = random.random((batch_size, input_seq_length, d_v))

attention = DotProductAttention()
print(attention(queries, keys, values, d_k))

tf.Tensor(
[[[0.404837   0.6081327  0.4263291  ... 0.5371667  0.607482   0.591229  ]
  [0.40648323 0.60440564 0.41669473 ... 0.5330911  0.60841334 0.58539027]
  [0.42123097 0.5900426  0.42422932 ... 0.55082476 0.599424   0.5745475 ]
  [0.43376115 0.60014343 0.45548874 ... 0.5712822  0.5775714  0.57115483]
  [0.40480608 0.589259   0.39576858 ... 0.53021646 0.6200434  0.57937914]]

 [[0.73453027 0.48621315 0.59504896 ... 0.5679474  0.5870134  0.5708886 ]
  [0.72074604 0.49343777 0.58282137 ... 0.5863008  0.598035   0.5903868 ]
  [0.7164197  0.49830198 0.56411326 ... 0.5707411  0.6160005  0.5771624 ]
  [0.7267853  0.4943127  0.58252853 ... 0.58087105 0.5915294  0.57505494]
  [0.72391635 0.49807006 0.57659745 ... 0.58670205 0.59035337 0.5738055 ]]

 [[0.34524274 0.40637815 0.5660924  ... 0.5997388  0.82626635 0.47036666]
  [0.33530208 0.402887   0.56048715 ... 0.608022   0.81650347 0.47291195]
  [0.3346155  0.3957529  0.5769265  ... 0.6215534  0.820185   0.48420572]
  [0.35206616 0.4118525