<a href="https://colab.research.google.com/github/ishammansoor/AI-and-Machine-Learning/blob/main/Multi_Head_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SelfAttention(nn.Module):
  def __init__(self, embed_dim):
    super(SelfAttention, self).__init__()

    self.embed_dim = embed_dim

    self.query = nn.Linear(embed_dim, embed_dim)
    self.key = nn.Linear(embed_dim, embed_dim)
    self.value = nn.Linear(embed_dim, embed_dim)

    self.out_proj = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    B, T, E = x.size()

    # step1: compute the Q, K, V
    Q = self.query(x)
    K = self.key(x)
    V = self.value(x)

    # step2: Compute scaled dot product
    attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / E ** 0.5
    attn_weights = F.softmax(attn_scores, dim=-1)

    #step3: Apply weight to values
    attn_output = torch.matmul(attn_weights, V)


    output = self.out_proj(attn_output)

    return output, attn_weights




In [7]:
class MultiHeadAttention(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super(MultiHeadAttention, self).__init__()

    assert embed_dim % num_heads == 0, "Embedding dim must be divisible by num_heads"

    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads

    self.heads = nn.ModuleList([
        SelfAttention(self.head_dim) for _ in range(self.num_heads)
    ])

    self.out_proj = nn.Linear(embed_dim, embed_dim)

  def forward(self, x):
    B, T, E = x.shape
    assert E == self.embed_dim

    x_split = x.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)  # (B, H, T, D)

    head_outputs = []
    attn_weights_all = []

    for i in range(self.num_heads):
        head_input = x_split[:, i, :, :]  # (B, T, head_dim)
        out, attn_weights = self.heads[i](head_input)
        head_outputs.append(out)
        attn_weights_all.append(attn_weights)

    concat = torch.cat(head_outputs, dim=-1)  # (B, T, E)
    output = self.out_proj(concat)

    return output, attn_weights_all



In [8]:
x = torch.randn(2, 5, 32)
mha = MultiHeadAttention(embed_dim=32, num_heads=4)

out, weights = mha(x)
print(out.shape)
print(weights[0].shape)

torch.Size([2, 5, 32])
4
torch.Size([2, 5, 5])
