In [2]:
import numpy as np

# utils
def debug_att(source, mat, who, to):
    print(f"'{who}' attends '{to}': {(mat[source.index(who), source.index(to)] * 100).round()}%")

def debug_x_att(source, target, mat, who, to):
    print(f"'{who}' attends '{to}': {(mat[source.index(who), target.index(to)] * 100).round()}%")

In [3]:
# Suppose we are in the middle of the training dataset (e.g. a text)
# and we are at some batch B (the attention mechanism allows us to process B batches concurrently)

np.random.seed(1234)

tokens_eng = ["nah", "I", "d", "win"]
tokens_jp = ["いいえ", "勝つ", "さ"]


$$attention := softmax(Query(s_q) . Key(s_k)^T) . Value(s_v)$$

Or with proper scaling, assuming $Q$ and $K$ both came from a $N(\mu = 0, \sigma^2 = 1)$, we get..

$$attention := softmax(\frac{Query(s_q) . Key(s_k)^T} {\sqrt{EmbeddingDim}}) . Value(s_v)$$

Note: $Var(Q. K) = \sum_i {Q_{i} K_{i}} =  \sum_i 1 . 1 = EmbeddingDim$ hence the denominator to keep everything stable by scaling back the variance to be 1 due to softmax behaviors on large values.

$s_q, s_k, s_v$ are respectively the input "sources" for the Query, Key and Value matrices.

Commonly, $s_k = s_v= t$, $s_q$ gets refered as query source $s$ whereas $t$ is the target source.

Why does attention take that form? It's **ad hoc** assumptions on top of **well founded intuition**. \
**As far as I know** there are no real findings yet on why it works so well on arbitrary distributions.

On the **well founded intuition** part, the entire thing is not that crazy at all, key is to define mathematically what "similarity" means.. \
For vectors, this can be defined in many ways: cosine similarity, dot product, absolute distance, ..etc.
For attention, we commonly use dot product.

$Q$ is a matrix embedding of the source, each rows is a sub-embedding of some concept such that when we compute the similarity of the \
i-th row with the key $Q_i . K^T$, we get a number representing the affinities between the two. Then we do that for all rows... \
All these ops are compactified in matrix form.

When the source of the query is the same as the key it is a decoder architecture, when they are different, it is an encoder/decoder architecture \
as we compute the similarity between a source and a target source, which is literally inducing an embedding of a translation from source to target \
in some sense, and is also the main topic of the [Attention is All You Need](https://arxiv.org/abs/1706.03762) paper.

Why go all the trouble with intermediary embeddings $Q$, $K$, and $V$? \
The answer is simple, we want to generalize by holding the sources as private and instead we delegate representation with intermediary vector spaces. \
This whole scheme can also be viewed as a lossy compression of the sources, Q compresses the main source, K the target source.

The $softmax(Query . Key)$ component especially can be viewed as a directed graph.
* Each token is a node
* The weight of each connection is expected to quantify how much a source token "cares" about another, \
  this value gets coupled with the source and target.


In [4]:
# Self-attention: the example attends to itself i.e. s_q, s_k, s_z all come from the same source (the example)
softmax_qk = np.array([ 
    # nah  I   d    win
    [0.1, 0.2, 0.3, 0.4], # nah
    [0.1, 0.4, 0.3, 0.2], # I
    [0.3, 0.2, 0.2, 0.3], # d
    [0.3, 0.5, 0.1, 0.1], # win
])

debug_att(tokens_eng, softmax_qk, "I", "win")
debug_att(tokens_eng, softmax_qk, "win", "I")
debug_att(tokens_eng, softmax_qk, "d", "nah")
softmax_qk.sum(axis=1)

'I' attends 'win': 20.0%
'win' attends 'I': 50.0%
'd' attends 'nah': 30.0%


array([1., 1., 1., 1.])

In [5]:
# Causal self-attention: self-attention with a mask

# Goal: each token at position t cannot attend tokens after >= t + 1. (which we care in Generative models like GPTs or word2vec embeddings like BERT)
# Practically, it is just self-attention applied with a mask that hides the upper diag
# Simply done by replacing the components above the diagonal with -Infinity, softmax will 0 on these
# and fix the probabilities on the finite components.

softmax_qk = np.array([ 
    # nah  I   d    win
    [1.0, 0.0, 0.0, 0.0], # nah
    [0.2, 0.8, 0.0, 0.0], # I
    [0.8, 0.1, 0.1, 0.0], # d
    [0.3, 0.5, 0.1, 0.1], # win
])

debug_att(tokens_eng, softmax_qk, "I", "win") # cannot see into the future
debug_att(tokens_eng, softmax_qk, "win", "I")
debug_att(tokens_eng, softmax_qk, "d", "nah")
softmax_qk.sum(axis=1)

'I' attends 'win': 0.0%
'win' attends 'I': 50.0%
'd' attends 'nah': 80.0%


array([1., 1., 1., 1.])

In [6]:
# Cross attention: Query comes from the main source, Key from a target source
softmax_qk = np.array([ 
    # いいえ  勝つ    さ
    [ 0.9, 0.05, 0.05],  # nah
    [ 0.3,  0.6,  0.1],  # I
    [0.33, 0.33, 0.33],  # d
    [0.15,  0.8, 0.05],  # win
])

debug_x_att(tokens_eng, tokens_jp, softmax_qk, "I", "いいえ")
debug_x_att(tokens_eng, tokens_jp, softmax_qk, "win", "勝つ")
debug_x_att(tokens_eng, tokens_jp, softmax_qk, "I", "勝つ")
softmax_qk.sum(axis=1)

'I' attends 'いいえ': 30.0%
'win' attends '勝つ': 80.0%
'I' attends '勝つ': 60.0%


array([1.  , 1.  , 0.99, 1.  ])

In [None]:
# Flash attention

# Nothing much to be said here, it's just yet another trick for reshaping the math ops on the GPU
# Original: https://arxiv.org/abs/2205.14135
# Example improv proposal: https://arxiv.org/abs/2307.08691