Skip to content

Commit

Permalink
MAINT: Use strong references/copies for sorting buffer
Browse files Browse the repository at this point in the history
I accidentally included a start for a cleanup in the other PR which
was incorrect because we actually sorted weak references.

Sorting weak references buffer is correct, but seems a bit trickier
to reason about and also potentially to generalize.

This makes sure we have strong references everywhere and fixes the
issue seen by pandas.  Also adds a (slightly complex) test to cover
both the sort and argsort path.
  • Loading branch information
seberg committed Feb 25, 2023
1 parent 5f04e74 commit 07f603d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
28 changes: 16 additions & 12 deletions numpy/core/src/multiarray/item_selection.c
Expand Up @@ -22,6 +22,7 @@
#include "ctors.h"
#include "lowlevel_strided_loops.h"
#include "array_assign.h"
#include "refcount.h"

#include "npy_sort.h"
#include "npy_partition.h"
Expand Down Expand Up @@ -1070,6 +1071,9 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
ret = -1;
goto fail;
}
if (PyDataType_FLAGCHK(PyArray_DESCR(op), NPY_NEEDS_INIT)) {
memset(buffer, 0, N * elsize);
}
}

NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(op));
Expand Down Expand Up @@ -1114,16 +1118,7 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
}

if (needcopy) {
if (hasrefs) {
if (swap) {
copyswapn(buffer, elsize, NULL, 0, N, swap, op);
}
_unaligned_strided_byte_copy(it->dataptr, astride,
buffer, elsize, N, elsize);
}
else {
copyswapn(it->dataptr, astride, buffer, elsize, N, swap, op);
}
copyswapn(it->dataptr, astride, buffer, elsize, N, swap, op);
}

PyArray_ITER_NEXT(it);
Expand All @@ -1132,7 +1127,10 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
fail:
NPY_END_THREADS_DESCR(PyArray_DESCR(op));
/* cleanup internal buffer */
PyDataMem_UserFREE(buffer, N * elsize, mem_handler);
if (needcopy) {
PyArray_ClearBuffer(PyArray_DESCR(op), buffer, elsize, N, 1);
PyDataMem_UserFREE(buffer, N * elsize, mem_handler);
}
if (ret < 0 && !PyErr_Occurred()) {
/* Out of memory during sorting or buffer creation */
PyErr_NoMemory();
Expand Down Expand Up @@ -1206,6 +1204,9 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
ret = -1;
goto fail;
}
if (PyDataType_FLAGCHK(PyArray_DESCR(op), NPY_NEEDS_INIT)) {
memset(valbuffer, 0, N * elsize);
}
}

if (needidxbuffer) {
Expand Down Expand Up @@ -1281,7 +1282,10 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
fail:
NPY_END_THREADS_DESCR(PyArray_DESCR(op));
/* cleanup internal buffers */
PyDataMem_UserFREE(valbuffer, N * elsize, mem_handler);
if (needcopy) {
PyArray_ClearBuffer(PyArray_DESCR(op), valbuffer, elsize, N, 1);
PyDataMem_UserFREE(valbuffer, N * elsize, mem_handler);
}
PyDataMem_UserFREE(idxbuffer, N * sizeof(npy_intp), mem_handler);
if (ret < 0) {
if (!PyErr_Occurred()) {
Expand Down
21 changes: 14 additions & 7 deletions numpy/core/tests/test_multiarray.py
Expand Up @@ -2116,19 +2116,26 @@ def test_sort_object(self):
c.sort(kind=kind)
assert_equal(c, a, msg)

def test_sort_structured(self):
@pytest.mark.parametrize("dt", [
np.dtype([('f', float), ('i', int)]),
np.dtype([('f', float), ('i', object)])])
@pytest.mark.parametrize("step", [1, 2])
def test_sort_structured(self, dt, step):
# test record array sorts.
dt = np.dtype([('f', float), ('i', int)])
a = np.array([(i, i) for i in range(101)], dtype=dt)
a = np.array([(i, i) for i in range(101*step)], dtype=dt)
b = a[::-1]
for kind in ['q', 'h', 'm']:
msg = "kind=%s" % kind
c = a.copy()
c = a.copy()[::step]
indx = c.argsort(kind=kind)
c.sort(kind=kind)
assert_equal(c, a, msg)
c = b.copy()
assert_equal(c, a[::step], msg)
assert_equal(a[::step][indx], a[::step], msg)
c = b.copy()[::step]
indx = c.argsort(kind=kind)
c.sort(kind=kind)
assert_equal(c, a, msg)
assert_equal(c, a[step-1::step], msg)
assert_equal(b[::step][indx], a[step-1::step], msg)

@pytest.mark.parametrize('dtype', ['datetime64[D]', 'timedelta64[D]'])
def test_sort_time(self, dtype):
Expand Down

0 comments on commit 07f603d

Please sign in to comment.