In [1]:
import sys
sys.path.append('../')
from attention import *
import jax.numpy as jnp
import jax
import torch

rng = jax.random.PRNGKey(0)

In [2]:
n_heads = 1
emb_size = 2
bias = False

jax_mha = MultiHeadAttention(emb_size, n_heads, bias=bias, v_bias=False)

torch_mha = torch.nn.MultiheadAttention(emb_size, n_heads, bias=bias)

In [3]:
# print all learnable params in torch_mha

for name, param in torch_mha.named_parameters():
    print(name, param.shape)
    
mha_state = jax_mha.init_state(rng)
for x in jax.tree_util.tree_leaves(mha_state):
    print(x.shape)

in_proj_weight torch.Size([6, 2])
out_proj.weight torch.Size([2, 2])
(2, 2)
(2, 2)
(2, 2)
(2, 2)


In [4]:
torch_in_proj_weight = torch_mha.in_proj_weight
print(torch_in_proj_weight.shape)

torch_weights = (torch_in_proj_weight[0:emb_size, :],
                 torch_in_proj_weight[emb_size:2*emb_size, :],
                 torch_in_proj_weight[2*emb_size:3*emb_size, :],
                 torch_mha.out_proj.weight)

torch_weights = tuple(DenseState(jnp.array(w.detach().numpy()), None) for w in torch_weights)

jax_mha_state = MultiHeadAttentionState(
    *torch_weights
)


torch.Size([6, 2])


In [5]:
torch.random.manual_seed(0)

context_len = 3
batch_size = 3
x = torch.randn(context_len, batch_size, emb_size, requires_grad=False)

with torch.no_grad():
    torch_out = torch_mha(x, x, x, need_weights=False)[0].detach().numpy()

print(torch_out)
print(torch_out.shape) # (context_len, batch_size, emb_size)

[[[-0.09932961 -0.22264259]
  [ 0.0832648   0.1969153 ]
  [ 0.11124789  0.22671439]]

 [[-0.6437645  -1.113713  ]
  [ 0.06693697  0.16663295]
  [ 0.00666594  0.04319041]]

 [[-0.7300143  -1.2334375 ]
  [ 0.08545306  0.20034644]
  [ 0.03891064  0.11145894]]]
(3, 3, 2)


In [6]:
x_jnp = jnp.array(x.detach().numpy())

print(f"Calling jax mha forward with shape {x_jnp.shape} and type {type(x_jnp)}")
print()
jax_out = jax_mha.forward(jax_mha_state, x_jnp, x_jnp, x_jnp)
print(jax_out)
print(jax_out.shape)

print()
print(np.allclose(torch_out, jax_out, atol=1e-6))

Calling jax mha forward with shape (3, 3, 2) and type <class 'jaxlib.xla_extension.ArrayImpl'>

q.shape = (3, 3, 2), k.shape = (3, 3, 2), v.shape = (3, 3, 2)
x has type <class 'jax._src.interpreters.batching.BatchTracer'>, d_k = 2
x has type <class 'jax._src.interpreters.batching.BatchTracer'>, d_k = 2
x has type <class 'jax._src.interpreters.batching.BatchTracer'>, d_k = 2
# Shapes after linear transform and split into heads
query.shape = (3, 3, 1, 2), key.shape = (3, 3, 1, 2), value.shape = (3, 3, 1, 2)
Scaling using 0.7071067690849304
q * k^T = s.shape = (3, 3, 3, 1)
Softmax attn.shape = (3, 3, 3, 1)
*v shape = (3, 3, 1, 2)
After reshape: x.shape = (3, 3, 2)
out.shape = (3, 3, 2)
[[[-0.09932958 -0.22264253]
  [ 0.08326481  0.19691533]
  [ 0.11124788  0.22671437]]

 [[-0.64376456 -1.113713  ]
  [ 0.06693695  0.1666329 ]
  [ 0.00666592  0.04319039]]

 [[-0.7300143  -1.2334375 ]
  [ 0.08545302  0.20034638]
  [ 0.03891061  0.11145889]]]
(3, 3, 2)

True
