https://jax.readthedocs.io/en/latest/automatic-vectorization.html

`vmap`의 여러 활용방식

In [1]:
import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)

Array([11., 20., 29.], dtype=float32)

In [9]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])
xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

print(f"{xs.shape=}")
print(f"{ws.shape=}")
print(f"{xst.shape=}")
print(f"{wst.shape=}")

xs.shape=(2, 5)
ws.shape=(2, 3)
xst.shape=(5, 2)
wst.shape=(3, 2)


In [3]:
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)
auto_batch_convolve_v2(xst, wst)

Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32)

In [14]:
auto_batch_convolve_v3 = jax.vmap(convolve, in_axes=(0, 1), out_axes=1)
auto_batch_convolve_v3(xs, wst)

Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32)

In [15]:
auto_batch_convolve_v4 = jax.vmap(convolve, in_axes=(0, 1), out_axes=0)
auto_batch_convolve_v4(xs, wst)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [16]:
auto_batch_convolve_v5 = jax.vmap(convolve, in_axes=(1, 0), out_axes=0)
auto_batch_convolve_v5(xst, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)