In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10, ))
print(x.shape)



(10,)


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%time jnp.dot(x, x.T)
print("hello")

CPU times: user 4.74 s, sys: 106 ms, total: 4.85 s
Wall time: 1.59 s
hello


In [4]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

1.43 s ± 44.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%time jnp.dot(x, x.T)

CPU times: user 4.47 s, sys: 99 ms, total: 4.57 s
Wall time: 1.38 s


DeviceArray([[ 3.0269419e+03,  7.0578247e+01,  3.6624695e+01, ...,
              -2.0423191e+01, -9.6192074e+00,  7.2796245e+00],
             [ 7.0578247e+01,  2.9739919e+03,  1.1413579e+00, ...,
              -1.2797272e+01,  3.2100609e+01,  2.9788399e+01],
             [ 3.6624695e+01,  1.1413579e+00,  2.9718857e+03, ...,
              -4.9126926e+01, -2.4262154e+01,  3.4415054e+01],
             ...,
             [-2.0423191e+01, -1.2797272e+01, -4.9126926e+01, ...,
               2.9759958e+03, -7.2979441e+00, -2.8910929e+01],
             [-9.6192074e+00,  3.2100609e+01, -2.4262154e+01, ...,
              -7.2979441e+00,  2.8355730e+03, -8.5347986e+00],
             [ 7.2796245e+00,  2.9788399e+01,  3.4415054e+01, ...,
              -2.8910929e+01, -8.5347986e+00,  3.0271863e+03]],            dtype=float32)

In [10]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000, ))
%timeit selu(x).block_until_ready()

8.13 ms ± 345 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

2.29 ms ± 54.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.0)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [16]:
def first_finite_differences(f, x):
    eps = 1e-3
    return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                      for v in jnp.eye(len(x))])

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1964569  0.10502338]


In [17]:
print(grad(grad(grad(sum_logistic)))(1.0))

0.19661197


In [23]:
from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))

In [32]:
print(jacfwd(sum_logistic)(x_small))
print(jacrev(sum_logistic)(x_small))

#print(hessian(sum_logistic)(x_small))

[0.25       0.19661197 0.10499357]
[0.25       0.19661197 0.10499357]


In [33]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)

In [49]:
@jit
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print("Naively batched")
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
result = naively_batched_apply_matrix(batched_x)
print(result.shape)

Naively batched
54.5 µs ± 19.3 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
(10, 150)


In [52]:
@jit
def batched_apply_matrix(v_batched):
    return jnp.dot(v_batched, mat.T)

print("Manually batched")
%timeit batched_apply_matrix(batched_x).block_until_ready()
result = batched_apply_matrix(batched_x)
print(result.shape)

Manually batched
53.4 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
(10, 150)


In [51]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print("Auto-vectorized with vmap")
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
81.1 µs ± 695 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
