In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv_proj = nn.Linear(dim, dim * 3)  # Project to Q, K, V
        self.out_proj = nn.Linear(dim, dim)  # Final projection

    def forward(self, x):
        batch_size, seq_len, dim = x.shape
        qkv = self.qkv_proj(x)  # [B, S, 3*D]
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(dim=2)  # Each is [B, S, H, D_head]

        # Transpose for attention: [B, H, S, D_head]
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / (
            self.head_dim**0.5
        )  # [B, H, S, S]
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)  # [B, H, S, D_head]

        context = context.transpose(1, 2).reshape(batch_size, seq_len, dim)  # [B, S, D]
        return self.out_proj(context)  # Final projection


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))


class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, num_heads, ff_hidden_dim):
        super().__init__()
        self.self_attn = MultiHeadSelfAttention(dim, num_heads)
        self.norm1 = nn.LayerNorm(dim)
        self.ff = FeedForward(dim, ff_hidden_dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        # Self-attention block
        attn_out = self.self_attn(x)
        x = self.norm1(x + attn_out)

        # Feed-forward block
        ff_out = self.ff(x)
        x = self.norm2(x + ff_out)

        return x


# Full transformer encoder (stacked layers)
class TransformerEncoder(nn.Module):
    def __init__(self, dim, num_heads, ff_hidden_dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(dim, num_heads, ff_hidden_dim)
                for _ in range(num_layers)
            ]
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


# Demo usage
if __name__ == "__main__":
    batch_size = 2
    seq_len = 10
    dim = 64
    num_heads = 8
    ff_hidden_dim = 256
    num_layers = 2

    model = TransformerEncoder(dim, num_heads, ff_hidden_dim, num_layers)
    x = torch.randn(batch_size, seq_len, dim)
    out = model(x)
    print("Output shape:", out.shape)

Output shape: torch.Size([2, 10, 64])
