Skip to content

Commit

Permalink
Merge pull request #3320 from rjenc29/partition
Browse files Browse the repository at this point in the history
Support for np.partition
  • Loading branch information
stuartarchibald committed Oct 17, 2018
2 parents 61b2933 + bc83639 commit 92d7b37
Show file tree
Hide file tree
Showing 3 changed files with 465 additions and 47 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/numpysupported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ The following top-level functions are supported:
* :class:`numpy.nditer` (only the first argument)
* :func:`numpy.ones` (only the 2 first arguments)
* :func:`numpy.ones_like` (only the 2 first arguments)
* :func:`numpy.partition` (only the 2 first arguments)
* :func:`numpy.ravel` (no order argument; 'C' order only)
* :func:`numpy.reshape` (no order argument; 'C' order only)
* :func:`numpy.roots`
Expand Down
186 changes: 142 additions & 44 deletions numba/targets/arraymath.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,54 +721,76 @@ def nancumsum_impl(arr):
# Median and partitioning

@register_jitable
def _partition(A, low, high):
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.

# Use median of three {low, middle, high} as the pivot
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
if A[high] < A[mid]:
A[high], A[mid] = A[mid], A[high]
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
def less_than(a, b):
return a < b

A[high], A[mid] = A[mid], A[high]
i = low
j = high - 1
while True:
while i < high and A[i] < pivot:
@register_jitable
def nan_aware_less_than(a, b):
if np.isnan(a):
return False
else:
if np.isnan(b):
return True
else:
return a < b

def _partition_factory(pivotimpl):
def _partition(A, low, high):
mid = (low + high) >> 1
# NOTE: the pattern of swaps below for the pivot choice and the
# partitioning gives good results (i.e. regular O(n log n))
# on sorted, reverse-sorted, and uniform arrays. Subtle changes
# risk breaking this property.

# Use median of three {low, middle, high} as the pivot
if pivotimpl(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
if pivotimpl(A[high], A[mid]):
A[high], A[mid] = A[mid], A[high]
if pivotimpl(A[mid], A[low]):
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]

A[high], A[mid] = A[mid], A[high]
i = low
j = high - 1
while True:
while i < high and pivotimpl(A[i], pivot):
i += 1
while j >= low and pivotimpl(pivot, A[j]):
j -= 1
if i >= j:
break
A[i], A[j] = A[j], A[i]
i += 1
while j >= low and pivot < A[j]:
j -= 1
if i >= j:
break
A[i], A[j] = A[j], A[i]
i += 1
j -= 1
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
A[i], A[high] = A[high], A[i]
return i
# Put the pivot back in its final place (all items before `i`
# are smaller than the pivot, all items at/after `i` are larger)
A[i], A[high] = A[high], A[i]
return i
return _partition

@register_jitable
def _select(arry, k, low, high):
"""
Select the k'th smallest element in array[low:high + 1].
"""
i = _partition(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high)
else:
high = i - 1
i = _partition(arry, low, high)
return arry[k]
_partition = register_jitable(_partition_factory(less_than))
_partition_w_nan = register_jitable(_partition_factory(nan_aware_less_than))

def _select_factory(partitionimpl):
def _select(arry, k, low, high):
"""
Select the k'th smallest element in array[low:high + 1].
"""
i = partitionimpl(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = partitionimpl(arry, low, high)
else:
high = i - 1
i = partitionimpl(arry, low, high)
return arry[k]
return _select

_select = register_jitable(_select_factory(_partition))
_select_w_nan = register_jitable(_select_factory(_partition_w_nan))

@register_jitable
def _select_two(arry, k, low, high):
Expand Down Expand Up @@ -967,6 +989,82 @@ def nanmedian_impl(arry):

return nanmedian_impl

@register_jitable
def np_partition_impl_inner(a, kth_array):

# allocate and fill empty array rather than copy a and mutate in place
# as the latter approach fails to preserve strides
out = np.empty_like(a)

idx = np.ndindex(a.shape[:-1]) # Numpy default partition axis is -1
for s in idx:
arry = a[s].copy()
low = 0
high = len(arry) - 1

for kth in kth_array:
_select_w_nan(arry, kth, low, high)
low = kth # narrow span of subsequent partition

out[s] = arry
return out

@register_jitable
def valid_kths(a, kth):
"""
Returns a sorted, unique array of kth values which serve
as indexers for partitioning the input array, a.
If the absolute value of any of the provided values
is greater than a.shape[-1] an exception is raised since
we are partitioning along the last axis (per Numpy default
behaviour).
Values less than 0 are transformed to equivalent positive
index values.
"""
kth_array = _asarray(kth).astype(np.int64) # cast boolean to int, where relevant

if kth_array.ndim != 1:
raise ValueError('kth must be scalar or 1-D')
# numpy raises ValueError: object too deep for desired array

if np.any(np.abs(kth_array) >= a.shape[-1]):
raise ValueError("kth out of bounds")

out = np.empty_like(kth_array)

for index, val in np.ndenumerate(kth_array):
if val < 0:
out[index] = val + a.shape[-1] # equivalent positive index
else:
out[index] = val

return np.unique(out)

@overload(np.partition)
def np_partition(a, kth):

if not isinstance(a, (types.Array, types.Sequence, types.Tuple)):
raise TypeError('The first argument must be an array-like')

if isinstance(a, types.Array) and a.ndim == 0:
raise TypeError('The first argument must be at least 1-D (found 0-D)')

kthdt = getattr(kth, 'dtype', kth)
if not isinstance(kthdt, (types.Boolean, types.Integer)): # bool gets cast to int subsequently
raise TypeError('Partition index must be integer')

def np_partition_impl(a, kth):
a_tmp = _asarray(a)
if a_tmp.size == 0:
return a_tmp.copy()
else:
kth_array = valid_kths(a_tmp, kth)
return np_partition_impl_inner(a_tmp, kth_array)

return np_partition_impl

#----------------------------------------------------------------------------
# Building matrices

Expand Down

0 comments on commit 92d7b37

Please sign in to comment.