In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key   = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** 0.5

    def forward(self, x):
        # x: (batch, seq_len, embed_dim)
        Q, K, V = self.query(x), self.key(x), self.value(x)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        weights = F.softmax(scores, dim=-1)
        out = torch.matmul(weights, V)
        return out, weights

In [3]:
# Example
batch, seq_len, embed_dim = 2, 4, 8
x = torch.randn(batch, seq_len, embed_dim)
sa = SelfAttention(embed_dim)
out, attn = sa(x)
print("Output shape:", out.shape)
print("Attention weights shape:", attn.shape)

Output shape: torch.Size([2, 4, 8])
Attention weights shape: torch.Size([2, 4, 4])
