Skip to content

Commit

Permalink
Merge pull request #4675 from sklam/fix/cuda_array_face
Browse files Browse the repository at this point in the history
Bump cuda array interface to version 2
  • Loading branch information
sklam committed Oct 9, 2019
2 parents 9eb26b3 + 31604cd commit 02188e5
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
8 changes: 4 additions & 4 deletions docs/source/cuda/cuda_array_interface.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
.. _cuda-array-interface:

====================
CUDA Array Interface
====================
================================
CUDA Array Interface (Version 2)
================================

The *cuda array inteface* is created for interoperability between different
implementation of GPU array-like objects in various projects. The idea is
Expand Down Expand Up @@ -91,7 +91,7 @@ include:
.. _numpy array interface: https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface__


Differences with CUDA Array Interface (Version 0)
Differences with CUDA Array Interface (Version 0)
-------------------------------------------------

Version 0 of the CUDA Array Interface did not have the optional **mask**
Expand Down
3 changes: 2 additions & 1 deletion numba/cuda/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def from_cuda_array_interface(desc, owner=None):
The resulting DeviceNDArray will acquire a reference from it.
"""
version = desc.get('version')
if version == 1:
# Mask introduced in version 1
if 1 <= version:
mask = desc.get('mask')
# Would ideally be better to detect if the mask is all valid
if mask is not None:
Expand Down
2 changes: 1 addition & 1 deletion numba/cuda/cudadrv/devicearray.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __cuda_array_interface__(self):
'strides': strides,
'data': (ptr, False),
'typestr': self.dtype.str,
'version': 1,
'version': 2,
}

def bind(self, stream=0):
Expand Down
14 changes: 14 additions & 0 deletions numba/cuda/tests/cudapy/test_cuda_array_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,20 @@ def test_strides(self):
c_arr = c_arr[:, 1, :]
self.assertNotEqual(c_arr.__cuda_array_interface__['strides'], None)

def test_consuming_strides(self):
hostarray = np.arange(10).reshape(2, 5)
face = cuda.to_device(hostarray).__cuda_array_interface__
self.assertIsNone(face['strides'])
got = cuda.from_cuda_array_interface(face).copy_to_host()
np.testing.assert_array_equal(got, hostarray)
self.assertTrue(got.flags['C_CONTIGUOUS'])
# Try non-NULL strides
face['strides'] = hostarray.strides
self.assertIsNotNone(face['strides'])
got = cuda.from_cuda_array_interface(face).copy_to_host()
np.testing.assert_array_equal(got, hostarray)
self.assertTrue(got.flags['C_CONTIGUOUS'])


if __name__ == "__main__":
unittest.main()

0 comments on commit 02188e5

Please sign in to comment.