From 67be1de9c297401b529a3b9d106473d00ec88ae9 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 8 May 2024 13:39:05 +0800 Subject: [PATCH 01/12] [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24 --- brainpy/_src/math/op_register/__init__.py | 2 +- .../op_register/numba_approach/__init__.py | 187 ++++++++++++++++-- .../op_register/tests/test_numba_based.py | 18 ++ 3 files changed, 185 insertions(+), 22 deletions(-) diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index 21c222c00..7e59e8c09 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -1,5 +1,5 @@ from .numba_approach import (CustomOpByNumba, - register_op_with_numba, + register_op_with_numba_xla, compile_cpu_signature_with_numba) from .base import XLACustomOp from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 5bbd04e0c..8d5cd3de1 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -1,29 +1,41 @@ # -*- coding: utf-8 -*- - +import ctypes from functools import partial from typing import Callable from typing import Union, Sequence import jax -from jax.interpreters import xla, batching, ad +from jax.interpreters import xla, batching, ad, mlir +from jax.lib import xla_client from jax.tree_util import tree_map +from jaxlib.hlo_helpers import custom_call from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject +from brainpy._src.math.op_register.utils import _shape_to_layout from brainpy.errors import PackageMissingError from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba numba = import_numba(error_if_not_found=False) - +if numba is not None: + from numba import types, carray, cfunc __all__ = [ 'CustomOpByNumba', - 'register_op_with_numba', + 'register_op_with_numba_xla', 'compile_cpu_signature_with_numba', ] +def _transform_to_shapedarray(a): + return jax.core.ShapedArray(a.shape, a.dtype) + + +def convert_shapedarray_to_shapedtypestruct(shaped_array): + return jax.ShapeDtypeStruct(shape=shaped_array.shape, dtype=shaped_array.dtype) + + class CustomOpByNumba(BrainPyObject): """Creating a XLA custom call operator with Numba JIT on CPU backend. @@ -61,20 +73,35 @@ def __init__( # abstract evaluation function if eval_shape is None: raise ValueError('Must provide "eval_shape" for abstract evaluation.') + self.eval_shape = eval_shape # cpu function cpu_func = con_compute # register OP - self.op = register_op_with_numba( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) + if jax.__version__ > '0.4.23': + self.op_method = 'mlir' + self.op = register_op_with_numba_mlir( + self.name, + cpu_func=cpu_func, + out_shapes=eval_shape, + gpu_func_translation=None, + batching_translation=batching_translation, + jvp_translation=jvp_translation, + transpose_translation=transpose_translation, + multiple_results=multiple_results, + ) + else: + self.op_method = 'xla' + self.op = register_op_with_numba_xla( + self.name, + cpu_func=cpu_func, + out_shapes=eval_shape, + batching_translation=batching_translation, + jvp_translation=jvp_translation, + transpose_translation=transpose_translation, + multiple_results=multiple_results, + ) def __call__(self, *args, **kwargs): args = tree_map(lambda a: a.value if isinstance(a, Array) else a, @@ -85,7 +112,7 @@ def __call__(self, *args, **kwargs): return res -def register_op_with_numba( +def register_op_with_numba_xla( op_name: str, cpu_func: Callable, out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], @@ -132,13 +159,6 @@ def register_op_with_numba( A JAX Primitive object. """ - if jax.__version__ > '0.4.23': - raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are ' - f'only supported in JAX version <= 0.4.23. \n' - f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. ' - f'For more information, please refer to the documentation: ' - f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') - if numba is None: raise PackageMissingError.by_purpose('numba', 'custom op with numba') @@ -202,3 +222,128 @@ def abs_eval_rule(*input_shapes, **info): ad.primitive_transposes[prim] = transpose_translation return prim + + +def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs): + # output information + outs = ctx.avals_out + output_shapes = tuple([out.shape for out in outs]) + output_dtypes = tuple([out.dtype for out in outs]) + output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) + result_types = [mlir.aval_to_ir_type(out) for out in outs] + + # input information + avals_in = ctx.avals_in + input_layouts = [_shape_to_layout(a.shape) for a in avals_in] + input_dtypes = tuple(inp.dtype for inp in avals_in) + input_shapes = tuple(inp.shape for inp in avals_in) + + # compiling function + code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes, + output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) + args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' + for i in range(len(input_shapes))] + if len(output_shapes) > 1: + args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' + for i in range(len(output_shapes))] + sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) + else: + args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'] + sig = types.void(types.voidptr, types.CPointer(types.voidptr)) + args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))] + code_string = ''' + def numba_cpu_custom_call_target(output_ptrs, input_ptrs): + {args_out} + {args_in} + func_to_call({args_call}) + '''.format(args_out="\n ".join(args_out), args_in="\n ".join(args_in), args_call=", ".join(args_call)) + + if debug: + print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + new_f = code_scope['numba_cpu_custom_call_target'] + + # register + xla_c_rule = cfunc(sig)(new_f) + target_name = f'numba_custom_call_{str(xla_c_rule.address)}' + capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + + # call + return custom_call( + call_target_name=target_name, + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + has_side_effect=False, + ).results + + +def register_op_with_numba_mlir( + op_name: str, + cpu_func: Callable, + out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], + gpu_func_translation: Callable = None, + batching_translation: Callable = None, + jvp_translation: Callable = None, + transpose_translation: Callable = None, + multiple_results: bool = False, +): + if numba is None: + raise PackageMissingError.by_purpose('numba', 'custom op with numba') + + if out_shapes is None: + raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' + 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' + 'shapes and dtypes and should return correct output shapes of `ShapedArray`.') + + prim = jax.core.Primitive(op_name) + prim.multiple_results = multiple_results + + from numba.core.dispatcher import Dispatcher + if not isinstance(cpu_func, Dispatcher): + cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) + + def abs_eval_rule(*input_shapes, **info): + if callable(out_shapes): + shapes = out_shapes(*input_shapes, **info) + else: + shapes = out_shapes + + if isinstance(shapes, jax.core.ShapedArray): + assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." + elif isinstance(shapes, (tuple, list)): + assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." + for elem in shapes: + if not isinstance(elem, jax.core.ShapedArray): + raise ValueError(f'Elements in "out_shapes" must be instances of ' + f'jax.abstract_arrays.ShapedArray, but we got ' + f'{type(elem)}: {elem}') + else: + raise ValueError(f'Unknown type {type(shapes)}, only ' + f'supports function, ShapedArray or ' + f'list/tuple of ShapedArray.') + return shapes + + prim.def_abstract_eval(abs_eval_rule) + prim.def_impl(partial(xla.apply_primitive, prim)) + + def cpu_translation_rule(ctx, *ins, **kwargs): + return _numba_mlir_cpu_translation_rule(cpu_func, False, ctx, *ins, **kwargs) + + mlir.register_lowering(prim, cpu_translation_rule, platform='cpu') + + if gpu_func_translation is not None: + mlir.register_lowering(prim, gpu_func_translation, platform='gpu') + + if batching_translation is not None: + jax.interpreters.batching.primitive_batchers[prim] = batching_translation + + if jvp_translation is not None: + jax.interpreters.ad.primitive_jvps[prim] = jvp_translation + + if transpose_translation is not None: + jax.interpreters.ad.primitive_transposes[prim] = transpose_translation + + return prim diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py index 28b80d0f4..f7adc695c 100644 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ b/brainpy/_src/math/op_register/tests/test_numba_based.py @@ -1,5 +1,6 @@ import jax.core import pytest +from jax.core import ShapedArray import brainpy.math as bm from brainpy._src.dependency_check import import_numba @@ -35,3 +36,20 @@ def test_event_ELL(): call(1000) call(100) bm.clear_buffer_memory() + +# CustomOpByNumba Test + +def eval_shape(a): + b = ShapedArray(a.shape, dtype=a.dtype) + return b + +@numba.njit(parallel=True) +def con_compute(outs, ins): + b = outs + a = ins + b[:] = a + 1 + +def test_CustomOpByNumba(): + op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False) + print(op(bm.zeros(10))) + assert bm.allclose(op(bm.zeros(10)), bm.ones(10)) \ No newline at end of file From e9c21a4862d63e19e39643dcbf8174a1f4fdbaa6 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 8 May 2024 13:40:55 +0800 Subject: [PATCH 02/12] Update dependency_check.py --- brainpy/_src/dependency_check.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 1e1060625..7ab47b822 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -57,9 +57,11 @@ def import_taichi(error_if_not_found=True): if taichi is None: return None - if taichi.__version__ != _minimal_taichi_version: - raise RuntimeError(taichi_install_info) - return taichi + if taichi.__version__[0] >= _minimal_taichi_version[0] and taichi.__version__[1] >= _minimal_taichi_version[1] and \ + taichi.__version__[2] >= _minimal_taichi_version[2]: + return taichi + else: + raise ModuleNotFoundError(taichi_install_info) def raise_taichi_not_found(*args, **kwargs): @@ -182,4 +184,4 @@ def import_brainpylib_gpu_ops(): raise ImportError('Please install GPU version of brainpylib. \n' 'See https://brainpy.readthedocs.io for installation instructions.') - return brainpylib_gpu_ops + return brainpylib_gpu_ops \ No newline at end of file From bafc425bc3673f264632e8c1da99c029853cedc5 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 8 May 2024 14:24:45 +0800 Subject: [PATCH 03/12] Update dependency_check.py --- brainpy/_src/dependency_check.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 7ab47b822..75c2051f9 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -57,8 +57,10 @@ def import_taichi(error_if_not_found=True): if taichi is None: return None - if taichi.__version__[0] >= _minimal_taichi_version[0] and taichi.__version__[1] >= _minimal_taichi_version[1] and \ - taichi.__version__[2] >= _minimal_taichi_version[2]: + taichi_version = taichi.__version__[0] * 10000 + taichi.__version__[1] * 100 + taichi.__version__[2] + minimal_taichi_version = _minimal_taichi_version[0] * 10000 + _minimal_taichi_version[1] * 100 + \ + _minimal_taichi_version[2] + if taichi_version >= minimal_taichi_version: return taichi else: raise ModuleNotFoundError(taichi_install_info) @@ -184,4 +186,4 @@ def import_brainpylib_gpu_ops(): raise ImportError('Please install GPU version of brainpylib. \n' 'See https://brainpy.readthedocs.io for installation instructions.') - return brainpylib_gpu_ops \ No newline at end of file + return brainpylib_gpu_ops From 789865846e49ffd076d0286f31abeebad2d03545 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 8 May 2024 14:44:54 +0800 Subject: [PATCH 04/12] Update requirements-dev.txt --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 641f99fde..754073f44 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,7 @@ matplotlib msgpack tqdm pathos -taichi +taichi==1.7.0 numba braincore braintools From e55a9b0ea9cd34e1577cc0d73c17a36df771478d Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 14 May 2024 23:06:35 +0800 Subject: [PATCH 05/12] Update --- .../op_register/numba_approach/__init__.py | 68 ++--------------- .../numba_approach/cpu_translation.py | 76 +++++++++++++++++++ .../tests/test_numba_approach.py | 49 ++++++++++++ 3 files changed, 131 insertions(+), 62 deletions(-) create mode 100644 brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 8d5cd3de1..1ad489cf1 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -6,16 +6,15 @@ import jax from jax.interpreters import xla, batching, ad, mlir -from jax.lib import xla_client + from jax.tree_util import tree_map -from jaxlib.hlo_helpers import custom_call from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject -from brainpy._src.math.op_register.utils import _shape_to_layout + from brainpy.errors import PackageMissingError -from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba +from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba, _numba_mlir_cpu_translation_rule numba = import_numba(error_if_not_found=False) if numba is not None: @@ -224,62 +223,6 @@ def abs_eval_rule(*input_shapes, **info): return prim -def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs): - # output information - outs = ctx.avals_out - output_shapes = tuple([out.shape for out in outs]) - output_dtypes = tuple([out.dtype for out in outs]) - output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) - result_types = [mlir.aval_to_ir_type(out) for out in outs] - - # input information - avals_in = ctx.avals_in - input_layouts = [_shape_to_layout(a.shape) for a in avals_in] - input_dtypes = tuple(inp.dtype for inp in avals_in) - input_shapes = tuple(inp.shape for inp in avals_in) - - # compiling function - code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes, - output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) - args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' - for i in range(len(input_shapes))] - if len(output_shapes) > 1: - args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' - for i in range(len(output_shapes))] - sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) - else: - args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'] - sig = types.void(types.voidptr, types.CPointer(types.voidptr)) - args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))] - code_string = ''' - def numba_cpu_custom_call_target(output_ptrs, input_ptrs): - {args_out} - {args_in} - func_to_call({args_call}) - '''.format(args_out="\n ".join(args_out), args_in="\n ".join(args_in), args_call=", ".join(args_call)) - - if debug: - print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - new_f = code_scope['numba_cpu_custom_call_target'] - - # register - xla_c_rule = cfunc(sig)(new_f) - target_name = f'numba_custom_call_{str(xla_c_rule.address)}' - capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - - # call - return custom_call( - call_target_name=target_name, - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - has_side_effect=False, - ).results - - def register_op_with_numba_mlir( op_name: str, cpu_func: Callable, @@ -329,8 +272,9 @@ def abs_eval_rule(*input_shapes, **info): prim.def_abstract_eval(abs_eval_rule) prim.def_impl(partial(xla.apply_primitive, prim)) - def cpu_translation_rule(ctx, *ins, **kwargs): - return _numba_mlir_cpu_translation_rule(cpu_func, False, ctx, *ins, **kwargs) + cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule, + cpu_func, + True) mlir.register_lowering(prim, cpu_translation_rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 4b06effdf..363ce6b17 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -5,8 +5,11 @@ from jax import dtypes, numpy as jnp from jax.core import ShapedArray from jax.lib import xla_client +from jaxlib.hlo_helpers import custom_call +from jax.interpreters import mlir from brainpy._src.dependency_check import import_numba +from brainpy._src.math.op_register.utils import _shape_to_layout numba = import_numba(error_if_not_found=False) ctypes.pythonapi.PyCapsule_New.argtypes = [ @@ -19,6 +22,7 @@ __all__ = [ '_cpu_translation', 'compile_cpu_signature_with_numba', + '_numba_mlir_cpu_translation_rule', ] if numba is not None: @@ -150,3 +154,75 @@ def compile_cpu_signature_with_numba( if multiple_results else output_layouts[0]) return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts + + +def _numba_mlir_cpu_translation_rule( + cpu_func, + debug, + ctx, + *ins, + **kwargs +): + # output information + outs = ctx.avals_out + output_shapes = tuple([out.shape for out in outs]) + output_dtypes = tuple([out.dtype for out in outs]) + output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) + result_types = [mlir.aval_to_ir_type(out) for out in outs] + + # input information + avals_in = ctx.avals_in + input_layouts = [_shape_to_layout(a.shape) for a in avals_in] + input_dtypes = tuple(inp.dtype for inp in avals_in) + input_shapes = tuple(inp.shape for inp in avals_in) + + # compiling function + code_scope = dict(func_to_call=cpu_func, input_shapes=input_shapes, input_dtypes=input_dtypes, + output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) + if len(input_shapes) > 1: + args_in = [ + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' + for i in range(len(input_shapes)) + ] + args_in = '(\n ' + "\n ".join(args_in) + '\n )' + else: + args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' + if len(output_shapes) > 1: + args_out = [ + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' + for i in range(len(output_shapes)) + ] + args_out = '(\n ' + "\n ".join(args_out) + '\n )' + sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) + else: + args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' + sig = types.void(types.voidptr, types.CPointer(types.voidptr)) + # args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))] + code_string = ''' +def numba_cpu_custom_call_target(output_ptrs, input_ptrs): + args_out = {args_out} + args_in = {args_in} + func_to_call(args_out, args_in) + '''.format(args_in=args_in, + args_out=args_out) + + if debug: + print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + new_f = code_scope['numba_cpu_custom_call_target'] + + # register + xla_c_rule = cfunc(sig)(new_f) + target_name = f'numba_custom_call_{str(xla_c_rule.address)}' + capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + + # call + return custom_call( + call_target_name=target_name, + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + has_side_effect=False, + ).results diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py new file mode 100644 index 000000000..e1bed7de5 --- /dev/null +++ b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py @@ -0,0 +1,49 @@ +import jax.core +import pytest +from jax.core import ShapedArray + +import brainpy.math as bm +from brainpy._src.dependency_check import import_numba + +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) + +bm.set_platform('cpu') + + +def eval_shape(a): + b = ShapedArray(a.shape, dtype=a.dtype) + return b + +@numba.njit(parallel=True) +def con_compute(outs, ins): + b = outs + a = ins + b[:] = a + 1 + +def test_CustomOpByNumba_single_result(): + op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False) + print(op(bm.zeros(10))) + +def eval_shape2(a, b): + c = ShapedArray(a.shape, dtype=a.dtype) + d = ShapedArray(b.shape, dtype=b.dtype) + return c, d + +@numba.njit(parallel=True) +def con_compute2(outs, ins): + c = outs[0] # take out all the outputs + d = outs[1] + a = ins[0] # take out all the inputs + b = ins[1] + # c, d = outs + # a, b = ins + c[:] = a + 1 + d[:] = b * 2 + +def test_CustomOpByNumba_multiple_results(): + op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True) + print(op2(bm.zeros(10), bm.ones(10))) + +test_CustomOpByNumba_multiple_results() \ No newline at end of file From 48023d22372e3d340ac055f55b005ba36667f6e7 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 14 May 2024 23:12:30 +0800 Subject: [PATCH 06/12] Update operator_custom_with_numba.ipynb --- .../operator_custom_with_numba.ipynb | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index e1121f5b6..e4f8dd208 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -149,8 +149,10 @@ " return c, d\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " c = outs[0] # take out all the outputs\n", + " d = outs[1]\n", + " a = ins[0] # take out all the inputs\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -191,7 +193,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # Take out all the outputs\n", - " a, b = ins # Take out all inputs\n", + " a = ins[0] # Take out all inputs\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -434,8 +437,10 @@ " return c, d # 返回多个抽象数组信息\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " c = outs[0] # 取出所有的输出\n", + " d = outs[1]\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -476,7 +481,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", From 9728844e1a6621566f54f3466521529f02d3a8e1 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 14 May 2024 23:14:19 +0800 Subject: [PATCH 07/12] Update __init__.py --- brainpy/_src/math/op_register/numba_approach/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 7429659ec..5ac07191b 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -148,7 +148,6 @@ def __call__(self, *args, **kwargs): return res -def register_op_with_numba_xla( def register_op_with_numba_xla( op_name: str, cpu_func: Callable, From 8094c079df8916c8b83b6eafe5b903376f3527d0 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 14 May 2024 23:18:46 +0800 Subject: [PATCH 08/12] Update dependency_check.py --- brainpy/_src/dependency_check.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 75c2051f9..05a7c79c1 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -25,9 +25,9 @@ brainpylib_cpu_ops = None brainpylib_gpu_ops = None -taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' - f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' - '> pip install taichi==1.7.0') +taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. ' + f'Currently you can install taichi=={_minimal_taichi_version} by pip . \n' + '> pip install taichi -U') numba_install_info = ('We need numba. Please install numba by pip . \n' '> pip install numba') cupy_install_info = ('We need cupy. Please install cupy by pip . \n' From c70b3dd2bfd7ccd5e491745b3250a5261f250e1d Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 14 May 2024 23:21:47 +0800 Subject: [PATCH 09/12] Update __init__.py --- .../op_register/numba_approach/__init__.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 5ac07191b..4d9b284f5 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -21,25 +21,14 @@ numba = import_numba(error_if_not_found=False) if numba is not None: from numba import types, carray, cfunc -if numba is not None: - from numba import types, carray, cfunc __all__ = [ 'CustomOpByNumba', 'register_op_with_numba_xla', - 'register_op_with_numba_xla', 'compile_cpu_signature_with_numba', ] -def _transform_to_shapedarray(a): - return jax.core.ShapedArray(a.shape, a.dtype) - - -def convert_shapedarray_to_shapedtypestruct(shaped_array): - return jax.ShapeDtypeStruct(shape=shaped_array.shape, dtype=shaped_array.dtype) - - def _transform_to_shapedarray(a): return jax.core.ShapedArray(a.shape, a.dtype) @@ -86,7 +75,6 @@ def __init__( if eval_shape is None: raise ValueError('Must provide "eval_shape" for abstract evaluation.') self.eval_shape = eval_shape - self.eval_shape = eval_shape # cpu function cpu_func = con_compute @@ -115,29 +103,6 @@ def __init__( transpose_translation=transpose_translation, multiple_results=multiple_results, ) - if jax.__version__ > '0.4.23': - self.op_method = 'mlir' - self.op = register_op_with_numba_mlir( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - gpu_func_translation=None, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) - else: - self.op_method = 'xla' - self.op = register_op_with_numba_xla( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) def __call__(self, *args, **kwargs): args = tree_map(lambda a: a.value if isinstance(a, Array) else a, From a9a9845ece81f68efdb52632297f8b5bd265c7a2 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 15 May 2024 08:37:16 +0800 Subject: [PATCH 10/12] Fix --- .../tests/test_numba_approach.py | 14 +++++++------- .../operator_custom_with_numba.ipynb | 18 ++++++------------ 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py index e1bed7de5..091468c9e 100644 --- a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py +++ b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py @@ -31,14 +31,14 @@ def eval_shape2(a, b): d = ShapedArray(b.shape, dtype=b.dtype) return c, d -@numba.njit(parallel=True) + def con_compute2(outs, ins): - c = outs[0] # take out all the outputs - d = outs[1] - a = ins[0] # take out all the inputs - b = ins[1] - # c, d = outs - # a, b = ins + # c = outs[0] # take out all the outputs + # d = outs[1] + # a = ins[0] # take out all the inputs + # b = ins[1] + c, d = outs + a, b = ins c[:] = a + 1 d[:] = b * 2 diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index e4f8dd208..0b840db04 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -149,10 +149,8 @@ " return c, d\n", "\n", "def con_compute2(outs, ins):\n", - " c = outs[0] # take out all the outputs\n", - " d = outs[1]\n", - " a = ins[0] # take out all the inputs\n", - " b = ins[1]\n", + " c, d = outs # take out all the outputs\n", + " a, b = ins # take out all the inputs\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -193,8 +191,7 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # Take out all the outputs\n", - " a = ins[0] # Take out all inputs\n", - " b = ins[1]\n", + " a, b = ins # Take out all inputs\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -437,10 +434,8 @@ " return c, d # 返回多个抽象数组信息\n", "\n", "def con_compute2(outs, ins):\n", - " c = outs[0] # 取出所有的输出\n", - " d = outs[1]\n", - " a = ins[0] # 取出所有的输入\n", - " b = ins[1]\n", + " c, d = outs # 取出所有的输出\n", + " a, b = ins # 取出所有的输入\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -481,8 +476,7 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # 取出所有的输出\n", - " a = ins[0] # 取出所有的输入\n", - " b = ins[1]\n", + " a, b = ins # 取出所有的输入\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", From e61db3f5316e41331449a6fdcb2ac86c2919bd5e Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 15 May 2024 11:11:57 +0800 Subject: [PATCH 11/12] Update docs --- .../op_register/numba_approach/__init__.py | 2 +- .../tests/test_numba_approach.py | 13 +- .../operator_custom_with_numba.ipynb | 120 ++++++++++++++++-- 3 files changed, 113 insertions(+), 22 deletions(-) diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 4d9b284f5..35c9beef6 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -276,7 +276,7 @@ def abs_eval_rule(*input_shapes, **info): cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule, cpu_func, - True) + False) mlir.register_lowering(prim, cpu_translation_rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py index 091468c9e..21099cb61 100644 --- a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py +++ b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py @@ -31,14 +31,13 @@ def eval_shape2(a, b): d = ShapedArray(b.shape, dtype=b.dtype) return c, d - def con_compute2(outs, ins): - # c = outs[0] # take out all the outputs - # d = outs[1] - # a = ins[0] # take out all the inputs - # b = ins[1] - c, d = outs - a, b = ins + c = outs[0] # take out all the outputs + d = outs[1] + a = ins[0] # take out all the inputs + b = ins[1] + # c, d = outs + # a, b = ins c[:] = a + 1 d[:] = b * 2 diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index 0b840db04..7f00cd56e 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -65,8 +65,6 @@ "source": [ "### ``brainpy.math.CustomOpByNumba``\n", "\n", - "``brainpy.math.CustomOpByNumba`` is also called ``brainpy.math.XLACustomOp``.\n", - "\n", "BrainPy provides ``brainpy.math.CustomOpByNumba`` for customizing the operator on the CPU device. Two parameters are required to provide in ``CustomOpByNumba``:\n", "\n", "- ``eval_shape``: evaluates the *shape* and *datatype* of the output argument based on the *shape* and *datatype* of the input argument.\n", @@ -137,7 +135,7 @@ "collapsed": false }, "source": [ - "### Return multiple values ``multiple_returns=True``\n", + "#### Return multiple values ``multiple_returns=True``\n", "\n", "If the result of our computation needs to return multiple arrays, then we need to use ``multiple_returns=True`` in our use of registering the operator. In this case, ``outs`` will be a list containing multiple arrays, not an array.\n", "\n", @@ -149,8 +147,10 @@ " return c, d\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # take out all the outputs\n", - " a, b = ins # take out all the inputs\n", + " c = outs[0] # take out all the outputs\n", + " d = outs[1]\n", + " a = ins[0] # take out all the inputs\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -170,7 +170,7 @@ "collapsed": false }, "source": [ - "### Non-Tracer parameters\n", + "#### Non-Tracer parameters\n", "\n", "In the ``eval_shape`` function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by ``jax.jit``. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.\n", "\n", @@ -191,7 +191,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # Take out all the outputs\n", - " a, b = ins # Take out all inputs\n", + " a = ins[0] # Take out all inputs\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -221,7 +222,7 @@ "collapsed": false }, "source": [ - "### Example: A sparse operator\n", + "#### Example: A sparse operator\n", "\n", "To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator." ] @@ -297,6 +298,50 @@ "f(1.)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### brainpy.math.XLACustomOp\n", + "\n", + "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit, and then pass the kernel to `brainpy.math.XLACustomOp`.\n", + "\n", + "Detailed steps are as follows:\n", + "\n", + "#### Define the kernel\n", + "\n", + "```python\n", + "@numba.njit(fastmath=True)\n", + "def numba_event_csrmv(weight, indices, vector, outs):\n", + " outs.fill(0)\n", + " weight = weight[()] # 0d\n", + " for row_i in range(vector.shape[0]):\n", + " if vector[row_i]:\n", + " for j in indices[row_i]:\n", + " outs[j] += weight\n", + "```\n", + "\n", + "In the declaration of parameters, the last few parameters need to be output parameters so that numba can compile correctly. This operator numba_event_csrmv receives four parameters: `weight`, `indices`, `vector`, and `outs`. The first three parameters are input parameters, and the last parameter is the output parameter. The output parameter is a 1D array, and the input parameters are 0D, 1D, and 2D arrays, respectively." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Registering and Using Custom Operators\n", + "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n", + "\n", + "Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n", + "\n", + "```python\n", + "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n", + "indices = bm.random.randint(0, s, (s, 80))\n", + "vector = bm.random.rand(s) < 0.1\n", + "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n", + "print(out)\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": { @@ -423,7 +468,7 @@ "collapsed": false }, "source": [ - "### 返回多个值 ``multiple_returns=True``\n", + "#### 返回多个值 ``multiple_returns=True``\n", "\n", "如果我们的计算结果需要返回多个数组,那么,我们在注册算子的使用需要使用``multiple_returns=True``。此时,``outs``将会是一个包含多个数组的列表,而不是一个数组。\n", "\n", @@ -434,8 +479,10 @@ " return c, d # 返回多个抽象数组信息\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " c = outs[0] # 取出所有的输出\n", + " d = outs[1]\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -455,7 +502,7 @@ "collapsed": false }, "source": [ - "### 非Tracer参数\n", + "#### 非Tracer参数\n", "\n", "在``eval_shape``函数中推断数据类型时,如果所有参数都是可以被``jax.jit``追踪的参数,那么所有参数都是抽象信息(只包含形状和类型)。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息,此时我们需要定义非Tracer参数。\n", "\n", @@ -476,7 +523,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -506,7 +554,7 @@ "collapsed": false }, "source": [ - "### 示例:一个稀疏算子\n", + "#### 示例:一个稀疏算子\n", "\n", "为了说明这种方法的有效性,我们在这个定义一个事件驱动的稀疏计算算子。" ] @@ -581,6 +629,50 @@ "f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n", "f(1.)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### brainpy.math.XLACustomOp\n", + "\n", + "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using `@numba.jit` or `@numba.njit` decorator, and then pass the kernel to `brainpy.math.XLACustomOp`.\n", + "`brainpy.math.XLACustomOp`是一种自定义算子的新方法。它类似于`brainpy.math.CustomOpByNumba`,但它更灵活并支持更高级的特性。如果您想用numba使用这种新方法,只需要使用 `@numba.jit`或`@numba.njit`装饰器定义一个kernel,然后将内核传递给`brainpy.math.XLACustomOp`。\n", + "\n", + "详细步骤如下:\n", + "\n", + "#### 定义kernel\n", + "在参数声明中,最后几个参数需要是输出参数,这样numba才能正确编译。这个算子`numba_event_csrmv`接受四个参数:weight、indices、vector 和 outs。前三个参数是输入参数,最后一个参数是输出参数。输出参数是一个一维数组,输入参数分别是 0D、1D 和 2D 数组。\n", + "\n", + "```python\n", + "@numba.njit(fastmath=True)\n", + "def numba_event_csrmv(weight, indices, vector, outs):\n", + " outs.fill(0)\n", + " weight = weight[()] # 0d\n", + " for row_i in range(vector.shape[0]):\n", + " if vector[row_i]:\n", + " for j in indices[row_i]:\n", + " outs[j] += weight\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 注册并使用自定义算子\n", + "在定义了自定义算子之后,可以将其注册到特定框架中,并在需要的地方使用它。在注册时可以指定`cpu_kernel`和`gpu_kernel`,这样算子就可以在不同的设备上运行。并在调用中指定`outs`参数,用`jax.ShapeDtypeStruct`来指定输出的形状和数据类型。\n", + "\n", + "注意: 在算子声明的参数与调用时需要保持顺序的一致。\n", + "\n", + "```python\n", + "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n", + "indices = bm.random.randint(0, s, (s, 80))\n", + "vector = bm.random.rand(s) < 0.1\n", + "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n", + "print(out)\n", + "```" + ] } ], "metadata": { From 3820aa71f7b3f3dba4b44deee5f2b90f04bd1bc5 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 15 May 2024 15:45:04 +0800 Subject: [PATCH 12/12] Update operator_custom_with_taichi.ipynb --- docs/tutorial_advanced/operator_custom_with_taichi.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb index 4b86a4269..e927bf72c 100644 --- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb @@ -127,7 +127,7 @@ "metadata": {}, "source": [ "### Registering and Using Custom Operators\n", - "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output.\n", + "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n", "\n", "Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n", "\n",