In [6]:
import numpy as np

Why?

One key component with the transformer architecture (making up about 60% of it) is the Multiheaded attention.\
There is nothing magical about it either, it is just attention but chunked into N parts on the embedding dim.

Practically, $Query$, $Key$ and $Value$ are all of shape $(C, E)$ where $C$ is the input dim and $E$ is the embedding dim . \
Meaning... $softmax(\frac {Query . Key^T} {\sqrt {E}} + Mask) . Value$ is of shape $(C, E)$

In the multiheaded scenario... we chunk the $Query$, $Key$ and $Value$ by the number of heads we want (that would be an hyperparameter for the model).

i.e. for $N$ heads we have $N$ small $Query$, $Key$ and $Value$, each of shape $(C, E / N)$, for simplicity $N$ should divide $E$.

But how do we merge it back? \
Concatenation. \
Like.. literally... This is yet again non-rigourous ad hoc solution by DL folks. But intuitively, we are still on the wrong space! as $softmax$ do not have such fancy convenient properties.

The other trick is to have yet another intermediary space that would project the ad hoc concatenation into the actual/original expected embedding space.

$$(Head_1 \oplus Head_2 \oplus ... \oplus Head_N) \xrightarrow{} raw concat \xrightarrow{proj} original$$
$$(C, E / N) \oplus (C, E / N) \oplus ... \oplus (C, E / N) \xrightarrow{} (C, E) \xrightarrow{proj} (C, E) $$

$proj$ can be as simple as (yet) another linear transformation that can be learned through training. 

i.e. $$attention = proj(Head_1 \oplus Head_2 \oplus ... \oplus Head_N) $$
with $$Head_i = softmax(\frac {Split(Query) . Split(Key)^T} {\sqrt {E / N}} + Mask) . Split(Value)$$

In [7]:
def softmax(x):
    x_exp = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return x_exp / np.sum(x_exp, axis=-1, keepdims=True)

In [8]:
B, C, E = 3, 11, 8 # batch, context length, embedding dim
q = np.random.rand(B, C, E)
k = np.random.rand(C, E)
v = np.random.rand(C, E)

In [9]:
# classic (no mask for simplicity)
classic_attention = softmax((q @ k.T) / np.sqrt(E)) @ v
classic_attention.shape

(3, 11, 8)

In [11]:
# multiheaded
N = 4
assert E % N == 0

proj = np.random.rand(E, E) # learned at the same time

split_dim = E // N
heads = []

for i in range(N):
    q_head = q[:, :, i * split_dim: (i + 1) * split_dim]    # (B, C, E / N)
    k_head = k[   :, i * split_dim: (i + 1) * split_dim]    # (C, E / N)
    v_head = v[   :, i * split_dim: (i + 1) * split_dim]    # (C, E / N)
    # print(q_head.shape, k_head.T.shape, v_head.shape)

    scores = q_head @ k_head.T / np.sqrt(split_dim)         # (B, C, C)
    attention_weights = softmax(scores)                     # (B, C, C)

    # weighted sum of values
    local_head = attention_weights @ v_head                 # (B, C, E / N)
    heads.append(local_head)

raw_multi_head_output = np.concatenate(heads, axis=-1)      # (B, C, E)
print("concat", raw_multi_head_output.shape)

attention = raw_multi_head_output @ proj                    # (B, C, E) . (E, E) -> (B, C, E)
print("proj(concat) -> attention", raw_multi_head_output.shape)

concat (3, 11, 8)
proj(concat) -> attention (3, 11, 8)
