# Build a Large Language Model (from scratch)

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px">

Author of notes: https://github.com/deburky

## Chapter 3: Coding attention mechanisms

### Self-attention

Self-attention serves as the cornerstone of every LLM based on the transformer architecture.

Self-attention is a mechanism that allows each position in the input sequence to consider the relevancy of, or “attend to,” all other positions in the same sequence when computing the representation of a sequence. Self-attention is a key component of contemporary LLMs based on the transformer architecture, such as the GPT series.

In [46]:
import torch
from rich import print as rprint
from IPython.display import HTML

display(HTML(
    """In self-attention, our goal is to calculate context vectors <code>z(i)</code>
    for each element <code>x(i)</code> in the input sequence. A context vector can be
    interpreted as an enriched embedding vector."""
))

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

rprint(inputs.shape)

# Dot product of each input vector with the query vector
query = inputs[1]
rprint(query)

# Attention scores
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
rprint(attn_scores_2)

display(HTML(
    """The main goal behind the normalization is to obtain attention weights that sum up to 1.
    In practice, it's more common and advisable to use the softmax function for normalization."""
))

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
rprint(f"Attention weights: {attn_weights_2}")
rprint(f"Sum: {attn_weights_2.sum()}")

display(HTML(
    """The next step is calculating the context vector <code>z(2)</code> by multiplying
    the embedded input tokens, <code>x(i)</code>, with the corresponding attention weights
    and then summing the resulting vectors. Thus, context vector <code>z(2)</code> is the 
    weighted sum of all input vectors, obtained by multiplying each input vector
    by its corresponding attention weight:
    """
))

# Calculate context vector z(2)
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
rprint(context_vec_2)

---

The dot product is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot product indicates a greater degree of alignment or similarity between the vectors.

In [44]:
display(HTML(
    """First, we calculate attention scores for each pair of input vectors. <br><br> Then,
    we normalize the scores with softmax to obtain attention weights. <br><br>
    Finally, we calculate the context vector by taking the weighted sum of the input vectors.
    """
))

# Attention scores (covariance matrix)
attn_scores = inputs @ inputs.T
rprint(attn_scores)

# Normalize to get attention weights
attn_weights = torch.softmax(attn_scores, dim=-1)
rprint(attn_weights)

# Context vectors
all_context_vecs = attn_weights @ inputs
rprint(all_context_vecs)

### Self-attention with trainable weights

Our next step will be to implement the self-attention mechanism used in the original transformer architecture, the GPT models, and most other popular LLMs. This self-attention mechanism is also called scaled dot-product attention.

Weight parameters are the fundamental, learned coefficients that define the network’s connections, while attention weights are dynamic, context-specific values.

In [None]:
import torch
torch.manual_seed(123)

display(HTML(
    """The most notable difference is the introduction of weight matrices
    that are updated during model training.<br><br> These trainable weight matrices
    are crucial so that the model (specifically, the attention module
    inside the model) can learn to produce "good" context vectors.
    """
))

# Select second input vector
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

display(HTML(
    """First, we initialize three weight matrices: <code>W_query</code>, <code>W_key</code>,
    and <code>W_value</code>. Each matrix has a shape of <code>(d_in, d_out)</code>.
    """
))

W_query, W_key, W_value = [
    torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) for _ in range(3)
]

display(HTML(
    """Then we do a dot product of input <code>x(2)</code> with the query weight matrix
    and the key weight matrix, respectively. The value weight matrix is not used
    in this step. The query and key vectors are then used to calculate the attention
    score between <code>x(2)</code> and each input vector.
    """
))

query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value
rprint(query_2)

display(HTML(
    """We successfully projected the six input tokens
    from a three-dimensional onto a two-dimensional embedding space:"""
))

keys = inputs @ W_key 
values = inputs @ W_value
rprint(f"keys.shape: {keys.shape}, values.shape: {values.shape}")

# Attention scores
display(HTML(
    """The result for the unnormalized attention score is:"""
))

keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
rprint(attn_score_22)

# Attention scores
display(HTML(
    """Calculate all attention scores via matrix multiplication:"""
))

# Attention scores against other vectors
attn_scores_2 = query_2 @ keys.T
rprint(attn_scores_2)

display(HTML(
    """Second element matches the previous calculation.
    Next we normalize the attention scores to get attention weights.
    There is a small difference in the normalization step. We divide by the square
    root of the embedding dimension of the keys. <br><br>
    <b>The scaling by the square root of the embedding dimension is the reason why this
    self-attention mechanism is also called scaled-dot product attention.</b>
    """
))

# Normalize to get attention weights
d_k = keys.shape[-1]
rprint(f"d_k: {d_k}, Sqrt(d_k): {d_k**0.5:.2f}")
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
rprint(attn_weights_2, attn_weights_2.sum())

display(HTML(
    """And finally we get context vector by calculating
    a dot-product of the attention weights and the values.
    """
))

context_vec_2 = attn_weights_2 @ values
rprint(context_vec_2, context_vec_2.sum())

---
**Why query, key, and value?**

The terms **key**, **query**, and **value** in the context of attention mechanisms are borrowed from the domain of information retrieval and databases, where similar concepts are used to store, search, and retrieve information.

- A query is analogous to a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.

- The key is like a database key used for indexing and searching. In the attention mechanism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match the query.

- The value in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.

In [124]:
import torch
import torch.nn as nn
torch.manual_seed(123)

display(HTML(
    """Random initialization with <code>torch.rand()</code>:"""
))

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query, self.W_key, self.W_value = [
            nn.Parameter(torch.rand(d_in, d_out)) for _ in range(3)
        ]

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        return attn_weights @ values

sa_v1 = SelfAttention_v1(d_in, d_out)
rprint(sa_v1(inputs))

display(HTML(
    """ Random initialization with <code>nn.Linear()</code>
    helps to perform matrix multiplication without bias:"""
))

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query, self.W_key, self.W_value = [
            nn.Linear(d_in, d_out, bias=qkv_bias) for _ in range(3)
        ]

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        return attn_weights @ values
    
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
rprint(sa_v2(inputs))

display(HTML(
    """<code>nn.Linear</code> has an optimized weight
    initialization scheme, contributing to more stable and
    effective model training. <br><br>
    Note that <code>SelfAttention_v1</code> and <code>SelfAttention_v2</code>
    give different outputs because they use different initial weights
    for the weight matrices since <code>nn.Linear</code> uses a more sophisticated
    weight initialization scheme."""
))

### Multi-head attention

Each head learns different aspects of the data, allowing the model to simultaneously attend to information from different representation subspaces at different positions.

In [132]:
import torch

queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
rprint(attn_weights)

display(HTML(
    """<code>torch.tril</code> sets values above the diagonal to zero:"""
))

context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
rprint(mask_simple)

display(HTML(
    """Setting attention weights to zero:"""
))

masked_simple = attn_weights * mask_simple
rprint(masked_simple)

display(HTML(
    """Normalize to 1 on masked inputs:"""
))

row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
rprint(masked_simple_norm)

---

When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 = 2.

### Single attention head

In [None]:
torch.manual_seed(123)

class CausalAttention(nn.Module):
    def __init__(
        self, d_in, d_out, context_length,
        dropout, qkv_bias=False
    ):
        super().__init__()
        self.d_out = d_out
        self.W_query, self.W_key, self.W_value = [
            nn.Linear(d_in, d_out, bias=qkv_bias) for _ in range(3)
        ]
  
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
           'mask',
           torch.triu(
               torch.ones(context_length, context_length),
                diagonal=1
            )
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        return attn_weights @ values

batch = torch.stack((inputs, inputs), dim=0)
rprint(f"Batch: {batch.shape}")
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
rprint(f"context_vecs.shape: {context_vecs.shape}")

### Stacking multiple single-head attention layers

In [142]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(
        self, d_in, d_out, context_length,
        dropout, num_heads, qkv_bias=False
    ):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(
                    d_in,
                    d_out,
                    context_length,
                    dropout,
                    qkv_bias
                ) 
             for _ in range(num_heads)
            ]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
    
torch.manual_seed(123)
context_length = batch.shape[1]  # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

rprint(context_vecs)
rprint(f"context_vecs.shape: {context_vecs.shape}")