In [1]:
import sys
sys.path.append('../')
from attention import *
import jax.numpy as jnp
import jax
import torch

rng = jax.random.PRNGKey(0)

In [None]:
n_heads = 1
emb_size = 2
bias = False

jax_mha = MultiHeadAttention(emb_size, n_heads, bias=bias, v_bias=False)

torch_mha = torch.nn.MultiheadAttention(emb_size, n_heads, bias=bias)

In [None]:
# print all learnable params in torch_mha

for name, param in torch_mha.named_parameters():
    print(name, param.shape)
    
mha_state = jax_mha.init_state(rng)
for x in jax.tree_util.tree_leaves(mha_state):
    print(x.shape)

In [None]:
torch_in_proj_weight = torch_mha.in_proj_weight
print(torch_in_proj_weight.shape)

torch_weights = (torch_in_proj_weight[0:emb_size, :],
                 torch_in_proj_weight[emb_size:2*emb_size, :],
                 torch_in_proj_weight[2*emb_size:3*emb_size, :],
                 torch_mha.out_proj.weight)

torch_weights = tuple(DenseState(jnp.array(w.detach().numpy()), None) for w in torch_weights)

jax_mha_state = MultiHeadAttentionState(
    *torch_weights
)


In [None]:
torch.random.manual_seed(0)

context_len = 3
batch_size = 3
x = torch.randn(context_len, batch_size, emb_size, requires_grad=False)

with torch.no_grad():
    torch_out = torch_mha(x, x, x, need_weights=False)[0].detach().numpy()

print(torch_out)
print(torch_out.shape) # (context_len, batch_size, emb_size)

In [None]:
x_jnp = jnp.array(x.detach().numpy())

print(f"Calling jax mha forward with shape {x_jnp.shape} and type {type(x_jnp)}")
print()
jax_out = jax_mha.forward(jax_mha_state, x_jnp, x_jnp, x_jnp)
print(jax_out)
print(jax_out.shape)

print()
print(np.allclose(torch_out, jax_out, atol=1e-6))

In [2]:
t = jax.random.normal(rng, (2, 2, 3))
print(t.shape)
l = (2, 3)
t.shape[-len(l):] == l

(2, 2, 3)


True

In [3]:
print(t.shape)
print(t)
norm_dims = (3,)
axes_to_reduce = tuple(range(-len(norm_dims), 0))
print(axes_to_reduce)
jnp.mean(t, axis=axes_to_reduce, keepdims=True)

(2, 2, 3)
[[[ 1.1901639  -1.0996888   0.44367844]
  [ 0.5984697  -0.39189556  0.69261974]]

 [[ 0.46018356 -2.068578   -0.21438177]
  [-0.9898306  -0.6789304   0.27362573]]]
(-1,)


Array([[[ 0.17805116],
        [ 0.2997313 ]],

       [[-0.6075921 ],
        [-0.4650451 ]]], dtype=float32)

In [8]:
# Trying layer norm
context_len = 2
batch_size = 2
emb_size = 2

x = torch.randn(context_len, batch_size, emb_size, requires_grad=False) * 10
print(f"input shape: {x.shape}")
print(x)
layer_norm = torch.nn.LayerNorm((emb_size,))

# print layer norm learnable params
for name, param in layer_norm.named_parameters():
    print(name, param)

with torch.no_grad():
    out = layer_norm(x).detach().numpy()

print()
print(f"out.shape: {out.shape}")
print(out)

input shape: torch.Size([2, 2, 2])
tensor([[[-14.4555,   5.6489],
         [-13.8605,   4.7949]],

        [[  5.1084, -12.1866],
         [ -2.4934, -20.7785]]])
weight Parameter containing:
tensor([1., 1.], requires_grad=True)
bias Parameter containing:
tensor([0., 0.], requires_grad=True)

out.shape: (2, 2, 2)
[[[-1.         0.9999999]
  [-1.         0.9999999]]

 [[ 1.        -1.       ]
  [ 0.9999999 -1.       ]]]


In [9]:
x_jnp = jnp.array(x.detach().numpy())

from attention import LayerNorm

ln = LayerNorm((emb_size,))
ln_state = ln.init_state()
print(ln_state)

out_jax = ln.forward(ln_state, x_jnp)
print(out_jax.shape)
print(out_jax)

LayerNormState(gamma=Array([1., 1.], dtype=float32), beta=Array([0., 0.], dtype=float32))
(2, 2, 2)
[[[-0.9999999  0.9999999]
  [-0.9999999  0.9999999]]

 [[ 1.        -1.       ]
  [ 1.        -0.9999999]]]
