In [2]:
import torch as th
from torch import nn

In [28]:
class multihead_attention(nn.Module):
    def __init__(self, contenxt_length, n_dim, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.n_dim = n_dim
        self.head_dim = n_dim//self.n_heads
        self.q = nn.Linear(n_dim, n_dim)
        self.k = nn.Linear(n_dim, n_dim)
        self.v = nn.Linear(n_dim, n_dim)
        self.register_buffer("mask", th.triu(th.ones(contenxt_length, contenxt_length), 
                                             diagonal=1))
        
    def forward(self, x):
        b, num_tokens, embeddings = x.shape
        query = self.q(x) # (batch, tokens, embeddings)
        keys = self.k(x)
        values = self.v(x)
        
        # split the query, key and value in heads
        query = query.view(b, num_tokens, self.n_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.n_heads, self.head_dim)
        values = values.view(b, num_tokens, self.n_heads, self.head_dim)
        
        # reshape q, k, v to bring heads forward
        query = query.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # now perform the matmul operantion to caluate the attention score for each heads
        attention_score = query @ keys.transpose(2, 3)
        
        # orignal mak
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attention_score.masked_fill_(mask_bool, -th.inf)
        
        attention_weight = th.softmax((attention_score)/keys.shape[-1]**0.5, dim = -1)
        attention_weight = nn.Dropout(0.5)(attention_weight)
        
        # calculate the context vector
        context_vector = (attention_weight @ values).transpose(1, 2) # b, num_tokens, heads, head_dim
        
        # combine heads
        context_vector = context_vector.contiguous().view(b, num_tokens, self.n_dim)
        return context_vector
         
        
        

In [30]:
model = multihead_attention(7, 36, 3)

In [31]:
dummy = th.rand(size = (2, 7, 36))

In [32]:
model(dummy).shape

torch.Size([2, 7, 36])