Skip to content

Commit

Permalink
ENH: Add partition/rpartition ufunc for string dtypes
Browse files Browse the repository at this point in the history
Closes #25993.
  • Loading branch information
lysnikolaou committed Mar 19, 2024
1 parent 6059db1 commit 5993849
Show file tree
Hide file tree
Showing 8 changed files with 552 additions and 48 deletions.
10 changes: 10 additions & 0 deletions numpy/_core/code_generators/generate_umath.py
Expand Up @@ -1300,6 +1300,16 @@ def english_upper(s):
docstrings.get('numpy._core.umath._zfill'),
None,
),
'_partition':
Ufunc(2, 3, None,
docstrings.get('numpy._core.umath._partition'),
None,
),
'_rpartition':
Ufunc(2, 3, None,
docstrings.get('numpy._core.umath._rpartition'),
None,
),
}

def indent(st, spaces):
Expand Down
82 changes: 82 additions & 0 deletions numpy/_core/code_generators/ufunc_docstrings.py
Expand Up @@ -5028,3 +5028,85 @@ def add_newdoc(place, name, doc):
array(['001', '-01', '+01'], dtype='<U3')
""")

add_newdoc('numpy._core.umath', '_partition',
"""
Partition each element in ``x1`` around ``x2``.
For each element in ``x1``, split the element at the first
occurrence of ``x2``, and return a 3-tuple containing the part
before the separator, a boolean signifying whether the separator
was found, and the part after the separator. If the separator is
not found, the first part will contain the whole string,
the boolean will be false, and the third part will be the empty
string.
Parameters
----------
x1 : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Input array
x2 : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Separator to split each string element in ``x1``.
Returns
-------
out : 3-tuple:
- ``StringDType``, ``bytes_`` or ``str_`` dtype string with the part
before the separator
- ``bool_`` dtype, whether the separator was found
- ``StringDType``, ``bytes_`` or ``str_`` dtype string with the part
after the separator
See Also
--------
str.partition
Examples
--------
>>> x = np.array(["Numpy is nice!"])
>>> np.strings.partition(x, " ")
array([['Numpy', ' ', 'is nice!']], dtype='<U8')
""")

add_newdoc('numpy._core.umath', '_rpartition',
"""
Partition (split) each element around the right-most separator.
For each element in ``x1``, split the element at the first
occurrence of ``x2``, and return a 3-tuple containing the part
before the separator, a boolean signifying whether the separator
was found, and the part after the separator. If the separator is
not found, the first part will contain the whole string,
the boolean will be false, and the third part will be the empty
string.
Parameters
----------
x1 : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Input array
x2 : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Separator to split each string element in ``x1``.
Returns
-------
out : 3-tuple:
- ``StringDType``, ``bytes_`` or ``str_`` dtype string with the part
before the separator
- ``bool_`` dtype, whether the separator was found
- ``StringDType``, ``bytes_`` or ``str_`` dtype string with the part
after the separator
See Also
--------
str.rpartition
Examples
--------
>>> a = np.array(['aAaAaA', ' aA ', 'abBABba'])
>>> np.strings.rpartition(a, 'A')
array([['aAaAa', 'A', ''],
[' a', 'A', ' '],
['abB', 'A', 'Bba']], dtype='<U5')
""")
90 changes: 87 additions & 3 deletions numpy/_core/defchararray.py
Expand Up @@ -18,19 +18,19 @@
import functools

from .._utils import set_module
from .numerictypes import bytes_, str_, character
from .numerictypes import bytes_, str_, character, object_
from .numeric import ndarray, array as narray, asarray as asnarray
from numpy._core.multiarray import compare_chararrays
from numpy._core import overrides
from numpy.strings import *
from numpy.strings import multiply as strings_multiply
from numpy._core.strings import (
_partition as partition,
_rpartition as rpartition,
_split as split,
_rsplit as rsplit,
_splitlines as splitlines,
_join as join,
_to_bytes_or_str_array,
_vec_string,
)

__all__ = [
Expand Down Expand Up @@ -303,6 +303,90 @@ def multiply(a, i):
raise ValueError("Can only multiply by integers")


def partition(a, sep):
"""
Partition each element in `a` around `sep`.
Calls :meth:`str.partition` element-wise.
For each element in `a`, split the element as the first
occurrence of `sep`, and return 3 strings containing the part
before the separator, the separator itself, and the part after
the separator. If the separator is not found, return 3 strings
containing the string itself, followed by two empty strings.
Parameters
----------
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Input array
sep : {str, unicode}
Separator to split each string element in `a`.
Returns
-------
out : ndarray
Output array of ``StringDType``, ``bytes_`` or ``str_`` dtype,
depending on input types. The output array will have an extra
dimension with 3 elements per input element.
Examples
--------
>>> x = np.array(["Numpy is nice!"])
>>> np.strings.partition(x, " ") # doctest: +SKIP
array([['Numpy', ' ', 'is nice!']], dtype='<U8') # doctest: +SKIP
See Also
--------
str.partition
"""
return _to_bytes_or_str_array(
_vec_string(a, object_, 'partition', (sep,)), a)


def rpartition(a, sep):
"""
Partition (split) each element around the right-most separator.
Calls :meth:`str.rpartition` element-wise.
For each element in `a`, split the element as the last
occurrence of `sep`, and return 3 strings containing the part
before the separator, the separator itself, and the part after
the separator. If the separator is not found, return 3 strings
containing the string itself, followed by two empty strings.
Parameters
----------
a : array-like, with ``StringDType``, ``bytes_``, or ``str_`` dtype
Input array
sep : str or unicode
Right-most separator to split each element in array.
Returns
-------
out : ndarray
Output array of ``StringDType``, ``bytes_`` or ``str_`` dtype,
depending on input types. The output array will have an extra
dimension with 3 elements per input element.
See Also
--------
str.rpartition
Examples
--------
>>> a = np.array(['aAaAaA', ' aA ', 'abBABba'])
>>> np.strings.rpartition(a, 'A') # doctest: +SKIP
array([['aAaAa', 'A', ''], # doctest: +SKIP
[' a', 'A', ' '], # doctest: +SKIP
['abB', 'A', 'Bba']], dtype='<U5') # doctest: +SKIP
"""
return _to_bytes_or_str_array(
_vec_string(a, object_, 'rpartition', (sep,)), a)


@set_module("numpy.char")
class chararray(ndarray):
"""
Expand Down
60 changes: 60 additions & 0 deletions numpy/_core/src/umath/string_buffer.h
Expand Up @@ -1593,4 +1593,64 @@ string_zfill(Buffer<enc> buf, npy_int64 width, Buffer<enc> out)
}


template <ENCODING enc>
static inline npy_bool
string_partition(Buffer<enc> buf1, Buffer<enc> buf2,
Buffer<enc> out1, Buffer<enc> out2,
npy_intp *final_len1, npy_intp *final_len2,
STARTPOSITION pos)
{
size_t len1 = buf1.num_codepoints();
size_t len2 = buf2.num_codepoints();

if (len2 == 0) {
npy_gil_error(PyExc_ValueError, "empty separator");
*final_len1 = *final_len2 = -1;
return false;
}

if (len1 < len2) {
buf1.buffer_memcpy(out1, len1);
*final_len1 = len1;
*final_len2 = 0;
return false;
}

npy_intp idx;
switch(enc) {
case ENCODING::UTF8:
assert(0); // TODO
break;
case ENCODING::ASCII:
idx = fastsearch(buf1.buf, len1, buf2.buf, len2, -1,
pos == STARTPOSITION::FRONT ? FAST_SEARCH : FAST_RSEARCH);
break;
case ENCODING::UTF32:
idx = fastsearch((npy_ucs4 *)buf1.buf, len1, (npy_ucs4 *)buf2.buf, len2, -1,
pos == STARTPOSITION::FRONT ? FAST_SEARCH : FAST_RSEARCH);
break;
}

if (idx < 0) {
if (pos == STARTPOSITION::FRONT) {
buf1.buffer_memcpy(out1, len1);
*final_len1 = len1;
*final_len2 = 0;
}
else {
buf1.buffer_memcpy(out2, len1);
*final_len1 = 0;
*final_len2 = len1;
}
return false;
}

buf1.buffer_memcpy(out1, idx);
*final_len1 = idx;
(buf1 + idx + len2).buffer_memcpy(out2, len1 - idx - len2);
*final_len2 = len1 - idx - len2;
return true;
}


#endif /* _NPY_CORE_SRC_UMATH_STRING_BUFFER_H_ */

0 comments on commit 5993849

Please sign in to comment.