In [None]:
import compyute as cp
import numpy as np
import torch

In [None]:
cp.random.set_seed(42)

In [None]:
# B, T, C, H = 16, 256, 384, 6
B, T, C, H = 1, 4, 8, 1
x = cp.random.uniform((B, T, C), dtype=cp.float32)
x_torch = torch.tensor(x.to_numpy(), requires_grad=True)

In [None]:
from mha_semibatched import MultiHeadAttention

mha = MultiHeadAttention(in_channels=C, n_heads=H, bias=False)

In [None]:
mha_torch = torch.nn.MultiheadAttention(C, H, bias=False, batch_first=True)

# PyTorch implements MHA as a batched matrix multiplication, they therefore only have one
# input proj matrix containing queries, keys and values for all heads
in_proj_weights = cp.concat([mha.q_proj.w, mha.k_proj.w, mha.v_proj.w], dim=0)
out_proj_weights = mha.out_proj.w

mha_torch.in_proj_weight = torch.nn.Parameter(torch.tensor(in_proj_weights.to_numpy()))
mha_torch.out_proj.weight = torch.nn.Parameter(torch.tensor(out_proj_weights.to_numpy()))

In [None]:
out = mha(x)
out_torch, _ = mha_torch(x_torch, x_torch, x_torch, need_weights=False)

In [None]:
tol = 1e-5 # stacking floating point errors due to lots of computation

In [None]:
np.allclose(
    out.to_numpy(),
    out_torch.detach().numpy(),
    atol=tol,
    rtol=tol
)

In [None]:
dy = cp.random.normal(out.shape)
dx = mha.backward(dy)
out_torch.backward(torch.tensor(dy.to_numpy()))

In [None]:
np.allclose(
    dx.to_numpy(),
    x_torch.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)

In [None]:
in_proj_weight_grads = cp.concat([mha.q_proj.w.grad, mha.k_proj.w.grad, mha.v_proj.w.grad], dim=0)
out_proj_weight_grads = mha.out_proj.w.grad

In [None]:
np.allclose(
    in_proj_weight_grads.to_numpy(),
    mha_torch.in_proj_weight.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)

In [None]:
np.allclose(
    out_proj_weight_grads.to_numpy(),
    mha_torch.out_proj.weight.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)