# JIT Compilation in UnifyML

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/holl-/UnifyML/blob/main/docs/JIT.ipynb)
&nbsp; • &nbsp; [🌐 **UnifyML**](https://github.com/holl-/UnifyML)
&nbsp; • &nbsp; [📖 **Documentation**](https://holl-.github.io/UnifyML/)
&nbsp; • &nbsp; [🔗 **API**](https://holl-.github.io/UnifyML/unifyml)
&nbsp; • &nbsp; [**▶ Videos**]()
&nbsp; • &nbsp; [<img src="images/colab_logo_small.png" height=4>](https://colab.research.google.com/github/holl-/UnifyML/blob/main/docs/Examples.ipynb) [**Examples**](https://holl-.github.io/UnifyML/Examples.html)

Just-in-time (JIT) compilation can drastically speed up your code as Python-related overheads are eliminated and optimizations can be performed.

In [17]:
%%capture
!pip install unifyml

from unifyml import math

math.use('jax')

In UnifyML, you can JIT-compile a function using the [`math.jit_compile()`](unifyml/math#unifyml.math.jit_compile) decorator.

In [18]:
@math.jit_compile
def fun(x):
    print(f"Tracing fun with x = {x}")
    return 2 * x

The first time the function is called with new arguments, it is traced, i.e. all tensor operations are recorded.
Then, the passed arguments have concrete shapes but no concrete values.
Consequently, traced tensors cannot be used in control flow, such as `if` or loop conditions.
Replace `if` statements by [`math.where()`](unifyml/math#unifyml.math.where).

Depending on the used backend, the function may be called multiple times during tracing.

In [19]:
fun(math.tensor(1.))

Tracing fun with x = [92m()[0m [93mfloat32[0m [94mjax tracer[0m


[94m2.0[0m

Whenever the function is called with similar arguments to a previous call, the compiled version of the function is evaluated without calling the Python code.
Instead, the previously recorded tensor operations are performed again on the new input.

In [20]:
fun(math.tensor(1.))

[94m2.0[0m

Note that the `print` statement was not executed since `fun` was not actually called.
If we call the function with different shapes or dtypes, it will be traced again.

In [21]:
fun(math.tensor([1, 2]))

Tracing fun with x = [92m(vectorᶜ=2)[0m [93mint64[0m [94mjax tracer[0m


[94m(2, 4)[0m [93mint64[0m

## NumPy Operations

All [NumPy operations are performed at JIT-compile time](NumPy_Constants.html) and will not be executed once the function is compiled, similar to the `print` statement.
NumPy-backed tensors always have concrete values and can be used in `if` statements as well as loop conditions.

In [30]:
@math.jit_compile
def fun(x):
    print(f"Tracing fun with x = {x}")
    y = math.wrap(2)
    z = math.sin(y ** 2)
    print(f"z = {z}")
    if z > 1:
        return z * x
    else:
        return z / x

fun(math.tensor(1.))

Tracing fun with x = [92m()[0m [93mfloat32[0m [94mjax tracer[0m
z = [93mfloat64[0m [94m-0.7568025[0m


[94m-0.7568025[0m

Here, the control flow can depend on `z` since it is a NumPy array.

## Auxiliary Arguments

If we want the control flow to depend on a parameter, we must pass it as an auxiliary argument.

In [31]:
@math.jit_compile(auxiliary_args='y')
def fun(x, y):
    print(f"Tracing fun with x = {x}, y = {y}")
    z = math.sin(y ** 2)
    print(f"z = {z}")
    if (z > 1).all:
        return z * x
    else:
        return z / x

fun(math.tensor(1.), math.wrap(2))

Tracing fun with x = [92m()[0m [93mfloat32[0m [94mjax tracer[0m, y = [94m2[0m
z = [93mfloat64[0m [94m-0.7568025[0m


[94m-0.7568025[0m

The function always needs to be re-traced if an auxiliary argument changes in any way.

You can check whether a function would have to be traced using [`math.trace_check()](unifyml/math#unifyml.math.trace_check).

In [32]:
math.trace_check(fun, math.tensor(1.), math.wrap(2))

(True, '')

In [33]:
math.trace_check(fun, math.tensor(1.), math.wrap(-1))

(False, 'Auxiliary arguments do not match')

## Further Reading

[🌐 **UnifyML**](https://github.com/holl-/UnifyML)
&nbsp; • &nbsp; [📖 **Documentation**](https://holl-.github.io/UnifyML/unifyml/)
&nbsp; • &nbsp; [🔗 **API**](https://holl-.github.io/UnifyML/unifyml)
&nbsp; • &nbsp; [**▶ Videos**]()
&nbsp; • &nbsp; [<img src="images/colab_logo_small.png" height=4>](https://colab.research.google.com/github/holl-/UnifyML/blob/main/docs/Examples.ipynb) [**Examples**](https://holl-.github.io/UnifyML/Examples.html)