In [None]:
import compyute as cp
import numpy as np
import torch

In [None]:
B, T, C, H = 1, 5, 8, 4
x = cp.random.uniform((B, T, C), dtype="float32")
x_torch = torch.tensor(x.to_numpy(), requires_grad=True)

In [None]:
from transformer.transformer import MultiHeadAttention

mha = MultiHeadAttention(emb_dim=C, n_heads=H, bias=False)

In [None]:

mha_torch = torch.nn.MultiheadAttention(C, H, bias=False, batch_first=True)

# PyTorch implements MHA as a batched matrix multiplication, they therefore only have one input proj matrix
in_proj_weights = [
    mha.modules[0].modules[0].q.w,
    mha.modules[0].modules[1].q.w,
    mha.modules[0].modules[2].q.w,
    mha.modules[0].modules[3].q.w,
    
    mha.modules[0].modules[0].k.w,
    mha.modules[0].modules[1].k.w,
    mha.modules[0].modules[2].k.w,
    mha.modules[0].modules[3].k.w,

    mha.modules[0].modules[0].v.w,
    mha.modules[0].modules[1].v.w,
    mha.modules[0].modules[2].v.w,
    mha.modules[0].modules[3].v.w,
]
in_proj_weights = cp.concatenate(in_proj_weights, axis=0)
out_proj_weights = mha.modules[1].w

mha_torch.in_proj_weight = torch.nn.Parameter(torch.tensor(in_proj_weights.to_numpy(), requires_grad=True))
mha_torch.out_proj.weight = torch.nn.Parameter(torch.tensor(out_proj_weights.to_numpy(), requires_grad=True))

In [None]:
print(in_proj_weights.shape)
print(out_proj_weights.shape)

In [None]:
print(mha_torch.in_proj_weight.shape)
print(mha_torch.out_proj.weight.shape)

In [None]:
with mha.training():
    out = mha(x)
out_torch, _ = mha_torch(x_torch, x_torch, x_torch, need_weights=False)

In [None]:
tol = 1e-6

In [None]:
np.allclose(out.to_numpy(), out_torch.detach().numpy(), atol=tol, rtol=tol)

In [None]:
dy = cp.random.normal(out.shape)
with mha.training():
    dx = mha.backward(dy)
out_torch.backward(torch.tensor(dy.to_numpy()))

np.allclose(dx.to_numpy(), x_torch.grad.detach().numpy(), atol=tol, rtol=tol)

In [None]:
in_proj_weight_grads = [
    mha.modules[0].modules[0].q.w.grad,
    mha.modules[0].modules[1].q.w.grad,
    mha.modules[0].modules[2].q.w.grad,
    mha.modules[0].modules[3].q.w.grad,

    mha.modules[0].modules[0].k.w.grad,
    mha.modules[0].modules[1].k.w.grad,
    mha.modules[0].modules[2].k.w.grad,
    mha.modules[0].modules[3].k.w.grad,

    mha.modules[0].modules[0].v.w.grad,
    mha.modules[0].modules[1].v.w.grad,
    mha.modules[0].modules[2].v.w.grad,
    mha.modules[0].modules[3].v.w.grad,
]
in_proj_weight_grads = cp.concatenate(in_proj_weight_grads, axis=0)
out_proj_weight_grads = mha.modules[1].w.grad

In [None]:
np.allclose(in_proj_weight_grads.to_numpy(), mha_torch.in_proj_weight.grad.detach().numpy(), atol=tol, rtol=tol)

In [None]:
np.allclose(out_proj_weight_grads.to_numpy(), mha_torch.out_proj.weight.grad.detach().numpy(), atol=tol, rtol=tol)