Skip to content

Commit

Permalink
ENH: add stringdtype partition/rpartition
Browse files Browse the repository at this point in the history
  • Loading branch information
ngoldbaum committed Mar 22, 2024
1 parent 481a013 commit edd4f11
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 32 deletions.
20 changes: 20 additions & 0 deletions numpy/_core/src/multiarray/stringdtype/utf8_utils.c
Expand Up @@ -299,6 +299,26 @@ num_codepoints_for_utf8_bytes(const unsigned char *s, size_t *num_codepoints, si
return state != UTF8_ACCEPT;
}

NPY_NO_EXPORT npy_int64
num_bytes_until_index(char *buf, size_t buffer_size, npy_int64 index) {
size_t bytes_consumed = 0;
size_t num_codepoints = 0;

while (bytes_consumed < buffer_size && num_codepoints < (size_t) index) {
size_t num_bytes = num_bytes_for_utf8_character((const unsigned char *)buf);
num_codepoints += 1;
bytes_consumed += num_bytes;
buf += num_bytes;
}

if (num_codepoints < (size_t) index) {
// didn't hit the requested index
return -1;
}

return bytes_consumed;
}

NPY_NO_EXPORT void
find_start_end_locs(char* buf, size_t buffer_size, npy_int64 start_index, npy_int64 end_index,
char **start_loc, char **end_loc) {
Expand Down
3 changes: 3 additions & 0 deletions numpy/_core/src/multiarray/stringdtype/utf8_utils.h
Expand Up @@ -39,6 +39,9 @@ utf8_character_index(
const char* start_loc, size_t start_byte_offset, size_t start_index,
size_t search_byte_offset, size_t buffer_size);

NPY_NO_EXPORT npy_int64
num_bytes_until_index(char *buf, size_t buffer_size, npy_int64 index);

#ifdef __cplusplus
}
#endif
Expand Down
30 changes: 23 additions & 7 deletions numpy/_core/src/umath/string_buffer.h
Expand Up @@ -1600,16 +1600,31 @@ string_partition(Buffer<enc> buf1, Buffer<enc> buf2, npy_int64 idx,
npy_intp *final_len1, npy_intp *final_len2, npy_intp *final_len3,
STARTPOSITION pos)
{
size_t len1 = buf1.num_codepoints();
size_t len2 = buf2.num_codepoints();
size_t len1, len2;
npy_int64 offset;
if (enc == ENCODING::UTF8) {
len1 = buf1.after - buf1.buf;
len2 = buf2.after - buf2.buf;
if (idx > 0) {
offset = num_bytes_until_index(buf1.buf, len1, idx);
}
else {
offset = idx;
}
}
else {
len1 = buf1.num_codepoints();
len2 = buf2.num_codepoints();
offset = (size_t)idx;
}

if (len2 == 0) {
npy_gil_error(PyExc_ValueError, "empty separator");
*final_len1 = *final_len2 = *final_len3 = -1;
return;
}

if (idx < 0) {
if (offset < 0) {
if (pos == STARTPOSITION::FRONT) {
buf1.buffer_memcpy(out1, len1);
*final_len1 = len1;
Expand All @@ -1623,12 +1638,13 @@ string_partition(Buffer<enc> buf1, Buffer<enc> buf2, npy_int64 idx,
return;
}

buf1.buffer_memcpy(out1, idx);
*final_len1 = idx;
buf1.buffer_memcpy(out1, offset);
*final_len1 = offset;
buf2.buffer_memcpy(out2, len2);
*final_len2 = len2;
(buf1 + idx + len2).buffer_memcpy(out3, len1 - idx - len2);
*final_len3 = len1 - idx - len2;
buf1.advance_chars_or_bytes(offset + len2);
buf1.buffer_memcpy(out3, len1 - offset - len2);
*final_len3 = len1 - offset - len2;
}


Expand Down
185 changes: 185 additions & 0 deletions numpy/_core/src/umath/stringdtype_ufuncs.cpp
Expand Up @@ -1881,6 +1881,167 @@ zfill_strided_loop(PyArrayMethod_Context *context,
return -1;
}

static NPY_CASTING
string_partition_resolve_descriptors(
PyArrayMethodObject *self,
PyArray_DTypeMeta *NPY_UNUSED(dtypes[3]),
PyArray_Descr *given_descrs[3],
PyArray_Descr *loop_descrs[3],
npy_intp *NPY_UNUSED(view_offset))
{
if (given_descrs[3] || given_descrs[4] || given_descrs[5]) {
PyErr_Format(PyExc_TypeError, "The StringDType '%s' ufunc does not "
"support the 'out' keyword", self->name);
return (NPY_CASTING)-1;
}
for (int i=0; i<3; i++) {
Py_INCREF(given_descrs[i]);
loop_descrs[i] = given_descrs[i];
}
PyArray_StringDTypeObject *adescr = (PyArray_StringDTypeObject *)given_descrs[0];
for (int i=3; i<6; i++) {
loop_descrs[i] = (PyArray_Descr *)new_stringdtype_instance(
adescr->na_object, adescr->coerce);
if (loop_descrs[i] == NULL) {
return (NPY_CASTING)-1;
}
}

return NPY_NO_CASTING;
}

NPY_NO_EXPORT int
string_partition_strided_loop(
PyArrayMethod_Context *context,
char *const data[],
npy_intp const dimensions[],
npy_intp const strides[],
NpyAuxData *NPY_UNUSED(auxdata))
{
STARTPOSITION startposition = *(STARTPOSITION *)(context->method->static_data);

npy_intp N = dimensions[0];

char *in1 = data[0];
char *in2 = data[1];
char *in3 = data[2];
char *out1 = data[3];
char *out2 = data[4];
char *out3 = data[5];

npy_intp in1_stride = strides[0];
npy_intp in2_stride = strides[1];
npy_intp in3_stride = strides[2];
npy_intp out1_stride = strides[3];
npy_intp out2_stride = strides[4];
npy_intp out3_stride = strides[5];

npy_string_allocator *allocators[6] = {};
NpyString_acquire_allocators(6, context->descriptors, allocators);
npy_string_allocator *in1allocator = allocators[0];
npy_string_allocator *in2allocator = allocators[1];
// allocators[2] is NULL
npy_string_allocator *out1allocator = allocators[3];
npy_string_allocator *out2allocator = allocators[4];
npy_string_allocator *out3allocator = allocators[5];

PyArray_StringDTypeObject *idescr =
(PyArray_StringDTypeObject *)context->descriptors[0];
int has_string_na = idescr->has_string_na;
const npy_static_string *default_string = &idescr->default_string;

while (N--) {
const npy_packed_static_string *i1ps = (npy_packed_static_string *)in1;
npy_static_string i1s = {0, NULL};
const npy_packed_static_string *i2ps = (npy_packed_static_string *)in2;
npy_static_string i2s = {0, NULL};
npy_packed_static_string *o1ps = (npy_packed_static_string *)out1;
npy_packed_static_string *o2ps = (npy_packed_static_string *)out2;
npy_packed_static_string *o3ps = (npy_packed_static_string *)out3;

int i1_isnull = NpyString_load(in1allocator, i1ps, &i1s);
int i2_isnull = NpyString_load(in2allocator, i2ps, &i2s);

if (i1_isnull == -1 || i2_isnull == -1) {
npy_gil_error(PyExc_MemoryError, "Failed to load string in %s",
((PyUFuncObject *)context->caller)->name);
goto fail;
}
else if (NPY_UNLIKELY(i1_isnull || i2_isnull)) {
if (!has_string_na) {
npy_gil_error(PyExc_ValueError,
"Null values are not supported in %s",
((PyUFuncObject *)context->caller)->name);
goto fail;
}
else {
if (i1_isnull) {
i1s = *default_string;
}
if (i2_isnull) {
i2s = *default_string;
}
}
}

Buffer<ENCODING::UTF8> in1buf((char *)i1s.buf, i1s.size);
Buffer<ENCODING::UTF8> in2buf((char *)i2s.buf, i2s.size);

// out1 and out3 can be no longer than in1, so conservatively
// overallocate buffers big enough to store in1
// out2 must be no bigger than in2
char *out1_mem = (char *)PyMem_RawCalloc(i1s.size, 1);
char *out2_mem = (char *)PyMem_RawCalloc(i2s.size, 1);
char *out3_mem = (char *)PyMem_RawCalloc(i1s.size, 1);

Buffer<ENCODING::UTF8> out1buf(out1_mem, i1s.size);
Buffer<ENCODING::UTF8> out2buf(out2_mem, i2s.size);
Buffer<ENCODING::UTF8> out3buf(out3_mem, i1s.size);

npy_intp final_len1, final_len2, final_len3;

string_partition(in1buf, in2buf, *(npy_int64 *)in3, out1buf, out2buf, out3buf,
&final_len1, &final_len2, &final_len3, startposition);

if (final_len1 < 0 || final_len2 < 0 || final_len3 < 0) {
goto fail;
}

if (NpyString_pack(out1allocator, o1ps, out1buf.buf, final_len1) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to pack string in %s",
((PyUFuncObject *)context->caller)->name);
}

if (NpyString_pack(out2allocator, o2ps, out2buf.buf, final_len2) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to pack string in %s",
((PyUFuncObject *)context->caller)->name);
}

if (NpyString_pack(out3allocator, o3ps, out3buf.buf, final_len3) < 0) {
npy_gil_error(PyExc_MemoryError,
"Failed to pack string in %s",
((PyUFuncObject *)context->caller)->name);
}

in1 += in1_stride;
in2 += in2_stride;
in3 += in3_stride;
out1 += out1_stride;
out2 += out2_stride;
out3 += out3_stride;
}

NpyString_release_allocators(6, allocators);
return 0;

fail:

NpyString_release_allocators(6, allocators);
return -1;
}

NPY_NO_EXPORT int
string_inputs_promoter(
PyObject *ufunc_obj, PyArray_DTypeMeta *op_dtypes[],
Expand Down Expand Up @@ -2625,5 +2786,29 @@ init_stringdtype_ufuncs(PyObject *umath)
return -1;
}

PyArray_DTypeMeta *partition_dtypes[] = {
&PyArray_StringDType,
&PyArray_StringDType,
&PyArray_Int64DType,
&PyArray_StringDType,
&PyArray_StringDType,
&PyArray_StringDType
};

const char *partition_names[] = {"_partition", "_rpartition"};

static STARTPOSITION partition_startpositions[] = {
STARTPOSITION::FRONT, STARTPOSITION::BACK
};

for (int i=0; i<2; i++) {
if (init_ufunc(umath, partition_names[i], partition_dtypes,
string_partition_resolve_descriptors,
string_partition_strided_loop, 3, 3, NPY_NO_CASTING,
(NPY_ARRAYMETHOD_FLAGS) 0, &partition_startpositions[i]) < 0) {
return -1;
}
}

return 0;
}
16 changes: 10 additions & 6 deletions numpy/_core/strings.py
Expand Up @@ -1359,10 +1359,11 @@ def partition(a, sep):
"""
a = np.asanyarray(a)
sep = np.asanyarray(sep)
sep = np.asanyarray(sep).astype(a.dtype)
pos = _find_ufunc(a, sep, 0, MAX).astype('int64')

shape = np.broadcast_shapes(a.shape, sep.shape)
pos = _find_ufunc(a, sep, 0, MAX)
if a.dtype.char == "T":
return _partition(a, sep, pos)

a_len = str_len(a)
sep_len = str_len(sep)
Expand All @@ -1376,6 +1377,7 @@ def partition(a, sep):
1 if np.all(not_found) else sep_len.max(),
buffersizes3.max(),
)])
shape = np.broadcast_shapes(a.shape, sep.shape)
out = np.empty_like(a, shape=shape, dtype=out_dtype)
return _partition(a, sep, pos, out=(out["f0"], out["f1"], out["f2"]))

Expand Down Expand Up @@ -1422,10 +1424,11 @@ def rpartition(a, sep):
"""
a = np.asanyarray(a)
sep = np.asanyarray(sep)
sep = np.asanyarray(sep).astype(a.dtype)
pos = _rfind_ufunc(a, sep, 0, MAX).astype('int64')

shape = np.broadcast_shapes(a.shape, sep.shape)
pos = _rfind_ufunc(a, sep, 0, MAX)
if a.dtype.char == "T":
return _rpartition(a, sep, pos)

a_len = str_len(a)
sep_len = str_len(sep)
Expand All @@ -1439,6 +1442,7 @@ def rpartition(a, sep):
1 if np.all(not_found) else sep_len.max(),
buffersizes3.max(),
)])
shape = np.broadcast_shapes(a.shape, sep.shape)
out = np.empty_like(a, shape=shape, dtype=out_dtype)
return _rpartition(a, sep, pos, out=(out["f0"], out["f1"], out["f2"]))

Expand Down
19 changes: 0 additions & 19 deletions numpy/_core/tests/test_strings.py
Expand Up @@ -795,16 +795,6 @@ def test_zfill(self, buf, width, res, dt):
res = np.array(res, dtype=dt)
assert_array_equal(np.strings.zfill(buf, width), res)


@pytest.mark.parametrize("dt", [
"U",
"S",
pytest.param("T", marks=pytest.mark.xfail(
reason="StringDType support not implemented",
strict=True),
),
])
class TestMethodsWithoutStringDTypeSupport:
@pytest.mark.parametrize("buf,sep,res1,res2,res3", [
("this is the partition method", "ti", "this is the par",
"ti", "tion method"),
Expand Down Expand Up @@ -1032,15 +1022,6 @@ def test_rjust(self, buf, width, fillchar, res, dt):
res = np.array(res, dtype=dt)
assert_array_equal(np.strings.rjust(buf, width, fillchar), res)


@pytest.mark.parametrize("dt", [
"U",
pytest.param("T", marks=pytest.mark.xfail(
reason="StringDType support not implemented",
strict=True),
),
])
class TestMethodsWithUnicodeWithoutStringDTypeSupport:
@pytest.mark.parametrize("buf,sep,res1,res2,res3", [
("āāāāĀĀĀĀ", "Ă", "āāāāĀĀĀĀ", "", ""),
("āāāāĂĀĀĀĀ", "Ă", "āāāā", "Ă", "ĀĀĀĀ"),
Expand Down

0 comments on commit edd4f11

Please sign in to comment.