<a href="https://colab.research.google.com/github/kasakun/CodeBook/blob/master/ml_coding/multihead_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch

In [10]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, num_heads, embedding_dim):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.embedding_dim = embedding_dim
        self.head_dim = embedding_dim // num_heads


        self.W_Q = torch.nn.Linear(embedding_dim, embedding_dim)
        self.W_K = torch.nn.Linear(embedding_dim, embedding_dim)
        self.W_V = torch.nn.Linear(embedding_dim, embedding_dim)


    def scaled_dot_production(self, query, key, value, mask=None):
        # input query shape:  [batch_size, num_heads, seq_lenth, head_dim]
        # scores shape : [batch_size, num_heads, seq_len, seq,len]
        scores = torch.matmul(query, key.transpose(-1, -2))

        scores = scores / (self.head_dim ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # shape [batch_size, num_heads, seq_len, seq_len]
        attention = torch.softmax(scores, dim=-1)

        # shape [batch_size, num_heads, seq_len, head_dim]
        return torch.matmul(attention, value)


    def forward(self, query, key, value):
        # query -> [batch_size, seq_len, embedding_size]
        query = self.W_Q(query)
        key = self.W_K(key)
        value = self.W_V(value)

        batch_size, seq_len, embedding_dim = query.shape

        # Convert to shape: [batch_size, num_heads, seq_len, head_dim]
        query = query.view(batch_size, -1, self.num_heads, self.head_dim).transpose(-1, -2)
        key = key.view(batch_size, -1, self.num_heads, self.head_dim).transpose(-1, -2)
        value = value.view(batch_size, -1, self.num_heads, self.head_dim).transpose(-1, -2)


        attention = self.scaled_dot_production(query, key, value)

        attention = attention.transpose(1, 2).contiguous().view(batch_size, -1, embedding_dim)
        # [batch_size, seq_len, embded_dim]
        return attention

In [11]:
embed_dim = 8     # Embedding dimension (size of token vectors)
num_heads = 2     # Number of attention heads
seq_len = 5       # Length of the input sequence
batch_size = 3    # Number of sequences in the batch

# Create random input tensors for Q, K, and V
query = torch.rand(batch_size, seq_len, embed_dim)  # (batch_size, seq_len, embed_dim)
key = torch.rand(batch_size, seq_len, embed_dim)    # (batch_size, seq_len, embed_dim)
value = torch.rand(batch_size, seq_len, embed_dim)  # (batch_size, seq_len, embed_dim)

# Instantiate the multi-head attention layer
multihead_attn = MultiHeadAttention(num_heads, embed_dim)

# Forward pass through the multi-head attention layer
output = multihead_attn(query, key, value)
print(output.shape)

torch.Size([3, 5, 8])
