# JAX Primer

GenJAX is Gen & _JAX_: JAX is an array programming system in Python, which supports several unique capabilities, including automatic differentiation, user-directed parallelism, and JIT compilation via XLA to native code.

While JAX is a powerful tool, [it does have its sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).

In [None]:
import jax

## Common gotchas

Here's a quick list of things that you can expect to see if you're not careful with JAX.

### Don't use Python control flow

#### If statements

In [None]:
def f(x):
    if x > 5:
        return 3.0
    else:
        return 2.0


jax.jit(f)(3.0)

JAX works _by tracing_ your code. It passes symbolic values through your code, and dispatch the operations which occur on them.

_This error_ is saying that Python is attempting to convert _a JAX tracer value_ into a boolean, to resolve the conditional.

This is disallowed -- because _we don't know the value of the array_ at compile time / Python runtime. Because we don't know the value, how could we know what branch to take?

#### For loops

In contrast to `if`, `for` will "just work" -- except it's not a good idea to use it.

In [None]:
def f(x):
    y = 0.0
    for i in x:
        y += i
    return y


jax.jit(f)(jax.numpy.arange(10))

Why is it not a good idea to use this? Let's look at the _lowered Jaxpr code_ which JAX produces:

In [None]:
jax.make_jaxpr(f)(jax.numpy.arange(10))

JAX has _unrolled_ the entire loop. While this example JIT compiles relatively fast, the XLA compile time goes like quadratic in the size of the code. 

For any more complicated program, the compilation time will start to take longer and longer...