# Custom GPU Operators with Numba CUDA

This tutorial shows how to write custom GPU kernels using **Numba CUDA** and integrate them
into the BrainEvent / JAX ecosystem.

[Numba](https://numba.readthedocs.io/) is a JIT compiler for Python that targets CUDA GPUs
via its `numba.cuda` subpackage. Kernels are written in Python, compiled at first call,
and run natively on the GPU. BrainEvent provides two functions that bridge Numba CUDA kernels
into JAX via XLA's Foreign Function Interface (FFI):

- **`numba_cuda_kernel`** – wraps a single `@cuda.jit` kernel with fixed launch configuration.
- **`numba_cuda_callable`** – wraps an arbitrary Python function that may launch *multiple*
  CUDA kernels, allocate temporary memory, and orchestrate multi-step GPU computation.

## Contents
1. Why Numba CUDA?
2. Installation and Imports
3. Writing Numba CUDA Kernels (`@cuda.jit`)
4. `numba_cuda_kernel` – Single-Kernel Wrapper
5. Launch Configuration: `grid` / `block` vs. `launch_dims`
6. `numba_cuda_callable` – Multi-Kernel Wrapper
7. Registering with `XLACustomKernel`
8. Neuroscience Example: Parallel Spike Threshold Detection
9. Performance Tips
10. Summary

## 1. Why Numba CUDA?

| Feature | Numba CUDA | Warp | Raw CUDA C++ |
|---------|------------|------|--------------|
| Language | Python (`@cuda.jit`) | Python-like | C++ |
| Low-level control | Full thread/block/shared-mem | Partial | Full |
| Shared memory | Yes | Yes | Yes |
| Atomic operations | Yes | Yes | Yes |
| Device-side allocation | Yes (`cuda.device_array`) | Limited | Yes |
| Multi-kernel orchestration | Yes (via `numba_cuda_callable`) | No | Yes |

Choose Numba CUDA when you need:
- Fine-grained control over shared memory or thread synchronization
- Multi-kernel pipelines with temporary device allocations
- Familiar CUDA programming model in Python

**Requirements:**
- NVIDIA GPU with CUDA
- `pip install numba` + CUDA toolkit
- JAX with GPU support (`pip install jax[cuda12]`)

## 2. Installation and Imports

In [None]:
# Install if needed:
# !pip install numba -U
# !pip install brainevent[cuda12] -U

import jax
import jax.numpy as jnp
import numpy as np

import brainevent
from brainevent import XLACustomKernel, numba_cuda_kernel, numba_cuda_callable

print(f"JAX version    : {jax.__version__}")
print(f"JAX backend    : {jax.default_backend()}")
print(f"BrainEvent     : {brainevent.__version__}")

try:
    from numba import cuda
    NUMBA_CUDA_AVAILABLE = cuda.is_available()
    if NUMBA_CUDA_AVAILABLE:
        import numba
        print(f"Numba version  : {numba.__version__}")
        print(f"CUDA available : {NUMBA_CUDA_AVAILABLE}")
        gpu = cuda.get_current_device()
        print(f"GPU            : {gpu.name}")
    else:
        print("Numba installed but CUDA device not found.")
except ImportError:
    print("Numba not installed. Run: pip install numba")
    NUMBA_CUDA_AVAILABLE = False

## 3. Writing Numba CUDA Kernels (`@cuda.jit`)

Numba CUDA kernels follow standard CUDA programming conventions:
- Decorated with `@cuda.jit`
- Each thread identifies itself via `cuda.grid(ndim)` (equivalent to `blockIdx * blockDim + threadIdx`)
- Arrays received as Numba device arrays (zero-copy from JAX GPU memory)
- Results are written **in-place** into output arrays (no return values)

### 3.1 Basic Element-wise Kernel

In [None]:
if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def elementwise_relu_kernel(x, out):
        """Element-wise ReLU: out[i] = max(x[i], 0)."""
        i = cuda.grid(1)          # global thread index
        if i < out.size:          # bounds check
            out[i] = max(x[i], 0.0)

    @cuda.jit
    def elementwise_sigmoid_kernel(x, out):
        """Element-wise sigmoid: out[i] = 1 / (1 + exp(-x[i]))."""
        import math
        i = cuda.grid(1)
        if i < out.size:
            out[i] = 1.0 / (1.0 + math.exp(-x[i]))

    print("Kernels defined:", elementwise_relu_kernel, elementwise_sigmoid_kernel)

### 3.2 Kernel with Shared Memory

Shared memory is fast, on-chip memory shared between threads in the same block.
It is useful for reduction operations or when threads need to communicate.

In [None]:
if NUMBA_CUDA_AVAILABLE:
    BLOCK_SIZE = 256

    @cuda.jit
    def block_sum_kernel(x, block_sums):
        """
        Computes the sum of each block's elements using shared memory reduction.
        block_sums[blockIdx.x] = sum of x elements processed by block blockIdx.x.
        """
        shared = cuda.shared.array(BLOCK_SIZE, dtype=numba.float32)

        tx  = cuda.threadIdx.x
        pos = cuda.grid(1)

        # Load into shared memory
        shared[tx] = x[pos] if pos < x.size else 0.0
        cuda.syncthreads()

        # Parallel reduction in shared memory
        stride = BLOCK_SIZE // 2
        while stride > 0:
            if tx < stride:
                shared[tx] += shared[tx + stride]
            cuda.syncthreads()
            stride //= 2

        # Thread 0 writes the block result
        if tx == 0:
            block_sums[cuda.blockIdx.x] = shared[0]

    print("Shared memory kernel defined.")

## 4. `numba_cuda_kernel` – Single-Kernel Wrapper

`numba_cuda_kernel` wraps a single `@cuda.jit` kernel so it can be called with
JAX GPU arrays.  The kernel receives Numba CUDA device arrays (zero-copy from JAX
device memory) and writes results into the output buffer.

**Function signature:**
```python
numba_cuda_kernel(
    kernel,                    # @cuda.jit decorated function
    outs,                      # jax.ShapeDtypeStruct or list thereof
    *,
    grid=None, block=None,     # explicit CUDA launch config
    launch_dims=None,          # OR total threads (auto grid/block)
    threads_per_block=256,     # only used with launch_dims
    shared_mem=0,              # dynamic shared memory bytes
) -> callable
```

The kernel function signature must be:
```python
kernel(input1, input2, ..., output1, output2, ...)
```
Inputs first, then outputs – all as Numba device arrays.

In [None]:
if NUMBA_CUDA_AVAILABLE:
    N = 1024
    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)

    # Wrap the ReLU kernel
    relu_fn = numba_cuda_kernel(
        elementwise_relu_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
        grid=4,
        block=256,
    )

    result  = relu_fn(x)
    expected = jnp.maximum(x, 0.0)
    print("ReLU max error:", float(jnp.max(jnp.abs(result - expected))))

    # Wrap the sigmoid kernel using launch_dims (auto grid/block)
    sigmoid_fn = numba_cuda_kernel(
        elementwise_sigmoid_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
        launch_dims=N,           # launch exactly N threads
        threads_per_block=128,
    )

    result_sig  = sigmoid_fn(x)
    expected_sig = 1.0 / (1.0 + jnp.exp(-x))
    print("Sigmoid max error:", float(jnp.max(jnp.abs(result_sig - expected_sig))))

### 4.1 Multiple Outputs

In [None]:
if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def split_kernel(x, pos_out, neg_out):
        """Split positive and negative parts of x."""
        i = cuda.grid(1)
        if i < x.size:
            v = x[i]
            pos_out[i] = max(v, 0.0)
            neg_out[i] = min(v, 0.0)

    N = 512
    x = jnp.linspace(-2.0, 2.0, N, dtype=jnp.float32)

    split_fn = numba_cuda_kernel(
        split_kernel,
        outs=[
            jax.ShapeDtypeStruct((N,), jnp.float32),  # pos_out
            jax.ShapeDtypeStruct((N,), jnp.float32),  # neg_out
        ],
        launch_dims=N,
    )

    pos, neg = split_fn(x)

    print("pos_out[:5]:", pos[:5])   # max(x, 0)
    print("neg_out[:5]:", neg[:5])   # min(x, 0)
    print("pos + neg == x:", bool(jnp.allclose(pos + neg, x)))

### 4.2 JIT Compatibility

In [None]:
if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def add_kernel(x, y, out):
        i = cuda.grid(1)
        if i < out.size:
            out[i] = x[i] + y[i]

    N = 256
    add_fn = numba_cuda_kernel(
        add_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
        launch_dims=N,
    )

    @jax.jit
    def jitted_add(a, b):
        return add_fn(a, b)

    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.ones(N, dtype=jnp.float32) * 2.0

    r = jitted_add(a, b)
    print("JIT add max error:", float(jnp.max(jnp.abs(r - (a + b)))))

    # Call multiple times (JIT is amortized after first call)
    for _ in range(5):
        r = jitted_add(a, b)
    print("Multiple JIT calls OK:", bool(jnp.allclose(r, a + b)))

## 5. Launch Configuration: `grid` / `block` vs. `launch_dims`

CUDA kernels need a grid/block decomposition specifying how many threads to launch.

**Option A – explicit `grid` and `block`:**
```python
numba_cuda_kernel(kernel, outs=..., grid=8, block=128)
# launches 8 blocks × 128 threads = 1024 total threads
```

**Option B – `launch_dims` (auto-compute):**
```python
numba_cuda_kernel(kernel, outs=..., launch_dims=1024, threads_per_block=256)
# auto: block=256, grid=ceil(1024/256)=4
```

**2D / 3D launches:**
```python
# 2D: launch M×N threads
numba_cuda_kernel(kernel, outs=..., launch_dims=(M, N))
# auto: block=(16,16), grid=(ceil(M/16), ceil(N/16))
```

In [None]:
if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def matmul_element_kernel(A, B, C):
        """C[i,j] = A[i,j] * B[i,j]  (element-wise, 2D grid)."""
        i, j = cuda.grid(2)
        if i < C.shape[0] and j < C.shape[1]:
            C[i, j] = A[i, j] * B[i, j]

    M, N = 64, 64
    A = jnp.arange(M * N, dtype=jnp.float32).reshape(M, N)
    B = jnp.ones((M, N), dtype=jnp.float32) * 2.0

    hadamard_fn = numba_cuda_kernel(
        matmul_element_kernel,
        outs=jax.ShapeDtypeStruct((M, N), jnp.float32),
        launch_dims=(M, N),      # 2D launch
    )

    C = hadamard_fn(A, B)
    print("2D kernel max error:", float(jnp.max(jnp.abs(C - A * B))))

## 6. `numba_cuda_callable` – Multi-Kernel Wrapper

Sometimes one kernel is not enough.  For example, a reduction may need two passes,
or a pipeline may require a temporary device buffer between stages.

`numba_cuda_callable` wraps an **arbitrary Python function** that can:
- Launch multiple `@cuda.jit` kernels
- Allocate temporary device memory with `cuda.device_array`
- Use the XLA-managed CUDA stream (passed as the last argument)

**Required function signature:**
```python
def my_func(input1, input2, ..., output1, output2, ..., stream):
    # input* and output* are Numba CUDA device arrays
    # stream is a Numba CUDA stream from XLA
    ...
```

In [None]:
if NUMBA_CUDA_AVAILABLE:
    # ---- Kernel 1: element-wise square ----
    @cuda.jit
    def square_kernel(x, temp):
        i = cuda.grid(1)
        if i < temp.size:
            temp[i] = x[i] * x[i]

    # ---- Kernel 2: element-wise square root ----
    @cuda.jit
    def sqrt_kernel(temp, out):
        import math
        i = cuda.grid(1)
        if i < out.size:
            out[i] = math.sqrt(temp[i])

    # ---- Multi-kernel callable: |x| = sqrt(x^2) ----
    def abs_via_two_kernels(x, out, stream):
        """
        Compute |x| using two kernels and a temporary buffer.
        Demonstrates multi-kernel pipeline with device allocation.
        """
        n = x.shape[0]
        threads = 256
        blocks  = (n + threads - 1) // threads

        # Temporary buffer on the GPU (freed after this function returns)
        temp = cuda.device_array(n, dtype=np.float32)

        # Launch both kernels on the XLA-managed stream
        square_kernel[blocks, threads, stream](x, temp)
        sqrt_kernel  [blocks, threads, stream](temp, out)

    N = 512
    x = jnp.linspace(-5.0, 5.0, N, dtype=jnp.float32)

    abs_fn = numba_cuda_callable(
        abs_via_two_kernels,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
    )

    result   = abs_fn(x)
    expected = jnp.abs(x)
    print("Multi-kernel |x| max error:", float(jnp.max(jnp.abs(result - expected))))
    print("First 5 values :", result[:5])
    print("Expected |x|   :", expected[:5])

## 7. Registering with `XLACustomKernel`

For production use, register your Numba CUDA kernel as a backend of an
`XLACustomKernel` primitive.  This integrates the kernel into JAX's lowering
pipeline and allows mixing with other backends (e.g., a Numba CPU fallback).

The **kernel generator** pattern is the same as for Warp (see Tutorial 6):
it is a Python callable that receives keyword arguments (forwarded from
`primitive.bind`) and returns a concrete kernel function.

In [None]:
if NUMBA_CUDA_AVAILABLE:
    # -----------------------------------------------------------------------
    # Kernel generator for element-wise leaky ReLU:
    #   out[i] = x[i] if x[i] > 0 else alpha * x[i]
    # 'alpha' is passed at trace time via kwargs.
    # -----------------------------------------------------------------------

    def leaky_relu_numba_cuda_generator(**kwargs):
        out_info = kwargs['outs'][0]
        n        = out_info.shape[0]
        alpha    = float(kwargs.get('alpha', 0.01))

        @cuda.jit
        def leaky_relu_kern(x, out):
            i = cuda.grid(1)
            if i < out.size:
                v = x[i]
                out[i] = v if v > 0.0 else alpha * v

        def kernel(x):
            return numba_cuda_kernel(
                leaky_relu_kern,
                outs=out_info,
                launch_dims=n,
            )(x)

        return kernel

    # Register the primitive
    leaky_relu_op = XLACustomKernel('tutorial_numba_cuda_leaky_relu')
    leaky_relu_op.def_numba_cuda_kernel(leaky_relu_numba_cuda_generator)

    print("Registered:", leaky_relu_op._kernels)

In [None]:
if NUMBA_CUDA_AVAILABLE:
    N = 256
    x = jnp.linspace(-3.0, 3.0, N, dtype=jnp.float32)

    @jax.jit
    def jitted_leaky_relu(x, alpha=0.1):
        return leaky_relu_op(
            x,
            outs=[jax.ShapeDtypeStruct(x.shape, x.dtype)],
            alpha=alpha,
        )[0]

    r = jitted_leaky_relu(x, alpha=0.1)

    expected = jnp.where(x > 0, x, 0.1 * x)
    print("Leaky ReLU max error:", float(jnp.max(jnp.abs(r - expected))))
    print("Values around 0     :", r[N//2 - 3 : N//2 + 3])

## 8. Neuroscience Example: Parallel Spike Threshold Detection

A common operation in spiking neural networks: given membrane potentials `V` and a
threshold `V_th`, detect which neurons fire and reset their potentials in-place.

We implement this as two Numba CUDA kernels fused via `numba_cuda_callable`:
1. **Detect** spikes: `spikes[i] = (V[i] >= V_th)`
2. **Reset** potentials: `V_reset[i] = spikes[i] ? V_rest : V[i]`

In [None]:
if NUMBA_CUDA_AVAILABLE:
    @cuda.jit
    def detect_spikes_kernel(V, V_th, spikes):
        """spikes[i] = 1 if V[i] >= V_th[0] else 0."""
        i = cuda.grid(1)
        if i < V.size:
            spikes[i] = 1 if V[i] >= V_th[0] else 0

    @cuda.jit
    def reset_potential_kernel(V, spikes, V_rest, V_out):
        """V_out[i] = V_rest[0] if spikes[i] else V[i]."""
        i = cuda.grid(1)
        if i < V.size:
            V_out[i] = V_rest[0] if spikes[i] else V[i]

    def lif_step(V, V_th, V_rest, spikes_out, V_out, stream):
        """One LIF step: detect spikes and reset membrane potential."""
        n       = V.shape[0]
        threads = 256
        blocks  = (n + threads - 1) // threads

        detect_spikes_kernel[blocks, threads, stream](V, V_th, spikes_out)
        reset_potential_kernel[blocks, threads, stream](V, spikes_out, V_rest, V_out)

    print("LIF step function defined.")

In [None]:
if NUMBA_CUDA_AVAILABLE:
    N_NEURONS = 10_000

    rng = np.random.default_rng(0)
    V      = jnp.array(rng.uniform(-75.0, -50.0, N_NEURONS).astype(np.float32))
    V_th   = jnp.array([-55.0], dtype=jnp.float32)   # threshold (mV)
    V_rest = jnp.array([-70.0], dtype=jnp.float32)   # reset potential (mV)

    lif_fn = numba_cuda_callable(
        lif_step,
        outs=[
            jax.ShapeDtypeStruct((N_NEURONS,), jnp.int32),    # spikes
            jax.ShapeDtypeStruct((N_NEURONS,), jnp.float32),  # V_out
        ],
    )

    spikes, V_out = lif_fn(V, V_th, V_rest)

    # Verify
    expected_spikes = (V >= V_th[0]).astype(jnp.int32)
    expected_V_out  = jnp.where(expected_spikes, V_rest[0], V)

    print(f"Neurons: {N_NEURONS}")
    print(f"Spikes detected: {int(spikes.sum())} / {N_NEURONS}  ({100*float(spikes.mean()):.1f}%)")
    print(f"Spike detection error: {int(jnp.sum(spikes != expected_spikes))}")
    print(f"V_out max error: {float(jnp.max(jnp.abs(V_out - expected_V_out))):.6f} mV")

In [None]:
if NUMBA_CUDA_AVAILABLE:
    import time

    @jax.jit
    def numba_lif_step(V, V_th, V_rest):
        return lif_fn(V, V_th, V_rest)

    @jax.jit
    def jax_lif_step(V, V_th, V_rest):
        spikes = (V >= V_th[0]).astype(jnp.int32)
        V_out  = jnp.where(spikes, V_rest[0], V)
        return spikes, V_out

    # Warm up
    jax.block_until_ready(numba_lif_step(V, V_th, V_rest))
    jax.block_until_ready(jax_lif_step(V, V_th, V_rest))

    N_TRIALS = 500

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(numba_lif_step(V, V_th, V_rest))
    numba_time = (time.time() - t0) / N_TRIALS * 1000

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(jax_lif_step(V, V_th, V_rest))
    jax_time = (time.time() - t0) / N_TRIALS * 1000

    print(f"Numba CUDA LIF step : {numba_time:.3f} ms")
    print(f"JAX native LIF step : {jax_time:.3f} ms")

## 9. Performance Tips

### 9.1 Choose the right launch configuration

- A warp is 32 threads; use block sizes that are multiples of 32 (128, 256, 512).
- Too few threads per block wastes warp slots; too many limits occupancy.
- Use `launch_dims` for simple 1-D problems; specify explicit `grid`/`block` for fine control.

### 9.2 Minimize thread divergence

Threads in the same warp execute in lock-step. Conditional branches that differ between
threads (divergence) serialize execution.  Where possible, arrange data so threads
in the same warp take the same branch.

### 9.3 Use shared memory for reuse

If multiple threads access the same data, load it into shared memory first and
synchronize with `cuda.syncthreads()` before reading.

### 9.4 Cache the wrapped callable

Each call to `numba_cuda_kernel` / `numba_cuda_callable` registers a new FFI target.
Create the wrapped callable **once** at module level and reuse it.

In [None]:
if NUMBA_CUDA_AVAILABLE:
    # Good pattern: create the callable once and reuse
    @cuda.jit
    def exp_decay_kernel(x, decay, out):
        import math
        i = cuda.grid(1)
        if i < out.size:
            out[i] = x[i] * math.exp(-decay[0])

    N = 1024
    # Create once at definition time
    _exp_decay_fn = numba_cuda_kernel(
        exp_decay_kernel,
        outs=jax.ShapeDtypeStruct((N,), jnp.float32),
        launch_dims=N,
    )

    @jax.jit
    def apply_exp_decay(x, decay):
        return _exp_decay_fn(x, decay)

    x     = jnp.ones(N, dtype=jnp.float32)
    decay = jnp.array([0.1], dtype=jnp.float32)

    r = apply_exp_decay(x, decay)
    print("Exp decay result:", float(r[0]), "| Expected:", float(np.exp(-0.1)))

## 10. Summary

In this tutorial we covered:

1. **`@cuda.jit`** – Write GPU kernels in Python; use `cuda.grid(ndim)` for thread indices
   and bounds-check with `if i < size`.
2. **`numba_cuda_kernel`** – Single-kernel JAX wrapper.  Specify launch config via
   `(grid, block)` for explicit control or `launch_dims` for automatic decomposition.
   Supports 1-D, 2-D, and 3-D launches.
3. **`numba_cuda_callable`** – Multi-kernel JAX wrapper.  Your Python function receives
   Numba device arrays and the XLA-managed CUDA stream; it can launch multiple kernels
   and allocate temporary device memory.
4. **`XLACustomKernel.def_numba_cuda_kernel`** – Register a kernel generator as the GPU
   backend of a multi-backend custom JAX primitive.
5. **Neuroscience application** – LIF spike detection and reset implemented as a
   two-kernel callable, showing the key pattern for fused GPU operations.

## Next Steps

- **Tutorial 6**: Custom GPU operators with Warp
- **Tutorial 8**: Custom CPU operators with Numba (`@numba.njit`)

## References

- [Numba CUDA documentation](https://numba.readthedocs.io/en/stable/cuda/index.html)
- [BrainEvent GitHub](https://github.com/chaobrain/brainevent)
- [JAX FFI documentation](https://jax.readthedocs.io/en/latest/ffi.html)