Permalink
Browse files

Cursor.executemany now accepts an iterator or a generator. closes #12

  • Loading branch information...
1 parent 6463a91 commit e543fe249ff6a7916365dc1cd86660524b04aab5 @mkleehammer committed Sep 27, 2012
Showing with 132 additions and 29 deletions.
  1. +79 −19 src/cursor.cpp
  2. +4 −1 src/wrapper.h
  3. +49 −9 tests2/pgtests.py
View
@@ -25,6 +25,7 @@
#include "dbspecific.h"
#include "sqlwchar.h"
#include <datetime.h>
+#include "wrapper.h"
enum
{
@@ -959,32 +960,91 @@ static PyObject* Cursor_executemany(PyObject* self, PyObject* args)
return 0;
}
- if (!IsSequence(param_seq))
+ if (IsSequence(param_seq))
{
- PyErr_SetString(ProgrammingError, "The second parameter to executemany must be a sequence.");
- return 0;
- }
+ Py_ssize_t c = PySequence_Size(param_seq);
- Py_ssize_t c = PySequence_Size(param_seq);
+ if (c == 0)
+ {
+ PyErr_SetString(ProgrammingError, "The second parameter to executemany must not be empty.");
+ return 0;
+ }
- if (c == 0)
- {
- PyErr_SetString(ProgrammingError, "The second parameter to executemany must not be empty.");
- return 0;
+ for (Py_ssize_t i = 0; i < c; i++)
+ {
+ PyObject* params = PySequence_GetItem(param_seq, i);
+ PyObject* result = execute(cursor, pSql, params, false);
+ bool success = result != 0;
+ Py_XDECREF(result);
+ Py_DECREF(params);
+ if (!success)
+ {
+ cursor->rowcount = -1;
+ return 0;
+ }
+ }
}
-
- for (Py_ssize_t i = 0; i < c; i++)
+ else if (PyGen_Check(param_seq) || PyIter_Check(param_seq))
{
- PyObject* params = PySequence_GetItem(param_seq, i);
- PyObject* result = execute(cursor, pSql, params, false);
- bool success = result != 0;
- Py_XDECREF(result);
- Py_DECREF(params);
- if (!success)
+ Object iter;
+
+ if (PyGen_Check(param_seq))
{
- cursor->rowcount = -1;
- return 0;
+ iter = PyObject_GetIter(param_seq);
}
+ else
+ {
+ iter = param_seq;
+ Py_INCREF(param_seq);
+ }
+
+ Object params;
+
+ // If the iterator/generator returns a single non-sequence object, we'll supply a temporary tuple so execute
+ // doesn't have to deal with single items.
+ Object tmptuple;
+
+ while (params.Attach(PyIter_Next(iter)))
+ {
+ PyObject* param = 0;
+
+ if (!IsSequence(params))
+ {
+ if (!tmptuple.IsValid())
+ {
+ tmptuple.Attach(PyTuple_New(1));
+ if (!tmptuple.IsValid())
+ return 0;
+ }
+ PyTuple_SetItem(tmptuple, 0, params);
+
+ param = tmptuple;
+ }
+ else
+ {
+ param = params;
+ }
+
+ PyObject* result = execute(cursor, pSql, param, false);
+ bool success = result != 0;
+ Py_XDECREF(result);
+
+ PyTuple_SetItem(tmptuple, 0, 0);
+
+ if (!success)
+ {
+ cursor->rowcount = -1;
+ return 0;
+ }
+ }
+
+ if (PyErr_Occurred())
+ return 0;
+ }
+ else
+ {
+ PyErr_SetString(ProgrammingError, "The second parameter to executemany must be a sequence, iterator, or generator.");
+ return 0;
}
cursor->rowcount = -1;
View
@@ -31,10 +31,13 @@ class Object
bool IsValid() const { return p != 0; }
- void Attach(PyObject* _p)
+ bool Attach(PyObject* _p)
{
+ // Returns true if the new pointer is non-zero.
+
Py_XDECREF(p);
p = _p;
+ return (_p != 0);
}
PyObject* Detach()
View
@@ -57,7 +57,7 @@ def setUp(self):
self.cnxn.commit()
except:
pass
-
+
self.cnxn.rollback()
@@ -192,7 +192,7 @@ def test_negative_decimal_scale(self):
def _exec(self):
self.cursor.execute(self.sql)
-
+
def test_close_cnxn(self):
"""Make sure using a Cursor after closing its connection doesn't crash."""
@@ -201,7 +201,7 @@ def test_close_cnxn(self):
self.cursor.execute("select * from t1")
self.cnxn.close()
-
+
# Now that the connection is closed, we expect an exception. (If the code attempts to use
# the HSTMT, we'll get an access violation instead.)
self.sql = "select * from t1"
@@ -263,13 +263,13 @@ def test_rowcount_select(self):
# PostgreSQL driver fails here?
# def test_rowcount_reset(self):
# "Ensure rowcount is reset to -1"
- #
+ #
# self.cursor.execute("create table t1(i int)")
# count = 4
# for i in range(count):
# self.cursor.execute("insert into t1 values (?)", i)
# self.assertEquals(self.cursor.rowcount, 1)
- #
+ #
# self.cursor.execute("create table t2(i int)")
# self.assertEquals(self.cursor.rowcount, -1)
@@ -291,7 +291,7 @@ def test_lower_case(self):
# Put it back so other tests don't fail.
pyodbc.lowercase = False
-
+
def test_row_description(self):
"""
Ensure Cursor.description is accessible as Row.cursor_description.
@@ -303,7 +303,7 @@ def test_row_description(self):
row = self.cursor.execute("select * from t1").fetchone()
self.assertEquals(self.cursor.description, row.cursor_description)
-
+
def test_executemany(self):
self.cursor.execute("create table t1(a int, b varchar(10))")
@@ -336,10 +336,50 @@ def test_executemany_failure(self):
params = [ (1, 'good'),
('error', 'not an int'),
(3, 'good') ]
-
+
self.failUnlessRaises(pyodbc.Error, self.cursor.executemany, "insert into t1(a, b) value (?, ?)", params)
-
+
+ def test_executemany_generator_scalar(self):
+ self.cursor.execute("create table t1(a int)")
+
+ self.cursor.executemany("insert into t1(a) values (?)", (i for i in range(4)))
+
+ row = self.cursor.execute("select min(a) mina, max(a) maxa from t1").fetchone()
+
+ self.assertEqual(row.mina, 0)
+ self.assertEqual(row.maxa, 3)
+
+
+
+ def test_executemany_generator_tuple(self):
+ self.cursor.execute("create table t1(a int)")
+
+ self.cursor.executemany("insert into t1(a) values (?)", ((i,) for i in range(4)))
+
+ row = self.cursor.execute("select min(a) mina, max(a) maxa from t1").fetchone()
+
+ self.assertEqual(row.mina, 0)
+ self.assertEqual(row.maxa, 3)
+
+
+ def test_executemany_iterator(self):
+ self.cursor.execute("create table t1(a int)")
+
+ values = [ 0, 1, 2, 3 ]
+
+ x = iter(values)
+ print('x:', x, type(x))
+
+ self.cursor.executemany("insert into t1(a) values (?)", x)
+
+ row = self.cursor.execute("select min(a) mina, max(a) maxa from t1").fetchone()
+
+ self.assertEqual(row.mina, 0)
+ self.assertEqual(row.maxa, 3)
+
+
+
def test_row_slicing(self):
self.cursor.execute("create table t1(a int, b int, c int, d int)");
self.cursor.execute("insert into t1 values(1,2,3,4)")

0 comments on commit e543fe2

Please sign in to comment.