https://www.pragmatic.ml/first-look-at-jax/

In [0]:
import jax
import jax.numpy as np

def gpu_backed_hidden_layer(x):
  return jax.nn.relu(np.dot(X, w) + b)

In [0]:
from jax.scipy.linalg import svd

x = np.array([[1, 2], [3, 4]])
singular_vectors, singular_values, _ = svd(x)

In [7]:
from jax import grad

grad_hidden_layer = grad(gpu_backed_hidden_layer)
grad_hidden_layer

<function __main__.grad.<locals>.grad_f>

In [0]:
# Also diffentiate through native python controll structures

def absolute_value(x):
  if x >= 0:
    return x
  else:
    return -x

grad_absolute_value = grad(absolute_value)

In [11]:
# JAX also includes support for taking higher-order derivatives
# the `grad` funciton can be chained arbitrarily.

# grads all the way down
print(grad(grad(grad(np.tanh)))(1.0))

0.6216267


In [0]:
from jax import jacfwd, jacrev

hessian_fn = jacfwd(jacrev(absolute_value))

In [0]:
def unoptimized_fn(x, y, z):
  return np.sum(x + y * z)

In [0]:
@jax.jit
def xla_optimized_fn(x, y, z):
  return np.sum(x + y * z)

In [0]:
lax_optimized_grad = jax.jit(grad(xla_optimized_fn))

In [21]:
# Convention to distinguish between
# jax.numpy and numpy
import numpy as onp

def hidden_layer(x):
  W = np.zeros([128, 128], dtype=np.float32)
  b = np.zeros([128])
  return jax.nn.relu(np.dot(W, x) + b)

print(hidden_layer(onp.random.randn(128)).shape)

(128,)


In [22]:
batch_hidden_layer = jax.vmap(hidden_layer)
print(batch_hidden_layer(onp.random.randn(32, 128)).shape)

(32, 128)


In [0]:
batch_hidden_layer = jax.vmap(hidden_layer, in_axes=(0, ))

In [0]:
# first dimension must align with number of XLA-enabled
# devices
spmd_hidden_layer = jax.pmap(hidden_layer)

In [28]:
# Hypothetical setup for high-throughput inference
outputs = jax.pmap(jax.vmap(hidden_layer))(onp.random.randn(4, 32, 128))
print(outputs.shape)

ValueError: ignored