# Custom CPU Operators with Numba

This tutorial shows how to write high-performance CPU kernels using **Numba's `@njit`**
decorator and integrate them into the BrainEvent / JAX ecosystem.

[Numba](https://numba.readthedocs.io/) compiles Python functions to native machine code
via LLVM, achieving speeds comparable to C/Fortran.  BrainEvent's `numba_kernel` function
bridges Numba JIT-compiled functions into JAX via XLA's Foreign Function Interface (FFI),
so your Numba kernels become first-class JAX operations compatible with `jax.jit`, `jax.vmap`,
and other transforms.

## Contents
1. Why Numba on CPU?
2. Installation and Imports
3. Writing Numba JIT Kernels (`@numba.njit`)
4. `numba_kernel` – Wrapping for JAX
5. Parallel Kernels with `numba.prange`
6. Multiple Inputs and Outputs
7. Registering with `XLACustomKernel`
8. Neuroscience Example: Sparse CSR × Float-Vector Multiplication
9. Combining Numba CPU and Warp/Numba-CUDA Backends
10. Summary

## 1. Why Numba on CPU?

JAX runs on CPU, GPU, and TPU but some algorithms do not map well to the GPU's
massively-parallel execution model:

- **Sparse / irregular access patterns** – random memory accesses serialize on GPU
- **Sequential algorithms** – recurrences that depend on previous iterations
- **Small to medium problem sizes** – GPU overhead dominates for small arrays
- **CPU-only environments** – laptops, CI servers, edge devices

| Property | JAX native (CPU) | Numba (`@njit`) | C extension |
|----------|-----------------|-----------------|-------------|
| JIT speed | XLA (fast) | LLVM (fast) | Compiled ahead of time |
| Python overhead | Yes | Eliminated | Eliminated |
| Parallelism | Limited | `prange` / OpenMP | pthread / OpenMP |
| Custom loop structure | No | Yes | Yes |
| Write in Python | Yes | Yes | No |

Numba `@njit` lets you write the inner loop in Python while achieving native
performance, and `brainevent.numba_kernel` makes the result a proper JAX primitive.

**Requirements:** `pip install numba`  (no GPU needed)

## 2. Installation and Imports

In [None]:
# Install if needed:
# !pip install numba -U
# !pip install brainevent -U

import jax
import jax.numpy as jnp
import numpy as np

import brainevent
from brainevent import XLACustomKernel, numba_kernel

print(f"JAX version    : {jax.__version__}")
print(f"JAX backend    : {jax.default_backend()}")
print(f"BrainEvent     : {brainevent.__version__}")

try:
    import numba
    print(f"Numba version  : {numba.__version__}")
    NUMBA_AVAILABLE = True
except ImportError:
    print("Numba not installed. Run: pip install numba")
    NUMBA_AVAILABLE = False

## 3. Writing Numba JIT Kernels (`@numba.njit`)

Rules for Numba CPU kernels used with `numba_kernel`:
- Decorate with `@numba.njit`  (or `@numba.njit(parallel=True)` for parallelism)
- Function signature: `kernel(input1, input2, ..., output1, output2, ...)`
  – inputs first, then outputs; all as NumPy arrays (zero-copy from JAX)
- **Write results into output arrays** – no return values
- Standard Python math, NumPy slicing, and `for` loops all work

### 3.1 Simple Element-wise Kernels

In [None]:
if NUMBA_AVAILABLE:
    @numba.njit
    def add_kernel(x, y, out):
        """out[i] = x[i] + y[i]"""
        for i in range(out.size):
            out[i] = x[i] + y[i]

    @numba.njit
    def relu_kernel(x, out):
        """out[i] = max(x[i], 0.0)"""
        for i in range(out.size):
            v = x[i]
            out[i] = v if v > 0.0 else 0.0

    @numba.njit
    def matvec_kernel(A, x, out):
        """Dense matrix-vector product: out = A @ x"""
        rows, cols = A.shape
        for i in range(rows):
            total = A.dtype.type(0)
            for j in range(cols):
                total += A[i, j] * x[j]
            out[i] = total

    print("Numba kernels defined:", add_kernel, relu_kernel, matvec_kernel)

### 3.2 Reduction Kernels

In [None]:
if NUMBA_AVAILABLE:
    @numba.njit
    def sum_kernel(x, out):
        """out[0] = sum(x)."""
        total = x.dtype.type(0)
        for i in range(x.size):
            total += x[i]
        out[0] = total

    @numba.njit
    def max_kernel(x, out):
        """out[0] = max(x)."""
        m = x[0]
        for i in range(1, x.size):
            if x[i] > m:
                m = x[i]
        out[0] = m

    @numba.njit
    def running_stats_kernel(x, mean_out, std_out):
        """Compute mean and std in a single pass."""
        n = x.size
        s = x.dtype.type(0)
        for i in range(n):
            s += x[i]
        mean = s / n
        var = x.dtype.type(0)
        for i in range(n):
            d = x[i] - mean
            var += d * d
        mean_out[0] = mean
        std_out[0]  = (var / n) ** 0.5

    print("Reduction kernels defined.")

## 4. `numba_kernel` – Wrapping for JAX

`numba_kernel` wraps a Numba CPU kernel so it can be called with JAX CPU arrays
via XLA's typed FFI protocol.

**Signature:**
```python
numba_kernel(
    kernel,              # @numba.njit function
    outs,                # jax.ShapeDtypeStruct or list thereof
    *,
    vmap_method=None,
    input_output_aliases=None,
) -> callable
```

The returned callable accepts JAX arrays as inputs and returns JAX arrays as outputs.
It is compatible with `jax.jit`.

In [None]:
if NUMBA_AVAILABLE:
    N = 512
    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.ones(N, dtype=jnp.float32) * 3.0

    # Create the JAX-callable wrapper
    add_fn = numba_kernel(
        add_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
    )

    result = add_fn(a, b)
    # numba_kernel returns a tuple; unwrap if needed
    result = result[0] if isinstance(result, tuple) else result

    expected = a + b
    print("Add max error  :", float(jnp.max(jnp.abs(result - expected))))

    # ---- ReLU ----
    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)
    relu_fn = numba_kernel(
        relu_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
    )
    r = relu_fn(x)
    r = r[0] if isinstance(r, tuple) else r
    print("ReLU max error :", float(jnp.max(jnp.abs(r - jnp.maximum(x, 0.0)))))

In [None]:
if NUMBA_AVAILABLE:
    # ---- Reduction ----
    N = 10_000
    x = jnp.arange(N, dtype=jnp.float32)

    sum_fn = numba_kernel(
        sum_kernel,
        outs=jax.ShapeDtypeStruct((1,), jnp.float32),
    )

    s = sum_fn(x)
    s = s[0] if isinstance(s, tuple) else s
    print(f"Sum: {float(s[0]):.1f}  |  Expected: {float(jnp.sum(x)):.1f}")

    # ---- Multiple outputs (mean and std in one pass) ----
    stats_fn = numba_kernel(
        running_stats_kernel,
        outs=[
            jax.ShapeDtypeStruct((1,), jnp.float32),  # mean
            jax.ShapeDtypeStruct((1,), jnp.float32),  # std
        ],
    )

    mean_val, std_val = stats_fn(x)
    print(f"Mean: {float(mean_val[0]):.2f}  |  Std: {float(std_val[0]):.2f}")
    print(f"jnp.mean: {float(jnp.mean(x)):.2f}  |  jnp.std: {float(jnp.std(x)):.2f}")

### 4.1 JIT Compatibility

In [None]:
if NUMBA_AVAILABLE:
    N = 128
    add_fn_cached = numba_kernel(
        add_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
    )

    @jax.jit
    def jitted_pipeline(a, b):
        # Mix Numba kernel with standard JAX operations
        temp = add_fn_cached(a, b)
        temp = temp[0] if isinstance(temp, tuple) else temp
        return jnp.sin(temp) * jnp.sqrt(jnp.abs(temp) + 1.0)

    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.ones(N, dtype=jnp.float32)

    r1 = jitted_pipeline(a, b)
    r2 = jitted_pipeline(a * 2, b * 0.5)   # second call reuses compiled code

    print("JIT pipeline output shape:", r1.shape)
    print("First 5 values           :", r1[:5])

## 5. Parallel Kernels with `numba.prange`

Adding `parallel=True` to `@numba.njit` and replacing `range` with `numba.prange`
enables automatic parallelization across CPU cores using threading.

This is the easiest way to exploit multi-core CPUs without writing thread management code.

In [None]:
if NUMBA_AVAILABLE:
    @numba.njit(parallel=True)
    def parallel_add_kernel(x, y, out):
        """Parallel element-wise add using prange."""
        for i in numba.prange(out.size):  # parallelized loop
            out[i] = x[i] + y[i]

    @numba.njit(parallel=True)
    def parallel_matvec_kernel(A, x, out):
        """Parallel matrix-vector product: each row computed by a separate thread."""
        rows, cols = A.shape
        for i in numba.prange(rows):    # parallelize over rows
            total = A.dtype.type(0)
            for j in range(cols):       # inner loop stays sequential
                total += A[i, j] * x[j]
            out[i] = total

    @numba.njit(parallel=True)
    def parallel_exp_decay_kernel(trace, spikes, tau_inv, out):
        """
        Exponential trace update used in STDP:
          out[i] = trace[i] * exp(-tau_inv) + spikes[i]
        """
        import math
        decay = math.exp(-tau_inv[0])
        for i in numba.prange(out.size):
            out[i] = trace[i] * decay + spikes[i]

    print("Parallel kernels defined.")

In [None]:
if NUMBA_AVAILABLE:
    import time

    N = 1_000_000
    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.ones(N, dtype=jnp.float32)

    serial_fn   = numba_kernel(add_kernel,          outs=jax.ShapeDtypeStruct((N,), jnp.float32))
    parallel_fn = numba_kernel(parallel_add_kernel, outs=jax.ShapeDtypeStruct((N,), jnp.float32))

    # Warm up
    jax.block_until_ready(serial_fn(a, b))
    jax.block_until_ready(parallel_fn(a, b))

    N_TRIALS = 20

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(serial_fn(a, b))
    serial_time = (time.time() - t0) / N_TRIALS * 1000

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(parallel_fn(a, b))
    parallel_time = (time.time() - t0) / N_TRIALS * 1000

    import os
    n_cores = os.cpu_count()
    print(f"N = {N:,}  |  CPU cores: {n_cores}")
    print(f"Serial   : {serial_time:.2f} ms")
    print(f"Parallel : {parallel_time:.2f} ms")
    print(f"Speedup  : {serial_time / parallel_time:.2f}x")

## 6. Multiple Inputs and Outputs

Numba kernels can take any number of inputs and outputs.
The `outs` argument to `numba_kernel` mirrors the output buffers:
a single `ShapeDtypeStruct` for one output, a list for multiple.

In [None]:
if NUMBA_AVAILABLE:
    @numba.njit(parallel=True)
    def lif_dynamics_kernel(
        V,         # membrane potentials  (N,)
        I_ext,     # external current     (N,)
        tau_inv,   # 1/tau_m              (1,)
        V_th,      # threshold            (1,)
        V_rest,    # reset potential      (1,)
        dt,        # time step            (1,)
        V_out,     # updated potentials   (N,)  – output
        spikes,    # spike vector         (N,)  – output
    ):
        """
        One Euler step of leaky integrate-and-fire dynamics:
          V_out[i] = V[i] + dt * (-(V[i] - V_rest) * tau_inv + I_ext[i])
        then threshold and reset.
        """
        th   = V_th[0]
        vr   = V_rest[0]
        ti   = tau_inv[0]
        step = dt[0]

        for i in numba.prange(V.size):
            v_new = V[i] + step * (-(V[i] - vr) * ti + I_ext[i])
            if v_new >= th:
                spikes[i]  = 1
                V_out[i]   = vr
            else:
                spikes[i]  = 0
                V_out[i]   = v_new

    print("LIF dynamics kernel defined.")

In [None]:
if NUMBA_AVAILABLE:
    N = 5_000
    rng = np.random.default_rng(0)

    V      = jnp.array(rng.uniform(-75.0, -50.0, N).astype(np.float32))
    I_ext  = jnp.array(rng.uniform(0.0,    5.0,  N).astype(np.float32))
    tau_inv = jnp.array([1.0 / 20.0], dtype=jnp.float32)  # tau_m = 20 ms
    V_th   = jnp.array([-55.0], dtype=jnp.float32)
    V_rest = jnp.array([-70.0], dtype=jnp.float32)
    dt     = jnp.array([0.1],   dtype=jnp.float32)         # dt = 0.1 ms

    lif_fn = numba_kernel(
        lif_dynamics_kernel,
        outs=[
            jax.ShapeDtypeStruct((N,), jnp.float32),  # V_out
            jax.ShapeDtypeStruct((N,), jnp.int32),    # spikes
        ],
    )

    V_new, spikes = lif_fn(V, I_ext, tau_inv, V_th, V_rest, dt)

    print(f"Neurons: {N}")
    print(f"Spikes : {int(spikes.sum())} ({100*float(spikes.mean()):.1f}%)")
    print(f"V range: [{float(V_new.min()):.2f}, {float(V_new.max()):.2f}] mV")

    # Verify against JAX reference
    V_ref = V + dt[0] * (-(V - V_rest[0]) * tau_inv[0] + I_ext)
    spk_ref = (V_ref >= V_th[0]).astype(jnp.int32)
    V_ref   = jnp.where(spk_ref, V_rest[0], V_ref)

    print(f"V max error    : {float(jnp.max(jnp.abs(V_new - V_ref))):.6f} mV")
    print(f"Spike mismatch : {int(jnp.sum(spikes != spk_ref))}")

## 7. Registering with `XLACustomKernel`

For production use, embed your Numba kernel inside a **kernel generator** and
register it with `XLACustomKernel`.  The generator receives shape/dtype
information forwarded from `primitive.bind` and returns the concrete callable.

In [None]:
if NUMBA_AVAILABLE:
    # -----------------------------------------------------------------------
    # Kernel generator: exponential trace update (STDP)
    #   out[i] = trace[i] * decay + spikes[i]
    # -----------------------------------------------------------------------

    def exp_trace_numba_generator(**kwargs):
        out_info = kwargs['outs'][0]
        n        = out_info.shape[0]

        @numba.njit(parallel=True)
        def trace_kern(trace, spikes, tau_inv, out):
            import math
            decay = math.exp(-tau_inv[0])
            for i in numba.prange(n):
                out[i] = trace[i] * decay + spikes[i]

        def kernel(trace, spikes, tau_inv):
            result = numba_kernel(
                trace_kern,
                outs=out_info,
            )(trace, spikes, tau_inv)
            return result if not isinstance(result, tuple) else result

        return kernel

    # Register the primitive
    trace_op = XLACustomKernel('tutorial_numba_exp_trace')
    trace_op.def_numba_kernel(exp_trace_numba_generator)

    print("Registered backends:", list(trace_op._kernels.keys()))

In [None]:
if NUMBA_AVAILABLE:
    N = 1000
    trace   = jnp.zeros(N, dtype=jnp.float32)
    spikes  = jnp.array(np.random.default_rng(1).random(N) < 0.1,
                        dtype=jnp.float32)
    tau_inv = jnp.array([1.0 / 20.0], dtype=jnp.float32)  # tau = 20 ms

    out_spec = jax.ShapeDtypeStruct((N,), jnp.float32)

    @jax.jit
    def update_trace(trace, spikes, tau_inv):
        return trace_op(
            trace, spikes, tau_inv,
            outs=[out_spec],
        )[0]

    # Simulate 100 time steps of trace dynamics
    import math
    decay = math.exp(-float(tau_inv[0]))

    trace_history = []
    for step in range(100):
        spikes = jnp.array(
            np.random.default_rng(step).random(N) < 0.05,
            dtype=jnp.float32
        )
        trace = update_trace(trace, spikes, tau_inv)
        trace_history.append(float(trace.mean()))

    print(f"Trace stats after 100 steps:")
    print(f"  Mean  : {float(trace.mean()):.4f}")
    print(f"  Max   : {float(trace.max()):.4f}")
    print(f"  Steady-state (theory): {0.05 / (1 - decay):.4f}")

## 8. Neuroscience Example: Sparse CSR × Float-Vector Multiplication

A core operation in neural network simulation:
given a CSR weight matrix and a float input vector, compute the matrix-vector product.

This is naturally sequential per output neuron (row of CSR), making it a good fit
for parallel Numba on CPU.

In [None]:
if NUMBA_AVAILABLE:
    @numba.njit(parallel=True)
    def csr_matvec_numba(
        data,       # CSR non-zero values  (nnz,)
        indices,    # CSR column indices   (nnz,)
        indptr,     # CSR row pointers     (n_rows+1,)
        x,          # input vector         (n_cols,)
        out,        # output vector        (n_rows,)
    ):
        """
        Sparse matrix-vector product (CSR format).
        Each row is processed by one thread (parallel over rows).
        """
        n_rows = indptr.size - 1
        for i in numba.prange(n_rows):
            total = out.dtype.type(0)
            for k in range(indptr[i], indptr[i + 1]):
                total += data[k] * x[indices[k]]
            out[i] = total

    def csr_mv_numba_generator(**kwargs):
        out_info = kwargs['outs'][0]

        def kernel(data, indices, indptr, x):
            result = numba_kernel(
                csr_matvec_numba,
                outs=out_info,
            )(data, indices, indptr, x)
            return result if not isinstance(result, tuple) else result

        return kernel

    csr_mv_op = XLACustomKernel('tutorial_numba_csr_matvec')
    csr_mv_op.def_numba_kernel(csr_mv_numba_generator)

    print("CSR MV operator registered.")

In [None]:
if NUMBA_AVAILABLE:
    import scipy.sparse as sp

    N_PRE  = 2000
    N_POST = 1000
    PROB   = 0.05

    rng   = np.random.default_rng(42)
    dense = (rng.random((N_POST, N_PRE)) < PROB).astype(np.float32)
    dense *= rng.uniform(0.01, 0.5, dense.shape).astype(np.float32)
    csr   = sp.csr_matrix(dense)

    data    = jnp.array(csr.data,    dtype=jnp.float32)
    indices = jnp.array(csr.indices, dtype=jnp.int32)
    indptr  = jnp.array(csr.indptr,  dtype=jnp.int32)
    x       = jnp.array(rng.random(N_PRE).astype(np.float32))

    out_spec = jax.ShapeDtypeStruct((N_POST,), jnp.float32)

    result = csr_mv_op(
        data, indices, indptr, x,
        outs=[out_spec],
    )[0]

    expected = jnp.array(dense) @ x
    print(f"Network: {N_PRE} pre -> {N_POST} post  (nnz={csr.nnz})")
    print(f"Max error vs dense: {float(jnp.max(jnp.abs(result - expected))):.6f}")

In [None]:
if NUMBA_AVAILABLE:
    import time

    @jax.jit
    def numba_csr_mv(data, indices, indptr, x):
        return csr_mv_op(
            data, indices, indptr, x,
            outs=[out_spec],
        )[0]

    @jax.jit
    def jax_dense_mv(A, x):
        return A @ x

    A_jnp = jnp.array(dense)

    # Warm up
    jax.block_until_ready(numba_csr_mv(data, indices, indptr, x))
    jax.block_until_ready(jax_dense_mv(A_jnp, x))

    N_TRIALS = 50

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(numba_csr_mv(data, indices, indptr, x))
    numba_time = (time.time() - t0) / N_TRIALS * 1000

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(jax_dense_mv(A_jnp, x))
    jax_time = (time.time() - t0) / N_TRIALS * 1000

    print(f"Numba CSR MV  : {numba_time:.2f} ms")
    print(f"JAX dense MV  : {jax_time:.2f} ms")
    print(f"Speedup       : {jax_time / numba_time:.2f}x  (sparsity: {1 - csr.nnz/(N_PRE*N_POST):.0%})")

## 9. Combining Numba CPU and GPU Backends

The same `XLACustomKernel` primitive can have both a Numba CPU backend and a
GPU backend (Warp or Numba CUDA). JAX automatically dispatches to the correct
backend based on the device where the arrays live.

In [None]:
try:
    import warp
    from warp.jax_experimental import jax_kernel as warp_jax_kernel
    from brainevent import jaxinfo_to_warpinfo
    warp.config.quiet = True
    WARP_AVAILABLE = True
except ImportError:
    WARP_AVAILABLE = False

if NUMBA_AVAILABLE:
    # CPU backend (already shown above)
    @numba.njit(parallel=True)
    def scale_numba(x, s, out):
        for i in numba.prange(out.size):
            out[i] = x[i] * s[0]

    def scale_numba_generator(**kwargs):
        out_info = kwargs['outs'][0]

        def kernel(x, s):
            r = numba_kernel(scale_numba, outs=out_info)(x, s)
            return r if not isinstance(r, tuple) else r

        return kernel

    scale_op = XLACustomKernel('tutorial_multi_backend_scale')
    scale_op.def_numba_kernel(scale_numba_generator)   # CPU backend

    if WARP_AVAILABLE:
        def scale_warp_generator(**kwargs):
            out_info = kwargs['outs'][0]
            n = out_info.shape[0]
            t = jaxinfo_to_warpinfo(out_info)
            s_type = warp.array(dtype=jaxinfo_to_warpinfo(out_info).dtype, ndim=1)

            @warp.kernel
            def kern(x: t, s: s_type, out: t):
                i = warp.tid()
                out[i] = x[i] * s[0]

            def kernel(x, s):
                fn = warp_jax_kernel(kern, launch_dims=[n], num_outputs=1,
                                     output_dims={'out': (n,)})
                return fn(x, s)

            return kernel

        scale_op.def_warp_kernel(scale_warp_generator)  # GPU backend

    print("Multi-backend scale op registered.")
    print("Backends:", {p: list(b.keys()) for p, b in scale_op._kernels.items()})

    # Use it
    N = 256
    x = jnp.arange(N, dtype=jnp.float32)
    s = jnp.array([3.14], dtype=jnp.float32)

    r = scale_op(x, s, outs=[jax.ShapeDtypeStruct((N,), jnp.float32)])[0]
    print(f"Result matches: {bool(jnp.allclose(r, x * 3.14, atol=1e-5))}")

## 10. Summary

In this tutorial we covered:

1. **`@numba.njit`** – JIT-compile Python to native machine code. Kernel signature:
   `kernel(input1, ..., output1, ...)` – all NumPy arrays, no return values.
2. **`@numba.njit(parallel=True)` + `numba.prange`** – Multi-threaded parallelism
   on CPU cores with zero additional code.
3. **`numba_kernel(kernel, outs=...)`** – Wrap a Numba kernel as a JAX-callable
   via XLA FFI. Returns a function compatible with `jax.jit`.
4. **Multiple outputs** – Pass a list of `jax.ShapeDtypeStruct` to `outs` to get
   multiple return arrays from a single kernel call.
5. **`XLACustomKernel.def_numba_kernel`** – Register a kernel generator as the CPU
   backend of a multi-backend custom JAX primitive.
6. **Neuroscience applications** – LIF dynamics and sparse CSR matrix-vector product
   implemented with parallel Numba, demonstrating realistic use cases.

## Key Guidelines

- Cache the wrapped callable (do **not** call `numba_kernel` inside `@jax.jit`);
  create it once at definition time.
- Use `@njit(parallel=True)` + `prange` for outer loops; keep inner loops sequential.
- Prefer Numba on CPU for irregular / sparse access patterns; prefer GPU backends
  (Warp, Numba CUDA) for large-scale parallel workloads.

## Next Steps

- **Tutorial 6**: Custom GPU operators with Warp
- **Tutorial 7**: Custom GPU operators with Numba CUDA

## References

- [Numba documentation](https://numba.readthedocs.io/)
- [Numba `prange` parallelism](https://numba.readthedocs.io/en/stable/user/parallel.html)
- [BrainEvent GitHub](https://github.com/chaobrain/brainevent)
- [JAX FFI documentation](https://jax.readthedocs.io/en/latest/ffi.html)