In [37]:
from tensorflow.keras.layers import Layer,Dense
from tensorflow import matmul,math,cast,float32,reshape,shape,transpose
from tensorflow.keras.backend import softmax
class DotProductAttention(Layer):
  def __init__(self,**kwargs):
    super(DotProductAttention,self).__init__(**kwargs)
  def call(self,queries,keys,values,d_k,mask=None):
    score=matmul(queries,keys,transpose_b=True)/math.sqrt(cast(d_k,float32))
    if mask is not None:
      score+=-1e9*mask
    weight=softmax(score)
    return matmul(weight,values)
class MultiHeadAttention(Layer):
  def __init__(self,h,d_k,d_v,d_model,**kwargs):
    super(MultiHeadAttention,self).__init__(**kwargs)
    self.attention=DotProductAttention()
    self.head=h
    self.d_k=d_k
    self.d_v=d_v
    self.d_model=d_model
    self.W_q=Dense(d_k)
    self.W_k=Dense(d_k)
    self.W_v=Dense(d_v)
    self.W_o=Dense(d_model)
  def reshape_tensor(self,x,head,flag):
    if flag:
      x=reshape(x,shape=(shape(x)[0],shape(x)[1],head,-1))
      x=transpose(x,perm=(0,2,1,3))
    else:
      x=transpose(x,perm=(0,2,1,3))
      x=reshape(x,shape=(shape(x)[0],shape(x)[1],self.d_k))
    return x
  def call(self,queries,keys,values,mask=None):
    q_reshape=self.reshape_tensor(self.W_q(queries),self.head,True)
    k_reshape=self.reshape_tensor(self.W_k(keys),self.head,True)
    v_reshape=self.reshape_tensor(self.W_v(values),self.head,True)
    o_reshape=self.attention(q_reshape,k_reshape,v_reshape,d_k=self.d_k,mask=mask)
    output=self.reshape_tensor(o_reshape,self.head,False)
    return self.W_o(output)

In [38]:
from numpy import random
h=8
batch_size=65
input_seq_length=5
d_model=512
d_k=64
d_v=64
queries=random.random((batch_size,input_seq_length,d_k))
keys=random.random((batch_size,input_seq_length,d_k))
values=random.random((batch_size,input_seq_length,d_v))

In [39]:
multiheadattention=MultiHeadAttention(h,d_k,d_v,d_model)
multiheadattention(queries,keys,values)

<tf.Tensor: shape=(65, 5, 512), dtype=float32, numpy=
array([[[-2.09929975e-04, -1.35063753e-01,  8.57737754e-03, ...,
         -5.60619950e-01,  2.46605952e-04,  2.50632633e-02],
        [-8.39318789e-04, -1.35312438e-01,  5.77640254e-03, ...,
         -5.61561406e-01,  1.50563207e-03,  2.13867649e-02],
        [-6.12106523e-04, -1.34001926e-01,  7.88539369e-03, ...,
         -5.62368453e-01, -2.00813753e-04,  2.41375770e-02],
        [ 1.02884864e-04, -1.33864313e-01,  7.87307229e-03, ...,
         -5.62577963e-01,  1.62191840e-03,  2.34233178e-02],
        [ 4.50047315e-04, -1.35322511e-01,  7.94218760e-03, ...,
         -5.61856627e-01,  2.37617129e-03,  2.26644929e-02]],

       [[ 7.61391744e-02, -1.24775805e-01,  4.18224074e-02, ...,
         -4.08251643e-01, -1.65483616e-02,  2.77380459e-02],
        [ 7.64259920e-02, -1.22853100e-01,  4.23289947e-02, ...,
         -4.06363100e-01, -1.35547323e-02,  3.02770548e-02],
        [ 7.73765147e-02, -1.24992512e-01,  4.11587618e-02, ..