# Import Package

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
import haiku as hk
import jax
import jax.numpy as jnp
from deeprte.model.modules import Attention,Attention_v2
from deeprte.model.config import CONFIG

In [3]:
c = CONFIG.green_function.attenuation.attention
gc = CONFIG.global_config

# Define hk module

In [18]:
def forward_v1(*arg):
    out = Attention(c, gc)(*arg)
    return out
attn_v1 = hk.transform(forward_v1)

def forward_v2(*arg):
    out = Attention_v2(c, gc)(*arg)
    return out
attn_v2 = hk.transform(forward_v2)

In [19]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

# Generate Q, K, V

In [20]:
q_dim = 2
num_seq = 4000
k_dim = 3
v_dim = 3

q = jax.random.uniform(next(rng), [q_dim,])
k = jax.random.uniform(next(rng), [num_seq, k_dim])
v = jax.random.uniform(next(rng), [num_seq, v_dim])

mask = jax.random.randint(next(rng), [1, num_seq,], 0, 2)

# Init params

In [21]:
params = attn_v1.init(next(rng), q, k, v, mask)
jax.tree_util.tree_map(lambda x: x.shape, params_1)

{'attention/key': {'b': (64,), 'w': (3, 64)},
 'attention/output_projection': {'b': (64,), 'w': (64, 64)},
 'attention/query': {'b': (64,), 'w': (2, 64)},
 'attention/value': {'b': (64,), 'w': (3, 64)}}

# Compare output

In [27]:
logits_v2 = attn_v2.apply(params, next(rng), q, k ,v, mask)
logits_v1 = attn_v1.apply(params, next(rng), q, k ,v, mask)

In [26]:
jnp.allclose(logits_v2, logits_v1)

DeviceArray(False, dtype=bool)

In [35]:
jnp.sqrt(jnp.mean((logits_v2-logits_v1)**2)/jnp.mean(logits_v1**2))

DeviceArray(0.00295489, dtype=float32)