## 3.3 Attending to different parts of the input with self-attention

### 3.3.1 A simple self attention mechanism without trainable weights

In [6]:
import torch

# Inputs are the embeddings of the words in the sentence
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66],      # journey  (x^2)
   [0.57, 0.85, 0.64],      # starts   (x^3)
   [0.22, 0.58, 0.33],      # with     (x^4)
   [0.77, 0.25, 0.10],      # one      (x^5)
   [0.05, 0.80, 0.55]]      # step     (x^6)
)

inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

Using the second token, `journey`, as the query:

In [7]:
query = inputs[1]
attention_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attention_scores_2[i] = torch.dot(query, x_i)

attention_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

Now we normalize each of the attention scores.

We want to ensure that the sum of the attention weights is 1.

In [8]:
# Notice how we computed use the scores to compute the weights
attention_weights_2_tmp = attention_scores_2 / attention_scores_2.sum()
print("Attention weights:", attention_weights_2_tmp)
print("Sum:", attention_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


But it's better to use the `softmax` function for normalization. It's better at handling extreme values & gives better gradient properties during training.

In [9]:
attention_weights_2 = torch.softmax(attention_scores_2, dim=0)
print("Attention weights:", attention_weights_2)
print("Sum:", attention_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


The final step is to calculate the context vector by multiplying the embedded input tokens with the corresponding attention weights & summing the resulting vectors.

In [10]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    print(f"{i} | {attention_weights_2[i]:.2f} * {x_i}")
    context_vec_2 += attention_weights_2[i] * x_i

context_vec_2

0 | 0.14 * tensor([0.4300, 0.1500, 0.8900])
1 | 0.24 * tensor([0.5500, 0.8700, 0.6600])
2 | 0.23 * tensor([0.5700, 0.8500, 0.6400])
3 | 0.12 * tensor([0.2200, 0.5800, 0.3300])
4 | 0.11 * tensor([0.7700, 0.2500, 0.1000])
5 | 0.16 * tensor([0.0500, 0.8000, 0.5500])


tensor([0.4419, 0.6515, 0.5683])

Just me doing retrieval practice:

In [14]:
# Compute attention scores
query = inputs[0]
attn_scores_1 = torch.zeros(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_1[i] = torch.dot(query, x_i)

# Normalize - compute attenttion weights
attn_weights_1 = torch.softmax(attn_scores_1, dim=0)

# Compute context vector
context_vec_1 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_1 += attn_weights_1[i] * x_i

context_vec_1

tensor([0.4421, 0.5931, 0.5790])

### 3.3.2 Computing attention weights for all input tokens

In [17]:
attn_scores = torch.empty((inputs.shape[0], inputs.shape[0]))
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

Each element in the `attn_scores` tensor represents an attention score between each pair of inputs.

You can imagine it being a matrix like this, excluding the labels:
| | Your | journey | starts | with | one | step |
|---|---|---|---|---|---|---|
| Your | 0.9995 | 0.9544 | 0.9422 | 0.4753 | 0.4576 | 0.6310 |
| journey | 0.9544 | 1.4950 | 1.4754 | 0.8434 | 0.7070 | 1.0865 |
| starts | 0.9422 | 1.4754 | 1.4570 | 0.8296 | 0.7154 | 1.0605 |
| with | 0.4753 | 0.8434 | 0.8296 | 0.4937 | 0.3474 | 0.6565 |
| one | 0.4576 | 0.7070 | 0.7154 | 0.3474 | 0.6654 | 0.2935 |
| step | 0.6310 | 1.0865 | 1.0605 | 0.6565 | 0.2935 | 0.9450 |

In [20]:
# A faster way:
attn_scores = inputs @ inputs.T # or torch.matmul
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [29]:
# dim=-1 means the last dimension.
# For this rank 2 tensor, it means we're applying softmax along the second dimension of [rows, columns]. That is,
# we're normalizing across the columns, so the values in each row (summing over the column dimension) sum up to 1.
attn_weights = torch.softmax(attn_scores, dim=-1)
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [32]:
context_vecs = attn_weights @ inputs
context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

## 3.4 Implementing self-attention with trainable weights

### 3.4.1 Computing the attention weights step by step

In [48]:
x_2 = inputs[1]
d_in = inputs.shape[1] # input embedding size
d_out = 2 # output embedding size

In [49]:
torch.manual_seed(42)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

`requires_grad=False` to reduce clutter. If we were using the weight matrices for model training, we'd set it to `True` during training.

In [99]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

query_2

tensor([1.0760, 1.7344])

In [53]:
keys = inputs @ W_key
values = inputs @ W_value

keys.shape, values.shape

(torch.Size([6, 2]), torch.Size([6, 2]))

Now we've projected the six input tokens from a three-dimensional onto a two-dimensional embedding space.

In [54]:
keys_2 = keys[1]
attn_scores_22 = query_2.dot(keys_2)
attn_scores_22  # unnormalized attention score

tensor(3.3338)

In [55]:
# generalized:
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([2.7084, 3.3338, 3.3013, 1.7563, 1.7869, 2.1966])

From attention scores to attention weights.

Scale the attention scores by dividing them by the sqrt of the embedding dimension of the keys & then using the softmax fn.

We scale by the embedding dimension to improve training performance by avoiding small gradients.

Large dot products can lead to very small gradients during backprop due to softmax. As dot products increase, softmax becomes more like a step function, leading to gradients near zero. These can slow down training / cause it to stagnate.

We call this self-attention mechanism "scaled-dot product attention" due to this scaling by the sqrt of the embedding dimension.

In [57]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
attn_weights_2

tensor([0.1723, 0.2681, 0.2620, 0.0879, 0.0898, 0.1200])

In [68]:
context_vec_2 = attn_weights_2 @ values

context_vec_2

tensor([1.4201, 0.8892])

### 3.4.2 Implementing a compact self-attention Python class

In [69]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in: int, d_out: int):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [70]:
torch.manual_seed(42)
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v1(inputs)

tensor([[1.3751, 0.8610],
        [1.4201, 0.8892],
        [1.4198, 0.8890],
        [1.3533, 0.8476],
        [1.3746, 0.8606],
        [1.3620, 0.8532]], grad_fn=<MmBackward0>)

We can use `nn.Linear` layers instead, which effectively perform matmuls when bias units are disabled.
And linear layers have an optimized weight initialization scheme, meaning more stable and effective training.

In [92]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in: int, d_out: int, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [95]:
torch.manual_seed(42)
sa_v2 = SelfAttention_v2(d_in, d_out)

# Exercise 3.1: checking that nn.Linear (bias=False) is similar to nn.Parameter, except for weight initialization
# sa_v2.W_query.weight = nn.Parameter(sa_v1.W_query.T)
# sa_v2.W_key.weight = nn.Parameter(sa_v1.W_key.T)
# sa_v2.W_value.weight = nn.Parameter(sa_v1.W_value.T)

sa_v2(inputs) # different outputs due to different weight initialization schemes

tensor([[0.3755, 0.2777],
        [0.3761, 0.2831],
        [0.3761, 0.2833],
        [0.3768, 0.2763],
        [0.3754, 0.2836],
        [0.3772, 0.2746]], grad_fn=<MmBackward0>)

## 3.5 Hiding future words with causal attention