Implement the multi-head attention mechanism, a critical component of transformer models. Given Query (Q), Key (K), and Value (V) matrices, compute the attention outputs for multiple heads and concatenate the results.

Example:
Input:
Q = np.array([[1, 0], [0, 1]]), K = np.array([[1, 0], [0, 1]]), V = np.array([[1, 0], [0, 1]]), n_heads = 2
Output:
[[1., 0.], [0., 1.]]
Reasoning:
Multi-head attention is computed for 2 heads using the input Q, K, and V matrices. The resulting outputs for each head are concatenated to form the final attention output.



In [1]:
import numpy as np

def _softmax(x, axis=-1):
    x = x - np.max(x, axis=axis, keepdims=True)  # stable
    e = np.exp(x)
    return e / np.sum(e, axis=axis, keepdims=True)

def compute_qkv(X: np.ndarray, W_q: np.ndarray, W_k: np.ndarray, W_v: np.ndarray):
    """X: (T, D), W_*: (D, D). Returns Q,K,V each (T, D)."""
    return X @ W_q, X @ W_k, X @ W_v

def self_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray):
    """Single-head scaled dot-product attention for (T, d)."""
    d_k = Q.shape[-1]
    scores = (Q @ K.T) / np.sqrt(d_k)          # (T, T)
    attn = _softmax(scores, axis=-1)           # row-wise softmax
    return attn @ V                             # (T, d)

def multi_head_attention(Q: np.ndarray, K: np.ndarray, V: np.ndarray, n_heads: int):
    """
    Q,K,V: (T, D). n_heads must divide D.
    Returns concatenated head outputs: (T, D).
    """
    T, D = Q.shape
    assert D % n_heads == 0, "D must be divisible by n_heads"
    d_h = D // n_heads

    # reshape to (H, T, d_h)
    def split_heads(X):
        return X.reshape(T, n_heads, d_h).transpose(1, 0, 2)

    Qh, Kh, Vh = map(split_heads, (Q, K, V))   # each (H, T, d_h)

    # compute attention per head
    outs = []
    for h in range(n_heads):
        out_h = self_attention(Qh[h], Kh[h], Vh[h])  # (T, d_h)
        outs.append(out_h)

    # concat heads: (T, D)
    return np.concatenate(outs, axis=-1)

# Example
if __name__ == "__main__":
    Q = np.array([[1, 0], [0, 1]], dtype=float)
    K = np.array([[1, 0], [0, 1]], dtype=float)
    V = np.array([[1, 0], [0, 1]], dtype=float)
    print(multi_head_attention(Q, K, V, n_heads=2))
    # [[1. 0.]
    #  [0. 1.]]

[[0.73105858 0.5       ]
 [0.5        0.73105858]]
