diff --git a/mariadb/mariadb_cursor.c b/mariadb/mariadb_cursor.c index ec4a0bd..6003777 100644 --- a/mariadb/mariadb_cursor.c +++ b/mariadb/mariadb_cursor.c @@ -482,29 +482,31 @@ static void ma_set_result_column_value(MrdbCursor *self, PyObject *row, uint32_t static void ma_cursor_close(MrdbCursor *self) { - if (!self->is_text && self->stmt) + if (!self->is_closed) { - /* Todo: check if all the cursor stuff is deleted (when using prepared - statements this should be handled in mysql_stmt_close) */ - MARIADB_BEGIN_ALLOW_THREADS(self) - mysql_stmt_close(self->stmt); - MARIADB_END_ALLOW_THREADS(self) - self->stmt= NULL; - } - MrdbCursor_clear(self, 0); - if (self->parser) - { - MrdbParser_end(self->parser); - self->parser= NULL; + if (!self->is_text && self->stmt) + { + /* Todo: check if all the cursor stuff is deleted (when using prepared + statements this should be handled in mysql_stmt_close) */ + MARIADB_BEGIN_ALLOW_THREADS(self) + mysql_stmt_close(self->stmt); + MARIADB_END_ALLOW_THREADS(self) + self->stmt= NULL; + } + MrdbCursor_clear(self, 0); + if (self->parser) + { + MrdbParser_end(self->parser); + self->parser= NULL; + } + self->is_closed= 1; } - self->is_closed= 1; } static PyObject * MrdbCursor_close(MrdbCursor *self) { ma_cursor_close(self); - self->is_closed= 1; Py_INCREF(Py_None); return Py_None; } @@ -521,7 +523,6 @@ void MrdbCursor_dealloc(MrdbCursor *self) static int Mrdb_GetFieldInfo(MrdbCursor *self) { self->row_number= 0; - self->row_count= CURSOR_AFFECTED_ROWS(self); if (self->field_count) { @@ -732,7 +733,7 @@ PyObject *MrdbCursor_execute(MrdbCursor *self, /* execute_direct was implemented together with bulk operations, so we need to check if MARIADB_CLIENT_STMT_BULK_OPERATIONS is set in extended server capabilities */ - if (!(self->connection->extended_server_capabilities & + if ((self->connection->extended_server_capabilities & (MARIADB_CLIENT_STMT_BULK_OPERATIONS >> 32))) { rc= mysql_stmt_prepare(self->stmt, self->parser->statement.str, @@ -799,10 +800,21 @@ PyObject *MrdbCursor_execute(MrdbCursor *self, if (MrdbCursor_InitResultSet(self)) goto error; + + if (self->field_count) + { + self->row_count= CURSOR_NUM_ROWS(self); + } else { + self->row_count= CURSOR_AFFECTED_ROWS(self); + } + self->lastrow_id= CURSOR_INSERT_ID(self); + end: MARIADB_FREE_MEM(self->value); Py_RETURN_NONE; error: + self->row_count= -1; + self->lastrow_id= 0; MrdbParser_end(self->parser); self->parser= NULL; MrdbCursor_clear(self, 0); @@ -813,7 +825,6 @@ PyObject *MrdbCursor_execute(MrdbCursor *self, /* {{{ MrdbCursor_fieldcount() */ PyObject *MrdbCursor_fieldcount(MrdbCursor *self) { - MARIADB_CHECK_STMT(self); if (PyErr_Occurred()) return NULL; @@ -833,7 +844,6 @@ PyObject *MrdbCursor_description(MrdbCursor *self) PyObject *obj= NULL; unsigned int field_count= self->field_count; - MARIADB_CHECK_STMT(self); if (PyErr_Occurred()) return NULL; @@ -937,7 +947,8 @@ MrdbCursor_fetchone(MrdbCursor *self) uint32_t i; unsigned int field_count= self->field_count; - MARIADB_CHECK_STMT(self); + if (self->cursor_type == CURSOR_TYPE_READ_ONLY) + MARIADB_CHECK_STMT(self); if (PyErr_Occurred()) { return NULL; @@ -1116,7 +1127,7 @@ MrdbCursor_fetchmany(MrdbCursor *self, { goto end; } - self->affected_rows= CURSOR_NUM_ROWS(self); + self->row_count= CURSOR_NUM_ROWS(self); if (!(Row= mariadb_get_sequence_or_tuple(self))) { return NULL; @@ -1242,6 +1253,7 @@ MrdbCursor_executemany_fallback(MrdbCursor *self, } self->row_count++; } + self->lastrow_id= CURSOR_INSERT_ID(self); rc= mysql_query(self->stmt->mysql, "COMMIT"); if (!rc) return 0; @@ -1360,6 +1372,8 @@ MrdbCursor_executemany(MrdbCursor *self, } } end: + self->row_count= CURSOR_AFFECTED_ROWS(self); + self->lastrow_id= CURSOR_INSERT_ID(self); MARIADB_FREE_MEM(self->values); Py_RETURN_NONE; error: @@ -1422,36 +1436,13 @@ MrdbCursor_nextset(MrdbCursor *self) static PyObject * Mariadb_row_count(MrdbCursor *self) { - int64_t row_count= 0; - - MARIADB_CHECK_STMT(self); - if (PyErr_Occurred()) - { - return NULL; - } - /* PEP-249 requires to return -1 if the cursor was not executed before */ if (!self->statement) { return PyLong_FromLongLong(-1); } - if (self->field_count) - { - if (!self->is_buffered && !self->fetched) - { - row_count= -1; - } else - { - row_count= CURSOR_NUM_ROWS(self); - } - } - else { - row_count= self->row_count ? self->row_count : CURSOR_AFFECTED_ROWS(self); - if (!row_count) - row_count= -1; - } - return PyLong_FromLongLong(row_count); + return PyLong_FromLongLong(self->row_count); } static PyObject * @@ -1499,8 +1490,12 @@ MrdbCursor_setbuffered(MrdbCursor *self, PyObject *arg) static PyObject * MrdbCursor_lastrowid(MrdbCursor *self) { - MARIADB_CHECK_STMT(self); - return PyLong_FromUnsignedLongLong(CURSOR_INSERT_ID(self)); + if (!self->lastrow_id) + { + Py_INCREF(Py_None); + return Py_None; + } + return PyLong_FromUnsignedLongLong(self->lastrow_id); } /* iterator protocol */ diff --git a/testing/test/integration/test_connection.py b/testing/test/integration/test_connection.py index 7e09cb3..7ad06b8 100644 --- a/testing/test/integration/test_connection.py +++ b/testing/test/integration/test_connection.py @@ -25,7 +25,7 @@ def test_conpy36(self): default_conf = conf() try: conn= mariadb.connect(user=default_conf["user"], unix_socket="/does_not_exist/x.sock", port=default_conf["port"], host=default_conf["host"]) - except mariadb.DatabaseError: + except (mariadb.OperationalError,): pass def test_connection_default_file(self): @@ -59,7 +59,7 @@ def test_local_infile(self): cursor.execute("CREATE TEMPORARY TABLE t1 (a int)") try: cursor.execute("LOAD DATA LOCAL INFILE 'x.x' INTO TABLE t1") - except (mariadb.ProgrammingError, mariadb.DatabaseError): + except (mariadb.OperationalError,): pass del cursor del new_conn diff --git a/testing/test/integration/test_cursor.py b/testing/test/integration/test_cursor.py index ddca3cb..a8ab83b 100644 --- a/testing/test/integration/test_cursor.py +++ b/testing/test/integration/test_cursor.py @@ -166,13 +166,13 @@ def test_fetchmany(self): self.assertRaises(mariadb.Error, cursor.fetchall) cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id") - self.assertEqual(-1, cursor.rowcount) + self.assertEqual(0, cursor.rowcount) row = cursor.fetchall() self.assertEqual(row, params) self.assertEqual(5, cursor.rowcount) cursor.execute("SELECT id, name, city FROM test_fetchmany ORDER BY id") - self.assertEqual(-1, cursor.rowcount) + self.assertEqual(0, cursor.rowcount) row = cursor.fetchmany(1) self.assertEqual(row, [params[0]]) @@ -543,7 +543,7 @@ def test_conpy_15(self): cursor = self.connection.cursor() cursor.execute( "CREATE TEMPORARY TABLE test_conpy_15 (a int not null auto_increment primary key, b varchar(20))"); - self.assertEqual(cursor.lastrowid, 0) + self.assertEqual(cursor.lastrowid, None) cursor.execute("INSERT INTO test_conpy_15 VALUES (null, 'foo')") self.assertEqual(cursor.lastrowid, 1) cursor.execute("SELECT LAST_INSERT_ID()") @@ -570,7 +570,7 @@ def test_conpy_14(self): self.assertEqual(cursor.rowcount, -1) cursor.execute( "CREATE TEMPORARY TABLE test_conpy_14 (a int not null auto_increment primary key, b varchar(20))"); - self.assertEqual(cursor.rowcount, -1) + self.assertEqual(cursor.rowcount, 0) cursor.execute("INSERT INTO test_conpy_14 VALUES (null, 'foo')") self.assertEqual(cursor.rowcount, 1) vals = [(3, "bar"), (4, "this")] @@ -961,13 +961,13 @@ def test_conpy67(self): con= create_connection() cur = con.cursor() cur.execute("SELECT 1") - self.assertEqual(cur.rowcount, -1) + self.assertEqual(cur.rowcount, 0) cur.close() cur = con.cursor() cur.execute("CREATE TEMPORARY TABLE test_conpy67 (a int)") cur.execute("SELECT * from test_conpy67") - self.assertEqual(cur.rowcount, -1) + self.assertEqual(cur.rowcount, 0) cur.fetchall() self.assertEqual(cur.rowcount, 0) diff --git a/testing/test/integration/test_dbapi20.py b/testing/test/integration/test_dbapi20.py index 043f7ba..d496949 100644 --- a/testing/test/integration/test_dbapi20.py +++ b/testing/test/integration/test_dbapi20.py @@ -307,25 +307,19 @@ def test_rowcount(self): try: cur = con.cursor(buffered=True) self.executeDDL1(cur) - self.assertEqual(cur.rowcount, -1, - 'cursor.rowcount should be -1 after executing no-result ' + self.assertEqual(cur.rowcount, 0, + 'cursor.rowcount should be 0 after executing no-result ' 'statements' ) cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( self.table_prefix )) - self.assertTrue(cur.rowcount in (-1, 1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) + self.assertEqual(cur.rowcount, 1) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1, 1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertEqual(cur.rowcount, 1) self.executeDDL2(cur) - self.assertEqual(cur.rowcount, -1, - 'cursor.rowcount not being reset to -1 after executing ' + self.assertEqual(cur.rowcount, 0, + 'cursor.rowcount not being reset to 0 after executing ' 'no-result statements' ) finally: diff --git a/testing/test/integration/test_exception.py b/testing/test/integration/test_exception.py index 7190002..6e6216e 100644 --- a/testing/test/integration/test_exception.py +++ b/testing/test/integration/test_exception.py @@ -41,7 +41,7 @@ def test_conn_timeout_exception(self): start = datetime.today() try: create_connection({"connect_timeout": 1, "host": "8.8.8.8"}) - except mariadb.DatabaseError as err: + except mariadb.OperationalError as err: self.assertEqual(err.sqlstate, "HY000") self.assertEqual(err.errno, 2002) self.assertTrue(err.errmsg.find("server on '8.8.8.8'") > -1) diff --git a/testing/test/integration/test_nondbapi.py b/testing/test/integration/test_nondbapi.py index 80fcbe1..9e25644 100644 --- a/testing/test/integration/test_nondbapi.py +++ b/testing/test/integration/test_nondbapi.py @@ -27,7 +27,7 @@ def test_ping(self): self.connection.kill(id) try: new_conn.ping() - except mariadb.DatabaseError: + except mariadb.InterfaceError: pass del new_conn new_conn = create_connection() @@ -83,7 +83,7 @@ def test_reset(self): cursor.execute("SELECT 1 UNION SELECT 2") try: self.connection.ping() - except mariadb.DatabaseError: + except mariadb.InterfaceError: pass self.connection.reset()