Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for np.partition #3320

Merged
merged 25 commits into from
Oct 17, 2018
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ The following top-level functions are supported:
* :class:`numpy.nditer` (only the first argument)
* :func:`numpy.ones` (only the 2 first arguments)
* :func:`numpy.ones_like` (only the 2 first arguments)
* :func:`numpy.partition` (only the 2 first arguments)
* :func:`numpy.ravel` (no order argument; 'C' order only)
* :func:`numpy.reshape` (no order argument; 'C' order only)
* :func:`numpy.roots`
Expand Down
186 changes: 142 additions & 44 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,54 +721,76 @@ def nancumsum_impl(arr):
# Median and partitioning

@register_jitable
def _partition(A, low, high):
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.

# Use median of three {low, middle, high} as the pivot
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
if A[high] < A[mid]:
A[high], A[mid] = A[mid], A[high]
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
def less_than(a, b):
return a < b

A[high], A[mid] = A[mid], A[high]
i = low
j = high - 1
while True:
while i < high and A[i] < pivot:
@register_jitable
def nan_aware_less_than(a, b):
if np.isnan(a):
return False
else:
if np.isnan(b):
return True
else:
return a < b

def _partition_factory(pivotimpl):
def _partition(A, low, high):
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.

# Use median of three {low, middle, high} as the pivot
if pivotimpl(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
if pivotimpl(A[high], A[mid]):
A[high], A[mid] = A[mid], A[high]
if pivotimpl(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]

A[high], A[mid] = A[mid], A[high]
i = low
j = high - 1
while True:
while i < high and pivotimpl(A[i], pivot):
i += 1
while j >= low and pivotimpl(pivot, A[j]):
j -= 1
if i >= j:
break
A[i], A[j] = A[j], A[i]
i += 1
while j >= low and pivot < A[j]:
j -= 1
if i >= j:
break
A[i], A[j] = A[j], A[i]
i += 1
j -= 1
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
A[i], A[high] = A[high], A[i]
return i
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
A[i], A[high] = A[high], A[i]
return i
return _partition

@register_jitable
def _select(arry, k, low, high):
"""
Select the k'th smallest element in array[low:high + 1].
"""
i = _partition(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high)
else:
high = i - 1
i = _partition(arry, low, high)
return arry[k]
_partition = register_jitable(_partition_factory(less_than))
_partition_w_nan = register_jitable(_partition_factory(nan_aware_less_than))

def _select_factory(partitionimpl):
def _select(arry, k, low, high):
"""
Select the k'th smallest element in array[low:high + 1].
"""
i = partitionimpl(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = partitionimpl(arry, low, high)
else:
high = i - 1
i = partitionimpl(arry, low, high)
return arry[k]
return _select

_select = register_jitable(_select_factory(_partition))
_select_w_nan = register_jitable(_select_factory(_partition_w_nan))

@register_jitable
def _select_two(arry, k, low, high):
Expand Down Expand Up @@ -967,6 +989,82 @@ def nanmedian_impl(arry):

return nanmedian_impl

@register_jitable
def np_partition_impl_inner(a, kth_array):

# allocate and fill empty array rather than copy a and mutate in place
# as the latter approach fails to preserve strides
out = np.empty_like(a)

idx = np.ndindex(a.shape[:-1]) # Numpy default partition axis is -1
for s in idx:
arry = a[s].copy()
low = 0
high = len(arry) - 1

for kth in kth_array:
_select_w_nan(arry, kth, low, high)
low = kth # narrow span of subsequent partition

out[s] = arry
return out

@register_jitable
def valid_kths(a, kth):
"""
Returns a sorted, unique array of kth values which serve
as indexers for partitioning the input array, a.

If the absolute value of any of the provided values
is greater than a.shape[-1] an exception is raised since
we are partitioning along the last axis (per Numpy default
behaviour).

Values less than 0 are transformed to equivalent positive
index values.
"""
kth_array = _asarray(kth).astype(np.int64) # cast boolean to int, where relevant

if kth_array.ndim != 1:
raise ValueError('kth must be scalar or 1-D')
# numpy raises ValueError: object too deep for desired array

if np.any(np.abs(kth_array) >= a.shape[-1]):
raise ValueError("kth out of bounds")

out = np.empty_like(kth_array)

for index, val in np.ndenumerate(kth_array):
if val < 0:
out[index] = val + a.shape[-1] # equivalent positive index
else:
out[index] = val

return np.unique(out)

@overload(np.partition)
def np_partition(a, kth):

if not isinstance(a, (types.Array, types.Sequence, types.Tuple)):
raise TypeError('The first argument must be an array-like')

if isinstance(a, types.Array) and a.ndim == 0:
raise TypeError('The first argument must be at least 1-D (found 0-D)')

kthdt = getattr(kth, 'dtype', kth)
if not isinstance(kthdt, (types.Boolean, types.Integer)): # bool gets cast to int subsequently
raise TypeError('Partition index must be integer')

def np_partition_impl(a, kth):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a needs to be put through _asarray. Without this asking for array attrs on a is an error. Also see my note on type legalization above.

As an aside, it's weird in NumPy that this is valid:

np.partition([], 2)
np.partition(np.array([]), 2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, now using _asarray.

Those two edge cases are pretty weird - I added them explicitly to check the behaviour is equivalently weird.

a_tmp = _asarray(a)
if a_tmp.size == 0:
return a_tmp.copy()
else:
kth_array = valid_kths(a_tmp, kth)
return np_partition_impl_inner(a_tmp, kth_array)

return np_partition_impl

#----------------------------------------------------------------------------
# Building matrices

Expand Down