Skip to content

Commit

Permalink
Merge pull request #531 from chaoming0625/master
Browse files Browse the repository at this point in the history
[math] remove the hard requirement of `taichi`
  • Loading branch information
chaoming0625 committed Nov 2, 2023
2 parents dcf91e3 + 121c7aa commit 1064116
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 66 deletions.
58 changes: 41 additions & 17 deletions brainpy/_src/math/brainpylib_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,49 @@
import platform
import ctypes

import taichi as ti
from jax.lib import xla_client

taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Does not found {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Does not found {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

try:
import taichi as ti
except (ImportError, ModuleNotFoundError):
ti = None


def import_taichi():
if ti is None:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
if ti.__version__ < (1, 7, 0):
raise RuntimeError(
'We need taichi>=1.7.0. Currently you can install taichi>=1.7.0 through taichi-nightly:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
return ti


if ti is None:
is_taichi_installed = False
else:
is_taichi_installed = True
taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Can not find {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

# Register the CPU XLA custom calls
try:
Expand Down
51 changes: 21 additions & 30 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from functools import partial, reduce
from typing import Any

import jax.numpy as jnp
import numpy as np
import taichi as ti
from jax.interpreters import xla
from jax.lib import xla_client

import brainpy.math as bm
from .utils import _shape_to_layout
from ..brainpylib_check import import_taichi


### UTILS ###

Expand Down Expand Up @@ -122,25 +122,25 @@ def check_kernel_exist(source_md5_encode: str) -> bool:

### KERNEL AOT BUILD ###

# jnp dtype to taichi type
type_map4template = {
jnp.dtype("bool"): bool,
jnp.dtype("int8"): ti.int8,
jnp.dtype("int16"): ti.int16,
jnp.dtype("int32"): ti.int32,
jnp.dtype("int64"): ti.int64,
jnp.dtype("uint8"): ti.uint8,
jnp.dtype("uint16"): ti.uint16,
jnp.dtype("uint32"): ti.uint32,
jnp.dtype("uint64"): ti.uint64,
jnp.dtype("float16"): ti.float16,
jnp.dtype("float32"): ti.float32,
jnp.dtype("float64"): ti.float64,
}


def _array_to_field(dtype, shape) -> Any:
return ti.field(dtype=type_map4template[dtype], shape=shape)
ti = import_taichi()
if dtype == np.bool_:
dtype = bool
elif dtype == np.int8: dtype= ti.int8
elif dtype == np.int16: dtype= ti.int16
elif dtype == np.int32: dtype= ti.int32
elif dtype == np.int64: dtype= ti.int64
elif dtype == np.uint8: dtype= ti.uint8
elif dtype == np.uint16: dtype= ti.uint16
elif dtype == np.uint32: dtype= ti.uint32
elif dtype == np.uint64: dtype= ti.uint64
elif dtype == np.float16: dtype= ti.float16
elif dtype == np.float32: dtype= ti.float32
elif dtype == np.float64: dtype= ti.float64
else:
raise TypeError
return ti.field(dtype=dtype, shape=shape)


# build aot kernel
Expand All @@ -151,6 +151,8 @@ def build_kernel(
outs: dict,
device: str
):
ti = import_taichi()

# init arch
arch = None
if device == 'cpu':
Expand Down Expand Up @@ -191,17 +193,6 @@ def build_kernel(
int: 0,
float: 1,
bool: 2,
ti.int32: 0,
ti.float32: 1,
ti.u8: 3,
ti.u16: 4,
ti.u32: 5,
ti.u64: 6,
ti.i8: 7,
ti.i16: 8,
ti.i64: 9,
ti.f16: 10,
ti.f64: 11,
np.dtype('int32'): 0,
np.dtype('float32'): 1,
np.dtype('bool'): 2,
Expand Down
18 changes: 0 additions & 18 deletions brainpy/_src/tools/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
except (ImportError, ModuleNotFoundError):
brainpylib = None

try:
import taichi as ti
except (ImportError, ModuleNotFoundError):
ti = None

__all__ = [
'import_numba',
Expand All @@ -27,20 +23,6 @@
]


def import_taichi():
if ti is None:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
if ti.__version__ < (1, 7, 0):
raise RuntimeError(
'We need taichi>=1.7.0. Currently you can install taichi>=1.7.0 through taichi-nightly:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
return ti


def import_numba():
if numba is None:
raise ModuleNotFoundError('Numba is needed. Please install numba through:\n\n'
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ jax
tqdm
msgpack
numba
taichi

0 comments on commit 1064116

Please sign in to comment.