# 08 JIT Control Flow and Logic

Original Documentation: https://docs.jax.dev/en/latest/control-flow.html


In [1]:
import jax
import jax.numpy as jnp

While executing eagerly outside of JIT, JAX works nicely with Python control flow.

Since Python control flow over runtime values can dynamically change the size of the control flow graph, it is not possible for JIT to produce a compiled representation.

Paths dependent on the dtype or shape are okay, but will require recompilation when a value of a new dtype or shape is provided.


In [2]:
@jax.jit
def f(x):
    return (x > 0) and (x < 3)  # This will fail! Runtime comparison in JIT.


f(1.0)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /var/folders/r0/zlgwr7551fg116vdpn_9s7ph0000gn/T/ipykernel_31847/3756118353.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

When we JIT compile a function, we want to compile a version that works for many different values so we can cache and reuse the XLA without needing to recompile.

The only way to still condition on a runtime value is to mark the argument as static (which treats the value as a constant, meaning new values passed to the static argument will require function recompilation).

In a loop, if the loop range is constant or static, then JAX will unroll the loop.

### Functions with argument-value dependent shapes

A common case programmers forget to account for:

Functions that we want to JIT compile cannot specialize the shapes of internal arrays on argument values either.


In [3]:
@jax.jit
def alloc(length, val):
    return jnp.arange(length) * val  # Length of array can't be value dependent!


alloc(10, 1.0)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[]
It arose in the jnp.arange argument 'stop'
The error occurred while tracing the function alloc at /var/folders/r0/zlgwr7551fg116vdpn_9s7ph0000gn/T/ipykernel_31847/298980794.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument length.

See https://docs.jax.dev/en/latest/errors.html#jax.errors.ConcretizationTypeError

## Structured control flow primitives

To avoid requiring retracing and avoid unrolling loops in control flow conditioned on runtime values, JAX provides basic primitives.

### Conditions


In [4]:
from jax import lax


# Python version
def cond(pred, true_fun, false_fun, operand):
    if pred:
        return true_fun(operand)
    else:
        return false_fun(operand)


# JAX Version
# lax.cond(pred, true_fun, false_fun, operand)
lax.cond(True, lambda x: x + 1, lambda x: x - 1, 10)

Array(11, dtype=int32, weak_type=True)

LAX also provides other functions for conditions:


In [5]:
# 1) Choose between two precomputed values of the same shape and dtype
# lax.select(pred, a, b)
a = jnp.arange(10) * 10
b = jnp.arange(10) * 5
lax.select(a[-1] > b[-1], a, b)

# 2) Like lax.cond, but pick between any number of branch functions
# lax.switch(index, branches, operand)
branches = [lambda x: x + 1, lambda x: x - 1, lambda x: x * 2]
lax.switch(0, branches, 10)

Array(11, dtype=int32, weak_type=True)

The `jax.numpy` API also provides conditional functions:


In [6]:
# 1) Similar to lax.select but for numpy arrays
# jnp.where(pred, a, b)
a = jnp.arange(10)
jnp.where(a % 2 == 0, a, 0)

# 2) Similar to lax.switch but for numpy arrays
# jnp.piecewise(a, conds, funcs)
conds = [a % 2 == 0, a % 5 == 0]
funcs = [lambda x: x - 1, lambda x: x + 1, lambda x: x]  # Last func is for `else` case
jnp.piecewise(a, conds, funcs)

# 3) jnp.select(conds, values)
jnp.select(conds, [0, 1], default=2)

Array([0, 2, 0, 2, 0, 1, 0, 2, 0, 2], dtype=int32, weak_type=True)

### While Loops


In [7]:
# Python version
def while_loop(cond_fun, body_fun, init_val):
    val = init_val
    while cond_fun(val):
        val = body_fun(val)
    return val


# JAX version
# lax.while_loop(cond_fun, body_fun, init_val)
lax.while_loop(lambda x: x < 10, lambda x: x + 1, 0)

Array(10, dtype=int32, weak_type=True)

### For Loops


In [8]:
# Python version
def for_loop(start, stop, body_fun, init_val):
    val = init_val
    for i in range(start, stop):
        val = body_fun(i, val)
    return val


# JAX version
# lax.for_loop(start, stop, body_fun, init_val)
print(lax.fori_loop(0, 10, lambda i, x: x + i, 0))

45


## Logical Operators

`jax.numpy` provides `logical_and`, `logical_or`, and `logical_not` which operate element-wise on arrays and can be ran in JIT contexts.


In [9]:
def python_check_positive_even(x):
    is_even = x % 2 == 0
    return is_even and (x > 2)  # Short-circuits: when `is_even` is False, x > 0 is not evaluated


@jax.jit
def jax_check_positive_even(x):
    is_even = x % 2 == 0
    return jnp.logical_and(is_even, x > 0)  # Does not short-circuit, x > 0 is always evaluated


print(python_check_positive_even(24))
print(jax_check_positive_even(24))

True
True


Python logical operators cannot be applied element-wise:


In [10]:
x = jnp.array([1, 2, 5])
print(jax_check_positive_even(x))  # Success! Can be applied element-wise
print(python_check_positive_even(x))  # Error! Cannot be applied element-wise

[False  True False]


ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

## Behavior in Grad

Note that the Python control flow requirements are only in JIT contexts.

Python control flow works fine in `@jax.grid` contexts without issues.
