Skip to content
Browse files

Added Cursor.commit and Cursor.rollback

  • Loading branch information...
1 parent 244e84a commit b456b9f0cfd88526d6919784246a596f04df7313 Michael Kleehammer committed
Showing with 71 additions and 13 deletions.
  1. +26 −13 src/connection.cpp
  2. +5 −0 src/connection.h
  3. +26 −0 src/cursor.cpp
  4. +14 −0 tests2/sqlservertests.py
View
39 src/connection.cpp
@@ -616,24 +616,21 @@ static PyObject* Connection_getinfo(PyObject* self, PyObject* args)
return result;
}
-
-static PyObject* Connection_endtrans(PyObject* self, PyObject* args, SQLSMALLINT type)
+PyObject* Connection_endtrans(Connection* cnxn, SQLSMALLINT type)
{
- UNUSED(args);
-
- Connection* cnxn = Connection_Validate(self);
- if (!cnxn)
- return 0;
-
- TRACE("%s: cnxn=%p hdbc=%d\n", (type == SQL_COMMIT) ? "commit" : "rollback", cnxn, cnxn->hdbc);
+ // If called from Cursor.commit, it is possible that `cnxn` is deleted by another thread when we release them
+ // below. (The cursor has had its reference incremented by the method it is calling, but nothing has incremented
+ // the connections count. We could, but we really only need the HDBC.)
+ HDBC hdbc = cnxn->hdbc;
SQLRETURN ret;
Py_BEGIN_ALLOW_THREADS
- ret = SQLEndTran(SQL_HANDLE_DBC, cnxn->hdbc, type);
+ ret = SQLEndTran(SQL_HANDLE_DBC, hdbc, type);
Py_END_ALLOW_THREADS
+
if (!SQL_SUCCEEDED(ret))
{
- RaiseErrorFromHandle("SQLEndTran", cnxn->hdbc, SQL_NULL_HANDLE);
+ RaiseErrorFromHandle("SQLEndTran", hdbc, SQL_NULL_HANDLE);
return 0;
}
@@ -642,12 +639,28 @@ static PyObject* Connection_endtrans(PyObject* self, PyObject* args, SQLSMALLINT
static PyObject* Connection_commit(PyObject* self, PyObject* args)
{
- return Connection_endtrans(self, args, SQL_COMMIT);
+ UNUSED(args);
+
+ Connection* cnxn = Connection_Validate(self);
+ if (!cnxn)
+ return 0;
+
+ TRACE("commit: cnxn=%p hdbc=%d\n", cnxn, cnxn->hdbc);
+
+ return Connection_endtrans(cnxn, SQL_COMMIT);
}
static PyObject* Connection_rollback(PyObject* self, PyObject* args)
{
- return Connection_endtrans(self, args, SQL_ROLLBACK);
+ UNUSED(args);
+
+ Connection* cnxn = Connection_Validate(self);
+ if (!cnxn)
+ return 0;
+
+ TRACE("rollback: cnxn=%p hdbc=%d\n", cnxn, cnxn->hdbc);
+
+ return Connection_endtrans(cnxn, SQL_ROLLBACK);
}
static char cursor_doc[] =
View
5 src/connection.h
@@ -73,4 +73,9 @@ struct Connection
*/
PyObject* Connection_New(PyObject* pConnectString, bool fAutoCommit, bool fAnsi, bool fUnicodeResults, long timeout, bool fReadOnly);
+/*
+ * Used by the Cursor to implement commit and rollback.
+ */
+PyObject* Connection_endtrans(Connection* cnxn, SQLSMALLINT type);
+
#endif
View
26 src/cursor.cpp
@@ -1874,6 +1874,30 @@ static PyObject* Cursor_skip(PyObject* self, PyObject* args)
Py_RETURN_NONE;
}
+static const char* commit_doc =
+ "Commits any pending transaction to the database on the current connection,\n"
+ "including those from other cursors.\n";
+
+static PyObject* Cursor_commit(PyObject* self, PyObject* args)
+{
+ Cursor* cur = Cursor_Validate(self, CURSOR_REQUIRE_OPEN | CURSOR_RAISE_ERROR);
+ if (!cur)
+ return 0;
+ return Connection_endtrans(cur->cnxn, SQL_COMMIT);
+}
+
+static char rollback_doc[] =
+ "Rolls back any pending transaction to the database on the current connection,\n"
+ "including those from other cursors.\n";
+
+static PyObject* Cursor_rollback(PyObject* self, PyObject* args)
+{
+ Cursor* cur = Cursor_Validate(self, CURSOR_REQUIRE_OPEN | CURSOR_RAISE_ERROR);
+ if (!cur)
+ return 0;
+ return Connection_endtrans(cur->cnxn, SQL_ROLLBACK);
+}
+
static PyObject* Cursor_ignored(PyObject* self, PyObject* args)
{
@@ -2052,6 +2076,8 @@ static PyMethodDef Cursor_methods[] =
{ "procedures", (PyCFunction)Cursor_procedures, METH_VARARGS|METH_KEYWORDS, procedures_doc },
{ "procedureColumns", (PyCFunction)Cursor_procedureColumns, METH_VARARGS|METH_KEYWORDS, procedureColumns_doc },
{ "skip", (PyCFunction)Cursor_skip, METH_VARARGS, skip_doc },
+ { "commit", (PyCFunction)Cursor_commit, METH_NOARGS, commit_doc },
+ { "rollback", (PyCFunction)Cursor_rollback, METH_NOARGS, rollback_doc },
{ 0, 0, 0, 0 }
};
View
14 tests2/sqlservertests.py
@@ -1015,6 +1015,20 @@ def test_autocommit(self):
othercnxn.autocommit = False
self.assertEqual(othercnxn.autocommit, False)
+ def test_cursorcommit(self):
+ "Ensure cursor.commit works"
+ othercnxn = pyodbc.connect(self.connection_string)
+ othercursor = othercnxn.cursor()
+ othercnxn = None
+
+ othercursor.execute("create table t1(s varchar(20))")
+ othercursor.execute("insert into t1 values(?)", 'test')
+ othercursor.commit()
+
+ value = self.cursor.execute("select s from t1").fetchone()[0]
+ self.assertEqual(value, 'test')
+
+
def test_unicode_results(self):
"Ensure unicode_results forces Unicode"
othercnxn = pyodbc.connect(self.connection_string, unicode_results=True)

0 comments on commit b456b9f

Please sign in to comment.
Something went wrong with that request. Please try again.