## Implementing MultiHead Attention with Weight Splits

Instead of maintaining two classes for **MultiHeadAttention** and **CausalAttention** we can combine two clases into a single one . This class integrates the multi head functionality with in a single class. It splits the inputs into multiple heads by reshaping the projected query, key and value tensors and then combines the results from these heads after computing attention.

In [9]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_Query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_Key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_Value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
        
    def forward(self, x):
        b, num_tokens, d_in = x.size() # batch, num_tokens, dimension
        
        keys = self.W_Key(x)
        queries = self.W_Query(x)
        values = self.W_Value(x)
        # We change last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) to split the d_out in num_heads part
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        
        # transpose (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        
        attention_scores = queries @ keys.transpose(-2, -1) # we transpose the last two dimension as the first two will be broadcased
        
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attention_scores.masked_fill_(mask_bool, float("-inf"))
        
        attention_weights = torch.nn.functional.softmax(attention_scores / keys.shape[-1] ** 0.5, dim=-1)
        
        attention_weights = self.dropout(attention_weights)
        
        context_vectors = (attention_weights @ values).transpose(1, 2) # (b, num_heads, num_tokens, head_dim) -> (b, num_tokens, num_heads, head_dim)
        
        context_vectors = context_vectors.contiguous().view(b, num_tokens, self.d_out)
        context_vectors = self.out_proj(context_vectors)
        
        return context_vectors

In [10]:
data = torch.randn(8, 200, 512)
context_length = 200
d_in = 512
d_out = 512
dropout = 0.1
num_heads = 8
qkv_bias = False

mha = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads, qkv_bias)

output = mha(data)

print(output.size())

torch.Size([8, 200, 512])
