In [29]:
import matplotlib.pyplot as plt

import jax.numpy as jnp
from jax import random, jit, grad, jacobian

In [6]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.        1.05      2.1       3.1499999 4.2      ]


In [12]:
key = random.key(1111)

x = random.normal(key, (1_000_000))

selu_jit = jit(selu)
%timeit selu(x).block_until_ready()

3.51 ms ± 241 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [25]:
from jax import grad

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661197 0.10499357]


In [None]:
def first_finite_differences(f, x, eps=1E-3):
  return jnp.array([
    (f(x + eps * v) - f(x - eps * v)) / (2 * eps)
    for v in jnp.eye(len(x))
])

print(first_finite_differences(sum_logistic, x_small))

[0.24998187 0.1965761  0.10502338]


In [32]:
print(jacobian(jnp.exp)(x_small))

[[1.        0.        0.       ]
 [0.        2.7182817 0.       ]
 [0.        0.        7.389056 ]]
