<a href="https://colab.research.google.com/github/jecampagne/JaxTutos/blob/main/JAX_control_flow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np

import jax
import jax.numpy as jnp

from jax import grad, jit, vmap
from jax import jacfwd, jacrev

jax.config.update("jax_enable_x64", True)

from functools import partial

# Theme: explore some JAX flux control 

Complementary to `JIT_fractals.ipyn` nb where `jax.lax.while` and `jax.lax.cond` are used.

## jax.lax.fori_loop 
 the equivalen code is
```python
def fori_loop(start, stop, body_fun, init_val):
  val = init_val
  for i in range(start, stop):
    val = body_fun(i, val)
  return val
```

Basic use

In [None]:
def body(i,val):
  return val+i**2
val_init=0
N=10
res = jax.lax.fori_loop(0,N+1,body,val_init)
print(res, N*(N+1)*(2*N+1)//6)



385 385


just a reminder that jax device array is immutable

In [None]:
arr = jnp.arange(0,N+1)
arr

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10], dtype=int64)

In [None]:
jnp.sum(arr**0), jnp.sum(arr**1), jnp.sum(arr**2), jnp.sum(arr**3)

(Array(11, dtype=int64),
 Array(55, dtype=int64),
 Array(385, dtype=int64),
 Array(3025, dtype=int64))

In [None]:
def body(dum,val):
  a, b, iptr=val
  b =  b.at[iptr].set(jnp.sum(a**iptr)) # a reminder that array are immutable
  iptr += 1
  return a, b, iptr

In [None]:
b=jnp.zeros(shape=(4,))
_,b,_ = jax.lax.fori_loop(0,len(b),body,(arr,b,0))
b

Array([  11.,   55.,  385., 3025.], dtype=float64)

## jax.lax.scan

equivalen code

```python
def scan(f, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = f(carry, x)
    ys.append(y)
  return carry, np.stack(ys)
```

basic use to get the $\sum_j a_j^i$ for different values of i.

In [None]:
def body(a, i):
  return a, jnp.sum(a**i) 

arr = jnp.arange(0,N+1)
_, res = jax.lax.scan(body,arr,jnp.array([0,1,2,3]))
print(res)

[  11   55  385 3025]


## A minimizer where carry is updated w/o changing its structure

In [None]:
from jax.example_libraries.optimizers  import adam,sgd

sum_abs = lambda x: jnp.sum(jnp.abs(x))

def minimize(
    fun, inputs, optimizer=sgd, schedule=1e-3, maxiter=4, gtol=1e-2,
):
    """Optimizes inputs to minimize output of energy_fn
    Args:
        fun: maps x -> energy
        inputs: initial inputs
        optimizer: any jax.example_libraries.optimizers
        schedule: learning rate schedule for optimizer
        maxiter: maximum number of optimizer steps
        gtol: cutoff for gradients to stop early
    """
    init_optimizer, update_optimizer, get_inputs = optimizer(schedule)
    energy_gradient = grad(lambda x: jnp.sum(fun(x)))
    opt_state = init_optimizer(inputs)
    gradient = energy_gradient(inputs)
    sum_abs_gradient = sum_abs(gradient)
    carry = (0, opt_state, gradient, sum_abs_gradient, inputs)

    # jit not necesary done auto
    def step_once(carry):
        (step, opt_state, gradient, sum_abs_gradient, inputs) = carry
        opt_state = update_optimizer(step, gradient, opt_state)
        inputs = get_inputs(opt_state)
        gradient = energy_gradient(inputs)
        sum_abs_gradient = sum_abs(gradient)
        carry = (step + 1, opt_state, gradient, sum_abs_gradient, inputs)
        return carry

    def noop(carry):
        return carry

    # jit not necesary done auto
    def body_fn(carry, x):
        sum_abs_gradient = carry[-2]
        return jax.lax.cond(sum_abs_gradient < gtol, noop, step_once, carry), None

    carry, _ = jax.lax.scan(body_fn, carry, None, length=maxiter)
    return carry

In [None]:
def fun(x):
  return (x-1.)**2

In [None]:
_, final_state, _, _, _ =  minimize(fun,0.,optimizer=sgd, schedule=0.1, maxiter=50)
print("optimal x=",final_state.packed_state[0][0])

optimal x= 0.9952776335171303


## Auto-diff and scan/fori_loop

Let defines $f:\mathbb{R}^n → \mathbb{R}^3$ as
$$
f(x_1,\dots,x_n) = \begin{pmatrix}
\sum_{i=1}^n x_i \\
\sum_{i=1}^n x_i^2 \\
\sum_{i=1}^n x_i^3
\end{pmatrix}
$$
The jacobian is
$$
\nabla f(x) =  \begin{pmatrix}
\frac{\partial f_1}{\partial x_1}=1 & \dots & \frac{\partial f_1}{\partial x_n}\\
\frac{\partial f_2}{\partial x_1}=2x_1& \ddots & \vdots \\
\frac{\partial f_3}{\partial x_1}=3x_1^2 & \dots & \frac{\partial f_3}{\partial x_n} \\
\end{pmatrix}_{(x)}
$$
we can compute then
$$
\begin{pmatrix}
\sum_{i=1}^n 1 \\
\sum_{i=1}^n 2 x_i \\
\sum_{i=1}^n 3 x_i^2
\end{pmatrix}_{(x)}
$$


Using scan

In [None]:
@jit
def f(x):

  def body(a, i):
    return a, jnp.sum(a**i) 

  _, res = jax.lax.scan(body,x,jnp.array([1,2,3]))

  return res

In [None]:
N=10
x=jnp.arange(0.,N+1)

In [None]:
f(x), N*(N+1)//2,  N*(N+1)*(2*N+1)//6, N*N*(N+1)*(N+1)//4

(Array([  55.,  385., 3025.], dtype=float64), 55, 385, 3025)

In [None]:
jnp.sum((jacfwd(f))(x), axis=1), N+1, N*(N+1),  N*(N+1)*(2*N+1)//2

(Array([  11.,  110., 1155.], dtype=float64), 11, 110, 1155)

In [None]:
jnp.sum((jacrev(f))(x), axis=1)

Array([  11.,  110., 1155.], dtype=float64)

Using a fori_loop

In [None]:
@jit
def g(x):

  def body(i,val):
    v1, v2, v3, a = val 
    return v1+a[i], v2+a[i]**2, v3+a[i]**3, a

  val_init=(0.,0.,0.,x)
  N=x.shape[0]
  res = jax.lax.fori_loop(0,N,body,val_init)
  return jnp.array(list(res[:-1]))

In [None]:
g(x)

Array([  55.,  385., 3025.], dtype=float64)

In [None]:
jnp.sum((jacfwd(g))(x), axis=1)

Array([  11.,  110., 1155.], dtype=float64)

In [None]:
jnp.sum((jacrev(g))(x), axis=1)

Array([  11.,  110., 1155.], dtype=float64)

In the above examples, the 'N' upper loop index `N=x.shape[0]` is known at compilation time because **array shapes are static**. So even is the array shape can change from one another, the compilation can be triggered, and we can use either jacfwd or jacrev. 

BUT there are cases that lower or upper indexes are not known at compilation time, so then use scan. 



# Crash study:  JIT partialization is a solution here

As an exemple following the previous ones, consider that we trunc the loop by an upper bound "N" as followed 

In [None]:
#@partial(jit, static_argnums=(1,))  # you do it after experiencing the crash of jacrev
@jit
def h(x,N):
  print("N:",N, "\nx:",x)  # with JIT you will see if N,x are traced
  jax.debug.print("N={}",N) # to debug the true value
  def body(i,val):
    v1, v2, v3, a = val 
    return v1+a[i], v2+a[i]**2, v3+a[i]**3, a

  val_init=(0.,0.,0.,x)
  res = jax.lax.fori_loop(0,N,body,val_init)
  return jnp.array(list(res[:-1]))

The compilation of "h" is ok 

In [None]:
y = jax.random.uniform(jax.random.PRNGKey(10),shape=(100,))
h(y,10)

N: Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)> 
x: Traced<ShapedArray(float64[100])>with<DynamicJaxprTrace(level=1/0)>
N=10


Array([6.47489364, 4.8713031 , 4.04653465], dtype=float64)

The forward derivative wrt to x is ok

In [None]:
jnp.sum((jacfwd(h))(y, 10), axis=1)

N=10


Array([10.        , 12.94978727, 14.61390931], dtype=float64)

But not the backward derivative

In [None]:
jnp.sum((jacrev(h))(y, 10), axis=1) # decomment to see the trace back.

N=10


ValueError: ignored

The reason of the crash is that the compilation JAX consider N as a "Traced" object which is not a static variable, so the implementation of fori_loop is a while_loop which do not accept backward diff.

One possibility in this example  is to tell JIT not to trace "N", this is done using @partial(jit, static_argnums=(1,))

You can now go to the "h" def and replace "@jit" by its partialization. 

Notice that the same partialization is needed also for a scan version


In [None]:
#@partial(jit, static_argnums=(1,))
@jit
def fnew(x,N):
  
  tmp=x[:N]  # requires that N not to be traced to

  def body(a, i):
    return a, jnp.sum(a**i) 

  _, res = jax.lax.scan(body,tmp,jnp.array([1,2,3]))

  return res

In [None]:
fnew(y,10)

IndexError: ignored

In [None]:
jnp.sum((jacrev(fnew))(y, 10), axis=1)

IndexError: ignored

# fori_loop vs scan

Let use compute for 2 arrays
$$
S(x,y) = \frac{1}{N} \sum_{i=1,N} x_i y_i
$$

Of course if we do not consider auto-diff, S can be computed as 
```python
x@y / x.shape[0]
``` 

In [None]:
_, key1, key2 = jax.random.split(jax.random.PRNGKey(10),3) 
N = 100_000
x = jax.random.uniform(key1, shape=(N,), dtype=jnp.float32)
y = jax.random.uniform(key2, shape=(N,), dtype=jnp.float32)


In [None]:
x@y / x.shape[0]

Array(0.25065652, dtype=float32)

In [None]:
@jit
def loop_fun(x, y):

    n = x.shape[0]

    def body(i, curr): 
        return curr + x[i] * y[i]

    return jax.lax.fori_loop(0, n, body, 0.) / n

@jit
def scan_fun(x, y):
    n = x.shape[0]

    def body(carry, x):
        curr, i = carry
        return (curr + x * y[i], i+1), None

    (res, _), _ = jax.lax.scan(body, (0., 0), x)
    return res / n

In [None]:
loop_fun(x,y)

Array(0.2506565, dtype=float32)

In [None]:
scan_fun(x, y)

Array(0.2506565, dtype=float32)

In [None]:
np.alltrue(jax.jacrev(scan_fun)(x,y) ==  jax.jacrev(loop_fun)(x,y))

Array(True, dtype=bool)

# While: avoid as possible... crash here due to JIT or Jacobian Backward mode

Think diffrently !

Here an example: let
$$
X = (x_0,x_2,\dots,x_{n-1})\in [0,1]^n, \quad \sum_{i=0}^{n-1} x_i =  1
$$
Goal : set the highest value of $(x_i)_i$ to $0$ while $\sum_{i=0}^{n-1} x_i> 0.5$

## simple while loop (JIT crash here)

In [181]:
def func(x):
  x  = x/jnp.sum(x)

  while jnp.sum(x)>0.5:
    idx = jax.lax.argmax(x,0,index_dtype=jnp.int16)
    x = x.at[idx].set(0.)
  return x


In [180]:
x = jax.random.uniform(jax.random.PRNGKey(10), shape=(10,))
print(x)
y = func(x)
y, jnp.sum(y)

[0.63337177 0.50169736 0.65120645 0.12358142 0.96813596 0.83877037
 0.91412169 0.60858754 0.88261905 0.71854831]


(Array([0.09258955, 0.07334071, 0.09519672, 0.01806577, 0.        ,
        0.        , 0.        , 0.08896646, 0.        , 0.1050411 ],      dtype=float64),
 Array(0.4732003, dtype=float64))

In [172]:
jacfwd_f= jax.jacfwd(func)(x)

In [173]:
np.alltrue(jax.jacrev(func)(x) == jacfwd_f)

Array(True, dtype=bool)

In [174]:
@jit
def jfunc(x):
  x  = x/jnp.sum(x)

  while jnp.sum(x)>0.5:
    idx = jax.lax.argmax(x,0,index_dtype=jnp.int16)
    x = x.at[idx].set(0.)

  return x

In [175]:
x = jax.random.uniform(jax.random.PRNGKey(10), shape=(10,))
y = jfunc(x)
#y, jnp.sum(y)

ConcretizationTypeError: ignored

## with jax.lax.while_loop : jacrev crash here

```python
def while_loop(cond_fun, body_fun, init_val):
  val = init_val
  while cond_fun(val):
    val = body_fun(val)
  return val
```

In [188]:
@jit
def new_func(x):
  x  = x/jnp.sum(x)

  def cond(val):
    sum, _ = val
    # if sum>0.5: return True crash due to ConcretizationTypeError even w/o JIT
    return jnp.where(sum>0.5, True, False) 
  def body(val):
    sum, x = val
    idx = jax.lax.argmax(x,0,index_dtype=jnp.int16)
    x = x.at[idx].set(0.)
    return jnp.sum(x), x

  val= jax.lax.while_loop(cond, body,(1.,x))

  return val[1]

In [189]:
x = jax.random.uniform(jax.random.PRNGKey(10), shape=(10,))
y = new_func(x)
y, jnp.sum(y)

(Array([0.09258955, 0.07334071, 0.09519672, 0.01806577, 0.        ,
        0.        , 0.        , 0.08896646, 0.        , 0.1050411 ],      dtype=float64),
 Array(0.4732003, dtype=float64))

In [192]:
np.alltrue(jax.jacfwd(new_func)(x) == jacfwd_f)

Array(True, dtype=bool)

In [193]:
np.alltrue(jax.jacrev(new_func)(x) == jacfwd_f)

ValueError: ignored

The solution: use jax.lax.cond & jax.lax.scan 

In [176]:
def jfunc_bis(x):
  n = x.size
  x  = x/jnp.sum(x)

  def noop(carry):
    return carry

  @jit
  def step_once(carry):
    (step, sum, x) = carry
    idx = jax.lax.argmax(x,0,index_dtype=jnp.int16)
    x = x.at[idx].set(0.)
    carry = (step+1, jnp.sum(x), x)
    return carry

  @jit
  def body_fn(carry, dummy):
      (step, sum, x) = carry
      return jax.lax.cond(sum < 0.5, noop, step_once, carry), None


  carry = (0, 1., x)
  carry, _ = jax.lax.scan(body_fn, carry, None, length=n)  
  return carry[2]

In [177]:
x = jax.random.uniform(jax.random.PRNGKey(10), shape=(10,))
print(x)
y = jfunc_bis(x)
y

[0.63337177 0.50169736 0.65120645 0.12358142 0.96813596 0.83877037
 0.91412169 0.60858754 0.88261905 0.71854831]


Array([0.09258955, 0.07334071, 0.09519672, 0.01806577, 0.        ,
       0.        , 0.        , 0.08896646, 0.        , 0.1050411 ],      dtype=float64)

In [178]:
np.alltrue(jax.jacfwd(jfunc_bis)(x) == jacfwd_f)

Array(True, dtype=bool)

In [179]:
np.alltrue(jax.jacrev(jfunc_bis)(x) == jacfwd_f)

Array(True, dtype=bool)

# Takeaway message
- Flux controls: use jax.lax (cond, while, fori_loop, scan)
- Reminder that jax.numpy.where, jax.numpy.piecewise deal also to conditions over values in numpy arrays.
- One can use auto-diff on function that uses jax.lax control flow
- You have seen printing for debug of JIT functions
- Sometimes you need to turn off the tracing of arguments for JIT 
- While is tricky: JIT or jacrev will be a problem ! impose to think differently 
- As thomb rule when you hesitate between fori_loop/while_loop or scan, then **"always scan when you can!"**.