In __control flow__ graph, if the path depends on the values of the inputs, the function cannot be JIT-compiled.

But when the path depends on the __shape__ or __dtype__ of the inputs, function is re-compiled every time it is called.

In [25]:
from jax import grad, jit
import jax.numpy as jnp

In [26]:
@jit 
def f(x):
    for i in range(3):
        x = 2*x
    return x
print(f(3))

24


In [27]:
@jit 
def g(x):
    y = 0. 
    for i in range(x.shape[0]):
        y = y + x[i]
    return y
print(g(jnp.array([1., 2., 3.])))

6.0


In [28]:
# #These will throw errors
# @jit 
# def f(x):
#     if x <3:
#         return 3*x ** 2
#     else:
#         return - 4 * x 
# f(2) #TracerBoolConversionError


# @jit 
# def g(x):
#     return (x >0) and (x <3)
# g(2) #tracerboolconversionerror

JIT-compilation uses compiled function for multiple evaluations. It traces compiled function using __abstract values__ instead of concrete values. 

Specifically, it uses ``ShapedArray`` which represent the set of all arrays with fixed __shape__ and __dtype__ 

For example:
- ``ShapedArray((3,), jnp.float32)`` represents all arrays of shape ``(3,)`` and dtype ``float32``
- ``ShapedArray((), jnp.float32)`` represents all scalar arrays of dtype ``float32``

In [29]:
@jit
def f(x):
  return x ** 2

# First call compiles the function
result1 = f(jnp.array([1., 2., 3.], dtype=jnp.float32))
print(result1)
# Second call reuses the compiled code
result2 = f(jnp.array([4., 5., 6.], dtype=jnp.float32))
print(result2)

[1. 4. 9.]
[16. 25. 36.]


__Trade-Off Abstraction vs Traceability__

If the function contains control flow that depends on __runtime values__, JAX cannot trace the function because the abstract value ``ShapedArray`` does not represent a specific concrete value

In [30]:
# @jax.jit
# def f(x):
#   if x < 3:  # x is a ShapedArray((), jnp.float32), not a concrete value
#     return 3. * x ** 2
#   else:
#     return -4 * x

# f(2.0)  # Raises TracerBoolConversionError

Here, ``x <3`` evaluated to the abstract ``ShapedArray((), jnp.bool_)`` which represent the set as ``{True, False}`` JAX cannot decide which branch to follow because value of $x$ is not concrete

This problem, though, can be handled with ``static_argnames`` and ``static_argnums``. These two allows certain arguments to be treated as __static__

In [31]:
def f(x):
  if x < 3:
    return 3. * x ** 2
  else:
    return -4 * x

f = jit(f, static_argnames='x')

print(f(2.))

12.0


Similarly, ``static_argnames`` can be used to treat the loop bound as a static value. This causes the loop to be unrolled at compile time.

In [32]:
def f(x, n):
  y = 0.0
  for i in range(n):  # n is a runtime value
    y = y + x[i]
  return y

# Use static_argnames to treat 'n' as a static value
f_jitted = jit(f, static_argnames='n')

result = f_jitted(jnp.array([2., 3., 4.]), 2)
print(result)  # Output: 5.0

5.0


#### Structured Control Flow Primitives

- ``lax.cond`` differentiable

In [33]:
#python equivalent
def cond(pred, true_fun, false_fun, arg):
  if pred:
    return true_fun(arg)
  else:
    return false_fun(arg)

In [35]:
from jax import lax 

operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, operand) #Output: DeviceArray([1.], dtype=float32)
lax.cond(False, lambda x: x+1, lambda x: x-1, operand) #Output: DeviceArray([-1.], dtype=float32)

Array([-1.], dtype=float32)

``lax.select`` is used to choose between two arrays based on boolean condition

In [37]:
condition = jnp.array([True, False, True])
on_true = jnp.array([1., 2., 3.])
on_false = jnp.array([-1., -2., -3.])

result = lax.select(condition, on_true, on_false)  
print(result)  # Output: DeviceArray([ 1., -2.,  3.], dtype=float32)

[ 1. -2.  3.]


``lax.switch`` is used to choose between __multiple callable choices__ based on single scalar index.

In [38]:
def f1(x):
    return 2*x
def f2(x):
    return 3*x
def f3(x):
    return 4*x

#using lax.switch to choose a function based on a condition
index = 2 
x = 3
result = lax.switch(index, [f1, f2, f3], x)
print(result) #Output: 12

12


``jnp.where`` works with three-argument in Numpy-wrapper for __lax.select__ choosing between two arrays based on boolean condition

In [39]:
condition = jnp.array([False, False, True])
on_true = jnp.array([1., 2., 3.])
on_false = jnp.array([-1., -2., -3.]) 

#use jnp.where to choose between on_true and on_false based on condition
result = jnp.where(condition, on_true, on_false)
print(result) #Output: [ -1. -2.  3.]

[-1. -2.  3.]


``jnp.piecewise`` is a numpy-wrapper for ``lax.switch``, but instead of a single scalar index, it switches based on a list of boolean conditions.

In [47]:
x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])

conditions = [x < 0, x >= 0]
functions = [lambda x: 0 * x, lambda x: x]

result = jnp.piecewise(x, conditions, functions)
print(result) #Output: [0 0 0 0 0 1 2 3 4]

[0 0 0 0 0 1 2 3 4]


__while_loop__

In [48]:
#python equivalent

def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val

In [50]:
#JAX equivalent
init_val = 0
condition = lambda x: x< 10 
function_body = lambda x: x +1
result = lax.while_loop(condition, function_body, init_val)
print(result) #Output: DeviceArray(10, dtype=int32)

10


__fori_loop__

In [51]:
#python equivalent
def fori_loop(start, stop, body_func, init_val):
    val = init_val
    for i in range(start, stop):
        val = body_func(i, val)
    return val

In [52]:
init_val = 0
start = 0
stop = 10 
function_body = lambda i, x: x+i 
lax.fori_loop(start, stop, function_body, init_val)

Array(45, dtype=int32, weak_type=True)

__Logical Operators__ in JAX are ``logical_and``, ``logical_or``, and ``logical_not`` which operates element-wise on arrays and can be evaluated under __jit-compilation__

__Bitwise Operators__ ``(&, |, ~)`` can also be used with __JIT-compilation__

In [54]:
def python_check_positive_even(x):
    is_even = x % 2 == 0
    return is_even and (x > 0)

@jit 
def jax_check_positive_even(x):
    is_even = x % 2 == 0
    return jnp.logical_and(is_even, x > 0)

print(python_check_positive_even(2)) #Output: True 
print(jax_check_positive_even(5)) #Output: True

True
False


In [55]:
#can be applied to array as well 
x = jnp.array([1, 2, 3, 4, 5, 6])
print(jax_check_positive_even(x)) #Output: [False  True False  True False  True]

[False  True False  True False  True]


__Python control_flow + autodiff__

In [57]:
def f(x):
    if x < 3:
        return 30*x ** 2
    else:
        return - 4 * x
    
print(grad(f)(3.))
print(grad(f)(2.)) #Output: 6.0

-4.0
120.0
