# Performances tips

By design, most functions in this library work on arrays, or even batch of arrays,
see [the `*batch` axes section](../batch_axes.md#the-batch-axes).

However, the [runtime type checking](type_checking.ipynb#runtime-type-checking)
of those functions, coupled with the use of Python logic, introduces some overhead
that can degrade performances, especially with nested function calls.

To this end, we encourage using JAX's
[just-in-time compilation](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)
(JIT). Please read the linked content if you are not familiar with this concept.

Almost all functions we provide can be wrapped with {func}`jax.jit`, in order
to compile them to efficient code. The type checkers we use are aware of that
and will only check functions are compilation time.

Once compiled, no more type checking will be performed, reducing the overhead to the
bare minimal.

## JIT-ing an existing function

Here, we will look at the
{func}`rays_intersect_triangles<differt.rt.utils.rays_intersect_triangles>`
function and how much it can benefit from JIT compilation.

In [None]:
import jax

from differt.rt.utils import rays_intersect_triangles

In [None]:
key = jax.random.PRNGKey(1234)
key1, key2, key3 = jax.random.split(key, 3)

batch = (10, 100)

ray_origins = jax.random.uniform(key1, (*batch, 3))
ray_directions = jax.random.uniform(key2, (*batch, 3))
triangle_vertices = jax.random.uniform(key2, (*batch, 3, 3))

Let's look at the execution time without compilation.
The `[0].block_until_ready()` is needed because:

1. the function returns a tuple and we need to select one
   (e.g., the first with `[0]`) of the output arrays to;
2. call `.block_until_ready()` on it, so JAX knows it must actually perform the computation.

If the call to `.block_until_ready()` is omitted, the execution time *may not* be relevant.

In [None]:
%timeit rays_intersect_triangles(ray_origins, ray_directions, triangle_vertices)[0].block_until_ready()

Then, let's compare it with its JIT compiled version.

Note that we call the function before timing it, so we do not take
the compilation overhead into account.

In [None]:
rays_intersect_triangles_jit = jax.jit(rays_intersect_triangles)

# Warmup to compile code
rays_intersect_triangles_jit(ray_origins, ray_directions, triangle_vertices)[
    0
].block_until_ready()

%timeit rays_intersect_triangles_jit(ray_origins, ray_directions, triangle_vertices)[0].block_until_ready()

See! Here, we reduced the execution time by **more that one order of magnitude**, which is quite
nice given the fact that we only had to wrap it with {func}`jax.jit`, nothing more.

In general, the amount of performances gained will hightly depend on the function that is compiled.
We advice to first try **without** any JIT compilation, and gradually add some `@jax.jit`
decorators to the functions you feel could benefit from it.

## Why not JIT all functions?

JIT compilation comes at a cost of compiling the function
during its first execution, which can become slow during debugging stages.
Also, if some arguments are static,
if it will need to re-compile the function everytime the static arguments
change.

Moreover, JIT compilation removes print statements, does not allow for
inpure functions (e.g., using globals), and might not always produce a faster code.

For all those reasons, this is the responsability of the end-user to
determine when to use JIT compilation in their code.