|
15 | 15 | is_torch_namespace,
|
16 | 16 | )
|
17 | 17 | from ._lib._utils._compat import device as get_device
|
18 |
| -from ._lib._utils._helpers import asarrays |
| 18 | +from ._lib._utils._helpers import asarrays, eager_shape |
19 | 19 | from ._lib._utils._typing import Array, DType
|
20 | 20 |
|
21 | 21 | __all__ = [
|
@@ -645,3 +645,194 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
645 | 645 | return xp.sinc(x)
|
646 | 646 |
|
647 | 647 | return _funcs.sinc(x, xp=xp)
|
| 648 | + |
| 649 | + |
| 650 | +def partition( |
| 651 | + a: Array, |
| 652 | + kth: int, |
| 653 | + /, |
| 654 | + axis: int | None = -1, |
| 655 | + *, |
| 656 | + xp: ModuleType | None = None, |
| 657 | +) -> Array: |
| 658 | + """ |
| 659 | + Return a partitioned copy of an array. |
| 660 | +
|
| 661 | + Creates a copy of the array and partially sorts it in such a way that the value |
| 662 | + of the element in k-th position is in the position it would be in a sorted array. |
| 663 | + In the output array, all elements smaller than the k-th element are located to |
| 664 | + the left of this element and all equal or greater are located to its right. |
| 665 | + The ordering of the elements in the two partitions on the either side of |
| 666 | + the k-th element in the output array is undefined. |
| 667 | +
|
| 668 | + Parameters |
| 669 | + ---------- |
| 670 | + a : Array |
| 671 | + Input array. |
| 672 | + kth : int |
| 673 | + Element index to partition by. |
| 674 | + axis : int, optional |
| 675 | + Axis along which to partition. The default is ``-1`` (the last axis). |
| 676 | + If ``None``, the flattened array is used. |
| 677 | + xp : array_namespace, optional |
| 678 | + The standard-compatible namespace for `x`. Default: infer. |
| 679 | +
|
| 680 | + Returns |
| 681 | + ------- |
| 682 | + partitioned_array |
| 683 | + Array of the same type and shape as `a`. |
| 684 | +
|
| 685 | + Notes |
| 686 | + ----- |
| 687 | + If `xp` implements ``partition`` or an equivalent function |
| 688 | + (e.g. ``topk`` for torch), complexity will likely be O(n). |
| 689 | + If not, this function simply calls ``xp.sort`` and complexity is O(n log n). |
| 690 | + """ |
| 691 | + # Validate inputs. |
| 692 | + if xp is None: |
| 693 | + xp = array_namespace(a) |
| 694 | + if a.ndim < 1: |
| 695 | + msg = "`a` must be at least 1-dimensional" |
| 696 | + raise TypeError(msg) |
| 697 | + if axis is None: |
| 698 | + return partition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp) |
| 699 | + (size,) = eager_shape(a, axis) |
| 700 | + if not (0 <= kth < size): |
| 701 | + msg = f"kth(={kth}) out of bounds [0 {size})" |
| 702 | + raise ValueError(msg) |
| 703 | + |
| 704 | + # Delegate where possible. |
| 705 | + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): |
| 706 | + return xp.partition(a, kth, axis=axis) |
| 707 | + |
| 708 | + # Use top-k when possible: |
| 709 | + if is_torch_namespace(xp): |
| 710 | + if not (axis == -1 or axis == a.ndim - 1): |
| 711 | + a = xp.transpose(a, axis, -1) |
| 712 | + |
| 713 | + out = xp.empty_like(a) |
| 714 | + ranks = xp.arange(a.shape[-1]).expand_as(a) |
| 715 | + |
| 716 | + split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True) |
| 717 | + del indices # indices won't be used => del ASAP to reduce peak memory usage |
| 718 | + |
| 719 | + # fill the left-side of the partition |
| 720 | + mask_src = a < split_value |
| 721 | + n_left = mask_src.sum(dim=-1, keepdim=True) |
| 722 | + mask_dest = ranks < n_left |
| 723 | + out[mask_dest] = a[mask_src] |
| 724 | + |
| 725 | + # fill the middle of the partition |
| 726 | + mask_src = a == split_value |
| 727 | + n_left += mask_src.sum(dim=-1, keepdim=True) |
| 728 | + mask_dest ^= ranks < n_left |
| 729 | + out[mask_dest] = a[mask_src] |
| 730 | + |
| 731 | + # fill the right-side of the partition |
| 732 | + mask_src = a > split_value |
| 733 | + mask_dest = ranks >= n_left |
| 734 | + out[mask_dest] = a[mask_src] |
| 735 | + |
| 736 | + if not (axis == -1 or axis == a.ndim - 1): |
| 737 | + out = xp.transpose(out, axis, -1) |
| 738 | + return out |
| 739 | + |
| 740 | + # Note: dask topk/argtopk sort the return values, so it's |
| 741 | + # not much more efficient than sorting everything when |
| 742 | + # kth is not small compared to x.size |
| 743 | + |
| 744 | + return _funcs.partition(a, kth, axis=axis, xp=xp) |
| 745 | + |
| 746 | + |
| 747 | +def argpartition( |
| 748 | + a: Array, |
| 749 | + kth: int, |
| 750 | + /, |
| 751 | + axis: int | None = -1, |
| 752 | + *, |
| 753 | + xp: ModuleType | None = None, |
| 754 | +) -> Array: |
| 755 | + """ |
| 756 | + Perform an indirect partition along the given axis. |
| 757 | +
|
| 758 | + It returns an array of indices of the same shape as `a` that |
| 759 | + index data along the given axis in partitioned order. |
| 760 | +
|
| 761 | + Parameters |
| 762 | + ---------- |
| 763 | + a : Array |
| 764 | + Input array. |
| 765 | + kth : int |
| 766 | + Element index to partition by. |
| 767 | + axis : int, optional |
| 768 | + Axis along which to partition. The default is ``-1`` (the last axis). |
| 769 | + If ``None``, the flattened array is used. |
| 770 | + xp : array_namespace, optional |
| 771 | + The standard-compatible namespace for `x`. Default: infer. |
| 772 | +
|
| 773 | + Returns |
| 774 | + ------- |
| 775 | + index_array |
| 776 | + Array of indices that partition `a` along the specified axis. |
| 777 | +
|
| 778 | + Notes |
| 779 | + ----- |
| 780 | + If `xp` implements ``argpartition`` or an equivalent function |
| 781 | + e.g. ``topk`` for torch), complexity will likely be O(n). |
| 782 | + If not, this function simply calls ``xp.argsort`` and complexity is O(n log n). |
| 783 | + """ |
| 784 | + # Validate inputs. |
| 785 | + if xp is None: |
| 786 | + xp = array_namespace(a) |
| 787 | + if is_pydata_sparse_namespace(xp): |
| 788 | + msg = "Not implemented for sparse backend: no argsort" |
| 789 | + raise NotImplementedError(msg) |
| 790 | + if a.ndim < 1: |
| 791 | + msg = "`a` must be at least 1-dimensional" |
| 792 | + raise TypeError(msg) |
| 793 | + if axis is None: |
| 794 | + return argpartition(xp.reshape(a, (-1,)), kth, axis=0, xp=xp) |
| 795 | + (size,) = eager_shape(a, axis) |
| 796 | + if not (0 <= kth < size): |
| 797 | + msg = f"kth(={kth}) out of bounds [0 {size})" |
| 798 | + raise ValueError(msg) |
| 799 | + |
| 800 | + # Delegate where possible. |
| 801 | + if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp): |
| 802 | + return xp.argpartition(a, kth, axis=axis) |
| 803 | + |
| 804 | + # Use top-k when possible: |
| 805 | + if is_torch_namespace(xp): |
| 806 | + # see `partition` above for commented details of those steps: |
| 807 | + if not (axis == -1 or axis == a.ndim - 1): |
| 808 | + a = xp.transpose(a, axis, -1) |
| 809 | + |
| 810 | + ranks = xp.arange(a.shape[-1]).expand_as(a) |
| 811 | + out = xp.empty_like(ranks) |
| 812 | + |
| 813 | + split_value, indices = xp.kthvalue(a, kth + 1, keepdim=True) |
| 814 | + del indices # indices won't be used => del ASAP to reduce peak memory usage |
| 815 | + |
| 816 | + mask_src = a < split_value |
| 817 | + n_left = mask_src.sum(dim=-1, keepdim=True) |
| 818 | + mask_dest = ranks < n_left |
| 819 | + out[mask_dest] = ranks[mask_src] |
| 820 | + |
| 821 | + mask_src = a == split_value |
| 822 | + n_left += mask_src.sum(dim=-1, keepdim=True) |
| 823 | + mask_dest ^= ranks < n_left |
| 824 | + out[mask_dest] = ranks[mask_src] |
| 825 | + |
| 826 | + mask_src = a > split_value |
| 827 | + mask_dest = ranks >= n_left |
| 828 | + out[mask_dest] = ranks[mask_src] |
| 829 | + |
| 830 | + if not (axis == -1 or axis == a.ndim - 1): |
| 831 | + out = xp.transpose(out, axis, -1) |
| 832 | + return out |
| 833 | + |
| 834 | + # Note: dask topk/argtopk sort the return values, so it's |
| 835 | + # not much more efficient than sorting everything when |
| 836 | + # kth is not small compared to x.size |
| 837 | + |
| 838 | + return _funcs.argpartition(a, kth, axis=axis, xp=xp) |
0 commit comments