Just In Time Compilation with JAX

https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html

In [None]:
import jax
import jax.numpy as jnp

In [None]:
global_scope = []

def some_unpure_func(x):
    """ Functionally not pure function. """
    global_scope.append(x) # JAX does not know about this global scope variable
    print(f'inside unpure {x}') # Printing is also unpure
    return jnp.log(x) / jnp.log(2)

jax.make_jaxpr(some_unpure_func)(1)


In [None]:
def some_pure_func(x):
    """ Functionally pure function. """
    return jnp.log(x) / jnp.log(2)

jax.make_jaxpr(some_pure_func)(1)

In [None]:
# JIT can get confused with conditionals

def conditional_func(x):
    return jnp.log(x) / jnp.log(2) if x.ndim >= 2 else 0

print(f'{jax.make_jaxpr(conditional_func)(jnp.asarray([1]))}')
print(f'{jax.make_jaxpr(conditional_func)(jnp.asarray([[1]]))}')


In [None]:
import matplotlib.pyplot as plt

def selu(x, alpha=1.67, lambda_=1.05):
    """ Scaled exponential linear unit"""
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = 0.01*jnp.arange(-100, 100)
plt.plot(x, selu(x))

jax.make_jaxpr(selu)(x)

In [None]:
%timeit selu(x).block_until_ready()

In [None]:
# Create a JIT hardware-ready version
selu_jit = jax.jit(selu)


# One "warm-up" call to create the compiled code
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

In [None]:
# JIT doesn't work when conditioned on value of function
def another_unpure_func(x):
    if x > 1:
        return x
    else:
        return x*x

jax.jit(another_unpure_func)(1.0)

# It does work when not conditioned on value, but it is misleading
jax.jit(conditional_func)(1.0)

In [None]:
# Seperate out the JIT-able parts of a function

@jax.jit
def jittable_segment(x):
    return x + 1

def un_jittable_func(x, n):
    i = 0
    while i < n:
        i = jittable_segment(i)
    return x + i

%timeit un_jittable_func(10, 20)

In [None]:
def jittable_segment_2(x):
    return x + 1

def un_jittable_func_2(x, n):
    i = 0
    while i < n:
        i = jittable_segment_2(i)
    return x + i

%timeit un_jittable_func_2(10, 20)

Generally, you want to jit the largest possible chunk of your computation