Permalink
Browse files

Implemented Row comparisons.

  • Loading branch information...
Michael Kleehammer authored and mkleehammer committed Aug 30, 2010
1 parent c660885 commit d89d7e97c41984ee4d005f06893fbe0d168e5a50
Showing with 71 additions and 2 deletions.
  1. +49 −1 src/row.cpp
  2. +22 −1 tests/sqlservertests.py
View
@@ -244,7 +244,55 @@ Row_repr(PyObject* o)
return result.Detach();
}
+static PyObject* Row_richcompare(PyObject* olhs, PyObject* orhs, int op)
+{
+ if (!Row_Check(olhs) || !Row_Check(orhs))
+ {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+
+ Row* lhs = (Row*)olhs;
+ Row* rhs = (Row*)orhs;
+
+ if (lhs->cValues != rhs->cValues)
+ {
+ // Different sizes, so use the same rules as the tuple class.
+ bool result;
+ switch (op)
+ {
+ case Py_EQ: result = (lhs->cValues == rhs->cValues); break;
+ case Py_GE: result = (lhs->cValues >= rhs->cValues); break;
+ case Py_GT: result = (lhs->cValues > rhs->cValues); break;
+ case Py_LE: result = (lhs->cValues <= rhs->cValues); break;
+ case Py_LT: result = (lhs->cValues < rhs->cValues); break;
+ case Py_NE: result = (lhs->cValues != rhs->cValues); break;
+ }
+ PyObject* p = result ? Py_True : Py_False;
+ Py_INCREF(p);
+ return p;
+ }
+
+ for (Py_ssize_t i = 0, c = lhs->cValues; i < c; i++)
+ if (!PyObject_RichCompareBool(lhs->apValues[i], rhs->apValues[i], Py_EQ))
+ return PyObject_RichCompare(lhs->apValues[i], rhs->apValues[i], op);
+
+ // All items are equal.
+ switch (op)
+ {
+ case Py_EQ:
+ case Py_GE:
+ case Py_LE:
+ Py_RETURN_TRUE;
+
+ case Py_GT:
+ case Py_LT:
+ case Py_NE:
+ break;
+ }
+ Py_RETURN_FALSE;
+}
static PySequenceMethods row_as_sequence =
{
@@ -318,7 +366,7 @@ PyTypeObject RowType =
row_doc, // tp_doc
0, // tp_traverse
0, // tp_clear
- 0, // tp_richcompare
+ Row_richcompare, // tp_richcompare
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
View
@@ -1134,11 +1134,32 @@ def convert(value):
self.assertEqual(value, None)
self.cnxn.clear_output_converters()
-
def test_login_timeout(self):
# This can only test setting since there isn't a way to cause it to block on the server side.
cnxns = pyodbc.connect(self.connection_string, timeout=2)
+ def test_row_equal(self):
+ self.cursor.execute("create table t1(n int, s varchar(20))")
+ self.cursor.execute("insert into t1 values (1, 'test')")
+ row1 = self.cursor.execute("select n, s from t1").fetchone()
+ row2 = self.cursor.execute("select n, s from t1").fetchone()
+ b = (row1 == row2)
+ self.assertEqual(b, True)
+
+ def test_row_gtlt(self):
+ self.cursor.execute("create table t1(n int, s varchar(20))")
+ self.cursor.execute("insert into t1 values (1, 'test1')")
+ self.cursor.execute("insert into t1 values (1, 'test2')")
+ rows = self.cursor.execute("select n, s from t1 order by s").fetchall()
+ self.assert_(rows[0] < rows[1])
+ self.assert_(rows[0] <= rows[1])
+ self.assert_(rows[1] > rows[0])
+ self.assert_(rows[1] >= rows[0])
+ self.assert_(rows[0] != rows[1])
+
+ rows = list(rows)
+ rows.sort() # uses <
+
def main():
from optparse import OptionParser

0 comments on commit d89d7e9

Please sign in to comment.