Skip to content

Commit

Permalink
Implement cupy based customized operators and Need to be tested
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 14, 2024
1 parent 8d8e30e commit cc3a817
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 22 deletions.
7 changes: 5 additions & 2 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import register_cupy_mlir_gpu_translation_rule as register_cupy_gpu_translation_rule
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import register_cupy_xla_gpu_translation_rule as register_cupy_gpu_translation_rule
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

Expand Down Expand Up @@ -125,14 +127,15 @@ def __init__(
gpu_checked = False
if gpu_kernel is None:
gpu_checked = True
elif gpu_kernel is str: # cupy
elif isinstance(gpu_kernel, str): # cupy
# TODO: register cupy translation rule
register_cupy_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
if not gpu_checked:
raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}')
raise ValueError(f'"gpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}')

# batching rule
if batching_translation is None:
Expand Down
135 changes: 123 additions & 12 deletions brainpy/_src/math/op_register/cupy_based.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,151 @@
from functools import partial, reduce
from typing import List

import jax
import numpy as np
from jax.interpreters import xla, mlir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
from functools import partial
from brainpy._src.dependency_check import (import_cupy)

from brainpy._src.dependency_check import (import_cupy,
import_brainpylib_gpu_ops)
from brainpy._src.math.op_register.utils import _shape_to_layout
from brainpy.errors import PackageMissingError

cp = import_cupy(error_if_not_found=False)

# convert type to number
type_number_map = {
int: 0,
float: 1,
bool: 2,
np.dtype('int32'): 0,
np.dtype('float32'): 1,
np.dtype('bool'): 2,
np.dtype('uint8'): 3,
np.dtype('uint16'): 4,
np.dtype('uint32'): 5,
np.dtype('uint64'): 6,
np.dtype('int8'): 7,
np.dtype('int16'): 8,
np.dtype('int64'): 9,
np.dtype('float16'): 10,
np.dtype('float64'): 11,
}


def _preprocess_kernel_call_gpu(
grid: int,
block: int,
func_ptr: int,
shared_mem: int,
*ins,
outs: List[jax.ShapeDtypeStruct],
):
grid = (grid + (1, 1))[:3]
block = (block + (1, 1))[:3]
in_num = len(ins)
out_num = len(outs)
in_out_num = [in_num, out_num]

out_type_list = [0] * out_num
out_elem_count_list = [0] * out_num

for i, value in enumerate(outs):
out_type_list[i] = type_number_map[value.dtype]
out_elem_count_list[i] = reduce(lambda x, y: x * y, value.shape)

grid = ",".join(str(i) for i in grid)
block = ",".join(str(i) for i in block)
in_out_num_str = ",".join(str(i) for i in in_out_num)
out_type_list_str = ",".join(str(i) for i in out_type_list)
out_elem_count_list_str = ",".join(str(i) for i in out_elem_count_list)

opaque = (bytes(str(func_ptr), encoding='utf-8') + b';' +
bytes(str(shared_mem), encoding='utf-8') + b';' +
bytes(in_out_num_str, encoding='utf-8') + b';' +
bytes(grid, encoding='utf-8') + b';' +
bytes(block, encoding='utf-8') + b';' +
bytes(out_type_list_str, encoding='utf-8') + b';' +
bytes(out_elem_count_list_str, encoding='utf-8') + b';')
return opaque


def _cupy_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
grid = kwargs.get('grid', None)
block = kwargs.get('block', None)
shared_mem = kwargs.get('shared_mem', 0)
if grid is None or block is None:
raise ValueError('The grid and block should be specified for the cupy kernel.')

def _cupy_xla_gpu_translation_rule(kernel, c, *args, **kwargs):
# TODO: implement the translation rule
mod = cp.RawModule(code=kernel)
# compile
mod = cp.RawModule(code=kernel)
try:
kernel_ptr = mod.get_function('kernel')
kernel_func = mod.get_function('kernel')
except AttributeError:
raise ValueError('The \'kernel\' function is not found in the module.')

# preprocess
import_brainpylib_gpu_ops()
opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])

# create custom call
return xla_client.ops.CustomCallWithLayout(
c,
b'cupy_kernel_call_gpu',

operands=ins,
operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
shape_with_layout=xla_client.Shape.tuple_shape(
[xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape))
for value in kwargs['outs']]
),
opaque=opaque,
)
...


def register_cupy_xla_gpu_translation_rule(primitive, gpu_kernel):
xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_xla_gpu_translation_rule, gpu_kernel)


def _cupy_mlir_gpu_translation_rule(kernel, c, *args, **kwargs):
# TODO: implement the translation rule
...
def _cupy_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
grid = kwargs.get('grid', None)
block = kwargs.get('block', None)
shared_mem = kwargs.get('shared_mem', 0)
if grid is None or block is None:
raise ValueError('The grid and block should be specified for the cupy kernel.')

# compile
mod = cp.RawModule(code=kernel)
try:
kernel_func = mod.get_function('kernel')
except AttributeError:
raise ValueError('The \'kernel\' function is not found in the module.')

# preprocess
import_brainpylib_gpu_ops()
opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])

input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out]

# print ins
for i in ins:
print(i)
return custom_call(
call_target_name='cupy_kernel_call_gpu',
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
backend_config=opaque,
has_side_effect=False,
).results


def register_cupy_mlir_gpu_translation_rule(primitive, gpu_kernel):
if cp is None:
raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule')

rule = partial(_cupy_mlir_gpu_translation_rule, gpu_kernel)
mlir.register_primitive_rule(primitive, rule, platform='gpu')
mlir.register_lowering(primitive, rule, platform='gpu')
56 changes: 48 additions & 8 deletions brainpy/_src/math/op_register/tests/test_cupy_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,54 @@

import brainpy.math as bm
from brainpy._src.math import as_jax

bm.set_platform('gpu')

time1 = time()
a = bm.random.rand(4, 4)
time2 = time()
c = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(a)))
time3 = time()

c *= c
print(f'c: {c}')
print(f'a: {a}')
def test_cupy_based():
source_code = r'''
extern "C"{
__global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)
{
unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < N)
{
y[tid] = x1[tid] + x2[tid];
}
}
}
'''
N = 10
x1 = bm.random.rand(N, N)
# x1_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(x1)))
x2 = bm.ones((N, N))
# x2_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(x2)))
y = bm.zeros((N, N))
# y_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(y)))

# mod = cp.RawModule(code=source_code)
# kernel = mod.get_function('kernel')
# y = kernel((N,), (N,), (x1_cp, x2_cp, N**2, y_cp))
# print(y_cp)

prim = bm.XLACustomOp(gpu_kernel=source_code)

n = jnp.asarray([N**2,], dtype=jnp.int32)

y = prim(x1, x2, n, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])

print(y)
assert jnp.allclose(y, x1 + x2)

# N = 10
# x1 = cp.arange(N**2, dtype=cp.float32).reshape(N, N)
# x2 = cp.ones((N, N), dtype=cp.float32)
# y = cp.zeros((N, N), dtype=cp.float32)
# ker_sum((N,), (N,), (x1, x2, y, N**2)) # y = x1 + x2
# assert cp.allclose(y, x1 + x2)
# ker_times((N,), (N,), (x1, x2, y, N**2)) # y = x1 * x2
# assert cp.allclose(y, x1 * x2)


test_cupy_based()

0 comments on commit cc3a817

Please sign in to comment.