In [5]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        
        # Linear layers for projecting queries, keys, and values
        self.query_fc = nn.Linear(embed_size, embed_size)
        self.key_fc = nn.Linear(embed_size, embed_size)
        self.value_fc = nn.Linear(embed_size, embed_size)
        
        # Linear layer for the output of the attention heads
        self.fc_out = nn.Linear(embed_size, embed_size)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        
        # Apply linear transformations to obtain queries, keys, and values
        Q = self.query_fc(query)
        K = self.key_fc(key)
        V = self.value_fc(value)
        
        # Reshape the queries, keys, and values into multiple heads
        Q = Q.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)  # Shape: (batch_size, heads, query_len, head_dim)
        K = K.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)  # Shape: (batch_size, heads, key_len, head_dim)
        V = V.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)  # Shape: (batch_size, heads, value_len, head_dim)
        
        # Compute the scaled dot-product attention scores
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.head_dim).to(query.device))
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        
        # Apply the softmax function to get attention weights
        attention_weights = torch.softmax(scores, dim=-1)
        
        # Compute the output of the attention heads
        attention_output = torch.matmul(attention_weights, V)
        
        # Reshape the attention output and apply a linear transformation
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous()  # Shape: (batch_size, query_len, heads, head_dim)
        attention_output = attention_output.view(batch_size, -1, self.embed_size)  # Shape: (batch_size, query_len, embed_size)
        attention_output = self.fc_out(attention_output)
        
        return attention_output

In [10]:
import torch

# Create an instance of the MultiHeadAttention module
embed_size = 256
heads = 8
attention = MultiHeadAttention(embed_size, heads)

# Generate some random input tensors
batch_size = 4
query_len = 10
key_len = 12
value_len = 12
input_dim = embed_size
query = torch.randn(batch_size, query_len, input_dim)
key = torch.randn(batch_size, key_len, input_dim)
value = torch.randn(batch_size, value_len, input_dim)

# Compute the attention output
attention_output = attention(query, key, value)

# Print the shape of the attention output
print(attention_output.shape)
print(query)

torch.Size([4, 10, 256])
tensor([[[-1.5570,  0.4454, -0.3750,  ..., -1.5871,  0.3023,  0.9385],
         [-1.1305,  1.5727,  1.2512,  ...,  2.3726, -0.0412,  0.3728],
         [ 0.9033, -1.4288,  0.1932,  ..., -1.5245, -2.7567, -1.0425],
         ...,
         [-0.7975, -0.5387,  0.7566,  ..., -0.3096,  1.3038, -0.2501],
         [ 0.2265, -1.7868,  0.4760,  ..., -1.0768, -0.3601, -0.7491],
         [-0.9038,  2.5302,  0.6968,  ..., -0.4758, -0.5395, -0.3028]],

        [[-0.9213,  0.2771,  1.5257,  ...,  1.4238,  1.4708, -0.4367],
         [ 0.9364,  0.4308, -2.0451,  ...,  1.3843,  0.0173,  0.4427],
         [ 0.9470, -0.0933,  2.0222,  ...,  0.3437, -1.1608, -0.0854],
         ...,
         [ 0.7012,  0.1685, -0.8767,  ..., -0.0453, -1.4106, -0.7392],
         [ 0.7173,  0.4339, -0.7196,  ..., -0.4849,  0.5645,  1.6565],
         [-0.5110, -0.9833,  0.1971,  ..., -0.5758, -0.8260,  0.1002]],

        [[ 0.1395, -0.5209, -1.1628,  ..., -1.8587,  0.3551,  1.0113],
         [ 0.9975, -