# Chapter 3: Coding Attention Mechanisms

## 3.3.1: A simple self-attention mechanism without weights

In [None]:
import torch
# Start with tokenized text
inputs = torch.tensor(
   [[0.43, 0.15, 0.89], # Your (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55]] # step (x^6)
)

In [None]:
# Compute attention scores
query = inputs[1] # The second input token serves as the query
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
print(attn_scores_2)

In [None]:
# Normalize attention scores
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

In [None]:
# Same as above, but using softmax
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

In [None]:
# Same as above, but using PyTorch softmax
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

In [None]:
# Calculate the second context vector
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

## 3.3.2: Computing attention weights for all input tokens

In [None]:
attn_scores = torch.empty(6,6)
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

In [None]:
# Instead of using loops, we'll use matrix multiplication
attn_scores = inputs @ inputs.T
print(attn_scores)

In [None]:
# Normalize each row
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

In [None]:
# Verify our normalization sums to 1
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)
print("All row sums:", attn_weights.sum(dim=-1))

In [None]:
# Compute all context vectors
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

In [None]:
# Let's verify this is correct
print("Previous 2nd context vector:", context_vec_2)
print("Current 2nd context vector:", all_context_vecs[1])

## 3.4.1: Computing the attention weights step by step

In [None]:
x_2 = inputs[1] # the second input element
d_in = inputs.shape[1] # Input embedding size, d=3
d_out = 2 # output embedding size, d_out = 2

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
# Would need to set requires_grad=True if we were to do actual training.

query_2 = x_2 @ W_query
key_2   = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

In [None]:
# We still need all the key and value vectors to compute the single query
keys   = inputs @ W_key
values = inputs @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

In [None]:
# Compute the attention scores
# Start with just a single one, for demonstration purposes
keys_2 = keys[1] # Python starts index at 0
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

In [None]:
# generalize for all attention scores
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

In [None]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

In [None]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

## 3.4.2: Implementing a compact self-attention Python class

In [None]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
        attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [None]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

In [None]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [None]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

## 3.5.1: Applying a causal attention mask

In [None]:
# Start by computing the attention weights
queries      = sa_v2.W_query(inputs)
keys         = sa_v2.W_key(inputs)
attn_scores  = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

In [None]:
# Create a mask that zeroes out weights for tokens not yet encountered (the diagonal)
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

In [None]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

In [None]:
# Renormalize the weights
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

In [None]:
# A more efficient mask
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

In [None]:
# and normalize this
attn_weights = torch.softmax(masked / keys.shape[-1]**.5, dim=1)
print(attn_weights)