# JIT-compilation with Numba and JAX

Let's consider the following quadratic formula:
```python
def quadratic_formula(a, b, c):
    return (-b + np.sqrt(b**2 - 4*a*c)) / (2*a)
```



What does this computation do? It gets more clear if we write this computation out:
```python
def pedantic_quadratic_formula(a, b, c):
    tmp1 = np.negative(b)            # -b
    tmp2 = np.square(b)              # b**2
    tmp3 = np.multiply(4, a)         # 4*a
    tmp4 = np.multiply(tmp3, c)      # tmp3*c
    del tmp3
    tmp5 = np.subtract(tmp2, tmp4)   # tmp2 - tmp4
    del tmp2, tmp4
    tmp6 = np.sqrt(tmp5)             # sqrt(tmp5)
    del tmp5
    tmp7 = np.add(tmp1, tmp6)        # tmp1 + tmp6
    del tmp1, tmp6
    tmp8 = np.multiply(2, a)         # 2*a
    return np.divide(tmp7, tmp8)     # tmp7 / tmp8
```



There are **9(!)** elementwise operations that each runs a compiled loop, i.e.:
```python
tmp1 = np.negative(b)  
tmp2 = np.square(b)
...

# is equivalent to
n = len(b)
tmp1 = np.empty(n)
for i in range(n):  # (compiled loop)
    tmp1[i] = -b[i]

tmp2 = np.empty(n)
for i in range(n):  # (compiled loop)
    tmp2[i] = b[i] ** 2

...
```



It would be much more efficient to apply all elementwise operations in a single loop:
```python
n = len(b)
out = np.empty(n)
for i in range(n):  # (compiled loop)
    out[i] = (-b[i] + np.sqrt(b[i]**2 - 4*a[i]*c[i])) / (2*a[i])
```

Essentially, we're able to get rid of _intermediate_ arrays by "fusing operations" using just-in-time (JIT) compilation by applying these operations in a _single_ iteration over our data.

Fusing operations is a tricky task however. There are a few ways to achieve this for array processing in Python, and I'd like to highlight two of them:

- Numba: https://numba.pydata.org
- JAX: https://github.com/jax-ml/jax

In [None]:
%pip install jax

In [None]:
# NumPy
import numpy as np

# JAX
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

# Numba
import numba as nb

# Matplotlib
import matplotlib.pyplot as plt

Let's consider the quadratic formula example again, and compare the runtimes for NumPy, Numba, and JAX:

In [None]:
# Setup data
a = np.random.uniform(5, 10, 5_000_000)
b = np.random.uniform(10, 20, 5_000_000)
c = np.random.uniform(-0.1, 0.1, 5_000_000)


# Setup quadratic formula
def quadratic_formula(a, b, c):
    return (-b + np.sqrt(b**2 - 4*a*c)) / (2*a)

NumPy case:

In [None]:
%%timeit -n1 -r3

quadratic_formula(a, b, c)

Numba case:

In [None]:
@nb.njit  # JIT compile!
def quadratic_formula_numba(a, b, c):
    n = a.shape[0]
    out = np.empty(n)
    for i in range(n):
        out[i] = (-b[i] + np.sqrt(b[i]**2 - 4*a[i]*c[i])) / (2*a[i])
    return out

In [None]:
%%timeit -n1

quadratic_formula_numba(a, b, c)

In [None]:
%%timeit -n10 -r3

quadratic_formula_numba(a, b, c)

JAX case:

In [None]:
# Setup data
a_jax = jnp.asarray(a)
b_jax = jnp.asarray(b)
c_jax = jnp.asarray(c)


@jax.jit  # JIT compile!
def quadratic_formula_jax(a, b, c):
    return (-b + jnp.sqrt(b**2 - 4*a*c)) / (2*a)

In [None]:
%%timeit -n1

quadratic_formula_jax(a_jax, b_jax, c_jax).block_until_ready()

In [None]:
%%timeit -n10 -r3

quadratic_formula_jax(a_jax, b_jax, c_jax).block_until_ready()

The first invocation for JAX & Numba took longer than consecutive ones. That's the compile time! Afterwards the compiled function is cached...

But JAX is still much faster, why?

One important difference is that JAX uses as many threads as it has access to. Numba is single-threaded, but can be multithreaded using `parallel=True`:

In [None]:
@nb.njit(parallel=True)  # JIT compile with `parallel=True`!
def quadratic_formula_numba_parallel(a, b, c):
    n = a.shape[0]
    out = np.empty(n)
    for i in nb.prange(n):  # note: `range` -> `nb.prange`
        out[i] = (-b[i] + np.sqrt(b[i]**2 - 4*a[i]*c[i])) / (2*a[i])
    return out

In [None]:
%%timeit -n1

quadratic_formula_numba_parallel(a, b, c)

In [None]:
%%timeit -n10 -r3

quadratic_formula_numba_parallel(a, b, c)

Now we're roughly on-par with JAX and Numba with ~2-3ms runtime compared to NumPy's ~23ms.


You might have noticed a fundamental difference between JAX and Numba in how those kernels are written: 

- Numba forces[<sup id="fn1-back">1</sup>](#fn1) you to write _imperative_ code
- JAX forces[<sup id="fn2-back">2</sup>](#fn2) you to write _array-oriented_ code


![image](https://raw.githubusercontent.com/jpivarski-talks/2023-12-18-hsf-india-tutorial-bhubaneswar/refs/heads/main/img/slow-fast-imperative-vectorized.svg)



[<sup id="fn1">1</sup>](#fn1-back) <sup>Can be written array-oriented with `nb.vectorize`.</sup> 

[<sup id="fn2">2</sup>](#fn2-back) <sup>Can be written imperative with JAX's own loop primitives, e.g. `jax.lax.scan`.</sup>

### How does JIT compilation even work? (JAX)

Let's have a look at the JAX example, what does `jax.jit` do?

It works in 4 steps:
1. Stage out a `jax.jit`-decorated function into a new program using a JAX internal IR (JaxPr)
2. Lower this IR (JaxPr) into the StableHLO IR
3. Compile the StableHLO program with the XLA compiler
4. Execute the compiled program

Let's see those 4 steps in action:

In [None]:
# Step 1: Create the JaxPr (through tracing)
traced = quadratic_formula_jax.trace(a_jax, b_jax, c_jax)
print(traced.jaxpr)

This JaxPr looks a lot like the previously shown pedantic version of the quadratic formula (lecture part-2):

```python
def pedantic_quadratic_formula(a, b, c):
    tmp1 = np.negative(b)            # -b
    tmp2 = np.square(b)              # b**2
    tmp3 = np.multiply(4, a)         # 4*a
    tmp4 = np.multiply(tmp3, c)      # tmp3*c
    del tmp3
    tmp5 = np.subtract(tmp2, tmp4)   # tmp2 - tmp4
    del tmp2, tmp4
    tmp6 = np.sqrt(tmp5)             # sqrt(tmp5)
    del tmp5
    tmp7 = np.add(tmp1, tmp6)        # tmp1 + tmp6
    del tmp1, tmp6
    tmp8 = np.multiply(2, a)         # 2*a
    return np.divide(tmp7, tmp8)     # tmp7 / tmp8
```

But instead of executing line-by-line we'll lower our JaxPr to StableHLO, and then compile it with XLA to fuse those kernels!

In [None]:
# Step 2: Lower the JaxPr to StableHLO (still looks similar to our pedantic code)
lowered = quadratic_formula_jax.lower(a_jax, b_jax, c_jax)
print(lowered.as_text())

In [None]:
# Step 3: Compile the StableHLO program with XLA
compiled = lowered.compile()
# print(compiled.as_text())

In [None]:
# Step 4: Execute the compiled program
print(compiled(a_jax, b_jax, c_jax))

### Limitations of Numba

You can not JIT-compile arbitrary Python functions. Numba can only JIT-compile a subset of Python, i.e. everything that's "known" to Numba as a type (mostly NumPy & NumPy operations).

For more information, see: https://numba.readthedocs.io/en/stable/user/5minguide.html#will-numba-work-for-my-code.


Check the following:

In [None]:
@nb.njit
def sum_dict_values(d):
    out = 0.
    for v in d.values():
        out += v
    return out

sum_dict_values({"a": 1.0, "b": 2.0, "c": 3.0})  # Fails, because `dict` is not a known type for Numba

### Limitations of JAX

JAX infers the operations that are going to be run through a "tracing step". Essentially, JAX will run your program once with shallow array objects (no data, just metadata). That let's you JIT-compile all of Python, **but** you can't JIT-compile data-dependent operations.

For more "sharp bits", see: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html.

Check the following:

In [None]:
# Data-dependent operations are not traceable

@jax.jit
def accumulate_if(arr):
    print(arr)
    if jnp.any(arr > 3):
        return jnp.sum(arr)
    else:
        return jnp.prod(arr)


array = jnp.array([1., 2., 3., 4., 5.])
print(accumulate_if(array))  # Fails, because jnp.any(arr > 3) is not traceable!

Another limitation of JAX is that you can't JIT compile programs with unknown shapes:

In [None]:
@jax.jit
def sum_greater_than_three(arr):
    return jnp.sum(arr[arr > 3.0])


array = jnp.array([1., 2., 3., 4., 5.])
print(sum_greater_than_three(array))  # Fails, because the output shape of `arr[arr > 3.0]` is not inferrable through tracing (without data)

### Impure functions are dangerous with JIT compilation! (Numba & JAX)

In [None]:
do_sum = False

@nb.njit
def accumulate(arr):
    if do_sum:
        return np.sum(arr)
    else:
        return np.prod(arr)


array = np.array([1., 2., 3., 4., 5.])
print("Accumulate with `np.prod`:", accumulate(array))

# now we switch `do_sum` on!
do_sum = True
print("Accumulate with `np.sum`:", accumulate(array), f"...Hey, this should've been {np.sum(array)} instead!")

In [None]:
do_sum = False

@jax.jit
def accumulate(arr):
    if do_sum:
        return jnp.sum(arr)
    else:
        return jnp.prod(arr)


array = jnp.array([1., 2., 3., 4., 5.])

print("Accumulate with `jnp.prod`:", accumulate(array))

# now we switch `do_sum` on!
do_sum = True
print("Accumulate with `jnp.sum`:", accumulate(array), f"...Hey, this should've been {jnp.sum(array)} instead!")

We can see why in the JAX case: that's because the traced program _never knew_ that there's a `sum` operation in the first place _and_ the compiled function is cached based on their input arguments.

In [None]:
print("Traced program:\n", accumulate.trace(array).jaxpr)
print()
print("HLO program:\n", accumulate.lower(array).as_text()) # this is the program that get's compiled by XLA compiler

On to the [project1.ipynb](project1.ipynb)!

### Auto-differentiation with JAX


Knowing the computational graph of a program (i.e. JaxPr) gives us the possibility to transform the program. JAX implements different _interpreters_ to execute the JaxPr of which one is able to replace every operation by its gradient:


In [None]:
def fun(x):
    return 2.0 + jnp.sin(x)

print("JaxPr:")
print(jax.make_jaxpr(fun)(1.0))
print()

grad_fun = jax.grad(fun)
print("JaxPr (grad):")
print(jax.make_jaxpr(grad_fun)(1.0))

Gradients are powerful! Many scientific problems involve gradient-based minimizations.

Let's implement a gradient-based optimization on our own:

In [None]:
rng = jax.random.key(42)

true_a, true_b = 0.2, 1.1

# function that we want to fit
@jax.jit
def function(x, a, b):
    return b*x**2 - 4*a*x - b

# generate true data with some noise
def generate_data(rng):
    x_key, noise_key = jax.random.split(rng)

    xs = jax.random.uniform(x_key, (128, 1), minval=-3, maxval=3)
    noise = jax.random.normal(noise_key, (128, 1)) * 0.15

    ys = function(x=xs + noise, a=true_a, b=true_b)
    return xs, ys


# plot data
xs, ys = generate_data(rng=rng)
plt.scatter(xs, ys)

We want to know what the true underlying `a` and `b` values are in this distribution. The next cell implements a gradient-based optimization to fit `function` to the distribution:

In [None]:
from typing import NamedTuple


# Just a struct that holds the parameters of the function
class Params(NamedTuple):
    a: jax.Array
    b: jax.Array


# Initialize parameters for the function (`a` and `b`)
def init(rng) -> Params:
    a_key, b_key = jax.random.split(rng)
    a = jax.random.normal(a_key, ())
    b = jax.random.normal(b_key, ())
    return Params(a, b)


# Compute the loss function (least squares error)
def loss(params: Params, x: jax.Array, y: jax.Array) -> jax.Array:
    pred = function(x=x, a=params.a, b=params.b)
    return jnp.mean((pred - y) ** 2)


# Perform one gradient descent update step on params using the given data. (~SGD)
@jax.jit
def update(params: Params, x: jax.Array, y: jax.Array) -> Params:
    # Computes the gradients of the loss function with respect to the parameters
    grads = jax.grad(loss)(params, x, y)

    # Define a step function that updates the parameters
    def step(param, grad):
      return param - 0.005 * grad  # 0.005 := learning rate

    # Apply the step function to each parameter
    return jax.tree.map(step, params, grads)


# Run the optimization
params = init(rng)
for _ in range(500):
    params = update(params, xs, ys)


print(f"True parameters  : a={true_a:.2f}, b={true_b:.2f}")
print(f"Fitted parameters: a={params.a:.2f}, b={params.b:.2f}")

plt.scatter(xs, ys)
pred_ys = function(x=xs, a=params.a, b=params.b)
plt.plot(xs, pred_ys, ".", c='red', label=f'Fit result: a={params.a:.2f}, b={params.b:.2f}')
plt.legend()

This is the key ingredient for training neural networks, see more at tomorrow's ML lecture by Liv!