# Implementing Multi-head Attention

### Ref: [The AiEdge Newsletter](https://drive.google.com/file/d/1Je2SAFBlsWcgwzK_gl1_f-LtPK3SOzg3/view)

<p float="left">
  <img src="../../assets/multihead_attention.png" width="500" height="250">
  <img src="../../assets/attention_head.png" width="500" height="250"> 
</p>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MultiHeadAttention(nn.Module):
    """
    Implementation of multi-head attention that is used in transformer architecture.
    It splits the input (embeddings) into multiple heads, computes attention for each 
    for each head, and then concatenate the results.

    Args:
        d_model (int): input (embedding) dimension (also known as hidden_size)
        n_head (int): number of attention heads
    
    Raises:
        ValueError: if d_model is not divisible by n_head
    """    
    def __init__(self, d_model: int, n_head: int) -> None:
        super().__init__()

        if d_model % n_head != 0: 
            raise ValueError('d_model should be divisible by n_head ...')
        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_model // n_head     # head dimension

        self.Q = nn.Linear(d_model, d_model)
        self.K = nn.Linear(d_model, d_model)
        self.V = nn.Linear(d_model, d_model)
        self.O = nn.Linear(d_model, d_model) # allows the model to remix the heads that optimizes task performance.

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        compute multi-head self-attention
        Args:
            x (torch.Tensor): input tensor of size [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: output tensor of shape [batch_size, seq_len, d_model]
        """        
        batch_size, seq_len, d_model = x.size()

        # linear projection of input
        queries = self.Q(x) # [batch_size, seq_len, d_model]
        keys = self.K(x)
        values = self.V(x)

        # split into multiple heads (by reshaping tensors).
        # each of the following tensors has [batch_size, n_head, seq_len, d_head] shape.
        keys = keys.reshape((batch_size, seq_len, self.n_head, self.d_head)).transpose(1, 2)
        queries = queries.reshape((batch_size, seq_len, self.n_head, self.d_head)).transpose(1, 2)
        values = values.reshape((batch_size, seq_len, self.n_head, self.d_head)).transpose(1, 2)

        # compute self-attention
        scores = queries @ keys.transpose(-1, -2) / self.d_head ** 0.5  
        self_attn = F.softmax(scores, dim=-1)   # [batch_size, n_head, seq_len, seq_len]
        out = self_attn @ values    # [batch_size, n_head, seq_len, d_head]

        # recombine heads
        out = out.transpose(1, 2)   # [batch_size, seq_len, n_head, d_head]
        out = out.reshape((batch_size, seq_len, d_model))

        # final linear projection
        out = self.O(out)   # [batch_size, seq_len, d_model]

        return out


#### Toy Example

In [3]:
batch_size = 2
vocab_size = 15
d_model = 12    # or model dim or hidden size
seq_len = 5
n_head = 3

# generate random Token Ids (note that Embedding layer only receives integer values)
x = torch.randint(0, vocab_size, (batch_size, seq_len))

embedding = nn.Embedding(vocab_size, d_model)
embedded = embedding(x)     # [batch_size, seq_len, d_model]

attn = MultiHeadAttention(d_model, n_head)
hidden_state = attn(embedded)   # [batch_size, seq_len, d_model]

print(f'input x: {x.size()}, and embedded: {embedded.size()}, and hidden state: {hidden_state.size()}')

input x: torch.Size([2, 5]), and embedded: torch.Size([2, 5, 12]), and hidden state: torch.Size([2, 5, 12])


## More Efficient implementation of MultiHeadAttention
* It is faster and more memory-efficient on GPU (by using one big linear projection for Q, K, and V)

In [None]:
class MultiHeadAttentionV2(nn.Module):
    """
    Multi-head attention

    Args:
        d_model (int): input (embedding) dimension (also known as hidden_size)
        n_head (int): number of attention heads
    Raises:
        ValueError: if d_model is not divisible by n_hea
    """    
    def __init__(self, d_model: int, n_head: int) -> None:
        super().__init__()

        if d_model % n_head != 0: 
            raise ValueError('d_model should be divisible by n_head ...')

        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_model // n_head   # head dimension

        # combined linear projection for Q, K, and V
        self.qkv_linear = nn.Linear(d_model, d_model * 3)
        # Final output linear projection that allows the model to remix the heads that optimizes task performance.
        self.O = nn.Linear(d_model, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass to compute multi-head self attention
        Args:
            x (torch.Tensor): input tensor of shape [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: hedden states tensor with shape of [batch_size, seq_len, d_model]
        """        
        batch_size, seq_len, d_model = x.size()
        qkv = self.qkv_linear(x)    # [batch_size, seq_len, d_model * 3]
        qkv = qkv.reshape((batch_size, seq_len, self.n_head, self.d_head * 3))
        qkv = qkv.transpose(1, 2)   # [batch_size, n_head, seq_len, d_head * 3]

        queries, keys, values = qkv.chunk(3, dim=-1)    # each has dimension of [batch_size, n_head, seq_len, d_head]
        scores = queries @ keys.transpose(-1, -2) / self.d_head ** 0.5
        self_attn = F.softmax(scores, dim=-1)  # [batch_size, n_head, seq_len, seq_len]

        out = self_attn @ values    # [batch_size, n_head, seq_len, d_head]
        out = out.transpose(1, 2)   # [batch_size, seq_len, n_head, d_head]
        out = out.reshape((batch_size, seq_len, d_model))
        out = self.O(out)

        return out


#### Toy Example

In [5]:
batch_size = 2
vocab_size = 15
d_model = 12    # or model dim or hidden size
seq_len = 5
n_head = 3

# generate random Token Ids (note that Embedding layer only receives integer values)
x = torch.randint(0, vocab_size, (batch_size, seq_len))

embedding = nn.Embedding(vocab_size, d_model)
embedded = embedding(x)     # [batch_size, seq_len, d_model]

attn = MultiHeadAttentionV2(d_model, n_head)
hidden_state = attn(embedded)   # [batch_size, seq_len, d_model]

print(f'input x: {x.size()}, and embedded: {embedded.size()}, and hidden state: {hidden_state.size()}')

input x: torch.Size([2, 5]), and embedded: torch.Size([2, 5, 12]), and hidden state: torch.Size([2, 5, 12])
