Skip to content

Commit

Permalink
[math] Add taichi customized operators (event csrmv, csrmv, jitconn e…
Browse files Browse the repository at this point in the history
…vent mv, jitconn mv) (#553)

* Add _csr_matvec_taichi.py

* Test event csr matvec using taichi custom op

* Update _csr_matvec_taichi.py

* Add sparse csr matvec using taichi customized op

* Test event csr matvec using taichi customized op

* Implement autograd of event csr matvec using taichi customized op

* Update test of `test_event_csrmv_taichi.py`

* Update _csr_mv_taichi.py

* Test sparse csr matvec using taichi customized op

* Update test_csrmv_taichi.py

* Remove test event and sparse csrmv using taichi from pytest

* Fix autograd bug and update test_csrmv_taichi.py

* Fix autograd bug and update `test_event_csr_matvec_taichi.py`

* Fix event csr matvec kernel bug

* Fix test bugs

* Add taichi.func random generators

* Update `test_taichi_random.py`

* Implement `mv_prob_homo_taichi` and `mv_prob_uniform_taichi`

* Implement jitconn matvec using taichi customized op` and Need to test

* Fix bugs in

* Remove pytest in 'test_taichi_random.py'

* Implement jitconn event matvec using taichi customized op and Need to test

* Implement lfsr88 random generator algorithm

* Refactor `jitconn/_matvec_taichi.py` with lfsr88 random generator

* [csrmv taichi] format codes and redefine JVP rules using `.defjvp` interface

* [csrmv taichi] format codes of `brainpy.math.sparse.csrmv` and redefine JVP rules using `.defjvp` interface

* [math] depress taichi import logging by forcing using `import_taichi()` utility, move taichi random functions into another file

* fix missing file

* Optimize event csr matvec with taichi customized op and Add taichi event csr matvec benchmark

* Update event_csrmv_taichi_VS_event_csrmv.py

* Optimize csr matvec with taichi customized op and Add taichi csr matvec benchmark

* Fix bugs

* Add more benchmarks

* Update benchmarks

* Optimized taichi event csr matvec gpu

* Update benchmarks

* Update benchmarks

* Update benchmarks

* Update benchmarks

* Optimized taichi customized cpu kernels about event csr matvec and csr matvec

* Add taichi jitconn matvec benchmark and Optimize taichi jitconn matvec op

* Refactor taichi event matvec op

* Add taichi jitconn event matvec benchmark

* Optimize taichi jitconn matvec op on CPU backend

* Update taichi jitconn matvec op

* Update test files for taichi jitconn op

* Update taichi random generator

* Fix bugs

* Add new function for taichi random seeds initialization

* Update taichi_random_time_test.py

* Update taichi_random_time_test.py

* Update taichi_random_time_test.py

* Fix bugs

* Remove taichi_random_time_test.py

* Update test_taichi_random.py

* [event csr taichi] small upgrade

* [csr mv taichi] fix bugs

* [math] new module `brainpy.math.tifunc` for taichi functionality

* [math] move default environment setting into `defaults.py`

* [math] fix and update taichi jitconn operators

---------

Co-authored-by: chaoming <adaduo@outlook.com>
  • Loading branch information
Routhleck and chaoming0625 committed Dec 29, 2023
1 parent 9662fbb commit 6368289
Show file tree
Hide file tree
Showing 34 changed files with 8,732 additions and 410 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
'> pip install taichi==1.7.0 -U')
'> pip install taichi==1.7.0')
os.environ["TI_LOG_LEVEL"] = "error"


Expand Down
6 changes: 5 additions & 1 deletion brainpy/_src/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,18 @@ def new_func(*args, **kwargs):
return new_func


def deprecation_getattr(module, deprecations):
def deprecation_getattr(module, deprecations, redirects=None):
redirects = redirects or {}

def getattr(name):
if name in deprecations:
message, fn = deprecations[name]
if fn is None:
raise AttributeError(message)
_deprecate(message)
return fn
if name in redirects:
return redirects[name]
raise AttributeError(f"module {module!r} has no attribute {name!r}")

return getattr
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from .compat_numpy import *
from .compat_tensorflow import *
from .others import *
from . import random, linalg, fft
from . import random, linalg, fft, tifunc

# operators
from .op_register import *
Expand Down
48 changes: 48 additions & 0 deletions brainpy/_src/math/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import jax.numpy as jnp
from jax import config

from brainpy._src.dependency_check import import_taichi
from .modes import NonBatchingMode
from .scales import IdScaling

__all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_']

ti = import_taichi()

# Default computation mode.
mode = NonBatchingMode()

# '''Default computation mode.'''
membrane_scaling = IdScaling()

# '''Default time step.'''
dt = 0.1

# '''Default bool data type.'''
bool_ = jnp.bool_

# '''Default integer data type.'''
int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32

# '''Default integer data type in Taichi.'''
ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32

# '''Default float data type.'''
float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32

# '''Default float data type in Taichi.'''
ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32

# '''Default complex data type.'''
complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64

# redirects
redirects = {'mode': mode,
'membrane_scaling': membrane_scaling,
'dt': dt,
'bool_': bool_,
'int_': int_,
'ti_int': ti_int,
'float_': float_,
'ti_float': ti_float,
'complex_': complex_}
96 changes: 42 additions & 54 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@

from . import modes
from . import scales
from . import defaults
from brainpy._src.dependency_check import import_taichi

bm = None
ti = import_taichi()

__all__ = [
# context manage for environment setting
Expand Down Expand Up @@ -389,9 +391,7 @@ def ditype():
"""
# raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n'
# 'Use `brainpy.math.int_` instead.')
global bm
if bm is None: from brainpy import math as bm
return bm.int_
return defaults.int_


def dftype():
Expand All @@ -403,9 +403,7 @@ def dftype():

# raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n'
# 'Use `brainpy.math.float_` instead.')
global bm
if bm is None: from brainpy import math as bm
return bm.float_
return defaults.float_


def set_float(dtype: type):
Expand All @@ -416,11 +414,17 @@ def set_float(dtype: type):
dtype: type
The float type.
"""
if dtype not in [jnp.float16, jnp.float32, jnp.float64, ]:
raise TypeError(f'Float data type {dtype} is not supported.')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['float_'] = dtype
if dtype in [jnp.float16, 'float16', 'f16']:
defaults.__dict__['float_'] = jnp.float16
defaults.__dict__['ti_float'] = ti.float16
elif dtype in [jnp.float32, 'float32', 'f32']:
defaults.__dict__['float_'] = jnp.float32
defaults.__dict__['ti_float'] = ti.float32
elif dtype in [jnp.float64, 'float64', 'f64']:
defaults.__dict__['float_'] = jnp.float64
defaults.__dict__['ti_float'] = ti.float64
else:
raise NotImplementedError


def get_float():
Expand All @@ -431,9 +435,7 @@ def get_float():
dftype: type
The default float data type.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.float_
return defaults.float_


def set_int(dtype: type):
Expand All @@ -444,12 +446,20 @@ def set_int(dtype: type):
dtype: type
The integer type.
"""
if dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64,
jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64, ]:
raise TypeError(f'Integer data type {dtype} is not supported.')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['int_'] = dtype
if dtype in [jnp.int8, 'int8', 'i8']:
defaults.__dict__['int_'] = jnp.int8
defaults.__dict__['ti_int'] = ti.int8
elif dtype in [jnp.int16, 'int16', 'i16']:
defaults.__dict__['int_'] = jnp.int16
defaults.__dict__['ti_int'] = ti.int16
elif dtype in [jnp.int32, 'int32', 'i32']:
defaults.__dict__['int_'] = jnp.int32
defaults.__dict__['ti_int'] = ti.int32
elif dtype in [jnp.int64, 'int64', 'i64']:
defaults.__dict__['int_'] = jnp.int64
defaults.__dict__['ti_int'] = ti.int64
else:
raise NotImplementedError


def get_int():
Expand All @@ -460,9 +470,7 @@ def get_int():
dftype: type
The default int data type.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.int_
return defaults.int_


def set_bool(dtype: type):
Expand All @@ -473,9 +481,7 @@ def set_bool(dtype: type):
dtype: type
The bool type.
"""
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['bool_'] = dtype
defaults.__dict__['bool_'] = dtype


def get_bool():
Expand All @@ -486,9 +492,7 @@ def get_bool():
dftype: type
The default bool data type.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.bool_
return defaults.bool_


def set_complex(dtype: type):
Expand All @@ -499,9 +503,7 @@ def set_complex(dtype: type):
dtype: type
The complex type.
"""
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['complex_'] = dtype
defaults.__dict__['complex_'] = dtype


def get_complex():
Expand All @@ -512,9 +514,7 @@ def get_complex():
dftype: type
The default complex data type.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.complex_
return defaults.complex_


# numerical precision
Expand All @@ -529,9 +529,7 @@ def set_dt(dt):
Numerical integration precision.
"""
assert isinstance(dt, float), f'"dt" must a float, but we got {dt}'
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['dt'] = dt
defaults.__dict__['dt'] = dt


def get_dt():
Expand All @@ -542,9 +540,7 @@ def get_dt():
dt : float
Numerical integration precision.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.dt
return defaults.dt


def set_mode(mode: modes.Mode):
Expand All @@ -558,9 +554,7 @@ def set_mode(mode: modes.Mode):
if not isinstance(mode, modes.Mode):
raise TypeError(f'Must be instance of brainpy.math.Mode. '
f'But we got {type(mode)}: {mode}')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['mode'] = mode
defaults.__dict__['mode'] = mode


def get_mode() -> modes.Mode:
Expand All @@ -571,9 +565,7 @@ def get_mode() -> modes.Mode:
mode: Mode
The default computing mode.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.mode
return defaults.mode


def set_membrane_scaling(membrane_scaling: scales.Scaling):
Expand All @@ -587,9 +579,7 @@ def set_membrane_scaling(membrane_scaling: scales.Scaling):
if not isinstance(membrane_scaling, scales.Scaling):
raise TypeError(f'Must be instance of brainpy.math.Scaling. '
f'But we got {type(membrane_scaling)}: {membrane_scaling}')
global bm
if bm is None: from brainpy import math as bm
bm.__dict__['membrane_scaling'] = membrane_scaling
defaults.__dict__['membrane_scaling'] = membrane_scaling


def get_membrane_scaling() -> scales.Scaling:
Expand All @@ -600,9 +590,7 @@ def get_membrane_scaling() -> scales.Scaling:
membrane_scaling: Scaling
The default computing membrane_scaling.
"""
global bm
if bm is None: from brainpy import math as bm
return bm.membrane_scaling
return defaults.membrane_scaling


def enable_x64(x64=None):
Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/math/event/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

from ._info_collection import *
from ._csr_matvec import *
from ._csr_matvec_taichi import *

0 comments on commit 6368289

Please sign in to comment.