In [130]:
import torch
from torch import nn, Tensor

from arithmetic_lm.model.pos_encoding import RelativeMultiheadAttention

In [131]:
standard_mha = nn.MultiheadAttention(
    embed_dim=16, num_heads=8, dropout=0.0, batch_first=True
)
relative_mha = RelativeMultiheadAttention(
    embed_dim=16, num_heads=8, dropout=0.0, rel_pos_k=16, batch_first=True
)

In [132]:
# ensure same weights
relative_mha.load_state_dict(standard_mha.state_dict(), strict=False)

# ensure same output
# x: [B, L, D]
x = torch.randn(2, 3, 16)
standard_output = standard_mha(x, x, x, need_weights=False)[0]
print("standard:", standard_output.shape)
relative_output = relative_mha(x, x, x, need_weights=False)[0]
print("relative:", relative_output.shape)
assert torch.allclose(standard_output, relative_output, atol=1e-6)

standard: torch.Size([2, 3, 16])
relative: torch.Size([2, 3, 16])
