Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions stringdtype/stringdtype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
"""

from .scalar import StringScalar # isort: skip
from ._main import StringDType
from ._main import StringDType, _memory_usage

__all__ = ["StringDType", "StringScalar"]
__all__ = [
"StringDType",
"StringScalar",
"_memory_usage",
]
3 changes: 3 additions & 0 deletions stringdtype/stringdtype/src/dtype.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,7 @@ new_stringdtype_instance(void);
int
init_string_dtype(void);

// from dtypemeta.h, not public in numpy
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))

#endif /*_NPY_DTYPE_H*/
69 changes: 69 additions & 0 deletions stringdtype/stringdtype/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,81 @@
#include "numpy/experimental_dtype_api.h"

#include "dtype.h"
#include "static_string.h"
#include "umath.h"

static PyObject *
_memory_usage(PyObject *NPY_UNUSED(self), PyObject *obj)
{
if (!PyArray_Check(obj)) {
PyErr_SetString(PyExc_TypeError,
"can only be called with ndarray object");
return NULL;
}

PyArrayObject *arr = (PyArrayObject *)obj;

PyArray_Descr *descr = PyArray_DESCR(arr);
PyArray_DTypeMeta *dtype = NPY_DTYPE(descr);

if (dtype != &StringDType) {
PyErr_SetString(PyExc_TypeError,
"can only be called with a StringDType array");
return NULL;
}

NpyIter *iter =
NpyIter_New(arr, NPY_ITER_READONLY | NPY_ITER_EXTERNAL_LOOP,
NPY_KEEPORDER, NPY_NO_CASTING, NULL);

if (iter == NULL) {
return NULL;
}

NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL);

if (iternext == NULL) {
NpyIter_Deallocate(iter);
return NULL;
}

char **dataptr = NpyIter_GetDataPtrArray(iter);
npy_intp *strideptr = NpyIter_GetInnerStrideArray(iter);
npy_intp *innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);

// initialize with the size of the internal buffer
size_t memory_usage = PyArray_NBYTES(arr);
size_t struct_size = sizeof(ss);

do {
ss **in = (ss **)*dataptr;
npy_intp stride = *strideptr / descr->elsize;
npy_intp count = *innersizeptr;

while (count--) {
// +1 byte for the null terminator
memory_usage += (*in)->len + struct_size + 1;
in += stride;
}

} while (iternext(iter));

PyObject *ret = PyLong_FromSize_t(memory_usage);

return ret;
}

static PyMethodDef string_methods[] = {
{"_memory_usage", _memory_usage, METH_O,
"get memory usage for an array"},
{NULL},
};

static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
.m_name = "stringdtype_main",
.m_size = -1,
.m_methods = string_methods,
};

/* Module initialization function */
Expand Down
15 changes: 14 additions & 1 deletion stringdtype/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from stringdtype import StringDType, StringScalar
from stringdtype import StringDType, StringScalar, _memory_usage


@pytest.fixture
Expand Down Expand Up @@ -111,3 +111,16 @@ def test_isnan(string_list):
np.testing.assert_array_equal(
np.isnan(sarr), np.zeros_like(sarr, dtype=np.bool_)
)


def test_memory_usage(string_list):
sarr = np.array(string_list, dtype=StringDType())
# 4 bytes for each ASCII string buffer in string_list
# (three characters and null terminator)
# plus enough bytes for the size_t length
# plus enough bytes for the pointer in the array buffer
assert _memory_usage(sarr) == (4 + 2 * np.dtype(np.uintp).itemsize) * 3
with pytest.raises(TypeError):
_memory_usage("hello")
with pytest.raises(TypeError):
_memory_usage(np.array([1, 2, 3]))