## JAX

In [4]:
import jax.numpy as jnp

In [5]:
def relu(x):
    return jnp.maximum(0, x)

def softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x), axis=0)

def cross_entropy(x, y):
    return jnp.sum(y * jnp.log(x))


In [6]:
x = jnp.arange(0, 6)
print(relu(x))
print(softmax(x))
print(cross_entropy(softmax(x), x))

[0 1 2 3 4 5]
[0.00426978 0.01160646 0.03154963 0.08576079 0.23312199 0.6336913 ]
-26.8429


In [7]:
from jax import random

key = random.PRNGKey(0)
key, subkey = random.split(key)
print(subkey)

[ 928981903 3453687069]


In [8]:
x = random.normal(key, (1_000_000,))
%timeit relu(x)

92.2 μs ± 6.7 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
from jax import jit

jitted_relu = jit(relu)
a = jitted_relu(x)
%timeit relu(a)

88.9 μs ± 7.33 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [11]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

key = random.key(1000)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

1.02 ms ± 170 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
selu_jit = jit(selu)
_= selu_jit(x)
%timeit selu_jit(x).block_until_ready()

78.3 μs ± 5.11 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [13]:
def mm(x, y):
    return jnp.dot(x, y)

@jit
def mm_jit(x, y):
    return jnp.dot(x, y)

a = random.normal(key, (1000, 1000))
b = random.normal(key, (1000, 1000))

In [14]:
%timeit mm(a, b)

58.6 μs ± 8.97 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
%timeit mm_jit(a, b)

56.2 μs ± 268 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
