# Multi-Head Self-Attention

- 📺 **Video:** [https://youtu.be/nHXrdLMo8Uk](https://youtu.be/nHXrdLMo8Uk)

## Overview
- Show how multi-head self-attention allows the model to focus on different relational patterns simultaneously.
- Explain head concatenation and projection back to the model dimension.

## Key ideas
- **Multiple heads:** each head uses different learned projections to capture varied dependencies.
- **Concatenation:** head outputs are concatenated then linearly mixed.
- **Dimensionality:** head dimension times number of heads equals model dimension.
- **Interpretability:** different heads often specialize (syntax, coreference, etc.).

## Demo
Implement a two-head self-attention block on toy embeddings to illustrate the computation pipeline described in the lecture (https://youtu.be/6FM2ctkEoMc).

In [1]:
import numpy as np

X = np.array([
    [0.5, 0.1, 0.3, 0.2],
    [0.2, 0.4, 0.1, 0.5],
    [0.7, 0.0, 0.2, 0.4]
])
model_dim = X.shape[1]
num_heads = 2
head_dim = model_dim // num_heads

rng = np.random.default_rng(0)
W_q = rng.normal(scale=0.4, size=(model_dim, model_dim))
W_k = rng.normal(scale=0.4, size=(model_dim, model_dim))
W_v = rng.normal(scale=0.4, size=(model_dim, model_dim))
W_o = rng.normal(scale=0.4, size=(model_dim, model_dim))

Q = X @ W_q
K = X @ W_k
V = X @ W_v

heads = []
for h in range(num_heads):
    q_h = Q[:, h*head_dim:(h+1)*head_dim]
    k_h = K[:, h*head_dim:(h+1)*head_dim]
    v_h = V[:, h*head_dim:(h+1)*head_dim]
    scale = np.sqrt(head_dim)
    logits = q_h @ k_h.T / scale
    weights = np.exp(logits - logits.max(axis=-1, keepdims=True))
    weights /= weights.sum(axis=-1, keepdims=True)
    heads.append(weights @ v_h)

concat = np.concatenate(heads, axis=-1)
output = concat @ W_o

print('Head 0 output:')
print(heads[0])
print()
print('Head 1 output:')
print(heads[1])
print()
print('Multi-head combination:')
print(output)


Head 0 output:
[[-0.13502184  0.16859793]
 [-0.13458499  0.16709715]
 [-0.13486361  0.16837749]]

Head 1 output:
[[0.41382381 0.51585003]
 [0.41388256 0.51598671]
 [0.41415734 0.51665625]]

Multi-head combination:
[[-0.11650333 -0.15291629  0.05670701 -0.11995107]
 [-0.11619959 -0.15312817  0.05761014 -0.12044207]
 [-0.1164723  -0.15317532  0.05724618 -0.12031059]]


## Try it
- Modify the demo
- Add a tiny dataset or counter-example


## References
- [Eisenstein 6.1](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 6.2](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 6.4](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [Eisenstein 6.3](https://github.com/jacobeisenstein/gt-nlp-class/blob/master/notes/eisenstein-nlp-notes.pdf)
- [[Blog] Understanding LSTMs](http://colah.github.io/posts/2015-08-Understanding-LSTMs/)
- [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473)
- [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)
- [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)
- [[Blog] The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/)
- [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)
- [Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation](https://arxiv.org/abs/2108.12409)
- [The Impact of Positional Encoding on Length Generalization in Transformers](https://arxiv.org/abs/2305.19466)


*Links only; we do not redistribute slides or papers.*