Skip to content

Commit

Permalink
Merge pull request #26012 from ngoldbaum/stringdtype-add-promoter
Browse files Browse the repository at this point in the history
ENH: install StringDType promoter for add
  • Loading branch information
mhvk committed Mar 14, 2024
2 parents a519781 + 5ec8a3d commit dda030f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
41 changes: 33 additions & 8 deletions numpy/_core/src/umath/stringdtype_ufuncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1043,9 +1043,9 @@ string_startswith_endswith_strided_loop(PyArrayMethod_Context *context,
}

static int
strip_chars_promoter(PyObject *NPY_UNUSED(ufunc),
PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *new_op_dtypes[])
all_strings_promoter(PyObject *NPY_UNUSED(ufunc),
PyArray_DTypeMeta *op_dtypes[], PyArray_DTypeMeta *signature[],
PyArray_DTypeMeta *new_op_dtypes[])
{
new_op_dtypes[0] = NPY_DT_NewRef(&PyArray_StringDType);
new_op_dtypes[1] = NPY_DT_NewRef(&PyArray_StringDType);
Expand Down Expand Up @@ -2312,6 +2312,28 @@ init_stringdtype_ufuncs(PyObject *umath)
return -1;
}

PyArray_DTypeMeta *rall_strings_promoter_dtypes[] = {
&PyArray_StringDType,
&PyArray_UnicodeDType,
&PyArray_StringDType,
};

if (add_promoter(umath, "add", rall_strings_promoter_dtypes, 3,
all_strings_promoter) < 0) {
return -1;
}

PyArray_DTypeMeta *lall_strings_promoter_dtypes[] = {
&PyArray_UnicodeDType,
&PyArray_StringDType,
&PyArray_StringDType,
};

if (add_promoter(umath, "add", lall_strings_promoter_dtypes, 3,
all_strings_promoter) < 0) {
return -1;
}

INIT_MULTIPLY(Int64, int64);
INIT_MULTIPLY(UInt64, uint64);

Expand Down Expand Up @@ -2446,10 +2468,6 @@ init_stringdtype_ufuncs(PyObject *umath)
"_lstrip_chars", "_rstrip_chars", "_strip_chars",
};

PyArray_DTypeMeta *strip_chars_promoter_dtypes[] = {
&PyArray_StringDType, &PyArray_UnicodeDType, &PyArray_StringDType
};

for (int i=0; i<3; i++) {
if (init_ufunc(umath, strip_chars_names[i], strip_chars_dtypes,
&strip_chars_resolve_descriptors,
Expand All @@ -2460,7 +2478,14 @@ init_stringdtype_ufuncs(PyObject *umath)
}

if (add_promoter(umath, strip_chars_names[i],
strip_chars_promoter_dtypes, 3, strip_chars_promoter) < 0) {
rall_strings_promoter_dtypes, 3,
all_strings_promoter) < 0) {
return -1;
}

if (add_promoter(umath, strip_chars_names[i],
lall_strings_promoter_dtypes, 3,
all_strings_promoter) < 0) {
return -1;
}
}
Expand Down
10 changes: 10 additions & 0 deletions numpy/_core/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,16 @@ def test_ufunc_add(dtype, string_list, other_strings, use_out):
np.add(arr1, arr2)


def test_add_promoter(string_list):
arr = np.array(string_list, dtype=StringDType())
lresult = np.array(["hello" + s for s in string_list], dtype=StringDType())
rresult = np.array([s + "hello" for s in string_list], dtype=StringDType())

for op in ["hello", np.str_("hello"), np.array(["hello"])]:
assert_array_equal(op + arr, lresult)
assert_array_equal(arr + op, rresult)


@pytest.mark.parametrize("use_out", [True, False])
@pytest.mark.parametrize("other", [2, [2, 1, 3, 4, 1, 3]])
@pytest.mark.parametrize(
Expand Down

0 comments on commit dda030f

Please sign in to comment.