Automatic Vectorization in JAX
https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html

In [None]:
import jax
import jax.numpy as jnp

In [None]:
x = jnp.array([1, 0, 1, 2, 3, 2, 1, 0, 1]).astype(jnp.float32)
w = jnp.array([0,1,0]).astype(jnp.float32)

def convolve(image, kernel):
    output = []
    for i in range(1, len(image)-1):
        output.append(jnp.dot(image[i-1:i+2], kernel))
    return jnp.array(output)

convolve(x, w)

In [None]:
x_batch = jnp.stack([x, x, x, x])
w_batch = jnp.stack([w, w, w, w])
w_batch

In [None]:
# Automatic vectorization
batch_convolve = jax.vmap(convolve)
# jax.make_jaxpr(batch_convolve)(x_batch, w_batch)
batch_convolve(x_batch, w_batch)


In [None]:
jax.vmap(convolve, in_axes=1, out_axes=1)(jnp.transpose(x_batch), jnp.transpose(w_batch))

In [None]:
jax.vmap(convolve, in_axes=[0,None])(x_batch, w)