Skip to content

Commit

Permalink
Merge pull request #7996 from kmaehashi/add-real-cuda-test
Browse files Browse the repository at this point in the history
Add import test without CUDA Toolkit
  • Loading branch information
takagi committed Mar 8, 2024
2 parents d2abc0f + e89cfe5 commit 272dfda
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .pfnci/linux/tests/cuda-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ export CUPY_NVCC_GENERATE_CODE="arch=compute_70,code=sm_70"

# Make sure that CuPy can be imported without CUDA Toolkit installed.
rm -rf /usr/local/cuda*
pushd /
python3 -c 'import cupy, cupyx; cupy.show_config(_full=True)'
pushd tests/import_tests
python3 test_import.py
popd

"$ACTIONS/cleanup.sh"
50 changes: 50 additions & 0 deletions tests/import_tests/test_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import cupy
from cupyx import jit

"""
Test to ensure that this file can be imported without CUDA Toolkit.
"""


@cupy.memoize()
def user_func(a: cupy.ndarray):
a.sum()


squared_diff = cupy.ElementwiseKernel(
'float32 x, float32 y',
'float32 z',
'z = (x - y) * (x - y)',
'squared_diff')


l2norm_kernel = cupy.ReductionKernel(
'T x', # input params
'T y', # output params
'x * x', # map
'a + b', # reduce
'y = sqrt(a)', # post-reduction map
'0', # identity value
'l2norm' # kernel name
)

complex_kernel = cupy.RawKernel(r'''
#include <cupy/complex.cuh>
extern "C" __global__
void my_func(const complex<float>* x1, const complex<float>* x2,
complex<float>* y, float a) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
y[tid] = x1[tid] + a * x2[tid];
}
''', 'my_func')


@jit.rawkernel()
def elementwise_copy(x, y, size):
tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x
ntid = jit.gridDim.x * jit.blockDim.x
for i in range(tid, size, ntid):
y[i] = x[i]


cupy.show_config(_full=True)

0 comments on commit 272dfda

Please sign in to comment.