Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Python 2 crash when ASCII keyword values passed to connect.

MakeConnectionString assumed that values were already converted to Unicode, but this was not
the case.  Added TextCopyToUnicode to contain the difference which cleans the code up nicely.

Discovered while trying to reproduce Issue 223.
  • Loading branch information...
commit 9a32afdff1cd2c71ec11a9f499575af9838bbd4c 1 parent 83db883
@mkleehammer authored
Showing with 75 additions and 53 deletions.
  1. +24 −0 src/pyodbccompat.h
  2. +51 −53 src/pyodbcmodule.cpp
View
24 src/pyodbccompat.h
@@ -102,5 +102,29 @@ inline Py_ssize_t Text_Size(PyObject* o)
return (o && PyUnicode_Check(o)) ? PyUnicode_GET_SIZE(o) : 0;
}
+inline Py_ssize_t TextCopyToUnicode(Py_UNICODE* buffer, PyObject* o)
+{
+ // Copies a String or Unicode object to a Unicode buffer and returns the number of characters copied.
+ // No NULL terminator is appended!
+
+#if PY_MAJOR_VERSION < 3
+ if (PyBytes_Check(o))
+ {
+ const Py_ssize_t cch = PyBytes_GET_SIZE(o);
+ const char * pch = PyBytes_AS_STRING(o);
+ for (Py_ssize_t i = 0; i < cch; i++)
+ *buffer++ = (Py_UNICODE)*pch++;
+ return cch;
+ }
+ else
+ {
+#endif
+ Py_ssize_t cch = PyUnicode_GET_SIZE(o);
+ memcpy(buffer, PyUnicode_AS_UNICODE(o), cch * sizeof(Py_UNICODE));
+ return cch;
+#if PY_MAJOR_VERSION < 3
+ }
+#endif
+}
#endif // PYODBCCOMPAT_H
View
104 src/pyodbcmodule.cpp
@@ -3,7 +3,7 @@
// documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so.
-//
+//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS
// OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
@@ -91,7 +91,7 @@ struct ExcInfo
#define MAKEEXCINFO(name, parent, doc) { #name, "pyodbc." #name, &name, &parent, doc }
static ExcInfo aExcInfos[] = {
- MAKEEXCINFO(Error, PyExc_Exception,
+ MAKEEXCINFO(Error, PyExc_Exception,
"Exception that is the base class of all other error exceptions. You can use\n"
"this to catch all errors with one single 'except' statement."),
MAKEEXCINFO(Warning, PyExc_Exception,
@@ -101,7 +101,7 @@ static ExcInfo aExcInfos[] = {
"Exception raised for errors that are related to the database interface rather\n"
"than the database itself."),
MAKEEXCINFO(DatabaseError, Error, "Exception raised for errors that are related to the database."),
- MAKEEXCINFO(DataError, DatabaseError,
+ MAKEEXCINFO(DataError, DatabaseError,
"Exception raised for errors that are due to problems with the processed data\n"
"like division by zero, numeric value out of range, etc."),
MAKEEXCINFO(OperationalError, DatabaseError,
@@ -168,12 +168,12 @@ static bool import_types()
// imported (among other problems).
PyObject* pdt = PyImport_ImportModule("datetime");
-
+
if (!pdt)
return false;
PyDateTime_IMPORT;
-
+
Cursor_init();
CnxnInfo_init();
GetData_init();
@@ -185,13 +185,13 @@ static bool import_types()
PyErr_SetString(PyExc_RuntimeError, "Unable to import decimal");
return false;
}
-
+
decimal_type = PyObject_GetAttrString(decimalmod, "Decimal");
Py_DECREF(decimalmod);
if (decimal_type == 0)
PyErr_SetString(PyExc_RuntimeError, "Unable to import decimal.Decimal.");
-
+
return decimal_type != 0;
}
@@ -210,7 +210,7 @@ static bool AllocateEnv()
return false;
}
}
-
+
if (!SQL_SUCCEEDED(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &henv)))
{
Py_FatalError("Can't initialize module pyodbc. SQLAllocEnv failed.");
@@ -246,7 +246,7 @@ static keywordmap keywordmaps[] =
static PyObject* mod_connect(PyObject* self, PyObject* args, PyObject* kwargs)
{
UNUSED(self);
-
+
Object pConnectString = 0;
int fAutoCommit = 0;
int fAnsi = 0; // force ansi
@@ -289,7 +289,7 @@ static PyObject* mod_connect(PyObject* self, PyObject* args, PyObject* kwargs)
return PyErr_Format(PyExc_TypeError, "Dictionary items passed to connect must be strings");
// // Note: key and value are *borrowed*.
- //
+ //
// // Check for the two non-connection string keywords we accept. (If we get many more of these, create something
// // table driven. Are we sure there isn't a Python function to parse keywords but leave those it doesn't know?)
// const char* szKey = PyString_AsString(key);
@@ -316,7 +316,7 @@ static PyObject* mod_connect(PyObject* self, PyObject* args, PyObject* kwargs)
return 0;
continue;
}
-
+
// Map DB API recommended names to ODBC names (e.g. user --> uid).
for (size_t i = 0; i < _countof(keywordmaps); i++)
@@ -338,19 +338,20 @@ static PyObject* mod_connect(PyObject* self, PyObject* args, PyObject* kwargs)
PyObject* str = PyObject_Str(value); // convert if necessary
if (!str)
return 0;
-
+
if (PyDict_SetItem(partsdict.Get(), key, str) == -1)
{
Py_XDECREF(str);
return 0;
}
+
Py_XDECREF(str);
}
if (PyDict_Size(partsdict.Get()))
pConnectString.Attach(MakeConnectionString(pConnectString.Get(), partsdict));
}
-
+
if (!pConnectString.IsValid())
return PyErr_Format(PyExc_TypeError, "no connection information was passed");
@@ -359,7 +360,7 @@ static PyObject* mod_connect(PyObject* self, PyObject* args, PyObject* kwargs)
if (!AllocateEnv())
return 0;
}
-
+
return (PyObject*)Connection_New(pConnectString.Get(), fAutoCommit != 0, fAnsi != 0, fUnicodeResults != 0, timeout);
}
@@ -368,7 +369,7 @@ static PyObject*
mod_datasources(PyObject* self)
{
UNUSED(self);
-
+
if (henv == SQL_NULL_HANDLE && !AllocateEnv())
return 0;
@@ -392,17 +393,17 @@ mod_datasources(PyObject* self)
Py_END_ALLOW_THREADS
if (!SQL_SUCCEEDED(ret))
break;
-
+
PyDict_SetItemString(result, (const char*)szDSN, PyString_FromString((const char*)szDesc));
nDirection = SQL_FETCH_NEXT;
}
-
+
if (ret != SQL_NO_DATA)
{
Py_DECREF(result);
return RaiseErrorFromHandle("SQLDataSources", SQL_NULL_HANDLE, SQL_NULL_HANDLE);
}
-
+
return result;
}
@@ -410,11 +411,11 @@ mod_datasources(PyObject* self)
static PyObject* mod_timefromticks(PyObject* self, PyObject* args)
{
UNUSED(self);
-
+
PyObject* num;
if (!PyArg_ParseTuple(args, "O", &num))
return 0;
-
+
if (!PyNumber_Check(num))
return PyErr_Format(PyExc_TypeError, "TimeFromTicks requires a number.");
@@ -495,21 +496,21 @@ static char connect_doc[] =
" attribute of the connection. The default is 0 which means the database's\n"
" default timeout, if any, is used.\n";
-static char timefromticks_doc[] =
+static char timefromticks_doc[] =
"TimeFromTicks(ticks) --> datetime.time\n"
"\n"
"Returns a time object initialized from the given ticks value (number of seconds\n"
"since the epoch; see the documentation of the standard Python time module for\n"
"details).";
-static char datefromticks_doc[] =
+static char datefromticks_doc[] =
"DateFromTicks(ticks) --> datetime.date\n" \
"\n" \
"Returns a date object initialized from the given ticks value (number of seconds\n" \
"since the epoch; see the documentation of the standard Python time module for\n" \
"details).";
-static char timestampfromticks_doc[] =
+static char timestampfromticks_doc[] =
"TimestampFromTicks(ticks) --> datetime.datetime\n" \
"\n" \
"Returns a datetime object initialized from the given ticks value (number of\n" \
@@ -542,7 +543,7 @@ static PyObject* mod_drivers(PyObject* self, PyObject* args)
long ret = RegOpenKeyEx(HKEY_LOCAL_MACHINE, "SOFTWARE\\ODBC\\ODBCINST.INI\\ODBC Drivers", 0, KEY_QUERY_VALUE, &key.hkey);
if (ret != ERROR_SUCCESS)
return PyErr_Format(PyExc_RuntimeError, "Unable to access the driver list in the registry. error=%ld", ret);
-
+
Object results(PyList_New(0));
DWORD index = 0;
char name[255];
@@ -556,7 +557,7 @@ static PyObject* mod_drivers(PyObject* self, PyObject* args)
PyObject* oname = PyString_FromStringAndSize(name, (Py_ssize_t)length);
if (!oname)
return 0;
-
+
if (PyList_Append(results.Get(), oname) != 0)
{
Py_DECREF(oname);
@@ -917,12 +918,12 @@ initpyodbc(void)
#else
module.Attach(Py_InitModule4("pyodbc", pyodbc_methods, module_doc, NULL, PYTHON_API_VERSION));
#endif
-
+
pModule = module.Get();
if (!module || !import_types() || !CreateExceptions())
return MODRETURN(0);
-
+
init_locale_info();
const char* szVersion = TOSTRING(PYODBC_VERSION);
@@ -935,7 +936,7 @@ initpyodbc(void)
Py_INCREF(Py_True);
PyModule_AddObject(module, "lowercase", Py_False);
Py_INCREF(Py_False);
-
+
PyModule_AddObject(module, "Connection", (PyObject*)&ConnectionType);
Py_INCREF((PyObject*)&ConnectionType);
PyModule_AddObject(module, "Cursor", (PyObject*)&CursorType);
@@ -972,7 +973,7 @@ initpyodbc(void)
Py_INCREF(binary_type);
PyModule_AddObject(module, "Binary", binary_type);
Py_INCREF(binary_type);
-
+
PyModule_AddIntConstant(module, "UNICODE_SIZE", sizeof(Py_UNICODE));
PyModule_AddIntConstant(module, "SQLWCHAR_SIZE", sizeof(SQLWCHAR));
@@ -984,7 +985,7 @@ initpyodbc(void)
{
ErrorCleanup();
}
-
+
return MODRETURN(pModule);
}
@@ -1000,14 +1001,26 @@ BOOL WINAPI DllMain(
}
#endif
+
static PyObject* MakeConnectionString(PyObject* existing, PyObject* parts)
{
// Creates a connection string from an optional existing connection string plus a dictionary of keyword value
- // pairs. The keywords must be String or Unicode objects and the values must be Unicode objects.
-
- Py_ssize_t length = 0;
+ // pairs.
+ //
+ // existing
+ // Optional Unicode connection string we will be appending to. Used when a partial connection string is passed
+ // in, followed by keyword parameters:
+ //
+ // connect("driver={x};database={y}", user='z')
+ //
+ // parts
+ // A dictionary of text keywords and text values that will be appended.
+
+ I(PyUnicode_Check(existing));
+
+ Py_ssize_t length = 0; // length in *characters*
if (existing)
- length = Text_Size(existing) + 1; // + 1 to add a trailing
+ length = Text_Size(existing) + 1; // + 1 to add a trailing semicolon
Py_ssize_t pos = 0;
PyObject* key = 0;
@@ -1017,7 +1030,7 @@ static PyObject* MakeConnectionString(PyObject* existing, PyObject* parts)
{
length += Text_Size(key) + 1 + Text_Size(value) + 1; // key=value;
}
-
+
PyObject* result = PyUnicode_FromUnicode(0, length);
if (!result)
return 0;
@@ -1027,32 +1040,17 @@ static PyObject* MakeConnectionString(PyObject* existing, PyObject* parts)
if (existing)
{
- memcpy(&buffer[offset], PyUnicode_AS_UNICODE(existing), PyUnicode_GET_SIZE(existing) * sizeof(Py_UNICODE));
- offset += PyUnicode_GET_SIZE(existing);
+ offset += TextCopyToUnicode(&buffer[offset], existing);
buffer[offset++] = (Py_UNICODE)';';
}
- Object okey;
-
pos = 0;
while (PyDict_Next(parts, &pos, &key, &value))
{
-#if PY_MAJOR_VERSION < 3
- if (PyBytes_Check(key))
- {
- okey = PyUnicode_FromString(PyBytes_AS_STRING(key));
- key = okey.Get();
- }
-#endif
-
- memcpy(&buffer[offset], PyUnicode_AS_UNICODE(key), PyUnicode_GET_SIZE(key) * sizeof(Py_UNICODE));
- offset += PyUnicode_GET_SIZE(key);
-
+ offset += TextCopyToUnicode(&buffer[offset], key);
buffer[offset++] = (Py_UNICODE)'=';
- memcpy(&buffer[offset], PyUnicode_AS_UNICODE(value), PyUnicode_GET_SIZE(value) * sizeof(Py_UNICODE));
- offset += PyUnicode_GET_SIZE(value);
-
+ offset += TextCopyToUnicode(&buffer[offset], value);
buffer[offset++] = (Py_UNICODE)';';
}
Please sign in to comment.
Something went wrong with that request. Please try again.