Skip to content

Commit

Permalink
ENH: Avoid use of item XINCREF and DECREF in fasttake
Browse files Browse the repository at this point in the history
Rather, use the cast function directly when the copy is not trivial
(which we know based on it not passing REFCHK, if that passes we assume
memcpy is fine).

Also uses memcpy, since no overlap is possible here.
  • Loading branch information
seberg committed Feb 20, 2023
1 parent af9f656 commit 94472bd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 24 deletions.
2 changes: 1 addition & 1 deletion benchmarks/benchmarks/bench_itemselection.py
Expand Up @@ -7,7 +7,7 @@ class Take(Benchmark):
params = [
[(1000, 1), (1000, 2), (2, 1000, 1), (1000, 3)],
["raise", "wrap", "clip"],
TYPES1]
TYPES1 + ["O", "i,O"]]
param_names = ["shape", "mode", "dtype"]

def setup(self, shape, mode, dtype):
Expand Down
78 changes: 55 additions & 23 deletions numpy/core/src/multiarray/item_selection.c
Expand Up @@ -17,6 +17,7 @@

#include "multiarraymodule.h"
#include "common.h"
#include "dtype_transfer.h"
#include "arrayobject.h"
#include "ctors.h"
#include "lowlevel_strided_loops.h"
Expand All @@ -39,28 +40,49 @@ npy_fasttake_impl(
PyArray_Descr *dtype, int axis)
{
NPY_BEGIN_THREADS_DEF;
NPY_BEGIN_THREADS_DESCR(dtype);

NPY_cast_info cast_info;
NPY_ARRAYMETHOD_FLAGS flags;
NPY_cast_info_init(&cast_info);

if (!needs_refcounting) {
/* if "refcounting" is not needed memcpy is safe for a simple copy */
NPY_BEGIN_THREADS;
}
else {
if (PyArray_GetDTypeTransferFunction(
1, itemsize, itemsize, dtype, dtype, 0,
&cast_info, &flags) < 0) {
return -1;
}
if (!(flags & NPY_METH_REQUIRES_PYAPI)) {
NPY_BEGIN_THREADS;
}
}

switch (clipmode) {
case NPY_RAISE:
for (npy_intp i = 0; i < n; i++) {
for (npy_intp j = 0; j < m; j++) {
npy_intp tmp = indices[j];
if (check_and_adjust_index(&tmp, max_item, axis,
_save) < 0) {
return -1;
goto fail;
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
for (npy_intp k = 0; k < nelem; k++) {
PyArray_Item_INCREF(tmp_src, dtype);
PyArray_Item_XDECREF(dest, dtype);
memmove(dest, tmp_src, itemsize);
dest += itemsize;
tmp_src += itemsize;
char *data[2] = {tmp_src, dest};
npy_intp strides[2] = {itemsize, itemsize};
if (cast_info.func(
&cast_info.context, data, &nelem, strides,
cast_info.auxdata) < 0) {
NPY_END_THREADS;
goto fail;
}
dest += itemsize * nelem;
}
else {
memmove(dest, tmp_src, chunk);
memcpy(dest, tmp_src, chunk);
dest += chunk;
}
}
Expand All @@ -83,16 +105,18 @@ npy_fasttake_impl(
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
for (npy_intp k = 0; k < nelem; k++) {
PyArray_Item_INCREF(tmp_src, dtype);
PyArray_Item_XDECREF(dest, dtype);
memmove(dest, tmp_src, itemsize);
dest += itemsize;
tmp_src += itemsize;
char *data[2] = {tmp_src, dest};
npy_intp strides[2] = {itemsize, itemsize};
if (cast_info.func(
&cast_info.context, data, &nelem, strides,
cast_info.auxdata) < 0) {
NPY_END_THREADS;
goto fail;
}
dest += itemsize * nelem;
}
else {
memmove(dest, tmp_src, chunk);
memcpy(dest, tmp_src, chunk);
dest += chunk;
}
}
Expand All @@ -111,16 +135,18 @@ npy_fasttake_impl(
}
char *tmp_src = src + tmp * chunk;
if (needs_refcounting) {
for (npy_intp k = 0; k < nelem; k++) {
PyArray_Item_INCREF(tmp_src, dtype);
PyArray_Item_XDECREF(dest, dtype);
memmove(dest, tmp_src, itemsize);
dest += itemsize;
tmp_src += itemsize;
char *data[2] = {tmp_src, dest};
npy_intp strides[2] = {itemsize, itemsize};
if (cast_info.func(
&cast_info.context, data, &nelem, strides,
cast_info.auxdata) < 0) {
NPY_END_THREADS;
goto fail;
}
dest += itemsize * nelem;
}
else {
memmove(dest, tmp_src, chunk);
memcpy(dest, tmp_src, chunk);
dest += chunk;
}
}
Expand All @@ -130,7 +156,13 @@ npy_fasttake_impl(
}

NPY_END_THREADS;
NPY_cast_info_xfree(&cast_info);
return 0;

fail:
/* NPY_END_THREADS already ensured. */
NPY_cast_info_xfree(&cast_info);
return -1;
}


Expand Down

0 comments on commit 94472bd

Please sign in to comment.