# JAX - The Sharp Bits

In [None]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

## Pure functions

In [None]:
def impure_print_side_effect(x):
    print("Executing function")  # This is a side-effect
    return x


# The side-effects appear during the first run
print("First call: ", jit(impure_print_side_effect)(4.0))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print("Second call: ", jit(impure_print_side_effect)(5.0))

# JAX re-runs the Python function when the type or shape of the argument changes
print("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.0])))

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [None]:
g = 0.0


def impure_uses_globals(x):
    return x + g


# JAX captures the value of the global during the first run
print("First call: ", jit(impure_uses_globals)(4.0))
g = 10.0  # Update the global

# Subsequent runs may silently use the cached value of the globals
print("Second call: ", jit(impure_uses_globals)(5.0))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.0])))

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [None]:
g = 0.0


def impure_saves_global(x):
    global g
    g = x
    return x


# JAX runs once the transformed function with special Traced values for arguments
print("First call: ", jit(impure_saves_global)(4.0))
print("Saved global: ", g)  # Saved global has an internal JAX value

First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [11]:
def pure_uses_internal_state(x):
    '''
    A Python function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state:
    '''
    state = dict(even=0, odd=0)
    for i in range(10):
        state["even" if i % 2 == 0 else "odd"] += x
    return state["even"] + state["odd"]

print(jit(pure_uses_internal_state)(5.0))

50.0


In [None]:
"""
It is not recommended to use iterators in any JAX function you want to jit or in any control-flow primitive. The reason is that an iterator is a python object which introduces state to retrieve the next element. Therefore, it is incompatible with JAX functional programming model. In the code below, there are some examples of incorrect attempts to use iterators with JAX. Most of them return an error, but some give unexpected results.
"""
import jax.numpy as jnp
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error