In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

### Simple function

In [2]:
def linear(x, w, b):
    return jnp.dot(x,w) + b

In [3]:
x = jnp.array([1,2,3,4,5], dtype=jnp.float32)
w = jnp.array([0.5, 0.4, 0.6, 0.8, 0.9], dtype=jnp.float32)
b = jnp.array([-1,-3,-6,-19,-2], dtype=jnp.float32)



In [4]:
linear(x, w, b)

DeviceArray([ 9.8,  7.8,  4.8, -8.2,  8.8], dtype=float32)

### Vectorization (batching)

In [5]:
linear_batched = vmap(linear,in_axes=(0, None, None)) # in_axes to specify what function param to batch
                                                      # axis 0 of param 1, others don't batch (None)

In [6]:
x_batch = jnp.array([
                [1,2,3,4,5],
                [6,7,8,9,10],
                [11,12,13,14,15],
              ], dtype=jnp.float32)

In [7]:
linear_batched(x_batch, w, b) # now first param expects a batch of X instead of a single X

DeviceArray([[ 9.8     ,  7.8     ,  4.8     , -8.2     ,  8.8     ],
             [25.800001, 23.800001, 20.800001,  7.800001, 24.800001],
             [41.8     , 39.8     , 36.8     , 23.8     , 40.8     ]],            dtype=float32)

### Vectorization combined with JIT

In [8]:
optimized_linear_batched = jit(linear_batched)

In [9]:
%timeit -n 100 linear_batched(x_batch, w, b)

1.13 ms ± 94 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
%timeit -n 100 optimized_linear_batched(x_batch, w, b)

The slowest run took 17.23 times longer than the fastest. This could mean that an intermediate result is being cached.
73.6 µs ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
