In [94]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

## nn.MultiheadAttention vs manual

#### nn.MultiheadAttention

In [217]:
mha = nn.MultiheadAttention(embed_dim=2, num_heads=1, bias=False, kdim=3, vdim=3)

In [218]:
for p in mha.parameters(): print(p.shape)

torch.Size([2, 2])
torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([2, 2])


In [219]:
q_proj_w = mha.q_proj_weight
k_proj_w = mha.k_proj_weight
v_proj_w = mha.v_proj_weight
out_proj_w = mha.out_proj.weight

In [220]:
q = torch.rand(5,3,2)  # (seq len, batch len, embed dim)
k = torch.rand(6,3,3)
v = torch.rand(6,3,3)

In [221]:
attn_out, attn_wts = mha(q, k, v)

In [222]:
print(attn_out.shape)
print(attn_wts.shape)

torch.Size([5, 3, 2])
torch.Size([3, 5, 6])


#### Manual

In [223]:
proj_q = q @ q_proj_w.t()
proj_k = k @ k_proj_w.t()
proj_v = v @ v_proj_w.t()
print(proj_q.shape)
print(proj_k.shape)
print(proj_v.shape)

torch.Size([5, 3, 2])
torch.Size([6, 3, 2])
torch.Size([6, 3, 2])


In [230]:
logits = torch.bmm(proj_q.transpose(0,1), proj_k.permute(1,2,0)) # q (L,B,E)->(B,L,E) @ k (S,B,E)->(B,E,S) = w (B,L,S)
weights = F.softmax(logits / np.sqrt(2), dim=2)
print(weights.shape)

torch.Size([3, 5, 6])


In [231]:
torch.allclose(weights, attn_wts)

True

In [226]:
out = torch.bmm(weights, proj_v.transpose(0,1)) # w (B,L,S) @ v (S,B,E)->(B,S,E) = out (B,L,E)
out = out.transpose(0,1)  # out (L,B,E)
out = F.linear(out, out_proj_w)
print(out.shape)

torch.Size([5, 3, 2])


In [227]:
torch.allclose(out, attn_out)

True