Skip to content

Commit

Permalink
Handle dicts in check_text_params
Browse files Browse the repository at this point in the history
  • Loading branch information
9EOR9 committed Apr 17, 2023
1 parent b0366fa commit 9aedf1c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion mariadb/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def execute(self, statement: str, data: Sequence = (), buffered=None):

# if one of the provided parameters has byte or datetime value,
# we don't use text protocol
if self._check_text_types() == True:
if data and self._check_text_types() == True:
self._text = False

if self._text:
Expand Down
32 changes: 21 additions & 11 deletions mariadb/mariadb_cursor.c
Original file line number Diff line number Diff line change
Expand Up @@ -653,18 +653,20 @@ PyObject *MrdbCursor_InitResultSet(MrdbCursor *self)
self->result= NULL;
}

if (Mrdb_GetFieldInfo(self))
return NULL;

if (!(self->values= (PyObject**)PyMem_RawCalloc(self->field_count, sizeof(PyObject *))))
return NULL;
if (!self->parseinfo.is_text)
mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_RESULT, field_fetch_callback);

if (self->field_count)
{
self->row_count= CURSOR_NUM_ROWS(self);
self->affected_rows= 0;
if (Mrdb_GetFieldInfo(self))
{
return NULL;
}

if (!(self->values= (PyObject**)PyMem_RawCalloc(self->field_count, sizeof(PyObject *))))
return NULL;
if (!self->parseinfo.is_text)
mysql_stmt_attr_set(self->stmt, STMT_ATTR_CB_RESULT, field_fetch_callback);

self->row_count= CURSOR_NUM_ROWS(self);
self->affected_rows= 0;
} else {
self->row_count= self->affected_rows= CURSOR_AFFECTED_ROWS(self);
}
Expand Down Expand Up @@ -1284,6 +1286,7 @@ static PyObject *
MrdbCursor_check_text_types(MrdbCursor *self)
{
PyDateTime_IMPORT;
Py_ssize_t ofs= 0;

if (!self || !self->data || !self->parseinfo.paramcount)
{
Expand All @@ -1292,7 +1295,14 @@ MrdbCursor_check_text_types(MrdbCursor *self)

for (uint32_t i= 0; i < self->parseinfo.paramcount; i++)
{
PyObject *obj= ListOrTuple_GetItem(self->data, i);
PyObject *obj;

if (PyDict_Check(self->data))
{
PyDict_Next(self->data, &ofs, NULL, &obj);
}
else
obj= ListOrTuple_GetItem(self->data, i);
if (PyBytes_Check(obj) ||
PyByteArray_Check(obj) ||
PyDate_Check(obj))
Expand Down

0 comments on commit 9aedf1c

Please sign in to comment.