In [None]:
# JIT compiling a function

import jax
import jax.numpy as jnp

def selu(x, alpha=1.67, lambda_=1.05):
  return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()

selu_jit = jax.jit(selu)

selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

1.05 ms ± 7.85 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
228 μs ± 1.69 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
# jax.make_jaxpr(how to work)

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

y = jnp.ones(x.shape)
make_jaxpr(f)(x, y)

{ [34;1mlambda [39;22m; a[35m:i32[1000000][39m b[35m:f32[1000000][39m. [34;1mlet
    [39;22mc[35m:i32[1000000][39m = add a 1:i32[]
    d[35m:f32[1000000][39m = add b 1.0:f32[]
    e[35m:f32[][39m = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  [34;1min [39;22m(e,) }

In [None]:
# Mark variables as static that don't want JIT to trace (no recompilation needed when they change)

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)