In [1]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # Linear layers for projecting queries, keys, and values
        self.query_fc = nn.Linear(embed_size, embed_size)
        self.key_fc = nn.Linear(embed_size, embed_size)
        self.value_fc = nn.Linear(embed_size, embed_size)
        
        # Linear layer for the output of the attention heads
        self.fc_out = nn.Linear(embed_size, embed_size)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        
        # Apply linear transformations to obtain queries, keys, and values
        Q = self.query_fc(query)
        K = self.key_fc(key)
        V = self.value_fc(value)
        
        # Reshape the queries, keys, and values into multiple heads
        Q = Q.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)  # Shape: (batch_size, heads, query_len, head_dim)
        K = K.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)  # Shape: (batch_size, heads, key_len, head_dim)
        V = V.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)  # Shape: (batch_size, heads, value_len, head_dim)
        
        # Compute the scaled dot-product attention scores
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.head_dim).to(query.device))
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        
        # Apply the softmax function to get attention weights
        attention_weights = torch.softmax(scores, dim=-1)
        
        # Compute the output of the attention heads
        attention_output = torch.matmul(attention_weights, V)
        
        # Reshape the attention output and apply a linear transformation
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous()  # Shape: (batch_size, query_len, heads, head_dim)
        attention_output = attention_output.view(batch_size, -1, self.embed_size)  # Shape: (batch_size, query_len, embed_size)
        attention_output = self.fc_out(attention_output)
        
        return attention_output

In [2]:
import torch

# Create an instance of the MultiHeadAttention module
embed_size = 256
heads = 8
attention = MultiHeadAttention(embed_size, heads)

# Generate some random input tensors
batch_size = 4
query_len = 10
key_len = 12
value_len = 12
input_dim = embed_size
query = torch.randn(batch_size, query_len, input_dim)
key = torch.randn(batch_size, key_len, input_dim)
value = torch.randn(batch_size, value_len, input_dim)

# Compute the attention output
attention_output = attention(query, key, value)

# Print the shape of the attention output
print(attention_output.shape)
print(query)

torch.Size([4, 10, 256])
tensor([[[-0.6933,  0.1144, -0.6404,  ..., -1.5227, -0.8612, -0.4788],
         [-0.4414,  0.8099, -0.2313,  ...,  1.2108, -1.3390,  0.1567],
         [ 1.6410, -0.1443, -0.1261,  ..., -0.6409,  0.2189,  0.7566],
         ...,
         [ 1.5665,  1.0366,  0.8955,  ..., -0.6198,  1.4127,  0.0709],
         [-0.0961,  0.2212,  0.5636,  ..., -1.8731,  0.4230, -0.1373],
         [-0.0052, -0.3138, -2.1592,  ...,  0.5203, -1.4493,  1.4830]],

        [[-1.6706, -0.6624, -0.9619,  ...,  0.2553,  1.5594, -1.0813],
         [ 2.0302, -0.3904, -0.4299,  ...,  0.1839, -0.6835, -0.5188],
         [ 3.0286,  2.0743, -0.5512,  ...,  1.6619, -1.3685, -0.8486],
         ...,
         [-0.1518,  0.5441,  1.2427,  ...,  2.2955, -0.7163, -0.0622],
         [-1.7081, -0.6666, -1.3355,  ...,  1.3552,  0.6429,  1.0939],
         [-0.6722, -1.7599,  0.4784,  ...,  0.8531,  1.1426, -0.7205]],

        [[-0.3293,  0.7776, -1.4412,  ..., -0.6374,  0.9112, -1.0856],
         [-0.7597, -