In [1]:
import jax.numpy as jnp
from jax import grad, jit
from jax import random

In [2]:
key = random.PRNGKey(0)

In [4]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)

# We added that block_until_ready because JAX uses asynchronous execution by default
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

14 ms ± 4.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

30.5 ms ± 2.56 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

13 ms ± 442 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

1.54 ms ± 275 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

86.5 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

In [10]:
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]
