Skip to content

Commit

Permalink
Merge pull request #527 from Routhleck/master
Browse files Browse the repository at this point in the history
[math] Implement taichi op register
  • Loading branch information
chaoming0625 committed Oct 31, 2023
2 parents 3aea8c1 + f45a454 commit 18bb51b
Show file tree
Hide file tree
Showing 3 changed files with 567 additions and 5 deletions.
30 changes: 29 additions & 1 deletion brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import inspect
import os
from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional

import jax
import numpy as np
import taichi as ti
from jax.interpreters import xla, batching, ad, mlir
from numba.core.dispatcher import Dispatcher

from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
from .numba_based import register_numba_cpu_translation_rule
from .taichi_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule)
register_taichi_gpu_translation_rule,
encode_md5,
preprocess_kernel_call_cpu,)
from .utils import register_general_batching

__all__ = [
Expand Down Expand Up @@ -82,6 +87,10 @@ def __init__(
name: str = None,
):
super().__init__(name)

# set cpu_kernel and gpu_kernel
self.cpu_kernel = cpu_kernel
self.gpu_kernel = gpu_kernel

# primitive
self.primitive = jax.core.Primitive(self.name)
Expand Down Expand Up @@ -134,8 +143,17 @@ def _abstract_eval(self, *args, **kwargs):
return self.outs

def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None):
# _set_taichi_envir()
if outs is not None:
self.outs = tuple([_transform_to_shapedarray(o) for o in outs])
cpu_kernel = getattr(self, "cpu_kernel", None)
if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi
source_md5_encode = encode_md5('cpu' + inspect.getsource(cpu_kernel) + \
str([(value.dtype, value.shape) for value in ins]) + \
str([(value.dtype, value.shape) for value in outs]))
new_ins = preprocess_kernel_call_cpu(source_md5_encode, ins, outs)
new_ins.extend(ins)
ins = new_ins
ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array)
return self.primitive.bind(*ins)

Expand Down Expand Up @@ -206,3 +224,13 @@ def _transform_to_array(a):
def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)

def _set_taichi_envir():
# find the path of taichi in python site_packages
taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
taichi_lib_dir = os.path.join(taichi_path, '_lib', 'runtime')
os.environ.update({
'TAICHI_C_API_INSTALL_DIR': taichi_c_api_install_dir,
'TI_LIB_DIR': taichi_lib_dir
})

0 comments on commit 18bb51b

Please sign in to comment.