diff --git a/asciidtype/asciidtype/src/dtype.c b/asciidtype/asciidtype/src/dtype.c index 6d93a890..90a9695b 100644 --- a/asciidtype/asciidtype/src/dtype.c +++ b/asciidtype/asciidtype/src/dtype.c @@ -55,20 +55,20 @@ new_asciidtype_instance(long size) } /* - * This is used to determine the correct dtype to return when operations mix - * dtypes (I think?). For now just return the first one. + * This is used to determine the correct dtype to return when dealing + * with a mix of different dtypes (for example when creating an array + * from a list of scalars). Always return the dtype with the biggest + * size. */ static ASCIIDTypeObject * common_instance(ASCIIDTypeObject *dtype1, ASCIIDTypeObject *dtype2) { - if (!PyObject_RichCompareBool((PyObject *)dtype1, (PyObject *)dtype2, - Py_EQ)) { - PyErr_SetString( - PyExc_RuntimeError, - "common_instance called on unequal ASCIIDType instances"); - return NULL; + if (dtype1->size >= dtype2->size) { + Py_INCREF(dtype1); + return dtype1; } - return dtype1; + Py_INCREF(dtype2); + return dtype2; } static PyArray_DTypeMeta * diff --git a/asciidtype/tests/test_asciidtype.py b/asciidtype/tests/test_asciidtype.py index 3e507bdb..06ec4c5a 100644 --- a/asciidtype/tests/test_asciidtype.py +++ b/asciidtype/tests/test_asciidtype.py @@ -24,6 +24,15 @@ def test_creation_with_explicit_dtype(): ) +def test_creation_from_scalar(): + data = [ + ASCIIScalar("hello", ASCIIDType(6)), + ASCIIScalar("array", ASCIIDType(7)), + ] + arr = np.array(data) + assert repr(arr) == ("array(['hello', 'array'], dtype=ASCIIDType(7))") + + def test_creation_truncation(): inp = ["hello", "this", "is", "an", "array"] diff --git a/metadatadtype/metadatadtype/src/dtype.c b/metadatadtype/metadatadtype/src/dtype.c index e9140562..3ec9a606 100644 --- a/metadatadtype/metadatadtype/src/dtype.c +++ b/metadatadtype/metadatadtype/src/dtype.c @@ -52,6 +52,7 @@ get_metadata(PyObject *scalar) } PyObject *metadata = dtype->metadata; + Py_DECREF(dtype); if (metadata == NULL) { return NULL; } @@ -87,6 +88,7 @@ new_metadatadtype_instance(PyObject *metadata) static MetadataDTypeObject * common_instance(MetadataDTypeObject *dtype1, MetadataDTypeObject *dtype2) { + Py_INCREF(dtype1); return dtype1; }