# Lecture 2: Tracers, JIT compilation and Sharp Edges

<br>

<br>

We saw how JAX has a numpy-like user API for performing array-based computing.

It has the capability to perform Composable Function Transformations on our Python/NumPy code and use it with the XLA compiler to perform super-fast computations on CPU, GPU or TPU

<br>

$$\text{Python} \rightarrow \text{Intermediate Representation} \rightarrow \text{Transformations}$$

<br>

<br>

<br>

Some examples of composable transformations include automatic differentiation, JIT compilation, parallelization on multi-core hardware, etc.

We explored automatic differentiation in the last lecture.

In this lecture, we will explore JIT compilation as well as find out a bit about how JAX works under the hood with these intermediate representations and transformations.

<br>

<br>

### JIT Compilation

<br>

JIT stands for Just-In-Time compilation, as opposed to AOT (Ahead-Of-Time) compilation.

As the name suggests, compilation of the code happens $\textit{just in time}$ for computation.


<br>

<br>

During JIT compilation, JAX applies a series of optimizations on primitive `lax` operations to generate efficient XLA executable code on CPU, GPU or TPU.

Once a function has been JIT-compiled, JAX caches the resulting XLA code so that it can be re-used in subsequent calls. 

<br>
<br>

$$\textbf{This is where the power of JIT compilation comes in - after an initial compilation phase,} \\ \textbf{the subsequent calls to the JIT-compiled function are super fast!}$$ 

<br>

<br>

In [76]:
import jax
from jax import jit

def fn(tuple_arr):
    
    return jnp.sum(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)

%timeit -r1 -n1 fn(jnp.ones(100)).block_until_ready()

fn_compiled = jit(fn)
%timeit -r1 -n1 fn_compiled(jnp.ones(100)).block_until_ready()

1.47 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
56.5 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

<br>

As mentioned before, the power of JIT compilation only comes in when we have intend to repeat function calls many times - for example finding the minimum of a loss function by repeatedly calling the loss function and it's gradient.

During the first computation, JAX performs the optimizations which takes some time. - hence the slowed down computation time.

<br>

<br>

In [79]:
# Let's try again

%timeit -r1 -n1 fn(jnp.ones(100)).block_until_ready()

%timeit -r1 -n1 fn_compiled(jnp.ones(100)).block_until_ready()

1.64 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
490 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

In [80]:
# Let's try again with a different sized array

%timeit -r1 -n1 fn(jnp.ones(1000)).block_until_ready()

%timeit -r1 -n1 fn_compiled(jnp.ones(1000)).block_until_ready()

1.3 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
42.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

In [83]:
# Let's re-do the same computation with the same array as before

%timeit -r1 -n1 fn(jnp.ones(1000)).block_until_ready()

%timeit -r1 -n1 fn_compiled(jnp.ones(1000)).block_until_ready()

2.45 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
560 µs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


<br>

<br>

To better understand what's going on, let's find out how JAX performs under the hood.

<br>

<br>

In [85]:
import jax
from jax import jit

# Equivalent to doing jit(fn), we can optionally put a decorator that does the same thing

@jit
def fn(tuple_arr):
    
    print(tuple_arr)   # Let's print out the input we passed to the function
    
    return jnp.sum(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)


print(fn(jnp.ones(100)))


Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=0/1)>
-100.0


<br>

<br>

JAX first takes in the shape of the input, and uses it to create an abstract object called `ShapedArray` that has a specific data type and size.

But it has no value!

<br>

<br>

In [86]:
print(jax.make_jaxpr(fn)(jnp.ones(100)))

Traced<ShapedArray(float32[100])>with<DynamicJaxprTrace(level=1/1)>
{ lambda ; a:f32[100]. let
    b:f32[] = xla_call[
      call_jaxpr={ lambda ; c:f32[100]. let
          d:f32[100] = integer_pow[y=2] c
          e:f32[100] = integer_pow[y=3] c
          f:f32[100] = sub d e
          g:f32[100] = sub f c
          h:f32[] = reduce_sum[axes=(0,)] g
        in (h,) }
      name=fn
    ] a
  in (b,) }


<br>

<br>

Then, it passes this `ShapedArray` object through the primitive `lax` operations to create a computation chart for function execution. This part of the process is known as $\textit{tracing}$.

$\textbf{Introduce some charts}$

But note that the `ShapedArray` object has a specific size and type now `a:f32[100]` - but that is all it has. This object doesn't store any values.

<br>

<br>

Once JAX finishes creating the computation chart with `lax` operations, it then optimizes the code for the XLA compiler to run on any machine of choice.

This optimization step is then cached.

Upon repeated executions with inputs of the same type and size, but any different value, the cached code is run on the machine with super-fast execution times.

<br>

<br>

#### Key Summary of JIT compilation

<br>

<br>

By default JAX executes operations one at a time, in sequence.

We saw an example of this with our purpose-ly bad written function and how JAX converted it.

<br>

<br>

In [3]:
# Let's see how JAX simplifies our expression with it's computation graph

def fnc_jax(x1, x2):
    
    return (jnp.divide(x1,x2) - jnp.exp(x2))*(jnp.sin(jnp.divide(x1,x2)) + jnp.divide(x1,x2) - jnp.exp(x2))

jax.make_jaxpr(fnc_jax)(1.0,1.0)



{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = div a b
    d:f32[] = exp b
    e:f32[] = sub c d
    f:f32[] = div a b
    g:f32[] = sin f
    h:f32[] = div a b
    i:f32[] = add g h
    j:f32[] = exp b
    k:f32[] = sub i j
    l:f32[] = mul e k
  in (l,) }

<br>

<br>

Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.

Let's see this in action!

<br>

<br>

In [1]:
# Here we perform an AOT compilation to show what goes on after compilation - yes JAX also supports AOT when necessary

print(jit(fnc_jax).lower(1., 1.).compile().as_text())

NameError: name 'fnc_jax' is not defined

$\textbf{What the output will look like with new version of JAX (ignore error above)}$

HloModule jit_fnc_jax, entry_computation_layout={(f64[],f64[])->f64[]}, allow_spmd_sharding_propagation_to_output={true}

fused_computation {
  param_0.2 = f64[] parameter(0)
  param_1.4 = f64[] parameter(1)
  exponential.0 = f64[] exponential(param_1.4)
  subtract.1 = f64[] subtract(param_0.2, exponential.0)
  sine.0 = f64[] sine(param_0.2)
  add.0 = f64[] add(sine.0, param_0.2)
  subtract.0 = f64[] subtract(add.0, exponential.0)
  ROOT multiply.0 = f64[] multiply(subtract.1, subtract.0)
}

ENTRY main.13 {
  Arg_0.1 = f64[] parameter(0), sharding={replicated}
  Arg_1.2 = f64[] parameter(1), sharding={replicated}
  divide.3 = f64[] divide(Arg_0.1, Arg_1.2)
  ROOT fusion = f64[] fusion(divide.3, Arg_1.2), kind=kLoop, calls=fused_computation
}

<br>

<br>

Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.


<br>

<br>

### Things to Keep in Mind when using JAX

<br>

<br>

JAX provides the `jax.numpy` wrapper to mimic the more familiar interface for users.

Under the hood, howoever, JAX performs it's computations using the more powerful, but stricter, `jax.lax` API. 

We just explored this with some examples.

<br>

<br>

But `jax.numpy` cannot be directly used as a replacement to `numpy`.

For example, unlike NumPy arrays, JAX arrays are always immutable.

In [98]:
import jax.numpy as jnp
import numpy as np

arr_jnp = jnp.array([1,2,3,4])

print(type(arr_jnp))

arr_np = np.array([1,2,3,4])

print(type(arr_np))

arr_np[1] = 0
print(arr_np)

arr_jnp[1] = 0



<class 'jaxlib.xla_extension.DeviceArray'>
<class 'numpy.ndarray'>
[1 0 3 4]


TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [99]:
# For updating individual elements, JAX provides an indexed update syntax that returns an updated copy:

arr_jnp = arr_jnp.at[1].set(0)  # Re-assign the mutated array
print(arr_jnp)


[1 0 3 4]
