
Source: [Understanding and Coding Self-Attention, Multi-Head Attention, Cross-Attention, and Causal-Attention in LLMs](https://magazine.sebastianraschka.com/p/understanding-and-coding-self-attention)

Date: 2024-01-23


Restrict dictonary to input sequence for simplicity

Typical vocabulary sizes range between 30k to 50k entries

In [12]:
sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s 
      in enumerate(sorted(sentence.replace(',', '').split()))}

print(dc)

{'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}


Encode input as id vector:

In [13]:
import torch

sentence_int = torch.tensor(
    [dc[s] for s in sentence.replace(',', '').split()]
)
print(sentence_int)

tensor([0, 4, 5, 2, 1, 3])


Embedding in `3` dimensions, initialised randomly (typically `10^2` or `10^3`)

In [16]:
vocab_size = 50_000

torch.manual_seed(123)
embed = torch.nn.Embedding(vocab_size, 3)
embedded_sentence = embed(sentence_int).detach()

print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.3035],
        [ 0.1794,  1.8951,  0.4954],
        [ 0.2692, -0.0770, -1.0205],
        [-0.2196, -0.3792,  0.7671],
        [-0.5880,  0.3486,  0.6603],
        [-1.1925,  0.6984, -1.4097]])
torch.Size([6, 3])


Projection to *query sequence* and *key sequence* is `2`-dim, resp. `4`-dim to *value sequence*.

In [17]:
torch.manual_seed(123)

d = embedded_sentence.shape[1]

d_q, d_k, d_v = 2, 2, 4

W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))

Compute the attention vector for the second input element:

In [19]:
x_2 = embedded_sentence[1]
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2.shape)
print(key_2.shape)
print(value_2.shape)

torch.Size([2])
torch.Size([2])
torch.Size([4])


Generalize this to compute the remaining key, and value elements for all inputs:

In [21]:
keys = embedded_sentence @ W_key
values = embedded_sentence @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 4])


Compute the unnormalized attention weight for the query and 5th input element:

In [24]:
omega_24 = query_2.dot(keys[4])
print(omega_24)

tensor(1.2903, grad_fn=<DotBackward0>)


Compute the Ï‰ values for all input tokens:

In [25]:
omega_2 = query_2 @ keys.T
print(omega_2)

tensor([-0.6004,  3.4707, -1.5023,  0.4991,  1.2903, -1.3374],
       grad_fn=<SqueezeBackward4>)


normalize attention weights:

In [26]:
import torch.nn.functional as F

attention_weights_2 = F.softmax(omega_2 / d_k**0.5, dim=0)
print(attention_weights_2)

tensor([0.0386, 0.6870, 0.0204, 0.0840, 0.1470, 0.0229],
       grad_fn=<SoftmaxBackward0>)


In [27]:
context_vector_2 = attention_weights_2 @ values

print(context_vector_2.shape)
print(context_vector_2)

torch.Size([4])
tensor([0.5313, 1.3607, 0.7891, 1.3110], grad_fn=<SqueezeBackward4>)


cont. https://magazine.sebastianraschka.com/i/140464659/self-attention