In [None]:
import torch
from torch import nn,Tensor

In [None]:
class MHA(nn.Module):
    
    def __init__(self,d_model:int,num_heads:int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model / num_heads
    
        self.q_proj = nn.Linear(d_model,d_model,bias=False)
        self.k_proj = nn.Linear(d_model,d_model,bias=False)
        self.v_proj = nn.Linear(d_model,d_model,bias=False)
        self.o_proj = nn.Linear(d_model,d_model,bias=False)
        
    def forward(self,x:Tensor,attention_mask:Tensor = None):
        
        batch_size = x.shape[0]

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        q = q.view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)
        k = k.view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)
        v = v.view(batch_size,-1,self.num_heads,self.head_dim).transpose(1,2)
        
        
        attn = torch.matmul(q,k.transpose(-1,-2))/torch.sqrt(Tensor(self.d_model))
        
        if attention_mask:
            attn =  attn.masked_fill(attention_mask == 0,float('inf'))
        
        attn = torch.softmax(attn,dim = -1)
        
        attn = torch.matmul(attn,v)
        
        attn = attn.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)
        
        return self.o_proj(attn)
        
        