# Tutorial 6: Custom GPU Operators with Warp

This tutorial shows how to write custom GPU kernels using **NVIDIA Warp** and integrate them
into the BrainEvent / JAX ecosystem.

[NVIDIA Warp](https://github.com/NVIDIA/warp) is a Python framework for high-performance
GPU kernel authoring. Kernels are written in Python-like syntax, JIT-compiled to CUDA PTX,
and can be called seamlessly from JAX via `warp.jax_experimental.jax_kernel`.

## Contents
1. Why Warp?
2. Installation and Imports
3. Writing Your First Warp Kernel
4. Type Annotations – `jaxinfo_to_warpinfo` / `jaxtype_to_warptype`
5. Calling Warp Kernels from JAX
6. In-place (accumulation) vs. Pure-output Patterns
7. Registering Kernels with `XLACustomKernel`
8. Neuroscience Example: Sparse Synaptic Input Accumulation
9. Summary

## 1. Why Warp?

| Feature | Warp | Raw CUDA C++ |
|---------|------|--------------|
| Language | Python-like syntax | C++ |
| Compilation | Automatic JIT | Manual |
| JAX integration | Built-in (`jax_kernel`) | Manual XLA FFI |
| Autodiff | Limited (scalar ops) | Manual |
| Best for | Custom GPU ops in Python | Maximum control |

Warp is the recommended path when you want GPU acceleration without leaving Python.
BrainEvent's `XLACustomKernel` infrastructure makes it trivial to register a Warp kernel
as a backend for any custom JAX primitive.

**Requirements:**
- NVIDIA GPU with CUDA
- `pip install warp-lang` (installs as `import warp`)
- JAX with GPU support (`pip install jax[cuda12]`)

## 2. Installation and Imports

In [None]:
# Install if needed
# !pip install warp-lang -U
# !pip install brainevent[cuda12] -U

import jax
import jax.numpy as jnp
import numpy as np

import brainevent
from brainevent import XLACustomKernel, jaxinfo_to_warpinfo, jaxtype_to_warptype

print(f"JAX version    : {jax.__version__}")
print(f"JAX backend    : {jax.default_backend()}")
print(f"BrainEvent     : {brainevent.__version__}")

try:
    import warp
    from warp.jax_experimental import jax_kernel
    warp.config.quiet = True
    print(f"Warp version   : {warp.__version__}")
    WARP_AVAILABLE = True
except ImportError:
    print("Warp not installed. Run: pip install warp-lang")
    WARP_AVAILABLE = False

## 3. Writing Your First Warp Kernel

A Warp kernel is a Python function decorated with `@warp.kernel`. Key rules:
- **No Python data structures** – only Warp scalars and arrays
- **Thread index** obtained via `warp.tid()` (replaces `blockIdx * blockDim + threadIdx` in CUDA C)
- **Array types** must be annotated using `warp.array(dtype=..., ndim=...)`
- The kernel body runs **once per thread**, so you typically launch one thread per element

### 3.1 Element-wise ReLU

In [None]:
if WARP_AVAILABLE:
    @warp.kernel
    def relu_kernel(
        x:   warp.array(dtype=warp.float32, ndim=1),
        out: warp.array(dtype=warp.float32, ndim=1),
    ):
        i = warp.tid()           # thread index = element index
        out[i] = warp.max(x[i], warp.float32(0.0))

    print("relu_kernel defined successfully")
    print(f"Kernel type: {type(relu_kernel)}")

### 3.2 Calling the Kernel via `jax_kernel`

`jax_kernel` wraps a Warp kernel so it can be called with JAX arrays.

**Signature:**
```python
fn = jax_kernel(
    warp_kernel,
    launch_dims=[n],         # total threads to launch per dimension
    num_outputs=1,           # how many output arrays the kernel writes
    output_dims={'out': (n,)} # shape of each output (allocated by Warp)
)
result = fn(x)  # pass only input arrays; outputs are returned
```

There are two output modes:
- **`output_dims`** – Warp allocates the output buffer; you only pass inputs.
- **`in_out_argnames`** – You pass a pre-allocated (e.g., `jnp.zeros`) buffer; Warp writes into it.

In [None]:
if WARP_AVAILABLE:
    N = 1024
    x = jnp.linspace(-2.0, 2.0, N, dtype=jnp.float32)

    # Build the JAX-callable wrapper
    relu_fn = jax_kernel(
        relu_kernel,
        launch_dims=[N],
        num_outputs=1,
        output_dims={'out': (N,)},
    )

    # Call it – returns a tuple of output arrays
    (result,) = relu_fn(x)

    # Verify against JAX reference
    expected = jnp.maximum(x, 0.0)
    print("Max error:", float(jnp.max(jnp.abs(result - expected))))
    print("First 8 values:", result[:8])

## 4. Type Annotations – `jaxinfo_to_warpinfo` / `jaxtype_to_warptype`

When embedding a Warp kernel inside a **kernel generator** (a function that receives
shape/dtype information at trace time), you need to create the Warp type annotations
dynamically.  BrainEvent provides two helpers:

```python
from brainevent import jaxinfo_to_warpinfo, jaxtype_to_warptype

# Convert jax.ShapeDtypeStruct  ->  warp.array(dtype=..., ndim=...)
warp_arr_type = jaxinfo_to_warpinfo(jax.ShapeDtypeStruct((1024,), jnp.float32))

# Convert numpy/JAX dtype  ->  warp scalar type
warp_scalar_type = jaxtype_to_warptype(jnp.float32)  # -> warp.float32
```

These utilities support: `float16`, `float32`, `float64`, `int8`–`int64`, `uint8`–`uint64`, `bool`.

In [None]:
if WARP_AVAILABLE:
    import jax

    for jax_dtype in [jnp.float32, jnp.float64, jnp.int32, jnp.bool_]:
        warp_type = jaxtype_to_warptype(jax_dtype)
        info = jax.ShapeDtypeStruct((8, 4), jax_dtype)
        warp_arr = jaxinfo_to_warpinfo(info)
        print(f"  jnp.{jax_dtype.__name__:<8} -> warp scalar: {warp_type}  |  warp array: {warp_arr}")

## 5. Kernel Generators – Dynamic Kernel Construction

When integrating with `XLACustomKernel`, kernels are not defined statically.
Instead you define a **kernel generator**: a plain Python function that receives
shape/dtype keyword arguments (forwarded from `primitive.bind`) and returns a
callable that runs the actual computation.

This pattern allows the same generator to handle different dtypes and shapes
without re-registering the primitive.

### 5.1 Template for a Warp Kernel Generator

In [None]:
if WARP_AVAILABLE:
    def my_relu_kernel_generator(**kwargs):
        """
        Kernel generator for element-wise ReLU.

        kwargs contains whatever was passed to XLACustomKernel.__call__,
        e.g. kwargs['outs'] = [jax.ShapeDtypeStruct(shape, dtype)]
        """
        # --- 1. Extract shape/dtype information from kwargs ---------------
        out_info = kwargs['outs'][0]            # jax.ShapeDtypeStruct
        n = out_info.shape[0]

        # --- 2. Build Warp type annotations dynamically -------------------
        x_warp_type   = jaxinfo_to_warpinfo(out_info)   # same dtype for input
        out_warp_type = jaxinfo_to_warpinfo(out_info)

        # --- 3. Define the @warp.kernel with dynamic type annotations -----
        @warp.kernel
        def relu_kern(
            x:   x_warp_type,
            out: out_warp_type,
        ):
            i = warp.tid()
            out[i] = warp.max(x[i], out_warp_type.dtype(0.0))

        # --- 4. Return the concrete kernel function -----------------------
        def kernel(x):
            fn = jax_kernel(
                relu_kern,
                launch_dims=[n],
                num_outputs=1,
                output_dims={'out': (n,)},
            )
            return fn(x)

        return kernel

    print("Kernel generator defined.")

## 6. In-place (Accumulation) vs. Pure-output Patterns

Many neuroscience operations **scatter-add** values into an output buffer
(e.g., synaptic current accumulation). Warp handles this via atomic operations
and the `in_out_argnames` mechanism.

### 6.1 Pure output (Warp allocates)

```python
fn = jax_kernel(kernel, launch_dims=[N], num_outputs=1, output_dims={'out': (N,)})
result, = fn(x)  # only pass inputs
```

### 6.2 In-place / accumulation (caller provides buffer)

```python
fn = jax_kernel(kernel, launch_dims=[M], num_outputs=1, in_out_argnames=['acc'])
result, = fn(x, jnp.zeros((N,), dtype))  # pass input THEN the initial output buffer
```

The `in_out_argnames` list tells Warp which arguments are both input and output,
enabling atomic operations inside the kernel.

In [None]:
if WARP_AVAILABLE:
    # Scatter-add example: for each non-zero element in 'values',
    # add values[i] * scale to acc[targets[i]].

    N_SRC = 512   # source elements
    N_DST = 128   # destination (output) size

    @warp.kernel
    def scatter_add_kernel(
        values:  warp.array(dtype=warp.float32, ndim=1),
        targets: warp.array(dtype=warp.int32,   ndim=1),
        scale:   warp.array(dtype=warp.float32, ndim=1),  # 1-element array
        acc:     warp.array(dtype=warp.float32, ndim=1),  # in-place output
    ):
        i = warp.tid()
        # Atomic add is thread-safe – multiple threads may target the same slot
        warp.atomic_add(acc, targets[i], values[i] * scale[0])

    # Create test data
    rng = np.random.default_rng(0)
    values  = jnp.array(rng.random(N_SRC).astype(np.float32))
    targets = jnp.array(rng.integers(0, N_DST, N_SRC).astype(np.int32))
    scale   = jnp.array([2.0], dtype=jnp.float32)

    # Build callable with in-place accumulator
    scatter_fn = jax_kernel(
        scatter_add_kernel,
        launch_dims=[N_SRC],
        num_outputs=1,
        in_out_argnames=['acc'],       # 'acc' is both input and output
    )

    # Run: pass (values, targets, scale, initial_acc)
    init_acc = jnp.zeros(N_DST, dtype=jnp.float32)
    (result,) = scatter_fn(values, targets, scale, init_acc)

    # Verify with NumPy reference
    ref = np.zeros(N_DST, dtype=np.float32)
    np.add.at(ref, np.array(targets), np.array(values) * 2.0)
    print("Scatter-add max error:", float(jnp.max(jnp.abs(result - jnp.array(ref)))))
    print("Result sum:", float(result.sum()), "| Expected:", float(ref.sum()))

## 7. Registering Kernels with `XLACustomKernel`

`XLACustomKernel` is BrainEvent's central abstraction for multi-backend custom
JAX primitives. It lets you register different backend implementations
(Warp, Numba, Pallas, …) for the same logical operation, then dispatch to the
right one at runtime.

**Workflow:**
1. Create an `XLACustomKernel` instance with a unique name
2. Register your Warp kernel generator via `def_warp_kernel()`
3. (Optionally) register a CPU fallback via `def_numba_kernel()`
4. Call the primitive with `kernel(x, outs=[...])`

In [None]:
if WARP_AVAILABLE:
    # -----------------------------------------------------------------------
    # Step 1: Define the kernel generator
    # -----------------------------------------------------------------------
    def warp_scale_add_generator(**kwargs):
        """Element-wise: out[i] = a[i] * b[i] + c[i]"""
        out_info = kwargs['outs'][0]
        n        = out_info.shape[0]
        t        = jaxinfo_to_warpinfo(out_info)

        @warp.kernel
        def kern(
            a:   t,
            b:   t,
            c:   t,
            out: t,
        ):
            i = warp.tid()
            out[i] = a[i] * b[i] + c[i]

        def run(a, b, c):
            fn = jax_kernel(kern, launch_dims=[n], num_outputs=1,
                            output_dims={'out': (n,)})
            return fn(a, b, c)

        return run

    # -----------------------------------------------------------------------
    # Step 2: Create and register the primitive
    # -----------------------------------------------------------------------
    scale_add_op = XLACustomKernel('tutorial_warp_scale_add')
    scale_add_op.def_warp_kernel(warp_scale_add_generator)

    print("Registered backends:", scale_add_op._kernels)
    print("Default backends  :", scale_add_op.defaults)

In [None]:
if WARP_AVAILABLE:
    # -----------------------------------------------------------------------
    # Step 3: Call the primitive
    # -----------------------------------------------------------------------
    N = 256
    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.full(N, 2.0, dtype=jnp.float32)
    c = jnp.ones(N, dtype=jnp.float32)

    out_spec = jax.ShapeDtypeStruct((N,), jnp.float32)
    result   = scale_add_op(a, b, c, outs=[out_spec])

    expected = a * b + c
    print("Max error:", float(jnp.max(jnp.abs(result[0] - expected))))
    print("First 5  :", result[0][:5])

    # -----------------------------------------------------------------------
    # Step 4: Use inside jax.jit (the primitive is JIT-compatible)
    # -----------------------------------------------------------------------
    @jax.jit
    def jitted_op(a, b, c):
        return scale_add_op(a, b, c, outs=[jax.ShapeDtypeStruct(a.shape, a.dtype)])[0]

    r = jitted_op(a, b, c)
    print("JIT result matches:", bool(jnp.allclose(r, expected)))

## 8. Neuroscience Example: Sparse Synaptic Input Accumulation

A classic operation in spiking neural network simulation:
given a binary spike vector `spikes` (shape `[N_pre]`) and a CSR weight matrix
`(data, indices, indptr)`, compute the postsynaptic current

$$I_{\text{post}}[j] = \sum_{i:\, \text{spikes}[i]>0} W[\text{ptr}_{i}..\text{ptr}_{i+1}]$$

We implement this with a Warp kernel that:
1. Iterates over pre-synaptic neurons in parallel
2. Skips silent neurons (no spike)
3. Atomically accumulates weights into the postsynaptic current buffer

In [None]:
if WARP_AVAILABLE:
    def csr_binary_mv_warp_generator(**kwargs):
        """
        Kernel generator for CSR × binary-vector multiplication.
        Signature: kernel(weights, indices, indptr, spikes) -> post_current
        """
        weight_info  = kwargs['weight_info']
        spike_info   = kwargs['spike_info']
        indices_info = kwargs['indices_info']
        indptr_info  = kwargs['indptr_info']
        n_pre        = indptr_info.shape[0] - 1
        n_post       = kwargs['n_post']
        out_dtype    = kwargs['outs'][0].dtype

        # Build Warp type descriptors
        w_type      = jaxinfo_to_warpinfo(weight_info)
        idx_type    = jaxinfo_to_warpinfo(indices_info)
        indptr_type = jaxinfo_to_warpinfo(indptr_info)
        spk_type    = jaxinfo_to_warpinfo(spike_info)
        out_type    = warp.array(dtype=jaxtype_to_warptype(out_dtype), ndim=1)

        @warp.kernel
        def mv_kern(
            weights: w_type,
            indices: idx_type,
            indptr:  indptr_type,
            spikes:  spk_type,
            posts:   out_type,
        ):
            i = warp.tid()              # one thread per pre-synaptic neuron
            if spikes[i]:               # skip silent neurons
                w = weights[0]          # scalar weight (homogeneous)
                for j in range(indptr[i], indptr[i + 1]):
                    warp.atomic_add(posts, indices[j], w)

        def kernel(weights, indices, indptr, spikes):
            fn = jax_kernel(
                mv_kern,
                launch_dims=[n_pre],
                num_outputs=1,
                in_out_argnames=['posts'],
            )
            return fn(weights, indices, indptr, spikes,
                      jnp.zeros(n_post, dtype=out_dtype))

        return kernel

    print("CSR binary MV kernel generator defined.")

In [None]:
if WARP_AVAILABLE:
    import scipy.sparse as sp

    # Build a random CSR connectivity matrix
    N_PRE  = 1000
    N_POST = 500
    PROB   = 0.05   # 5 % connection probability
    W      = 0.1    # homogeneous weight

    rng = np.random.default_rng(42)
    dense = (rng.random((N_PRE, N_POST)) < PROB).astype(np.float32) * W
    csr   = sp.csr_matrix(dense)

    data   = jnp.array([W], dtype=jnp.float32)          # scalar weight
    indices = jnp.array(csr.indices, dtype=jnp.int32)
    indptr  = jnp.array(csr.indptr,  dtype=jnp.int32)

    # Generate binary spikes (10 % firing rate)
    spikes = jnp.array(rng.random(N_PRE) < 0.10, dtype=jnp.bool_)

    # Register the primitive
    csr_mv_op = XLACustomKernel('tutorial_warp_csr_mv')
    csr_mv_op.def_warp_kernel(csr_binary_mv_warp_generator)

    # Build output spec and call
    out_spec = jax.ShapeDtypeStruct((N_POST,), jnp.float32)

    result = csr_mv_op(
        data, indices, indptr, spikes,
        outs=[out_spec],
        # extra kwargs forwarded to the generator:
        weight_info  = jax.ShapeDtypeStruct(data.shape,    data.dtype),
        spike_info   = jax.ShapeDtypeStruct(spikes.shape,  spikes.dtype),
        indices_info = jax.ShapeDtypeStruct(indices.shape, indices.dtype),
        indptr_info  = jax.ShapeDtypeStruct(indptr.shape,  indptr.dtype),
        n_post       = N_POST,
    )

    # Reference: dense matmul
    spikes_f  = spikes.astype(jnp.float32)
    expected  = spikes_f @ jnp.array(dense)

    print(f"Network: {N_PRE} pre -> {N_POST} post  |  {int(spikes.sum())} spikes")
    print(f"Max error vs dense reference: {float(jnp.max(jnp.abs(result[0] - expected))):.6f}")
    print(f"Post current range: [{float(result[0].min()):.3f}, {float(result[0].max()):.3f}]")

In [None]:
if WARP_AVAILABLE:
    import time

    @jax.jit
    def warp_mv(data, indices, indptr, spikes):
        return csr_mv_op(
            data, indices, indptr, spikes,
            outs=[out_spec],
            weight_info  = jax.ShapeDtypeStruct(data.shape,    data.dtype),
            spike_info   = jax.ShapeDtypeStruct(spikes.shape,  spikes.dtype),
            indices_info = jax.ShapeDtypeStruct(indices.shape, indices.dtype),
            indptr_info  = jax.ShapeDtypeStruct(indptr.shape,  indptr.dtype),
            n_post       = N_POST,
        )[0]

    @jax.jit
    def dense_mv(spikes_f, dense):
        return spikes_f @ dense

    # Warm up
    jax.block_until_ready(warp_mv(data, indices, indptr, spikes))
    jax.block_until_ready(dense_mv(spikes_f, jnp.array(dense)))

    N_TRIALS = 200
    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(warp_mv(data, indices, indptr, spikes))
    warp_time = (time.time() - t0) / N_TRIALS * 1000

    t0 = time.time()
    for _ in range(N_TRIALS):
        jax.block_until_ready(dense_mv(spikes_f, jnp.array(dense)))
    dense_time = (time.time() - t0) / N_TRIALS * 1000

    print(f"Warp sparse kernel : {warp_time:.3f} ms")
    print(f"JAX dense matmul   : {dense_time:.3f} ms")
    print(f"Speedup            : {dense_time / warp_time:.2f}x")

## 9. Multiple Backends with `XLACustomKernel`

You can register multiple backends for the same operation and switch at runtime.

In [None]:
# CPU fallback using Numba (demonstrated here even if GPU is unavailable)
try:
    import numba
    from brainevent import numba_kernel
    NUMBA_AVAILABLE = True
except ImportError:
    NUMBA_AVAILABLE = False

if NUMBA_AVAILABLE:
    @numba.njit(parallel=True)
    def _scale_add_numba(a, b, c, out):
        for i in numba.prange(out.size):
            out[i] = a[i] * b[i] + c[i]

    def numba_scale_add_generator(**kwargs):
        out_info = kwargs['outs'][0]

        def kernel(a, b, c):
            return numba_kernel(_scale_add_numba, outs=out_info)(a, b, c)

        return kernel

    # Create op with both GPU (Warp) and CPU (Numba) backends
    multi_backend_op = XLACustomKernel('tutorial_multi_backend_scale_add')

    if WARP_AVAILABLE:
        multi_backend_op.def_warp_kernel(warp_scale_add_generator)   # GPU

    multi_backend_op.def_numba_kernel(numba_scale_add_generator)     # CPU

    print("Registered backends:", list(multi_backend_op._kernels.keys()))

    # On GPU, Warp is default; on CPU, Numba is used automatically
    N = 128
    a = jnp.arange(N, dtype=jnp.float32)
    b = jnp.full(N, 3.0, dtype=jnp.float32)
    c = jnp.ones(N, dtype=jnp.float32)
    r = multi_backend_op(a, b, c, outs=[jax.ShapeDtypeStruct((N,), jnp.float32)])
    print("Result matches:", bool(jnp.allclose(r[0], a * b + c)))

## 10. Summary

In this tutorial we covered:

1. **`@warp.kernel`** – Write GPU kernels in Python-like syntax; use `warp.tid()` for the thread index.
2. **`jax_kernel`** – Wrap a Warp kernel so JAX can call it with `jax.Array` inputs.
   - `output_dims` mode: Warp allocates the output buffer.
   - `in_out_argnames` mode: caller provides the initial buffer (needed for atomic accumulation).
3. **`jaxinfo_to_warpinfo` / `jaxtype_to_warptype`** – Convert JAX dtype/shape info to Warp types
   for dynamic kernel construction inside kernel generators.
4. **`XLACustomKernel.def_warp_kernel`** – Register a Warp kernel generator as the GPU backend
   of a multi-backend custom JAX primitive.
5. **Neuroscience application** – Sparse CSR × binary-spike matrix-vector product implemented
   with Warp atomic operations, demonstrating the key pattern used throughout BrainEvent.

## Next Steps

- **Tutorial 7**: Custom GPU operators with Numba CUDA (`@cuda.jit`)
- **Tutorial 8**: Custom CPU operators with Numba (`@numba.njit`)

## References

- [NVIDIA Warp documentation](https://nvidia.github.io/warp/)
- [BrainEvent GitHub](https://github.com/chaobrain/brainevent)
- [JAX FFI documentation](https://jax.readthedocs.io/en/latest/ffi.html)