# Self-Attention



Start with our token embeddings.

For this example we use 3 dimensional embeddings that were arbitrarily created to represent each token for illustrative purposes. Also, we just assume each word is a token (rather than breaking a word into smaller tokens).

In reality, tokens are often sub-words and have many more dimensions. GPT-3 used over 12k diomensions for it's token embeddings.

In [1]:
import torch

inputs = torch.tensor(
    [[0.21, 0.47, 0.91], # I
     [0.52, 0.11, 0.65], # can't
     [0.03, 0.85, 0.19], # find
     [0.73, 0.64, 0.39], # the
     [0.13, 0.55, 0.68], # light
     [0.22, 0.77, 0.08]] # switch
)

In [27]:
inputs.shape

torch.Size([6, 3])

### Step 1 - Compute Attention Scores

Each token in our 6 word context window above needs to know how much it shoud "pay attention" to the other tokens in our sequence. 

In other words - how much does a token impact another token's meaning/ change it's context.

To tell our model how much attention a token should should give the other tokens in the sequence we compute the attention weights.

##### Computing Attention Weights
For every token in our context window, we take that token as our "query" and compute a dot product between the query and every other token in the context window.

The dot product is a measure of similarity between the tokens.

In [None]:
# Computing attention scores for only one token.
# Example taking the token "light" as the query
query = inputs[4]

# Create empty tensor, sized for each token in our context window
attention_scores_for_light = torch.empty(inputs.shape[0])

# Compute attention scores for each token in the context window
for i, token_embedding in enumerate(inputs):
    attention_scores_for_light[i] = torch.dot(query, token_embedding)

In [3]:
attention_scores_for_light

tensor([0.9046, 0.5701, 0.6006, 0.7121, 0.7818, 0.5065])

### Step 2 - Normalize the Attention Scores to Weights

We want our scores to sum to 1, so we normalize them to create our attention weights

In [5]:
import torch.nn.functional as F

attention_weights_for_light = F.softmax(attention_scores_for_light, dim=0)
sum(attention_weights_for_light)

tensor(1.0000)

### Step 3 - Compute the Context Vector
Take the input vectors of each token * the attention weights w.r.t. a given token to get that token's context vector (simplified version).

In [15]:
context_vector_for_light = torch.zeros_like(query)

for i, token_embedding in enumerate(inputs):
    context_vector_for_light += attention_weights_for_light[i] * token_embedding

context_vector_for_light

tensor([0.3039, 0.5600, 0.5155])

This is one simplified example for updating the input embedding vector for a given token with contextual information that "attends to" the other tokens in the context sequence.

# Extending this to All Tokens
1) Compute attention scores
2) Compute attention weights
3) Compute context vectors

In [31]:
inputs.shape

torch.Size([6, 3])

In [32]:
inputs.T.shape

torch.Size([3, 6])

In [None]:
# Compute the dot product of the input embeddings with themselves
# Matrix multiplication (@) of the input embeddings with their transpose (.T)
attention_scores = inputs @ inputs.T
attention_scores

tensor([[1.0931, 0.7524, 0.5787, 0.8090, 0.9046, 0.4809],
        [0.7524, 0.7050, 0.2326, 0.7035, 0.5701, 0.2511],
        [0.5787, 0.2326, 0.7595, 0.6400, 0.6006, 0.6763],
        [0.8090, 0.7035, 0.6400, 1.0946, 0.7121, 0.6846],
        [0.9046, 0.5701, 0.6006, 0.7121, 0.7818, 0.5065],
        [0.4809, 0.2511, 0.6763, 0.6846, 0.5065, 0.6477]])

Each row above is an attention score vector for that token in our context window.

In [None]:
# dim= -1 means apply the softmax function to the last dimension of the attention_scores tensor
# in this case that means normalize the scores for each token in the context window
# since the last dimension of the attention_scores tensor is the vector for each token
attention_weights = F.softmax(attention_scores, dim=-1)
attention_weights

tensor([[0.2256, 0.1605, 0.1349, 0.1698, 0.1869, 0.1223],
        [0.2024, 0.1931, 0.1204, 0.1928, 0.1687, 0.1226],
        [0.1641, 0.1161, 0.1966, 0.1745, 0.1677, 0.1809],
        [0.1705, 0.1534, 0.1440, 0.2268, 0.1547, 0.1505],
        [0.2068, 0.1480, 0.1526, 0.1706, 0.1829, 0.1389],
        [0.1552, 0.1233, 0.1887, 0.1902, 0.1592, 0.1834]])

In [35]:
context_vectors = attention_weights @ inputs
context_vectors

tensor([[0.3101, 0.5440, 0.5383],
        [0.3362, 0.5293, 0.5323],
        [0.2897, 0.6003, 0.4588],
        [0.3387, 0.5656, 0.4880],
        [0.3039, 0.5600, 0.5155],
        [0.3023, 0.5974, 0.4544]])

In [36]:
attention_weights.shape, inputs.shape, context_vectors.shape

(torch.Size([6, 6]), torch.Size([6, 3]), torch.Size([6, 3]))

Left off 3.4