# **A Deep Dive Into Multi-Head Attention And PyTorch Buffers**

## Comparisons of Efficient Multi-Head Attention Implementations

In [1]:
import torch
import torch.nn as nn

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"PyTorch version: {torch.__version__}")

batch_size = 8
context_length = 1024
embed_dim = 768
embeddings = torch.randn((batch_size, context_length, embed_dim), device=device)

PyTorch version: 2.7.0+cu126


### **1. CausalAttention MHA Wrapper**

In [5]:
class CausalAttention(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask",
                             torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        context_vec = attn_weights @ values
        return context_vec
    
class MHA_Wrapper(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)])
        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)
        
    def forward(self, x):
        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.out_proj(context_vec)

In [6]:
mha_wrapper = MHA_Wrapper(
    d_in=embed_dim,
    d_out=embed_dim//12,
    context_length=context_length,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)


In [9]:
%time out = mha_wrapper(embeddings)
out.shape

CPU times: user 2.2 ms, sys: 208 μs, total: 2.4 ms
Wall time: 2.25 ms


torch.Size([8, 1024, 768])

### **2. Multi-Head Attention With Split Weights**

In [10]:
class MHA(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        # Reduce projection dim to match desired output dim
        self.head_dim = d_out // num_heads
        
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        # Implicitly split the matrix by adding `num-heads` dim
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        
        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        # Scaled dot product attention with causal mask. Dot product for each head
        attn_scores = queries @ keys.transpose(2, 3) 
        
        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)  # optional projection

        return context_vec
        

In [11]:
mha = MHA(
    d_in=embed_dim,
    d_out= embed_dim,
    context_length=context_length,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

In [12]:
%timeit out = mha(embeddings)
print(out.shape)

14.9 ms ± 45 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
torch.Size([8, 1024, 768])


### **3. Alternative MHA With Combined Weights**

The code for the `MultiHeadAttentionCombinedQKV` class below is based on code that was shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51).

The difference between `MultiHeadAttentionCombinedQKV` class and the `MHA` class in the previous section is the use of a single weight matrix instead of separate weight matrices for the prior.

In [13]:
class MultiHeadAttentionCombinedQKV(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()

        assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"

        self.num_heads = num_heads
        self.context_length = context_length
        self.head_dim = d_out // num_heads

        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias) # Key change
        self.proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            "mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
        
    def forward(self, x):
        batch_size, num_tokens, embed_dim = x.shape
        
        # (b, num_tokens , embed_dim) --> (b, num_tokens, 3 * embed_dim)
        qkv = self.qkv(x)
        
        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
        
        # (b - num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        
        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)
        queries, keys, values = qkv.unbind(0)
        
        # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)
        attn_scores = queries @ keys.transpose(-2, -1)
        attn_scores = attn_scores.masked_fill(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)
        context_vec = attn_weights @ values
        
        # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_heads, head_dim)
        context_vec = context_vec.transpose(1, 2)
        
        # (b, num_heads, num_heads, head_dim) --> (b, num_tokens, embed_dim)
        context_vec = context_vec.contiguous().view(batch_size, num_tokens, embed_dim)
        
        context_vec = self.proj(context_vec)
        return context_vec

In [14]:
mha_combined_qkv = MultiHeadAttentionCombinedQKV(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_length,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

In [17]:
%timeit out = mha_combined_qkv(embeddings)

16.7 ms ± 37.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
out.shape

torch.Size([8, 1024, 768])

### **4. Implementing MHA with Einsum**

Implementations of MHA using `einsum` are often faster than base implementations due to reduced tensor manipulation overhead, optimized batched computations and kernel fusion opportunities. Specifically, PyTorch’s einsum leverages optimized backends (e.g., NVIDIA’s cuTENSOR) for specific tensor contractions, outperforming naive PyTorch operations.

In [21]:
import math

class MHAEinsum(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        
        # Parameters for Q, K, V
        self.W_query = nn.Parameter(torch.randn(d_out, d_in))
        self.W_key = nn.Parameter(torch.randn(d_out, d_in))
        self.W_value = nn.Parameter(torch.randn(d_out, d_in))
        
        if qkv_bias:
            self.bias_q = nn.Parameter(torch.zeros(d_out))
            self.bias_k = nn.Parameter(torch.zeros(d_out))
            self.bias_v = nn.Parameter(torch.zeros(d_out))
        else:
            self.register_parameter("bias_q", None)
            self.register_parameter("bias_k", None)
            self.register_parameter("bias_v", None)
            
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask",
                             torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
        # Initialize parameters
        self.reset_parameters()
        
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W_query, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_key, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_value, a=math.sqrt(5))
        
        if self.bias_q is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_query)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias_q, -bound, bound)
            nn.init.uniform_(self.bias_k, -bound, bound)
            nn.init.uniform_(self.bias_v, -bound, bound)
            
    def forward(self, x):
        b, n, _ = x.shape
        
        # Calculate Q, K, V using einsum, beginning with linear transforms
        Q = torch.einsum("bnd,di->bni", x, self.W_query)
        K = torch.einsum("bnd,di->bni", x, self.W_key)
        V = torch.einsum("bnd,di->bni", x, self.W_value)
        
        # Add biases if they are used
        if self.bias_q is not None:
            Q += self.bias_q
            K += self.bias_k
            V += self.bias_v
            
        # Reshape for multi-head attention
        Q = Q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.einsum("bhnd,bhmd->bhnm", Q, K) / (self.head_dim ** 0.5)
        
        # Apply mask
        mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n) 
        scores = scores.masked_fill(mask.bool(), -torch.inf)
        
        # Softmax and dropout
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Aggregate the attended context vectors
        context_vec = torch.einsum("bhnm, bhmd->bhnd", attn_weights, V)
        
        # Combine heads and project the output
        context_vec = context_vec.transpose(1, 2).reshape(b, n, self.d_out)
        context_vec = self.out_proj(context_vec)
        
        return context_vec

In [22]:
mha_einsum = MHAEinsum(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_length,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

In [23]:
%timeit out = mha_einsum(embeddings)

17.9 ms ± 35.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
out.shape

torch.Size([8, 1024, 768])

### **5. MHA With PyTorch's Scaled Dot Product Attention and FlashAttention**

In [29]:
# Using PyTorch's scaled_dot_product_attention function which implements
# a memory optimized version of self-attention called Flash Attention
class MHAPyTorchScaledDotProduct(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"
        
        self.num_heads = num_heads
        self.context_length = context_length
        self.head_dim = d_out // num_heads
        self.d_out = d_out
        
        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
        self.proj = nn.Linear(d_out, d_out)
        self.dropout = dropout

    def forward(self, x):
        batch_size, num_tokens, embed_dim = x.shape
        
        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
        qkv = self.qkv(x)
        
        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
        
        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        
        # (3, b, num_heads, num_tokens, head_dim) --> 3 times (b, num_heads, num_tokens, head_dim)
        queries, keys, values = qkv
        
        use_dropout = 0. if not self.training else self.dropout
        
        context_vec = nn.functional.scaled_dot_product_attention(
            queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)
        
        # Combine heads where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)
        context_vec = self.proj(context_vec)
        return context_vec

In [30]:
mha_pytorch_scaled = MHAPyTorchScaledDotProduct(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_length,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

In [31]:
%timeit out = mha_pytorch_scaled(embeddings)

6.05 ms ± 18.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [32]:
out.shape

torch.Size([8, 1024, 768])

### **6. PyTorch's Scaled Dot-Product Attention w/o FlashAttention** 

In [33]:
# Compared to the above, we disable FlashAttention by passing an explicit
# causal mask

class MHASDPAWithoutFlash(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()

        assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"

        self.num_heads = num_heads
        self.context_length = context_length
        self.head_dim = d_out // num_heads
        self.d_out = d_out

        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
        self.proj = nn.Linear(d_out, d_out)
        self.dropout = dropout
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())

    def forward(self, x):
        batch_size, num_tokens, embed_dim = x.shape

        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
        qkv = self.qkv(x)

        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)

        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)

        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
        queries, keys, values = qkv

        use_dropout = 0. if not self.training else self.dropout
        
        # Ensure attn_mask is compatible with expected shape and `batch_first=True`
        # Manual adjustment of num_heads is not necessary
        if self.context_length >= num_tokens:
            attn_mask = self.mask[:num_tokens, :num_tokens]
        else:
            attn_mask = self.mask[:self.context_length, :self.context_length]
        
        context_vec = nn.functional.scaled_dot_product_attention(
            queries, keys, values, attn_mask=attn_mask, dropout_p=use_dropout, is_causal=False)
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)
        context_vec = self.proj(context_vec)
        
        return context_vec 

In [34]:
mha_sdpa_no_flash = MHASDPAWithoutFlash(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_length,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

In [36]:
%timeit out = mha_sdpa_no_flash(embeddings)
print(out.shape)

8.02 ms ± 58.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
torch.Size([8, 1024, 768])


In [37]:
torch.cuda.empty_cache()