Skip to content

Commit

Permalink
fix: Fix contiguity check for CAI/DLPack buffers with shape[i] <= 1
Browse files Browse the repository at this point in the history
  • Loading branch information
dalcinl committed Nov 25, 2021
1 parent d53c159 commit 9ccf55f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
7 changes: 4 additions & 3 deletions src/mpi4py/MPI/ascaibuf.pxi
Expand Up @@ -8,7 +8,7 @@ cdef inline int cuda_is_contig(tuple shape,
Py_ssize_t itemsize,
char order) except -1:
cdef Py_ssize_t i, ndim = len(shape)
cdef Py_ssize_t start, step, index
cdef Py_ssize_t start, step, index, dim, size = itemsize
if order == c'F':
start = 0
step = 1
Expand All @@ -17,9 +17,10 @@ cdef inline int cuda_is_contig(tuple shape,
step = -1
for i from 0 <= i < ndim:
index = start + step * i
if itemsize != <Py_ssize_t>strides[index]:
dim = <Py_ssize_t>shape[index]
if dim > 1 and size != <Py_ssize_t>strides[index]:
return 0
itemsize *= <Py_ssize_t>shape[index]
size *= dim
return 1

cdef inline char* cuda_get_format(char typekind, Py_ssize_t itemsize) nogil:
Expand Down
7 changes: 4 additions & 3 deletions src/mpi4py/MPI/asdlpack.pxi
Expand Up @@ -65,7 +65,7 @@ cdef inline int dlpack_is_contig(const DLTensor *dltensor, char order) nogil:
cdef int i, ndim = dltensor.ndim
cdef int64_t *shape = dltensor.shape
cdef int64_t *strides = dltensor.strides
cdef int64_t start, step, index, size = 1
cdef int64_t start, step, index, dim, size = 1
if strides == NULL:
if ndim > 1 and order == c'F':
return 0
Expand All @@ -78,9 +78,10 @@ cdef inline int dlpack_is_contig(const DLTensor *dltensor, char order) nogil:
step = -1
for i from 0 <= i < ndim:
index = start + step * i
if size != strides[index]:
dim = shape[index]
if dim > 1 and size != strides[index]:
return 0
size *= shape[index]
size *= dim
return 1

cdef inline int dlpack_check_shape(const DLTensor *dltensor) except -1:
Expand Down
18 changes: 17 additions & 1 deletion test/test_msgspec.py
Expand Up @@ -663,6 +663,18 @@ def testContiguous(self):
self.assertRaises(BufferError, MPI.Get_address, buf)
del s
#
dltensor.ndim, dltensor.shape, dltensor.strides = \
dlpack.make_dl_shape([1, 3, 1], order='C')
s = dltensor.strides
MPI.Get_address(buf)
for i in range(4):
for j in range(4):
s[0], s[2] = i, j
MPI.Get_address(buf)
s[1] = 0
self.assertRaises(BufferError, MPI.Get_address, buf)
del s
#
del dltensor

def testByteOffset(self):
Expand Down Expand Up @@ -690,8 +702,12 @@ def testNonReadonly(self):
def testNonContiguous(self):
smsg = CAIBuf('i', [1,2,3])
rmsg = CAIBuf('i', [0,0,0])
Sendrecv(smsg, rmsg)
strides = rmsg.__cuda_array_interface__['strides']
bad_strides = strides[:-1] + (7,)
good_strides = strides[:-2] + (0, 7)
rmsg.__cuda_array_interface__['strides'] = good_strides
Sendrecv(smsg, rmsg)
bad_strides = (7,) + strides[1:]
rmsg.__cuda_array_interface__['strides'] = bad_strides
self.assertRaises(BufferError, Sendrecv, smsg, rmsg)

Expand Down

0 comments on commit 9ccf55f

Please sign in to comment.