# Multi-Head Attention

An implementation of Multi-Head Attention (MHA) with learnable parameters for query, key, value, and final output to combine heads (W_q, W_k, W_v, W_o).

## Code

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

We reuse the [scaled dot-product attention](../scaled_dot_product_attention/) implementation.

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        d_k = query.size(-1)
        scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=query.dtype))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        output = torch.matmul(attention_weights, value)
        return output, attention_weights

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Learnable parameters (linear projections) for query, key, value, and output
        self.q_proj = nn.Linear(embed_dim, embed_dim) # W_q for query
        self.k_proj = nn.Linear(embed_dim, embed_dim) # W_k for key
        self.v_proj = nn.Linear(embed_dim, embed_dim) # W_v for value
        self.out_proj = nn.Linear(embed_dim, embed_dim) # W_o to learn how to combine each head's output into a final output

        self.attention = ScaledDotProductAttention(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        # Pass the input X through the Q, K, V parameter matrices, 
        # then reshape them to multi-heads (batch_size, seq_len, num_heads, head_dim)
        # and transpose to (batch_size, num_heads, seq_len, head_dim) to apply attention for each head in parallel
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  
        k = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  
        v = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Apply scaled dot-product attention to each head
        attn_output, attn_weights = self.attention(q, k, v, mask)

        # Concatenate each head's output
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)
        # Pass the concatenated output through the final projection (learnable paramters to combine meanings of each head's output)
        output = self.out_proj(attn_output)

        return output, attn_weights

## Line-by-line Explanation

### def \_\_init\_\_(self, embed_dim, num_heads, dropout=0.1)

- Make sure embedding dimension is divisible by number of heads
  - `assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"`
- Store the embedding dimension, number of heads, and head dimension as instance attributes
  - ```
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    ```
- Create weight matrices for Q, K, V and head combination output
  - ```
    self.q_proj = nn.Linear(embed_dim, embed_dim)
    self.k_proj = nn.Linear(embed_dim, embed_dim)
    self.v_proj = nn.Linear(embed_dim, embed_dim)
    self.out_proj = nn.Linear(embed_dim, embed_dim)
    ```
- Store the scaled dot-product computation nn.Module also as an instance attribute for later usage
  - `self.attention = ScaledDotProductAttention(dropout)`

### def forward(self, x, mask=None)

- Get batch size and sequence length from input tensor X, which are the first two dimensions of x's shape
  - `batch_size, seq_len, _ = x.shape` 
- Pass the input X through the query, key, value weight matrices, then split the last dimension `emb_dim` into `(num_head, head_dim)`. Then transpose the second and third dimensions so that we have them in shape `(batch_size, num_head, seq_len, head_dim)` for parallel multi-head attention computation
  - ```
    q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    k = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    v = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
    ```
- Apply attention calculation to each head
  - `attn_output, attn_weights = self.attention(q, k, v, mask)`
    - Thanks to PyTorch’s batch-aware `matmul` and `broadcasting` mechanisms, this one line will apply attention to all heads at once in parallel over the `num_heads` dimension, which is more efficient than using an explicit for-loop over heads.
- Concatenate the output of each head
  - `attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_dim)`
    - First, transpose back the `num_head` and `seq_len` dimensions
    - Use `.contiguous()` to ensure the tensor's memory layout is row major and contiguous in memory, which is required by `.view()` to reshape them
    - Use `.view()` to merge the last two dimensions `num_head` and `head_dim` into the original `self.embed_dim`
- Pass the concatenated output through the output weight matrix to combine each head meaningfully into the final output
  - `output = self.out_proj(attn_output)`
- Return the final output and attention weights
  - `return output, attn_weights`
    - `output` is of shape `(batch_size, seq_len, embed_dim)`, ready to be passed to the next layer
    - `attn_weights` is the attention weights per head, so its shape is `(batch_size, num_heads, seq_len, seq_len)`

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0, \
            f'Model embedding dimension (received embed_dim: \
                {embed_dim}) must be divisible by number of \
                    attention heads (received num_heads: {num_heads})'
        self.num_heads = num_heads
        self.head_dim = embed_dim / num_heads

        # initialize Q, K, V, and output linear projection layers
        # input x (batch_size, seq_len, embed_dim) will go through them
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.W_out = nn.Linear(embed_dim, embed_dim)

        # initialize dropout layer
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # x: (batch_size, seq_len, embed_dim)
        batch_size, seq_len, embed_dim = x.size()

        # Step 1: Project input through the Q, K, V linear projections
        #       shape of x: (batch_size, seq_len, embed_dim)
        #       shape of W_q, W_k, W_v: (embed_dim, embed_dim)
        #       shape after projection: (batch_size, seq_len, embed_dim)
        # Step 2: Split the last embedding dimension into multiple heads
        #       shape after split: (batch_size, seq_len, num_heads, head_dim)
        # Step 3: Transpose the seq_len and num_heads dimensions for 
        #           parallel attention computation
        #       shape after transpose: (batch_size, num_heads, seq_len, head_dim)
        queries = self.W_q(x).view(batch_size, seq_len, self.num_heads, \
                                   self.head_dim).transpose(1,2)
        keys = self.W_k(x).view(batch_size, seq_len, self.num_heads, \
                                self.head_dim).transpose(1,2)
        values = self.W_v(x).vies(batch_size, seq_len, self.num_heads, \
                                  self.head_dim).transpose(1,2)

        # Step 4: Compute attention output of each head
        # Step 4-1: attention score computation
        #       shapes: 
        #           queries: (batch_size, num_heads, seq_len, head_dim)
        #           keys.transpose(): (batch_size, num_heads, head_dim, seq_len)
        #           attn_scores: (batch_size, num_heads, seq_len, seq_len)
        attn_scores = torch.matmul(queries, keys.transpose(-2,-1)) / math.sqrt(self.head_dim)
        # Step 4-2: apply mask
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
        # Step 4-3: softmax
        attn_weights = F.softmax(attn_scores, dim = -1)
        # Step 4-4: dropout
        attn_weights = self.dropout(attn_weights)
        # Step 4-5: attention value of each head
        #       shapes:
        #           attn_weights: (batch_size, num_heads, seq_len, seq_len)
        #           values: (batch_size, num_heads, seq_len, head_dim)
        #           attn_values: (batch_size, num_heads, seq_len, head_dim)
        attn_values = torch.matmul(attn_weights, values)

        # Step 5: rearrange dimensions and reshape to concatenate heads
        #       shapes:
        #           before rearrange: (batch_size, num_heads, seq_len, head_dim)
        #           after rearrange: (batch_size, seq_len, num_heads, head_dim)
        #           after concatenation: (batch_size, seq_len, embed_dim)
        attn_values = attn_values.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)

        # Step 6: Go through the final output projection
        output = self.W_out(attn_values)

        return output