In [1]:
# Load corporate proxy configuration
import sys
sys.path.insert(0, '..')
try:
    from _proxy_config import *
except ImportError:
    print("Warning: _proxy_config.py not found. Proxy settings may not be configured.")
except Exception as e:
    print(f"Error loading proxy configuration: {e}")

In [2]:
import torch
import mermaid as md
from mermaid.graph import Graph
import inspect
from build_llm_module.SimpleTokenizer import SimpleTokenizer
from build_llm_module.SelfAttention import SelfAttention

# Attention mechanisms

Before modern LLM's, language tasks were performed by other neural network architectures such as RNN's. Some of these tasks, such as translations, relied on two parts: an encoder and a decoder.

The encoder would be fed a text input in the source language and result in a hidden state. That would then be fed into the decoder to be transformed back into the target natural language.

However, text translations are not done by simply translating word after word. Different languages have different grammatical structures, where sentence entities can come in different orders and sometimes a word in a language could need multiple words in another language (and vice-versa).

The RNN's were good enough for short sentences, but the decoder's lack of access to previous words in the input made them unsuitable for longer texts. The Bahdanau attention mechanism was proposed in 2014 to fix that, giving the RNN's decoder the ability to selectively access different parts of the input sequence at each decoding step.

## 1. Self-attention

The Bahdanau Attention inspired the self-attention mechanism proposed in the transformer architecture.

Self-attention gives the model the ability to compute attention weights that relate different positions within the same input sequence. In traditional attentions mechanisms, the relationships focused on elements of two different sequences, where the attention could be between an input sequence and an output sequence, for example.

In [3]:
sequence = Graph(
    "Self-attention",
    """
block
    columns 1
                 
    block:input:1
        columns 6
        x1["The"] x2["dog"] x3["attacks"] x4["the"] x5["wild"] x6["cat"]
    end
    space
    block:embeddings:1
        block:embedding1:1
            columns 3
            x1d1["0.1"] x1d2["0.2"] x1d3["0.3"]
        end
        block:embedding2:1
            columns 3
            x2d1["0.4"] x2d2["0.1"] x2d3["0.2"]
        end
        block:embedding3:1
            columns 3
            x3d1["0.3"] x3d2["0.2"] x3d3["0.1"]
        end
        block:embedding4:1
            columns 3
            x4d1["0.2"] x4d2["0.3"] x4d3["0.1"]
        end
        block:embedding5:1
            columns 3
            x5d1["0.1"] x5d2["0.4"] x5d3["0.2"]
        end
        block:embedding6:1
            columns 3
            x6d1["0.3"] x6d2["0.1"] x6d3["0.4"]
        end
    end
    x1 --> embedding1
    x2 --> embedding2
    x3 --> embedding3
    x4 --> embedding4
    x5 --> embedding5
    x6 --> embedding6
    space
    block:context_vectors:1
        columns 6
        block:z1:1
            columns 3
            z1d1["0.2"] z1d2["0.3"] z1d3["0.4"]
        end
        block:z2:1
            columns 3
            z2d1["0.1"] z2d2["0.4"] z2d3["0.3"]
        end
        block:z3:1
            columns 3
            z3d1["0.4"] z3d2["0.2"] z3d3["0.1"]
        end
        block:z4:1
            columns 3
            z4d1["0.3"] z4d2["0.1"] z4d3["0.2"]
        end
        block:z5:1
            columns 3
            z5d1["0.2"] z5d2["0.4"] z5d3["0.1"]
        end
        block:z6:1
            columns 3
            z6d1["0.1"] z6d2["0.2"] z6d3["0.4"]
        end
    end
    embedding1 --> z2
    embedding2 --> z2
    embedding3 --> z2
    embedding4 --> z2
    embedding5 --> z2
    embedding6 --> z2

""",
    config={
        "theme": "base",
        "themeVariables": {
            "primaryColor": "#350F0F",
            "secondaryColor": "#e0e0e0",
            "tertiaryColor": "#d0d0d0",
            "lineColor": "#F8B229",
        },
    },
)
render = md.Mermaid(sequence)
render

### 1.1. Simplified self-attention

What self-attention does is transform each input embedding (a vector of N dimensions) into another vector called a context vector. 

#### Context vectors

Context vectors can be defined as "enriched" versions of the input token embeddings. Each context vector is a representation of one of the embeddings in the input sequence, but it contains information about all other tokens.

#### Attention scores


The first step to calculate the context vectors is to calculate intermediate values named attention scores. For each embedded token input, the attention scores is a vector with the dot products between its embeddings and the embeddings for each other token. The token for which the attention scores are being calculated is called the query token.

Let's take this example, using the simple tokenizer and 3-dimensional embeddings:

In [4]:
torch.manual_seed(69)
input_sequence = "The dog attacks the wild cat"
simple_tokenizer = SimpleTokenizer(input_sequence)
print(f"Tokens:{SimpleTokenizer.tokenize(input_sequence)}")
simple_token_ids = torch.tensor(simple_tokenizer.encode(input_sequence))
embedding_layer = torch.nn.Embedding(simple_tokenizer.vocab_size, embedding_dim=3)

embeddings = embedding_layer(simple_token_ids).detach() # detach() removes grad to reduce print clutter

print(f"Token ids:{simple_token_ids}")
print(f"Embeddings:\n{embeddings}")

Tokens:['The', ' ', 'dog', ' ', 'attacks', ' ', 'the', ' ', 'wild', ' ', 'cat']
Token ids:tensor([1, 0, 4, 0, 2, 0, 5, 0, 6, 0, 3])
Embeddings:
tensor([[ 1.2221,  1.0395,  0.9608],
        [-0.5300, -1.3035,  0.4438],
        [ 0.6370,  1.3158, -0.4287],
        [-0.5300, -1.3035,  0.4438],
        [ 0.4214,  0.7452, -1.8389],
        [-0.5300, -1.3035,  0.4438],
        [ 1.9435, -0.8080, -0.8735],
        [-0.5300, -1.3035,  0.4438],
        [ 0.9367, -0.3077, -1.4196],
        [-0.5300, -1.3035,  0.4438],
        [-1.2497, -0.2485, -1.0530]])


The attention scores for the first token (embedding [ 1.2221,  1.0395,  0.9608]), can be calculated as follows:

In [5]:
e0 = embeddings[0] # using the first token ("The") as query
attention_scores = torch.zeros(len(embeddings))
for i, e in enumerate(embeddings):
    for j, dim in enumerate(e): # dot product is the sum of the element-wise products
        attention_scores[i] += e0[j] * e[j]
    
print(attention_scores)

tensor([ 3.4975, -1.5764,  1.7344, -1.5764, -0.4773, -1.5764,  0.6959, -1.5764,
        -0.5391, -1.5764, -2.7973])


Pytorch has a dot product convenience:

In [6]:
print(torch.tensor([torch.dot(e0, e) for e in embeddings]))

tensor([ 3.4975, -1.5764,  1.7344, -1.5764, -0.4773, -1.5764,  0.6959, -1.5764,
        -0.5391, -1.5764, -2.7973])


Now let's get the attention scores for all the tokens:

In [7]:
attention_scores = torch.zeros(len(embeddings), len(embeddings))

for i, ei in enumerate(embeddings):
    attention_scores[i] = torch.tensor([torch.dot(ei, ej) for ej in embeddings])
print(attention_scores)

tensor([[ 3.4975, -1.5764,  1.7344, -1.5764, -0.4773, -1.5764,  0.6959, -1.5764,
         -0.5391, -1.5764, -2.7973],
        [-1.5764,  2.1770, -2.2430,  2.1770, -2.0108,  2.1770, -0.3645,  2.1770,
         -0.7254,  2.1770,  0.5189],
        [ 1.7344, -2.2430,  2.3208, -2.2430,  2.0373, -2.2430,  0.5492, -2.2430,
          0.8003, -2.2430, -0.6715],
        [-1.5764,  2.1770, -2.2430,  2.1770, -2.0108,  2.1770, -0.3645,  2.1770,
         -0.7254,  2.1770,  0.5189],
        [-0.4773, -2.0108,  2.0373, -2.0108,  4.1144, -2.0108,  1.8230, -2.0108,
          2.7759, -2.0108,  1.2246],
        [-1.5764,  2.1770, -2.2430,  2.1770, -2.0108,  2.1770, -0.3645,  2.1770,
         -0.7254,  2.1770,  0.5189],
        [ 0.6959, -0.3645,  0.5492, -0.3645,  1.8230, -0.3645,  5.1931, -0.3645,
          3.3092, -0.3645, -1.3082],
        [-1.5764,  2.1770, -2.2430,  2.1770, -2.0108,  2.1770, -0.3645,  2.1770,
         -0.7254,  2.1770,  0.5189],
        [-0.5391, -0.7254,  0.8003, -0.7254,  2.7759, -0

Or, more succintly:

In [8]:
attention_scores = embeddings @ embeddings.T
print(attention_scores)

tensor([[ 3.4975, -1.5764,  1.7344, -1.5764, -0.4773, -1.5764,  0.6959, -1.5764,
         -0.5391, -1.5764, -2.7973],
        [-1.5764,  2.1770, -2.2430,  2.1770, -2.0108,  2.1770, -0.3645,  2.1770,
         -0.7254,  2.1770,  0.5189],
        [ 1.7344, -2.2430,  2.3208, -2.2430,  2.0373, -2.2430,  0.5492, -2.2430,
          0.8003, -2.2430, -0.6715],
        [-1.5764,  2.1770, -2.2430,  2.1770, -2.0108,  2.1770, -0.3645,  2.1770,
         -0.7254,  2.1770,  0.5189],
        [-0.4773, -2.0108,  2.0373, -2.0108,  4.1144, -2.0108,  1.8230, -2.0108,
          2.7759, -2.0108,  1.2246],
        [-1.5764,  2.1770, -2.2430,  2.1770, -2.0108,  2.1770, -0.3645,  2.1770,
         -0.7254,  2.1770,  0.5189],
        [ 0.6959, -0.3645,  0.5492, -0.3645,  1.8230, -0.3645,  5.1931, -0.3645,
          3.3092, -0.3645, -1.3082],
        [-1.5764,  2.1770, -2.2430,  2.1770, -2.0108,  2.1770, -0.3645,  2.1770,
         -0.7254,  2.1770,  0.5189],
        [-0.5391, -0.7254,  0.8003, -0.7254,  2.7759, -0

#### Attention weights

The next step is to transform attention scores into attention **weights**. We can do this by normalizing each tensor so that they sum up to 1:

In [9]:
atscs0 = attention_scores[0]
atwts0 = atscs0 / sum(atscs0)
print(f"scores: {atscs0}", f"weights: {atwts0}", sep="\n")
print(f"sum: {atwts0.sum()}")

scores: tensor([ 3.4975, -1.5764,  1.7344, -1.5764, -0.4773, -1.5764,  0.6959, -1.5764,
        -0.5391, -1.5764, -2.7973])
weights: tensor([-0.6064,  0.2733, -0.3007,  0.2733,  0.0828,  0.2733, -0.1207,  0.2733,
         0.0935,  0.2733,  0.4850])
sum: 1.0


However, it is more advisable to normalize by using the softmax function. This function is better for managing extreme values, and offers more favorable gradient properties. It also guarantees that all weights are positive, which makes outputs more interpretable as probabilities.

$ \omega(z)_i = \frac{e^{z_i}}{\sum^{K}_{j=1} e^{z_j}}$

In [10]:
def softmax(scores: torch.Tensor) -> torch.Tensor:
    return torch.exp(scores) / torch.exp(scores).sum(dim=0)

atwts0_softmax = softmax(atscs0)
print(f"weights (softmax): {atwts0_softmax}", f"sum: {sum(atwts0_softmax)}", sep="\n")

weights (softmax): tensor([0.7682, 0.0048, 0.1318, 0.0048, 0.0144, 0.0048, 0.0466, 0.0048, 0.0136,
        0.0048, 0.0014])
sum: 1.0


Of course, pytorch also has a built-in function for that:

In [11]:
torch.softmax(atscs0, dim=0)

tensor([0.7682, 0.0048, 0.1318, 0.0048, 0.0144, 0.0048, 0.0466, 0.0048, 0.0136,
        0.0048, 0.0014])

Let's get the attention weights for all the tokens:

In [12]:
attention_weights = attention_scores.softmax(dim=1)
print(f"Attention weights:\n{attention_weights}")
print(f"Sum sanity checks:\n{attention_weights.sum(dim=1)}")

Attention weights:
tensor([[0.7682, 0.0048, 0.1318, 0.0048, 0.0144, 0.0048, 0.0466, 0.0048, 0.0136,
         0.0048, 0.0014],
        [0.0044, 0.1861, 0.0022, 0.1861, 0.0028, 0.1861, 0.0147, 0.1861, 0.0102,
         0.1861, 0.0354],
        [0.1987, 0.0037, 0.3571, 0.0037, 0.2689, 0.0037, 0.0607, 0.0037, 0.0781,
         0.0037, 0.0179],
        [0.0044, 0.1861, 0.0022, 0.1861, 0.0028, 0.1861, 0.0147, 0.1861, 0.0102,
         0.1861, 0.0354],
        [0.0065, 0.0014, 0.0800, 0.0014, 0.6389, 0.0014, 0.0646, 0.0014, 0.1675,
         0.0014, 0.0355],
        [0.0044, 0.1861, 0.0022, 0.1861, 0.0028, 0.1861, 0.0147, 0.1861, 0.0102,
         0.1861, 0.0354],
        [0.0091, 0.0031, 0.0078, 0.0031, 0.0280, 0.0031, 0.8144, 0.0031, 0.1238,
         0.0031, 0.0012],
        [0.0044, 0.1861, 0.0022, 0.1861, 0.0028, 0.1861, 0.0147, 0.1861, 0.0102,
         0.1861, 0.0354],
        [0.0083, 0.0069, 0.0318, 0.0069, 0.2294, 0.0069, 0.3910, 0.0069, 0.2835,
         0.0069, 0.0213],
        [0.0044, 0

#### Calculating context vectors

Finally, having the weights in hand, we calculate the context vectors by multiplying the embedded input tokens with the corresponding attention weights, then summing the resulting vectors.

For example, using the first embedded token as query:

In [13]:
print(f"embeddings:\n{embeddings}")
print(f"query: {e0}")
print(f"weights(query): {atwts0_softmax}")

embeddings:
tensor([[ 1.2221,  1.0395,  0.9608],
        [-0.5300, -1.3035,  0.4438],
        [ 0.6370,  1.3158, -0.4287],
        [-0.5300, -1.3035,  0.4438],
        [ 0.4214,  0.7452, -1.8389],
        [-0.5300, -1.3035,  0.4438],
        [ 1.9435, -0.8080, -0.8735],
        [-0.5300, -1.3035,  0.4438],
        [ 0.9367, -0.3077, -1.4196],
        [-0.5300, -1.3035,  0.4438],
        [-1.2497, -0.2485, -1.0530]])
query: tensor([1.2221, 1.0395, 0.9608])
weights(query): tensor([0.7682, 0.0048, 0.1318, 0.0048, 0.0144, 0.0048, 0.0466, 0.0048, 0.0136,
        0.0048, 0.0014])


The embedding and its correspondent context vector will have the same shape.

In [14]:
ctxtv0 = torch.zeros(e0.shape)
print(ctxtv0)

tensor([0., 0., 0.])


For one query (which in in this case is the first embedding, e0), we loop through all the embeddings (including itself), multiplying each of them by the correspondent element in the query's attention weights vector. Then we sum the results:

In [15]:
for i, ei in enumerate(embeddings):
    weighted = atwts0_softmax[i] * ei
    print(f"{atwts0_softmax[i]:.2f} * {ei}  => {weighted}",end="\n")
    ctxtv0 += weighted
print(f"context vector (query) = {ctxtv0}")

0.77 * tensor([1.2221, 1.0395, 0.9608])  => tensor([0.9388, 0.7985, 0.7381])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.13 * tensor([ 0.6370,  1.3158, -0.4287])  => tensor([ 0.0839,  0.1734, -0.0565])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.01 * tensor([ 0.4214,  0.7452, -1.8389])  => tensor([ 0.0061,  0.0108, -0.0265])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.05 * tensor([ 1.9435, -0.8080, -0.8735])  => tensor([ 0.0906, -0.0377, -0.0407])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.01 * tensor([ 0.9367, -0.3077, -1.4196])  => tensor([ 0.0127, -0.0042, -0.0193])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.00 * tensor([-1.2497, -0.2485, -1.0530])  => tensor([-0.0018, -0.0004, -0.0015])
context vector (query) = tensor([1.1176, 0.9091, 0.6043])


Generalizing for all the tokens:

In [16]:
context_vectors = torch.zeros(embeddings.shape)
print(context_vectors.shape)

torch.Size([11, 3])


In [17]:
for i, query in enumerate(embeddings):
    atwtsi = attention_weights[i]
    # ctxtvi = torch.zeros(query.shape)
    for j, ej in enumerate(embeddings):
        weight = atwtsi[j]
        weighted = weight * ej
        print(f"{weight:.2f} * {ej}  => {weighted}",end="\n")
        context_vectors[i] += weighted
    print(f"-----done for query = embeddings[{i}]")


0.77 * tensor([1.2221, 1.0395, 0.9608])  => tensor([0.9388, 0.7985, 0.7381])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.13 * tensor([ 0.6370,  1.3158, -0.4287])  => tensor([ 0.0839,  0.1734, -0.0565])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.01 * tensor([ 0.4214,  0.7452, -1.8389])  => tensor([ 0.0061,  0.0108, -0.0265])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.05 * tensor([ 1.9435, -0.8080, -0.8735])  => tensor([ 0.0906, -0.0377, -0.0407])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.01 * tensor([ 0.9367, -0.3077, -1.4196])  => tensor([ 0.0127, -0.0042, -0.0193])
0.00 * tensor([-0.5300, -1.3035,  0.4438])  => tensor([-0.0025, -0.0063,  0.0021])
0.00 * tensor([-1.2497, -0.2485, -1.0530])  => tensor([-0.0018, -0.0004, -0.0015])
-----done for query = embeddings[0]
0.00 * tensor([1.2221, 1.0395, 0.9608])  => tensor([0.005

Our final context vectors are:

In [18]:
print(f"all context vectors:\n{context_vectors}")
# sanity check:
(context_vectors[0] == ctxtv0).all()

all context vectors:
tensor([[ 1.1176,  0.9091,  0.6043],
        [-0.4914, -1.2268,  0.3463],
        [ 0.7425,  0.7750, -0.6312],
        [-0.4914, -1.2268,  0.3463],
        [ 0.5625,  0.4664, -1.5314],
        [-0.4914, -1.2268,  0.3463],
        [ 1.7167, -0.6763, -0.9275],
        [-0.4914, -1.2268,  0.3463],
        [ 1.1076, -0.2321, -1.1786],
        [-0.4914, -1.2268,  0.3463],
        [-0.6744, -0.4126, -0.7193]])


tensor(True)

More concisely, we could have done:

In [19]:
context_vectors = attention_weights @ embeddings
print(context_vectors)

tensor([[ 1.1176,  0.9091,  0.6043],
        [-0.4914, -1.2268,  0.3463],
        [ 0.7425,  0.7750, -0.6312],
        [-0.4914, -1.2268,  0.3463],
        [ 0.5625,  0.4664, -1.5314],
        [-0.4914, -1.2268,  0.3463],
        [ 1.7167, -0.6763, -0.9275],
        [-0.4914, -1.2268,  0.3463],
        [ 1.1076, -0.2321, -1.1786],
        [-0.4914, -1.2268,  0.3463],
        [-0.6744, -0.4126, -0.7193]])


### 1.2. Self-attention with trainable weights

In order to add training to our self-attention mechanism, we must project the embbeded input tokens into three different vectors: a query vector (q), a key vector (k) and a value vector (v). Each of these vectors is an element of the weight matrices $W_q$, $W_k$, and $W_v$, respectively.

#### Context vector for a single token

We will start by computing the q, k and v vectors for a single embedded input token. Let's pick the second token this time:

In [20]:
e1 = embeddings[1]

In GPT-like models, input and output dimensions are usually the same. But for better illustration, we will use:

In [21]:
d_in:int = e1.shape[0] # number of embedding dimensions
d_out = 2
print(f"d_in={d_in}, d_out={d_out}")

d_in=3, d_out=2


Initializing $W_q$, $W_k$, and $W_v$:

In [22]:
torch.manual_seed(69)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) # no grad to reduce output clutter
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
print(W_query,W_key,W_value, sep="\n\n")

Parameter containing:
tensor([[0.8398, 0.8042],
        [0.1213, 0.5309],
        [0.6646, 0.4077]])

Parameter containing:
tensor([[0.0888, 0.2429],
        [0.7053, 0.6216],
        [0.9188, 0.0185]])

Parameter containing:
tensor([[0.8741, 0.0560],
        [0.9659, 0.0073],
        [0.3628, 0.4197]])


Query, key and value vectors can be calculated by matrix multiplication:

In [23]:
query_1:torch.Tensor = e1 @ W_query
key_1:torch.Tensor = e1 @ W_key
value_1:torch.Tensor = e1 @ W_value
print(query_1, key_1, value_1, sep="\n")

tensor([-0.3083, -0.9373])
tensor([-0.5587, -0.9308])
tensor([-1.5614,  0.1471])


Note that the output tensors have 2 columns as we specified in d_out.


$M_{(i,j)} \times M_{(j,k)} => M_{(i,k)}$

Also note: **weight parameters** and **attention weights** are **not** the same thing.

Even though we are computing the context vector for a single token, we still need they key and value vectors for all the tokens:

In [24]:
key_vectors:torch.Tensor = embeddings @ W_key
value_vectors:torch.Tensor = embeddings @ W_value
print(f"key vectors (shape={key_vectors.shape}):\n{key_vectors}", f"value vectors (shape={value_vectors.shape}):\n{value_vectors}",sep="\n\n")

key vectors (shape=torch.Size([11, 2])):
tensor([[ 1.7245,  0.9609],
        [-0.5587, -0.9308],
        [ 0.5908,  0.9647],
        [-0.5587, -0.9308],
        [-1.1265,  0.5316],
        [-0.5587, -0.9308],
        [-1.2000, -0.0464],
        [-0.5587, -0.9308],
        [-1.4382,  0.0100],
        [-0.5587, -0.9308],
        [-1.2536, -0.4775]])

value vectors (shape=torch.Size([11, 2])):
tensor([[ 2.4211,  0.4793],
        [-1.5614,  0.1471],
        [ 1.6722, -0.1347],
        [-1.5614,  0.1471],
        [ 0.4209, -0.7427],
        [-1.5614,  0.1471],
        [ 0.6014, -0.2637],
        [-1.5614,  0.1471],
        [ 0.0065, -0.5456],
        [-1.5614,  0.1471],
        [-1.7144, -0.5137]])


##### Attention scores

The attention scores are once again calculated via dot product. But there's a difference:

In the simplified case, the attention score was the dot product between two inputs. This time, the dot product is between the query and the transposed key vectors.

Calculating the attention score between the second embedding and itself:

In [25]:
keys_1 = key_vectors[1]
attn_score_1_1 = query_1.dot(keys_1)
print(f"α₁,₁={attn_score_1_1}")

α₁,₁=1.044725775718689


All the attention scores in regards to the second embedding can be calculated as:

In [26]:
attn_scores_1_all = query_1 @ key_vectors.T
print(attn_scores_1_all, attn_scores_1_all.shape, sep="\n")

assert attn_score_1_1 == attn_scores_1_all[1]

tensor([-1.4323,  1.0447, -1.0864,  1.0447, -0.1510,  1.0447,  0.4134,  1.0447,
         0.4340,  1.0447,  0.8340])
torch.Size([11])


With the attention scores in hand, the next step would be to normalize them to get attention weights.

However, in real-world models, with hundres or thousands of embedding dimensions, a common technique is commonly applied before normalization: **scaling down the attention weights by the square root of the embedding dimension of the keys.**

Dot product grows together with the number of dimension, which can result in very small gradients during backpropagation since the softmax function ends up behaving almost like a step function - resulting in tiny gradients.

This technique is the reason why this self-attention mechanism is also called *scaled-dot product attention*.

In [27]:
d_k = key_vectors.shape[-1] # this is the same as d_out
scaled_attn_scores_1_all = attn_scores_1_all / (d_out ** 0.5)
print(f"attention scores for embedding 1: {attn_scores_1_all}")
print(f"scaled attention scores for embedding 1: {scaled_attn_scores_1_all}")

attention scores for embedding 1: tensor([-1.4323,  1.0447, -1.0864,  1.0447, -0.1510,  1.0447,  0.4134,  1.0447,
         0.4340,  1.0447,  0.8340])
scaled attention scores for embedding 1: tensor([-1.0128,  0.7387, -0.7682,  0.7387, -0.1067,  0.7387,  0.2923,  0.7387,
         0.3069,  0.7387,  0.5897])


##### Attention weights 

We now may apply the softmax function to the weights.

In [28]:
attn_weights_1_all = torch.softmax(scaled_attn_scores_1_all, dim=-1)
print(f"attention weights for embedding 1: {attn_weights_1_all}", attn_weights_1_all.sum(), sep="\n")

attention weights for embedding 1: tensor([0.0218, 0.1254, 0.0278, 0.1254, 0.0538, 0.1254, 0.0802, 0.1254, 0.0814,
        0.1254, 0.1080])
tensor(1.)


##### Calculating the context vectors

In the simplified case, the context vector for an input was calculated as the weighted sum **over the vector itself**.

Now, the weighted sum performed **over the value vector** instead:

In [29]:
print(attn_weights_1_all, value_vectors, sep="\n\n")

tensor([0.0218, 0.1254, 0.0278, 0.1254, 0.0538, 0.1254, 0.0802, 0.1254, 0.0814,
        0.1254, 0.1080])

tensor([[ 2.4211,  0.4793],
        [-1.5614,  0.1471],
        [ 1.6722, -0.1347],
        [-1.5614,  0.1471],
        [ 0.4209, -0.7427],
        [-1.5614,  0.1471],
        [ 0.6014, -0.2637],
        [-1.5614,  0.1471],
        [ 0.0065, -0.5456],
        [-1.5614,  0.1471],
        [-1.7144, -0.5137]])


In [30]:
context_vector_1 = torch.zeros(d_out)
for i, vvi in enumerate(value_vectors):
    attn_weights_1_i = attn_weights_1_all[i]
    context_vector_1 += attn_weights_1_i * vvi
print(context_vector_1)

# or

context_vector_1 = attn_weights_1_all @ value_vectors
print(context_vector_1)

tensor([-0.9935, -0.0622])
tensor([-0.9935, -0.0622])


#### Generalizing to all tokens

We're gonna wrap all up by creating a module to perform the self-attention mechanism for the whole input:

In [31]:
class SelfAttention_P(torch.nn.Module): # _P for Parameter (as opposed to Linear which we'll see soon)
    def __init__(self, d_in:int, d_out:int):
        super().__init__()

        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        key_vectors = x @ self.W_key
        query_vectors = x @ self.W_query
        value_vectors = x @ self.W_value

        attn_scores = query_vectors @ key_vectors.T
        scaled_attn_scores = attn_scores/(key_vectors.shape[-1] ** 0.5)
        attn_weights = torch.softmax(scaled_attn_scores, dim=-1)
        context_vectors = attn_weights @ value_vectors
        
        return context_vectors

In [32]:
torch.manual_seed(69)
self_attention_P = SelfAttention_P(d_in=d_in, d_out=d_out)
context_vectors_P = self_attention_P(embeddings)
print(context_vectors_P)
print (context_vectors_P[1],context_vector_1)

tensor([[ 2.1652e+00,  3.3566e-01],
        [-9.9352e-01, -6.2155e-02],
        [ 8.1129e-01,  3.4023e-04],
        [-9.9352e-01, -6.2155e-02],
        [-6.3151e-01, -1.8909e-01],
        [-9.9352e-01, -6.2155e-02],
        [ 1.1980e+00,  1.3611e-01],
        [-9.9352e-01, -6.2155e-02],
        [-4.8231e-01, -1.1911e-01],
        [-9.9352e-01, -6.2155e-02],
        [-1.0848e+00, -1.2670e-01]], grad_fn=<MmBackward0>)
tensor([-0.9935, -0.0622], grad_fn=<SelectBackward0>) tensor([-0.9935, -0.0622])


We can use Linear instead of Parameter, which has a more optimized weight initialization method:


(code moved to a module so it can be reused)

In [33]:
print(inspect.getsource(SelfAttention))

class SelfAttention(torch.nn.Module):
    def __init__(self, d_in:int, d_out:int, qkv_bias=False):
        super().__init__()

        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        key_vectors = self.W_key(x)
        query_vectors = self.W_query(x)
        value_vectors = self.W_value(x)

        attn_scores = query_vectors @ key_vectors.T
        scaled_attn_scores = attn_scores/(key_vectors.shape[-1] ** 0.5)
        attn_weights = torch.softmax(scaled_attn_scores, dim=-1)
        context_vectors = attn_weights @ value_vectors
        
        return context_vectors



In [34]:
torch.manual_seed(69)
self_attention_L = SelfAttention(d_in=d_in, d_out=d_out)
print(self_attention_L(embeddings))

tensor([[ 0.1785,  0.1088],
        [ 0.1216, -0.1495],
        [ 0.1985,  0.1906],
        [ 0.1216, -0.1495],
        [ 0.2426,  0.2571],
        [ 0.1216, -0.1495],
        [ 0.2941,  0.2379],
        [ 0.1216, -0.1495],
        [ 0.2602,  0.2216],
        [ 0.1216, -0.1495],
        [ 0.1258, -0.0132]], grad_fn=<MmBackward0>)


Since the weight initialization is different, naturally the results are different. But we can assign the weights from Linear version to Parameter version to check if the results are the same:

In [35]:
print(self_attention_P.W_key, self_attention_L.W_key,sep="\n")

Parameter containing:
tensor([[0.0888, 0.2429],
        [0.7053, 0.6216],
        [0.9188, 0.0185]], requires_grad=True)
Linear(in_features=3, out_features=2, bias=False)


The two classes have different data types for the weight matrices.
We must compatibilize these. We will try to create Parameters objects from the Linear objects we have at the second class.

Accessing the weights of the linear self-attention object:

In [36]:
self_attention_L.W_key.weight

Parameter containing:
tensor([[-0.4748, -0.2969,  0.2371],
        [ 0.1405,  0.4836, -0.5560]], requires_grad=True)

The weights are of Parameter type, but seem to be transposed, so we must create Parameter objects from the transposed weight matrices:

In [37]:
self_attention_P.W_key = torch.nn.Parameter(self_attention_L.W_key.weight.T)
self_attention_P.W_value = torch.nn.Parameter(self_attention_L.W_value.weight.T)
self_attention_P.W_query = torch.nn.Parameter(self_attention_L.W_query.weight.T)

Testing to see if we get the same results:

In [38]:
self_attention_L(embeddings)==self_attention_P(embeddings)

tensor([[True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True],
        [True, True]])

## 2. Causal attention