In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as f


In [2]:
class MultiHeadAttention(nn.Module):
  def __init__(self,input_dim,sequence_length,d_model,num_heads,batch_size):
    super().__init__()
    self.input_dim = input_dim
    self.sequence_length = sequence_length
    self.batch_size = batch_size
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dims = self.d_model // self.num_heads # head_dims = d_k(dimension of key vector) = d_v(dimension of value vector)
    self.qkv_layer = nn.Linear(in_features=self.input_dim,out_features=3*self.d_model)
    self.linear_layer = nn.Linear(in_features=self.d_model,out_features=self.d_model)

  def scaled_dot_product_attention(self,q,k,v,mask = None):
    d_k = self.head_dims
    scaled = torch.matmul(q,k.transpose(-2,-1)) / np.sqrt(d_k)  # shape (batch_size,num_heads,num_queries,num_kv) (num_queries == num_kv)
    if (mask is not None):
      scaled += mask
    attention = f.softmax(scaled,dim = -1) # shape (batch_size,num_heads,num_queries,num_kv) (num_queries == num_kv)
    values = torch.matmul(attention,v) # shape (batch_size,num_heads,num_queries,head_dims) (head_dims = d_v)
    return values,attention



  def forward(self,x,mask = None):
    qkv = self.qkv_layer(x) # shape (batch_size,num_queries,3*d_model)
    qkv = qkv.reshape(self.batch_size,self.sequence_length,self.num_heads,3*self.head_dims) # shape (batch_size,num_queries,num_heads,3*head_dims)
    qkv = torch.permute(qkv,(0,2,1,3))  # shape (batch_size,num_heads,num_queries,3*head_dims)
    q,k,v = torch.chunk(qkv,3,dim= -1)  # each shape (batch_size,num_heads,num_queries,head_dims)
    values,attention = self.scaled_dot_product_attention(q,k,v,mask = mask)
    values = values.reshape(self.batch_size,self.sequence_length,self.head_dims * self.num_heads) # shape (batch_size,num_queries,head_dims * num_heads)
    out = self.linear_layer(values) # shape (batch_size,num_queries,d_model)
    return out





In [3]:
batch_size = 32
sequence_length = 20
input_dim = 250
d_model = 512
num_heads = 8
x = torch.randn((batch_size,sequence_length,input_dim)) ## (batch_size,num_queries,emb_dim)


In [4]:
case1 = MultiHeadAttention(input_dim,sequence_length,d_model,num_heads,batch_size)
out2 = case1(x)
out2.size()

torch.Size([32, 20, 512])

In [6]:
out2[0][0][:20]

tensor([-0.0318, -0.0378,  0.0708,  0.0550,  0.0378,  0.1422, -0.2409, -0.0942,
         0.0319, -0.0276, -0.0042, -0.0161, -0.0702,  0.1295,  0.0467,  0.0841,
         0.0282,  0.1663, -0.1363,  0.0547], grad_fn=<SliceBackward0>)