# Lecture 2: Tracers, JIT compilation and Sharp Edges

<br>

<br>

We saw how JAX has a numpy-like user API for performing array-based computing.

It has the capability to perform Composable Function Transformations on our Python/NumPy code and use it with the XLA compiler to perform super-fast computations on CPU, GPU or TPU

<br>

$$\text{Python} \rightarrow \text{Intermediate Representation} \rightarrow \text{Transformations}$$

<br>

<br>

<br>

Some examples of composable transformations include automatic differentiation, JIT compilation, parallelization on multi-core hardware, etc.

We explored automatic differentiation in the last lecture.

In this lecture, we will explore JIT compilation as well as find out how JAX works under the hood with these intermediate representations and transformations.

<br>

<br>

JAX uses a technique called function tracing (we will explore this in a bit) to convert Python code into a series of primitive operations - also referred to as an intermediate representation (IR).

During JIT compilation, JAX applies a series of optimizations on the IR operations to generate efficient XLA executable code be it on CPU, GPU or TPU.

Once a function has been JIT-compiled, JAX caches the resulting XLA code so that it can be re-used in subsequent calls. 

This means that the performance benefits of JIT compilation are persistent across multiple function calls, and that the cost of compilation is amortized over many calls.

<br>

<br>

JAX works by first converting your Python code into an Intermediate Representation, as we saw in the last lecture.

This part is known as $\textit{tracing}$.

How does it work?

Let's say we have a function:

<br>

<br>

In [7]:
from jax import lax
from jax import numpy as jnp

# Define a toy function that calculates the negative log likelihood of a scalar value

def NLL(x):
    
    log_x = jnp.log(x)            
    
    nll = jnp.multiply(-2,x)
    
    return nll

print(NLL(2))
print(NLL(4))

-4
-8


<br>

<br>

Okay, so how does JAX see this? Let's find out!

In [16]:
from jax import lax
from jax import numpy as jnp
import jax

# Define a toy function that calculates the negative log likelihood of a scalar value
@jax.jit
def NLL(x):
    
    print(x)
    
    log_x = jnp.log(x)        
    
    nll = jnp.multiply(-2,x)
    
    return nll

# Let's apply a JAX transformation to see exactly how JAX works with the input

print(jax.grad(NLL)(2.0))

Traced<ShapedArray(float32[], weak_type=True)>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float32[], weak_type=True)>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float32[], weak_type=True), *)
    recipe = LambdaBinding()
-2.0


-4
