Skip to content

Commit

Permalink
Merge pull request #529 from chaoming0625/master
Browse files Browse the repository at this point in the history
update taichi op customization
  • Loading branch information
chaoming0625 committed Nov 1, 2023
2 parents 3110da1 + 728dcbe commit 8d523da
Show file tree
Hide file tree
Showing 15 changed files with 317 additions and 307 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/CI-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down Expand Up @@ -79,6 +80,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down Expand Up @@ -128,6 +130,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install numpy>=1.21.0
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
python -m pip install -r requirements-dev.txt
python -m pip install tqdm brainpylib
pip uninstall brainpy -y
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down Expand Up @@ -102,6 +103,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
pip install taichi-nightly -i https://pypi.taichi.graphics/simple/
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand Down
11 changes: 1 addition & 10 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.5.post6"
__version__ = "2.4.6"
_minimal_brainpylib_version = '0.1.10'

# fundamental supporting modules
Expand All @@ -12,15 +12,6 @@
except ModuleNotFoundError:
raise ModuleNotFoundError(tools.jaxlib_install_info) from None


try:
import brainpylib
if brainpylib.__version__ < _minimal_brainpylib_version:
raise SystemError(f'This version of brainpy ({__version__}) needs brainpylib >= {_minimal_brainpylib_version}.')
del brainpylib
except ModuleNotFoundError:
pass

# Part: Math Foundation #
# ----------------------- #

Expand Down
25 changes: 18 additions & 7 deletions brainpy/_src/math/brainpylib_check.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
from jax.lib import xla_client
import taichi as ti
import os
import platform
import ctypes

import taichi as ti
from jax.lib import xla_client

taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
import ctypes
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
print('taichi aot custom call, Only support linux now.')
os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir
os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime')

# link DLL
if platform.system() == 'Windows':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll')
except OSError:
raise OSError(f'Does not found {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}')
elif platform.system() == 'Linux':
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
raise OSError(f'Does not found {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}')

# Register the CPU XLA custom calls
try:
Expand Down
15 changes: 10 additions & 5 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,18 @@

from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
from .numba_based import register_numba_cpu_translation_rule
from .taichi_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,
encode_md5,
preprocess_kernel_call_cpu,)
# if jax.__version__ >= '0.4.16':
# from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
# else:
# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,
encode_md5,
preprocess_kernel_call_cpu, )
from .utils import register_general_batching


__all__ = [
'XLACustomOp',
]
Expand Down
117 changes: 82 additions & 35 deletions brainpy/_src/math/op_register/numba_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
import ctypes
from functools import partial

from jax.interpreters import xla
from jax.interpreters import xla, mlir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
from numba import types, carray, cfunc

from .utils import _shape_to_layout


__all__ = [
'register_numba_cpu_translation_rule',
'register_numba_xla_cpu_translation_rule',
'register_numba_mlir_cpu_translation_rule',
]

ctypes.pythonapi.PyCapsule_New.argtypes = [
ctypes.c_void_p, # void* pointer
ctypes.c_char_p, # const char *name
ctypes.c_void_p, # PyCapsule_Destructor destructor
]
ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p]
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object


Expand All @@ -27,12 +28,6 @@ def _cpu_signature(
output_shapes,
debug: bool = False
):
# kernel_key = str(id(kernel))
# input_keys = [f'{dtype}({shape})' for dtype, shape in zip(input_dtypes, input_shapes)]
# output_keys = [f'{dtype}({shape})' for dtype, shape in zip(output_dtypes, output_shapes)]
# key = f'{kernel_key}-ins=[{", ".join(input_keys)}]-outs=[{", ".join(output_keys)}]'
# if key not in __cache:

code_scope = dict(
func_to_call=kernel,
input_shapes=input_shapes,
Expand All @@ -42,7 +37,7 @@ def _cpu_signature(
carray=carray,
)

# inputs
# inputs, outputs, arguments
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
for i in range(len(input_shapes))]
args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
Expand All @@ -51,33 +46,27 @@ def _cpu_signature(

# function body
code_string = '''
def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
{args_in}
{args_out}
func_to_call({args_call})
'''.format(args_in="\n ".join(args_in),
args_out="\n ".join(args_out),
def xla_cpu_custom_call_target(output_ptrs, input_ptrs):
{args_in}
{args_out}
func_to_call({args_call})
'''.format(args_in="\n ".join(args_in),
args_out="\n ".join(args_out),
args_call=", ".join(args_call))
if debug: print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)

# register
new_f = code_scope['xla_cpu_custom_call_target']
xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr),
types.CPointer(types.voidptr)))(new_f)
target_name = xla_c_rule.native_name.encode("ascii")
capsule = ctypes.pythonapi.PyCapsule_New(
xla_c_rule.address, # A CFFI pointer to a function
b"xla._CUSTOM_CALL_TARGET", # A binary string
None # PyCapsule object run at destruction
)
xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)))(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# else:
# target_name = __cache[key]
return target_name


def _numba_cpu_translation_rule(prim, kernel, debug: bool, c, *ins):
def _numba_xla_cpu_translation_rule(prim, kernel, debug: bool, c, *ins):
outs = prim.abstract_eval()[0]

# output information
Expand All @@ -103,13 +92,71 @@ def _numba_cpu_translation_rule(prim, kernel, debug: bool, c, *ins):
# call
return xla_client.ops.CustomCallWithLayout(
c,
target_name,
target_name.encode("ascii"),
operands=tuple(ins),
operand_shapes_with_layout=input_layouts,
shape_with_layout=output_infos,
)


def register_numba_cpu_translation_rule(primitive, cpu_kernel, debug=False):
xla.backend_specific_translations['cpu'][primitive] = partial(_numba_cpu_translation_rule,
primitive, cpu_kernel, debug)
def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False):
xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule,
primitive,
cpu_kernel,
debug)


def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins):
# output information
outs = ctx.avals_out
output_shapes = tuple([out.shape for out in outs])
output_dtypes = tuple([out.dtype for out in outs])
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
result_types = [mlir.aval_to_ir_type(out) for out in outs]

# input information
avals_in = ctx.avals_in
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
input_dtypes = tuple(inp.dtype for inp in avals_in)
input_shapes = tuple(inp.shape for inp in avals_in)

# compiling function
code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes,
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
for i in range(len(input_shapes))]
args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
for i in range(len(output_shapes))]
args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))]
code_string = '''
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
{args_in}
{args_out}
func_to_call({args_call})
'''.format(args_in="\n ".join(args_in),
args_out="\n ".join(args_out),
args_call=", ".join(args_call))
if debug: print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']

# register
xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)))(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# call
call = custom_call(call_target_name=target_name,
operands=list(ins),
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types)).results
return call


def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False):
rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug)
mlir.register_lowering(primitive, rule, platform='cpu')


0 comments on commit 8d523da

Please sign in to comment.