In [2]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np

In [4]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()

707 µs ± 72.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

57.3 µs ± 503 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357]


In [16]:
sum_logistic(x_small), derivative_fn(x_small)

(Array(2.1118557, dtype=float32),
 Array([0.25      , 0.19661194, 0.10499357], dtype=float32))

In [7]:
def first_finite_differences(f, x):
    eps = 1e-3
    return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [21]:
def simple(x):
    return x ** 3
def der(x):
    return 3 * x**2

x_l = jnp.arange(3.)
derivative_f = grad(simple)
#print(derivative_f(x_l))

In [22]:
derivative_f(x_l[0]),derivative_f(x_l[1]),derivative_f(x_l[2])

(Array(0., dtype=float32), Array(3., dtype=float32), Array(12., dtype=float32))

In [20]:
derivative_f(x_l[0])

Array(0., dtype=float32)