# Low-level Operator Customization

@[Tianqiu Zhang](https://github.com/ztqakita)

BrainPy is built on Jax and can accelerate model running performance based on [Just-in-Time(JIT) compilation](./compilation.ipynb). In order to enhance performance on CPU and GPU, we publish another package ``BrainPyLib`` to provide several built-in low-level operators in synaptic computation. These operators are written in C++/CUDA and wrapped as Jax primitives by using ``XLA``. However, users cannot simply customize their own operators unless they have specific background. To solve this problem, we introduce `numba.cfunc` here and provide convenient interfaces for users to customize operators without touching the underlying logic. In this tutorial, we will introduce how to customize operators on CPU. Please notice that BrainPy currently only supports CPU operators customization, and GPU operators will be supported in the future.

In [1]:
import brainpy as bp
import brainpy.math as bm
import jax
from jax import jit
import jax.numpy as jnp
from jax.core import ShapedArray
import numba
import time

bm.set_platform('cpu')

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


We have formally discussed the benefits of computation with our built-in operators. These operators are provided by `brainpylib` package and can be accessed through `brainpy.math` module. To be more specific, in order to speed up sparse synaptic computation, we customize several low-level operators for CPU and GPU, which are written in C++/CUDA and converted into Jax/XLA compatible primitive by using `Pybind11`.

It is not easy to write a C++/CUDA operator and implement a series of conversion. Users have to learn how to write a C++/CUDA operator, how to write a customized Jax primitive, and how to convert your C++/CUDA operator into a Jax primitive. Here are some links for users who prefer to dive into the details: [Jax primitives](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html), [XLA custom calls](https://www.tensorflow.org/xla/custom_call).

However, we can only provide limit amounts of operators for users, and it would be great if users can customize their own operators in a relatively simple way. To achieve this goal, BrainPy provides a convenient interface `XLACustomOp` to register customized operators on CPU. Users no longer need to involve any C++ programming and XLA compilation. This is accomplished with the help of [`numba.cfunc`](https://numba.pydata.org/numba-doc/latest/user/cfunc.html), which will wrap python code as a compiled function callable from foreign C code. The C function object exposes the address of the compiled C callback so that it can be passed into XLA and registered as a jittable Jax primitives. Here is an example of using `XLACustomOp` on CPU.

## How to customize operators?

### CPU version

First, users can customize a simple operator written in python. Notice that this python operator will be jitted in nopython mode, but some language features are not available inside Numba-compiled functions. Please look up [numba documentations](https://numba.pydata.org/numba-doc/latest/reference/pysupported.html#pysupported) for details.

In [2]:
def custom_op(outs, ins):
  y, y1 = outs
  x, x2 = ins
  y[:] = x + 1
  y1[:] = x2 + 2

There are some restrictions that users should know:
- Parameters of the operators are `outs` and `ins`, corresponding to output variable(s) and input variable(s). The order cannot be changed.
- The function cannot have any return value.
- When applying CPU function to GPU, users only need to implement CPU operators.

Then users should describe the shapes and types of the outputs, because JAX/python can deduce the shapes and types of inputs when you call it, but it cannot infer the shapes and types of the outputs. The argument can be:
- a `ShapedArray`,
- a sequence of `ShapedArray`,
- a function, it should return correct output shapes of `ShapedArray`.

Here we use function to describe the output shapes and types. The arguments include all the inputs of custom operators, but only shapes and types are accessible.

In [3]:
def abs_eval_1(*ins):
  # ins: inputs arguments, only shapes and types are accessible.
  # Because custom_op outputs shapes and types are exactly the
  # same as inputs, so here we can only return ordinary inputs.
  return ins

The function above is somewhat abstract for users, so here we give an alternative function below for passing shape information. We want you to know ``abs_eval_1`` and ``abs_eval_2`` are doing the same thing.

In [4]:
def abs_eval_2(*ins):
  return ShapedArray(ins[0].shape, ins[0].dtype), ShapedArray(ins[1].shape, ins[1].dtype)

Now we have prepared for registering a CPU operator. `XLACustomOp` will be called to wrap your operator and return a jittable Jax primitives. Here are some parameters users should define:
- `name`: Name of the operator.
- `eval_shape`: The function to evaluate the shape and dtype of the output according to the input. This function should receive the abstract information of inputs, and return the abstract information of the outputs.
- `con_compute`: The function to make the concrete computation. This function receives inputs and returns outputs.
- `cpu_func`: The function defines the computation on CPU backend. Same as ``con_compute``.
- `gpu_func`: The function defines the computation on GPU backend. Currently, this function is not supported.
- `apply_cpu_func_to_gpu`: Whether allows to apply CPU function on GPU backend. If True, the GPU data will be moved to CPU, and after calculation returned outputs on CPU backend will move to GPU.
- `batching_translation`: The batching translation for the primitive.
- `jvp_translation`:  The forward autodiff translation rule.
- `transpose_translation`: The backward autodiff translation rule.
- `multiple_results`: Whether the primitive returns multiple results.

In [5]:
z = jnp.ones((1, 2), dtype=jnp.float32)
# Users could try out_shapes=abs_eval_2 and see if the result is different
op = bm.XLACustomOp(
  name='add',
  eval_shape=abs_eval_1,
  cpu_func=custom_op,
)
jit_op = jit(op)
print(jit_op(z, z))

[Array([[2., 2.]], dtype=float32), Array([[3., 3.]], dtype=float32)]


### GPU version

We have discussed how to customize a CPU operator above, next we will talk about GPU operator, which is slightly different from CPU version. There are two additional parameters users need to provide:
- `gpu_func`: Customized operator of GPU version.
- `apply_cpu_func_to_gpu`: Whether to run kernel function on CPU for an alternative way for GPU version.

```{warning}
  GPU operators will be wrapped by `cuda.jit` in `numba`, but `numba` currently is not support to launch CUDA kernels from `cfuncs`. For this reason, `gpu_func` is none for default, and there will be an error if users pass a gpu operator to `gpu_func`.
```

Therefore, BrainPy enables users to set `apply_cpu_func_to_gpu` to true for a backup method. All the inputs will be initialized on GPU and transferred to CPU for computing. The operator users have defined will be implemented on CPU and the results will be transferred back to GPU for further tasks.

## Performance

To illustrate the effectiveness of this approach, we will compare the customized operators with BrainPy built-in operators. Here we use `event_sum` as an example. The implementation of `event_sum` by using our customization is shown as below:

In [6]:
def abs_eval(data, indices, indptr, vector, shape):
  out_shape = shape[0]
  return ShapedArray((out_shape,), data.dtype),

@numba.njit(fastmath=True)
def sparse_op(outs, ins):
  res_val = outs[0]
  res_val.fill(0)
  values, col_indices, row_ptr, vector, shape = ins

  for row_i in range(shape[0]):
      v = vector[row_i]
      for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
          res_val[col_indices[j]] += values * v

sparse_cus_op = bm.XLACustomOp(name='sparse', eval_shape=abs_eval, con_compute=sparse_op)

We will use sparse matrix vector multiplication to be our benchmark for testing the speed. We will use built-in operator `event` first.

In [7]:
def sparse(size, prob):
  bm.random.seed()
  vector = bm.random.randn(size)
  sparse_A = bp.conn.FixedProb(prob=prob, allow_multi_conn=True)(size, size).require('pre2post')
  t0 = time.time()
  for _ in range(100):
    hidden = jax.block_until_ready(bm.sparse.csrmv(1., sparse_A[0], sparse_A[1], vector, shape=(size, size), transpose=True, method='vector'))
  cost_t = time.time() - t0
  print(f'Sparse: size {size}, prob {prob}, cost_t {cost_t} s.')
  bm.clear_buffer_memory()

sparse(50000, 0.01)

Sparse: size 50000, prob 0.01, cost_t 2.222744941711426 s.


The total time is 2.22 seconds. Next we use our customized operator.

In [9]:
def sparse_customize(size, prob):
  bm.random.seed()
  vector = bm.random.randn(size)
  sparse_A = bp.conn.FixedProb(prob=prob, allow_multi_conn=True)(size, size).require('pre2post')
  t0 = time.time()
  f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))
  for _ in range(100):
      hidden = jax.block_until_ready(f(1.))
  cost_t = time.time() - t0
  print(f'Sparse: size {size}, prob {prob}, cost_t {cost_t} s.')
  bm.clear_buffer_memory()

sparse_customize(50000, 0.01)

Sparse: size 50000, prob 0.01, cost_t 2.364152193069458 s.


After comparison, the customization method is almost as fast as the built-in method. Users can simply build their own operators without considering the computation speed loss.