Skip to content

Commit

Permalink
ENH: Create string dtype instances from the abstract dtype (#22923)
Browse files Browse the repository at this point in the history
Following up from #22863 (comment), this makes it possible to create string dtype instances from an abstract string DTypeMeta.

Co-authored-by: Sebastian Berg <sebastianb@nvidia.com>
  • Loading branch information
ngoldbaum and seberg committed Jan 5, 2023
1 parent c1697e0 commit 3b5ba53
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
7 changes: 7 additions & 0 deletions doc/release/upcoming_changes/22863.new_feature.rst
@@ -0,0 +1,7 @@
String dtype instances can be created from the string abstract dtype classes
----------------------------------------------------------------------------
It is now possible to create a string dtype instance with a size without
using the string name of the dtype. For example, ``type(np.dtype('U'))(8)``
will create a dtype that is equivalent to ``np.dtype('U8')``. This feature
is most useful when writing generic code dealing with string dtype
classes.
49 changes: 48 additions & 1 deletion numpy/core/src/multiarray/dtypemeta.c
Expand Up @@ -18,6 +18,8 @@
#include "scalartypes.h"
#include "convert_datatype.h"
#include "usertypes.h"
#include "conversion_utils.h"
#include "templ_common.h"

#include <assert.h>

Expand Down Expand Up @@ -122,6 +124,50 @@ legacy_dtype_default_new(PyArray_DTypeMeta *self,
return (PyObject *)self->singleton;
}

static PyObject *
string_unicode_new(PyArray_DTypeMeta *self, PyObject *args, PyObject *kwargs)
{
npy_intp size;

static char *kwlist[] = {"", NULL};

if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O&", kwlist,
PyArray_IntpFromPyIntConverter, &size)) {
return NULL;
}

if (size < 0) {
PyErr_Format(PyExc_ValueError,
"Strings cannot have a negative size but a size of "
"%"NPY_INTP_FMT" was given", size);
return NULL;
}

PyArray_Descr *res = PyArray_DescrNewFromType(self->type_num);

if (res == NULL) {
return NULL;
}

if (self->type_num == NPY_UNICODE) {
// unicode strings are 4 bytes per character
if (npy_mul_sizes_with_overflow(&size, size, 4)) {
PyErr_SetString(
PyExc_TypeError,
"Strings too large to store inside array.");
return NULL;
}
}

if (size > NPY_MAX_INT) {
PyErr_SetString(PyExc_TypeError,
"Strings too large to store inside array.");
return NULL;
}

res->elsize = (int)size;
return (PyObject *)res;
}

static PyArray_Descr *
nonparametric_discover_descr_from_pyobject(
Expand Down Expand Up @@ -151,7 +197,7 @@ string_discover_descr_from_pyobject(
}
if (itemsize > NPY_MAX_INT) {
PyErr_SetString(PyExc_TypeError,
"string to large to store inside array.");
"string too large to store inside array.");
}
PyArray_Descr *res = PyArray_DescrNewFromType(cls->type_num);
if (res == NULL) {
Expand Down Expand Up @@ -849,6 +895,7 @@ dtypemeta_wrap_legacy_descriptor(PyArray_Descr *descr)
string_discover_descr_from_pyobject);
dt_slots->common_dtype = string_unicode_common_dtype;
dt_slots->common_instance = string_unicode_common_instance;
((PyTypeObject*)dtype_class)->tp_new = (newfunc)string_unicode_new;
}
}

Expand Down
28 changes: 28 additions & 0 deletions numpy/core/tests/test_dtype.py
Expand Up @@ -195,6 +195,34 @@ def test_field_order_equality(self):
# This is an safe cast (not equiv) due to the different names:
assert np.can_cast(x, y, casting="safe")

@pytest.mark.parametrize(
["type_char", "char_size", "scalar_type"],
[["U", 4, np.str_],
["S", 1, np.bytes_]])
def test_create_string_dtypes_directly(
self, type_char, char_size, scalar_type):
dtype_class = type(np.dtype(type_char))

dtype = dtype_class(8)
assert dtype.type is scalar_type
assert dtype.itemsize == 8*char_size

def test_create_invalid_string_errors(self):
one_too_big = np.iinfo(np.intc).max + 1
with pytest.raises(TypeError):
type(np.dtype("U"))(one_too_big // 4)

with pytest.raises(TypeError):
# Code coverage for very large numbers:
type(np.dtype("U"))(np.iinfo(np.intp).max // 4 + 1)

if one_too_big < sys.maxsize:
with pytest.raises(TypeError):
type(np.dtype("S"))(one_too_big)

with pytest.raises(ValueError):
type(np.dtype("U"))(-1)


class TestRecord:
def test_equivalent_record(self):
Expand Down

0 comments on commit 3b5ba53

Please sign in to comment.