# What is Just In Time (JIT) Compilation?

If we go by the [definition](https://en.wikipedia.org/wiki/Just-in-time_compilation) of JIT, then JIT is a way of compiling your code during the execution. A system implementing a JIT compiler typically continuously analyses the code being executed and identifies parts of the code where the speedup gained from compilation or recompilation would outweigh the overhead of compiling that code.


In [1]:
import jax
import jax.numpy as jnp

from jax import jit, grad, random

# Jaxprs (JAX Expressions)

Jaxpr is an intermediate language for representing the normal Python functions. When you transform a function the function is first converted to simple statically-typed intermediate expressions by Jaxpr language, then the transformations are directly applied on these jaxprs. 

Jaxpr is a flow model in the background which can then be compiled using XLA.

1. A jaxpr instance represents a function with one or more typed parameters (input variables) and one or more typed results
2. The inputs and outputs have `types` and are represented as abstract values
3. Not all Python programs can be represented by jaxprs but many scientific computations and machine learning programs can


In [4]:
def relu(x):
    return jnp.maximum(0.0, x)

In [5]:
print(jax.make_jaxpr(relu)(5.0))

{ lambda ; a:f32[]. let b:f32[] = max 0.0 a in (b,) }


Side effects are not captured by jaxpr. jaxpr depends on tracing. The behavior of any transformed function is dependent on the traced values. You may notice the side effect on the first run but not necessarily on the subsequent calls. Hence jaxpr isn't even bothered about the global list in this case.

In [7]:
input_list = []

def sigmoid(x):
  global input_list
  
  input_list.append(x)

  res = 1 / (1 + jnp.exp(-x))
  
  return res

print(jax.make_jaxpr(sigmoid)(5.0))

{ lambda ; a:f32[]. let
    b:f32[] = neg a
    c:f32[] = exp b
    d:f32[] = add c 1.0
    e:f32[] = div 1.0 d
  in (e,) }


In [8]:
def sigmoid_with_print(x):
  print('Executing sigmoid activation on :', x)
  res = sigmoid(x)
  return res

print(jax.make_jaxpr(sigmoid_with_print)(5.0))

Executing sigmoid activation on : Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = neg a
    c:f32[] = exp b
    d:f32[] = add c 1.0
    e:f32[] = div 1.0 d
  in (e,) }


# JIT compiling a function


In [12]:
def sigmoid_with_print(x):

  print('Executing sigmoid activation on :', x)

  res = 1 / (1 + jnp.exp(-x))

  return res

In [13]:
sigmoid_jit = jax.jit(sigmoid_with_print)

Side-effects of the jitted function only visible in the first iteration. 

Also observe that the first execution of the jitted function is much slower than the later executions. The first call takes additional time thanks to JAX tracing. JAX tracing converts the code into the intermediate jaxprs language and in the remaining calls the compiled (jaxprs) code is invoked

In [14]:
for i, num in enumerate([0, 2, 4, 8]):
    print('Iteration: ', i + 1)
    %time print('Output: ', sigmoid_jit(num).block_until_ready())
    print('*'*30)

Iteration:  1
Executing sigmoid activation on : Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
Output:  0.5
CPU times: user 64 ms, sys: 482 µs, total: 64.5 ms
Wall time: 64.7 ms
******************************
Iteration:  2
Output:  0.8807971
CPU times: user 3.66 ms, sys: 1.18 ms, total: 4.84 ms
Wall time: 3.91 ms
******************************
Iteration:  3
Output:  0.98201376
CPU times: user 3.42 ms, sys: 1.3 ms, total: 4.72 ms
Wall time: 3.77 ms
******************************
Iteration:  4
Output:  0.99966466
CPU times: user 2.61 ms, sys: 2.15 ms, total: 4.76 ms
Wall time: 3.74 ms
******************************


Here code is  sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions.

In [15]:
x = jnp.arange(10000000)

%timeit sigmoid_with_print(x).block_until_ready()

Executing sigmoid activation on : [      0       1       2 ... 9999997 9999998 9999999]
Executing sigmoid activation on : [      0       1       2 ... 9999997 9999998 9999999]
Executing sigmoid activation on : [      0       1       2 ... 9999997 9999998 9999999]
Executing sigmoid activation on : [      0       1       2 ... 9999997 9999998 9999999]
Executing sigmoid activation on : [      0       1       2 ... 9999997 9999998 9999999]
Executing sigmoid activation on : [      0       1       2 ... 9999997 9999998 9999999]
The slowest run took 136.55 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 2.68 ms per loop


We are running  sigmoid_jit once on x. This is where JAX does its tracing – it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Subsequent calls to sigmoid_jit will now use that code, skipping our old Python implementation entirely.

(If we didn’t include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn’t be a fair comparison.)

We timed the execution speed of the compiled version. (Note the use of block_until_ready(), which is required due to JAX’s Asynchronous execution model).

In [16]:
%timeit sigmoid_jit(x).block_until_ready()

Executing sigmoid activation on : Traced<ShapedArray(int32[10000000])>with<DynamicJaxprTrace(level=0/1)>
The slowest run took 235.32 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 509 µs per loop


In [17]:
x = jnp.arange(10000000)

for i in range(5):
    print('Iteration :' , i+1)

    %timeit sigmoid_jit(x).block_until_ready()

Iteration : 1
1000 loops, best of 5: 507 µs per loop
Iteration : 2
The slowest run took 4.25 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 498 µs per loop
Iteration : 3
The slowest run took 5.00 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 499 µs per loop
Iteration : 4
1000 loops, best of 5: 495 µs per loop
Iteration : 5
1000 loops, best of 5: 492 µs per loop


JIT and Python Control Flow

**Important**: because JIT compilation is done without information on the content of the array, control flow statements in the function cannot depend on traced values.

In [28]:
def f(x):
  if x > 0:
    return 3 * x**3 + 2 * x**2 + 5 * x
  else:
    return 2 * x

jitted_fn = jax.jit(f)

jitted_fn(10) 

ConcretizationTypeError: ignored


The value of `x` isn't concrete while tracing. As a result when we hit a line like `if x > 0`, the expression `x >0 ` evaluates to an abstract `ShapedArray((), jnp.bool_)` that represents the set {True, False}. **When Python attempts to coerce that to a concrete True or False, we get an error: we don’t know which branch to take, and can’t continue tracing!** 

If the computation inside the loop is pretty expensive, you can still jit some part of the function body. Let's see it in action

Jitting the expensive computational part

In [26]:
@jit
def exp_fun(x):
  return 3 * x**3 + 2 * x**2 + 5 * x


In [27]:
def f_inner_jitted(x):
  if x > 0:
    return exp_fun(x)
  else:
    return 2*x

print(f_inner_jitted(10))

3250


If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying static_argnums or static_argnames. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values.

In [30]:
def f(x, n):
  print(f"x = {x}, n = {n}")

  if (n >= 0):
    return (n) * x**3 + (n - 1) *x**2 +  (n - 2)*x
  else:
    return 2 * x

jitted_fn = jax.jit(f)

jitted_fn(10, 5) 

x = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, n = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


ConcretizationTypeError: ignored

In [31]:
f_jit_wrong = jax.jit(f, static_argnums = 0)

print(f_jit_wrong(10, 5))

x = 10, n = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


ConcretizationTypeError: ignored

In [32]:
f_jit_correct = jax.jit(f, static_argnames = ['n'])

print(f_jit_correct(10, 5))

x = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, n = 5
5430


In [33]:
print(f_jit_correct(100, 5))

5040300


In [34]:
print(f_jit_correct(10, 3))

x = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, n = 3
3210


In [35]:
print(f_jit_correct(100, 3))

3020100


To specify such arguments when using jit as a decorator, a common pattern is to use python’s functools.partial

In [37]:
from functools import partial

@partial(jax.jit, static_argnames = ['n'])
def f_jit_decorated(x, n):
  print(f"x = {x}, n = {n}")

  if (n >= 0):
    return (n) * x**3 + (n - 1) *x**2 +  (n - 2)*x
  else:
    return 2 * x

print(f_jit_decorated(10, 5))

x = Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, n = 5
5430


In many of the the examples, jitting is not worth it.This is because jax.jit introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. This is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.

Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.

In [39]:
def f(x, n):
  if (n >= 0):
    return (n) * x**3 + (n - 1) *x**2 +  (n - 2)*x
  else:
    return 2 * x

f_jit_correct = jax.jit(f, static_argnames = ['n'])

print('f jitted:')
%timeit f_jit_correct(10, 5).block_until_ready()

print('f:')
%timeit f(10, 5)

f jitted:
The slowest run took 190.36 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 5: 135 µs per loop
f:
The slowest run took 4.63 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 5: 627 ns per loop


# Caching
Suppose I define f = jax.jit(g). When I first invoke f, it will get compiled, and the resulting XLA code will get cached. Subsequent calls of f will reuse the cached code. This is how jax.jit makes up for the up-front cost of compilation.
Avoid calling jax.jit inside loops. For most cases, JAX will be able to use the compiled, cached function in subsequent calls to jax.jit. However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined. This will cause unnecessary compilation each time in the loop:

In [41]:
def unjitted_exp_fun(x):
  return 3 * x**3 + 2 * x**2 + 5 * x

In [42]:
def f_inner_jitted_lambda(x):
  if x > 0:
    return jax.jit(lambda x: unjitted_exp_fun(x))(x)
  else:
    return 2 * x

In [44]:
def f_inner_jitted_normal(x):
  if x > 0:
    return jax.jit(unjitted_exp_fun)(x)
  else:
    return 2 * x

jit called in a loop with lambdas

In [45]:
%timeit f_inner_jitted_lambda(10).block_until_ready()

100 loops, best of 5: 16.3 ms per loop


jit called in a loop with caching:

In [46]:
%timeit f_inner_jitted_normal(10).block_until_ready()

The slowest run took 49.80 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 5: 501 µs per loop
