In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

145 ms ± 10 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

149 ms ± 4.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

144 ms ± 6.11 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x>0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))

In [8]:
%timeit selu(x).block_until_ready()

1.48 ms ± 42.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

415 µs ± 8.56 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
def sum_logistic(x):
    return jnp.sum(1.0/(1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
x_small

Array([0., 1., 2.], dtype=float32)

In [20]:
sum_logistic(x_small)

Array(2.1118555, dtype=float32)

In [13]:
derivaive_fn = grad(sum_logistic)
print(derivaive_fn(x_small))

[0.25       0.19661197 0.10499357]


In [14]:
jnp.eye(len(x_small))

Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

In [18]:
def first_finite_differences(f, x):
    eps = 1e-3
    return jnp.array([ (f(x + eps*v) - f(x - eps*v))/(2*eps) for v in jnp.eye(len(x))])

In [19]:
first_finite_differences(sum_logistic, x_small)

Array([0.24998187, 0.1964569 , 0.10502338], dtype=float32)

In [26]:
grad(jit(sum_logistic))(0.0)

Array(0.25, dtype=float32, weak_type=True)

In [25]:
jit(grad(jit(sum_logistic)))(0.0)

Array(0.25, dtype=float32, weak_type=True)

In [29]:
grad(jit(grad(jit(grad(jit(sum_logistic))))))(1.0)

Array(-0.0353256, dtype=float32, weak_type=True)

In [32]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)

In [37]:
for x in batched_x:
    print(mat.shape)
    print(x.shape)
    print(apply_matrix(x).shape)

(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)
(150, 100)
(100,)
(150,)


In [33]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

In [38]:
naively_batched_apply_matrix(batched_x).block_until_ready().shape

(10, 150)

In [34]:
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
480 µs ± 2.02 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [41]:
@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
9.59 µs ± 76 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [42]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
13.7 µs ± 36.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
