In [14]:
import compyute as cp
import numpy as np
import torch

In [15]:
B, T, C, H = 1, 5, 8, 4
x = cp.random.uniform((B, T, C), dtype=cp.float32)
x_torch = torch.tensor(x.to_numpy(), requires_grad=True)

In [16]:
from transformer import MultiHeadAttention

mha = MultiHeadAttention(in_channels=C, n_heads=H)

In [17]:
mha_torch = torch.nn.MultiheadAttention(C, H, bias=False, batch_first=True)

# PyTorch implements MHA as a batched matrix multiplication, they therefore only have one
# input proj matrix containing queries, keys and values for all heads
in_proj_weights = [
    mha.query_proj.w,
    mha.key_proj.w,
    mha.value_proj.w
]
in_proj_weights = cp.concatenate(in_proj_weights, axis=0)
out_proj_weights = mha.modules[-1].w

mha_torch.in_proj_weight = torch.nn.Parameter(torch.tensor(in_proj_weights.to_numpy()))
mha_torch.out_proj.weight = torch.nn.Parameter(torch.tensor(out_proj_weights.to_numpy()))

In [18]:
print(in_proj_weights.shape)
print(out_proj_weights.shape)

(24, 8)
(8, 8)


In [19]:
print(mha_torch.in_proj_weight.shape)
print(mha_torch.out_proj.weight.shape)

torch.Size([24, 8])
torch.Size([8, 8])


In [20]:
with mha.train():
    out = mha(x)
out_torch, _ = mha_torch(x_torch, x_torch, x_torch, need_weights=False)

In [21]:
tol = 1e-6

In [22]:
np.allclose(
    out.to_numpy(),
    out_torch.detach().numpy(),
    atol=tol,
    rtol=tol
)

True

In [23]:
dy = cp.random.normal(out.shape)
with mha.train():
    dx = mha.backward(dy)
out_torch.backward(torch.tensor(dy.to_numpy()))

np.allclose(
    dx.to_numpy(),
    x_torch.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)

True

In [24]:
in_proj_weight_grads = [
    mha.query_proj.w.grad,
    mha.key_proj.w.grad,
    mha.value_proj.w.grad,
]
in_proj_weight_grads = cp.concatenate(in_proj_weight_grads, axis=0)
out_proj_weight_grads = mha.modules[-1].w.grad

In [25]:
np.allclose(
    in_proj_weight_grads.to_numpy(),
    mha_torch.in_proj_weight.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)

True

In [26]:
np.allclose(
    out_proj_weight_grads.to_numpy(),
    mha_torch.out_proj.weight.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)

True