Skip to content

Commit

Permalink
Merge pull request #2696 from seberg/issue434
Browse files Browse the repository at this point in the history
BUG: Fix bug with size 1-dims in CreateSortedStridePerm
  • Loading branch information
certik committed Nov 16, 2012
2 parents 8ab301a + e565afb commit 3ecbac5
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 21 deletions.
1 change: 0 additions & 1 deletion numpy/core/src/multiarray/ctors.c
Expand Up @@ -1112,7 +1112,6 @@ PyArray_NewLikeArray(PyArrayObject *prototype, NPY_ORDER order,
int idim;

PyArray_CreateSortedStridePerm(PyArray_NDIM(prototype),
PyArray_SHAPE(prototype),
PyArray_STRIDES(prototype),
strideperm);

Expand Down
6 changes: 3 additions & 3 deletions numpy/core/src/multiarray/dtype_transfer.c
Expand Up @@ -3921,7 +3921,7 @@ PyArray_PrepareOneRawArrayIter(int ndim, npy_intp *shape,
}

/* Sort the axes based on the destination strides */
PyArray_CreateSortedStridePerm(ndim, shape, strides, strideperm);
PyArray_CreateSortedStridePerm(ndim, strides, strideperm);
for (i = 0; i < ndim; ++i) {
int iperm = strideperm[ndim - i - 1].perm;
out_shape[i] = shape[iperm];
Expand Down Expand Up @@ -4051,7 +4051,7 @@ PyArray_PrepareTwoRawArrayIter(int ndim, npy_intp *shape,
}

/* Sort the axes based on the destination strides */
PyArray_CreateSortedStridePerm(ndim, shape, stridesA, strideperm);
PyArray_CreateSortedStridePerm(ndim, stridesA, strideperm);
for (i = 0; i < ndim; ++i) {
int iperm = strideperm[ndim - i - 1].perm;
out_shape[i] = shape[iperm];
Expand Down Expand Up @@ -4185,7 +4185,7 @@ PyArray_PrepareThreeRawArrayIter(int ndim, npy_intp *shape,
}

/* Sort the axes based on the destination strides */
PyArray_CreateSortedStridePerm(ndim, shape, stridesA, strideperm);
PyArray_CreateSortedStridePerm(ndim, stridesA, strideperm);
for (i = 0; i < ndim; ++i) {
int iperm = strideperm[ndim - i - 1].perm;
out_shape[i] = shape[iperm];
Expand Down
18 changes: 5 additions & 13 deletions numpy/core/src/multiarray/shape.c
Expand Up @@ -849,7 +849,7 @@ int _npy_stride_sort_item_comparator(const void *a, const void *b)
bstride = -bstride;
}

if (astride == bstride || astride == 0 || bstride == 0) {
if (astride == bstride) {
/*
* Make the qsort stable by next comparing the perm order.
* (Note that two perm entries will never be equal)
Expand All @@ -861,9 +861,7 @@ int _npy_stride_sort_item_comparator(const void *a, const void *b)
if (astride > bstride) {
return -1;
}
else {
return 1;
}
return 1;
}

/*NUMPY_API
Expand All @@ -874,21 +872,15 @@ int _npy_stride_sort_item_comparator(const void *a, const void *b)
* [(2, 12), (0, 4), (1, -2)].
*/
NPY_NO_EXPORT void
PyArray_CreateSortedStridePerm(int ndim, npy_intp *shape,
npy_intp *strides,
PyArray_CreateSortedStridePerm(int ndim, npy_intp *strides,
npy_stride_sort_item *out_strideperm)
{
int i;

/* Set up the strideperm values */
for (i = 0; i < ndim; ++i) {
out_strideperm[i].perm = i;
if (shape[i] == 1) {
out_strideperm[i].stride = 0;
}
else {
out_strideperm[i].stride = strides[i];
}
out_strideperm[i].stride = strides[i];
}

/* Sort them */
Expand Down Expand Up @@ -1027,7 +1019,7 @@ PyArray_Ravel(PyArrayObject *arr, NPY_ORDER order)
npy_intp stride;
int i, ndim = PyArray_NDIM(arr);

PyArray_CreateSortedStridePerm(PyArray_NDIM(arr), PyArray_SHAPE(arr),
PyArray_CreateSortedStridePerm(PyArray_NDIM(arr),
PyArray_STRIDES(arr), strideperm);

stride = strideperm[ndim-1].stride;
Expand Down
4 changes: 2 additions & 2 deletions numpy/core/src/umath/reduction.c
Expand Up @@ -51,7 +51,7 @@ allocate_reduce_result(PyArrayObject *arr, npy_bool *axis_flags,
Py_INCREF(dtype);
}

PyArray_CreateSortedStridePerm(PyArray_NDIM(arr), PyArray_SHAPE(arr),
PyArray_CreateSortedStridePerm(PyArray_NDIM(arr),
PyArray_STRIDES(arr), strideperm);

/* Build the new strides and shape */
Expand All @@ -60,7 +60,7 @@ allocate_reduce_result(PyArrayObject *arr, npy_bool *axis_flags,
for (idim = ndim-1; idim >= 0; --idim) {
npy_intp i_perm = strideperm[idim].perm;
if (axis_flags[i_perm]) {
strides[i_perm] = 0;
strides[i_perm] = stride;
shape[i_perm] = 1;
}
else {
Expand Down
4 changes: 2 additions & 2 deletions numpy/core/tests/test_api.py
Expand Up @@ -138,9 +138,9 @@ def test_copyto():
assert_raises(TypeError, np.copyto, [1,2,3], [2,3,4])

def test_copy_order():
a = np.arange(24).reshape(2,3,4)
a = np.arange(24).reshape(2,1,3,4)
b = a.copy(order='F')
c = np.arange(24).reshape(2,4,3).swapaxes(1,2)
c = np.arange(24).reshape(2,1,4,3).swapaxes(2,3)

def check_copy_result(x, y, ccontig, fcontig, strides=False):
assert_(not (x is y))
Expand Down

0 comments on commit 3ecbac5

Please sign in to comment.