diff --git a/stringdtype/stringdtype/__init__.py b/stringdtype/stringdtype/__init__.py index 9bf20952..7e727903 100644 --- a/stringdtype/stringdtype/__init__.py +++ b/stringdtype/stringdtype/__init__.py @@ -5,6 +5,10 @@ """ from .scalar import StringScalar # isort: skip -from ._main import StringDType +from ._main import StringDType, _memory_usage -__all__ = ["StringDType", "StringScalar"] +__all__ = [ + "StringDType", + "StringScalar", + "_memory_usage", +] diff --git a/stringdtype/stringdtype/src/dtype.h b/stringdtype/stringdtype/src/dtype.h index 7f7496ff..c10cf67b 100644 --- a/stringdtype/stringdtype/src/dtype.h +++ b/stringdtype/stringdtype/src/dtype.h @@ -26,4 +26,7 @@ new_stringdtype_instance(void); int init_string_dtype(void); +// from dtypemeta.h, not public in numpy +#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr)) + #endif /*_NPY_DTYPE_H*/ diff --git a/stringdtype/stringdtype/src/main.c b/stringdtype/stringdtype/src/main.c index d2591dc5..7cc79de6 100644 --- a/stringdtype/stringdtype/src/main.c +++ b/stringdtype/stringdtype/src/main.c @@ -6,12 +6,81 @@ #include "numpy/experimental_dtype_api.h" #include "dtype.h" +#include "static_string.h" #include "umath.h" +static PyObject * +_memory_usage(PyObject *NPY_UNUSED(self), PyObject *obj) +{ + if (!PyArray_Check(obj)) { + PyErr_SetString(PyExc_TypeError, + "can only be called with ndarray object"); + return NULL; + } + + PyArrayObject *arr = (PyArrayObject *)obj; + + PyArray_Descr *descr = PyArray_DESCR(arr); + PyArray_DTypeMeta *dtype = NPY_DTYPE(descr); + + if (dtype != &StringDType) { + PyErr_SetString(PyExc_TypeError, + "can only be called with a StringDType array"); + return NULL; + } + + NpyIter *iter = + NpyIter_New(arr, NPY_ITER_READONLY | NPY_ITER_EXTERNAL_LOOP, + NPY_KEEPORDER, NPY_NO_CASTING, NULL); + + if (iter == NULL) { + return NULL; + } + + NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL); + + if (iternext == NULL) { + NpyIter_Deallocate(iter); + return NULL; + } + + char **dataptr = NpyIter_GetDataPtrArray(iter); + npy_intp *strideptr = NpyIter_GetInnerStrideArray(iter); + npy_intp *innersizeptr = NpyIter_GetInnerLoopSizePtr(iter); + + // initialize with the size of the internal buffer + size_t memory_usage = PyArray_NBYTES(arr); + size_t struct_size = sizeof(ss); + + do { + ss **in = (ss **)*dataptr; + npy_intp stride = *strideptr / descr->elsize; + npy_intp count = *innersizeptr; + + while (count--) { + // +1 byte for the null terminator + memory_usage += (*in)->len + struct_size + 1; + in += stride; + } + + } while (iternext(iter)); + + PyObject *ret = PyLong_FromSize_t(memory_usage); + + return ret; +} + +static PyMethodDef string_methods[] = { + {"_memory_usage", _memory_usage, METH_O, + "get memory usage for an array"}, + {NULL}, +}; + static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, .m_name = "stringdtype_main", .m_size = -1, + .m_methods = string_methods, }; /* Module initialization function */ diff --git a/stringdtype/tests/test_stringdtype.py b/stringdtype/tests/test_stringdtype.py index df3d468c..7677571e 100644 --- a/stringdtype/tests/test_stringdtype.py +++ b/stringdtype/tests/test_stringdtype.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from stringdtype import StringDType, StringScalar +from stringdtype import StringDType, StringScalar, _memory_usage @pytest.fixture @@ -111,3 +111,16 @@ def test_isnan(string_list): np.testing.assert_array_equal( np.isnan(sarr), np.zeros_like(sarr, dtype=np.bool_) ) + + +def test_memory_usage(string_list): + sarr = np.array(string_list, dtype=StringDType()) + # 4 bytes for each ASCII string buffer in string_list + # (three characters and null terminator) + # plus enough bytes for the size_t length + # plus enough bytes for the pointer in the array buffer + assert _memory_usage(sarr) == (4 + 2 * np.dtype(np.uintp).itemsize) * 3 + with pytest.raises(TypeError): + _memory_usage("hello") + with pytest.raises(TypeError): + _memory_usage(np.array([1, 2, 3]))