In [None]:
import compyute as cp
import numpy as np

In [None]:
B, T, C, H = 1, 5, 8, 4
x = cp.random.uniform((B, T, C), dtype="float32")

In [None]:
from transformer.transformer_batched import MultiHeadAttention

mha = MultiHeadAttention(emb_dim=C, n_heads=H, bias=False)

In [None]:
import torch

mha_torch = torch.nn.MultiheadAttention(C, H, bias=False, batch_first=True)

in_proj_weights = mha.modules[0].w
out_proj_weights = mha.modules[1].w

mha_torch.in_proj_weight = torch.nn.Parameter(torch.tensor(in_proj_weights.to_numpy(), requires_grad=True))
mha_torch.out_proj.weight = torch.nn.Parameter(torch.tensor(out_proj_weights.to_numpy(), requires_grad=True))

In [None]:
print(in_proj_weights.shape)
print(out_proj_weights.shape)

In [None]:
print(mha_torch.in_proj_weight.shape)
print(mha_torch.out_proj.weight.shape)

In [None]:
with mha.training():
    out = mha(x)
x_torch = torch.tensor(x.to_numpy(), requires_grad=True)
out_torch, _ = mha_torch(x_torch, x_torch, x_torch, need_weights=False)

In [None]:
tol = 1e-6

In [None]:
np.allclose(out.to_numpy(), out_torch.detach().numpy(), atol=tol, rtol=tol)

In [None]:
dy = cp.random.normal(out.shape)
with mha.training():
    dx = mha.backward(dy)
out_torch.backward(torch.tensor(dy.to_numpy()))

np.allclose(dx.to_numpy(), x_torch.grad.detach().numpy(), atol=tol, rtol=tol)

In [None]:
in_proj_weight_grads = mha.modules[0].w.grad
out_proj_weight_grads = mha.modules[1].w.grad

In [None]:
np.allclose(in_proj_weight_grads.to_numpy(), mha_torch.in_proj_weight.grad.detach().numpy(), atol=tol, rtol=tol)

In [None]:
np.allclose(out_proj_weight_grads.to_numpy(), mha_torch.out_proj.weight.grad.detach().numpy(), atol=tol, rtol=tol)