In [1]:
import torch
import torch.nn as nn
import math

In [26]:
class SelfAttention(nn.Module):
    def __init__(self,query_shape,key_shape,value_shape, model_size=512):
        super().__init__()
        self.d_q= query_shape[-1]
        self.d_k = key_shape[-1]
        self.d_v = value_shape[-1]
        self.model_size = model_size
       
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, attention_mask=None):
        
        query_key=torch.matmul(query, key.transpose(-2,-1))/math.sqrt(self.d_k)
        if attention_mask is not None:
            query_key = query_key.masked_fill(attention_mask.bool(), -torch.inf)
       
        attention = torch.matmul(self.softmax(query_key), value)
        return attention

In [32]:
attn = SelfAttention(query.shape, key.shape, value.shape)
output = attn(query, query, query)

In [35]:
output

tensor([[[ 2.0064,  0.2102,  0.8838,  ...,  0.0970,  0.6234, -1.1400],
         [-0.0885, -1.0831,  1.2557,  ..., -1.1865, -1.3762, -1.3190],
         [-0.1477,  0.5764, -0.0053,  ...,  1.2574,  1.0753, -0.3646]],

        [[ 1.7582,  0.0844, -1.0347,  ..., -1.9086,  0.8512,  0.2577],
         [ 1.1745,  0.5628,  0.5414,  ...,  1.1986, -0.5605,  0.3247],
         [-0.4274, -0.3267, -0.6601,  ...,  0.0339,  1.0798,  0.9191]]])

In [36]:
class Attention(nn.Module):
    def __init__(self,query_shape,key_shape,value_shape, model_size=512):
        super().__init__()
        self.d_q= query_shape[-1]
        self.d_k = key_shape[-1]
        self.d_v = value_shape[-1]
        self.model_size = model_size
        self.W_q= nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty((self.model_size, self.d_q))))
        self.W_k= nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty((self.model_size, self.d_k))))
        self.W_v= nn.Parameter(torch.nn.init.xavier_uniform_(torch.empty((self.model_size, self.d_v))))
        
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, query, key, value, attention_mask=None):
        query_t = torch.matmul(query, self.W_q)
        key_t = torch.matmul(key, self.W_k)
        value_t = torch.matmul(value, self.W_v)
        
        query_key=torch.matmul(query_t, key_t.transpose(-2,-1))/math.sqrt(self.d_k)
        if attention_mask is not None:
            query_key = query_key.masked_fill(attention_mask.bool(), -torch.inf)
       
        attention = torch.matmul(self.softmax(query_key), value_t)
        return attention

In [37]:
query = torch.randn(2, 3, 512)  # Example query tensor
key = torch.randn(2, 3, 512)    # Example key tensor 
value = torch.randn(2, 3, 512)  # Example value tensor

In [42]:
attn = Attention(query.shape, key.shape, value.shape)
output = attn(query, query, query)

In [43]:
output

tensor([[[ 0.5517, -0.2130,  0.1068,  ..., -0.0675,  0.0676, -0.3287],
         [ 0.5630, -0.2621,  0.1051,  ...,  0.0772,  0.0251, -0.2169],
         [ 0.1573,  0.1718,  0.3963,  ...,  0.3809,  0.1841, -0.5546]],

        [[-0.3631, -0.9195, -0.4250,  ...,  0.3630,  0.0185, -0.4522],
         [ 0.3783, -0.2530, -0.7024,  ..., -0.2203,  0.5756, -0.1280],
         [-0.1998, -0.4986, -0.5135,  ...,  0.0769, -0.4642,  1.2835]]],
       grad_fn=<UnsafeViewBackward0>)

In [38]:
class MultiHeadAttention(nn.Module):
    def __init__(self, query_shape,key_shape,value_shape, head_count, model_size=512):
        super().__init__()
        self.head_count = head_count
        self.model_size = model_size
        self.query_shape = query_shape
        self.key_shape = key_shape
        
        self.value_shape = value_shape
        self.W_O = nn.Parameter(nn.init.xavier_uniform_(torch.empty(self.head_count*self.value_shape[-1],self.model_size)))

        self.heads = [ Attention(self.query_shape, self.key_shape, self.value_shape, self.model_size) for _ in range(self.head_count)]

    def forward(self, query, key, value):
        mh_p1=torch.cat([head(query, key, value) for head in self.heads],-1)
        mh_p2 = torch.matmul(mh_p1, self.W_O)
        return mh_p2

In [39]:
attn = MultiHeadAttention(query.shape, key.shape, value.shape, model_size=512, head_count=8)

In [40]:
multi_head=attn(query, key, value)

In [41]:
multi_head

tensor([[[-6.0698e-01, -3.3274e-01, -3.5579e-01,  ..., -1.9445e+00,
           1.0646e+00, -7.4402e-01],
         [-7.5650e-01, -4.9202e-01, -2.3502e-01,  ..., -1.5646e+00,
           6.3224e-01, -9.9018e-01],
         [-6.9253e-01, -5.1152e-02, -3.8043e-01,  ..., -1.2231e+00,
           6.1537e-01, -6.1394e-01]],

        [[ 5.3048e-01, -5.9201e-01, -6.9080e-01,  ..., -1.1865e+00,
           5.9823e-01, -6.3175e-01],
         [ 1.7772e-01,  4.1436e-01, -6.5130e-01,  ...,  7.5560e-01,
          -7.4086e-02,  8.8687e-01],
         [-7.3537e-01, -1.7250e-01, -7.4415e-01,  ..., -7.9224e-04,
          -4.8310e-01,  6.0002e-02]]], grad_fn=<UnsafeViewBackward0>)

In [25]:
multi_head.shape

torch.Size([2, 3, 512])