# 02 Automatic Vectorization

Original Documentation: https://docs.jax.dev/en/latest/automatic-vectorization.html


In [16]:
import jax
import jax.numpy as jnp

## Manual vectorization


In [17]:
def convolve(x, w):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i - 1 : i + 2], w))
    return jnp.array(output)


x = jnp.arange(5)
w = jnp.array([2.0, 3.0, 4.0])

print(convolve(x, w))

[11. 20. 29.]


Suppose we wanted to apply `convolve` to a batch of weights `w` and vectors `x`. The naive implementation would introduce a new Python loop that calls `convolve`:


In [18]:
def convolve(x, w):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i - 1 : i + 2], w))
    return jnp.array(output)


x = jnp.arange(5)
w = jnp.array([2.0, 3.0, 4.0])

# Sample batch of vectors and weights
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])


def naive_batch_convolve(xs, ws):
    assert xs.shape[0] == ws.shape[0]
    output = []
    for i in range(xs.shape[0]):
        output.append(convolve(xs[i], ws[i]))
    return jnp.stack(output)


print(naive_batch_convolve(xs, ws))

[[11. 20. 29.]
 [11. 20. 29.]]


This is quite inefficient since each batch’s convolution could have been done in parallel. To make it more efficient, we could rewrite to express the computations in a batched form:


In [19]:
def manual_batch_convolve(xs, ws):
    assert xs.shape[0] == ws.shape[0]
    output = []
    for i in range(1, xs.shape[1] - 1):  # xs.shape[-1] == 5
        # Apply convolution along axis=1
        # xs[:, i - 1 : i + 2] is sliding window of size 3 across axis=1
        output.append(jnp.sum(xs[:, i - 1 : i + 2] * ws, axis=1))
    return jnp.stack(output, axis=1)


print(manual_batch_convolve(xs, ws))

[[11. 20. 29.]
 [11. 20. 29.]]


Here, we compute the convolution for all windows in the batch simultaneously at each iteration by sliding a window along the last axis. In other words, each loop iteration processes the same window across all batches in parallel (since we use `jnp.sum()`)

However, this is messy and error-prone.

## Automatic vectorization

`jax.vmap()` automatically generates a vectorized implementation of a function:


In [20]:
auto_batch_convolve = jax.vmap(convolve)

print(auto_batch_convolve(xs, ws))

[[11. 20. 29.]
 [11. 20. 29.]]


If the batch dimension is not the first axis, use the `in_axes` and `out_axes` arguments to specify the location of the batch dimension in inputs and outputs.


In [21]:
# Transpose so batch dimension is now axis=1
xst = xs.transpose()  # shape (5, 2)
wst = ws.transpose()  # shape (3, 2)

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

print(auto_batch_convolve_v2(xst, wst))

[[11. 11.]
 [20. 20.]
 [29. 29.]]


## Combining transformations

`jax.jit()` and `jax.vmap()` are composable; we can wrap them in either order.

Notably, the order of wrapping them does matter: JIT compiling a vectorized function is different from vectorizing a JIT compiled function.

Usually, you want to JIT compile a vectorized function (`jax.jit(jax.vmap(fn))`) since it gives the XLA compiler more code to optmize upfront.

However, if the shapes of batches are not consistent, then we want to vectorize a JIT compiled function (`jax.vmap(jax.jit(fn))`) since we cannot provide different shapes to a function that has been compiled for one shape.


In [22]:
compiled_batch_convolve = jax.jit(
    jax.vmap(convolve)
)  # Shape along axis=1 is consistent

print(compiled_batch_convolve(xs, ws))

[[11. 20. 29.]
 [11. 20. 29.]]
