Skip to content

Commit

Permalink
Merge pull request #540 from chaoming0625/master
Browse files Browse the repository at this point in the history
[math] simplify the taichi AOT operator customization interface
  • Loading branch information
chaoming0625 committed Nov 9, 2023
2 parents 3cfa047 + bc0e2b5 commit 8e201f6
Show file tree
Hide file tree
Showing 19 changed files with 290 additions and 394 deletions.
94 changes: 94 additions & 0 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from jax.lib import xla_client


__all__ = [
'import_taichi',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]


_minimal_brainpylib_version = '0.1.10'
_minimal_taichi_version = (1, 7, 0)

taichi = None
has_import_ti = False
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None


def import_taichi():
global taichi, has_import_ti
if not has_import_ti:
try:
import taichi as taichi # noqa
has_import_ti = True
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)

if taichi is None:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
if taichi.__version__ < _minimal_taichi_version:
raise RuntimeError(
f'We need taichi>={_minimal_taichi_version}. '
f'Currently you can install taichi>={_minimal_taichi_version} through taichi-nightly:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
return taichi


def is_brainpylib_gpu_installed():
return False if brainpylib_gpu_ops is None else True


def import_brainpylib_cpu_ops():
global brainpylib_cpu_ops
if brainpylib_cpu_ops is None:
try:
from brainpylib import cpu_ops as brainpylib_cpu_ops

for _name, _value in brainpylib_cpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")

import brainpylib
if brainpylib.__version__ < _minimal_brainpylib_version:
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.')
if hasattr(brainpylib, 'check_brainpy_version'):
brainpylib.check_brainpy_version()

except ImportError:
raise ImportError('Please install brainpylib. \n'
'See https://brainpy.readthedocs.io for installation instructions.')

return brainpylib_cpu_ops


def import_brainpylib_gpu_ops():
global brainpylib_gpu_ops
if brainpylib_gpu_ops is None:
try:
from brainpylib import gpu_ops as brainpylib_gpu_ops

for _name, _value in brainpylib_gpu_ops.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")

import brainpylib
if brainpylib.__version__ < _minimal_brainpylib_version:
raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.')
if hasattr(brainpylib, 'check_brainpy_version'):
brainpylib.check_brainpy_version()

except ImportError:
raise ImportError('Please install GPU version of brainpylib. \n'
'See https://brainpy.readthedocs.io for installation instructions.')

return brainpylib_gpu_ops



3 changes: 0 additions & 3 deletions brainpy/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
#


from . import brainpylib_check

# data structure
from .ndarray import *
from .delayvars import *
Expand Down Expand Up @@ -62,4 +60,3 @@
from .environment import *
from .scales import *

del brainpylib_check
80 changes: 0 additions & 80 deletions brainpy/_src/math/brainpylib_check.py

This file was deleted.

7 changes: 2 additions & 5 deletions brainpy/_src/math/event/_csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,9 @@
register_general_batching)
from brainpy._src.math.sparse._csr_mv import csrmv as normal_csrmv
from brainpy._src.math.sparse._utils import csr_to_coo
from brainpy._src.dependency_check import (import_brainpylib_gpu_ops)
from brainpy.errors import GPUOperatorNotFound

try:
from brainpylib import gpu_ops
except ImportError:
gpu_ops = None

__all__ = [
'csrmv'
]
Expand Down Expand Up @@ -455,6 +451,7 @@ def _event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, sha


def _event_csr_matvec_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose):
gpu_ops = import_brainpylib_gpu_ops()
if gpu_ops is None:
raise GPUOperatorNotFound(event_csr_matvec_p.name)

Expand Down
8 changes: 3 additions & 5 deletions brainpy/_src/math/event/_info_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@

from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import register_op_with_numba
from brainpy.errors import GPUOperatorNotFound
from brainpy._src.math.ndarray import Array
from brainpy._src.dependency_check import import_brainpylib_gpu_ops
from brainpy.errors import GPUOperatorNotFound

try:
from brainpylib import gpu_ops
except ImportError:
gpu_ops = None

__all__ = [
'info'
Expand Down Expand Up @@ -79,6 +76,7 @@ def _batch_event_info_batching_rule(args, axes):


def _event_info_gpu_translation(c, events):
gpu_ops = import_brainpylib_gpu_ops()
if gpu_ops is None:
raise GPUOperatorNotFound(event_info_p.name)

Expand Down
7 changes: 0 additions & 7 deletions brainpy/_src/math/event/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,8 @@
import brainpy.math as bm
from jax import vmap



import brainpylib as bl
import pytest

if bl.__version__ < '0.1.9':
pytest.skip('Need brainpylib>=0.1.9', allow_module_level=True)



class Test_event_info(unittest.TestCase):
def __init__(self, *args, platform='cpu', **kwargs):
Expand Down
12 changes: 7 additions & 5 deletions brainpy/_src/math/jitconn/_event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from jax.interpreters import xla, ad
from jax.lib import xla_client

from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p,
mv_prob_uniform_p,
Expand All @@ -21,11 +22,6 @@
from brainpy._src.math.op_register import register_general_batching
from brainpy.errors import GPUOperatorNotFound

try:
from brainpylib import gpu_ops
except ImportError:
gpu_ops = None

__all__ = [
'event_mv_prob_homo',
'event_mv_prob_uniform',
Expand Down Expand Up @@ -167,6 +163,7 @@ def _event_matvec_prob_homo_abstract(
def _event_matvec_prob_homo_cpu_translation(
c, events, weight, clen, seed, *, shape, transpose, outdim_parallel
):
import_brainpylib_cpu_ops()
n_row, n_col = (shape[1], shape[0]) if transpose else shape
out_dtype, event_type, type_name = _get_types(c.get_shape(events))

Expand Down Expand Up @@ -201,6 +198,7 @@ def _event_matvec_prob_homo_cpu_translation(
def _event_matvec_prob_homo_gpu_translation(
c, events, weight, clen, seed, *, shape, transpose, outdim_parallel
):
gpu_ops = import_brainpylib_gpu_ops()
if gpu_ops is None:
raise GPUOperatorNotFound(event_mv_prob_homo_p.name)

Expand Down Expand Up @@ -349,6 +347,7 @@ def _event_matvec_prob_uniform_abstract(
def _event_matvec_prob_uniform_cpu_translation(
c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel
):
import_brainpylib_cpu_ops()
n_row, n_col = (shape[1], shape[0]) if transpose else shape

out_dtype, event_type, type_name = _get_types(c.get_shape(events))
Expand Down Expand Up @@ -385,6 +384,7 @@ def _event_matvec_prob_uniform_cpu_translation(
def _event_matvec_prob_uniform_gpu_translation(
c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel
):
gpu_ops = import_brainpylib_gpu_ops()
if gpu_ops is None:
raise GPUOperatorNotFound(event_mv_prob_uniform_p.name)

Expand Down Expand Up @@ -541,6 +541,7 @@ def _get_types(event_shape):
def _event_matvec_prob_normal_cpu_translation(
c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel
):
import_brainpylib_cpu_ops()
n_row, n_col = (shape[1], shape[0]) if transpose else shape

out_dtype, event_type, type_name = _get_types(c.get_shape(events))
Expand Down Expand Up @@ -577,6 +578,7 @@ def _event_matvec_prob_normal_cpu_translation(
def _event_matvec_prob_normal_gpu_translation(
c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel
):
gpu_ops = import_brainpylib_gpu_ops()
if gpu_ops is None:
raise GPUOperatorNotFound(event_mv_prob_normal_p.name)

Expand Down
15 changes: 8 additions & 7 deletions brainpy/_src/math/jitconn/_matvec.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-


import math
from functools import partial
from typing import Tuple, Optional, Union

Expand All @@ -12,15 +11,11 @@
from jax.interpreters import xla, ad
from jax.lib import xla_client

from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array, _get_dtype
from brainpy._src.math.op_register import register_general_batching
from brainpy.errors import GPUOperatorNotFound, MathError

try:
from brainpylib import gpu_ops
except ImportError:
gpu_ops = None
from brainpy.errors import GPUOperatorNotFound

__all__ = [
'mv_prob_homo',
Expand Down Expand Up @@ -304,6 +299,7 @@ def _matvec_prob_homo_abstract(
def _matvec_prob_homo_cpu_translation(
c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel
):
import_brainpylib_cpu_ops()
n_row, n_col = (shape[1], shape[0]) if transpose else shape

vec_shape = c.get_shape(vector)
Expand Down Expand Up @@ -345,6 +341,7 @@ def _matvec_prob_homo_cpu_translation(
def _matvec_prob_homo_gpu_translation(
c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel
):
gpu_ops = import_brainpylib_gpu_ops()
if gpu_ops is None:
raise GPUOperatorNotFound(mv_prob_homo_p.name)

Expand Down Expand Up @@ -492,6 +489,7 @@ def _matvec_prob_uniform_abstract(
def _matvec_prob_uniform_cpu_translation(
c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel
):
import_brainpylib_cpu_ops()
n_row, n_col = (shape[1], shape[0]) if transpose else shape

vec_shape = c.get_shape(vector)
Expand Down Expand Up @@ -537,6 +535,7 @@ def _matvec_prob_uniform_cpu_translation(
def _matvec_prob_uniform_gpu_translation(
c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel
):
gpu_ops = import_brainpylib_gpu_ops()
if gpu_ops is None:
raise GPUOperatorNotFound(mv_prob_homo_p.name)

Expand Down Expand Up @@ -672,6 +671,7 @@ def _matvec_prob_normal_abstract(
def _matvec_prob_normal_cpu_translation(
c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel
):
import_brainpylib_cpu_ops()
n_row, n_col = (shape[1], shape[0]) if transpose else shape

vec_shape = c.get_shape(vector)
Expand Down Expand Up @@ -717,6 +717,7 @@ def _matvec_prob_normal_cpu_translation(
def _matvec_prob_normal_gpu_translation(
c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel
):
gpu_ops = import_brainpylib_gpu_ops()
if gpu_ops is None:
raise GPUOperatorNotFound(mv_prob_homo_p.name)

Expand Down

0 comments on commit 8e201f6

Please sign in to comment.