In [13]:
import torch
import torch.nn as nn

we will define the class inheriting from **nn.Module** then assert **d_out % num_heads == 0** to ensure even splitting of dimensions for heads

In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim

        ## then initialize linear layers for Q, K, V projections from d_in to d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        ## then add linear to combine heads post-attention. Dropout for regularization
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)

        ## then register a upper-triangular mask as buffer so we have a causal mask to block future tokens
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
        # this will result in errors in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.

        ## then we'll set a seed for reproductibility and create a random input to instantiate class with parameters
        torch.manual_seed(42)

        # Sample parameters and data
        b = 2
        torch.manual_seed(42)

        # Sample parameters and data
        b = 2
        num_tokens = 3
        d_in = 4
        d_out = 6
        num_heads = 3
        dropout = 0.0
        qkv_bias = False

        x = torch.randn(b, num_tokens, d_in)
        print("Input x shape:", x.shape)
        print("Input x:\n", x)

        mha = MultiHeadAttention(d_in, d_out, num_tokens, dropout, num_heads, qkv_bias)

        ## transform input to attention spaces and d_out concatenated for all heads.
        keys = mha.W_key(x)  # Shape: (b, num_tokens, d_out)
        print("Keys shape:", keys.shape)
        print("Keys:\n", keys)

        queries = mha.W_query(x)
        print("\nQueries shape:", queries.shape)
        print("Queries:\n", queries)

        values = mha.W_value(x)
        print("\nValues shape:", values.shape)
        print("Values:\n", values)

        ## then reshape last dim to add num_heads and head_dim
        keys = keys.view(b, num_tokens, mha.num_heads, mha.head_dim)
        print("Keys after view shape:", keys.shape)
        print("Keys after view:\n", keys)

        values = values.view(b, num_tokens, mha.num_heads, mha.head_dim)
        print("\nValues after view shape:", values.shape)
        print("Values after view:\n", values)

        queries = queries.view(b, num_tokens, mha.num_heads, mha.head_dim)
        print("\nQueries after view shape:", queries.shape)
        print("Queries after view:\n", queries)

        ## then swap num_tokens and num_heads because it makes shape (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        print("Keys after transpose shape:", keys.shape)
        print("Keys after transpose:\n", keys)

        queries = queries.transpose(1, 2)
        print("\nQueries after transpose shape:", queries.shape)
        print("Queries after transpose:\n", queries)

        values = values.transpose(1, 2)
        print("\nValues after transpose shape:", values.shape)
        print("Values after transpose:\n", values)

        ## Batched dot-product
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head
        print("Attention scores shape:", attn_scores.shape)
        print("Attention scores:\n", attn_scores)

        ## Truncates mask to num_tokens, convert to bool, fill attn_scores with -inf where masked
        mask_bool = mha.mask.bool()[:num_tokens, :num_tokens]
        print("Mask shape:", mask_bool.shape)
        print("Mask:\n", mask_bool)

        attn_scores.masked_fill_(mask_bool, -torch.inf)
        print("\nMasked attention scores shape:", attn_scores.shape)
        print("Masked attention scores:\n", attn_scores)

        ## below: scaling stabilizes, softmax to probabilities, dropout regularizes and batches are efficient
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        print("Attention weights after softmax shape:", attn_weights.shape)
        print("Attention weights after softmax:\n", attn_weights)

        attn_weights = mha.dropout(attn_weights)
        print("\nAfter dropout shape:", attn_weights.shape)
        print("After dropout:\n", attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)
        print("Context vec after matrix multiplication and transpose shape:", context_vec.shape)
        print("Context vec after matrix multiplication and transpose:\n", context_vec)

        context_vec = context_vec.contiguous().view(b, num_tokens, mha.d_out)
        print("Context vec after view shape:", context_vec.shape)
        print("Context vec after view:\n", context_vec)

        context_vec = mha.out_proj(context_vec) # optional projection
        print("\nFinal output shape:", context_vec.shape)
        print("Final output:\n", context_vec)
        return context_vec


# test
x = torch.randn(2, 3, 4)
mha = MultiHeadAttention(4, 6, 3, 0.0, 3, False)
output = mha.forward(x)


Input x shape: torch.Size([2, 3, 4])
Input x:
 tensor([[[ 1.9269,  1.4873,  0.9007, -2.1055],
         [ 0.6784, -1.2345, -0.0431, -1.6047],
         [ 0.3559, -0.6866, -0.4934,  0.2415]],

        [[-1.1109,  0.0915, -2.3169, -0.2168],
         [-0.3097, -0.3957,  0.8034, -0.6216],
         [-0.5920, -0.0631, -0.8286,  0.3309]]])
Keys shape: torch.Size([2, 3, 6])
Keys:
 tensor([[[ 1.1378, -0.9639,  0.6903, -1.0604, -0.3400,  1.4839],
         [ 0.6902, -0.1209,  1.4849, -1.1526,  0.5585,  0.9634],
         [-0.1501, -0.0830,  0.6047, -0.2467, -0.0375, -0.0455]],

        [[-1.0763, -0.7741,  0.6616, -0.1127, -0.3152, -0.9264],
         [ 0.5111,  0.4681, -0.1410, -0.0904,  0.5536,  0.4242],
         [-0.5258, -0.1221,  0.0227,  0.1755, -0.1123, -0.5377]]],
       grad_fn=<UnsafeViewBackward0>)

Queries shape: torch.Size([2, 3, 6])
Queries:
 tensor([[[-0.6932,  0.3034, -0.0030, -0.4352,  1.3922,  0.0545],
         [-0.4669,  0.3099,  0.8798, -1.1052,  0.6712,  0.7123],
         [ 0.144