# JAX as NumPy

In [1]:
import jax.numpy as jnp

In [2]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [3]:
x = jnp.arange(5.0)
selu(x)

Array([0.       , 1.05     , 2.1      , 3.1499999, 4.2      ], dtype=float32)

In [None]:
from jax import random

key = random.key(1701)
x = random.normal(key, shape=(1000000,))
%timeit selu(x).block_until_ready()  # if do not use block_until_ready, it will return a false time 

138 μs ± 5.26 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [5]:
from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x)  # compiles on first call
%timeit selu_jit(x).block_until_ready()

17 μs ± 1.78 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [None]:
from jax import grad


def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))


x_small = jnp.arange(3.0)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

# def first_finite_differences(f, x, eps=1e-3):
#     return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
#                    for v in jnp.eye(len(x))])


def first_finite_differences(f, x, eps=1e-3):
    return jnp.array(
        [(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))]
    )


print(first_finite_differences(sum_logistic, x_small))

[0.25       0.19661194 0.10499357]
[0.24998187 0.1965761  0.10502338]


In [None]:
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))  # grad defind for scaler-function

-0.0353256


In [None]:
from jax import jacobian

print(jacobian(jnp.exp)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]


In [None]:
from jax import jacfwd, jacrev


def hessian(fun):
    return jit(jacfwd(jacrev(fun)))


print(hessian(sum_logistic)(x_small))

[[-0.         -0.         -0.        ]
 [-0.         -0.09085774 -0.        ]
 [-0.         -0.         -0.07996248]]


In [None]:
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100))
batched_x = random.normal(key2, (10, 100))


@jit
def apply_matrix(x):
    return jnp.dot(mat, x)


@jit
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])


print("Naively batched")
%timeit naively_btched_apply_matrix(batched_x).block_until_ready()

import numpy as np


@jit
def batched_apply_matrix(batched_x):
    return jnp.dot(batched_x, mat.T)


np.testing.assert_allclose(
    naively_batched_apply_matrix(batched_x),
    batched_apply_matrix(batched_x),
    atol=1e-4,
    rtol=1e-4,
)
print("Manually batched")
%timeit batched_apply_matrix(batched_x).block_until_ready()

from jax import vmap


@jit
def vmap_batched_apply_matrix(batched_x):
    return vmap(apply_matrix)(batched_x)


np.testing.assert_allclose(
    naively_batched_apply_matrix(batched_x),
    vmap_batched_apply_matrix(batched_x),
    atol=1e-4,
    rtol=1e-4,
)
print("Auto-vectorized with vmap")
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
526 μs ± 4.39 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Manually batched
18.1 μs ± 800 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Auto-vectorized with vmap
20.2 μs ± 328 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
