# Single Head Attention
Let's denote the input sequence as X = [x₁, x₂, ..., xₙ], where each element xᵢ represents a token or feature. We can break down the attention mechanism into mathematical steps as follows:

1. Query, Key, and Value:
   - Query: q = Wq · X, where Wq is a learnable weight matrix.
   - Key: K = WK · X, where WK is a learnable weight matrix.
   - Value: V = WV · X, where WV is a learnable weight matrix.

   Here, q, K, and V are the query, key, and value vectors, respectively.

2. Similarity Calculation:
   - Compute the similarity scores between the query and key vectors, denoted as S = qᵀ · K.
   
   The similarity scores S measure the relevance between the query and each key-value pair.

3. Attention Weights:
   - Apply softmax to the similarity scores to obtain attention weights, denoted as A = softmax(S).
   
   Softmax ensures that the attention weights sum up to 1, representing the importance or contribution of each key-value pair.

4. Weighted Sum:
   - Compute the weighted sum of the value vectors using the attention weights, denoted as C = A · V.
   
   The weighted sum C represents the context or attended information, where more weight is given to the value vectors that have higher attention weights.

In summary, the attention mechanism computes attention weights by comparing the query vector with the key vectors. These weights are then used to compute a weighted sum of the value vectors, which represents the attended information or context. This mechanism allows the model to selectively focus on different parts of the input sequence based on their relevance to the query.

Note that the above equations provide a general outline of the attention mechanism. Different variants and extensions of attention may incorporate additional components or modifications based on the specific task or architecture being used.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SingleHeadAttention(nn.Module):
    def __init__(self, input_size):
        super(SingleHeadAttention, self).__init__()
        
        self.query_transform = nn.Linear(input_size, input_size)
        self.key_transform = nn.Linear(input_size, input_size)
        self.value_transform = nn.Linear(input_size, input_size)

    def forward(self, x):
        # Compute query, key, and value vectors
        query = self.query_transform(x)
        key = self.key_transform(x)
        value = self.value_transform(x)

        # Compute similarity scores
        scores = torch.matmul(query, key.transpose(-2, -1))
        
        # Compute attention weights using softmax
        weights = F.softmax(scores, dim=-1)
        
        # Compute weighted sum of values
        weighted_sum = torch.matmul(weights, value)
        
        return weighted_sum

In [3]:
# Example usage
input_size = 64
sequence_length = 10
batch_size = 32

# Generate random input tensor
x = torch.randn(batch_size, sequence_length, input_size)

# Create and apply single-head attention
attention = SingleHeadAttention(input_size)
output = attention(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([32, 10, 64])
Output shape: torch.Size([32, 10, 64])


# Multi-Head
Multi-head attention extends the basic attention mechanism by incorporating multiple attention heads. Each attention head performs its own attention computation independently. Here's the mathematical formulation for multi-head attention:

Given an input sequence X = [x₁, x₂, ..., xₙ], the multi-head attention mechanism can be broken down into the following steps:

1. Input Transformation:
   - Perform linear transformations to obtain query (Q), key (K), and value (V) vectors for each attention head.
   - Let Qᵢ = WQi · X, Kᵢ = WKi · X, and Vᵢ = WVi · X, where WQi, WKi, and WVi are learnable weight matrices specific to the i-th attention head.

2. Similarity Calculation:
   - Compute the similarity scores for each attention head: Sᵢ = Qᵢᵀ · Kᵢ, where Sᵢ represents the similarity scores for the i-th attention head.

3. Attention Weights:
   - Apply softmax to the similarity scores for each attention head to obtain attention weights: Aᵢ = softmax(Sᵢ).

4. Weighted Sum:
   - Compute the weighted sum of the value vectors for each attention head using the attention weights: Cᵢ = Aᵢ · Vᵢ.

5. Concatenation and Projection:
   - Concatenate the outputs from all attention heads: C = [C₁, C₂, ..., Cₖ], where k is the total number of attention heads.
   - Apply a linear transformation to the concatenated outputs: Y = WO · C, where WO is a learnable weight matrix.

Here, Y represents the final output of the multi-head attention mechanism, which incorporates information from multiple attention heads.

Note that in practice, there are often additional steps, such as layer normalization, residual connections, and feed-forward layers, applied to enhance the performance and stability of the multi-head attention mechanism. However, the steps outlined above capture the essence of the mathematical formulation for multi-head attention.

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        
        self.num_heads = num_heads
        self.head_size = input_size // num_heads
        
        self.query_transform = nn.Linear(input_size, input_size)
        self.key_transform = nn.Linear(input_size, input_size)
        self.value_transform = nn.Linear(input_size, input_size)
        self.output_transform = nn.Linear(input_size, input_size)

    def forward(self, x):
        # Split input into multiple heads
        batch_size, seq_len, _ = x.size()
        queries = self.query_transform(x).view(batch_size, seq_len, self.num_heads, self.head_size)
        keys = self.key_transform(x).view(batch_size, seq_len, self.num_heads, self.head_size)
        values = self.value_transform(x).view(batch_size, seq_len, self.num_heads, self.head_size)
        
        # Transpose dimensions for matrix multiplication
        queries = queries.transpose(1, 2)  # (batch_size, num_heads, seq_len, head_size)
        keys = keys.transpose(1, 2)  # (batch_size, num_heads, seq_len, head_size)
        values = values.transpose(1, 2)  # (batch_size, num_heads, seq_len, head_size)
        
        # Compute similarity scores
        scores = torch.matmul(queries, keys.transpose(-2, -1))  # (batch_size, num_heads, seq_len, seq_len)
        
        # Compute attention weights using softmax
        weights = F.softmax(scores / (self.head_size ** 0.5), dim=-1)
        
        # Compute weighted sum of values
        weighted_sum = torch.matmul(weights, values)  # (batch_size, num_heads, seq_len, head_size)
        
        # Transpose and reshape for concatenation
        weighted_sum = weighted_sum.transpose(1, 2).contiguous()  # (batch_size, seq_len, num_heads, head_size)
        weighted_sum = weighted_sum.view(batch_size, seq_len, -1)  # (batch_size, seq_len, input_size)
        
        # Apply linear transformation
        output = self.output_transform(weighted_sum)
        
        return output

In [5]:
# Example usage
input_size = 64
sequence_length = 10
batch_size = 32
num_heads = 8

# Generate random input tensor
x = torch.randn(batch_size, sequence_length, input_size)

# Create and apply multi-head attention
attention = MultiHeadAttention(input_size, num_heads)
output = attention(x)

print("Input shape:", x.shape)
print("Output shape:", output.shape)

Input shape: torch.Size([32, 10, 64])
Output shape: torch.Size([32, 10, 64])
