# MHA Sequential Implementaton Verification

Comparison against Pytorch to verify the correct implementation

In [1]:
import compyute as cp
import numpy as np
import torch

cp.random.set_seed(42)
tol = 1e-5 # stacking floating point errors due to lots of computation

In [2]:
B, T, C, H = 16, 256, 384, 6
x = cp.random.uniform((B, T, C), dtype=cp.float32)
x_torch = torch.tensor(x.to_numpy(), requires_grad=True)

In [3]:
from mha_sequential import SequentialMHA

mha = SequentialMHA(in_channels=C, n_heads=H, bias=False)
mha_torch = torch.nn.MultiheadAttention(C, H, bias=False, batch_first=True)

In [4]:
# PyTorch implements MHA as a batched matrix multiplication, they therefore only have one
# input proj matrix containing queries, keys and values for all heads
in_proj_weights = cp.concat([
    mha.heads[0].q_proj.w,
    mha.heads[1].q_proj.w,
    mha.heads[2].q_proj.w,
    mha.heads[3].q_proj.w,
    mha.heads[4].q_proj.w,
    mha.heads[5].q_proj.w,
    mha.heads[0].k_proj.w,
    mha.heads[1].k_proj.w,
    mha.heads[2].k_proj.w,
    mha.heads[3].k_proj.w,
    mha.heads[4].k_proj.w,
    mha.heads[5].k_proj.w,
    mha.heads[0].v_proj.w,
    mha.heads[1].v_proj.w,
    mha.heads[2].v_proj.w,
    mha.heads[3].v_proj.w,
    mha.heads[4].v_proj.w,
    mha.heads[5].v_proj.w
], dim=0)
out_proj_weights = mha.out_proj.w

mha_torch.in_proj_weight = torch.nn.Parameter(torch.tensor(in_proj_weights.to_numpy()))
mha_torch.out_proj.weight = torch.nn.Parameter(torch.tensor(out_proj_weights.to_numpy()))

# forward pass
out = mha(x)
out_torch, _ = mha_torch(x_torch, x_torch, x_torch, need_weights=False)

Check if outputs match Pytorch

In [5]:
np.allclose(
    out.to_numpy(),
    out_torch.detach().numpy(),
    atol=tol,
    rtol=tol
)

True

In [6]:
# backward pass
dy = cp.random.normal(out.shape)
dx = mha.backward(dy)
out_torch.backward(torch.tensor(dy.to_numpy()))

Check if input gradients match Pytorch

In [7]:
np.allclose(
    dx.to_numpy(),
    x_torch.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)

True

Check if input projection gradients match Pytorch

In [8]:
in_proj_weight_grads = cp.concat([
    mha.heads[0].q_proj.w.grad,
    mha.heads[1].q_proj.w.grad,
    mha.heads[2].q_proj.w.grad,
    mha.heads[3].q_proj.w.grad,
    mha.heads[4].q_proj.w.grad,
    mha.heads[5].q_proj.w.grad,
    mha.heads[0].k_proj.w.grad,
    mha.heads[1].k_proj.w.grad,
    mha.heads[2].k_proj.w.grad,
    mha.heads[3].k_proj.w.grad,
    mha.heads[4].k_proj.w.grad,
    mha.heads[5].k_proj.w.grad,
    mha.heads[0].v_proj.w.grad,
    mha.heads[1].v_proj.w.grad,
    mha.heads[2].v_proj.w.grad,
    mha.heads[3].v_proj.w.grad,
    mha.heads[4].v_proj.w.grad,
    mha.heads[5].v_proj.w.grad
], dim=0)

np.allclose(
    in_proj_weight_grads.to_numpy(),
    mha_torch.in_proj_weight.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)

True

Check if output projection gradients match Pytorch

In [9]:
np.allclose(
    mha.out_proj.w.grad.to_numpy(),
    mha_torch.out_proj.weight.grad.detach().numpy(),
    atol=tol,
    rtol=tol
)

True