# Lecture 3: JAX Control Flows

<br>

We saw in Exercise 2 how performing the Minuit optimization fit using JIT compilation speeds up the code by orders of magnitude when compared to regular Python code.

<br>

Why is that?

<br>
<br>
<div>   
    <center>
        <img src="figures/Minuit_output.png" width="1000"/>
    </center>
    </div>
  

<br>

<br>

Our test statistic function in the profile likelihood fit done in Exercise 2 has a for loop:

<br>

```
    for n in range(1, param_array.shape[0]):
         
        fact_sig *= jnp.where(jnp.abs(param_array[n])<=1.0, poly_interp(param_array[n], arr_ratio_up_sig[n-1], arr_ratio_down_sig[n-1], 1.0),exp_interp(param_array[n], arr_ratio_up_sig[n-1], arr_ratio_down_sig[n-1], 1.0))   

```

<br>

Let's see how JAX interprets this for loop with a simple example.

<br>

In [49]:
# Let's start with a simpler function and introduce a loop - not taking advantage of array-based computation

from jax import jit

@jit
def fn(tuple_arr):
    
    summed_array = 0
    for entry in tuple_arr:
        
        summed_array += entry**2 - entry ** 3 - entry
    
    return summed_array


<br>

In [4]:
from jax import make_jaxpr
import jax.numpy as jnp

# Look at how the loops are taken into account in the computation chart
print(make_jaxpr(fn)(jnp.ones(2)))       # For an array with size 2, we get two sets of repeated computations 
                                         # one per each iteration




{ lambda ; a:f32[2]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[2]. let
          d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] c
          e:f32[] = squeeze[dimensions=(0,)] d
          f:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] c
          g:f32[] = squeeze[dimensions=(0,)] f
          h:f32[] = integer_pow[y=2] e
          i:f32[] = integer_pow[y=3] e
          j:f32[] = sub h i
          k:f32[] = sub j e
          l:f32[] = add k 0.0
          m:f32[] = integer_pow[y=2] g
          n:f32[] = integer_pow[y=3] g
          o:f32[] = sub m n
          p:f32[] = sub o g
          q:f32[] = add l p
        in (q,) }
      name=fn
    ] a
  in (b,) }


<br>

In [5]:
# Look at how the loops are taken into account in the computation chart
print(make_jaxpr(fn)(jnp.ones(5)))     # Now that we have 5-sized array, so the computations are repeated 5 times

{ lambda ; a:f32[5]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[5]. let
          d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] c
          e:f32[] = squeeze[dimensions=(0,)] d
          f:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] c
          g:f32[] = squeeze[dimensions=(0,)] f
          h:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=(1,)] c
          i:f32[] = squeeze[dimensions=(0,)] h
          j:f32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=(1,)] c
          k:f32[] = squeeze[dimensions=(0,)] j
          l:f32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=(1,)] c
          m:f32[] = squeeze[dimensions=(0,)] l
          n:f32[] = integer_pow[y=2] e
          o:f32[] = integer_pow[y=3] e
          p:f32[] = sub n o
          q:f32[] = sub p e
          r:f32[] = add q 0.0
          s:f32[] = integer_pow[y=2] g
          t:f32[] = integer_pow[y=3] g
          u:f32[] = sub s t

<br>

As we keep increasing the size of the array, the for loop has to iterate through more and more elements. 

<br>

<br>

When you use a regular Python loop construct with `jit`, JAX doesn't see the loop and instead just traces out an unrolled computation.

<br>

This is clearly sub-optimal if we pass a very large array to our function - JAX tracing will unroll the for loop for all the iterations. 

JIT then has to work that much harder to optimize the entire un-rolled computation with a lot of redundant calculations.

<br>

<br>

**JAX Control Flow Operations**

<br>

JAX offers alternative control flow primitives to replace Python control flow operations. For example:

<br>

- `lax.cond` - Substitute for the if/else loop

- `lax.fori_loop` -  Substitute for the for loop

- `lax.while_loop` -  Substitute for the while loop


<br>

These are just rough comparisons - the syntax and inner working is different from regular Python control flow.

Why use JAX's control flow operations? Let's start with an example of `lax.fori_loop`.


<br>

In [45]:
# Substitute the regular Python for loop with a lax.fori_loop

from jax.lax import fori_loop

@jit
def fn_JCF(tuple_arr):
    
    def body_fun(n, carry):
        
        carry += tuple_arr[n]**2 - tuple_arr[n] ** 3 - tuple_arr[n]
        
        return carry
    
    sum_array = fori_loop(0, len(tuple_arr), body_fun, 0.0)
    
    return sum_array


<br>

In [46]:
# Let's see if our new function does the same thing as the original function

print(fn(jnp.ones(2)))

print(fn_JCF(jnp.ones(2)))

-2.0
-2.0


<br>

Okay, so now let's see how JAX tracing handles this new function with the alternative control flow:

<br>

In [47]:
print(make_jaxpr(fn_JCF)(jnp.ones(5)))

{ lambda ; a:f32[5]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[5]. let
          _:i32[] d:f32[] = scan[
            jaxpr={ lambda ; e:f32[5] f:i32[] g:f32[]. let
                h:i32[] = add f 1
                i:bool[] = lt f 0
                j:i32[] = add f 5
                k:i32[] = select_n i f j
                l:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
                m:f32[] = gather[
                  dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
                  fill_value=None
                  indices_are_sorted=True
                  mode=GatherScatterMode.PROMISE_IN_BOUNDS
                  slice_sizes=(1,)
                  unique_indices=True
                ] e l
                n:f32[] = integer_pow[y=2] m
                o:bool[] = lt f 0
                p:i32[] = add f 5
                q:i32[] = select_n o f p
                r:i32[1] = broadcast_in_dim[broad

<br>

In [48]:
print(make_jaxpr(fn_JCF)(jnp.ones(5000)))

{ lambda ; a:f32[5000]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[5000]. let
          _:i32[] d:f32[] = scan[
            jaxpr={ lambda ; e:f32[5000] f:i32[] g:f32[]. let
                h:i32[] = add f 1
                i:bool[] = lt f 0
                j:i32[] = add f 5000
                k:i32[] = select_n i f j
                l:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
                m:f32[] = gather[
                  dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
                  fill_value=None
                  indices_are_sorted=True
                  mode=GatherScatterMode.PROMISE_IN_BOUNDS
                  slice_sizes=(1,)
                  unique_indices=True
                ] e l
                n:f32[] = integer_pow[y=2] m
                o:bool[] = lt f 0
                p:i32[] = add f 5000
                q:i32[] = select_n o f p
                r:i32[1] = broadca

<br>

<br>

Using `lax.fori_loop` avoids un-rolling large loops! 

When you use `fori_loop`, JAX tracer sees that loop as a primitive (note how fori_loop is a method from `jax.lax`), and stages out a loop construct in the XLA program.

<br>

Let's see the effect this has on the efficiency of a JIT-compiled function

<br>


<br>

In [32]:
%%timeit -n1 -r1

# Let's time the initial compilation step of the JIT-compiled function with Python for loop

print(fn(jnp.ones(1000)))

-1000.0
5.98 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

In [33]:
%%timeit -n1 -r1

# Do the same for JIT-compiled function with JAX fori_loop

print(fn_JCF(jnp.ones(1000)))  # Notice how JIT compilation is now significantly faster!

-1000.0
76.1 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

In [34]:
%%timeit -n1 -r1

# Let's time post-compilation execution

print(fn(jnp.ones(1000)))

-1000.0
1.3 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

In [35]:
%%timeit -n1 -r1

print(fn_JCF(jnp.ones(1000)))   # Not any significant difference - JIT made both functions efficient
                                # but took longer to compile on the Python for loop function due to it's size

-1000.0
1.53 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

Using primitive `jax.lax` control flow operations can speed up the initial JIT compilation step - if compilation time is the limiting factor this can be very powerful.

<br>

For a handful of loop iterations, Python control flows work relatively well, but for many loop iterations and particularly with memory intensive computations, JAX’s structured control flow operations can accelerate the compilation step significantly.

<br>

**One more primitive control flow example**

We saw in the last lecture how forcing a traced input to take on a specific value with JIT compilation didn't go well.

<br>


In [53]:
# Let's see this for a simple example

@jit
def f(x):
    if x < 3:                        # Forcing the traced input to take on a value
        return 3. * x ** 2
    else:
        return -4 * x

# This will fail!
f(2)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function f at /tmp/ipykernel_538/3701872584.py:1 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

<br>

In [60]:
# Let's try to use JAX control flow on this function
import jax

@jit
def f(x):
    return jax.lax.cond(x < 3.0, lambda y: 3.0 * y ** 2, lambda y: -4.0 * y, x)

# This will NOT fail!
f(2)

DeviceArray(12., dtype=float32, weak_type=True)

<br>

Let's see how JAX traces out this function when we used it's primitive control flow `lax.cond`:

<br>

In [61]:
make_jaxpr(f)(2)     # See how the code is branched out for either possibilities!

{ lambda ; a:i32[]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:i32[]. let
          d:f32[] = convert_element_type[new_dtype=float32 weak_type=True] c
          e:bool[] = lt d 3.0
          f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
          g:f32[] = cond[
            branches=(
              { lambda ; h:i32[]. let
                  i:f32[] = convert_element_type[
                    new_dtype=float32
                    weak_type=True
                  ] h
                  j:f32[] = mul i -4.0
                in (j,) }
              { lambda ; k:i32[]. let
                  l:i32[] = integer_pow[y=2] k
                  m:f32[] = convert_element_type[
                    new_dtype=float32
                    weak_type=True
                  ] l
                  n:f32[] = mul m 3.0
                in (n,) }
            )
            linear=(False,)
          ] f c
        in (g,) }
      name=f
    ] a
  in (b,) }

<br>

That is all for today's lectures! For more information on the other similar operations, you may refer to https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow.

<br>

In the exercise today you will convert the test statistic funtion from exercise 2 into one that uses JAX's control flow operations and see the magic!