-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for np.partition #3320
Support for np.partition #3320
Changes from 13 commits
45423a7
849a2ce
9adad60
44a3d6f
3869e06
dc8170e
cba8568
f4cf6fb
b03cc56
d68de2f
3cf00b2
888103e
8df9811
bbdd731
380f301
8e9301b
ebd6e61
a0544c5
821f060
511faa7
2e3f111
614fcb8
6293617
8ee1adb
bc83639
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -721,54 +721,76 @@ def nancumsum_impl(arr): | |
# Median and partitioning | ||
|
||
@register_jitable | ||
def _partition(A, low, high): | ||
mid = (low + high) >> 1 | ||
# NOTE: the pattern of swaps below for the pivot choice and the | ||
# partitioning gives good results (i.e. regular O(n log n)) | ||
# on sorted, reverse-sorted, and uniform arrays. Subtle changes | ||
# risk breaking this property. | ||
|
||
# Use median of three {low, middle, high} as the pivot | ||
if A[mid] < A[low]: | ||
A[low], A[mid] = A[mid], A[low] | ||
if A[high] < A[mid]: | ||
A[high], A[mid] = A[mid], A[high] | ||
if A[mid] < A[low]: | ||
A[low], A[mid] = A[mid], A[low] | ||
pivot = A[mid] | ||
def less_than(a, b): | ||
return a < b | ||
|
||
A[high], A[mid] = A[mid], A[high] | ||
i = low | ||
j = high - 1 | ||
while True: | ||
while i < high and A[i] < pivot: | ||
@register_jitable | ||
def nan_aware_less_than(a, b): | ||
if np.isnan(a): | ||
return False | ||
else: | ||
if np.isnan(b): | ||
return True | ||
else: | ||
return a < b | ||
|
||
def _partition_factory(pivotimpl): | ||
def _partition(A, low, high): | ||
mid = (low + high) >> 1 | ||
# NOTE: the pattern of swaps below for the pivot choice and the | ||
# partitioning gives good results (i.e. regular O(n log n)) | ||
# on sorted, reverse-sorted, and uniform arrays. Subtle changes | ||
# risk breaking this property. | ||
|
||
# Use median of three {low, middle, high} as the pivot | ||
if pivotimpl(A[mid], A[low]): | ||
A[low], A[mid] = A[mid], A[low] | ||
if pivotimpl(A[high], A[mid]): | ||
A[high], A[mid] = A[mid], A[high] | ||
if pivotimpl(A[mid], A[low]): | ||
A[low], A[mid] = A[mid], A[low] | ||
pivot = A[mid] | ||
|
||
A[high], A[mid] = A[mid], A[high] | ||
i = low | ||
j = high - 1 | ||
while True: | ||
while i < high and pivotimpl(A[i], pivot): | ||
i += 1 | ||
while j >= low and pivotimpl(pivot, A[j]): | ||
j -= 1 | ||
if i >= j: | ||
break | ||
A[i], A[j] = A[j], A[i] | ||
i += 1 | ||
while j >= low and pivot < A[j]: | ||
j -= 1 | ||
if i >= j: | ||
break | ||
A[i], A[j] = A[j], A[i] | ||
i += 1 | ||
j -= 1 | ||
# Put the pivot back in its final place (all items before `i` | ||
# are smaller than the pivot, all items at/after `i` are larger) | ||
A[i], A[high] = A[high], A[i] | ||
return i | ||
# Put the pivot back in its final place (all items before `i` | ||
# are smaller than the pivot, all items at/after `i` are larger) | ||
A[i], A[high] = A[high], A[i] | ||
return i | ||
return _partition | ||
|
||
@register_jitable | ||
def _select(arry, k, low, high): | ||
""" | ||
Select the k'th smallest element in array[low:high + 1]. | ||
""" | ||
i = _partition(arry, low, high) | ||
while i != k: | ||
if i < k: | ||
low = i + 1 | ||
i = _partition(arry, low, high) | ||
else: | ||
high = i - 1 | ||
i = _partition(arry, low, high) | ||
return arry[k] | ||
_partition = register_jitable(_partition_factory(less_than)) | ||
_partition_w_nan = register_jitable(_partition_factory(nan_aware_less_than)) | ||
|
||
def _select_factory(partitionimpl): | ||
def _select(arry, k, low, high): | ||
""" | ||
Select the k'th smallest element in array[low:high + 1]. | ||
""" | ||
i = partitionimpl(arry, low, high) | ||
while i != k: | ||
if i < k: | ||
low = i + 1 | ||
i = partitionimpl(arry, low, high) | ||
else: | ||
high = i - 1 | ||
i = partitionimpl(arry, low, high) | ||
return arry[k] | ||
return _select | ||
|
||
_select = register_jitable(_select_factory(_partition)) | ||
_select_w_nan = register_jitable(_select_factory(_partition_w_nan)) | ||
|
||
@register_jitable | ||
def _select_two(arry, k, low, high): | ||
|
@@ -967,6 +989,74 @@ def nanmedian_impl(arry): | |
|
||
return nanmedian_impl | ||
|
||
@register_jitable | ||
def np_partition_impl_inner(a, kth_array): | ||
|
||
# allocate and fill empty array rather than copy a and mutate in place | ||
# as the latter approach fails to preserve strides | ||
out = np.empty_like(a) | ||
|
||
idx = np.ndindex(a.shape[:-1]) # Numpy default partition axis is -1 | ||
for s in idx: | ||
dense = a[s].copy() | ||
low = 0 | ||
high = len(dense) - 1 | ||
|
||
for kth in kth_array: | ||
_select_w_nan(dense, kth, low, high) | ||
low = kth # narrow span of subsequent partition | ||
|
||
out[s] = dense | ||
return out | ||
|
||
@register_jitable | ||
def valid_kths(a, kth): | ||
""" | ||
Returns a sorted, unique array of kth values which serve | ||
as indexers for partitioning the input array, a. | ||
|
||
If the absolute value of any of the provided values | ||
is greater than a.shape[-1] an exception is raised since | ||
we are partitioning along the last axis (per Numpy default | ||
behaviour). | ||
|
||
Values less than 0 are transformed to equivalent positive | ||
index values. | ||
""" | ||
kth_array = _asarray(kth) | ||
|
||
if np.any(np.abs(kth_array) >= a.shape[-1]): | ||
raise ValueError("kth out of bounds") | ||
|
||
out = np.empty_like(kth_array) | ||
|
||
for index, val in np.ndenumerate(kth_array): | ||
if val < 0: | ||
out[index] = val + a.shape[-1] # equivalent positive index | ||
else: | ||
out[index] = val | ||
|
||
return np.unique(out) | ||
|
||
@overload(np.partition) | ||
def np_partition(a, kth): | ||
|
||
if isinstance(kth, (types.Array, types.Sequence)): | ||
if not isinstance(kth.dtype, types.Integer): | ||
raise TypeError('Partition index must be integer') | ||
else: | ||
if not isinstance(kth, types.Integer): | ||
raise TypeError('Partition index must be integer') | ||
|
||
def np_partition_impl(a, kth): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think As an aside, it's weird in NumPy that this is valid:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, now using _asarray. Those two edge cases are pretty weird - I added them explicitly to check the behaviour is equivalently weird. |
||
if len(a.flat) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
return a.copy() | ||
else: | ||
kth_array = valid_kths(a, kth) | ||
return np_partition_impl_inner(a, kth_array) | ||
|
||
return np_partition_impl | ||
|
||
#---------------------------------------------------------------------------- | ||
# Building matrices | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps fold this branching into:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also think that
a
needs some type legalization similar to that ofkth
, what is valid as a type fora
? Seems like Numpy rejects at least scalars and 0d arrays likenp.array(1)
. I guessa
could feasibly a sequence type too?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took out the branching and added support for Boolean (who knew)!
Put in a guard to reject 0D arrays and support array-like inputs (tuple, list).