# Self-Attention



Start with our token embeddings.

For this example we use 3 dimensional embeddings that were arbitrarily created to represent each token for illustrative purposes. Also, we just assume each word is a token (rather than breaking a word into smaller tokens).

In reality, tokens are often sub-words and have many more dimensions. GPT-3 used over 12k diomensions for it's token embeddings.

In [1]:
import torch

inputs = torch.tensor(
    [[0.21, 0.47, 0.91], # I
     [0.52, 0.11, 0.65], # can't
     [0.03, 0.85, 0.19], # find
     [0.73, 0.64, 0.39], # the
     [0.13, 0.55, 0.68], # light
     [0.22, 0.77, 0.08]] # switch
)

In [27]:
inputs.shape

torch.Size([6, 3])

### Step 1 - Compute Attention Scores

Each token in our 6 word context window above needs to know how much it shoud "pay attention" to the other tokens in our sequence. 

In other words - how much does a token impact another token's meaning/ change it's context.

To tell our model how much attention a token should should give the other tokens in the sequence we compute the attention weights.

##### Computing Attention Weights
For every token in our context window, we take that token as our "query" and compute a dot product between the query and every other token in the context window.

The dot product is a measure of similarity between the tokens.

In [None]:
# Computing attention scores for only one token.
# Example taking the token "light" as the query
query = inputs[4]

# Create empty tensor, sized for each token in our context window
attention_scores_for_light = torch.empty(inputs.shape[0])

# Compute attention scores for each token in the context window
for i, token_embedding in enumerate(inputs):
    attention_scores_for_light[i] = torch.dot(query, token_embedding)

In [3]:
attention_scores_for_light

tensor([0.9046, 0.5701, 0.6006, 0.7121, 0.7818, 0.5065])

### Step 2 - Normalize the Attention Scores to Weights

We want our scores to sum to 1, so we normalize them to create our attention weights

In [5]:
import torch.nn.functional as F

attention_weights_for_light = F.softmax(attention_scores_for_light, dim=0)
sum(attention_weights_for_light)

tensor(1.0000)

### Step 3 - Compute the Context Vector
Take the input vectors of each token * the attention weights w.r.t. a given token to get that token's context vector (simplified version).

In [15]:
context_vector_for_light = torch.zeros_like(query)

for i, token_embedding in enumerate(inputs):
    context_vector_for_light += attention_weights_for_light[i] * token_embedding

context_vector_for_light

tensor([0.3039, 0.5600, 0.5155])

This is one simplified example for updating the input embedding vector for a given token with contextual information that "attends to" the other tokens in the context sequence.

# Extending this to All Tokens
1) Compute attention scores
2) Compute attention weights
3) Compute context vectors

In [31]:
inputs.shape

torch.Size([6, 3])

In [32]:
inputs.T.shape

torch.Size([3, 6])

In [None]:
# Compute the dot product of the input embeddings with themselves
# Matrix multiplication (@) of the input embeddings with their transpose (.T)
attention_scores = inputs @ inputs.T
attention_scores

tensor([[1.0931, 0.7524, 0.5787, 0.8090, 0.9046, 0.4809],
        [0.7524, 0.7050, 0.2326, 0.7035, 0.5701, 0.2511],
        [0.5787, 0.2326, 0.7595, 0.6400, 0.6006, 0.6763],
        [0.8090, 0.7035, 0.6400, 1.0946, 0.7121, 0.6846],
        [0.9046, 0.5701, 0.6006, 0.7121, 0.7818, 0.5065],
        [0.4809, 0.2511, 0.6763, 0.6846, 0.5065, 0.6477]])

Each row above is an attention score vector for that token in our context window.

In [None]:
# dim= -1 means apply the softmax function to the last dimension of the attention_scores tensor
# in this case that means normalize the scores for each token in the context window
# since the last dimension of the attention_scores tensor is the vector for each token
attention_weights = F.softmax(attention_scores, dim=-1)
attention_weights

tensor([[0.2256, 0.1605, 0.1349, 0.1698, 0.1869, 0.1223],
        [0.2024, 0.1931, 0.1204, 0.1928, 0.1687, 0.1226],
        [0.1641, 0.1161, 0.1966, 0.1745, 0.1677, 0.1809],
        [0.1705, 0.1534, 0.1440, 0.2268, 0.1547, 0.1505],
        [0.2068, 0.1480, 0.1526, 0.1706, 0.1829, 0.1389],
        [0.1552, 0.1233, 0.1887, 0.1902, 0.1592, 0.1834]])

In [35]:
context_vectors = attention_weights @ inputs
context_vectors

tensor([[0.3101, 0.5440, 0.5383],
        [0.3362, 0.5293, 0.5323],
        [0.2897, 0.6003, 0.4588],
        [0.3387, 0.5656, 0.4880],
        [0.3039, 0.5600, 0.5155],
        [0.3023, 0.5974, 0.4544]])

In [36]:
attention_weights.shape, inputs.shape, context_vectors.shape

(torch.Size([6, 6]), torch.Size([6, 3]), torch.Size([6, 3]))

# Self-Attention with weights
More specifically, scaled dot-product attention.

### Three Trainable Weight Matrices
Wquery, Wkey, Wvalue.

These three matrices project each embedding from our input tokens into their respective Query, Key, and Value vectors.

The context vector = attention weighted sum of each value vector

In [None]:
x_4 = inputs[4] # Using the word "light" as our example again

tensor([0.1300, 0.5500, 0.6800])

In practice, our KQV matrices will have fewer dimensions than our input.

In multi-headed attention it's often: 
```
model_input_dimensions / number_of_attention_heads = KQV output dimensions for each head
```

In PyTorch - the last dimension of your input tensor must match the input_dimension size.

In [6]:
d_in = inputs.shape[-1] # last dimension of the input tensor, size of our token embeddings
d_out = 2 # arbitrary size for illustration purposes it is different

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)


In [8]:
W_query

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

In [9]:
# Compute the vectors for our x_4 input
query_4 = x_4 @ W_query
key_4 = x_4 @ W_key
value_4 = x_4 @ W_value

In [11]:
print(x_4.shape, W_query.shape, query_4.shape)
query_4

torch.Size([3]) torch.Size([3, 2]) torch.Size([2])


tensor([0.2272, 1.0351])

In [12]:
# Compute for all inputs
keys = inputs @ W_key
values = inputs @ W_value

print(keys.shape, values.shape)

torch.Size([6, 2]) torch.Size([6, 2])


In [14]:
# Compute unnormalized attention scores, for one token
attention_score_4 = query_4.dot(keys[4])

# Compute attention scores for all tokens for a given query
attention_scores_4 = query_4 @ keys.T
attention_scores_4

tensor([1.1143, 0.6675, 0.8276, 0.9134, 0.9867, 0.7040])

In [None]:
# Compute the attention weights with softmax (scaled by sqrt of embedding size)
d_k = keys.shape[-1]
attention_weights_4 = torch.softmax(attention_scores_4 / (d_k ** 0.5), dim=-1)
sum(attention_weights_4)

tensor(1.0000)

In [17]:
# Get the context vector for the token "light"
context_vector_4 = attention_weights_4 @ values
context_vector_4

tensor([0.2620, 0.7043])

# SelfAttention Class

Linear layers replace the matrix multiplication (equivalent operatoin just with better weight initialization).



In [3]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super(SelfAttention, self).__init__()
        # Initialize the linear layers (weight matrices) for the query, key, and value
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, inputs):
        # Compute the query, key, and value vectors for all tokens in the input
        # (matrix multiplication of the input tensor with the weight matrices)
        keys = self.W_key(inputs)
        queries = self.W_query(inputs)
        values = self.W_value(inputs)

        # Compute the unnormalized attention scores for a given query
        # (dot product between the query vector of each token and the key vectors of all tokens)
        attention_scores = queries @ keys.T
        d_k = keys.shape[-1]
        # Normalize with softmax
        attention_weights = torch.softmax(attention_scores / (d_k ** 0.5), dim=-1)

        # Compute the context vector for each token
        context_vector = attention_weights @ values
        return context_vector
    
self_attention = SelfAttention(d_in=3, d_out=2)
print(self_attention(inputs))

tensor([[ 0.2480, -0.0219],
        [ 0.2481, -0.0209],
        [ 0.2488, -0.0185],
        [ 0.2487, -0.0175],
        [ 0.2482, -0.0210],
        [ 0.2489, -0.0177]], grad_fn=<MmBackward0>)


# Causal Attention

Mask future words in a sequence so not looking a head.

1. Mask the unnormalizd attention scores aobve the diagonal (future words) with negative infinity.
2. Then normalize with softmax.
3. Optionally apply dropout after computing the attention weights 

In [None]:
# Get attention scores for all tokens
queries = self_attention.W_query(inputs)
keys = self_attention.W_key(inputs)
attention_scores = queries @ keys.T

# Create the mask and apply it to our attention scores
context_length = inputs.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attention_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[-0.1339,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.1302, -0.0741,    -inf,    -inf,    -inf,    -inf],
        [ 0.0732,  0.0379,  0.0315,    -inf,    -inf,    -inf],
        [-0.0057, -0.0097, -0.0499, -0.0668,    -inf,    -inf],
        [-0.0730, -0.0409, -0.0532, -0.0537, -0.0646,    -inf],
        [ 0.0839,  0.0422,  0.0275,  0.0157,  0.0663,  0.0084]],
       grad_fn=<MaskedFillBackward0>)

In [14]:
# Compute the attention weights with the masked scores
masked_attention_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
masked_attention_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4901, 0.5099, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3394, 0.3310, 0.3295, 0.0000, 0.0000, 0.0000],
        [0.2548, 0.2541, 0.2470, 0.2441, 0.0000, 0.0000],
        [0.1978, 0.2023, 0.2005, 0.2005, 0.1989, 0.0000],
        [0.1718, 0.1668, 0.1651, 0.1637, 0.1697, 0.1629]],
       grad_fn=<SoftmaxBackward0>)

In [None]:
# Apply dropout to the masked attention weights
# non dropping out weights are scaled by 1/(1-p) to keep the expected value of the output the same
dropout = torch.nn.Dropout(0.5)
masked_with_dropout = dropout(masked_attention_weights)
masked_with_dropout

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9802, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6788, 0.6621, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5097, 0.5082, 0.0000, 0.4881, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.3302, 0.0000, 0.3394, 0.0000]],
       grad_fn=<MulBackward0>)

# Causal Self-Attention with Dropout

In [16]:
# Make batches of inputs so can test loading multiple inputs
batch = torch.stack((inputs, inputs), dim=0)
batch.shape
# 2 batches of 6 tokens with 3 dimensions each

torch.Size([2, 6, 3])

In [None]:
class CausalAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        # Initialize dropout
        self.dropout = torch.nn.Dropout(dropout)
        # register_buffer is used to store the mask in the model's state_dict
        # it is not a trainable parameter
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))


    def forward(self, inputs):
        batch_size, num_tokens, d_in = inputs.shape
        keys = self.W_key(inputs)
        queries = self.W_query(inputs)
        values = self.W_value(inputs)

        attention_scores = queries @ keys.transpose(1, 2) # TODO - changes b/c of batch?
        # _ means function in pytorch occurs in place
        # num_tokens instead of context_length b/c of may vary between batches
        attention_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / (keys.shape[-1] ** 0.5), dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Copmute the context vector for each token
        context_vector = masked_with_dropout @ values
        return context_vector
    
causal_attention = CausalAttention(d_in=3, d_out=2, context_length=6, dropout=0.5)
print(causal_attention(batch))

tensor([[[ 0.0000,  0.0000],
         [-0.2933, -0.3942],
         [-0.3239, -0.4191],
         [-0.3118, -0.2910],
         [ 0.0000,  0.0000],
         [-0.1207, -0.0910]],

        [[ 0.0000,  0.0000],
         [-0.2933, -0.3942],
         [-0.3239, -0.4191],
         [-0.3118, -0.2910],
         [ 0.0000,  0.0000],
         [-0.1207, -0.0910]]], grad_fn=<CloneBackward0>)


# Multihead Attention

We could just stack single-head attention layers, but it's more efficient to use a 3 single matrices for QKV and split for each attention head.

In [24]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        # The size of each individual head 
        self.head_dim = d_out // num_heads

        # Same as before
        self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
        # this out_proj isn't necessary, but it's a common practice to combine the head outputs
        self.out_proj = torch.nn.Linear(d_out, d_out)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        # This is how we allow each head to attend to different parts of the sequence
        # we add a new dimension to our tensor for each head
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        # This transpose allows us to perform the matrix multiplication on each head separately
        # now each head is a bathc of tokens of head_dim size
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(batch_size, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

print(batch.shape)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

torch.Size([2, 6, 3])
tensor([[[0.2958, 0.4672],
         [0.2963, 0.4761],
         [0.2559, 0.4713],
         [0.2581, 0.4334],
         [0.2586, 0.4435],
         [0.2442, 0.4428]],

        [[0.2958, 0.4672],
         [0.2963, 0.4761],
         [0.2559, 0.4713],
         [0.2581, 0.4334],
         [0.2586, 0.4435],
         [0.2442, 0.4428]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
