In __control flow__ graph, if the path depends on the values of the inputs, the function cannot be JIT-compiled.

But when the path depends on the __shape__ or __dtype__ of the inputs, function is re-compiled every time it is called.

In [2]:
from jax import grad, jit
import jax.numpy as jnp

In [3]:
@jit 
def f(x):
    for i in range(3):
        x = 2*x
    return x
print(f(3))

24


In [4]:
@jit 
def g(x):
    y = 0. 
    for i in range(x.shape[0]):
        y = y + x[i]
    return y
print(g(jnp.array([1., 2., 3.])))

6.0


In [None]:
#These will throw errors
@jit 
def f(x):
    if x <3:
        return 3*x ** 2
    else:
        return - 4 * x 
f(2) #TracerBoolConversionError


@jit 
def g(x):
    return (x >0) and (x <3)
g(2) #tracerboolconversionerror

JIT-compilation uses compiled function for multiple evaluations. It traces compiled function using __abstract values__ instead of concrete values. 

Specifically, it uses ``ShapedArray`` which represent the set of all arrays with fixed __shape__ and __dtype__ 

For example:
- ``ShapedArray((3,), jnp.float32)`` represents all arrays of shape ``(3,)`` and dtype ``float32``
- ``ShapedArray((), jnp.float32)`` represents all scalar arrays of dtype ``float32``

In [7]:
@jit
def f(x):
  return x ** 2

# First call compiles the function
result1 = f(jnp.array([1., 2., 3.], dtype=jnp.float32))
print(result1)
# Second call reuses the compiled code
result2 = f(jnp.array([4., 5., 6.], dtype=jnp.float32))
print(result2)

[1. 4. 9.]
[16. 25. 36.]


__Trade-Off Abstraction vs Traceability__

If the function contains control flow that depends on __runtime values__, JAX cannot trace the function because the abstract value ``ShapedArray`` does not represent a specific concrete value

In [None]:
@jax.jit
def f(x):
  if x < 3:  # x is a ShapedArray((), jnp.float32), not a concrete value
    return 3. * x ** 2
  else:
    return -4 * x

f(2.0)  # Raises TracerBoolConversionError

Here, ``x <3`` evaluated to the abstract ``ShapedArray((), jnp.bool_)`` which represent the set as ``{True, False}`` JAX cannot decide which branch to follow because value of $x$ is not concrete

This problem, though, can be handled with ``static_argnames`` and ``static_argnums``. These two allows certain arguments to be treated as __static__

In [11]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnames='x')

print(f(2.))

12.0


Similarly, ``static_argnames`` can be used to treat the loop bound as a static value. This causes the loop to be unrolled at compile time.

In [12]:
def f(x, n):
  y = 0.0
  for i in range(n):  # n is a runtime value
    y = y + x[i]
  return y

# Use static_argnames to treat 'n' as a static value
f_jitted = jit(f, static_argnames='n')

result = f_jitted(jnp.array([2., 3., 4.]), 2)
print(result)  # Output: 5.0

5.0


#### Structured Control Flow Primitives

- ``lax.cond`` differentiable

In [None]:
#python equivalent
def cond(pred, true_fun, false_fun, arg):
  if pred:
    return true_fun(arg)
  else:
    return false_fun(arg)

In [13]:
from jax import lax 

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand) #Output: DeviceArray([1.], dtype=float32)

lax.cond(False, lambda x: x+1, lambda x: x-1, operand) #Output: DeviceArray([-1.], dtype=float32)

Array([-1.], dtype=float32)