# pseudo test of model attention

In [22]:
import jax
import jax.numpy as jnp

from functools import partial

In [3]:
@jax.jit
def _check_jit(state, comm):
    comm = jnp.sum(comm, axis=1)
    state_com = jnp.concatenate((state, comm), axis=-1)
    return jnp.mean(state_com, axis=-1)

In [64]:
key = jax.random.PRNGKey(0)
# (batch_size, neighbor_agent_num, obs_dim)
state = jax.random.normal(key, shape=(1000,500))
comm1 = jax.random.normal(key, shape=(1000,5,500))
comm2 = jax.random.normal(key, shape=(1000,10,500))

In [66]:
# dim=1のサイズが異なるとその度にrecompileしているみたい。サイズが変わる前の情報も保有しているのである程度時間が経つと早くなる
# てっきり、エラーが出るかと思ったら普通に通った
# なので、うまくvmapすれば、データごとに neighbor_agent_numが異なってもうまくいきそう
# →やっぱ無理だった。batch_dimに対してarray_shapeは一緒でなければエラー出るはず。for_i_loopを使うとましになる？
%time results = _check_jit(state,comm1)
%time results = _check_jit(state,comm1)
%time results = _check_jit(state,comm2)
%time results = _check_jit(state,comm2)

%time results = _check_jit(state,comm1)
%time results = _check_jit(state,comm2)

CPU times: user 2.82 ms, sys: 0 ns, total: 2.82 ms
Wall time: 1.49 ms
CPU times: user 2.46 ms, sys: 0 ns, total: 2.46 ms
Wall time: 2.48 ms
CPU times: user 4.54 ms, sys: 0 ns, total: 4.54 ms
Wall time: 4.03 ms
CPU times: user 3.92 ms, sys: 0 ns, total: 3.92 ms
Wall time: 2.25 ms
CPU times: user 2.4 ms, sys: 0 ns, total: 2.4 ms
Wall time: 1.46 ms
CPU times: user 3.99 ms, sys: 0 ns, total: 3.99 ms
Wall time: 2.28 ms


# Pseudo test of adaptive communication in environmnet

In [58]:
# test get only neighbor agnet information. 
# this code failed
@partial(jax.jit, static_argnames=("neighbor_r"))
# @jax.jit
def get_neighbor_pos(pos, r, neighbor_r):
    idx = jnp.where(r<neighbor_r)
    print(idx)
    neighbor_pos = jnp.take(pos, idx)
    return neighbor_pos, idx

pos = jax.random.normal(key, shape=(10,2))
r = jax.random.normal(key, shape=(10,))

get_neighbor_pos(pos,r, 0.1)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
While tracing the function get_neighbor_pos at /tmp/ipykernel_17011/3839109848.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument passed at flattened position 1.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [70]:
# test code. get n closest agent position

@partial(jax.jit, static_argnames=("n_neighbor"))
def get_neighbor_pos(pos, n_neighbor):
    dist = jnp.sum(pos**2, axis=-1)
    neighbor_index = jnp.argsort(dist)[:n_neighbor]
    neighbor_pos = jnp.take(pos, neighbor_index, axis=0)
    return neighbor_pos

pos1 = jax.random.normal(key, shape=(10,2))
pos2 = jax.random.normal(key, shape=(100,2))
pos3 = jax.random.normal(key, shape=(1000,2))
get_neighbor_pos(pos1,5)
%time get_neighbor_pos(pos1,5)
get_neighbor_pos(pos2,5)
%time get_neighbor_pos(pos2,5)
get_neighbor_pos(pos3,5)
%time get_neighbor_pos(pos3,5)
# この計算はそこまでボトルネックにならなさそう

CPU times: user 33 µs, sys: 3 µs, total: 36 µs
Wall time: 39.8 µs
CPU times: user 53 µs, sys: 0 ns, total: 53 µs
Wall time: 57.5 µs
CPU times: user 226 µs, sys: 16 µs, total: 242 µs
Wall time: 247 µs


DeviceArray([[ 0.00721349,  0.03755151],
             [-0.08827304,  0.04782276],
             [-0.09470108,  0.05764259],
             [-0.06916548,  0.09070419],
             [-0.11401487,  0.01173056]], dtype=float32)