#### Manual Vectorization

In [1]:
import jax 
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
    output = []
    for i in range(1, len(x)-1):
        output.append(jnp.dot(x[i-1:i+2], w))
    return jnp.array(output)

convolve(x, w)

Array([11., 20., 29.], dtype=float32)

If we have batch of weights ``w`` and batch of vectors ``x``, we can do it manually as:

In [4]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

def manually_batched_convolve(xs, ws):
    output = []
    for i in range(xs.shape[0]):
        output.append(convolve(xs[i], ws[i]))
    return jnp.stack(output)

manually_batched_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

In [5]:
#Io support vectorized computation across the batch dimension 

def manually_vectorized_convolve(xs, ws):
    output = []
    for i in range(1, xs.shape[1]-1):
        output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
    return jnp.stack(output, axis=1)
manually_vectorized_convolve(xs, ws)    

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

#### Automatic Vectorization

In [6]:
auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

By default, we have batch dimension in the ``first axis`` i.e axis $0$ of input, and output.

However, when if batch dimension is not the first axis,``in_axes`` and ``out_axes`` is used to specify where the batch dimension is located

In [7]:
auto_batch_convolve2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)
auto_batch_convolve2(xst, wst)

Array([[11., 11.],
       [20., 20.],
       [29., 29.]], dtype=float32)

Then, in few cases in ML when we some inputs are batched (e.g, data points) and others are not (e.g, model parameters or constants), we can use ``in_axes = [0, None]`` (assuming, batched input is first axis).

In [8]:
batch_convolve3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve3(xs, w)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)

Combining Transformations. Also, ``jax.vmap`` can be wrapped with ``jit``

In [9]:
jitted_batch_convolve = jax.jit(jax.vmap(convolve))
jitted_batch_convolve(xs, ws)

Array([[11., 20., 29.],
       [11., 20., 29.]], dtype=float32)