Skip to content

Commit

Permalink
Format codes
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 15, 2024
1 parent baec6d7 commit f9cba21
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
6 changes: 4 additions & 2 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
numba_install_info = ('We need numba. Please install numba by pip . \n'
'> pip install numba')
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
'For CUDA v12.x > pip install cupy-cuda12x\n')
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
'For CUDA v12.x > pip install cupy-cuda12x\n')
os.environ["TI_LOG_LEVEL"] = "error"


Expand Down Expand Up @@ -105,9 +105,11 @@ def import_cupy(error_if_not_found=True):
return None
return cupy


def raise_cupy_not_found():
raise ModuleNotFoundError(cupy_install_info)


def is_brainpylib_gpu_installed():
return False if brainpylib_gpu_ops is None else True

Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
gpu_checked = False
if gpu_kernel is None:
gpu_checked = True
elif isinstance(gpu_kernel, str): # cupy
elif isinstance(gpu_kernel, str): # cupy
register_cupy_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
Expand Down
5 changes: 2 additions & 3 deletions brainpy/_src/math/op_register/tests/test_cupy_based.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp

import pytest

import brainpy.math as bm
from brainpy._src.dependency_check import import_cupy

Expand Down Expand Up @@ -43,7 +43,7 @@ def test_cupy_based():

# n = jnp.asarray([N**2,], dtype=jnp.int32)

y = prim(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0]
y = prim(x1, x2, N ** 2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0]

print(y)
assert jnp.allclose(y, x1 + x2)
Expand All @@ -57,5 +57,4 @@ def test_cupy_based():
# ker_times((N,), (N,), (x1, x2, y, N**2)) # y = x1 * x2
# assert cp.allclose(y, x1 * x2)


# test_cupy_based()

0 comments on commit f9cba21

Please sign in to comment.