# Setup

In [2]:
# noqa
import os
import sys

import torch

# Get the absolute path of the current file
current_dir = os.path.dirname(os.path.abspath("__file__"))
# Go up to the parent directory
parent_dir = os.path.dirname(current_dir)
# Add the parent directory to sys.path
sys.path.insert(0, parent_dir)
# Import the class
from self_attention import SelfAttentionV1, SelfAttentionV2, CausalAttention

# Basic attention

Embeddings tensor

In [3]:
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89],  # Your     (x^1)
        [0.55, 0.87, 0.66],  # journey  (x^2)
        [0.57, 0.85, 0.64],  # starts   (x^3)
        [0.22, 0.58, 0.33],  # with     (x^4)
        [0.77, 0.25, 0.10],  # one      (x^5)
        [0.05, 0.80, 0.55],
    ]  # step     (x^6)
)

Compute attention scores by computing the dot products of each input embedding tensor with the query

In [4]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
attn_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

Real dot product function vs hand-coded version

In [5]:
res = 0.0
for idx, element in enumerate(inputs[0]):
    res += element * query[idx]
print(res)
print(torch.dot(inputs[0], query))

tensor(0.9544)
tensor(0.9544)


Now we normalize the attention dot producs (attention scores) we have computed previously so that they sum up to 1. Once normalized, they are called attention weights instead of scores but they're mostly the same thing.

In [6]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


Normally, we use the softmax function to normalize attention scores.
This uses the exponential function before normalizing: larger values become even bigger relative to smaller values, and all values become positive. 

Softmax is more "peaked" and is useful for attention mechanisms but forcing to focus on more relevant parts of the input, greater division important / unimportant, ...

In [7]:
def softmax_naive(x: torch.Tensor) -> torch.Tensor:
    """Compute softmax."""
    return (
        torch.exp(x) / torch.exp(x).sum()
    )  # e^x computes the exponential of each element in x (2.7183 * 2.7183 * 2.7183, etc ...)


attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


But it's generally better to use the pytorch version of softmax instead of rolling out your own (naive softmax implementation (softmax_naive) may encounter numerical instability problems, such as overflow and underflow, when dealing with large or small input values)

In [8]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


Here, we compute the context vector by computing the sum of the multiplication of each input vector (corresponding to a token embedding) with its attention weight (nomalized attention score)

In [9]:
query = inputs[1]  # we could just as easily have accessed the dimension separately, we don't actually need the var
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
context_vec_2

tensor([0.4419, 0.6515, 0.5683])

Computing the attention weights for each input vector relative to the other input vectors

In [10]:
attn_scores = torch.empty(6, 6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

The same result cam be obtained by doing matrix multiplication directly using the "@" operator on the transposed matrix

In [11]:
inputs.T

tensor([[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500],
        [0.1500, 0.8700, 0.8500, 0.5800, 0.2500, 0.8000],
        [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500]])

In [12]:
attn_scores = inputs @ inputs.T
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

We can now normalize using softmax, transorming all attention scores into attention weights that score up to 1.0 when summed intra-vector

In [13]:
attn_weights = torch.softmax(attn_scores, dim=-1)  # dim(-1) means the last dimension
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

We can confirm uing `sum(dim=-1)` that the weights (intra-row) sum up to 1 (and are thus normalized)

In [14]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


Now we can compute all context vectors for each token using matrix multiplication

In [15]:
all_context_vecs = attn_weights @ inputs
all_context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [16]:
print("Previous 2nd context vector:", context_vec_2)

Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])


# Self attention with trainable weights

## Basic operations

We first setup a query vector (embdeddings for 3rd token in the vocab) and the weight matrices dimensions (input and output)

In [17]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

We initialize the trainable weight matrices that will be modified during backprogragation when training the model

In [18]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
print(W_query)
print(W_key)
print(W_value)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])
Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


Here we multiply an input vector (x_2) (embeddings for a single token) by each of the weights to project its  values intout `d_out` dimensions

In [19]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
print(key_2)
print(value_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor([0.3951, 1.0037])


We can obtain all keys and values vectors via matrix multiplication, transposing the inputs matrix (6 tokens vocab * 3 embedding dimensions) into appropriate keys or values vectors (6 tokens * 2 output dimensions as defined above in `d_out` variable)



In [20]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
print(keys)
print(values)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])
tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]])


## Computing attention scores with trainable weights

With trainable weights we need to compute the attention score by taking the current query token (e.g. x_2), taking it's query vector (dim 2) and multiplying each other token (e.g. x_1, x_3)'s key vectors with the query vector for the current query token (x_2)

In [21]:
keys_2 = keys[1]  # keys for token at index 1
print(keys_2)
attn_score_22 = query_2.dot(
    keys_2
)  # dot product between key vector (e.g. of token 1 with selected query token's query vector)
print(attn_score_22)

tensor([0.4433, 1.1419])
tensor(1.8524)


We generalise the computation above to all attention scores (matrix mult)
We are multiplying the query_2 vector (of x_2) with all projected key embeddings, to compute the attention scores of all input tokens relateive to the selected query (x_2)

We do that because we want to compute the context vector for the second input token (x_2) and we first need the attention scores, then the attention weights (normlized).

In [22]:
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

NOw we need to compute the attention weights from the attention scores

In [23]:
d_k = keys.shape[-1]  # d_k stands for dimension of the keys
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


#### The rationale behind scaled-dot product attention

The reason for the normalization by the embedding dimension size is to improve the training performance by avoiding small gradients. For instance, when scaling up the embedding dimension, which is typically greater than 1,000 for GPT-like LLMs, large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function behaves more like a step function, resulting in gradients nearing zero. These small gradients can drastically slow down learning or cause training to stagnate.

The scaling by the square root of the embedding dimension is the reason why this self-attention mechanism is also called scaled-dot product attention.


## Computing the context vectors

Last step is to multiply the value vectors with the attention weights and summing to obtain the context vectors (of each token).

First we compute single context vector just for x_2 by multiplying its attention weights by the values matrix (for all tokens)

In [24]:
print(attn_weights_2.shape)
context_vec_2 = attn_weights_2 @ values  # we can multiplie a vector of dim 6 by a matric of 6 * 2?
context_vec_2

torch.Size([6])


tensor([0.3061, 0.8210])

#### Why query, key, and value?
The terms “key,” “query,” and “value” in the context of attention mechanisms are borrowed from the domain of information retrieval and databases, where similar concepts are used to store, search, and retrieve information.

A query is analogous to a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.

The key is like a database key used for indexing and searching. In the attention mechanism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match the query.

The value in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.

## Using custom self attention class to compute context erctors for all tokens in one pass

In [25]:
torch.manual_seed(123)
self_attention_v1 = SelfAttentionV1(d_in=d_in, d_out=d_out)
self_attention_v1(inputs)

tensor([[0.2845, 0.4071],
        [0.2854, 0.4081],
        [0.2854, 0.4075],
        [0.2864, 0.3974],
        [0.2863, 0.3910],
        [0.2860, 0.4039]], grad_fn=<MmBackward0>)

### Using v2 with nn.Linear weight matrices

We can see the initial weights are different from SelfAttentionV1 because of how nn.Linear intialises its data

In [26]:
torch.manual_seed(789)
sa_v2 = SelfAttentionV2(d_in, d_out)
sa_v2(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

### Exercise 3.1  Comparing SelfAttention_v1 and SelfAttention_v2

In [27]:
key_layer_weight_matrix, query_layer_weight_matrix, value_layer_weight_matrix = (
    sa_v2.W_key,
    sa_v2.W_query,
    sa_v2.W_value,
)
self_attention_v1.W_key = torch.nn.Parameter(key_layer_weight_matrix.weight.T)
self_attention_v1.W_query = torch.nn.Parameter(query_layer_weight_matrix.weight.T)
self_attention_v1.W_value = torch.nn.Parameter(value_layer_weight_matrix.weight.T)

print(self_attention_v1(inputs))
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


# Causal attention

Also called "masked attention", because for NLP tasks, only the previous taken in a context / sentence matter, not the future ones.

As a first, step, compute the attention weights as before

In [28]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
attn_weights

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

Now we need to create a mask using `tril` that will zero all values above the diagonal

In [29]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

NOw we can multiply the mast with the attention weights, zeroing values above diagonal

In [30]:
masked_simple = attn_weights * mask_simple
masked_simple

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

We can again normalize the attention weights

In [31]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_normalized = masked_simple / row_sums
masked_simple_normalized

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

### More efficient way of applying mask and normalizing with less compute power using infinity

In [32]:
ones = torch.ones(context_length, context_length)
# `triu`` = "triangular upper. Will return lower portions of the matrix as zeros". Opposite of `tril`
mask = torch.triu(ones, diagonal=1)
print(mask)

# masked_fill(mask.bool(), -torch.inf) is a PyTorch operation that replaces certain
# values in the attention scores tensor with negative infinity based on a boolean mask
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])


tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

Now compute the attention weights in one pass by applying softmax to the masked attention scores and raising to the power of 0.5

In [33]:
attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

## Applying drop out mask to reduce overfitting

We apply a dropout mask to prevent overfitting during training. Dropout randomly sets some elements of the input to zero, which forces the model to not rely too heavily on any single feature or connection. This encourages the network to learn more robust and generalizable patterns. In attention mechanisms, applying dropout to the attention weights helps regularize the model and improves its ability to generalize to new data.

When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 = 2. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.


In [34]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
dropout(example)

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])

Dropout applied to attention weight matrix

In [35]:
torch.manual_seed(123)
dropout(attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)

## Using causal attenion class

First, let's create a batch of inputs (as normally happens with llms).

Instead of passing 2D matrices to our attention classes, we pass a 3d matrix num_batches * num_tokens * dimensions, e.g. 8 * 6 * 3

In [36]:
# torch.stack is a PyTorch function that joins a sequence of tensors along a new dimension.
# here' were just duplicating the inputs as an example (nonsensical)
batch = torch.stack((inputs, inputs), dim=0)
batch, batch.shape  # as we can se, we have 2 * 6 * 3 matrix

(tensor([[[0.4300, 0.1500, 0.8900],
          [0.5500, 0.8700, 0.6600],
          [0.5700, 0.8500, 0.6400],
          [0.2200, 0.5800, 0.3300],
          [0.7700, 0.2500, 0.1000],
          [0.0500, 0.8000, 0.5500]],
 
         [[0.4300, 0.1500, 0.8900],
          [0.5500, 0.8700, 0.6600],
          [0.5700, 0.8500, 0.6400],
          [0.2200, 0.5800, 0.3300],
          [0.7700, 0.2500, 0.1000],
          [0.0500, 0.8000, 0.5500]]]),
 torch.Size([2, 6, 3]))

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)