Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,11 @@ repos:
- id: check-added-large-files
- id: check-ast
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.217
rev: v0.0.254
hooks:
- id: ruff
# Respect `exclude` and `extend-exclude` settings.
args: ["--force-exclude"]
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.4
rev: v3.0.0-alpha.6
hooks:
- id: prettier
types:
Expand All @@ -88,7 +86,7 @@ repos:
yaml,
]
- repo: https://github.com/pycqa/isort
rev: 5.11.4
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
Expand All @@ -99,7 +97,7 @@ repos:
name: isort (pyi)
types: [pyi]
- repo: https://github.com/psf/black
rev: 22.12.0
rev: 23.1.0
hooks:
- id: black
name: 'black for asciidtype'
Expand Down
26 changes: 22 additions & 4 deletions stringdtype/stringdtype/src/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,30 @@ nonzero(void *data, void *NPY_UNUSED(arr))
// Implementation of PyArray_CompareFunc.
// Compares unicode strings by their code points.
int
compare_strings(char **a, char **b, PyArrayObject *NPY_UNUSED(arr))
compare(void *a, void *b, void *NPY_UNUSED(arr))
{
ss *ss_a = (ss *)a;
ss *ss_b = (ss *)b;
ss *ss_a = NULL;
ss *ss_b = NULL;
load_string(a, &ss_a);
load_string(b, &ss_b);
return strcmp(ss_a->buf, ss_b->buf);
}

// PyArray_ArgFunc
// The max element is the one with the highest unicode code point.
int
argmax(void *data, npy_intp n, npy_intp *max_ind, void *arr)
{
ss *dptr = (ss *)data;
*max_ind = 0;
for (int i = 1; i < n; i++) {
if (compare(&dptr[i], &dptr[*max_ind], arr) > 0) {
*max_ind = i;
}
}
return 0;
}

static StringDTypeObject *
stringdtype_ensure_canonical(StringDTypeObject *self)
{
Expand Down Expand Up @@ -232,8 +249,9 @@ static PyType_Slot StringDType_Slots[] = {
{NPY_DT_setitem, &stringdtype_setitem},
{NPY_DT_getitem, &stringdtype_getitem},
{NPY_DT_ensure_canonical, &stringdtype_ensure_canonical},
{NPY_DT_PyArray_ArrFuncs_compare, &compare_strings},
{NPY_DT_PyArray_ArrFuncs_nonzero, &nonzero},
{NPY_DT_PyArray_ArrFuncs_compare, &compare},
{NPY_DT_PyArray_ArrFuncs_argmax, &argmax},
{NPY_DT_get_clear_loop, &stringdtype_get_clear_loop},
{0, NULL}};

Expand Down
30 changes: 30 additions & 0 deletions stringdtype/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,33 @@ def test_creation_functions():

def test_is_numeric():
assert not StringDType._is_numeric


@pytest.mark.parametrize(
"strings",
[
["left", "right", "leftovers", "righty", "up", "down"],
["🤣🤣", "🤣", "📵", "😰"],
["🚜", "🙃", "😾"],
["😹", "🚠", "🚌"],
["A¢☃€ 😊", " A☃€¢😊", "☃€😊 A¢", "😊☃A¢ €"],
],
)
def test_argmax(strings):
"""Test that argmax matches what python calculates as the argmax."""
arr = np.array(strings, dtype=StringDType())
assert np.argmax(arr) == strings.index(max(strings))


@pytest.mark.parametrize(
"arrfunc,expected",
[
[np.sort, np.empty(10, dtype=StringDType())],
[np.nonzero, (np.array([], dtype=np.int64),)],
[np.argmax, 0],
],
)
def test_arrfuncs_empty(arrfunc, expected):
arr = np.empty(10, dtype=StringDType())
result = arrfunc(arr)
np.testing.assert_array_equal(result, expected, strict=True)