Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 13, 2024
1 parent 23b5ab9 commit 8d8e30e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 3 deletions.
26 changes: 26 additions & 0 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
'raise_taichi_not_found',
'import_numba',
'raise_numba_not_found',
'import_cupy',
'raise_cupy_not_found',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]
Expand All @@ -17,6 +19,7 @@

numba = None
taichi = None
cupy = None
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

Expand All @@ -25,6 +28,8 @@
'> pip install taichi==1.7.0')
numba_install_info = ('We need numba. Please install numba by pip . \n'
'> pip install numba')
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
'> pip install cupy')
os.environ["TI_LOG_LEVEL"] = "error"


Expand Down Expand Up @@ -81,6 +86,27 @@ def raise_numba_not_found():
raise ModuleNotFoundError(numba_install_info)


def import_cupy(error_if_not_found=True):
"""
Internal API to import cupy.
If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global cupy
if cupy is None:
try:
import cupy as cupy
except ModuleNotFoundError:
if error_if_not_found:
raise_cupy_not_found()
else:
return None
return cupy

def raise_cupy_not_found():
raise ModuleNotFoundError(cupy_install_info)

def is_brainpylib_gpu_installed():
return False if brainpylib_gpu_ops is None else True

Expand Down
9 changes: 6 additions & 3 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional
from typing import Callable, Sequence, Tuple, Protocol, Optional, Union

import jax
import numpy as np
Expand Down Expand Up @@ -83,7 +83,7 @@ class XLACustomOp(BrainPyObject):
def __init__(
self,
cpu_kernel: Callable = None,
gpu_kernel: Callable = None,
gpu_kernel: Union[Callable, str] = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
Expand Down Expand Up @@ -125,7 +125,10 @@ def __init__(
gpu_checked = False
if gpu_kernel is None:
gpu_checked = True
if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
elif gpu_kernel is str: # cupy
# TODO: register cupy translation rule
gpu_checked = True
elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
if not gpu_checked:
Expand Down
40 changes: 40 additions & 0 deletions brainpy/_src/math/op_register/cupy_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from jax.interpreters import xla, mlir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
from functools import partial
from brainpy._src.dependency_check import (import_cupy)
from brainpy.errors import PackageMissingError

cp = import_cupy(error_if_not_found=False)


def _cupy_xla_gpu_translation_rule(kernel, c, *args, **kwargs):
# TODO: implement the translation rule
mod = cp.RawModule(code=kernel)
# compile
try:
kernel_ptr = mod.get_function('kernel')
except AttributeError:
raise ValueError('The \'kernel\' function is not found in the module.')
return xla_client.ops.CustomCallWithLayout(
c,
b'cupy_kernel_call_gpu',

)
...


def register_cupy_xla_gpu_translation_rule(primitive, gpu_kernel):
xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_xla_gpu_translation_rule, gpu_kernel)


def _cupy_mlir_gpu_translation_rule(kernel, c, *args, **kwargs):
# TODO: implement the translation rule
...

def register_cupy_mlir_gpu_translation_rule(primitive, gpu_kernel):
if cp is None:
raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule')

rule = partial(_cupy_mlir_gpu_translation_rule, gpu_kernel)
mlir.register_primitive_rule(primitive, rule, platform='gpu')
18 changes: 18 additions & 0 deletions brainpy/_src/math/op_register/tests/test_cupy_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import jax.numpy as jnp
import jax
import cupy as cp
from time import time

import brainpy.math as bm
from brainpy._src.math import as_jax
bm.set_platform('gpu')

time1 = time()
a = bm.random.rand(4, 4)
time2 = time()
c = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(a)))
time3 = time()

c *= c
print(f'c: {c}')
print(f'a: {a}')

0 comments on commit 8d8e30e

Please sign in to comment.