From 94472bd702afd2a063e630b3aa020d9ccc0ec7ec Mon Sep 17 00:00:00 2001 From: Sebastian Berg Date: Mon, 20 Feb 2023 15:51:17 +0100 Subject: [PATCH] ENH: Avoid use of item XINCREF and DECREF in fasttake Rather, use the cast function directly when the copy is not trivial (which we know based on it not passing REFCHK, if that passes we assume memcpy is fine). Also uses memcpy, since no overlap is possible here. --- benchmarks/benchmarks/bench_itemselection.py | 2 +- numpy/core/src/multiarray/item_selection.c | 78 ++++++++++++++------ 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/benchmarks/benchmarks/bench_itemselection.py b/benchmarks/benchmarks/bench_itemselection.py index 518258a8f564..81c788a2c8bd 100644 --- a/benchmarks/benchmarks/bench_itemselection.py +++ b/benchmarks/benchmarks/bench_itemselection.py @@ -7,7 +7,7 @@ class Take(Benchmark): params = [ [(1000, 1), (1000, 2), (2, 1000, 1), (1000, 3)], ["raise", "wrap", "clip"], - TYPES1] + TYPES1 + ["O", "i,O"]] param_names = ["shape", "mode", "dtype"] def setup(self, shape, mode, dtype): diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 44c5531337f1..25ea011c3be5 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -17,6 +17,7 @@ #include "multiarraymodule.h" #include "common.h" +#include "dtype_transfer.h" #include "arrayobject.h" #include "ctors.h" #include "lowlevel_strided_loops.h" @@ -39,7 +40,26 @@ npy_fasttake_impl( PyArray_Descr *dtype, int axis) { NPY_BEGIN_THREADS_DEF; - NPY_BEGIN_THREADS_DESCR(dtype); + + NPY_cast_info cast_info; + NPY_ARRAYMETHOD_FLAGS flags; + NPY_cast_info_init(&cast_info); + + if (!needs_refcounting) { + /* if "refcounting" is not needed memcpy is safe for a simple copy */ + NPY_BEGIN_THREADS; + } + else { + if (PyArray_GetDTypeTransferFunction( + 1, itemsize, itemsize, dtype, dtype, 0, + &cast_info, &flags) < 0) { + return -1; + } + if (!(flags & NPY_METH_REQUIRES_PYAPI)) { + NPY_BEGIN_THREADS; + } + } + switch (clipmode) { case NPY_RAISE: for (npy_intp i = 0; i < n; i++) { @@ -47,20 +67,22 @@ npy_fasttake_impl( npy_intp tmp = indices[j]; if (check_and_adjust_index(&tmp, max_item, axis, _save) < 0) { - return -1; + goto fail; } char *tmp_src = src + tmp * chunk; if (needs_refcounting) { - for (npy_intp k = 0; k < nelem; k++) { - PyArray_Item_INCREF(tmp_src, dtype); - PyArray_Item_XDECREF(dest, dtype); - memmove(dest, tmp_src, itemsize); - dest += itemsize; - tmp_src += itemsize; + char *data[2] = {tmp_src, dest}; + npy_intp strides[2] = {itemsize, itemsize}; + if (cast_info.func( + &cast_info.context, data, &nelem, strides, + cast_info.auxdata) < 0) { + NPY_END_THREADS; + goto fail; } + dest += itemsize * nelem; } else { - memmove(dest, tmp_src, chunk); + memcpy(dest, tmp_src, chunk); dest += chunk; } } @@ -83,16 +105,18 @@ npy_fasttake_impl( } char *tmp_src = src + tmp * chunk; if (needs_refcounting) { - for (npy_intp k = 0; k < nelem; k++) { - PyArray_Item_INCREF(tmp_src, dtype); - PyArray_Item_XDECREF(dest, dtype); - memmove(dest, tmp_src, itemsize); - dest += itemsize; - tmp_src += itemsize; + char *data[2] = {tmp_src, dest}; + npy_intp strides[2] = {itemsize, itemsize}; + if (cast_info.func( + &cast_info.context, data, &nelem, strides, + cast_info.auxdata) < 0) { + NPY_END_THREADS; + goto fail; } + dest += itemsize * nelem; } else { - memmove(dest, tmp_src, chunk); + memcpy(dest, tmp_src, chunk); dest += chunk; } } @@ -111,16 +135,18 @@ npy_fasttake_impl( } char *tmp_src = src + tmp * chunk; if (needs_refcounting) { - for (npy_intp k = 0; k < nelem; k++) { - PyArray_Item_INCREF(tmp_src, dtype); - PyArray_Item_XDECREF(dest, dtype); - memmove(dest, tmp_src, itemsize); - dest += itemsize; - tmp_src += itemsize; + char *data[2] = {tmp_src, dest}; + npy_intp strides[2] = {itemsize, itemsize}; + if (cast_info.func( + &cast_info.context, data, &nelem, strides, + cast_info.auxdata) < 0) { + NPY_END_THREADS; + goto fail; } + dest += itemsize * nelem; } else { - memmove(dest, tmp_src, chunk); + memcpy(dest, tmp_src, chunk); dest += chunk; } } @@ -130,7 +156,13 @@ npy_fasttake_impl( } NPY_END_THREADS; + NPY_cast_info_xfree(&cast_info); return 0; + + fail: + /* NPY_END_THREADS already ensured. */ + NPY_cast_info_xfree(&cast_info); + return -1; }