In [2]:
from importlib.metadata import version

pkgs = [
   "torch",
]
for p in pkgs:
    print(f"{p} version: {version(p)}")

torch version: 2.4.1


In [3]:
import math
from typing import Optional, List
import torch
from torch import nn
from labml import tracker

- regarding `def foward(x)`:
    - Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`. We apply the linear transformation to the last dimension and split that into the heads.
- Output shape `return x` has `[seq_len, batch_size, heads, d_k] or [batch_size, heads, d_model]`

In [4]:
class PrepareForMultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        self.heads = heads
        self.d_k = d_k
    
    def forward(self, x: torch.Tensor):
        head_shape = x.shape[:-1]
        x = self.linear(x) # Linear transformation
        x = x.view(*head_shape, self.heads, self.d_k) # Split last dimensions into heads
        return x

## Multi-Head Attention Module
- this computes scaled multi-head attention for a given `query`, `key` and `value` vectors
- In simple terms: It finds the key that matches the query, and gets the value of those keys
- It uses dot product of query and key as the indicator of how matching they are
- Before taking the softmax the dot-products are scaled by $\frac{1}{\sqrt{d_k}}$
- This is done to avoid large dot-product values causing softmax to give very small gradients when $d_k$ is large
- Softmax is calculated along the axis of the sequence (or time)
- regarding `def get_scores(query, key)`:
    - This calculates $QK^T$
    - but this method can also be overriden for other variations like relative attention
- regarding `def prepare_mask(mask, query_shape, key_shape)`:
    - mask has shape `[seq_len_q, seq_len_k, batch_size]` , where first dimension is the query dimension. If the query dimension is equal to 1 it will be broadcasted
    - resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`

In [15]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
        super().__init__()
        self.d_k = d_model // heads # Number of features per head
        self.heads = heads
        # These transform the query, key and value vectors for multi-head attention
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)
        self.softmax = nn.Softmax(dim=1) # Softmax for attention along the time dimension of 'key'
        self.output = nn.Linear(d_model, d_model) # Output layer
        self.dropout = nn.Dropout(dropout_prob) # Dropout
        self.scale = 1 / math.sqrt(self.d_k) # Scaling factor before the softmax
        # We store the attentions so that it can be used for logging, or other computations if needed
        self.attn = None

    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
        return torch.einsum('ibhd,jbhd->ijbh', query, key)
    
    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

        mask = mask.unsqueeze(-1) # Same mask applied to all heads
        return mask
    
    def forward(self, *, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
        seq_len, batch_size, _ = query.shape # query , key and value have shape [seq_len, batch_size, d_model]
        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)
        
        # Prepare query, key and value for attention computation. These will then have shape [seq_len, batch_size, heads, d_k]
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        # Compute the attention scores QK^T. This gives a tensor of shape [seq_len, seq_len, batch_size, heads]
        scores = self.get_scores(query, key)
        scores *= self.scale # Scaling the scores

        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax attention along the key sequence dimension
        attn = self.softmax(scores)

        # Save attentions if debugging
        tracker.debug('attn', attn)

        # Apply Dropout
        attn = self.dropout(attn)

        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)

        # Save attentions for any other calculations
        self.attn = attn.detach()

        # Concatenate multiple heads
        x = x.reshape(seq_len, batch_size, -1)

        return self.output(x) # Output layer

## Testing MHA

In [16]:
d_model = 512
heads = 8
seq_len = 10
batch_size = 32
dropout_prob = 0.1

In [17]:
attention_layer = MultiHeadAttention(heads=heads, d_model=d_model, dropout_prob=dropout_prob)

In [18]:
query = torch.rand(seq_len, batch_size, d_model) # Shape: [seq_len, batch_size, d_model]
key = torch.rand(seq_len, batch_size, d_model) # Shape: [seq_len, batch_size, d_model]
value = torch.rand(seq_len, batch_size, d_model) # Shape: [seq_len, batch_size, d_model]

In [21]:
mask = torch.ones(1, seq_len, seq_len) # Optional mask (shape: [batch_size, seq_len])

In [22]:
output = attention_layer(query=query, key=key, value=value, mask=mask) # forward pass

RuntimeError: The size of tensor a (10) must match the size of tensor b (32) at non-singleton dimension 2