In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense


In [2]:
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention,self).__init__(**kwargs)
    
    def call(self, q, k, v, mask=None):   
        dk = tf.cast(tf.shape(k)[-1],tf.float32)
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        
        scores = matmul_qk / tf.math.sqrt(dk)
        
        if mask is not None:
            scores += (-1e9 * mask)
        
        weights = tf.nn.softmax(scores)
            
        return tf.matmul(weights, v)
        

    

In [3]:
class MultiHeadAttention(Layer):
    def __init__(self, d_model, number_heads):
        super(MultiHeadAttention, self).__init__()
        
        self.d_model = d_model
        self.num_heads = number_heads
        self.attention = DotProductAttention()  # Scaled dot product attention
        
        assert self.d_model % self.num_heads == 0, "d_model must be divisible by number_heads"
        
        self.depth = self.d_model // self.num_heads
        
        self.wq = Dense(self.d_model)
        self.wk = Dense(self.d_model)
        self.wv = Dense(self.d_model)
        self.dense = Dense(self.d_model)
        
    def  reshape_tensor(self,x, batch_size, flag=False):
        if flag:
            x = tf.reshape(x,(batch_size, -1, self.num_heads, self.depth))
            return tf.transpose(x,perm=(0,2,1,3))
        else:    
            x = tf.transpose(x,perm=(0,2,1,3))
            x = tf.reshape(x, (batch_size, -1, self.d_model))
            return x
    
    def call(self, q, k, v, mask=None):
        batch_size = tf.shape(q)[0]
        
        wq = self.wq(q)
        wk = self.wq(k)
        wv = self.wq(v)
        
        q = self.reshape_tensor(wq,batch_size,True)
        k = self.reshape_tensor(wk,batch_size,True)
        v = self.reshape_tensor(wv,batch_size,True)
        
        scaled_dot_product = self.attention(q, k, v)
        
        output = self.reshape_tensor(scaled_dot_product, batch_size)
        
        return self.dense(output)
        
        
        
        
        

In [4]:
# Define the input tensor dimensions
d_model = 128  # Model dimensionality
num_heads = 8  # Number of attention heads
seq_len = 10   # Sequence length

# Sample query, key, and value tensors
query = tf.random.normal((2, seq_len, d_model))  # (batch_size=2, seq_len, d_model)
key = tf.random.normal((2, seq_len, d_model))    # (batch_size=2, seq_len, d_model)
value = tf.random.normal((2, seq_len, d_model))  # (batch_size=2, seq_len, d_model)


# Create the Multi-Head Attention layer
mha = MultiHeadAttention(d_model, num_heads)

# Forward pass
attention_output = mha(value, key, query)

print("Attention Output Shape:", attention_output.shape)  # (batch_size, seq_len, d_model)
print("Attention Ouputs:", attention_output)  

Attention Output Shape: (2, 10, 128)
Attention Ouputs: tf.Tensor(
[[[-0.32265615  0.9651206  -0.2806585  ... -0.10053803 -0.636615
   -0.17914836]
  [-0.07004663  0.63617736 -0.16153038 ...  0.38840243 -0.5573803
   -0.1173067 ]
  [ 0.6227718   0.8793129   0.599577   ...  0.24041915 -0.1250247
   -0.12637393]
  ...
  [ 0.3822587   0.67935157  0.09969366 ... -0.4497999  -0.10653894
    0.45871755]
  [ 0.3895616   0.8173699  -0.6279467  ...  0.88777405 -0.44590983
    0.8041307 ]
  [ 0.5237001   1.3001679  -0.1341544  ...  0.24196798  0.03595141
   -0.3759784 ]]

 [[-0.1789748   0.31377462  0.14176211 ... -0.23341519  0.21462253
   -0.36530882]
  [-0.5872561   0.88563955  0.1219231  ... -0.1828457  -0.46431234
   -0.66672117]
  [-0.915837   -0.48023933  0.02416225 ... -0.5908037  -0.38637185
    0.64776105]
  ...
  [-0.89791274  0.30910543  0.23455557 ... -0.43627554  0.20827626
   -0.0033973 ]
  [-0.40550196  0.54268426 -0.4642162  ... -0.3497911   0.243126
   -0.36116546]
  [-0.9663886

