#JAX transform functions
##vmap()
Write your functions as if you were dealing with a single datapoint!

In [12]:
import jax.numpy as jnp
import numpy as np
from jax import jit, grad, vmap
from jax import random
import matplotlib.pyplot as plt

In [13]:
seed = 0
key = random.PRNGKey(seed)

W = random.normal(key, (150, 100))  # e.g. weights of a linear NN layer
batched_x = random.normal(key, (10, 100))  # e.g. a batch of 10 flattened images

def apply_matrix(x):
    return jnp.dot(W, x)  # (150, 100) * (100, 10) -> (150, 10)

In [15]:
batched_x.T.shape

(100, 10)

In [17]:
apply_matrix(batched_x.T).shape

(150, 10)

In [18]:
@jit  # Note: we can arbitrarily compose JAX transforms! Here jit + vmap.
def vmap_batched_apply_matrix(batched_x):
    return vmap(apply_matrix)(batched_x)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
The slowest run took 190.40 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 86.7 µs per loop
