# 15 MultiHeadAttention

In [1]:
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class AttentionConfig:
    embed_dim: int
    num_heads: int
    dropout_p: float = 0.0
    bias: bool = True

In [2]:
def causal_mask(T: int, S: Optional[int] = None, device=None) -> torch.Tensor:
    """
    Returns a bool mask [T, S] where True means ALLOW, False means BLOCK.
    Causal means each query position t can attend to keys <= t (when S==T).
    If S!=T, interpret keys as ordered and block future keys beyond query index where possible.
    """
    if S is None:
        S = T
    m = torch.ones((T, S), dtype=torch.bool, device=device)
    # Allow only j <= i for the overlapping region.
    idx_i = torch.arange(T, device=device)[:, None]
    idx_j = torch.arange(S, device=device)[None, :]
    return m & (idx_j <= idx_i)

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Interview scaffold: implement multi-head attention (self-attn and cross-attn).

    Conventions:
      - x: [B, T, E]
      - kv: [B, S, E] (only for cross-attn; S may differ from T)
      - attn_mask: [T, S] or [B, 1, T, S] (bool or float)
      - key_padding_mask: [B, S] bool (True = keep, False = pad)  (or invert if you prefer; document!)
    """

    def __init__(self, cfg: AttentionConfig):
        super().__init__()
        assert cfg.embed_dim % cfg.num_heads == 0, "embed_dim must be divisible by num_heads"
        self.cfg = cfg
        self.head_dim = cfg.embed_dim // cfg.num_heads

        # Candidate should use these
        self.q_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim, bias=cfg.bias)
        self.k_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim, bias=cfg.bias)
        self.v_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim, bias=cfg.bias)
        self.out_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim, bias=cfg.bias)
        self.dropout = nn.Dropout(cfg.dropout_p)

    # ---------- helpers (candidate implements) ----------

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
          x: [B, T, E]
        Returns:
          xh: [B, H, T, Dh]
        """
        raise NotImplementedError

    def _merge_heads(self, xh: torch.Tensor) -> torch.Tensor:
        """
        Args:
          xh: [B, H, T, Dh]
        Returns:
          x:  [B, T, E]
        """
        raise NotImplementedError

    def _apply_masks(
        self,
        attn_logits: torch.Tensor,
        attn_mask: Optional[torch.Tensor],
        key_padding_mask: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """
        Apply masks to attention logits before softmax.

        Args:
          attn_logits: [B, H, T, S]
          attn_mask: optional, either
            - bool mask where True means ALLOW, False means BLOCK
              shape [T, S] or [B, 1, T, S]
            - OR float additive mask (e.g., 0 for allow, -inf for block)
              shape [T, S] or [B, 1, T, S]
          key_padding_mask: optional bool mask [B, S] where True means token is real (keep),
                            False means padding (block).
                            (If you prefer the opposite convention, document and adjust tests.)
        Returns:
          masked_logits: [B, H, T, S]
        """
        raise NotImplementedError

    # ---------- main API (candidate implements) ----------

    def forward(
        self,
        x: torch.Tensor,
        kv: Optional[torch.Tensor] = None,
        *,
        attn_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        need_weights: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Multi-head attention.

        Self-attn when kv is None:
          q,k,v all come from x.

        Cross-attn when kv is not None:
          q from x, k,v from kv.

        Args:
          x:  [B, T, E]
          kv: [B, S, E] or None
          attn_mask: optional [T, S] or [B, 1, T, S], bool or float additive
          key_padding_mask: optional [B, S] bool (True keep, False pad)
          need_weights: if True, also return attention probs averaged over heads [B, T, S]

        Returns:
          y: [B, T, E]
          weights (optional): [B, T, S] (head-averaged attention probabilities)
        """
        raise NotImplementedError

## Run Test Cases

In [None]:
def _seed_all(seed=0):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def _assert_close(a: torch.Tensor, b: torch.Tensor, atol=1e-5, rtol=1e-5, msg=""):
    if not torch.allclose(a, b, atol=atol, rtol=rtol):
        raise AssertionError(msg + f"\nmax abs diff: {(a-b).abs().max().item()}")

def run_mha_basic_tests():
    """
    Keep these private or reveal progressively.
    """
    _seed_all(0)
    device = "cpu"

    cfg = AttentionConfig(embed_dim=8, num_heads=2, dropout_p=0.0, bias=True)
    mha = MultiHeadAttention(cfg).to(device)
    mha.eval()

    B, T, S, E = 2, 4, 5, cfg.embed_dim
    x = torch.randn(B, T, E, device=device)
    kv = torch.randn(B, S, E, device=device)

    # ---- Test 1: output shape (self-attn) ----
    y, w = mha(x, need_weights=True)
    assert y.shape == (B, T, E), f"self-attn output shape wrong: {y.shape}"
    assert w is not None and w.shape == (B, T, T), f"weights shape wrong: {None if w is None else w.shape}"
    _assert_close(w.sum(dim=-1), torch.ones(B, T), msg="weights should sum to 1 over keys (self-attn)")

    # ---- Test 2: output shape (cross-attn) ----
    y2, w2 = mha(x, kv=kv, need_weights=True)
    assert y2.shape == (B, T, E), f"cross-attn output shape wrong: {y2.shape}"
    assert w2 is not None and w2.shape == (B, T, S), f"cross weights shape wrong: {None if w2 is None else w2.shape}"
    _assert_close(w2.sum(dim=-1), torch.ones(B, T), msg="weights should sum to 1 over keys (cross-attn)")

    # ---- Test 3: key padding mask blocks padded keys ----
    # Mark last 2 keys as padding (False = pad)
    kpm = torch.ones(B, S, dtype=torch.bool, device=device)
    kpm[:, -2:] = False
    y3, w3 = mha(x, kv=kv, key_padding_mask=kpm, need_weights=True)
    assert (w3[..., -2:] < 1e-6).all(), "padded keys should get ~0 attention probability"

    # ---- Test 4: causal mask blocks future keys (self-attn) ----
    cm = causal_mask(T, device=device)  # [T, T], True allow
    y4, w4 = mha(x, attn_mask=cm, need_weights=True)
    # For each query t, positions > t should be ~0
    triu = torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1)
    assert (w4[:, triu] < 1e-6).all(), "causal mask should block future positions"

    # ---- Test 5: determinism / stability sanity ----
    # If x is all zeros and projections have bias, outputs should be consistent; mostly checks no NaNs.
    x0 = torch.zeros(B, T, E, device=device)
    y0, w0 = mha(x0, need_weights=True)
    assert torch.isfinite(y0).all() and torch.isfinite(w0).all(), "should not produce NaNs/Infs"

    print("All MHA basic tests passed âœ…")

In [None]:
run_mha_basic_tests()