# Multi-head Self-Attention (MHSA)

## Summary

**TL;DR**

* **Self-Attention** serves as a single person looking at a sentence and trying to understand the relationship between words. 
* **Multihead Self-Attention** is like a committee of people looking at the same sentence, where each person is assigned to look for a specific type of relationship

<div align="center">

|**Feature**|**Self-Attention**|**Multi-Head Attention**|
|-------|----------------------------|---------------------------|
|Perspective|View the sequence through a single "lens." <br> Smooth out details.|Views the sequence through multiple "lenses" simultaneously. <br> Focus on different aspects|

<div>

## Multi-head Self-Attention with Causal Masking and Rope

**Multihead Self-Attention or MHSA** stands as the quintessential architectural pivot within the Transformer paradigm, transcending the limitations of monolithic attention mechanisms. While self-attention facilitates the mapping of dependencies within a singular sequence, MHSA extrapolates this capability by partitioning the input into diverse "representation subspaces," thereby enabling the model to zero in on disparate relational features simultaneously. The act of splitting up the attention mechanism serves as a safeguard against "representational collapse." By forcing the model to attend to information from different subspaces at different positions, MHSA ensures the final output is a sophisticated amalgam of varied contextual cues, rather than a flattened, singular interpretation. Let $X \in \mathbb{R}^{n \times d_{model}}$ represent your input sequence, where $n$ is the sequence length and $d_{model}$ is the embedding dimension.1. 

**1.Linear Projections**

First, we project the input $X$ into Query ($Q$), Key ($K$), and Value ($V$) spaces using learnable weight matrices. To optimize computation, we perform one large matrix multiplication for each:

$$
Q = X W_Q, \quad K = X W_K, \quad V = X W_V
$$

Where $W_Q, W_K \in \mathbb{R}^{d_{model} \times (h \cdot d_k)}$ and $W_V \in \mathbb{R}^{d_{model} \times (h \cdot d_v)}$. Here, $h$ is the number of heads, and $d_k, d_v$ are the dimensions per head.

**2.Multi-Head Splitting**

We reshape $Q, K, V$ to separate the heads. This effectively treats the head dimension as a batch dimension by simply $d_{model}$ divided by $h$:

$$
Q_i, K_i, V_i \in \mathbb{R}^{n \times d_k} \quad \text{for each head } i \in \{1, \dots, h\}
$$

**3.Applying Rotary Positional Embeddings (RoPE)**

For a vector at position $m$ in the sequence, the RoPE transformation $f_{\text{RoPE}}$ is applied:

$$
\tilde{Q}_i^{(m)} = f_{\text{RoPE}}(Q_i^{(m)}, m), \quad \tilde{K}_i^{(m)} = f_{\text{RoPE}}(K_i^{(m)}, m)
$$

**4.Scaled Dot-Product Attention with Causal Masking**

For each head, we compute the attention scores and apply a causal mask $M$. The attention score matrix for head $i$ is:

$$
\text{head}_i = \text{Softmax}\left(\frac{\tilde{Q}_i \tilde{K}_i^T}{\sqrt{d_k}}\right)V_i + M\;\;\;\;\;\;\text{where}\;\;\;M_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}
$$

**5.Concatenation and Output Projection**

Finally, the results from all heads are concatenated and projected back to the original model dimension using the output weight matrix $W_O \in \mathbb{R}^{(h \cdot d_v) \times d_{model}}$:
$$
\text{MultiHeadSelfAttention}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) W_O
$$

## Code

In [None]:
import torch
import math
from torch import Tensor
from jaxtyping import Float
import torch.nn as nn
from cs336_basics.softmax import Softmax
from cs336_basics.sdpa import ScaledDotProductAttention
from cs336_basics.rope import RoPE


class MultiHeadSelfAttentionRoPE(nn.Module):
    
    def __init__(
            self,
            d_model : int,
            num_heads : int,
            max_seq_len : int,
            theta : float
        ) -> None:
        
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.attention = ScaledDotProductAttention()
        d_head = d_model // num_heads

        self.rope = RoPE(
            d_k=d_head,
            max_seq_len=max_seq_len,
            theta=theta,
        )
 
    
    def forward(
            self,
            q_proj_weight : torch.Tensor,
            k_proj_weight : torch.Tensor,
            v_proj_weight : torch.Tensor,
            o_proj_weight : torch.Tensor,
            in_features : torch.Tensor,
            token_positions : torch.Tensor
        ) -> Tensor:

        # Linear projections
        q = in_features @ q_proj_weight.T
        k = in_features @ k_proj_weight.T
        v = in_features @ v_proj_weight.T

        # Reshape and Transpose for the projections
        def split_heads(x: Tensor):
            *batch_dims, seq_len, d_total = x.shape
            d_head = d_total // self.num_heads
            return x.view(*batch_dims, seq_len, self.num_heads, d_head).transpose(-3, -2)

        # Split heads
        q = split_heads(q)
        k = split_heads(k)
        v = split_heads(v)

        # Apply RoPE
        *batch_dims, num_heads, seq_len, d_head = q.shape

        q = q.reshape(-1, seq_len, d_head)
        k = k.reshape(-1, seq_len, d_head)

        q = self.rope(q, token_positions)
        k = self.rope(k, token_positions)

        q = q.view(*batch_dims, num_heads, seq_len, d_head)
        k = k.view(*batch_dims, num_heads, seq_len, d_head)

        # Causal mask
        seq_len = q.shape[-2]
        causal_mask = torch.tril(
            torch.ones(seq_len, seq_len, device=in_features.device, dtype=torch.bool)
        )

        attn_output = self.attention(q, k, v, mask=causal_mask)

        # Concatenate heads
        *batch_dims, num_heads, seq_len, d_v_head = attn_output.shape
        combined = attn_output.transpose(-3, -2).reshape(
            *batch_dims, seq_len, num_heads * d_v_head
        )

        return combined @ o_proj_weight.T