Skip to content

Commit

Permalink
Merge 993774f into bf9e9df
Browse files Browse the repository at this point in the history
  • Loading branch information
keckj committed Mar 23, 2021
2 parents bf9e9df + 993774f commit 81c2aad
Show file tree
Hide file tree
Showing 8 changed files with 614 additions and 7 deletions.
23 changes: 23 additions & 0 deletions cupy/cuda/function.pyx
Expand Up @@ -90,6 +90,15 @@ cdef class CIntptr(CPointer):
self.ptr = <void*>&self.val


cdef class CNumpyArray(CPointer):
cdef:
object val

def __init__(self, v):
self.val = v
self.ptr = <void*><size_t>v.__array_interface__['data'][0]


cdef set _pointer_numpy_types = {numpy.dtype(i).type
for i in '?bhilqBHILQefdFD'}

Expand All @@ -109,6 +118,20 @@ cdef inline CPointer _pointer(x):
return x
if isinstance(x, (TextureObject, SurfaceObject)):
return CUIntMax(x.ptr)
if isinstance(x, numpy.ndarray):
# All numpy.ndarray work with CNumpyArray to pass a kernel argument by
# value. Here we allow only arrays of size one so that users do not
# mistakenly send numpy.ndarrays instead of cupy.ndarrays to kernels.
# This may happen if they forget to convert numpy arrays to cupy arrays
# prior to kernel call and would pass silently without this check.
if (x.size == 1):
return CNumpyArray(x)
else:
msg = ('You are trying to pass a numpy.ndarray of shape {} as a '
'kernel parameter. Only numpy.ndarrays of size one can be '
'passed by value. If you meant to pass a pointer to __glob'
'al__ memory, you need to pass a cupy.ndarray instead.')
raise TypeError(msg.format(x.shape))

if type(x) not in _pointer_numpy_types:
if isinstance(x, int):
Expand Down
118 changes: 112 additions & 6 deletions docs/source/tutorial/kernel.rst
Expand Up @@ -263,16 +263,122 @@ Accessing texture (surface) memory in :class:`~cupy.RawKernel` is supported via
The kernel does not have return values.
You need to pass both input arrays and output arrays as arguments.

.. note::
No validation will be performed by CuPy for arguments passed to the kernel, including types and number of arguments.
Especially note that when passing :class:`~cupy.ndarray`, its ``dtype`` should match with the type of the argument declared in the method signature of the CUDA source code (unless you are casting arrays intentionally).
For example, ``cupy.float32`` and ``cupy.uint64`` arrays must be passed to the argument typed as ``float*`` and ``unsigned long long*``.
For Python primitive types, ``int``, ``float``, ``complex`` and ``bool`` map to ``long long``, ``double``, ``cuDoubleComplex`` and ``bool``, respectively.

.. note::
When using ``printf()`` in your CUDA kernel, you may need to synchronize the stream to see the output.
You can use ``cupy.cuda.Stream.null.synchronize()`` if you are using the default stream.

Kernel arguments
----------------
Python primitive types and NumPy scalars are passed to the kernel by value.
Array arguments (pointer arguments) have to be passed as CuPy ndarrays.
No validation is performed by CuPy for arguments passed to the kernel, including types and number of arguments.

Especially note that when passing a CuPy :class:`~cupy.ndarray`, its ``dtype`` should match with the type of the argument declared in the function signature of the CUDA source code (unless you are casting arrays intentionally).

As an example, ``cupy.float32`` and ``cupy.uint64`` arrays must be passed to the argument typed as ``float*`` and ``unsigned long long*``, respectively. CuPy does not directly support arrays of non-primitive types such as ``float3``, but nothing prevents you from casting a ``float*`` or ``void*`` to a ``float3*`` in a kernel.

Python primitive types, ``int``, ``float``, ``complex`` and ``bool`` map to ``long long``, ``double``, ``cuDoubleComplex`` and ``bool``, respectively.

NumPy scalars (``numpy.generic``) and NumPy arrays (``numpy.ndarray``) **of size one**
are passed to the kernel by value.
This means that you can pass by value any base NumPy types such as ``numpy.int8`` or ``numpy.float64``, provided the kernel arguments match in size. You can refer to this table to match CuPy/NumPy dtype and CUDA types:

+-----------------+-----------------------------------------------+------------------+
| CuPy/NumPy type | Corresponding kernel types | itemsize (bytes) |
+=================+===============================================+==================+
| bool | bool | 1 |
+-----------------+-----------------------------------------------+------------------+
| int8 | char, signed char | 1 |
+-----------------+-----------------------------------------------+------------------+
| int16 | short, signed short | 2 |
+-----------------+-----------------------------------------------+------------------+
| int32 | int, signed int | 4 |
+-----------------+-----------------------------------------------+------------------+
| int64 | long long, signed long long | 8 |
+-----------------+-----------------------------------------------+------------------+
| uint8 | unsigned char | 1 |
+-----------------+-----------------------------------------------+------------------+
| uint16 | unsigned short | 2 |
+-----------------+-----------------------------------------------+------------------+
| uint32 | unsigned int | 4 |
+-----------------+-----------------------------------------------+------------------+
| uint64 | unsigned long long | 8 |
+-----------------+-----------------------------------------------+------------------+
| float16 | half | 2 |
+-----------------+-----------------------------------------------+------------------+
| float32 | float | 4 |
+-----------------+-----------------------------------------------+------------------+
| float64 | double | 8 |
+-----------------+-----------------------------------------------+------------------+
| complex64 | float2, cuFloatComplex, complex<float> | 8 |
+-----------------+-----------------------------------------------+------------------+
| complex128 | double2, cuDoubleComplex, complex<double> | 16 |
+-----------------+-----------------------------------------------+------------------+

The CUDA standard guarantees that the size of fundamental types on the host and device always match.
The itemsize of ``size_t``, ``ptrdiff_t``, ``intptr_t``, ``uintptr_t``,
``long``, ``signed long`` and ``unsigned long`` are however platform dependent.
To pass any CUDA vector builtins such as ``float3`` or any other user defined structure
as kernel arguments (provided it matches the device-side kernel parameter type), see section :ref:`custom_user_structs` below.

.. _custom_user_structs:

Custom user types
-----------------

Is is possible to use custom types (composite types such as structures and structures of structures)
as kernel arguments by defining a custom NumPy dtype.
When doing this, it is your responsability to match host and device structure memory layout.
The CUDA standard guarantees that the size of fundamental types on the host and device always match.
It may however impose device alignment requirements on composite types.
This means that for composite types, the struct member offsets may be different from what you might expect.

When a kernel argument is passed by value, the CUDA driver will copy exactly ``sizeof(param_type)`` bytes starting from the beginning of the NumPy object data pointer, where ``param_type`` is the parameter type in your kernel.
You have to match ``param_type``'s memory layout (ex: size, alignment and struct padding/packing)
by defining a corresponding `NumPy dtype <https://numpy.org/doc/stable/reference/arrays.dtypes.html>`_.

For builtin CUDA vector types such as ``int2`` and ``double4`` and other packed structures with
named members you can directly define such NumPy dtypes as the following:

.. doctest::

>>> import numpy as np
>>> names = ['x', 'y', 'z']
>>> types = [np.float32]*3
>>> float3 = np.dtype({'names': names, 'formats': types})
>>> arg = np.random.rand(3).astype(np.float32).view(float3)
>>> print(arg) # doctest: +SKIP
[(0.9940819, 0.62873816, 0.8953669)]
>>> arg['x'] = 42.0
>>> print(arg) # doctest: +SKIP
[(42., 0.62873816, 0.8953669)]

Here ``arg`` can be used directly as a kernel argument.
When there is no need to name fields you may prefer this syntax to define packed structures such as
vectors or matrices:

.. doctest::

>>> import numpy as np
>>> float5x5 = np.dtype({'names': ['dummy'], 'formats': [(np.float32,(5,5))]})
>>> arg = np.random.rand(25).astype(np.float32).view(float5x5)
>>> print(arg.itemsize)
100

Here ``arg`` represents a 100-byte scalar (i.e. a NumPy array of size 1)
that can be passed by value to any kernel.
Kernel parameters are passed by value in a dedicated 4kB memory bank which has its own cache with broadcast.
Upper bound for total kernel parameters size is thus 4kB
(see this `link <https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#function-parameters>`_).
It may be important to note that this dedicated memory bank is not shared with the device ``__constant__`` memory space.

For now, CuPy offers no helper routines to create user defined composite types.
Such composite types can however be built recursively using NumPy dtype `offsets` and `itemsize` capabilities,
see ``cupy/examples/user_structs`` for examples of advanced usage.

.. warning::
You cannot directly pass static arrays as kernel arguments with the ``type arg[N]`` syntax where N is a compile time constant. The signature of ``__global__ void kernel(float arg[5])`` is seen as ``__global__ void kernel(float* arg)`` by the compiler. If you want to pass five floats to the kernel by value you need to define a custom structure ``struct float5 { float val[5]; };`` and modify the kernel signature to ``__global__ void kernel(float5 arg)``.


Raw modules
-----------
Expand Down
11 changes: 11 additions & 0 deletions examples/custom_struct/README.md
@@ -0,0 +1,11 @@
# Custom user structure examples

This folder contains examples of custom user structures in `cupy.RawKernel` (see [https://docs.cupy.dev/en/stable/tutorial/kernel.html](https://docs.cupy.dev/en/stable/tutorial/kernel.html) for corresponding documentation).

This folder provides three scripts ranked by increasing complexity:

1. `builtins_vectors.py` shows how to use CUDA builtin vectors such as `float4` both as scalar parameter (pass by value from host) and array parameter in RawKernels.
2. `packed_matrix.py` demonstrates how to create and use templated packed structures in RawModules.
3. `complex_struct.py` illustrates the possibility to recursively build complex NumPy dtypes matching device structure memory layout.

All examples can be run as simple python scripts: `python3.x example_name.py`.
48 changes: 48 additions & 0 deletions examples/custom_struct/builtin_vectors.py
@@ -0,0 +1,48 @@
import sys
import numpy
import cupy

code = '''
__device__ double3 operator+(const double3& lhs, const double3& rhs) {
return make_double3(lhs.x + rhs.x,
lhs.y + rhs.y,
lhs.z + rhs.z);
}
extern "C" __global__ void sum_kernel(const double3* lhs,
double3 rhs,
double3* out) {
int i = threadIdx.x;
out[i] = lhs[i] + rhs;
}
'''

double3 = numpy.dtype(
{
'names': ['x', 'y', 'z'],
'formats': [numpy.float64]*3
}
)


def main():
N = 8

# The kernel computes out = lhs+rhs where lhs and rhs are double3 vectors.
# lhs is an array of N such vectors and rhs is double3 kernel parameter.

lhs = cupy.random.rand(3*N, dtype=numpy.float64).reshape(N, 3)
rhs = numpy.random.rand(3).astype(numpy.float64)
out = cupy.empty_like(lhs)

kernel = cupy.RawKernel(code, 'sum_kernel')
args = (lhs, rhs.view(double3), out)
kernel((1,), (N,), args)

expected = lhs + cupy.asarray(rhs[None, :])
cupy.testing.assert_array_equal(expected, out)
print("Kernel output matches expected value.")


if __name__ == '__main__':
sys.exit(main())
132 changes: 132 additions & 0 deletions examples/custom_struct/complex_struct.py
@@ -0,0 +1,132 @@
import sys
import numpy
import cupy

struct_definition = '''
struct complex_struct {
int4 a;
char b;
double c[2];
short1 d;
unsigned long long int e[3];
};
'''

struct_layout_code = '''
{struct_definition}
extern "C" __global__ void get_struct_layout(
unsigned long long *itemsize,
unsigned long long *sizes,
unsigned long long *offsets) {{
const complex_struct* ptr = NULL;
itemsize[0] = sizeof(complex_struct);
sizes[0] = sizeof(ptr->a);
sizes[1] = sizeof(ptr->b);
sizes[2] = sizeof(ptr->c);
sizes[3] = sizeof(ptr->d);
sizes[4] = sizeof(ptr->e);
offsets[0] = (unsigned long long)&ptr->a;
offsets[1] = (unsigned long long)&ptr->b;
offsets[2] = (unsigned long long)&ptr->c;
offsets[3] = (unsigned long long)&ptr->d;
offsets[4] = (unsigned long long)&ptr->e;
}}
'''.format(struct_definition=struct_definition)


kernel_code = '''
{struct_definition}
extern "C" __global__ void test_kernel(const complex_struct s,
double* out) {{
int i = threadIdx.x;
double sum = 0.0;
sum += s.a.x + s.a.y + s.a.z + s.a.w;
sum += s.b;
sum += s.c[0] + s.c[1];
sum += s.d.x;
sum += s.e[0] + s.e[1] + s.e[2];
out[i] = i * sum;
}}
'''.format(struct_definition=struct_definition)


def make_packed(basetype, N, itemsize):
# A small utility function to make packed structs
# Can represent simple packed vectors such as float4 or double[3].
assert 0 < N <= 4, N
names = list('xyzw')[:N]
formats = [basetype]*N
return numpy.dtype(dict(names=names,
formats=formats,
itemsize=itemsize))


def main():
# This program demonstrate how to build a hostside
# representation of device structure 'complex_struct'
# defined in variable 'struct_definition' that can be
# used as a RawKernel argument.

# First step is to determine structure memory layout
# itemsize -> overall struct size
# sizes -> individual struct member sizes, determined with sizeof
# offsets -> individual struct member offsets, determined with offsetof
# Results (in terms of bytes) are copied to host after kernel launch.
# Note that 'complex_struct' has 5 members named a, b, c, d and e.
itemsize = cupy.ndarray(shape=(1,), dtype=numpy.uint64)
sizes = cupy.ndarray(shape=(5,), dtype=numpy.uint64)
offsets = cupy.ndarray(shape=(5,), dtype=numpy.uint64)

kernel = cupy.RawKernel(struct_layout_code, 'get_struct_layout')
kernel((1,), (1,), (itemsize, sizes, offsets))

(itemsize, sizes, offsets) = map(cupy.asnumpy, (itemsize, sizes, offsets))
print("Overall structure itemsize: {} bytes".format(itemsize.item()))
print("Structure members itemsize: {}".format(sizes))
print("Structure members offsets: {}".format(offsets))

# Second step: build a numpy dtype for each struct member
atype = make_packed(numpy.int32, 4, sizes[0])
btype = make_packed(numpy.int8, 1, sizes[1])
ctype = make_packed(numpy.float64, 2, sizes[2])
dtype = make_packed(numpy.int16, 1, sizes[3])
etype = make_packed(numpy.uint64, 3, sizes[4])

# Third step: create the complex struct representation with
# the right offsets
names = list('abcde')
formats = [atype, btype, ctype, dtype, etype]
complex_struct = numpy.dtype(dict(names=names,
formats=formats,
offsets=offsets,
itemsize=itemsize.item()))

# Build a complex_struct kernel argument
s = numpy.empty(shape=(1,), dtype=complex_struct)
s['a'] = numpy.arange(0, 4).astype(numpy.int32).view(atype)
s['b'] = numpy.arange(4, 5).astype(numpy.int8).view(btype)
s['c'] = numpy.arange(5, 7).astype(numpy.float64).view(ctype)
s['d'] = numpy.arange(7, 8).astype(numpy.int16).view(dtype)
s['e'] = numpy.arange(8, 11).astype(numpy.uint64).view(etype)
print("Complex structure value:\n {}".format(s))

# Setup test kernel
N = 8
out = cupy.empty(shape=(N,), dtype=numpy.float64)
kernel = cupy.RawKernel(kernel_code, 'test_kernel')
kernel((1,), (N,), (s, out))

# the sum of all members of our complex struct instance is 55.0
expected = cupy.arange(N) * 55.0

cupy.testing.assert_array_almost_equal(expected, out)
print("Kernel output matches expected value.")


if __name__ == '__main__':
sys.exit(main())

0 comments on commit 81c2aad

Please sign in to comment.