# MHA Parallel Implementaton Verification

Comparison against Pytorch to verify the correct implementation

In [9]:
import compyute as cp
import numpy as np
import torch

cp.random.set_seed(42)
tol = 1e-5 # stacking floating point errors due to lots of computation

In [10]:
B, T, C, H = 16, 256, 384, 6
x = cp.random.uniform((B, T, C), dtype=cp.float32)
x_torch = torch.tensor(x.to_numpy(), requires_grad=True)

In [11]:
from mha_parallel import ParallelMHA

mha = ParallelMHA(in_channels=C, n_heads=H, bias=False)
mha_torch = torch.nn.MultiheadAttention(C, H, bias=False, batch_first=True)

mha_torch.in_proj_weight = torch.nn.Parameter(torch.tensor(mha.in_proj.w.to_numpy()))
mha_torch.out_proj.weight = torch.nn.Parameter(torch.tensor(mha.out_proj.w.to_numpy()))

# forward pass
out = mha(x)
out_torch, _ = mha_torch(x_torch, x_torch, x_torch, need_weights=False)

# backward pass
dy = cp.random.normal(out.shape)
dx = mha.backward(dy)
out_torch.backward(torch.tensor(dy.to_numpy()))

Check if outputs match Pytorch

In [12]:
np.allclose(out.to_numpy(), out_torch.detach().numpy(), atol=tol, rtol=tol)

True

Check if input gradients match Pytorch

In [13]:
np.allclose(dx.to_numpy(), x_torch.grad.detach().numpy(), atol=tol, rtol=tol)

True

Check if input projection weight gradients match Pytorch

In [14]:
np.allclose(mha.in_proj.w.grad.to_numpy(), mha_torch.in_proj_weight.grad.detach().numpy(), atol=tol, rtol=tol)

True

Check if output projection weight gradients match Pytorch

In [15]:
np.allclose(mha.out_proj.w.grad.to_numpy(), mha_torch.out_proj.weight.grad.detach().numpy(), atol=tol, rtol=tol)

True