Permalink
Browse files

Added user-defined conversions to the connection.

  • Loading branch information...
1 parent f85004f commit d19c67b66ca19dbc242085591c39115e05374dee Michael Kleehammer committed with Aug 30, 2010
Showing with 220 additions and 17 deletions.
  1. +135 −15 src/connection.cpp
  2. +11 −0 src/connection.h
  3. +6 −2 src/cursor.cpp
  4. +34 −0 src/getdata.cpp
  5. +6 −0 src/getdata.h
  6. +28 −0 tests/sqlservertests.py
View
@@ -177,6 +177,9 @@ PyObject* Connection_New(PyObject* pConnectString, bool fAutoCommit, bool fAnsi,
cnxn->searchescape = 0;
cnxn->timeout = 0;
cnxn->unicode_results = fUnicodeResults;
+ cnxn->conv_count = 0;
+ cnxn->conv_types = 0;
+ cnxn->conv_funcs = 0;
//
// Initialize autocommit mode.
@@ -227,6 +230,33 @@ PyObject* Connection_New(PyObject* pConnectString, bool fAutoCommit, bool fAnsi,
return reinterpret_cast<PyObject*>(cnxn);
}
+static void _clear_conv(Connection* cnxn)
+{
+ if (cnxn->conv_count != 0)
+ {
+ free(cnxn->conv_types);
+ cnxn->conv_types = 0;
+
+ for (int i = 0; i < cnxn->conv_count; i++)
+ Py_XDECREF(cnxn->conv_funcs[i]);
+ free(cnxn->conv_funcs);
+ cnxn->conv_funcs = 0;
+
+ cnxn->conv_count = 0;
+ }
+}
+
+static char conv_clear_doc[] =
+ "clear_output_converters() --> None\n\n"
+ "Remove all output converter functions.";
+
+static PyObject*
+Connection_conv_clear(Connection* cnxn)
+{
+ _clear_conv(cnxn);
+
+ Py_RETURN_NONE;
+}
static int
Connection_clear(Connection* cnxn)
@@ -254,6 +284,8 @@ Connection_clear(Connection* cnxn)
Py_XDECREF(cnxn->searchescape);
cnxn->searchescape = 0;
+ _clear_conv(cnxn);
+
return 0;
}
@@ -585,17 +617,16 @@ Connection_rollback(PyObject* self, PyObject* args)
}
static char cursor_doc[] =
- "Return a new Cursor Object using the connection.";
+ "Return a new Cursor object using the connection.";
static char execute_doc[] =
- "execute(sql, [params]) --> None | Cursor | count\n" \
- "\n" \
- "Creates a new Cursor object, calls its execute method, and returns its return\n" \
- "value. See Cursor.execute for a description of the parameter formats and\n" \
- "return values.\n" \
- "\n" \
- "This is a convenience method that is not part of the DB API. Since a new\n" \
- "Cursor is allocated by each call, this should not be used if more than one SQL\n" \
+ "execute(sql, [params]) --> Cursor\n"
+ "\n"
+ "Create a new Cursor object, call its execute method, and return it. See\n"
+ "Cursor.execute for more details.\n"
+ "\n"
+ "This is a convenience method that is not part of the DB API. Since a new\n"
+ "Cursor is allocated by each call, this should not be used if more than one SQL\n"
"statement needs to be executed.";
static char commit_doc[] =
@@ -731,14 +762,103 @@ Connection_settimeout(PyObject* self, PyObject* value, void* closure)
return 0;
}
+static bool _add_converter(Connection* cnxn, int sqltype, PyObject* func)
+{
+ if (cnxn->conv_count)
+ {
+ // If the sqltype is already registered, replace the old conversion function with the new.
+ for (int i = 0; i < cnxn->conv_count; i++)
+ {
+ if (cnxn->conv_types[i] == sqltype)
+ {
+ Py_XDECREF(cnxn->conv_funcs[i]);
+ cnxn->conv_funcs[i] = func;
+ Py_INCREF(func);
+ return true;
+ }
+ }
+ }
+
+ int oldcount = cnxn->conv_count;
+ SQLSMALLINT* oldtypes = cnxn->conv_types;
+ PyObject** oldfuncs = cnxn->conv_funcs;
+
+ int newcount = oldcount + 1;
+ SQLSMALLINT* newtypes = (SQLSMALLINT*)malloc(sizeof(SQLSMALLINT) * newcount);
+ PyObject** newfuncs = (PyObject**)malloc(sizeof(PyObject*) * newcount);
+
+ if (newtypes == 0 || newfuncs == 0)
+ {
+ if (newtypes)
+ free(newtypes);
+ if (newfuncs)
+ free(newfuncs);
+ PyErr_NoMemory();
+ return false;
+ }
+
+ newtypes[0] = sqltype;
+ newfuncs[0] = func;
+ Py_INCREF(func);
+
+ cnxn->conv_count = newcount;
+ cnxn->conv_types = newtypes;
+ cnxn->conv_funcs = newfuncs;
+
+ if (oldcount != 0)
+ {
+ // copy old items
+ memcpy(&newtypes[1], oldtypes, sizeof(int) * oldcount);
+ memcpy(&newfuncs[1], oldfuncs, sizeof(PyObject*) * oldcount);
+
+ free(oldtypes);
+ free(oldfuncs);
+ }
+
+ return true;
+}
+
+static char conv_add_doc[] =
+ "add_output_converter(sqltype, func) --> None\n"
+ "\n"
+ "Register an output converter function that will be called whenever a value with\n"
+ "the given SQL type is read from the database.\n"
+ "\n"
+ "sqltype\n"
+ " The integer SQL type value to convert, which can be one of the defined\n"
+ " standard constants (e.g. pyodbc.SQL_VARCHAR) or a database-specific value\n"
+ " (e.g. -151 for the SQL Server 2008 geometry data type).\n"
+ "\n"
+ "func\n"
+ " The converter function which will be called with a single parameter, the\n"
+ " value, and should return the converted value. If the value is NULL, the\n"
+ " parameter will be None. Otherwise it will be a Python string.";
+
+
+static PyObject*
+Connection_conv_add(Connection* cnxn, PyObject* args)
+{
+ int sqltype;
+ PyObject* func;
+ if (!PyArg_ParseTuple(args, "iO", &sqltype, &func))
+ return 0;
+
+ if (!_add_converter(cnxn, sqltype, func))
+ return 0;
+
+ Py_RETURN_NONE;
+}
+
static struct PyMethodDef Connection_methods[] =
{
- { "cursor", (PyCFunction)Connection_cursor, METH_NOARGS, cursor_doc },
- { "close", (PyCFunction)Connection_close, METH_NOARGS, close_doc },
- { "execute", (PyCFunction)Connection_execute, METH_VARARGS, execute_doc },
- { "commit", (PyCFunction)Connection_commit, METH_NOARGS, commit_doc },
- { "rollback", (PyCFunction)Connection_rollback, METH_NOARGS, rollback_doc },
- { "getinfo", (PyCFunction)Connection_getinfo, METH_VARARGS, getinfo_doc },
+ { "cursor", (PyCFunction)Connection_cursor, METH_NOARGS, cursor_doc },
+ { "close", (PyCFunction)Connection_close, METH_NOARGS, close_doc },
+ { "execute", (PyCFunction)Connection_execute, METH_VARARGS, execute_doc },
+ { "commit", (PyCFunction)Connection_commit, METH_NOARGS, commit_doc },
+ { "rollback", (PyCFunction)Connection_rollback, METH_NOARGS, rollback_doc },
+ { "getinfo", (PyCFunction)Connection_getinfo, METH_VARARGS, getinfo_doc },
+ { "add_output_converter", (PyCFunction)Connection_conv_add, METH_VARARGS, conv_add_doc },
+ { "clear_output_converters", (PyCFunction)Connection_conv_clear, METH_NOARGS, conv_clear_doc },
{ 0, 0, 0, 0 }
};
View
@@ -51,6 +51,17 @@ struct Connection
int varchar_maxlength;
int wvarchar_maxlength;
int binary_maxlength;
+
+ // Output conversions. Maps from SQL type in conv_types to the converter function in conv_funcs.
+ //
+ // If conv_count is zero, conv_types and conv_funcs will also be zero.
+ //
+ // pyodbc uses this manual mapping for speed and portability. The STL collection classes use the new operator and
+ // throw exceptions when out of memory. pyodbc does not use any exceptions.
+
+ int conv_count; // how many items are in conv_types and conv_funcs.
+ SQLSMALLINT* conv_types; // array of SQL_TYPEs to convert
+ PyObject** conv_funcs; // array of Python functions
};
#define Connection_Check(op) PyObject_TypeCheck(op, &ConnectionType)
View
@@ -128,7 +128,7 @@ inline bool IsNumericType(SQLSMALLINT sqltype)
PyObject*
-PythonTypeFromSqlType(const SQLCHAR* name, SQLSMALLINT type, bool unicode_results)
+PythonTypeFromSqlType(Cursor* cur, const SQLCHAR* name, SQLSMALLINT type, bool unicode_results)
{
// Returns a type object ('int', 'str', etc.) for the given ODBC C type. This is used to populate
// Cursor.description with the type of Python object that will be returned for each column.
@@ -141,6 +141,10 @@ PythonTypeFromSqlType(const SQLCHAR* name, SQLSMALLINT type, bool unicode_result
//
// The returned object does not have its reference count incremented!
+ int conv_index = GetUserConvIndex(cur, type);
+ if (conv_index != -1)
+ return (PyObject*)&PyString_Type;
+
PyObject* pytype = 0;
switch (type)
@@ -273,7 +277,7 @@ create_name_map(Cursor* cur, SQLSMALLINT field_count, bool lower)
if (lower)
_strlwr((char*)name);
- type = PythonTypeFromSqlType(name, nDataType, cur->cnxn->unicode_results);
+ type = PythonTypeFromSqlType(cur, name, nDataType, cur->cnxn->unicode_results);
if (!type)
goto done;
View
@@ -347,6 +347,23 @@ GetDataString(Cursor* cur, Py_ssize_t iCol)
return 0;
}
+
+static PyObject*
+GetDataUser(Cursor* cur, Py_ssize_t iCol, int conv)
+{
+ // conv
+ // The index into the connection's user-defined conversions `conv_types`.
+
+ PyObject* value = GetDataString(cur, iCol);
+ if (value == 0)
+ return 0;
+
+ PyObject* result = PyObject_CallFunction(cur->cnxn->conv_funcs[conv], "(O)", value);
+ Py_DECREF(value);
+ return result;
+}
+
+
static PyObject*
GetDataBuffer(Cursor* cur, Py_ssize_t iCol)
{
@@ -568,6 +585,17 @@ GetDataTimestamp(Cursor* cur, Py_ssize_t iCol)
return PyDateTime_FromDateAndTime(value.year, value.month, value.day, value.hour, value.minute, value.second, micros);
}
+int GetUserConvIndex(Cursor* cur, SQLSMALLINT sql_type)
+{
+ // If this sql type has a user-defined conversion, the index into the connection's `conv_funcs` array is returned.
+ // Otherwise -1 is returned.
+
+ for (int i = 0; i < cur->cnxn->conv_count; i++)
+ if (cur->cnxn->conv_types[i] == sql_type)
+ return i;
+ return -1;
+}
+
PyObject*
GetData(Cursor* cur, Py_ssize_t iCol)
@@ -578,6 +606,12 @@ GetData(Cursor* cur, Py_ssize_t iCol)
ColumnInfo* pinfo = &cur->colinfos[iCol];
+ // First see if there is a user-defined conversion.
+
+ int conv_index = GetUserConvIndex(cur, pinfo->sql_type);
+ if (conv_index != -1)
+ return GetDataUser(cur, iCol, conv_index);
+
switch (pinfo->sql_type)
{
case SQL_WCHAR:
View
@@ -6,4 +6,10 @@ void GetData_init();
PyObject* GetData(Cursor* cur, Py_ssize_t iCol);
+/**
+ * If this sql type has a user-defined conversion, the index into the connection's `conv_funcs` array is returned.
+ * Otherwise -1 is returned.
+ */
+int GetUserConvIndex(Cursor* cur, SQLSMALLINT sql_type);
+
#endif // _GETDATA_H_
View
@@ -1107,6 +1107,34 @@ def test_none_param(self):
self.assertEqual(row.blob, None)
+ def test_output_conversion(self):
+ def convert(value):
+ # `value` will be a string. We'll simply add an X at the beginning at the end.
+ return 'X' + value + 'X'
+ self.cnxn.add_output_converter(pyodbc.SQL_VARCHAR, convert)
+ self.cursor.execute("create table t1(n int, v varchar(10))")
+ self.cursor.execute("insert into t1 values (1, '123.45')")
+ value = self.cursor.execute("select v from t1").fetchone()[0]
+ self.assertEqual(value, 'X123.45X')
+
+ # Now clear the conversions and try again. There should be no Xs this time.
+ self.cnxn.clear_output_converters()
+ value = self.cursor.execute("select v from t1").fetchone()[0]
+ self.assertEqual(value, '123.45')
+
+
+ def test_geometry_null_insert(self):
+ def convert(value):
+ return value
+
+ self.cnxn.add_output_converter(-151, convert) # -151 is SQL Server's geometry
+ self.cursor.execute("create table t1(n int, v geometry)")
+ self.cursor.execute("insert into t1 values (?, ?)", 1, None)
+ value = self.cursor.execute("select v from t1").fetchone()[0]
+ self.assertEqual(value, None)
+ self.cnxn.clear_output_converters()
+
+
def main():
from optparse import OptionParser
parser = OptionParser(usage=usage)

0 comments on commit d19c67b

Please sign in to comment.