In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense


In [2]:
class DotProductAttention(Layer):
    def __init__(self, **kwargs):
        super(DotProductAttention,self).__init__(**kwargs)
    
    def call(self, q, k, v, mask=None):   
        dk = tf.cast(tf.shape(k)[-1],tf.float32)
        matmul_qk = tf.matmul(q, k, transpose_b=True)
        
        scores = matmul_qk / tf.math.sqrt(dk)
        
        if mask is not None:
            scores += (-1e9 * mask)
        
        weights = tf.nn.softmax(scores)
            
        return tf.matmul(weights, v)
        

    

In [3]:
class MultiHeadAttention(Layer):
    def __init__(self, d_model, number_heads):
        super(MultiHeadAttention, self).__init__()
        
        self.d_model = d_model
        self.num_heads = number_heads
        self.attention = DotProductAttention()  # Scaled dot product attention
        
        assert self.d_model % self.num_heads == 0, "d_model must be divisible by number_heads"
        
        self.depth = self.d_model // self.num_heads
        
        self.wq = Dense(self.d_model)
        self.wk = Dense(self.d_model)
        self.wv = Dense(self.d_model)
        self.dense = Dense(self.d_model)
        
    def  reshape_tensor(self,x, batch_size, flag=False):
        if flag:
            x = tf.reshape(x,(batch_size, -1, self.num_heads, self.depth))
            return tf.transpose(x,perm=(0,2,1,3))
        else:    
            x = tf.transpose(x,perm=(0,2,1,3))
            x = tf.reshape(x, (batch_size, -1, self.d_model))
            return x
    
    def call(self, q, k, v, mask=None):
        batch_size = tf.shape(q)[0]
        
        wq = self.wq(q)
        wk = self.wq(k)
        wv = self.wq(v)
        
        q = self.reshape_tensor(wq,batch_size,True)
        k = self.reshape_tensor(wk,batch_size,True)
        v = self.reshape_tensor(wv,batch_size,True)
        
        scaled_dot_product = self.attention(q, k, v)
        
        output = self.reshape_tensor(scaled_dot_product, batch_size)
        
        return self.dense(output)
        
        
        
        
        

In [5]:
# Define the input tensor dimensions
d_model = 128  # Model dimensionality
num_heads = 8  # Number of attention heads
seq_len = 10   # Sequence length

# Sample query, key, and value tensors
query = tf.random.normal((2, seq_len, d_model))  # (batch_size=2, seq_len, d_model)
key = tf.random.normal((2, seq_len, d_model))    # (batch_size=2, seq_len, d_model)
value = tf.random.normal((2, seq_len, d_model))  # (batch_size=2, seq_len, d_model)


# Create the Multi-Head Attention layer
mha = MultiHeadAttention(d_model, num_heads)

# Forward pass
attention_output = mha(value, key, query)

print("Attention Output Shape:", attention_output.shape)  # (batch_size, seq_len, d_model)
print("Attention Ouputs:", attention_output)  

Attention Output Shape: (2, 10, 128)
Attention Ouputs: tf.Tensor(
[[[ 0.03564724 -0.13026899 -0.2905817  ...  0.09156609  0.31391734
    0.7958979 ]
  [ 0.0491679   0.11829799 -0.813233   ...  0.59438455  0.33297265
    0.45926914]
  [-0.24264418  0.14396743 -0.41736877 ... -0.15397848  0.2737037
    0.23678291]
  ...
  [-0.3663321   0.11688503 -0.95523506 ... -0.03410807  0.5397025
    0.3563132 ]
  [-0.30889088  0.38681743 -0.62830323 ... -0.08417644 -0.01798632
    0.6061058 ]
  [-0.23371796  0.6275364  -0.12398793 ...  0.3999416   0.2280209
    1.0406909 ]]

 [[ 0.67229265 -0.58834255  0.17819001 ...  0.39587396  0.5937276
   -0.90302455]
  [ 0.5271286  -0.27466798  0.05163426 ...  0.3766877   0.06273318
   -0.08011136]
  [ 0.43079412 -0.72020763 -0.10172926 ...  0.35286143  1.2609713
   -0.88027924]
  ...
  [ 0.86087286 -0.19494893  0.04824557 ...  0.09723626  0.60974616
   -0.56391335]
  [-0.286786   -0.71712863  0.22400053 ... -0.25928017  1.113314
   -0.15178515]
  [-0.09421124

