参考

1. [JAX入門～高速なNumPyとして使いこなすためのチュートリアル～](https://qiita.com/koshian2/items/44a871386576b4f80aff)

In [1]:
import jax.numpy as jnp

In [2]:
# 配列の初期化
x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
print(x)

[[ 0.  1.  2.  3.  4.]
 [ 5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14.]
 [15. 16. 17. 18. 19.]
 [20. 21. 22. 23. 24.]]


In [4]:
# block_until_ready => jax内部の非同期処理による計算をまとめるため
y = x + 1
x_gram = jnp.dot(x, y.T).block_until_ready()
print(x_gram)

[[  40.   90.  140.  190.  240.]
 [ 115.  290.  465.  640.  815.]
 [ 190.  490.  790. 1090. 1390.]
 [ 265.  690. 1115. 1540. 1965.]
 [ 340.  890. 1440. 1990. 2540.]]


In [5]:
# jitによるXLAコンパイル
# メモリや推論時間で有利
# https://www.tensorflow.org/xla

from jax import jit

In [7]:
@jit
def static_jax_dot():
    x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
    x_gram = jnp.dot(x, x.T)
    return x_gram

In [8]:
static_jax_dot().block_until_ready()

DeviceArray([[  30.,   80.,  130.,  180.,  230.],
             [  80.,  255.,  430.,  605.,  780.],
             [ 130.,  430.,  730., 1030., 1330.],
             [ 180.,  605., 1030., 1455., 1880.],
             [ 230.,  780., 1330., 1880., 2430.]], dtype=float32)

In [None]:
# 以下の書き方も可能
# jit(static_jax_dot)().block_until_ready()

# Errorの例

In [10]:
#jit関数内部にblock_until_ready
# AttributeError: 'ShapedArray' object has no attribute 'block_until_ready'
@jit
def static_jax_dot_badexample():
    x = jnp.arange(25, dtype=jnp.float32).reshape(5, 5)
    x_gram = jnp.dot(x, x.T)
    return x_gram.block_until_ready()

#static_jax_dot_badexample()

In [12]:
#jitでラップした関数に引数を渡す時
#ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
@jit
def variable_jax_dot_badexample(size):
    x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
    x_gram = jnp.dot(x, x.T)
    return x_gram

#variable_jax_dot_badexample(5)


In [13]:
# ↑の引数を渡したい時は、どうするか
# 1. static_argnums : 引数の何番目が固定値であるか
def variable_jax_dot(size):
    x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
    x_gram = jnp.dot(x, x.T)
    return x_gram

jit(variable_jax_dot, static_argnums=(0,))(5).block_until_ready()


DeviceArray([[  30.,   80.,  130.,  180.,  230.],
             [  80.,  255.,  430.,  605.,  780.],
             [ 130.,  430.,  730., 1030., 1330.],
             [ 180.,  605., 1030., 1455., 1880.],
             [ 230.,  780., 1330., 1880., 2430.]], dtype=float32)

In [14]:
# 2. partialでjitをデコレート
from jax import partial

In [15]:
@partial(jit, static_argnums=(0,))
def variable_jax_dot_deco(size):
    x = jnp.arange(size**2, dtype=jnp.float32).reshape(size, size)
    x_gram = jnp.dot(x, x.T)
    return x_gram

variable_jax_dot_deco(5).block_until_ready()


DeviceArray([[  30.,   80.,  130.,  180.,  230.],
             [  80.,  255.,  430.,  605.,  780.],
             [ 130.,  430.,  730., 1030., 1330.],
             [ 180.,  605., 1030., 1455., 1880.],
             [ 230.,  780., 1330., 1880., 2430.]], dtype=float32)

# パフォーマンスの比較

In [27]:
# size > 1000でjax + jitが早くなった
import numpy as np
import jax.numpy as jnp
from jax import jit, partial

# (size, size)の行列を作ってMod計算
@partial(jit, static_argnums=(0,))
def jax_jit_mod(size):
    x = jnp.arange(size, dtype=jnp.int32)
    mat = x[None, :] * x[:, None] # (size, size)
    return mat % 256

def jax_nojit_mod(size):
    x = jnp.arange(size, dtype=jnp.int32)
    mat = x[None, :] * x[:, None]
    return mat % 256

def numpy_mod(size):
    x = np.arange(size, dtype=np.int32)
    mat = x[None, :] * x[:, None]
    return mat % 256

for i in range(4):
    size = 10**(i+1)
    repeat = 10**(4-i)
    print("size =", size, "repeat =", repeat)
    %timeit numpy_mod(size)
    %timeit jax_nojit_mod(size).block_until_ready() # jitなしJAX
    %timeit jax_jit_mod(size).block_until_ready() # jitありJAX

size = 10 repeat = 10000
6.15 µs ± 318 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.1 ms ± 56.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
51.8 µs ± 1.66 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
size = 100 repeat = 1000
59.6 µs ± 1.74 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
2.23 ms ± 166 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
98.5 µs ± 9.86 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
size = 1000 repeat = 100
4.82 ms ± 162 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
4.48 ms ± 1.03 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
210 µs ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
size = 10000 repeat = 10
839 ms ± 24.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
694 ms ± 50 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
92.1 ms ± 5.19 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [29]:
import jax
# deviceの指定
# devicesは実験的機能で変更される可能性がある
def dot_function():
    x = jnp.arange(1000**2, dtype=jnp.float32).reshape(1000, 1000)
    return jnp.dot(x, x.T)


%timeit -n 100 jit(dot_function, device=jax.devices("cpu")[0])().block_until_ready()

11.3 ms ± 1.5 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


# Numpy配列を併用する場合

In [30]:
import cv2
from jax import device_put

In [35]:
@jit
def blend_color(x):
    x = device_put(x)
    x = x[:, :, ::-1].astype(jnp.float32) / 255.
    blend = jnp.ones(x.shape[:-1], dtype=jnp.float32)[..., None] # h, w, c, 1
    blend = blend * (jnp.array([235, 86, 230], dtype=jnp.float32).reshape(1, 1, -1) / 255.0)

    a = 2 * x * blend + x ** 2 * (1 - 2 * blend)
    b = 2 * x * (1 - blend) + jnp.sqrt(x) * (2 * blend - 1)
    out = (blend < 0.5) * 2 + (blend >= 0.5) * b

    out = (out * 255.).astype(jnp.uint8)
    return out

img_path = "../data/sample.jpg"
img = cv2.imread(img_path)
if img is not None:
    %timeit blend_color(img).block_until_ready()
else:
    print(f"Can not read {img_path}")

331 µs ± 31.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
