Skip to content

Commit

Permalink
[taichi] Make taichi caches more transparent and Add clean caches fun…
Browse files Browse the repository at this point in the history
…ction (#596)

* [taichi] Make taichi caches more transparent and Add clean caches function

* Update clean caches function

* Fix bugs

* Update test_taichi_clean_cache.py

* Remove taichi kernels cache size check

* Update operator_custom_with_taichi.ipynb
  • Loading branch information
Routhleck committed Jan 17, 2024
1 parent 02b85b2 commit c2f2db9
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 56 deletions.
1 change: 1 addition & 0 deletions brainpy/_src/math/op_register/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .numba_approach import (CustomOpByNumba,
register_op_with_numba,
compile_cpu_signature_with_numba)
from .taichi_aot_based import clean_caches, check_kernels_count
from .base import XLACustomOp
from .utils import register_general_batching
4 changes: 3 additions & 1 deletion brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,)
register_taichi_gpu_translation_rule,
clean_caches)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
if transpose_translation is not None:
ad.primitive_transposes[self.primitive] = transpose_translation


def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
if outs is None:
outs = self.outs
Expand Down
41 changes: 38 additions & 3 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import platform
import re
import shutil
from functools import partial, reduce
from typing import Any, Sequence

Expand Down Expand Up @@ -36,6 +37,34 @@ def encode_md5(source: str) -> str:

return md5.hexdigest()

# check kernels count
def check_kernels_count() -> int:
if not os.path.exists(kernels_aot_path):
return 0
kernels_count = 0
dir1 = os.listdir(kernels_aot_path)
for i in dir1:
dir2 = os.listdir(os.path.join(kernels_aot_path, i))
kernels_count += len(dir2)
return kernels_count

# clean caches
def clean_caches(kernels_name: list[str]=None):
if kernels_name is None:
if not os.path.exists(kernels_aot_path):
raise FileNotFoundError("The kernels cache folder does not exist. \
Please define a kernel using `taichi.kernel` \
and customize the operator using `bm.XLACustomOp` \
before calling the operator.")
shutil.rmtree(kernels_aot_path)
print('Clean all kernel\'s cache successfully')
return
for kernel_name in kernels_name:
try:
shutil.rmtree(os.path.join(kernels_aot_path, kernel_name))
except FileNotFoundError:
raise FileNotFoundError(f'Kernel {kernel_name} does not exist.')
print('Clean kernel\'s cache successfully')

# TODO
# not a very good way
Expand Down Expand Up @@ -151,6 +180,9 @@ def _build_kernel(
if ti.lang.impl.current_cfg().arch != arch:
raise RuntimeError(f"Arch {arch} is not available")

# get kernel name
kernel_name = kernel.__name__

# replace the name of the func
kernel.__name__ = f'taichi_kernel_{device}'

Expand All @@ -170,6 +202,9 @@ def _build_kernel(
mod.add_kernel(kernel, template_args=template_args_dict)
mod.save(kernel_path)

# rename kernel name
kernel.__name__ = kernel_name


### KERNEL CALL PREPROCESS ###

Expand Down Expand Up @@ -246,7 +281,7 @@ def _preprocess_kernel_call_cpu(
return in_out_info


def preprocess_kernel_call_gpu(
def _preprocess_kernel_call_gpu(
source_md5_encode: str,
ins: dict,
outs: dict,
Expand Down Expand Up @@ -312,7 +347,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs):

# kernel to code
codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform)
source_md5_encode = encode_md5(codes)
source_md5_encode = kernel.__name__ + '/' + encode_md5(codes)

# create ins, outs dict from kernel's args
in_num = len(ins)
Expand All @@ -332,7 +367,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs):
# returns
if platform in ['gpu', 'cuda']:
import_brainpylib_gpu_ops()
opaque = preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
return opaque
elif platform == 'cpu':
import_brainpylib_cpu_ops()
Expand Down
54 changes: 54 additions & 0 deletions brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import brainpy.math as bm
import jax
import jax.numpy as jnp
import platform
import pytest
import taichi

if not platform.platform().startswith('Windows'):
pytest.skip(allow_module_level=True)

@taichi.func
def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32:
return weight[0]


@taichi.func
def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32):
out[index] += weight_val

@taichi.kernel
def event_ell_cpu(indices: taichi.types.ndarray(ndim=2),
vector: taichi.types.ndarray(ndim=1),
weight: taichi.types.ndarray(ndim=1),
out: taichi.types.ndarray(ndim=1)):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
taichi.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)

prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)

def test_taichi_clean_cache():
s = 1000
indices = bm.random.randint(0, s, (s, 1000))
vector = bm.random.rand(s) < 0.1
weight = bm.array([1.0])

out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

print(out)
bm.clear_buffer_memory()

print('kernels: ', bm.check_kernels_count())

bm.clean_caches()

print('kernels: ', bm.check_kernels_count())

# test_taichi_clean_cache()
2 changes: 2 additions & 0 deletions brainpy/math/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from brainpy._src.math.op_register import (
CustomOpByNumba,
compile_cpu_signature_with_numba,
clean_caches,
check_kernels_count,
)

from brainpy._src.math.op_register.base import XLACustomOp
Expand Down
116 changes: 64 additions & 52 deletions docs/tutorial_advanced/operator_custom_with_taichi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
},
{
"cell_type": "markdown",
"source": [
"This functionality is only available for ``brainpylib>=0.2.0``. "
],
"metadata": {
"collapsed": false
}
},
"source": [
"This functionality is only available for ``brainpylib>=0.2.0``. "
]
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -182,26 +182,6 @@
" # If the kernel is run on the CUDA backend, each block will have 16 threads.\n",
" for i in range(n):\n",
" val[i] = i\n",
"```\n",
"\n",
"#### `ti.grouped`\n",
"Groups the indices in the iterator returned by ndrange() into a 1-D vector.\n",
"This is often used when you want to iterate over all indices returned by ndrange() in one for loop and a single index.\n",
"\n",
"Example:\n",
"\n",
"```python\n",
"# without ti.grouped\n",
"for I in ti.ndrange(2, 3):\n",
" print(I)\n",
"prints 0, 1, 2, 3, 4, 5\n",
"```\n",
"\n",
"```python\n",
"# with ti.grouped\n",
"for I in ti.grouped(ndrange(2, 3)):\n",
" print(I)\n",
"prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]\n",
"```"
]
},
Expand Down Expand Up @@ -251,11 +231,12 @@
" vector: ti.types.ndarray(ndim=1), \n",
" weight: ti.types.ndarray(ndim=1), \n",
" out: ti.types.ndarray(ndim=1)):\n",
" weight_0 = weight[0]\n",
" ti.loop_config(block_dim=64)\n",
" for ij in ti.grouped(indices):\n",
" if vector[ij[0]]:\n",
" out[ij[1]] += weight_0\n",
" weight_val = get_weight(weight)\n",
" num_rows, num_cols = indices.shape\n",
" for i in range(num_rows):\n",
" if vector[i]:\n",
" for j in range(num_cols):\n",
" update_output(out, indices[i, j], weight_val)\n",
"\n",
"prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)\n",
"\n",
Expand All @@ -276,6 +257,32 @@
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### More Examples\n",
"For more examples, please refer to: \n",
"- [event/_csr_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/event/_csr_matvec_taichi.py)\n",
"- [sparse/_csr_mv_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/sparse/_csr_mv_taichi.py)\n",
"- [jitconn/_event_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_event_matvec_taichi.py)\n",
"- [jitconn/_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_matvec_taichi.py)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Clean the cache of taichi kernels\n",
"Because brainpy fuse taichi and JAX using taichi AOT method, the taichi kernels will be cached in the system. If you want to clean the cache, you can use the following code:\n",
"\n",
"```python\n",
"import brainpy.math as bm\n",
"\n",
"bm.clean_caches()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -442,28 +449,7 @@
" # If the kernel is run on the CUDA backend, each block will have 16 threads.\n",
" for i in range(n):\n",
" val[i] = i\n",
"```\n",
"\n",
"#### `ti.grouped`\n",
"\n",
"将由`ndrange()`返回的迭代器中的索引组合成一个一维向量。\n",
"这通常在你想要在一个 for 循环中迭代 ndrange() 返回的所有索引时使用,并且只使用一个索引。\n",
"\n",
"示例:\n",
"\n",
"```python\n",
"# without ti.grouped\n",
"for I in ti.ndrange(2, 3):\n",
" print(I)\n",
"prints 0, 1, 2, 3, 4, 5\n",
"```\n",
"\n",
"```python\n",
"# with ti.grouped\n",
"for I in ti.grouped(ndrange(2, 3)):\n",
" print(I)\n",
"prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]\n",
"```"
"```\n"
]
},
{
Expand Down Expand Up @@ -536,6 +522,32 @@
"test_taichi_op_register()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 更多示例\n",
"对于更多示例, 请参考: \n",
"- [event/_csr_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/event/_csr_matvec_taichi.py)\n",
"- [sparse/_csr_mv_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/sparse/_csr_mv_taichi.py)\n",
"- [jitconn/_event_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_event_matvec_taichi.py)\n",
"- [jitconn/_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_matvec_taichi.py)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 清除Taichi kernel的缓存\n",
"因为brainpy使用taichi的AOT方法来融合taichi和JAX,所以taichi的kernel会被缓存到系统中。如果你想清除缓存,可以使用以下代码:\n",
"\n",
"```python\n",
"import brainpy.math as bm\n",
"\n",
"bm.clean_caches()\n",
"```"
]
}
],
"metadata": {
Expand All @@ -554,7 +566,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down

0 comments on commit c2f2db9

Please sign in to comment.