Skip to content

Commit

Permalink
CONPY-188:
Browse files Browse the repository at this point in the history
When a connection or cursor was closed, an exception will be returned
if a method or property of closed object will be called.
  • Loading branch information
9EOR9 committed Jan 18, 2022
1 parent c7b27ed commit 7b63daa
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 60 deletions.
9 changes: 5 additions & 4 deletions include/mariadb_python.h
Expand Up @@ -478,7 +478,7 @@ if ((obj)->thread_state)\
#define MARIADB_CHECK_CONNECTION(connection, ret)\
if (!(connection) || !(connection)->mysql)\
{\
mariadb_throw_exception(NULL, Mariadb_InterfaceError, 0, \
mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0, \
"Invalid connection or not connected");\
return (ret);\
}
Expand All @@ -498,12 +498,13 @@ if ((obj)->thread_state)\
(a)= NULL;\
}

#define MARIADB_CHECK_STMT(cursor)\
if (!cursor->connection->mysql || cursor->closed)\
#define MARIADB_CHECK_STMT(cursor, retval)\
if (!(cursor)->connection->mysql || (cursor)->closed)\
{\
(cursor)->closed= 1;\
mariadb_throw_exception(cursor->stmt, Mariadb_ProgrammingError, 1,\
mariadb_throw_exception((cursor)->stmt, Mariadb_ProgrammingError, 1,\
"Invalid cursor or not connected");\
return (retval);\
}

#define pooling_keywords "pool_name", "pool_size", "reset_session", "idle_timeout", "acquire_timeout"
Expand Down
11 changes: 8 additions & 3 deletions mariadb/mariadb_connection.c
Expand Up @@ -674,6 +674,8 @@ static PyObject *MrdbConnection_cursor(MrdbConnection *self,
PyObject *cursor= NULL;
PyObject *conn = NULL;

MARIADB_CHECK_CONNECTION(self, NULL);

conn= Py_BuildValue("(O)", self);
cursor= PyObject_Call((PyObject *)&MrdbCursor_Type, conn, kwargs);
Py_DECREF(conn);
Expand Down Expand Up @@ -792,6 +794,8 @@ MrdbConnection_tpc_begin(MrdbConnection *self, PyObject *args)
char stmt[192];
int rc= 0;

MARIADB_CHECK_CONNECTION(self, NULL);

if (!PyArg_ParseTuple(args, "(iss)", &format_id,
&transaction_id,
&branch_qualifier))
Expand Down Expand Up @@ -1143,6 +1147,8 @@ static PyObject *MrdbConnection_getreconnect(MrdbConnection *self,
{
uint8_t reconnect= 0;

MARIADB_CHECK_CONNECTION(self, NULL);

if (self->mysql) {
mysql_get_option(self->mysql, MYSQL_OPT_RECONNECT, &reconnect);
}
Expand All @@ -1162,9 +1168,7 @@ static int MrdbConnection_setreconnect(MrdbConnection *self,
{
uint8_t reconnect;

if (!self->mysql) {
return 0;
}
MARIADB_CHECK_CONNECTION(self, -1);

if (!args || !CHECK_TYPE(args, &PyBool_Type)) {
PyErr_SetString(PyExc_TypeError, "Argument must be boolean");
Expand Down Expand Up @@ -1406,6 +1410,7 @@ MrdbConnection_exit(MrdbConnection *self, PyObject *args __attribute__((unused))
static PyObject *MrdbConnection_get_server_version(MrdbConnection *self)
{
MARIADB_CHECK_CONNECTION(self, NULL);

Py_INCREF(self->server_version_info);
return self->server_version_info;
}
Expand Down
77 changes: 25 additions & 52 deletions mariadb/mariadb_cursor.c
Expand Up @@ -677,9 +677,7 @@ PyObject *MrdbCursor_execute(MrdbCursor *self,
static char *key_words[]= {"", "", "buffered", NULL};
char errmsg[128];

MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
return NULL;
MARIADB_CHECK_STMT(self, NULL);

if (!PyArg_ParseTupleAndKeywords(args, kwargs,
"s#|Ob", key_words, &statement, &statement_len, &Data, &is_buffered))
Expand Down Expand Up @@ -871,8 +869,7 @@ PyObject *MrdbCursor_execute(MrdbCursor *self,
/* {{{ MrdbCursor_fieldcount() */
PyObject *MrdbCursor_fieldcount(MrdbCursor *self)
{
if (PyErr_Occurred())
return NULL;
MARIADB_CHECK_STMT(self, NULL);

return PyLong_FromLong((long)self->field_count);
}
Expand All @@ -890,9 +887,7 @@ PyObject *MrdbCursor_description(MrdbCursor *self)
PyObject *obj= NULL;
unsigned int field_count= self->field_count;

if (PyErr_Occurred())
return NULL;

MARIADB_CHECK_STMT(self, NULL);

if (self->fields && field_count)
{
Expand Down Expand Up @@ -996,12 +991,7 @@ MrdbCursor_fetchone(MrdbCursor *self)
uint32_t i;
unsigned int field_count= self->field_count;

if (self->cursor_type == CURSOR_TYPE_READ_ONLY)
MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
{
return NULL;
}
MARIADB_CHECK_STMT(self, NULL);

if (!field_count)
{
Expand Down Expand Up @@ -1042,11 +1032,7 @@ MrdbCursor_scroll(MrdbCursor *self,
const char *scroll_modes[]= {"relative", "absolute", NULL};


MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
{
return NULL;
}
MARIADB_CHECK_STMT(self, NULL);

if (!self->field_count)
{
Expand Down Expand Up @@ -1133,11 +1119,7 @@ MrdbCursor_fetchmany(MrdbCursor *self,
static char *kw_list[]= {"size", NULL};
unsigned int field_count= self->field_count;

MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
{
return NULL;
}
MARIADB_CHECK_STMT(self, NULL);

if (!field_count)
{
Expand Down Expand Up @@ -1211,11 +1193,8 @@ MrdbCursor_fetchall(MrdbCursor *self)
{
PyObject *List;
unsigned int field_count= self->field_count;
MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
{
return NULL;
}

MARIADB_CHECK_STMT(self, NULL);

if (!field_count)
{
Expand Down Expand Up @@ -1332,12 +1311,7 @@ MrdbCursor_executemany(MrdbCursor *self,
uint8_t do_prepare= 1;
char errmsg[128];

MARIADB_CHECK_STMT(self);

if (PyErr_Occurred())
{
return NULL;
}
MARIADB_CHECK_STMT(self, NULL);

self->data= NULL;

Expand Down Expand Up @@ -1447,13 +1421,8 @@ static PyObject *
MrdbCursor_nextset(MrdbCursor *self)
{
int rc;
MARIADB_CHECK_STMT(self);
MARIADB_CHECK_STMT(self, NULL);

if (PyErr_Occurred())
{
return NULL;
}
/* hmmm */
if (!self->field_count)
{
mariadb_throw_exception(NULL, Mariadb_ProgrammingError, 0,
Expand Down Expand Up @@ -1519,14 +1488,15 @@ Mariadb_row_number(MrdbCursor *self)
static PyObject *
MrdbCursor_warnings(MrdbCursor *self)
{
MARIADB_CHECK_STMT(self);
MARIADB_CHECK_STMT(self, NULL);

return PyLong_FromLong((long)CURSOR_WARNING_COUNT(self));
}

static PyObject *
MrdbCursor_getbuffered(MrdbCursor *self)
{
MARIADB_CHECK_STMT(self, NULL);
if (self->is_buffered)
{
Py_RETURN_TRUE;
Expand All @@ -1537,6 +1507,7 @@ MrdbCursor_getbuffered(MrdbCursor *self)
static int
MrdbCursor_setbuffered(MrdbCursor *self, PyObject *arg)
{
MARIADB_CHECK_STMT(self, -1);
if (!arg || !CHECK_TYPE(arg, &PyBool_Type))
{
PyErr_SetString(PyExc_TypeError, "Argument must be boolean");
Expand All @@ -1550,6 +1521,7 @@ MrdbCursor_setbuffered(MrdbCursor *self, PyObject *arg)
static PyObject *
MrdbCursor_lastrowid(MrdbCursor *self)
{
MARIADB_CHECK_STMT(self, NULL);
if (!self->lastrow_id)
{
Py_INCREF(Py_None);
Expand All @@ -1563,7 +1535,7 @@ MrdbCursor_lastrowid(MrdbCursor *self)
static PyObject *
MrdbCursor_iter(PyObject *self)
{
MARIADB_CHECK_STMT(((MrdbCursor *)self));
MARIADB_CHECK_STMT((MrdbCursor *)self, NULL);
Py_INCREF(self);
return self;
}
Expand All @@ -1573,6 +1545,8 @@ MrdbCursor_iternext(PyObject *self)
{
PyObject *res;

MARIADB_CHECK_STMT((MrdbCursor *)self, NULL);

res= MrdbCursor_fetchone((MrdbCursor *)self);

if (res && res == Py_None)
Expand All @@ -1594,15 +1568,14 @@ static PyObject
static PyObject *
MrdbCursor_sp_outparams(MrdbCursor *self)
{
if (!self->closed && self->stmt &&
self->stmt->mysql)
uint32_t server_status;

MARIADB_CHECK_STMT(self, NULL);

mariadb_get_infov(self->stmt->mysql, MARIADB_CONNECTION_SERVER_STATUS, &server_status);
if (server_status & SERVER_PS_OUT_PARAMS)
{
uint32_t server_status;
mariadb_get_infov(self->stmt->mysql, MARIADB_CONNECTION_SERVER_STATUS, &server_status);
if (server_status & SERVER_PS_OUT_PARAMS)
{
Py_RETURN_TRUE;
}
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
Expand All @@ -1619,7 +1592,7 @@ MrdbCursor_callproc(MrdbCursor *self, PyObject *args)
PyObject *new_args= NULL;
PyObject *rc= NULL;

MARIADB_CHECK_STMT(((MrdbCursor *)self));
MARIADB_CHECK_STMT(self, NULL);

if (!PyArg_ParseTuple(args, "s#|O", &sp, &sp_len, &data))
return NULL;
Expand Down
2 changes: 1 addition & 1 deletion testing/test/integration/test_connection.py
Expand Up @@ -201,7 +201,7 @@ def test_conpy155(self):
c1.close()
try:
version= c1.get_server_version()
except mariadb.InterfaceError:
except mariadb.ProgrammingError:
pass

def test_conpy175(self):
Expand Down

0 comments on commit 7b63daa

Please sign in to comment.