In [1]:
import jax
import jax.numpy as jnp

x = jnp.arange(6, dtype=jnp.float32)
w = jnp.array([1.0, 0.0, -1.0])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)

2024-07-14 21:38:18.980222: W external/xla/xla/service/gpu/nvptx_compiler.cc:765] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.5.82). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Array([-2., -2., -2., -2.], dtype=float32)

In [2]:
xs = jnp.stack([jnp.arange(12), jnp.arange(12)])
ws = jnp.array([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0]])

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)

Array([[-2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
       [-4., -4., -4., -4., -4., -4., -4., -4., -4., -4.]], dtype=float32)

In [3]:
%timeit manually_batched_convolve(xs, ws)

7.3 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
auto_batch_convolve = jax.vmap(convolve)

auto_batch_convolve(xs, ws)

%timeit auto_batch_convolve(xs, ws)

11.6 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
auto_batch_convolve = jax.jit(jax.vmap(convolve))

auto_batch_convolve(xs, ws)

%timeit auto_batch_convolve(xs, ws)

28.1 µs ± 456 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xt = jnp.transpose(xs)
wt = jnp.transpose(ws)

auto_batch_convolve_v2(xt, wt)

Array([[-2., -4.],
       [-2., -4.],
       [-2., -4.],
       [-2., -4.],
       [-2., -4.],
       [-2., -4.],
       [-2., -4.],
       [-2., -4.],
       [-2., -4.],
       [-2., -4.]], dtype=float32)

In [7]:
batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)

Array([[-2., -2., -2., -2., -2., -2., -2., -2., -2., -2.],
       [-2., -2., -2., -2., -2., -2., -2., -2., -2., -2.]], dtype=float32)