Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion asciidtype/asciidtype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,7 @@
from .scalar import ASCIIScalar # isort: skip
from ._asciidtype_main import ASCIIDType

__all__ = ["ASCIIDType", "ASCIIScalar"]
__all__ = [
"ASCIIDType",
"ASCIIScalar",
]
77 changes: 77 additions & 0 deletions asciidtype/asciidtype/src/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,82 @@ static PyMemberDef ASCIIDType_members[] = {
{NULL},
};

static int PICKLE_VERSION = 1;

static PyObject *
asciidtype__reduce__(ASCIIDTypeObject *self)
{
PyObject *ret, *mod, *obj, *state;

ret = PyTuple_New(3);
if (ret == NULL) {
return NULL;
}

mod = PyImport_ImportModule("asciidtype");
if (mod == NULL) {
Py_DECREF(ret);
return NULL;
}

obj = PyObject_GetAttrString(mod, "ASCIIDType");
Py_DECREF(mod);
if (obj == NULL) {
Py_DECREF(ret);
return NULL;
}

PyTuple_SET_ITEM(ret, 0, obj);

PyTuple_SET_ITEM(ret, 1, Py_BuildValue("(l)", self->size));

state = PyTuple_New(1);

PyTuple_SET_ITEM(state, 0, PyLong_FromLong(PICKLE_VERSION));

PyTuple_SET_ITEM(ret, 2, state);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think:

Py_BuildValue("O(l)", Py_TYPE(self) self->size,self->size)

is good enough? presumably we can consider ASCIIDType(length) to be stable API? And you made the type pickable using the copyreg, so no need for the reconstruct function? Plus, if we have a reconstruction function, then we certainly can assume stable API?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you didn't do the copyreg for the ASCIIDtype, I think. But maybe that is nicer? For the StringDType, this seems true even more so? We can assume StringDType() will always give the same thing, so no need for __setstate__ logic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not using copyreg at all in either dtype. It is a good point though that we don't actually need to save the dtype itself at all though, thanks for the hint.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I misread... somehow I thought one of the reconstructs was to make the class work. So I guess you need the reconstruct helper until we fix the class pickling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly right. I originally had it as you suggested, but that didn't work because ASCIIDType is an instance of DTypeMeta, so python checks the type, sees it's in the copyreg, and dutifully calls the pickler numpy registers for the type.


return ret;
}

static PyObject *
asciidtype__setstate__(ASCIIDTypeObject *NPY_UNUSED(self), PyObject *args)
{
if (PyTuple_GET_SIZE(args) != 1 ||
!(PyLong_Check(PyTuple_GET_ITEM(args, 0)))) {
PyErr_BadInternalCall();
return NULL;
}

long version = PyLong_AsLong(PyTuple_GET_ITEM(args, 0));

if (version != PICKLE_VERSION) {
PyErr_Format(PyExc_ValueError,
"Pickle version mismatch. Got version %d but expected "
"version %d.",
version, PICKLE_VERSION);
return NULL;
}

Py_RETURN_NONE;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not convinced we need it, and I would try to avoid it :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying we don't need to save a pickle version? I thought that was generally useful for forward compatibility, just in case we ever want to change what gets pickled in the future.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am saying there is no point in __setstate__.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then how do I check the pickle version when I load the pickle?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, if you want a version, then it is not possible. But, you don't need a version if you assume your API is stable, and if you have a loader function, you can keep the old one around also?
Anyway, not a strong opinion, the ndarray.__setstate__ in NumPy is just so messy, that it probably makes me want to not use it if possible :).


static PyMethodDef ASCIIDType_methods[] = {
{
"__reduce__",
(PyCFunction)asciidtype__reduce__,
METH_NOARGS,
"Reduction method for an ASCIIDType object",
},
{
"__setstate__",
(PyCFunction)asciidtype__setstate__,
METH_O,
"Unpickle an ASCIIDType object",
},
{NULL},
};

/*
* This is the basic things that you need to create a Python Type/Class in C.
* However, there is a slight difference here because we create a
Expand All @@ -242,6 +318,7 @@ PyArray_DTypeMeta ASCIIDType = {
.tp_repr = (reprfunc)asciidtype_repr,
.tp_str = (reprfunc)asciidtype_repr,
.tp_members = ASCIIDType_members,
.tp_methods = ASCIIDType_methods,
}},
/* rest, filled in during DTypeMeta initialization */
};
Expand Down
18 changes: 18 additions & 0 deletions asciidtype/tests/test_asciidtype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import pickle
import re
import tempfile

import numpy as np
import pytest
Expand Down Expand Up @@ -230,3 +233,18 @@ def test_insert_scalar_directly():
val = arr[0]
arr[1] = val
np.testing.assert_array_equal(arr, np.array(["some", "some"], dtype=dtype))


def test_pickle():
dtype = ASCIIDType(6)
arr = np.array(["this", "is", "an", "array"], dtype=dtype)
with tempfile.NamedTemporaryFile("wb", delete=False) as f:
pickle.dump([arr, dtype], f)

with open(f.name, "rb") as f:
res = pickle.load(f)

np.testing.assert_array_equal(arr, res[0])
assert res[1] == dtype

os.remove(f.name)
5 changes: 2 additions & 3 deletions stringdtype/stringdtype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""A dtype for working with string data
"""A dtype for working with variable-length string data

This is an example usage of the experimental new dtype API
in Numpy and is not intended for any real purpose.
"""

from .scalar import StringScalar # isort: skip
from ._main import StringDType, _memory_usage


__all__ = [
"StringDType",
"StringScalar",
Expand Down
5 changes: 0 additions & 5 deletions stringdtype/stringdtype/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@


class StringScalar(str):
def __new__(cls, value, dtype):
instance = super().__new__(cls, value)
instance.dtype = dtype
return instance

def partition(self, sep):
ret = super().partition(sep)
return (str(ret[0]), str(ret[1]), str(ret[2]))
Expand Down
82 changes: 80 additions & 2 deletions stringdtype/stringdtype/src/dtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ new_stringdtype_instance(void)
new->base.elsize = sizeof(ss *);
new->base.alignment = _Alignof(ss *);
new->base.flags |= NPY_NEEDS_INIT;
new->base.flags |= NPY_LIST_PICKLE;

return new;
}
Expand Down Expand Up @@ -68,7 +69,7 @@ string_discover_descriptor_from_pyobject(PyArray_DTypeMeta *NPY_UNUSED(cls),
return NULL;
}

PyArray_Descr *ret = (PyArray_Descr *)PyObject_GetAttrString(obj, "dtype");
PyArray_Descr *ret = (PyArray_Descr *)new_stringdtype_instance();
if (ret == NULL) {
return NULL;
}
Expand Down Expand Up @@ -143,7 +144,7 @@ stringdtype_getitem(StringDTypeObject *descr, char **dataptr)
}

PyObject *res = PyObject_CallFunctionObjArgs((PyObject *)StringScalar_Type,
val_obj, descr, NULL);
val_obj, NULL);

if (res == NULL) {
return NULL;
Expand Down Expand Up @@ -200,6 +201,82 @@ stringdtype_repr(StringDTypeObject *NPY_UNUSED(self))
return PyUnicode_FromString("StringDType()");
}

static int PICKLE_VERSION = 1;

static PyObject *
stringdtype__reduce__(StringDTypeObject *NPY_UNUSED(self))
{
PyObject *ret, *mod, *obj, *state;

ret = PyTuple_New(3);
if (ret == NULL) {
return NULL;
}

mod = PyImport_ImportModule("stringdtype");
if (mod == NULL) {
Py_DECREF(ret);
return NULL;
}

obj = PyObject_GetAttrString(mod, "StringDType");
Py_DECREF(mod);
if (obj == NULL) {
Py_DECREF(ret);
return NULL;
}

PyTuple_SET_ITEM(ret, 0, obj);

PyTuple_SET_ITEM(ret, 1, PyTuple_New(0));

state = PyTuple_New(1);

PyTuple_SET_ITEM(state, 0, PyLong_FromLong(PICKLE_VERSION));

PyTuple_SET_ITEM(ret, 2, state);

return ret;
}

static PyObject *
stringdtype__setstate__(StringDTypeObject *NPY_UNUSED(self), PyObject *args)
{
if (PyTuple_GET_SIZE(args) != 1 ||
!(PyLong_Check(PyTuple_GET_ITEM(args, 0)))) {
PyErr_BadInternalCall();
return NULL;
}

long version = PyLong_AsLong(PyTuple_GET_ITEM(args, 0));

if (version != PICKLE_VERSION) {
PyErr_Format(PyExc_ValueError,
"Pickle version mismatch. Got version %d but expected "
"version %d.",
version, PICKLE_VERSION);
return NULL;
}

Py_RETURN_NONE;
}

static PyMethodDef StringDType_methods[] = {
{
"__reduce__",
(PyCFunction)stringdtype__reduce__,
METH_NOARGS,
"Reduction method for an StringDType object",
},
{
"__setstate__",
(PyCFunction)stringdtype__setstate__,
METH_O,
"Unpickle an StringDType object",
},
{NULL},
};

/*
* This is the basic things that you need to create a Python Type/Class in C.
* However, there is a slight difference here because we create a
Expand All @@ -215,6 +292,7 @@ PyArray_DTypeMeta StringDType = {
.tp_dealloc = (destructor)stringdtype_dealloc,
.tp_repr = (reprfunc)stringdtype_repr,
.tp_str = (reprfunc)stringdtype_repr,
.tp_methods = StringDType_methods,
}},
/* rest, filled in during DTypeMeta initialization */
};
Expand Down
49 changes: 43 additions & 6 deletions stringdtype/tests/test_stringdtype.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import concurrent.futures
import os
import pickle
import tempfile

import numpy as np
import pytest

Expand All @@ -10,7 +15,7 @@ def string_list():


def test_scalar_creation():
assert str(StringScalar("abc", StringDType())) == "abc"
assert str(StringScalar("abc")) == "abc"


def test_dtype_creation():
Expand Down Expand Up @@ -38,12 +43,11 @@ def test_array_creation_utf8(data):


def test_array_creation_scalars(string_list):
dtype = StringDType()
arr = np.array(
[
StringScalar("abc", dtype=dtype),
StringScalar("def", dtype=dtype),
StringScalar("ghi", dtype=dtype),
StringScalar("abc"),
StringScalar("def"),
StringScalar("ghi"),
]
)
assert repr(arr) == repr(np.array(string_list, dtype=StringDType()))
Expand Down Expand Up @@ -94,7 +98,7 @@ def test_unicode_casts(string_list):
def test_insert_scalar(string_list):
dtype = StringDType()
arr = np.array(string_list, dtype=dtype)
arr[1] = StringScalar("what", dtype=dtype)
arr[1] = StringScalar("what")
assert repr(arr) == repr(np.array(["abc", "what", "ghi"], dtype=dtype))


Expand Down Expand Up @@ -124,3 +128,36 @@ def test_memory_usage(string_list):
_memory_usage("hello")
with pytest.raises(TypeError):
_memory_usage(np.array([1, 2, 3]))


def _pickle_load(filename):
with open(filename, "rb") as f:
res = pickle.load(f)

return res


def test_pickle(string_list):
dtype = StringDType()

arr = np.array(string_list, dtype=dtype)

with tempfile.NamedTemporaryFile("wb", delete=False) as f:
pickle.dump([arr, dtype], f)

with open(f.name, "rb") as f:
res = pickle.load(f)

np.testing.assert_array_equal(res[0], arr)
assert res[1] == dtype

# load the pickle in a subprocess to ensure the string data are
# actually stored in the pickle file
with concurrent.futures.ProcessPoolExecutor() as executor:
e = executor.submit(_pickle_load, f.name)
res = e.result()

np.testing.assert_array_equal(res[0], arr)
assert res[1] == dtype

os.remove(f.name)