## IMPLEMENTING MULTI-HEAD ATTENTION WITH WEIGHT Splits

Instead of maintaining two separate classes, MultiHeadAttentionWrapper and CausalAttention, we can combine both of these concepts into a single MultiHeadAttention class.

Also, in addition to just merging the MultiHeadAttentionWrapper with the CausalAttention code we will make some other modifications to implement multi-head attention more efficiently.

In the MultiHeadAttentionWrapper, multiple heads are implemented by creating a list of CausalAttention objects (self_heads), each representing a separate attention head.

The CausalAttention class independently performs the attention mechanism, and the results from each head are concatenated.

In contrast, the following MultiHeadAttention class integrates the multi-head functionality within a single class.

It splits the input into multiple heads by reshaping the projected query, key, and value tensors and then combines the results from these heads after computing attention.

Let's take a look at the MultiHeadAttention class before we discuss it further:

Step 1: Reduce the projection dim to match desired output dim

Step 2: Use a Linear layer to combine head outputs

Step 3: Tensor shape: (b, num_tokens, d_out)

Step 4: We implicitly split the matrix by adding a num_heads dimension. Then we unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_head, head_dim)

Step 5: Transpose from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)

Step 6: Compute dot product for each head

Step 7: Mask truncated to the number of tokens

Step 8: Use the mask to fill attention scores

Step 9: Tensor shape: (b, num_tokens, n_heads, head_dim)

Step 10: Combine heads, where self.d_out = self.num_heads * self.head_dim

Step 11: Add an optional linear projection

Even though the reshaping (view) and transposing (.transpose) of tensors inside the MultiHeadAttention class looks very complicated, mathematically, MultiHeadAttention class implements the same concept as the MultiHeadAttentionWrapper earlier.

On a big-picture level, in the previous MultiHeadAttentionWrapper, we stacked multiple single-head attention layers that we combined into a multi-head layer.

The MultiHeadAttention class takes an integrated approach.

It starts with a multi-head layer and then internally splits this layer into individual attention heads

## DETAILED EXPLANATION OF THE MULTI-HEAD ATTENTION CLASS

The splitting of the query, key, and value tensors, is achieved through tensor reshaping and transposing operations using PyTorch's .view and .transpose methods.

The input is first transformed (via linear layers for queries, keys, and values) and then reshaped to represent multiple heads.

The key operation is to split the d_out dimension into num_heads and head_dim, where head_dim = d_out / num_heads.

This splitting is then achieved using the view method: a tensor of dimensions (b, num_tokens, d_out) is reshaped to dimension (b, num_tokens, num_heads, num_dim)

The tensors are then transposed to bring the num_heads dimension before the num_tokens dimension, resulting in a shape of (b, num_heads, num_tokens, num_dim)

This transposition is crucial for correctly aligning the queries, keys, and values across the different heads and performing hatched matrix multiindications efficiently.

To illustrate this batched matrix multiplication, suppose we have the following example tensor:

Continuing with MultiHeadAttention, after computing the attention weights and context vectors, the context vectors from all heads are transposed back to the shape [b, num_tokens, num_heads, head_gim].

These vectors are then reshaped (flattened) into the shape (b, num_tokens, d_out), effectively combining the outputs from all heads

Additionally, we added a so-called output projection layer (self.out Projekt) to MultiHeadAttention after combining the heads, which is not present in the CausalAttention class.

This output projection layer is not strictly necessary, but it is commonly used in many LLM architectures, which is why we added it here for completeness.

Even though the MultiHeadAttention class looks more complicated than the MultiHeadAttentionWrapper due to the additional reshaping and transposition of tensors, it is more efficient.

The reason is that we only need one matrix multiplication to compute the keys, for instance, keys = self.W_key(x) (the same is true for the queries and values).

In the MultiHeadAttentionWrapper, we needed to repeat this matrix multiplication, which is computationally one of the most expensive steps. for each attention head.


In [1]:
import torch
from llm import MultiHeadAttention

  assert(d_out % num_heads == 0, \


In [2]:
torch.manual_seed(123)

inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66],
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33],
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55]]
)

batch = torch.stack([inputs, inputs], dim=0)  # (2, 3, 6)
print(batch.shape)

batch_size, context_length, d_in = batch.shape
d_out = 6
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape", context_vecs.shape)  # (2, 3, 6)   

torch.Size([2, 3, 6])
tensor([[[ 0.1195, -0.0484,  0.0306, -0.0639, -0.2782, -0.2564],
         [ 0.1208, -0.0497,  0.0319, -0.0638, -0.2779, -0.2566],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]],

        [[ 0.1195, -0.0484,  0.0306, -0.0639, -0.2782, -0.2564],
         [ 0.1208, -0.0497,  0.0319, -0.0638, -0.2779, -0.2566],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]]],
       grad_fn=<ViewBackward0>)
context_vecs.shape torch.Size([2, 3, 6])


As we can see based on the results, the output dimension is directly controlled by the d_out argument:

In this section, we implemented the MultiHeadAttention class that we will use in the upcoming sections when implementing and training the LLM itself.

Note that while the code is fully functional, we used relatively small embedding sizes and numbers of attention heads to keep the outputs readable.

For comparison, the smallest GPT-2 model (117 million parameters) has 12 attention heads and a context vector embedding size of 768.

The largest OPEB model (1.5 billion euros) has 25 attention heads and a contract value of â‚¬1000

Note that the embedding sizes of the token inputs and context embedding are the same in GPT model(d_in=d_out)