Skip to content

Commit

Permalink
Fix ravel for strides 0
Browse files Browse the repository at this point in the history
  • Loading branch information
Emilio Castillo committed Nov 1, 2021
1 parent 5b37c79 commit f5b13dd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
3 changes: 2 additions & 1 deletion cupy/_core/_routines_manipulation.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ cpdef ndarray _move_single_axis(ndarray a, Py_ssize_t source,
Py_ssize_t destination)
cpdef ndarray rollaxis(ndarray a, Py_ssize_t axis, Py_ssize_t start=*)
cpdef ndarray broadcast_to(ndarray array, shape)
cpdef ndarray _reshape(ndarray self, const shape_t &shape_spec)
cpdef ndarray _reshape(ndarray self, const shape_t &shape_spec,
bint enforce_copy=*)
cpdef ndarray _T(ndarray self)
cpdef ndarray _transpose(ndarray self, const vector.vector[Py_ssize_t] &axes)
cpdef ndarray _concatenate(
Expand Down
20 changes: 12 additions & 8 deletions cupy/_core/_routines_manipulation.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,16 @@ cdef ndarray _ndarray_ravel(ndarray self, order):
shape.push_back(self.size)

order_char = internal._normalize_order(order, True)
enforce_copy = not self._f_contiguous and not self._c_contiguous
if order_char == b'A':
if self._f_contiguous and not self._c_contiguous:
order_char = b'F'
else:
order_char = b'C'
if order_char == b'C':
return _reshape(self, shape)
return _reshape(self, shape, enforce_copy)
elif order_char == b'F':
return _reshape(_T(self), shape)
return _reshape(_T(self), shape, enforce_copy)
elif order_char == b'K':
raise NotImplementedError(
'ravel with order=\'K\' not yet implemented.')
Expand Down Expand Up @@ -294,17 +295,20 @@ cpdef ndarray rollaxis(ndarray a, Py_ssize_t axis, Py_ssize_t start=0):
return _transpose(a, axes)


cpdef ndarray _reshape(ndarray self, const shape_t &shape_spec):
cpdef ndarray _reshape(
ndarray self, const shape_t &shape_spec, bint enforce_copy=0):
cdef shape_t shape
cdef strides_t strides
cdef ndarray newarray
shape = internal.infer_unknown_dimension(shape_spec, self.size)
if internal.vector_equal(shape, self._shape):
return self.view()
if not enforce_copy:
if internal.vector_equal(shape, self._shape):
return self.view()

_get_strides_for_nocopy_reshape(self, shape, strides)
if strides.size() == shape.size():
return self._view(shape, strides, False, True)

_get_strides_for_nocopy_reshape(self, shape, strides)
if strides.size() == shape.size():
return self._view(shape, strides, False, True)
newarray = self.copy()
_get_strides_for_nocopy_reshape(newarray, shape, strides)

Expand Down
14 changes: 14 additions & 0 deletions tests/cupy_tests/manipulation_tests/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,20 @@ def test_external_ravel(self, xp):
a = a.transpose(2, 0, 1)
return xp.ravel(a)

@testing.for_orders('CFA')
@testing.numpy_cupy_array_equal()
def test_ravel_non_contiguous(self, xp, order):
a = xp.array([1])
b = xp.broadcast_to(a, (10,))
assert not b.flags.c_contiguous and not b.flags.f_contiguous
b = b.ravel(order)
if order == 'C':
assert b.flags.c_contiguous
else:
assert b.flags.f_contiguous
return b



@testing.parameterize(*testing.product({
'order_init': ['C', 'F'],
Expand Down

0 comments on commit f5b13dd

Please sign in to comment.