# Build a Large Language Model (from scratch)

<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px">

Author of notes: https://github.com/deburky

## Chapter 3: Coding attention mechanisms

### Context length

* Context length determines how many tokens the model can attend to.

* The context length is the number of tokens that the model can consider when making predictions. For example, if the context length is 512, the model can only look at the last 512 tokens when making a prediction.

* The last token generated contains information about the previous tokens. The model uses this information to make predictions about the next token in the sequence.

> In inference, context length includes the prompt and the tokens generated by the model. The model can only attend to the last 512 tokens, so if the prompt is 512 tokens long, the model can only generate one token at a time. If the prompt is 511 tokens long, the model can generate two tokens at a time.

In [49]:
VOCAB_SIZE = 1000
EMB_DIM = 128
CONTEXT_LENGTH = 1024

# Token embeddings
embeddings = torch.nn.Embedding(VOCAB_SIZE, EMB_DIM)
emb_weights = embeddings.weight.data

# Positional embeddings
pos_embeddings = torch.nn.Embedding(CONTEXT_LENGTH, EMB_DIM)
pos_weights = pos_embeddings.weight.data

In [81]:
# Simulate an input token sequence of length `L`
input_ids = torch.randint(0, VOCAB_SIZE, (1, 64))  # (batch=1, seq_len=64)
position_ids = torch.arange(0, input_ids.size(1)).unsqueeze(0)  # (1, 64)

print("Input IDs:", input_ids[0][:5])

# Get token + position embeddings
tok_emb = embeddings(input_ids)        # shape: [1, 64, 128]
pos_emb = pos_embeddings(position_ids) # shape: [1, 64, 128]

# Combine them (element-wise addition)
combined = tok_emb + pos_emb  # shape: [1, 64, 128]
combined

Input IDs: tensor([11846,  1905, 19843,  7581, 13627])


tensor([[[ 0.5851, -1.2503, -2.9723,  ..., -0.2092,  2.3059,  0.4104],
         [ 2.2785, -1.0144, -0.3601,  ...,  1.2413,  0.1254, -0.1556],
         [-0.9434, -0.4623,  0.9273,  ...,  0.3718,  1.0774, -1.3366],
         ...,
         [-0.6133,  1.6394, -1.4568,  ..., -0.1773, -0.0806, -1.7509],
         [-0.2132, -1.5831, -2.8863,  ...,  2.2222,  4.4905, -0.2002],
         [ 1.3369,  0.5410, -1.2254,  ...,  0.1226,  0.7172,  0.7807]]],
       grad_fn=<AddBackward0>)

In [52]:
from transformers import AutoTokenizer
import torch
import torch.nn as nn

# Setup
VOCAB_SIZE = 1000
EMB_DIM = 128
CONTEXT_LENGTH = 1024

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
text = "Transformers are powerful models for sequence data."
tokens = tokenizer(text, return_tensors="pt", max_length=CONTEXT_LENGTH, truncation=True, padding="max_length")

input_ids = tokens["input_ids"]   # shape: [1, CONTEXT_LENGTH]

# Adjust vocab size
VOCAB_SIZE = tokenizer.vocab_size

# Embedding layers
embeddings = nn.Embedding(VOCAB_SIZE, EMB_DIM)
pos_embeddings = nn.Embedding(CONTEXT_LENGTH, EMB_DIM)

# Create position IDs
position_ids = torch.arange(0, input_ids.size(1)).unsqueeze(0)

# Get embeddings
tok_emb = embeddings(input_ids)
pos_emb = pos_embeddings(position_ids)
combined = tok_emb + pos_emb  # shape: [1, CONTEXT_LENGTH, EMB_DIM]

# Inspect
print("Input IDs:", input_ids)
print("Token Embeddings Shape:", tok_emb.shape)
print("Combined Embedding Shape:", combined.shape)


  from .autonotebook import tqdm as notebook_tqdm


Input IDs: tensor([[  101, 19081,  2024,  ...,     0,     0,     0]])
Token Embeddings Shape: torch.Size([1, 1024, 128])
Combined Embedding Shape: torch.Size([1, 1024, 128])


### Self-attention (not trainable)

Self-attention serves as the cornerstone of every LLM based on the transformer architecture.

Self-attention is a mechanism that allows each position in the input sequence to consider the relevancy of, or “attend to,” all other positions in the same sequence when computing the representation of a sequence. Self-attention is a key component of contemporary LLMs based on the transformer architecture, such as the GPT series.

In [5]:
import torch
from rich import print as rprint
from IPython.display import HTML

display(HTML(
    """In self-attention, our goal is to calculate context vectors <code>z(i)</code>
    for each element <code>x(i)</code> in the input sequence. A context vector can be
    interpreted as an enriched embedding vector."""
))

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

rprint(inputs.shape)

# Dot product of each input vector with the query vector
query = inputs[1]
rprint(query)

# Attention scores
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)
rprint(attn_scores_2)

display(HTML(
    """The main goal behind the normalization is to obtain attention weights that sum up to 1.
    In practice, it's more common and advisable to use the softmax function for normalization."""
))

attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
rprint(f"Attention weights: {attn_weights_2}")
rprint(f"Sum: {attn_weights_2.sum()}")

display(HTML(
    """The next step is calculating the context vector <code>z(2)</code> by multiplying
    the embedded input tokens, <code>x(i)</code>, with the corresponding attention weights
    and then summing the resulting vectors. Thus, context vector <code>z(2)</code> is the 
    weighted sum of all input vectors, obtained by multiplying each input vector
    by its corresponding attention weight:
    """
))

# Calculate context vector z(2)
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
rprint(context_vec_2)

---

The dot product is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot product indicates a greater degree of alignment or similarity between the vectors.

In [7]:
display(HTML(
    """First, we calculate attention scores for each pair of input vectors. <br><br> Then,
    we normalize the scores with softmax to obtain attention weights. <br><br>
    Finally, we calculate the context vector by taking the weighted sum of the input vectors.
    """
))

# Attention scores (covariance matrix)
attn_scores = inputs @ inputs.T
rprint(attn_scores)

# Normalize to get attention weights
attn_weights = torch.softmax(attn_scores, dim=-1)
rprint(attn_weights)

# Context vectors
all_context_vecs = attn_weights @ inputs
rprint(all_context_vecs)

### Self-attention (trainable)

Our next step will be to implement the self-attention mechanism used in the original transformer architecture, the GPT models, and most other popular LLMs. This self-attention mechanism is also called scaled dot-product attention.

Weight parameters are the fundamental, learned coefficients that define the network’s connections, while attention weights are dynamic, context-specific values.

In [12]:
import torch
torch.manual_seed(123)

display(HTML(
    """The most notable difference is the introduction of weight matrices
    that are updated during model training.<br><br> These trainable weight matrices
    are crucial so that the model (specifically, the attention module
    inside the model) can learn to produce "good" context vectors.
    """
))

# Select second input vector
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

display(HTML(
    """First, we initialize three weight matrices: <code>W_query</code>, <code>W_key</code>,
    and <code>W_value</code>. Each matrix has a shape of <code>(d_in, d_out)</code>.
    """
))

W_query, W_key, W_value = [
    torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False) for _ in range(3)
]

display(HTML(
    """Then we do a dot product of input <code>x(2)</code> with the query weight matrix
    and the key weight matrix, respectively. The value weight matrix is not used
    in this step. The query and key vectors are then used to calculate the attention
    score between <code>x(2)</code> and each input vector.
    """
))

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
rprint(query_2)

display(HTML(
    """We successfully projected the six input tokens from
    a three-dimensional onto a two-dimensional embedding space:"""
))

keys = inputs @ W_key
values = inputs @ W_value
rprint(f"keys.shape: {keys.shape}, values.shape: {values.shape}")

# Attention scores
display(HTML(
    """The result for the unnormalized attention score is:"""
))

keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
rprint(attn_score_22)

# Attention scores
display(HTML(
    """Calculate all attention scores via matrix multiplication:"""
))

# Attention scores against other vectors
attn_scores_2 = query_2 @ keys.T
rprint(attn_scores_2)

display(HTML(
    """Second element matches the previous calculation.
    Next we normalize the attention scores to get attention weights.
    There is a small difference in the normalization step. We divide by the square
    root of the embedding dimension of the keys. <br><br>
    <b>The scaling by the square root of the embedding dimension is the reason why this
    self-attention mechanism is also called scaled-dot product attention.</b>
    """
))

# Normalize to get attention weights
d_k = keys.shape[-1]
rprint(f"d_k: {d_k}, Sqrt(d_k): {d_k**0.5:.2f}")
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
rprint(attn_weights_2, attn_weights_2.sum())

display(HTML(
    """And finally we get context vector by calculating
    a dot-product of the attention weights and the values.
    """
))

context_vec_2 = attn_weights_2 @ values
rprint(context_vec_2, context_vec_2.sum())

---
**Why query, key, and value?**

The terms **key**, **query**, and **value** in the context of attention mechanisms are borrowed from the domain of information retrieval and databases, where similar concepts are used to store, search, and retrieve information.

- A query is analogous to a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.

- The key is like a database key used for indexing and searching. In the attention mechanism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match the query.

- The value in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.

In [14]:
import torch
import torch.nn as nn
torch.manual_seed(123)

display(HTML(
    """Random initialization with <code>torch.rand()</code>:"""
))

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query, self.W_key, self.W_value = [
            nn.Parameter(torch.rand(d_in, d_out)) for _ in range(3)
        ]

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T  # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        return attn_weights @ values

sa_v1 = SelfAttention_v1(d_in, d_out)
rprint(sa_v1(inputs))

display(HTML(
    """Random initialization with <code>nn.Linear()</code>
    helps to perform matrix multiplication without bias:"""
))

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query, self.W_key, self.W_value = [
            nn.Linear(d_in, d_out, bias=qkv_bias) for _ in range(3)
        ]

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        return attn_weights @ values
    
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
rprint(sa_v2(inputs))

display(HTML(
    """<code>nn.Linear</code> has an optimized weight
    initialization scheme, contributing to more stable and
    effective model training. <br><br>
    Note that <code>SelfAttention_v1</code> and <code>SelfAttention_v2</code>
    give different outputs because they use different initial weights
    for the weight matrices since <code>nn.Linear</code> uses a more sophisticated
    weight initialization scheme."""
))

### Multi-head attention

Each head learns different aspects of the data, allowing the model to simultaneously attend to information from different representation subspaces at different positions.

In [16]:
import torch

queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
rprint(attn_weights)

display(HTML(
    """<code>torch.tril</code> sets values above the diagonal to zero:"""
))

context_length = attn_scores.shape[0]
rprint(f"Context length: {context_length}")
mask_simple = torch.tril(torch.ones(context_length, context_length))
rprint(f"Mask:\n{mask_simple}")

display(HTML(
    """Setting attention weights to zero:"""
))

masked_simple = attn_weights * mask_simple
rprint(masked_simple)

display(HTML(
    """Normalize to 1 on masked inputs:"""
))

row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
rprint(masked_simple_norm)

---

When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of 1/0.5 = 2.

### Single attention head

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors.

In [17]:
torch.manual_seed(123)


class CausalAttention(nn.Module):
    def __init__(
        self, d_in, d_out, context_length,
        dropout, qkv_bias=False
    ):
        super().__init__()
        self.d_out = d_out
        self.W_query, self.W_key, self.W_value = [
            nn.Linear(d_in, d_out, bias=qkv_bias) for _ in range(3)
        ]
  
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
           'mask',
           torch.triu(
               torch.ones(context_length, context_length),
                diagonal=1
            )
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        return attn_weights @ values

batch = torch.stack((inputs, inputs), dim=0)
rprint(f"Batch: {batch.shape}")
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
rprint(f"context_vecs.shape: {context_vecs.shape}")
rprint(context_vecs)

### Multi-head attention

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.

`torch.cat` performs concatenation along a specified dimension.

In [18]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(
        self, d_in, d_out, context_length,
        dropout, num_heads, qkv_bias=False
    ):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(
                    d_in,
                    d_out,
                    context_length,
                    dropout,
                    qkv_bias
                ) 
             for _ in range(num_heads)
            ]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)
    
torch.manual_seed(123)
context_length = batch.shape[1]  # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

rprint(context_vecs)
rprint(f"context_vecs.shape: {context_vecs.shape}")

### Multi-head attention with weight splits

In softmax regression (multiclass classification), we train a single weight matrix for all classes instead of having separate classifiers for each class.

In multi-head attention, we don't train separate weight matrices for each attention head. Instead, we stack multiple smaller weight matrices into a single large matrix, just like in softmax regression.

---

**What is different?**

`MultiHeadAttentionWrapper`: *Concatenation of all heads → [batch, num_tokens, d_out * num_heads]*

Each head's output is kept separate. Used in some custom implementations, but not in standard transformers.

`MultiHeadAttention`: *Final projection to d_out → [batch, num_tokens, d_out]*

This is a standard Transformer implementations (e.g., GPT, BERT). Used in official Transformer architectures (e.g., Vaswani’s original paper). We apply a final projection layer (`out_proj`) to mix the heads.

In [19]:
import torch


class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec

torch.manual_seed(123)

display(HTML(
    """Taking a tensor <code>a</code>:"""
))

a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

rprint(a.shape)
rprint(a)

display(HTML(
    """Now we perform a batched matrix multiplication
    between the tensor itself and a view of the tensor
    where we transposed the last two dimensions,
    <code>num_tokens</code> and <code>head_dim</code>:"""
))

rprint(a @ a.transpose(2, 3))

first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
rprint(f"First head\n\na[0, 0, :, :] @ a[0, 0, :, :].T:\n\n{first_res}")

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
rprint(f"Second head\n\na[0, 1, :, :] @ a[0, 1, :, :].T:\n\n{second_res}")

display(HTML(
    """Now we perform a batched matrix multiplication
    between the tensor itself and a view of the tensor
    where we transposed the last two dimensions,
    <code>num_tokens</code> and <code>head_dim</code>:"""
))

display(HTML(
    """The results show that the output dimension is
    directly controlled by the <code>d_out</code> argument:"""
))
    
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)

rprint(context_vecs)
rprint(f"context_vecs.shape: {context_vecs.shape}")

**Exercise 3.3 Initializing GPT-2 size attention modules**

Using the `MultiHeadAttention` class, initialize a multi-head attention module that has the same number of attention heads as the smallest GPT-2 model (12 attention heads). Also ensure that you use the respective input and output embedding sizes similar to GPT-2 (768 dimensions). Note that the smallest GPT-2 model supports a context length of 1,024 tokens.

In [22]:
torch.manual_seed(123)

batch_size, context_length, d_in = 1, 1024, 768
d_out = 768
n_heads = 12

batch = torch.rand(batch_size, context_length, d_in)

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=n_heads)
context_vecs = mha(batch)

display(HTML(
    """In GPT models, input embeddings and outputs have the same dimensions:"""
))

rprint(context_vecs)
rprint(f"context_vecs.shape: {context_vecs.shape}")

## Projection layer

[Docs > torch.nn > MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)

```python
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output, attn_output_weights = multihead_attn(query, key, value)
```

L is the target sequence length, N is the batch size, and E is the embedding dimension.

In [24]:
# (L, N, E)
from torch.nn import MultiheadAttention

batch_size, context_length, d_in, embed_dim = 1, 1024, 768, 768

# Create properly shaped input tensors [1024, 1, 768]
q, k, v = [torch.rand(context_length, batch_size, d_in) for _ in range(3)]

# # Convert to [context_length, batch_size, d_in]
q, k, v = [t.transpose(0, 1) for t in [q, k, v]]

# Initialize MultiheadAttention
mha = MultiheadAttention(embed_dim, num_heads=12, dropout=0.0)

# Compute attention
context_vecs, context_weights = mha(q, k, v)

rprint(context_vecs.shape)  # Expected: [1, 1024, 768]

In [26]:
# (N, L, E) 
from torch.nn import MultiheadAttention

batch_size, context_length, d_in, embed_dim = 1, 1024, 768, 768

# Create properly shaped input tensors [1024, 1, 768]
q, k, v = [torch.rand(batch_size, context_length, embed_dim) for _ in range(3)]

# Initialize MultiheadAttention
mha = MultiheadAttention(embed_dim, num_heads=12, dropout=0.0, batch_first=True)

# Compute attention
context_vecs, context_weights = mha(q, k, v)

rprint(context_vecs.shape)  # Expected: [1, 1024, 768]