In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
D_MODEL = 100
HEAD_SIZE = 10
BLOCK_SIZE = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Head(nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(D_MODEL, head_size)
    self.query = nn.Linear(D_MODEL, head_size)
    self.value = nn.Linear(D_MODEL, head_size)

  def forward(self, x):
    k = self.key(x) # (BLK_SZ, D_MDL) @ (D_MDL, HD_SZ) ->  (BLK_SZ, HD_SZ)
    q = self.query(x) # BLK_SZ, D_MDL) @ (D_MDL, HD_SZ) ->  (BLK_SZ, HD_SZ)
    v = self.value(x) # BLK_SZ, D_MDL) @ (D_MDL, HD_SZ) ->  (BLK_SZ, HD_SZ)
    attn = q @ k.T # (BLK_SZ, HD_SZ) @ (HD_SZ, BLK_SZ) -> (BLK_SZ, BLK_SZ)
    value = attn @ v # (BLK_SZ, BLK_SZ) @ (D_MDL, HD_SZ) -> (BLK_SZ, HD_SZ)
    return F.softmax(value, dim=-1)



In [None]:
class MultiHead(nn.Module):
  def __init__(self, head_size):
    super().__init__()
    assert D_MODEL % head_size == 0, "D_MODEL must be divisible by head_size"
    num_heads = D_MODEL // head_size
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])

  def forward(self, x):
    out = torch.cat([head(x) for head in self.heads], dim=-1)
    return F.softmax(out, dim=-1)


In [None]:
X = torch.randn((BLOCK_SIZE, D_MODEL)).to(DEVICE)
model = MultiHead(HEAD_SIZE).to(DEVICE)
print(model(X).shape)