# Lecture 3: JAX Control Flows

<br>

<br>

We saw in Exercise 2 how performing the Minuit optimization fit using JIT compilation speeds up the code by orders of magnitude when compared to regular Python code.

<br>

Why is that?

<br>
<br>
<div>   
    <center>
        <img src="figures/Minuit_output.png" width="1000"/>
    </center>
    </div>
  

<br>

<br>

You might think that we are done - how much better than this possible get!

<br>

Well, there's more! You can make the code even more efficient before JIT compiling.

<br>

To do this, let us learn how JAX handles control flow operations - e.g. for loops, if/else statements, etc.

<br>

<br>

In [24]:
# Let's change our previous function into a loop - not taking advantage of array-based computation

from jax import jit

@jit
def fn(tuple_arr):
    
    summed_array = 0
    for entry in tuple_arr:
        
        summed_array += entry**2 - entry ** 3 - entry
    
    return summed_array


In [4]:
from jax import make_jaxpr
import jax.numpy as jnp

# Look at how the loops are taken into account in the computation chart
print(make_jaxpr(fn)(jnp.ones(2)))



{ lambda ; a:f32[2]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[2]. let
          d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] c
          e:f32[] = squeeze[dimensions=(0,)] d
          f:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] c
          g:f32[] = squeeze[dimensions=(0,)] f
          h:f32[] = integer_pow[y=2] e
          i:f32[] = integer_pow[y=3] e
          j:f32[] = sub h i
          k:f32[] = sub j e
          l:f32[] = add k 0.0
          m:f32[] = integer_pow[y=2] g
          n:f32[] = integer_pow[y=3] g
          o:f32[] = sub m n
          p:f32[] = sub o g
          q:f32[] = add l p
        in (q,) }
      name=fn
    ] a
  in (b,) }


In [5]:
# Look at how the loops are taken into account in the computation chart
print(make_jaxpr(fn)(jnp.ones(5)))

{ lambda ; a:f32[5]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[5]. let
          d:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=(1,)] c
          e:f32[] = squeeze[dimensions=(0,)] d
          f:f32[1] = slice[limit_indices=(2,) start_indices=(1,) strides=(1,)] c
          g:f32[] = squeeze[dimensions=(0,)] f
          h:f32[1] = slice[limit_indices=(3,) start_indices=(2,) strides=(1,)] c
          i:f32[] = squeeze[dimensions=(0,)] h
          j:f32[1] = slice[limit_indices=(4,) start_indices=(3,) strides=(1,)] c
          k:f32[] = squeeze[dimensions=(0,)] j
          l:f32[1] = slice[limit_indices=(5,) start_indices=(4,) strides=(1,)] c
          m:f32[] = squeeze[dimensions=(0,)] l
          n:f32[] = integer_pow[y=2] e
          o:f32[] = integer_pow[y=3] e
          p:f32[] = sub n o
          q:f32[] = sub p e
          r:f32[] = add q 0.0
          s:f32[] = integer_pow[y=2] g
          t:f32[] = integer_pow[y=3] g
          u:f32[] = sub s t

In [22]:
# Let's change our previous function into a loop - not taking advantage of array-based computation

from jax.lax import fori_loop

@jit
def fn_JCF(tuple_arr):
    
    def body_fun(n, carry):
        
        sum_array = carry
        
        sum_array += tuple_arr[n]**2 - tuple_arr[n] ** 3 - tuple_arr[n]
        
        return (sum_array)
    
    sum_array = fori_loop(0, len(tuple_arr), body_fun, (0.0))
    
    return sum_array

# print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45


In [25]:
print(fn(jnp.ones(2)))

print(fn_JCF(jnp.ones(2)))

-2.0
-2.0


In [29]:
print(make_jaxpr(fn_JCF)(jnp.ones(5)))

{ lambda ; a:f32[5]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[5]. let
          _:i32[] d:f32[] = scan[
            jaxpr={ lambda ; e:f32[5] f:i32[] g:f32[]. let
                h:i32[] = add f 1
                i:bool[] = lt f 0
                j:i32[] = add f 5
                k:i32[] = select_n i f j
                l:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
                m:f32[] = gather[
                  dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
                  fill_value=None
                  indices_are_sorted=True
                  mode=GatherScatterMode.PROMISE_IN_BOUNDS
                  slice_sizes=(1,)
                  unique_indices=True
                ] e l
                n:f32[] = integer_pow[y=2] m
                o:bool[] = lt f 0
                p:i32[] = add f 5
                q:i32[] = select_n o f p
                r:i32[1] = broadcast_in_dim[broad

In [30]:
print(make_jaxpr(fn_JCF)(jnp.ones(5000)))

{ lambda ; a:f32[5000]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[5000]. let
          _:i32[] d:f32[] = scan[
            jaxpr={ lambda ; e:f32[5000] f:i32[] g:f32[]. let
                h:i32[] = add f 1
                i:bool[] = lt f 0
                j:i32[] = add f 5000
                k:i32[] = select_n i f j
                l:i32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] k
                m:f32[] = gather[
                  dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
                  fill_value=None
                  indices_are_sorted=True
                  mode=GatherScatterMode.PROMISE_IN_BOUNDS
                  slice_sizes=(1,)
                  unique_indices=True
                ] e l
                n:f32[] = integer_pow[y=2] m
                o:bool[] = lt f 0
                p:i32[] = add f 5000
                q:i32[] = select_n o f p
                r:i32[1] = broadca

In [32]:
%%timeit -n1 -r1

print(fn(jnp.ones(1000)))

-1000.0
5.98 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [33]:
%%timeit -n1 -r1

print(fn_JCF(jnp.ones(1000)))

-1000.0
76.1 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [34]:
%%timeit -n1 -r1

print(fn(jnp.ones(1000)))

-1000.0
1.3 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [35]:
%%timeit -n1 -r1

print(fn_JCF(jnp.ones(1000)))

-1000.0
1.53 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
