Skip to content

Commit

Permalink
Fix for CONPY-105: Change behavior of cursor->rowcount and cursor->la…
Browse files Browse the repository at this point in the history
…strowid

rowcount:
In case of an error, or if statement wasn't executed rowcount should be -1 (see PEP-249)

    For DML statements the number of affected rows returned in OK packet by server:
        > 0 for DML statements which modify or insert, e.g. ALTER TABLE or CREATE TABLE .. SELECT FROM
        otherwise 0
    For DQL statement
        if field_count > 0: number of rows returned
        otherwise affected rows returned in OK packet by server.

lastrowid:

    if server returns no value (0) for last_insert_id, lastrowid should be None.
    if last_insert_id is > 0, return it's value
  • Loading branch information
9EOR9 committed Aug 14, 2020
1 parent 30d5793 commit afea681
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 70 deletions.
89 changes: 42 additions & 47 deletions mariadb/mariadb_cursor.c
Original file line number Diff line number Diff line change
Expand Up @@ -482,29 +482,31 @@ static void ma_set_result_column_value(MrdbCursor *self, PyObject *row, uint32_t
static
void ma_cursor_close(MrdbCursor *self)
{
if (!self->is_text && self->stmt)
if (!self->is_closed)
{
/* Todo: check if all the cursor stuff is deleted (when using prepared
statements this should be handled in mysql_stmt_close) */
MARIADB_BEGIN_ALLOW_THREADS(self)
mysql_stmt_close(self->stmt);
MARIADB_END_ALLOW_THREADS(self)
self->stmt= NULL;
}
MrdbCursor_clear(self, 0);
if (self->parser)
{
MrdbParser_end(self->parser);
self->parser= NULL;
if (!self->is_text && self->stmt)
{
/* Todo: check if all the cursor stuff is deleted (when using prepared
statements this should be handled in mysql_stmt_close) */
MARIADB_BEGIN_ALLOW_THREADS(self)
mysql_stmt_close(self->stmt);
MARIADB_END_ALLOW_THREADS(self)
self->stmt= NULL;
}
MrdbCursor_clear(self, 0);
if (self->parser)
{
MrdbParser_end(self->parser);
self->parser= NULL;
}
self->is_closed= 1;
}
self->is_closed= 1;
}

static
PyObject * MrdbCursor_close(MrdbCursor *self)
{
ma_cursor_close(self);
self->is_closed= 1;
Py_INCREF(Py_None);
return Py_None;
}
Expand All @@ -521,7 +523,6 @@ void MrdbCursor_dealloc(MrdbCursor *self)
static int Mrdb_GetFieldInfo(MrdbCursor *self)
{
self->row_number= 0;
self->row_count= CURSOR_AFFECTED_ROWS(self);

if (self->field_count)
{
Expand Down Expand Up @@ -732,7 +733,7 @@ PyObject *MrdbCursor_execute(MrdbCursor *self,
/* execute_direct was implemented together with bulk operations, so we need
to check if MARIADB_CLIENT_STMT_BULK_OPERATIONS is set in extended server
capabilities */
if (!(self->connection->extended_server_capabilities &
if ((self->connection->extended_server_capabilities &
(MARIADB_CLIENT_STMT_BULK_OPERATIONS >> 32)))
{
rc= mysql_stmt_prepare(self->stmt, self->parser->statement.str,
Expand Down Expand Up @@ -799,10 +800,21 @@ PyObject *MrdbCursor_execute(MrdbCursor *self,

if (MrdbCursor_InitResultSet(self))
goto error;

if (self->field_count)
{
self->row_count= CURSOR_NUM_ROWS(self);
} else {
self->row_count= CURSOR_AFFECTED_ROWS(self);
}
self->lastrow_id= CURSOR_INSERT_ID(self);

end:
MARIADB_FREE_MEM(self->value);
Py_RETURN_NONE;
error:
self->row_count= -1;
self->lastrow_id= 0;
MrdbParser_end(self->parser);
self->parser= NULL;
MrdbCursor_clear(self, 0);
Expand All @@ -813,7 +825,6 @@ PyObject *MrdbCursor_execute(MrdbCursor *self,
/* {{{ MrdbCursor_fieldcount() */
PyObject *MrdbCursor_fieldcount(MrdbCursor *self)
{
MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
return NULL;

Expand All @@ -833,7 +844,6 @@ PyObject *MrdbCursor_description(MrdbCursor *self)
PyObject *obj= NULL;
unsigned int field_count= self->field_count;

MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
return NULL;

Expand Down Expand Up @@ -937,7 +947,8 @@ MrdbCursor_fetchone(MrdbCursor *self)
uint32_t i;
unsigned int field_count= self->field_count;

MARIADB_CHECK_STMT(self);
if (self->cursor_type == CURSOR_TYPE_READ_ONLY)
MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
{
return NULL;
Expand Down Expand Up @@ -1116,7 +1127,7 @@ MrdbCursor_fetchmany(MrdbCursor *self,
{
goto end;
}
self->affected_rows= CURSOR_NUM_ROWS(self);
self->row_count= CURSOR_NUM_ROWS(self);
if (!(Row= mariadb_get_sequence_or_tuple(self)))
{
return NULL;
Expand Down Expand Up @@ -1242,6 +1253,7 @@ MrdbCursor_executemany_fallback(MrdbCursor *self,
}
self->row_count++;
}
self->lastrow_id= CURSOR_INSERT_ID(self);
rc= mysql_query(self->stmt->mysql, "COMMIT");
if (!rc)
return 0;
Expand Down Expand Up @@ -1360,6 +1372,8 @@ MrdbCursor_executemany(MrdbCursor *self,
}
}
end:
self->row_count= CURSOR_AFFECTED_ROWS(self);
self->lastrow_id= CURSOR_INSERT_ID(self);
MARIADB_FREE_MEM(self->values);
Py_RETURN_NONE;
error:
Expand Down Expand Up @@ -1422,36 +1436,13 @@ MrdbCursor_nextset(MrdbCursor *self)
static PyObject *
Mariadb_row_count(MrdbCursor *self)
{
int64_t row_count= 0;

MARIADB_CHECK_STMT(self);
if (PyErr_Occurred())
{
return NULL;
}

/* PEP-249 requires to return -1 if the cursor was not executed before */
if (!self->statement)
{
return PyLong_FromLongLong(-1);
}

if (self->field_count)
{
if (!self->is_buffered && !self->fetched)
{
row_count= -1;
} else
{
row_count= CURSOR_NUM_ROWS(self);
}
}
else {
row_count= self->row_count ? self->row_count : CURSOR_AFFECTED_ROWS(self);
if (!row_count)
row_count= -1;
}
return PyLong_FromLongLong(row_count);
return PyLong_FromLongLong(self->row_count);
}

static PyObject *
Expand Down Expand Up @@ -1499,8 +1490,12 @@ MrdbCursor_setbuffered(MrdbCursor *self, PyObject *arg)
static PyObject *
MrdbCursor_lastrowid(MrdbCursor *self)
{
MARIADB_CHECK_STMT(self);
return PyLong_FromUnsignedLongLong(CURSOR_INSERT_ID(self));
if (!self->lastrow_id)
{
Py_INCREF(Py_None);
return Py_None;
}
return PyLong_FromUnsignedLongLong(self->lastrow_id);
}

/* iterator protocol */
Expand Down
4 changes: 2 additions & 2 deletions testing/test/integration/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_conpy36(self):
default_conf = conf()
try:
conn= mariadb.connect(user=default_conf["user"], unix_socket="/does_not_exist/x.sock", port=default_conf["port"], host=default_conf["host"])
except mariadb.DatabaseError:
except (mariadb.OperationalError,):
pass

def test_connection_default_file(self):
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_local_infile(self):
cursor.execute("CREATE TEMPORARY TABLE t1 (a int)")
try:
cursor.execute("LOAD DATA LOCAL INFILE 'x.x' INTO TABLE t1")
except (mariadb.ProgrammingError, mariadb.DatabaseError):
except (mariadb.OperationalError,):
pass
del cursor
del new_conn
Expand Down
12 changes: 6 additions & 6 deletions testing/test/integration/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ def test_fetchmany(self):
self.assertRaises(mariadb.Error, cursor.fetchall)

cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id")
self.assertEqual(-1, cursor.rowcount)
self.assertEqual(0, cursor.rowcount)
row = cursor.fetchall()
self.assertEqual(row, params)
self.assertEqual(5, cursor.rowcount)

cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id")
self.assertEqual(-1, cursor.rowcount)
self.assertEqual(0, cursor.rowcount)

row = cursor.fetchmany(1)
self.assertEqual(row, [params[0]])
Expand Down Expand Up @@ -543,7 +543,7 @@ def test_conpy_15(self):
cursor = self.connection.cursor()
cursor.execute(
"CREATE TEMPORARY TABLE test_conpy_15 (a int not null auto_increment primary key, b varchar(20))");
self.assertEqual(cursor.lastrowid, 0)
self.assertEqual(cursor.lastrowid, None)
cursor.execute("INSERT INTO test_conpy_15 VALUES (null, 'foo')")
self.assertEqual(cursor.lastrowid, 1)
cursor.execute("SELECT LAST_INSERT_ID()")
Expand All @@ -570,7 +570,7 @@ def test_conpy_14(self):
self.assertEqual(cursor.rowcount, -1)
cursor.execute(
"CREATE TEMPORARY TABLE test_conpy_14 (a int not null auto_increment primary key, b varchar(20))");
self.assertEqual(cursor.rowcount, -1)
self.assertEqual(cursor.rowcount, 0)
cursor.execute("INSERT INTO test_conpy_14 VALUES (null, 'foo')")
self.assertEqual(cursor.rowcount, 1)
vals = [(3, "bar"), (4, "this")]
Expand Down Expand Up @@ -961,13 +961,13 @@ def test_conpy67(self):
con= create_connection()
cur = con.cursor()
cur.execute("SELECT 1")
self.assertEqual(cur.rowcount, -1)
self.assertEqual(cur.rowcount, 0)
cur.close()

cur = con.cursor()
cur.execute("CREATE TEMPORARY TABLE test_conpy67 (a int)")
cur.execute("SELECT * from test_conpy67")
self.assertEqual(cur.rowcount, -1)
self.assertEqual(cur.rowcount, 0)
cur.fetchall()
self.assertEqual(cur.rowcount, 0)

Expand Down
18 changes: 6 additions & 12 deletions testing/test/integration/test_dbapi20.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,25 +307,19 @@ def test_rowcount(self):
try:
cur = con.cursor(buffered=True)
self.executeDDL1(cur)
self.assertEqual(cur.rowcount, -1,
'cursor.rowcount should be -1 after executing no-result '
self.assertEqual(cur.rowcount, 0,
'cursor.rowcount should be 0 after executing no-result '
'statements'
)
cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
self.table_prefix
))
self.assertTrue(cur.rowcount in (-1, 1),
'cursor.rowcount should == number or rows inserted, or '
'set to -1 after executing an insert statement'
)
self.assertEqual(cur.rowcount, 1)
cur.execute("select name from %sbooze" % self.table_prefix)
self.assertTrue(cur.rowcount in (-1, 1),
'cursor.rowcount should == number of rows returned, or '
'set to -1 after executing a select statement'
)
self.assertEqual(cur.rowcount, 1)
self.executeDDL2(cur)
self.assertEqual(cur.rowcount, -1,
'cursor.rowcount not being reset to -1 after executing '
self.assertEqual(cur.rowcount, 0,
'cursor.rowcount not being reset to 0 after executing '
'no-result statements'
)
finally:
Expand Down
2 changes: 1 addition & 1 deletion testing/test/integration/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_conn_timeout_exception(self):
start = datetime.today()
try:
create_connection({"connect_timeout": 1, "host": "8.8.8.8"})
except mariadb.DatabaseError as err:
except mariadb.OperationalError as err:
self.assertEqual(err.sqlstate, "HY000")
self.assertEqual(err.errno, 2002)
self.assertTrue(err.errmsg.find("server on '8.8.8.8'") > -1)
Expand Down
4 changes: 2 additions & 2 deletions testing/test/integration/test_nondbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_ping(self):
self.connection.kill(id)
try:
new_conn.ping()
except mariadb.DatabaseError:
except mariadb.InterfaceError:
pass
del new_conn
new_conn = create_connection()
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_reset(self):
cursor.execute("SELECT 1 UNION SELECT 2")
try:
self.connection.ping()
except mariadb.DatabaseError:
except mariadb.InterfaceError:
pass

self.connection.reset()
Expand Down

0 comments on commit afea681

Please sign in to comment.