Skip to content

Commit

Permalink
When converting parameters also check subtype of a Python Object
Browse files Browse the repository at this point in the history
  • Loading branch information
9EOR9 committed Oct 2, 2020
1 parent bce98d7 commit cbd51de
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions mariadb/mariadb_codecs.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@

#define CHARSET_BINARY 63

#define CHECK_TYPE(obj, type) \
(Py_TYPE((obj)) == type || \
PyType_IsSubtype(Py_TYPE((obj)), type))

#define IS_DECIMAL_TYPE(type) \
((type) == MYSQL_TYPE_NEWDECIMAL || (type) == MYSQL_TYPE_DOUBLE || (type) == MYSQL_TYPE_FLOAT)

Expand Down Expand Up @@ -428,6 +424,7 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
{
MYSQL_TIME tm;
unsigned long *length;
enum enum_extended_field_type ext_type= mariadb_extended_field_type(&self->fields[column]);

if (!data)
{
Expand Down Expand Up @@ -509,7 +506,8 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
{
self->fields[column].max_length= length[column];
}
if (self->fields[column].charsetnr== CHARSET_BINARY)
if (self->fields[column].charsetnr== CHARSET_BINARY &&
ext_type != EXT_TYPE_JSON)
{
self->values[column]=
PyBytes_FromStringAndSize((const char *)data,
Expand Down Expand Up @@ -537,8 +535,7 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
case MYSQL_TYPE_ENUM:
{
unsigned long len;
enum enum_extended_field_type ext_type= mariadb_extended_field_type(&self->fields[column]);
if ( self->fields[column].charsetnr == CHARSET_BINARY && ext_type != EXT_TYPE_JSON)
if ( self->fields[column].charsetnr == CHARSET_BINARY)
{
self->values[column]=
PyBytes_FromStringAndSize((const char *)data,
Expand All @@ -563,8 +560,14 @@ field_fetch_fromtext(MrdbCursor *self, char *data, unsigned int column)
if (self->connection->converter)
{
PyObject *val;
enum enum_field_types type;

if (ext_type == EXT_TYPE_JSON)
type= MYSQL_TYPE_JSON;
else
type= self->fields[column].type;

if ((val= ma_convert_value(self, self->fields[column].type, self->values[column])))
if ((val= ma_convert_value(self, type, self->values[column])))
self->values[column]= val;
}
}
Expand All @@ -583,7 +586,7 @@ void
field_fetch_callback(void *data, unsigned int column, unsigned char **row)
{
MrdbCursor *self= (MrdbCursor *)data;

enum enum_extended_field_type ext_type= mariadb_extended_field_type(&self->fields[column]);
MARIADB_UNBLOCK_THREADS(self);

if (!row)
Expand Down Expand Up @@ -735,7 +738,8 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
unsigned long length= mysql_net_field_length(row);
if (length > self->fields[column].max_length)
self->fields[column].max_length= length;
if (self->fields[column].charsetnr == CHARSET_BINARY)
if (self->fields[column].charsetnr == CHARSET_BINARY &&
ext_type != EXT_TYPE_JSON)
{
self->values[column]=
PyBytes_FromStringAndSize((const char *)*row,
Expand Down Expand Up @@ -777,11 +781,9 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
unsigned long length;
unsigned long utf8len;
length= mysql_net_field_length(row);
enum enum_extended_field_type ext_type= mariadb_extended_field_type(&self->fields[column]);

if ((self->fields[column].flags & BINARY_FLAG ||
self->fields[column].charsetnr == CHARSET_BINARY) &&
ext_type != EXT_TYPE_JSON)
self->fields[column].charsetnr == CHARSET_BINARY))
{
self->values[column]=
PyBytes_FromStringAndSize((const char *)*row,
Expand All @@ -805,8 +807,14 @@ field_fetch_callback(void *data, unsigned int column, unsigned char **row)
if (self->connection->converter)
{
PyObject *val;
enum enum_field_types type;

if ((val= ma_convert_value(self, self->fields[column].type, self->values[column])))
if (ext_type == EXT_TYPE_JSON)
type= MYSQL_TYPE_JSON;
else
type= self->fields[column].type;

if ((val= ma_convert_value(self, type, self->values[column])))
self->values[column]= val;
}
end:
Expand Down Expand Up @@ -876,10 +884,10 @@ mariadb_get_column_info(PyObject *obj, MrdbParamInfo *paraminfo)

static PyObject *ListOrTuple_GetItem(PyObject *obj, Py_ssize_t index)
{
if (Py_TYPE(obj) == &PyList_Type)
if (CHECK_TYPE(obj, &PyList_Type))
{
return PyList_GetItem(obj, index);
} else if (Py_TYPE(obj) == &PyTuple_Type)
} else if (CHECK_TYPE(obj, &PyTuple_Type))
{
return PyTuple_GetItem(obj, index);
}
Expand Down Expand Up @@ -1111,10 +1119,10 @@ mariadb_get_parameter_info(MrdbCursor *self,

static Py_ssize_t ListOrTuple_Size(PyObject *obj)
{
if (Py_TYPE(obj) == &PyList_Type)
if (CHECK_TYPE(obj, &PyList_Type))
{
return PyList_Size(obj);
} else if (Py_TYPE(obj) == &PyTuple_Type)
} else if (CHECK_TYPE(obj, &PyTuple_Type))
{
return PyTuple_Size(obj);
}
Expand All @@ -1132,8 +1140,8 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
{
uint32_t i;

if (Py_TYPE(data) != &PyList_Type &&
Py_TYPE(data) != &PyTuple_Type)
if (!CHECK_TYPE((data), &PyList_Type) &&
!CHECK_TYPE(data, &PyTuple_Type))
{
mariadb_throw_exception(self->stmt, Mariadb_InterfaceError, 1,
"Data must be passed as sequence (Tuple or List)");
Expand All @@ -1151,16 +1159,16 @@ mariadb_check_bulk_parameters(MrdbCursor *self,
{
PyObject *obj= ListOrTuple_GetItem(data, i);
if (self->parser->paramstyle != PYFORMAT &&
(Py_TYPE(obj) != &PyTuple_Type &&
Py_TYPE(obj) != &PyList_Type))
(!CHECK_TYPE(obj, &PyTuple_Type) &&
!CHECK_TYPE(obj, &PyList_Type)))
{
mariadb_throw_exception(NULL, Mariadb_DataError, 0,
"Invalid parameter type in row %d. "\
" (Row data must be provided as tuple(s))", i+1);
return 1;
}
if (self->parser->paramstyle == PYFORMAT &&
Py_TYPE(obj) != &PyDict_Type)
!CHECK_TYPE(obj, &PyDict_Type))
{
mariadb_throw_exception(NULL, Mariadb_DataError, 0,
"Invalid parameter type in row %d. "\
Expand Down Expand Up @@ -1347,16 +1355,8 @@ mariadb_param_to_bind(MrdbCursor *self,
*(double *)value->num= (double)PyFloat_AsDouble(value->value);
break;
case MYSQL_TYPE_LONG_BLOB:
if (Py_TYPE(value->value) != &PyBytes_Type)
{
PyObject *dump= NULL;
dump= PyObject_CallMethod(Mrdb_Pickle, "dumps", "O", value->value);
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(dump);
bind->buffer= (void *) PyBytes_AS_STRING(dump);
} else {
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
}
bind->buffer_length= (unsigned long)PyBytes_GET_SIZE(value->value);
bind->buffer= (void *) PyBytes_AS_STRING(value->value);
break;
case MYSQL_TYPE_DATE:
case MYSQL_TYPE_TIME:
Expand Down

0 comments on commit cbd51de

Please sign in to comment.