In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F

# Define vocabulary size and embedding dimensions
vocab_size = 6
embedding_dim = 6
num_heads = 2

# Initialize word embeddings
torch.manual_seed(42)
embedding_layer = nn.Embedding(vocab_size, embedding_dim)
word_indices = torch.arange(vocab_size)
word_embeddings = embedding_layer(word_indices)

# Define Multi-Head Self-Attention Class
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()

        if embed_dim % num_heads != 0:
            raise ValueError("Embedding dimension must be divisible by the number of heads")

        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.W_Q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_K = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_V = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_length, embed_dim = x.shape

        Q = self.W_Q(x)
        K = self.W_K(x)
        V = self.W_V(x)

        Q = Q.view(seq_length, self.num_heads, self.head_dim).transpose(0, 1)
        K = K.view(seq_length, self.num_heads, self.head_dim).transpose(0, 1)
        V = V.view(seq_length, self.num_heads, self.head_dim).transpose(0, 1)

        d_k = self.head_dim ** 0.5
        attention_scores = (Q @ K.transpose(-2, -1)) / d_k
        attention_weights = F.softmax(attention_scores, dim=-1)

        attention_output = attention_weights @ V
        attention_output = attention_output.transpose(0, 1).contiguous().view(seq_length, embed_dim)

        output = self.W_out(attention_output)
        return output, attention_weights

# Instantiate the attention layer
self_attention = MultiHeadSelfAttention(embedding_dim, num_heads)

# Prepare embeddings for processing
word_embeddings = word_embeddings.unsqueeze(0)
output, attention_weights = self_attention(word_embeddings)

# Print results
print("Attention Weights:\n", attention_weights)
print("\nFinal Output After Attention:\n", output)