<a href="https://colab.research.google.com/github/kinjaljoshi/attention_usage/blob/main/multiheaded_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Multi-Headed Attention** is an extension of self-attention, where instead of applying just one attention mechanism, we use multiple attention mechanisms (heads) in parallel. Each head focuses on different aspects of the input, allowing the model to learn better relationships between words.

**Problem with Single-Head Attention**
A single self-attention layer only captures one type of relationship between words. For example:

It might focus on subject-verb relationships but miss object-adjective dependencies.
Solution: Multi-Head Attention

Each head learns a different representation of the input.

Combining multiple heads helps the model capture different contextual meanings.
🔹 Example: "The cat sat on the mat."

* **Head 1** focuses on word-to-word meaning (cat ↔ mat).
* **Head 2** focuses on grammatical structure (sat ↔ on).
* **Head 3** focuses on subject-verb-object relations (cat ↔ sat).

By combining these different perspectives, the model understands the sentence
better.


Multi-Head Attention steps
1. Input sentence → Convert words to embeddings
2. Each head applies self-attention separately (Different projections for Q, K, and V)
3. Each head computes its own attention scores and outputs
4. Combine all heads into one representation using a linear layer




In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Example embeddings (word representations)
token_embeddings = torch.tensor([
    [1.0, 0.5],  # "The"
    [0.8, 0.2],  # "Cat"
    [0.9, 0.7],  # "Sat"
], dtype=torch.float32).unsqueeze(0)  # Shape: (1, 3, 2)  -> (Batch, Seq Length, Embedding Size)


In [2]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, mode=2, num_heads=2):
        super().__init__()
        self.num_heads = num_heads
        self.mode = mode
        assert mode % num_heads == 0, "mode must be divisible by num_heads"

        self.depth = mode // num_heads  # Size of each head’s Q, K, V

        # Learnable weights for Q, K, V for each head
        self.W_q = nn.Linear(mode, mode, bias=False)
        self.W_k = nn.Linear(mode, mode, bias=False)
        self.W_v = nn.Linear(mode, mode, bias=False)

        # Final linear layer to combine heads
        self.W_o = nn.Linear(mode, mode, bias=False)

    def split_heads(self, x):
        """
        Splits the last dimension into (num_heads, depth) and transposes to shape (batch, num_heads, seq_len, depth)
        """
        batch_size, seq_length, mode = x.shape
        x = x.view(batch_size, seq_length, self.num_heads, self.depth)
        return x.transpose(1, 2)  # (batch, num_heads, seq_length, depth)

    def forward(self, x):
        """
        x: Input token embeddings -> Shape: (batch_size, seq_length, mode)
        """
        batch_size, seq_length, _ = x.shape

        # Compute Query, Key, Value matrices
        Q = self.W_q(x)  # (batch_size, seq_length, mode)
        K = self.W_k(x)
        V = self.W_v(x)

        # Split into multiple heads
        Q = self.split_heads(Q)  # (batch_size, num_heads, seq_length, depth)
        K = self.split_heads(K)
        V = self.split_heads(V)

        # Compute scaled dot-product attention for each head
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.depth, dtype=torch.float32))
        attention_weights = F.softmax(scores, dim=-1)  # Normalize scores
        attention_output = torch.matmul(attention_weights, V)  # Apply attention to values

        # Merge heads back into one matrix
        attention_output = attention_output.transpose(1, 2).reshape(batch_size, seq_length, self.mode)

        # Final linear transformation
        output = self.W_o(attention_output)

        return output, attention_weights


# Initialize multi-head attention with 2 heads
multi_head_attention = MultiHeadSelfAttention(mode=2, num_heads=2)

# Apply multi-head attention
output, attention_weights = multi_head_attention(token_embeddings)

print("\nMulti-Head Attention Output:\n", output)
print("\nAttention Weights:\n", attention_weights)



Multi-Head Attention Output:
 tensor([[[0.0865, 0.0770],
         [0.0865, 0.0770],
         [0.0866, 0.0770]]], grad_fn=<UnsafeViewBackward0>)

Attention Weights:
 tensor([[[[0.3319, 0.3470, 0.3211],
          [0.3316, 0.3497, 0.3187],
          [0.3328, 0.3388, 0.3284]],

         [[0.3398, 0.3332, 0.3270],
          [0.3400, 0.3332, 0.3268],
          [0.3372, 0.3333, 0.3295]]]], grad_fn=<SoftmaxBackward0>)
