Skip to content

Commit

Permalink
Merge pull request numpy#4244 from jaimefrio/binsearch
Browse files Browse the repository at this point in the history
ENH: Type specific binary search functions for `searchsorted`
  • Loading branch information
juliantaylor committed Feb 14, 2014
2 parents 0e9956e + 9350d4d commit db198d5
Show file tree
Hide file tree
Showing 8 changed files with 500 additions and 228 deletions.
6 changes: 6 additions & 0 deletions doc/release/1.9.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ The performance of converting lists containing arrays to arrays using
`np.array` has been improved. It is now equivalent in speed to
`np.vstack(list)`.

Performance improvement for `np.searchsorted`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
For the built-in numeric types, `np.searchsorted` no longer relies on the
data type's `compare` function to perform the search, but is now implemented
by type specific functions. Depending on the size of the inputs, this can
result in performance improvements over 2x.

Changes
=======
Expand Down
4 changes: 3 additions & 1 deletion numpy/core/bento.info
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@ Library:
CompiledLibrary: npysort
Sources:
src/private/npy_partition.h.src,
src/private/npy_binsearch.h.src,
src/npysort/quicksort.c.src,
src/npysort/mergesort.c.src,
src/npysort/heapsort.c.src,
src/npysort/selection.c.src
src/npysort/selection.c.src,
src/npysort/binsearch.c.src
Extension: multiarray
Sources:
src/multiarray/multiarraymodule_onefile.c
Expand Down
7 changes: 5 additions & 2 deletions numpy/core/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,11 @@ def get_mathlib_info(*args):
npysort_sources=[join('src', 'npysort', 'quicksort.c.src'),
join('src', 'npysort', 'mergesort.c.src'),
join('src', 'npysort', 'heapsort.c.src'),
join('src','private', 'npy_partition.h.src'),
join('src', 'npysort', 'selection.c.src')]
join('src', 'private', 'npy_partition.h.src'),
join('src', 'npysort', 'selection.c.src'),
join('src', 'private', 'npy_binsearch.h.src'),
join('src', 'npysort', 'binsearch.c.src'),
]
config.add_library('npysort',
sources=npysort_sources,
include_dirs=[])
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/ctors.c
Original file line number Diff line number Diff line change
Expand Up @@ -1842,7 +1842,7 @@ PyArray_CheckFromAny(PyObject *op, PyArray_Descr *descr, int min_depth,
else if (descr && !PyArray_ISNBO(descr->byteorder)) {
PyArray_DESCR_REPLACE(descr);
}
if (descr) {
if (descr && descr->byteorder != NPY_IGNORE) {
descr->byteorder = NPY_NATIVE;
}
}
Expand Down
274 changes: 50 additions & 224 deletions numpy/core/src/multiarray/item_selection.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "item_selection.h"
#include "npy_sort.h"
#include "npy_partition.h"
#include "npy_binsearch.h"

/*NUMPY_API
* Take
Expand Down Expand Up @@ -1866,196 +1867,6 @@ PyArray_LexSort(PyObject *sort_keys, int axis)
}


/** @brief Use bisection of sorted array to find first entries >= keys.
*
* For each key use bisection to find the first index i s.t. key <= arr[i].
* When there is no such index i, set i = len(arr). Return the results in ret.
* Both arr and key must be of the same comparable type.
*
* @param arr 1d, strided, sorted array to be searched.
* @param key contiguous array of keys.
* @param ret contiguous array of intp for returned indices.
* @return void
*/
static void
local_search_left(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
{
PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare;
npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1];
npy_intp nkeys = PyArray_SIZE(key);
char *parr = PyArray_DATA(arr);
char *pkey = PyArray_DATA(key);
npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
int elsize = PyArray_DESCR(key)->elsize;
npy_intp arrstride = *PyArray_STRIDES(arr);
npy_intp i;

for (i = 0; i < nkeys; ++i) {
npy_intp imin = 0;
npy_intp imax = nelts;
while (imin < imax) {
npy_intp imid = imin + ((imax - imin) >> 1);
if (compare(parr + arrstride*imid, pkey, key) < 0) {
imin = imid + 1;
}
else {
imax = imid;
}
}
*pret = imin;
pret += 1;
pkey += elsize;
}
}


/** @brief Use bisection of sorted array to find first entries > keys.
*
* For each key use bisection to find the first index i s.t. key < arr[i].
* When there is no such index i, set i = len(arr). Return the results in ret.
* Both arr and key must be of the same comparable type.
*
* @param arr 1d, strided, sorted array to be searched.
* @param key contiguous array of keys.
* @param ret contiguous array of intp for returned indices.
* @return void
*/
static void
local_search_right(PyArrayObject *arr, PyArrayObject *key, PyArrayObject *ret)
{
PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare;
npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1];
npy_intp nkeys = PyArray_SIZE(key);
char *parr = PyArray_DATA(arr);
char *pkey = PyArray_DATA(key);
npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
int elsize = PyArray_DESCR(key)->elsize;
npy_intp arrstride = *PyArray_STRIDES(arr);
npy_intp i;

for(i = 0; i < nkeys; ++i) {
npy_intp imin = 0;
npy_intp imax = nelts;
while (imin < imax) {
npy_intp imid = imin + ((imax - imin) >> 1);
if (compare(parr + arrstride*imid, pkey, key) <= 0) {
imin = imid + 1;
}
else {
imax = imid;
}
}
*pret = imin;
pret += 1;
pkey += elsize;
}
}

/** @brief Use bisection of sorted array to find first entries >= keys.
*
* For each key use bisection to find the first index i s.t. key <= arr[i].
* When there is no such index i, set i = len(arr). Return the results in ret.
* Both arr and key must be of the same comparable type.
*
* @param arr 1d, strided array to be searched.
* @param key contiguous array of keys.
* @param sorter 1d, strided array of intp that sorts arr.
* @param ret contiguous array of intp for returned indices.
* @return int
*/
static int
local_argsearch_left(PyArrayObject *arr, PyArrayObject *key,
PyArrayObject *sorter, PyArrayObject *ret)
{
PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare;
npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1];
npy_intp nkeys = PyArray_SIZE(key);
char *parr = PyArray_DATA(arr);
char *pkey = PyArray_DATA(key);
char *psorter = PyArray_DATA(sorter);
npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
int elsize = PyArray_DESCR(key)->elsize;
npy_intp arrstride = *PyArray_STRIDES(arr);
npy_intp sorterstride = *PyArray_STRIDES(sorter);
npy_intp i;

for (i = 0; i < nkeys; ++i) {
npy_intp imin = 0;
npy_intp imax = nelts;
while (imin < imax) {
npy_intp imid = imin + ((imax - imin) >> 1);
npy_intp indx = *(npy_intp *)(psorter + sorterstride * imid);

if (indx < 0 || indx >= nelts) {
return -1;
}
if (compare(parr + arrstride*indx, pkey, key) < 0) {
imin = imid + 1;
}
else {
imax = imid;
}
}
*pret = imin;
pret += 1;
pkey += elsize;
}
return 0;
}


/** @brief Use bisection of sorted array to find first entries > keys.
*
* For each key use bisection to find the first index i s.t. key < arr[i].
* When there is no such index i, set i = len(arr). Return the results in ret.
* Both arr and key must be of the same comparable type.
*
* @param arr 1d, strided array to be searched.
* @param key contiguous array of keys.
* @param sorter 1d, strided array of intp that sorts arr.
* @param ret contiguous array of intp for returned indices.
* @return int
*/
static int
local_argsearch_right(PyArrayObject *arr, PyArrayObject *key,
PyArrayObject *sorter, PyArrayObject *ret)
{
PyArray_CompareFunc *compare = PyArray_DESCR(key)->f->compare;
npy_intp nelts = PyArray_DIMS(arr)[PyArray_NDIM(arr) - 1];
npy_intp nkeys = PyArray_SIZE(key);
char *parr = PyArray_DATA(arr);
char *pkey = PyArray_DATA(key);
char *psorter = PyArray_DATA(sorter);
npy_intp *pret = (npy_intp *)PyArray_DATA(ret);
int elsize = PyArray_DESCR(key)->elsize;
npy_intp arrstride = *PyArray_STRIDES(arr);
npy_intp sorterstride = *PyArray_STRIDES(sorter);
npy_intp i;

for(i = 0; i < nkeys; ++i) {
npy_intp imin = 0;
npy_intp imax = nelts;
while (imin < imax) {
npy_intp imid = imin + ((imax - imin) >> 1);
npy_intp indx = *(npy_intp *)(psorter + sorterstride * imid);

if (indx < 0 || indx >= nelts) {
return -1;
}
if (compare(parr + arrstride*indx, pkey, key) <= 0) {
imin = imid + 1;
}
else {
imax = imid;
}
}
*pret = imin;
pret += 1;
pkey += elsize;
}
return 0;
}

/*NUMPY_API
*
* Search the sorted array op1 for the location of the items in op2. The
Expand Down Expand Up @@ -2096,47 +1907,63 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
PyArrayObject *ret = NULL;
PyArray_Descr *dtype;
int ap1_flags = NPY_ARRAY_NOTSWAPPED | NPY_ARRAY_ALIGNED;
PyArray_BinSearchFunc *binsearch = NULL;
PyArray_ArgBinSearchFunc *argbinsearch = NULL;
NPY_BEGIN_THREADS_DEF;

/* Find common type */
dtype = PyArray_DescrFromObject((PyObject *)op2, PyArray_DESCR(op1));
if (dtype == NULL) {
return NULL;
}
/* refs to dtype we own = 1 */

/* Look for binary search function */
if (perm) {
argbinsearch = get_argbinsearch_func(dtype, side);
}
else {
binsearch = get_binsearch_func(dtype, side);
}
if (binsearch == NULL && argbinsearch == NULL) {
PyErr_SetString(PyExc_TypeError, "compare not supported for type");
/* refs to dtype we own = 1 */
Py_DECREF(dtype);
/* refs to dtype we own = 0 */
return NULL;
}

/* need ap2 as contiguous array and of right type */
/* refs to dtype we own = 1 */
Py_INCREF(dtype);
/* refs to dtype we own = 2 */
ap2 = (PyArrayObject *)PyArray_CheckFromAny(op2, dtype,
0, 0,
NPY_ARRAY_CARRAY_RO | NPY_ARRAY_NOTSWAPPED,
NULL);
/* refs to dtype we own = 1, array creation steals one even on failure */
if (ap2 == NULL) {
Py_DECREF(dtype);
/* refs to dtype we own = 0 */
return NULL;
}

/*
* If the needle (ap2) is larger than the haystack (op1) we copy the
* haystack to a continuous array for improved cache utilization.
* haystack to a contiguous array for improved cache utilization.
*/
if (PyArray_SIZE(ap2) > PyArray_SIZE(op1)) {
ap1_flags |= NPY_ARRAY_CARRAY_RO;
}

ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype,
1, 1, ap1_flags, NULL);
/* refs to dtype we own = 0, array creation steals one even on failure */
if (ap1 == NULL) {
goto fail;
}
/* check that comparison function exists */
if (PyArray_DESCR(ap2)->f->compare == NULL) {
PyErr_SetString(PyExc_TypeError,
"compare not supported for type");
goto fail;
}

if (perm) {
/* need ap3 as contiguous array and of right type */
/* need ap3 as a 1D aligned, not swapped, array of right type */
ap3 = (PyArrayObject *)PyArray_CheckFromAny(perm, NULL,
1, 1,
NPY_ARRAY_ALIGNED | NPY_ARRAY_NOTSWAPPED,
Expand Down Expand Up @@ -2167,7 +1994,7 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
}
}

/* ret is a contiguous array of intp type to hold returned indices */
/* ret is a contiguous array of intp type to hold returned indexes */
ret = (PyArrayObject *)PyArray_New(Py_TYPE(ap2), PyArray_NDIM(ap2),
PyArray_DIMS(ap2), NPY_INTP,
NULL, NULL, 0, 0, (PyObject *)ap2);
Expand All @@ -2176,33 +2003,32 @@ PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
}

if (ap3 == NULL) {
if (side == NPY_SEARCHLEFT) {
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
local_search_left(ap1, ap2, ret);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
else if (side == NPY_SEARCHRIGHT) {
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
local_search_right(ap1, ap2, ret);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
/* do regular binsearch */
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
binsearch((const char *)PyArray_DATA(ap1),
(const char *)PyArray_DATA(ap2),
(char *)PyArray_DATA(ret),
PyArray_SIZE(ap1), PyArray_SIZE(ap2),
PyArray_STRIDES(ap1)[0], PyArray_DESCR(ap2)->elsize,
NPY_SIZEOF_INTP, ap2);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
else {
int err=0;

if (side == NPY_SEARCHLEFT) {
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
err = local_argsearch_left(ap1, ap2, sorter, ret);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
else if (side == NPY_SEARCHRIGHT) {
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
err = local_argsearch_right(ap1, ap2, sorter, ret);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
if (err < 0) {
/* do binsearch with a sorter array */
int error = 0;
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
error = argbinsearch((const char *)PyArray_DATA(ap1),
(const char *)PyArray_DATA(ap2),
(const char *)PyArray_DATA(ap3),
(char *)PyArray_DATA(ret),
PyArray_SIZE(ap1), PyArray_SIZE(ap2),
PyArray_STRIDES(ap1)[0],
PyArray_DESCR(ap2)->elsize,
PyArray_STRIDES(ap3)[0], NPY_SIZEOF_INTP, ap2);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
if (error < 0) {
PyErr_SetString(PyExc_ValueError,
"Sorter index out of range.");
"Sorter index out of range.");
goto fail;
}
Py_DECREF(ap3);
Expand Down
Loading

0 comments on commit db198d5

Please sign in to comment.