From 76f33f435f0885d547147f0401f60ad918756e9c Mon Sep 17 00:00:00 2001 From: Marcel Nageler Date: Fri, 23 Jun 2023 12:09:48 +0200 Subject: [PATCH] support buffer protocol in Solver.add_clauses --- python/src/pycryptosat.cpp | 109 ++++++++----------------------------- 1 file changed, 23 insertions(+), 86 deletions(-) diff --git a/python/src/pycryptosat.cpp b/python/src/pycryptosat.cpp index adfff23e2..6a4fd29af 100644 --- a/python/src/pycryptosat.cpp +++ b/python/src/pycryptosat.cpp @@ -364,28 +364,21 @@ static int _add_clauses_from_array(Solver *self, const size_t array_length, cons return 1; } -static int _add_clauses_from_buffer_info(Solver *self, PyObject *buffer_info, const size_t itemsize) +static int _add_clauses_from_buffer(Solver *self, Py_buffer *view) { - PyObject *py_array_length = PyTuple_GetItem(buffer_info, 1); - if (py_array_length == NULL) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get array length"); + if (view->ndim != 1) { + PyErr_Format(PyExc_ValueError, "invalid clause array: expected 1-D array, got %d-D", view->ndim); return 0; } - long array_length = PyLong_AsLong(py_array_length); - if (array_length < 0) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get array length"); - return 0; - } - PyObject *py_array_address = PyTuple_GetItem(buffer_info, 0); - if (py_array_address == NULL) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get array address"); - return 0; - } - const void *array_address = PyLong_AsVoidPtr(py_array_address); - if (array_address == NULL) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get array address"); + if (strcmp(view->format, "i") != 0 && strcmp(view->format, "l") != 0 && strcmp(view->format, "q") != 0) { + PyErr_Format(PyExc_ValueError, "invalid clause array: invalid format '%s'", view->format); return 0; } + + void * array_address = view->buf; + size_t itemsize = view->itemsize; + size_t array_length = view->len / itemsize; + if (itemsize == sizeof(int)) { return _add_clauses_from_array(self, array_length, (const int *) array_address); } @@ -399,74 +392,14 @@ static int _add_clauses_from_buffer_info(Solver *self, PyObject *buffer_info, co return 0; } -static int _check_array_typecode(PyObject *clauses) -{ - PyObject *py_typecode = PyObject_GetAttrString(clauses, "typecode"); - if (py_typecode == NULL) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: typecode is NULL"); - return 0; - } - - PyObject *typecode_bytes = PyUnicode_AsASCIIString(py_typecode); - Py_DECREF(py_typecode); - if (typecode_bytes == NULL) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get typecode bytes"); - return 0; - } - - const char *typecode_cstr = PyBytes_AsString(typecode_bytes); - if (typecode_cstr == NULL) { - Py_DECREF(typecode_bytes); - PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get typecode cstring"); - return 0; - } - const char typecode = typecode_cstr[0]; - if (typecode == '\0' || typecode_cstr[1] != '\0') { - PyErr_Format(PyExc_ValueError, "invalid clause array: invalid typecode '%s'", typecode_cstr); - Py_DECREF(typecode_bytes); - return 0; - } - Py_DECREF(typecode_bytes); - if (typecode != 'i' && typecode != 'l' && typecode != 'q') { - PyErr_Format(PyExc_ValueError, "invalid clause array: invalid typecode '%c'", typecode); - return 0; - } - return 1; -} - -static int add_clauses_array(Solver *self, PyObject *clauses) -{ - if (_check_array_typecode(clauses) == 0) { - return 0; - } - PyObject *py_itemsize = PyObject_GetAttrString(clauses, "itemsize"); - if (py_itemsize == NULL) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: itemsize is NULL"); - return 0; - } - const long itemsize = PyLong_AsLong(py_itemsize); - Py_DECREF(py_itemsize); - if (itemsize < 0) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: could not get itemsize"); - return 0; - } - PyObject *buffer_info = PyObject_CallMethod(clauses, "buffer_info", NULL); - if (buffer_info == NULL) { - PyErr_SetString(PyExc_ValueError, "invalid clause array: buffer_info is NULL"); - return 0; - } - int ret = _add_clauses_from_buffer_info(self, buffer_info, itemsize); - Py_DECREF(buffer_info); - return ret; -} - PyDoc_STRVAR(add_clauses_doc, "add_clauses(clauses)\n\ Add iterable of clauses to the solver.\n\ \n\ :param clauses: List of clauses. Each clause contains literals (ints)\n\ - Alternatively, this can be a flat array.array (typecode 'i', 'l', or 'q')\n\ - of zero separated and terminated clauses of literals (ints).\n\ + Alternatively, this can be a flat array.array or other contiguous\n\ + buffer (format 'i', 'l', or 'q') of zero separated and terminated\n\ + clauses of literals (ints).\n\ :type clauses: or \n\ :return: None\n\ :rtype: " @@ -480,12 +413,16 @@ static PyObject* add_clauses(Solver *self, PyObject *args, PyObject *kwds) return NULL; } - if ( - PyObject_HasAttr(clauses, PyUnicode_FromString("buffer_info")) && - PyObject_HasAttr(clauses, PyUnicode_FromString("typecode")) && - PyObject_HasAttr(clauses, PyUnicode_FromString("itemsize")) - ) { - int ret = add_clauses_array(self, clauses); + if (PyObject_CheckBuffer(clauses)) { + Py_buffer view; + memset(&view, 0, sizeof(view)); + if (PyObject_GetBuffer(clauses, &view, PyBUF_CONTIG_RO | PyBUF_FORMAT) != 0) { + return NULL; + } + + int ret = _add_clauses_from_buffer(self, &view); + PyBuffer_Release(&view); + if (ret == 0 || PyErr_Occurred()) { return 0; }