Permalink
Browse files

ENH Improve output for all NaNs; Add performance notes.

  • Loading branch information...
kwgoodman committed May 18, 2012
1 parent 412b853 commit aa7d73940b955ea37b2e141ad1077b755d3b288d
Showing with 37 additions and 7 deletions.
  1. +1 −1 RELEASE.rst
  2. +28 −6 bottleneck/src/template/func/nn.py
  3. +8 −0 bottleneck/tests/func_test.py
View
@@ -16,7 +16,7 @@ Thanks to Dougal Sutherland, Bottleneck now runs on Python 3.2.
**New functions**
- replace(arr, old, new), e.g, replace(arr, np.nan, 0)
-- nn(arr, arr0, axis) nearest neighbor of 1d arr0 in 2d arr
+- nn(arr, arr0, axis) nearest neighbor and its index of 1d arr0 in 2d arr
**Enhancements**
@@ -22,10 +22,10 @@
@cython.wraparound(False)
def NAME_NDIMd_DTYPE_axisAXIS(np.ndarray[np.DTYPE_t, ndim=2] a,
np.ndarray[np.DTYPE_t, ndim=1] a0):
- "Nearest neighbor of NDIMd array with dtype=DTYPE along axis=AXIS."
+ "Nearest neighbor of 1d `a0` in 2d `a` with dtype=DTYPE, axis=AXIS."
cdef:
np.float64_t xsum = 0, d, xsummin=np.inf, dist
- Py_ssize_t imin = 0, n, a0size
+ Py_ssize_t imin = -1, n, a0size
"""
loop = {}
@@ -41,7 +41,11 @@ def NAME_NDIMd_DTYPE_axisAXIS(np.ndarray[np.DTYPE_t, ndim=2] a,
if xsum < xsummin:
xsummin = xsum
imin = iINDEX0
- dist = sqrt(xsummin)
+ if imin == -1:
+ dist = NAN
+ imin = 0
+ else:
+ dist = sqrt(xsummin)
return dist, imin
"""
floats['loop'] = loop
@@ -82,7 +86,10 @@ def nn(arr, arr0, int axis=1):
The squared distance used to determine the nearest neighbor of `arr0`
is equivalent to np.sum((arr - arr0) ** 2), axis) where `arr` is 2d
- and `arr0` is 1d.
+ and `arr0` is 1d and `arr0` must be reshaped if `axis` is 1.
+
+ If all distances are NaN then the distance returned is NaN and the
+ index is zero.
Parameters
----------
@@ -98,10 +105,25 @@ def nn(arr, arr0, int axis=1):
-------
dist : np.float64
The Euclidian distance between `arr0` and the nearest neighbor
- in `arr`.
+ in `arr`. If all distances are NaN then the distance returned
+ is NaN.
idx : int
- Index of nearest neighbor in `arr`.
+ Index of nearest neighbor in `arr`. If all distances are NaN
+ then the index returned is zero.
+
+ Notes
+ -----
+ A brute force algorithm is used to find the nearest neighbor.
+
+ Depending on the shapes of `arr` and `arr0`, SciPy's cKDTree may
+ be faster than bn.nn(). So benchmark if speed is important.
+ The relative speed also depends on how many times you will use
+ the same array `arr` to find nearest neighbors with different
+ `arr0`. That is because it takes time to set up SciPy's cKDTree.
+
+ If `arr` fits into your memory's cache then bn.nn() is fast.
+
Examples
--------
Create the input arrays:
@@ -172,6 +172,14 @@ def arrays2(dtypes=bn.dtypes):
yield arr.copy(), arr0[:2].copy(), axis
axis = 1
yield arr.copy(), arr0.copy(), axis
+ if issubclass(arr.dtype.type, np.inexact):
+ # Make sure NaNs are handled in the same way
+ arr.fill(np.nan)
+ arr0.fill(np.nan)
+ axis = 0
+ yield arr.copy(), arr0[:2].copy(), axis
+ axis = 1
+ yield arr.copy(), arr0.copy(), axis
def test_nn():
"Test that bn.nn gives the same output as bn.slow.nn."

0 comments on commit aa7d739

Please sign in to comment.