In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense


In [2]:
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention,self).__init__(**kwargs)
    
    def call(self, q, k, v, mask=None):   
        dk = tf.cast(tf.shape(k)[-1],tf.float32)
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        
        scores = matmul_qk / tf.math.sqrt(dk)
        
        if mask is not None:
            scores += (-1e9 * mask)
        
        weights = tf.nn.softmax(scores)
            
        return tf.matmul(weights, v)
        

    

In [3]:
class MultiHeadAttention(Layer):
    def __init__(self, d_model, number_heads):
        super(MultiHeadAttention, self).__init__()
        
        self.d_model = d_model
        self.num_heads = number_heads
        self.attention = DotProductAttention()  # Scaled dot product attention
        
        assert self.d_model % self.num_heads == 0, "d_model must be divisible by number_heads"
        
        self.depth = self.d_model // self.num_heads
        
        self.wq = Dense(self.d_model)
        self.wk = Dense(self.d_model)
        self.wv = Dense(self.d_model)
        self.dense = Dense(self.d_model)
        
    def  reshape_tensor(self,x, batch_size, flag=False):
        if flag:
            x = tf.reshape(x,(batch_size, -1, self.num_heads, self.depth))
            return tf.transpose(x,perm=(0,2,1,3))
        else:    
            x = tf.transpose(x,perm=(0,2,1,3))
            x = tf.reshape(x, (batch_size, -1, self.d_model))
            return x
    
    def call(self, q, k, v, mask=None):
        batch_size = tf.shape(q)[0]
        
        wq = self.wq(q)
        wk = self.wq(k)
        wv = self.wq(v)
        
        q = self.reshape_tensor(wq,batch_size,True)
        k = self.reshape_tensor(wk,batch_size,True)
        v = self.reshape_tensor(wv,batch_size,True)
        
        scaled_dot_product = self.attention(q, k, v)
        
        output = self.reshape_tensor(scaled_dot_product, batch_size)
        
        return self.dense(output)
        
        
        
        
        

In [4]:
# Define the input tensor dimensions
d_model = 128  # Model dimensionality
num_heads = 8  # Number of attention heads
seq_len = 10   # Sequence length

# Sample query, key, and value tensors
query = tf.random.normal((2, seq_len, d_model))  # (batch_size=2, seq_len, d_model)
key = tf.random.normal((2, seq_len, d_model))    # (batch_size=2, seq_len, d_model)
value = tf.random.normal((2, seq_len, d_model))  # (batch_size=2, seq_len, d_model)


# Create the Multi-Head Attention layer
mha = MultiHeadAttention(d_model, num_heads)

# Forward pass
attention_output = mha(value, key, query)

print("Attention Output Shape:", attention_output.shape)  # (batch_size, seq_len, d_model)
print("Attention Ouputs:", attention_output)  

Attention Output Shape: (2, 10, 128)
Attention Ouputs: tf.Tensor(
[[[ 0.16796461 -0.29680312 -0.36854327 ...  0.1213422  -0.01294411
    0.20422964]
  [ 0.06026412 -0.41810444 -0.45633212 ...  0.41323772  0.15755515
    0.12214798]
  [ 0.59621954 -0.20482707 -0.26585236 ... -0.00865346  0.6714901
    0.10502053]
  ...
  [-0.11432634 -0.43540606 -0.30294836 ...  0.5691043  -0.3535357
    0.5544609 ]
  [ 0.22376566 -0.22547275  0.07857489 ...  0.20396392 -0.2593416
   -0.01794769]
  [-0.19177414 -0.34396467 -0.22086892 ...  0.42266598 -0.1686663
   -0.06621613]]

 [[-0.0271943   0.890152    0.06651966 ...  0.03335216 -0.06357548
    0.14888097]
  [ 0.09474072  0.36804473  0.23109931 ... -0.5623231   0.21032071
    0.3945907 ]
  [ 0.23362786  0.18425468 -0.06610259 ...  0.1196131   0.09615794
   -0.48757905]
  ...
  [ 0.7888516  -0.16088219  0.20459506 ... -0.03058572 -0.21484427
   -0.5165455 ]
  [ 0.33547083 -0.35500762  0.20217553 ... -0.6294464   0.10174162
   -0.53016406]
  [ 0.49657

