In [4]:
# 4 adds the MHA as a class - we need to keep Attn Head as a class, why? Because each has Wq/Wk/Wv

import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt

class MHSelfAttn(nn.Module):
    def __init__(self, embed_dim, num_attn_heads):
        super().__init__()
        attn_head_input_dim = embed_dim
        attn_head_output_dim = embed_dim // num_attn_heads

        # Just add self..
        self.Wo = nn.Linear(embed_dim, embed_dim)

        # Just add self..
        self.attn_heads = nn.ModuleList(
            [AttnHead(attn_head_input_dim, attn_head_output_dim) for _ in range(num_attn_heads)]
        )


    def forward(self, input_data):

        
        # Refer to attn_heads with self..
        concatenated_head_context_vectors = torch.cat(
               [attn_head(input_data) for attn_head in self.attn_heads], dim=-1
        )
        # Refer to Wo with self..
        multihead_context_vector = self.Wo(concatenated_head_context_vectors)

        # return it
        return multihead_context_vector


# The rest does not change from 3 - this is the AttenHead Class from 3:

class AttnHead(nn.Module):

    def __init__(self, attn_head_input_dim, attn_head_output_dim):
        super().__init__()

        # Just add self..
        self.Wq = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False) 
        self.Wk = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False)
        self.Wv = nn.Linear(attn_head_input_dim, attn_head_output_dim, bias=False) 

    def forward(self, input_data):

        # Just add self..
        queries = self.Wq(input_data)     
        keys = self.Wk(input_data)
        values = self.Wv(input_data)

        dim_of_key = keys.size(-1) 
        attn_scores = queries @ keys.transpose(-2,-1)/sqrt(dim_of_key) 
        mask = torch.ones((attn_scores.shape[-1], attn_scores.shape[-1]), dtype=torch.bool).triu(diagonal=1)
        attn_scores = attn_scores.masked_fill(mask, float("-inf"))
        attn_weights = F.softmax(attn_scores, dim = -1)
        head_context_vector = attn_weights@values 
        return head_context_vector       

In [5]:
# No diff from 3:
seq_len = 512
embed_dim = 256
num_attn_heads = 2
batch_size = 16

input_data = torch.randn(batch_size, seq_len, embed_dim)

mha = MHSelfAttn(embed_dim, num_attn_heads)

output = mha(input_data)

output.shape

torch.Size([16, 512, 256])

In [None]:
# Step 5: plug this MHSelfAttn in in the training loop