# Part 3: JIT-compilation with Numba and JAX

In lecture 2 you've seen that fusing operations is powerful and crucial for performance!

Essentially, we're able to get rid of _intermediate_ arrays by "fusing operations" using just-in-time (JIT) compilation by applying these operations in a _single_ iteration over our data.

Fusing operations is a tricky task however. There are a few ways to achieve this for array processing in Python, and I'd like to highlight two of them:

- Numba: https://numba.pydata.org
- JAX: https://github.com/jax-ml/jax

In [1]:
# NumPy
import numpy as np

# JAX
import jax
import jax.numpy as jnp

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)

# Numba
import numba as nb

Let's consider the quadratic formula example again, and compare the runtimes for NumPy, Numba, and JAX:

In [2]:
# Setup data
a = np.random.uniform(5, 10, 5_000_000)
b = np.random.uniform(10, 20, 5_000_000)
c = np.random.uniform(-0.1, 0.1, 5_000_000)


# Setup quadratic formula
def quadratic_formula(a, b, c):
    return (-b + np.sqrt(b**2 - 4*a*c)) / (2*a)

NumPy case:

In [3]:
%%timeit -n1 -r3

quadratic_formula(a, b, c)

22.3 ms ± 1.82 ms per loop (mean ± std. dev. of 3 runs, 1 loop each)


Numba case:

In [4]:
@nb.njit  # JIT compile!
def quadratic_formula_numba(a, b, c):
    n = a.shape[0]
    out = np.empty(n)
    for i in range(n):
        out[i] = (-b[i] + np.sqrt(b[i]**2 - 4*a[i]*c[i])) / (2*a[i])
    return out

In [5]:
%%timeit -n1

quadratic_formula_numba(a, b, c)

The slowest run took 94.57 times longer than the fastest. This could mean that an intermediate result is being cached.
81.6 ms ± 184 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
%%timeit -n10 -r3

quadratic_formula_numba(a, b, c)

6.86 ms ± 474 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)


JAX case:

In [7]:
# Setup data
a_jax = jnp.asarray(a)
b_jax = jnp.asarray(b)
c_jax = jnp.asarray(c)


@jax.jit  # JIT compile!
def quadratic_formula_jax(a, b, c):
    return (-b + jnp.sqrt(b**2 - 4*a*c)) / (2*a)

In [8]:
%%timeit -n1

quadratic_formula_jax(a_jax, b_jax, c_jax).block_until_ready()

The slowest run took 19.43 times longer than the fastest. This could mean that an intermediate result is being cached.
7.8 ms ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
%%timeit -n10 -r3

quadratic_formula_jax(a_jax, b_jax, c_jax).block_until_ready()

2.24 ms ± 23.4 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)


The first invocation for JAX & Numba took longer than consecutive ones. That's the compile time! Afterwards the compiled function is cached...

But JAX is still much faster, why?

One important difference is that JAX uses as many threads as it has access to. Numba is single-threaded, but can be multithreaded using `parallel=True`:

In [10]:
@nb.njit(parallel=True)  # JIT compile with `parallel=True`!
def quadratic_formula_numba_parallel(a, b, c):
    n = a.shape[0]
    out = np.empty(n)
    for i in nb.prange(n):  # note: `range` -> `nb.prange`
        out[i] = (-b[i] + np.sqrt(b[i]**2 - 4*a[i]*c[i])) / (2*a[i])
    return out

In [11]:
%%timeit -n1

quadratic_formula_numba_parallel(a, b, c)

The slowest run took 90.82 times longer than the fastest. This could mean that an intermediate result is being cached.
25.5 ms ± 57.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [12]:
%%timeit -n10 -r3

quadratic_formula_numba_parallel(a, b, c)

1.98 ms ± 31.1 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)


Now we're roughly on-par with JAX and Numba with ~2-3ms runtime compared to NumPy's ~23ms.


You might have noticed a fundamental difference between JAX and Numba in how those kernels are written: 

- Numba forces[<sup id="fn1-back">1</sup>](#fn1) you to write _imperative_ code
- JAX forces[<sup id="fn2-back">2</sup>](#fn2) you to write _array-oriented_ code


![image](https://raw.githubusercontent.com/jpivarski-talks/2023-12-18-hsf-india-tutorial-bhubaneswar/refs/heads/main/img/slow-fast-imperative-vectorized.svg)



[<sup id="fn1">1</sup>](#fn1-back) <sup>Can be written array-oriented with `nb.vectorize`.</sup> 

[<sup id="fn2">2</sup>](#fn2-back) <sup>Can be written imperative with JAX's own loop primitives, e.g. `jax.lax.scan`.</sup>

### Limitations of Numba

You can not JIT-compile arbitrary Python functions. Numba can only JIT-compile a subset of Python, i.e. everything that's "known" to Numba as a type (mostly NumPy & NumPy operations).

For more information, see: https://numba.readthedocs.io/en/stable/user/5minguide.html#will-numba-work-for-my-code.


Check the following:

In [13]:
@nb.njit
def sum_dict_values(d):
    out = 0.
    for v in d.values():
        out += v
    return out

sum_dict_values({"a": 1.0, "b": 2.0, "c": 3.0})  # Fails, because `dict` is not a known type for Numba

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
non-precise type pyobject
During: typing of argument at /var/folders/34/_lhgt3t51_d5yqwccswc8k9w0000gn/T/ipykernel_23296/463706393.py (1)

File "../../../../../../var/folders/34/_lhgt3t51_d5yqwccswc8k9w0000gn/T/ipykernel_23296/463706393.py", line 1:
<source missing, REPL/exec in use?>

During: Pass nopython_type_inference 

This error may have been caused by the following argument(s):
- argument 0: Cannot determine Numba type of <class 'dict'>


### Limitations of JAX

JAX infers the operations that are going to be run through a "tracing step". Essentially, JAX will run your program once with shallow array objects (no data, just metadata). That let's you JIT-compile all of Python, **but** you can't JIT-compile data-dependent operations.

For more "sharp bits", see: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html.

Check the following:

In [None]:
# Data-dependent operations are not traceable

@jax.jit
def accumulate_if(arr):
    print(arr)
    if jnp.any(arr > 3):
        return jnp.sum(arr)
    else:
        return jnp.prod(arr)


array = jnp.array([1., 2., 3., 4., 5.])
print(accumulate_if(array))  # Fails, because jnp.any(arr > 3) is not traceable!

Traced<ShapedArray(float64[5])>with<DynamicJaxprTrace>


TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function sum_if at /var/folders/34/_lhgt3t51_d5yqwccswc8k9w0000gn/T/ipykernel_18396/583950977.py:3 for jit. This concrete value was not available in Python because it depends on the value of the argument arr.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError

Another limitation of JAX is that you can't JIT compile programs with unknown shapes:

In [None]:
@jax.jit
def sum_greater_than_three(arr):
    return jnp.sum(arr[arr > 3.0])


array = jnp.array([1., 2., 3., 4., 5.])
print(sum_greater_than_three(array))  # Fails, because the output shape of `arr[arr > 3.0]` is not inferrable through tracing (without data)

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[5])

See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

### Impure functions are dangerous with JIT compilation! (Numba & JAX)

In [None]:
do_sum = False

@nb.njit
def accumulate(arr):
    if do_sum:
        return np.sum(arr)
    else:
        return np.prod(arr)


array = np.array([1., 2., 3., 4., 5.])
print("Accumulate with `np.prod`:", accumulate(array))

# now we switch `do_sum` on!
do_sum = True
print("Accumulate with `np.sum`:", accumulate(array), f"...Hey, this should've been {np.sum(array)} instead!")

Accumulate with `np.prod`: 120.0
Accumulate with `np.sum`: 120.0 ...Hey, this should've been 15.0 instead!


In [None]:
do_sum = False

@jax.jit
def accumulate(arr):
    if do_sum:
        return jnp.sum(arr)
    else:
        return jnp.prod(arr)


array = jnp.array([1., 2., 3., 4., 5.])

print("Accumulate with `jnp.prod`:", accumulate(array))

# now we switch `do_sum` on!
do_sum = True
print("Accumulate with `jnp.sum`:", accumulate(array), f"...Hey, this should've been {jnp.sum(array)} instead!")

Accumulate with `jnp.prod`: 120.0
Accumulate with `jnp.sum`: 120.0 ...Hey, this should've been 15.0 instead!


We can see why in the JAX case: that's because the traced program _never knew_ that there's a `sum` operation in the first place _and_ the compiled function is cached based on their input arguments.

In [None]:
print("Traced program:\n", accumulate.trace(array).jaxpr)
print()
print("HLO program:\n", accumulate.lower(array).as_text()) # this is the program that get's compiled by XLA compiler

Traced program:
 { lambda ; a:f64[5]. let b:f64[] = reduce_prod[axes=(0,)] a in (b,) }

HLO program:
 module @jit_accumulate attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<5xf64>) -> (tensor<f64> {jax.result_info = "result"}) {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f64>
    %0 = stablehlo.reduce(%arg0 init: %cst) applies stablehlo.multiply across dimensions = [0] : (tensor<5xf64>, tensor<f64>) -> tensor<f64>
    return %0 : tensor<f64>
  }
}



### JIT-compilation for GPUs



#### Numba.cuda

Numba exposes CUDA to Python through the `numba.cuda` module. Here, the programming model follows very closely the CUDA C language by NVidia.

In [None]:
from numba import cuda

@cuda.jit
def matmul(A, B, C):
    """
    Perform square matrix multiplication of C = A * B
    """
    i, j = cuda.grid(2)
    if i < C.shape[0] and j < C.shape[1]:
        tmp = 0.
        for k in range(A.shape[1]):
            tmp += A[i, k] * B[k, j]
        C[i, j] = tmp

#### JAX on GPUs

JAX can run on GPUs without any code modifications (the power of array-oriented programming). The _symbolic_ operations of the IR (JaxPr) will just dispatch to GPU kernels instead of CPU kernels.

In [None]:
# this runs on CPU and GPU, depending on the available `jax.devices()`

@jax.jit
def matmul(A, B): # -> C
    """
    Perform square matrix multiplication of C = A * B
    """
    return A @ B


print("Available devices:", jax.devices())

# explicitely move data to devic (CPU or GPU):
array = jax.device_put(jnp.arange(10), device=jax.devices()[0])

print(f"{array=} lives on", array.device)

Available devices: [CpuDevice(id=0)]
array=Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int64) lives on TFRT_CPU_0


On to the [project.ipynb](project.ipynb)!