diff --git a/brainpy/math/operators/__init__.py b/brainpy/math/operators/__init__.py index 7466d06a7..59124391b 100644 --- a/brainpy/math/operators/__init__.py +++ b/brainpy/math/operators/__init__.py @@ -1,17 +1,18 @@ # -*- coding: utf-8 -*- -from . import multiplication +from . import sparse_matmul, event_matmul from . import op_register from . import pre_syn_post as pre_syn_post_module from . import wrap_jax from . import spikegrad -__all__ = multiplication.__all__ + op_register.__all__ +__all__ = event_matmul.__all__ + sparse_matmul.__all__ + op_register.__all__ __all__ += pre_syn_post_module.__all__ + wrap_jax.__all__ + spikegrad.__all__ -from .multiplication import * +from .event_matmul import * +from .sparse_matmul import * from .op_register import * from .pre_syn_post import * from .wrap_jax import * diff --git a/brainpy/math/operators/event_matmul.py b/brainpy/math/operators/event_matmul.py new file mode 100644 index 000000000..b44c39018 --- /dev/null +++ b/brainpy/math/operators/event_matmul.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- + + +from typing import Tuple + +from brainpy.math.numpy_ops import as_jax +from brainpy.types import Array +from .utils import _check_brainpylib + +try: + import brainpylib +except ModuleNotFoundError: + brainpylib = None + +__all__ = [ + 'event_csr_matvec', +] + + +def event_csr_matvec(values: Array, + indices: Array, + indptr: Array, + events: Array, + shape: Tuple[int, ...], + transpose: bool = False): + """The pre-to-post event-driven synaptic summation with `CSR` synapse structure. + + Parameters + ---------- + values: Array, float + An array of shape ``(nse,)`` or a float. + indices: Array + An array of shape ``(nse,)``. + indptr: Array + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + events: Array + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple of int + A length-2 tuple representing the sparse matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. Default is False. + + Returns + ------- + out: Array + A tensor with the shape of ``shape[1]`` if `transpose=True`, + or ``shape[0]`` if `transpose=False`. + """ + _check_brainpylib('event_csr_matvec') + events = as_jax(events) + indices = as_jax(indices) + indptr = as_jax(indptr) + values = as_jax(values) + return brainpylib.event_csr_matvec(values, indices, indptr, events, + shape=shape, transpose=transpose) diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py index 838772b93..97d2f88f7 100644 --- a/brainpy/math/operators/op_register.py +++ b/brainpy/math/operators/op_register.py @@ -2,7 +2,7 @@ from typing import Union, Sequence, Callable -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray from jax.tree_util import tree_map from brainpy.base import Base @@ -57,6 +57,10 @@ def __init__( gpu_func: Callable = None, apply_cpu_func_to_gpu: bool = False, name: str = None, + batching_translation: Callable = None, + jvp_translation: Callable = None, + transpose_translation: Callable = None, + multiple_results: bool = False, ): _check_brainpylib(register_op.__name__) super(XLACustomOp, self).__init__(name=name) @@ -77,11 +81,17 @@ def __init__( gpu_func = None # register OP - self.op = brainpylib.register_op(self.name, - cpu_func=cpu_func, - gpu_func=gpu_func, - out_shapes=eval_shape, - apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) + self.op = brainpylib.register_op_with_numba( + self.name, + cpu_func=cpu_func, + gpu_func_translation=gpu_func, + out_shapes=eval_shape, + apply_cpu_func_to_gpu=apply_cpu_func_to_gpu, + batching_translation=batching_translation, + jvp_translation=jvp_translation, + transpose_translation=transpose_translation, + multiple_results=multiple_results, + ) def __call__(self, *args, **kwargs): args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a, @@ -89,7 +99,7 @@ def __call__(self, *args, **kwargs): kwargs = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a, kwargs, is_leaf=lambda a: isinstance(a, JaxArray)) res = self.op.bind(*args, **kwargs) - return res[0] if len(res) == 1 else res + return res def register_op( @@ -122,15 +132,15 @@ def register_op( A jitable JAX function. """ _check_brainpylib(register_op.__name__) - f = brainpylib.register_op(name, - cpu_func=cpu_func, - gpu_func=gpu_func, - out_shapes=eval_shape, - apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) + f = brainpylib.register_op_with_numba(name, + cpu_func=cpu_func, + gpu_func_translation=gpu_func, + out_shapes=eval_shape, + apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) def fixed_op(*inputs, **info): inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs]) res = f.bind(*inputs, **info) - return res[0] if len(res) == 1 else res + return res return fixed_op diff --git a/brainpy/math/operators/pre_syn_post.py b/brainpy/math/operators/pre_syn_post.py index ea567b23e..6679e3565 100644 --- a/brainpy/math/operators/pre_syn_post.py +++ b/brainpy/math/operators/pre_syn_post.py @@ -6,8 +6,7 @@ from jax import vmap, jit, ops as jops from brainpy.errors import MathError -from brainpy.math.jaxarray import JaxArray -from brainpy.math.numpy_ops import as_device_array +from brainpy.math.numpy_ops import as_jax from brainpy.types import Array from .utils import _check_brainpylib @@ -25,9 +24,9 @@ 'pre2post_mean', # pre-to-post event operator - 'pre2post_csr_event_sum', 'pre2post_event_sum', + 'pre2post_event_sum', 'pre2post_coo_event_sum', - 'pre2post_csr_event_prod', 'pre2post_event_prod', + 'pre2post_event_prod', # pre-to-syn 'pre2syn', @@ -49,10 +48,10 @@ def _raise_pre_ids_is_none(pre_ids): f'(brainpy.math.ndim(pre_values) != 0).') -def pre2post_csr_event_sum(events: Array, - pre2post: Tuple[Array, Array], - post_num: int, - values: Union[float, Array] = 1.): +def pre2post_event_sum(events: Array, + pre2post: Tuple[Array, Array], + post_num: int, + values: Union[float, Array] = 1.): """The pre-to-post event-driven synaptic summation with `CSR` synapse structure. When ``values`` is a scalar, this function is equivalent to @@ -98,16 +97,15 @@ def pre2post_csr_event_sum(events: Array, out: JaxArray, jax.numpy.ndarray A tensor with the shape of ``post_num``. """ - _check_brainpylib('pre2post_event_sum') + _check_brainpylib('event_csr_matvec') indices, idnptr = pre2post - events = as_device_array(events) - indices = as_device_array(indices) - idnptr = as_device_array(idnptr) - values = as_device_array(values) - return brainpylib.csr_event_sum(events, (indices, idnptr), post_num, values) - - -pre2post_event_sum = pre2post_csr_event_sum + events = as_jax(events) + indices = as_jax(indices) + idnptr = as_jax(idnptr) + values = as_jax(values) + return brainpylib.event_csr_matvec(values, indices, idnptr, events, + shape=(events.shape[0], post_num), + transpose=True) def pre2post_coo_event_sum(events: Array, @@ -136,14 +134,14 @@ def pre2post_coo_event_sum(events: Array, A tensor with the shape of ``post_num``. """ _check_brainpylib('pre2post_event_sum') - events = as_device_array(events) - post_ids = as_device_array(post_ids) - pre_ids = as_device_array(pre_ids) - values = as_device_array(values) + events = as_jax(events) + post_ids = as_jax(post_ids) + pre_ids = as_jax(pre_ids) + values = as_jax(values) return brainpylib.coo_event_sum(events, pre_ids, post_ids, post_num, values) -def pre2post_csr_event_prod(events, pre2post, post_num, values=1.): +def pre2post_event_prod(events, pre2post, post_num, values=1.): """The pre-to-post synaptic computation with event-driven production. When ``values`` is a scalar, this function is equivalent to @@ -191,16 +189,13 @@ def pre2post_csr_event_prod(events, pre2post, post_num, values=1.): """ _check_brainpylib('pre2post_event_prod') indices, idnptr = pre2post - events = as_device_array(events) - indices = as_device_array(indices) - idnptr = as_device_array(idnptr) - values = as_device_array(values) + events = as_jax(events) + indices = as_jax(indices) + idnptr = as_jax(idnptr) + values = as_jax(values) return brainpylib.csr_event_prod(events, (indices, idnptr), post_num, values) -pre2post_event_prod = pre2post_csr_event_prod - - def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None): """The pre-to-post synaptic summation. @@ -230,11 +225,11 @@ def pre2post_sum(pre_values, post_num, post_ids, pre_ids=None): The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) + pre_ids = as_jax(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].add(pre_values) @@ -268,11 +263,11 @@ def pre2post_prod(pre_values, post_num, post_ids, pre_ids=None): The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) + pre_ids = as_jax(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].multiply(pre_values) @@ -306,11 +301,11 @@ def pre2post_min(pre_values, post_num, post_ids, pre_ids=None): The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) + pre_ids = as_jax(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].min(pre_values) @@ -344,11 +339,11 @@ def pre2post_max(pre_values, post_num, post_ids, pre_ids=None): The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) if jnp.ndim(pre_values) != 0: _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) + pre_ids = as_jax(pre_ids) pre_values = pre_values[pre_ids] return out.at[post_ids].max(pre_values) @@ -373,14 +368,14 @@ def pre2post_mean(pre_values, post_num, post_ids, pre_ids=None): The value with the size of post-synaptic neurons. """ out = jnp.zeros(post_num) - pre_values = as_device_array(pre_values) - post_ids = as_device_array(post_ids) + pre_values = as_jax(pre_values) + post_ids = as_jax(post_ids) if jnp.ndim(pre_values) == 0: return out.at[post_ids].set(pre_values) # return out.at[jnp.unique(post_ids)].set(pre_values) else: _raise_pre_ids_is_none(pre_ids) - pre_ids = as_device_array(pre_ids) + pre_ids = as_jax(pre_ids) pre_values = pre2syn(pre_values, pre_ids) return syn2post_mean(pre_values, post_ids, post_num) @@ -414,8 +409,8 @@ def pre2syn(pre_values, pre_ids): syn_val: jax.numpy.ndarray, JaxArray The synaptic value. """ - pre_values = as_device_array(pre_values) - pre_ids = as_device_array(pre_ids) + pre_values = as_jax(pre_values) + pre_ids = as_jax(pre_ids) if jnp.ndim(pre_values) == 0: return jnp.ones(len(pre_ids), dtype=pre_values.dtype) * pre_values else: @@ -454,8 +449,8 @@ def syn2post_sum(syn_values, post_ids, post_num: int, indices_are_sorted=False): post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) @@ -494,8 +489,8 @@ def syn2post_prod(syn_values, post_ids, post_num: int, indices_are_sorted=False) post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_prod(syn_values, post_ids, post_num, indices_are_sorted) @@ -531,8 +526,8 @@ def syn2post_max(syn_values, post_ids, post_num: int, indices_are_sorted=False): post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) @@ -568,8 +563,8 @@ def syn2post_min(syn_values, post_ids, post_num: int, indices_are_sorted=False): post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) return _jit_seg_min(syn_values, post_ids, post_num, indices_are_sorted) @@ -596,8 +591,8 @@ def syn2post_mean(syn_values, post_ids, post_num: int, indices_are_sorted=False) post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) nominator = _jit_seg_sum(syn_values, post_ids, post_num, indices_are_sorted) @@ -626,8 +621,8 @@ def syn2post_softmax(syn_values, post_ids, post_num: int, indices_are_sorted=Fal post_val: jax.numpy.ndarray, JaxArray The post-synaptic value. """ - post_ids = as_device_array(post_ids) - syn_values = as_device_array(syn_values) + post_ids = as_jax(post_ids) + syn_values = as_jax(syn_values) if syn_values.dtype == jnp.bool_: syn_values = jnp.asarray(syn_values, dtype=jnp.int32) syn_maxs = _jit_seg_max(syn_values, post_ids, post_num, indices_are_sorted) diff --git a/brainpy/math/operators/multiplication.py b/brainpy/math/operators/sparse_matmul.py similarity index 69% rename from brainpy/math/operators/multiplication.py rename to brainpy/math/operators/sparse_matmul.py index af8dc9cf0..bceff338c 100644 --- a/brainpy/math/operators/multiplication.py +++ b/brainpy/math/operators/sparse_matmul.py @@ -1,16 +1,23 @@ # -*- coding: utf-8 -*- - -from typing import Union, Dict +from typing import Union, Dict, Tuple import jax.numpy as jnp -from jax import ops as jops +from jax import ops from brainpy.math.jaxarray import JaxArray -from brainpy.math.numpy_ops import _remove_jaxarray +from brainpy.math.numpy_ops import as_jax +from .utils import _check_brainpylib +from brainpy.types import Array + +try: + import brainpylib +except ModuleNotFoundError: + brainpylib = None __all__ = [ - 'sparse_matmul' + 'sparse_matmul', + 'csr_matvec', ] @@ -42,16 +49,16 @@ def _matmul_with_left_sparse( shape = sparse['shape'] if len(shape) != 2: raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') - values = _remove_jaxarray(values) - rows = _remove_jaxarray(rows) - cols = _remove_jaxarray(cols) - dense = _remove_jaxarray(dense) + values = as_jax(values) + rows = as_jax(rows) + cols = as_jax(cols) + dense = as_jax(dense) B = dense.take(cols, axis=0) if B.ndim == 2: prod = B * jnp.reshape(values, (-1, 1)) else: prod = B * values - return jops.segment_sum(prod, rows, shape[0]) + return ops.segment_sum(prod, rows, shape[0]) def _matmul_with_right_sparse( @@ -82,17 +89,17 @@ def _matmul_with_right_sparse( shape = sparse['shape'] if len(shape) != 2: raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') - values = _remove_jaxarray(values) - rows = _remove_jaxarray(rows) - cols = _remove_jaxarray(cols) - dense = _remove_jaxarray(dense) + values = as_jax(values) + rows = as_jax(rows) + cols = as_jax(cols) + dense = as_jax(dense) if dense.ndim == 2: A = dense[:, rows] prod = (A * values).T - res = jops.segment_sum(prod, cols, shape[1]).T + res = ops.segment_sum(prod, cols, shape[1]).T else: prod = dense[rows] * values - res = jops.segment_sum(prod, cols, shape[1]) + res = ops.segment_sum(prod, cols, shape[1]) return res @@ -164,3 +171,43 @@ def sparse_matmul(A, B): f'A:\n{A}\n' f'B:\n{B}') return _matmul_with_right_sparse(A, B) + + +def csr_matvec(values: Array, + indices: Array, + indptr: Array, + vector: Array, + shape: Tuple[int, ...], + transpose: bool = False): + """Product of CSR sparse matrix and a dense vector. + + Parameters + ---------- + values: Array + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: Array + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + vector: Array + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple of int + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + + Returns + ------- + y : Array + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + _check_brainpylib('pre2post_event_sum') + vector = as_jax(vector) + indices = as_jax(indices) + indptr = as_jax(indptr) + values = as_jax(values) + return brainpylib.csr_matvec(values, indices, indptr, vector, + shape=shape, transpose=transpose) diff --git a/brainpy/math/operators/tests/test_op_register.py b/brainpy/math/operators/tests/test_op_register.py index f9f99aea6..33362d909 100644 --- a/brainpy/math/operators/tests/test_op_register.py +++ b/brainpy/math/operators/tests/test_op_register.py @@ -23,9 +23,8 @@ def event_sum_op(outs, ins): outs[index] += v -event_sum = bm.register_op(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval) -event_sum2 = bm.XLACustomOp(name='event_sum', cpu_func=event_sum_op, eval_shape=abs_eval) -event_sum = bm.jit(event_sum) +event_sum = bm.register_op(name='event_sum1', cpu_func=event_sum_op, eval_shape=abs_eval) +event_sum2 = bm.XLACustomOp(name='event_sum2', cpu_func=event_sum_op, eval_shape=abs_eval) class ExponentialSyn(bp.dyn.TwoEndConn): diff --git a/brainpy/math/operators/utils.py b/brainpy/math/operators/utils.py index a6143a437..40269d52f 100644 --- a/brainpy/math/operators/utils.py +++ b/brainpy/math/operators/utils.py @@ -8,7 +8,7 @@ brainpylib = None -_BRAINPYLIB_MINIMAL_VERSION = '0.0.7' +_BRAINPYLIB_MINIMAL_VERSION = '0.1.0' def _check_brainpylib(ops_name): diff --git a/docs/auto_generater.py b/docs/auto_generater.py index 0eeb8b3ec..e99e91227 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -493,7 +493,8 @@ def generate_math_docs(path='apis/auto/math/'): module_and_name = [ ('pre_syn_post', '``pre-syn-post`` Transformations',), - ('multiplication', 'Sparse Matrix Multiplication',), + ('sparse_matmul', 'Sparse Matrix Multiplication',), + ('event_matmul', 'Event-based Matrix Multiplication',), ('spikegrad', 'Surrogate Gradients for Spike Operation',), ('op_register', 'Operator Registration',), ('wrap_jax', 'Other Operators',),