Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions brainpy/math/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# -*- coding: utf-8 -*-


from . import multiplication
from . import sparse_matmul, event_matmul
from . import op_register
from . import pre_syn_post as pre_syn_post_module
from . import wrap_jax
from . import spikegrad

__all__ = multiplication.__all__ + op_register.__all__
__all__ = event_matmul.__all__ + sparse_matmul.__all__ + op_register.__all__
__all__ += pre_syn_post_module.__all__ + wrap_jax.__all__ + spikegrad.__all__


from .multiplication import *
from .event_matmul import *
from .sparse_matmul import *
from .op_register import *
from .pre_syn_post import *
from .wrap_jax import *
Expand Down
57 changes: 57 additions & 0 deletions brainpy/math/operators/event_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-


from typing import Tuple

from brainpy.math.numpy_ops import as_jax
from brainpy.types import Array
from .utils import _check_brainpylib

try:
import brainpylib
except ModuleNotFoundError:
brainpylib = None

__all__ = [
'event_csr_matvec',
]


def event_csr_matvec(values: Array,
indices: Array,
indptr: Array,
events: Array,
shape: Tuple[int, ...],
transpose: bool = False):
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.

Parameters
----------
values: Array, float
An array of shape ``(nse,)`` or a float.
indices: Array
An array of shape ``(nse,)``.
indptr: Array
An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``.
events: Array
An array of shape ``(shape[0] if transpose else shape[1],)``
and dtype ``data.dtype``.
shape: tuple of int
A length-2 tuple representing the sparse matrix shape.
transpose: bool
A boolean specifying whether to transpose the sparse matrix
before computing. Default is False.

Returns
-------
out: Array
A tensor with the shape of ``shape[1]`` if `transpose=True`,
or ``shape[0]`` if `transpose=False`.
"""
_check_brainpylib('event_csr_matvec')
events = as_jax(events)
indices = as_jax(indices)
indptr = as_jax(indptr)
values = as_jax(values)
return brainpylib.event_csr_matvec(values, indices, indptr, events,
shape=shape, transpose=transpose)
36 changes: 23 additions & 13 deletions brainpy/math/operators/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Union, Sequence, Callable

from jax.abstract_arrays import ShapedArray
from jax.core import ShapedArray
from jax.tree_util import tree_map

from brainpy.base import Base
Expand Down Expand Up @@ -57,6 +57,10 @@ def __init__(
gpu_func: Callable = None,
apply_cpu_func_to_gpu: bool = False,
name: str = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
multiple_results: bool = False,
):
_check_brainpylib(register_op.__name__)
super(XLACustomOp, self).__init__(name=name)
Expand All @@ -77,19 +81,25 @@ def __init__(
gpu_func = None

# register OP
self.op = brainpylib.register_op(self.name,
cpu_func=cpu_func,
gpu_func=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
self.op = brainpylib.register_op_with_numba(
self.name,
cpu_func=cpu_func,
gpu_func_translation=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)

def __call__(self, *args, **kwargs):
args = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
args, is_leaf=lambda a: isinstance(a, JaxArray))
kwargs = tree_map(lambda a: a.value if isinstance(a, JaxArray) else a,
kwargs, is_leaf=lambda a: isinstance(a, JaxArray))
res = self.op.bind(*args, **kwargs)
return res[0] if len(res) == 1 else res
return res


def register_op(
Expand Down Expand Up @@ -122,15 +132,15 @@ def register_op(
A jitable JAX function.
"""
_check_brainpylib(register_op.__name__)
f = brainpylib.register_op(name,
cpu_func=cpu_func,
gpu_func=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)
f = brainpylib.register_op_with_numba(name,
cpu_func=cpu_func,
gpu_func_translation=gpu_func,
out_shapes=eval_shape,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)

def fixed_op(*inputs, **info):
inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs])
res = f.bind(*inputs, **info)
return res[0] if len(res) == 1 else res
return res

return fixed_op
Loading