Skip to content

Commit

Permalink
Merge 7001fba into ae6ecfa
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Jul 9, 2017
2 parents ae6ecfa + 7001fba commit abb17b8
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 0 deletions.
11 changes: 11 additions & 0 deletions cupy/cuda/function.pyx
@@ -1,6 +1,7 @@
# distutils: language = c++

import numpy
import six

cimport cpython
from libcpp cimport vector
Expand Down Expand Up @@ -69,6 +70,16 @@ cdef inline CPointer _pointer(x):
if isinstance(x, core.Indexer):
return (<core.Indexer>x).get_pointer()

if type(x) not in _pointer_numpy_types:
if isinstance(x, six.integer_types):
x = numpy.int64(x)
elif isinstance(x, float):
x = numpy.float64(x)
elif isinstance(x, bool):
x = numpy.bool_(x)
else:
raise TypeError('Unsupported type %s' % type(x))

itemsize = x.itemsize
if itemsize == 1:
return CInt8(x.view(numpy.int8))
Expand Down
56 changes: 56 additions & 0 deletions tests/cupy_tests/core_tests/test_function.py
@@ -0,0 +1,56 @@
import unittest

import numpy

import cupy
from cupy.cuda import compiler
from cupy import testing


def _compile_func(kernel_name, code):
mod = compiler.compile_with_cache(code)
return mod.get_function(kernel_name)


@testing.gpu
class TestFunction(unittest.TestCase):

def test_python_scalar(self):
code = '''
extern "C" __global__ void test_kernel(const double* a, double b, double* x) {
int i = threadIdx.x;
x[i] = a[i] + b;
}
'''

a_cpu = numpy.arange(24, dtype=numpy.float64).reshape((4, 6))
a = cupy.array(a_cpu)
b = float(2)
x = cupy.empty_like(a)

func = _compile_func('test_kernel', code)

func.linear_launch(a.size, (a, b, x))

expected = a_cpu + b
testing.assert_array_equal(x, expected)

def test_numpy_scalar(self):
code = '''
extern "C" __global__ void test_kernel(const double* a, double b, double* x) {
int i = threadIdx.x;
x[i] = a[i] + b;
}
'''

a_cpu = numpy.arange(24, dtype=numpy.float64).reshape((4, 6))
a = cupy.array(a_cpu)
b = numpy.float64(2)
x = cupy.empty_like(a)

func = _compile_func('test_kernel', code)

func.linear_launch(a.size, (a, b, x))

expected = a_cpu + b
testing.assert_array_equal(x, expected)
37 changes: 37 additions & 0 deletions tests/cupy_tests/core_tests/test_userkernel.py
@@ -1,5 +1,7 @@
import unittest

import numpy

import cupy
from cupy import testing

Expand Down Expand Up @@ -29,3 +31,38 @@ def test_manual_indexing(self, n=100):
out2 = uesr_kernel_2(in1, in2, size=n)

testing.assert_array_equal(out1, out2)

def test_python_scalar(self):
for typ in (int, float, bool):
dtype = numpy.dtype(typ).type
in1_cpu = numpy.random.randint(0, 1, (4, 5)).astype(dtype)
in1 = cupy.array(in1_cpu)
scalar_value = typ(2)
uesr_kernel_1 = cupy.ElementwiseKernel(
'T x, T y',
'T z',
'''
z = x + y;
''',
'uesr_kernel_1')
out1 = uesr_kernel_1(in1, scalar_value)

expected = in1_cpu + dtype(2)
testing.assert_array_equal(out1, expected)

@testing.for_all_dtypes()
def test_numpy_scalar(self, dtype):
in1_cpu = numpy.random.randint(0, 1, (4, 5)).astype(dtype)
in1 = cupy.array(in1_cpu)
scalar_value = dtype(2)
uesr_kernel_1 = cupy.ElementwiseKernel(
'T x, T y',
'T z',
'''
z = x + y;
''',
'uesr_kernel_1')
out1 = uesr_kernel_1(in1, scalar_value)

expected = in1_cpu + dtype(2)
testing.assert_array_equal(out1, expected)

0 comments on commit abb17b8

Please sign in to comment.