Permalink
Browse files

Allocated own string for decimals.

Was using str(decimal), but it can return items in exponential notation, which I'm not sure
many databases would handle properly.
  • Loading branch information...
mkleehammer committed Sep 6, 2010
1 parent a64f284 commit 842aa9772c9cb3393fc08080233e801c61dc6304
Showing with 109 additions and 31 deletions.
  1. +4 −10 src/cursor.h
  2. +97 −21 src/params.cpp
  3. +8 −0 tests/sqlservertests.py
View
@@ -42,21 +42,15 @@ struct ParamInfo
SQLULEN ColumnSize;
SQLSMALLINT DecimalDigits;
- // If it is possible to bind into the parameter itself (ANSI string), this points into the Python object and must
- // not be modified or freed. Otherwise, this is memory allocated with 'malloc' specifically for this parameter and
- // should be freed after the call.
+ // The value pointer that will be bound. If `alloc` is true, this was allocated with malloc and must be freed.
+ // Otherwise it is zero or points into memory owned by the original Python parameter.
SQLPOINTER ParameterValuePtr;
SQLLEN BufferLength;
SQLLEN StrLen_or_Ind;
- // Optional Python object if we converted from the original parameter to one (e.g. Decimal to String). If
- // non-zero, ParameterValuePtr will point into this and allocated will be zero.
- PyObject* temp;
-
- // The amount of memory allocated (bytes) if binding into the original parameter or a temporary Python object was
- // not posssible. Otherwise zero.
- Py_ssize_t allocated;
+ // If true, the memory in ParameterValuePtr was allocated via malloc and must be freed.
+ bool allocated;
// Optional data. If used, ParameterValuePtr will point into this.
union
View
@@ -20,16 +20,8 @@ static bool GetParamType(Cursor* cur, Py_ssize_t iParam, SQLSMALLINT& type);
static void FreeInfos(ParamInfo* a, Py_ssize_t count)
{
for (Py_ssize_t i = 0; i < count; i++)
- {
- if (a[i].temp != 0)
- {
- Py_DECREF(a[i].temp);
- }
- else if (a[i].allocated != 0)
- {
+ if (a[i].allocated)
pyodbc_free(a[i].ParameterValuePtr);
- }
- }
pyodbc_free(a);
}
@@ -133,7 +125,6 @@ static bool GetStringInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamI
if (len <= cur->cnxn->varchar_maxlength)
{
info.ParameterType = SQL_VARCHAR;
- // info.BufferLength = len + 1; // + NULL
info.StrLen_or_Ind = len;
info.ParameterValuePtr = PyString_AS_STRING(param);
}
@@ -164,7 +155,7 @@ static bool GetUnicodeInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Param
info.ParameterValuePtr = SQLWCHAR_FromUnicode(pch, len);
if (info.ParameterValuePtr == 0)
return false;
- info.allocated = len * sizeof(SQLWCHAR);
+ info.allocated = true;
}
else
{
@@ -294,46 +285,131 @@ static bool GetFloatInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamIn
return true;
}
+static char* CreateDecimalString(long sign, PyObject* digits, long exp)
+{
+ long count = (long)PyTuple_GET_SIZE(digits);
+
+ char* pch;
+ long len;
+
+ if (exp >= 0)
+ {
+ // (1 2 3) exp = 2 --> '12300'
+
+ len = sign + count + exp + 1; // 1: NULL
+ pch = (char*)pyodbc_malloc(len);
+ if (pch)
+ {
+ char* p = pch;
+ if (sign)
+ *p++ = '-';
+ for (long i = 0; i < count; i++)
+ *p++ = (char)('0' + PyInt_AS_LONG(PyTuple_GET_ITEM(digits, i)));
+ for (long i = 0; i < exp; i++)
+ *p++ = '0';
+ *p = 0;
+ }
+ }
+ else if (-exp < count)
+ {
+ // (1 2 3) exp = -2 --> 1.23 : prec = 3, scale = 2
+
+ len = sign + count + 2; // 2: decimal + NULL
+ pch = (char*)pyodbc_malloc(len);
+ if (pch)
+ {
+ char* p = pch;
+ if (sign)
+ *p++ = '-';
+ int i = 0;
+ for (; i < (count + exp); i++)
+ *p++ = (char)('0' + PyInt_AS_LONG(PyTuple_GET_ITEM(digits, i)));
+ *p++ = '.';
+ for (; i < count; i++)
+ *p++ = (char)('0' + PyInt_AS_LONG(PyTuple_GET_ITEM(digits, i)));
+ *p++ = 0;
+ }
+ }
+ else
+ {
+ // (1 2 3) exp = -5 --> 0.00123 : prec = 5, scale = 5
+
+ len = sign + -exp + 3; // 3: leading zero + decimal + NULL
+
+ pch = (char*)pyodbc_malloc(len);
+ if (pch)
+ {
+ char* p = pch;
+ if (sign)
+ *p++ = '-';
+ *p++ = '0';
+ *p++ = '.';
+
+ for (int i = 0; i < -(exp + count); i++)
+ *p++ = '0';
+
+ for (int i = 0; i < count; i++)
+ *p++ = (char)('0' + PyInt_AS_LONG(PyTuple_GET_ITEM(digits, i)));
+ *p++ = 0;
+ }
+ }
+
+ I(pch == 0 || strlen(pch) + 1 == len);
+
+ return pch;
+}
+
static bool GetDecimalInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info)
{
+ // The NUMERIC structure never works right with SQL Server and probably a lot of other drivers. We'll bind as a
+ // string. Unfortunately, the Decimal class doesn't seem to have a way to force it to return a string without
+ // exponents, so we'll have to build it ourselves.
+
Object t = PyObject_CallMethod(param, "as_tuple", 0);
if (!t)
return false;
- Py_ssize_t digits = PyTuple_GET_SIZE(PyTuple_GET_ITEM(t.Get(), 1));
+ long sign = PyInt_AsLong(PyTuple_GET_ITEM(t.Get(), 0));
+ PyObject* digits = PyTuple_GET_ITEM(t.Get(), 1);
long exp = PyInt_AsLong(PyTuple_GET_ITEM(t.Get(), 2));
+ Py_ssize_t count = PyTuple_GET_SIZE(digits);
+
info.ValueType = SQL_C_CHAR;
info.ParameterType = SQL_NUMERIC;
if (exp >= 0)
{
// (1 2 3) exp = 2 --> '12300'
- info.ColumnSize = digits + exp;
+ info.ColumnSize = count + exp;
info.DecimalDigits = 0;
+
}
- else if (-exp <= digits)
+ else if (-exp <= count)
{
// (1 2 3) exp = -2 --> 1.23 : prec = 3, scale = 2
- info.ColumnSize = digits;
+ info.ColumnSize = count;
info.DecimalDigits = (SQLSMALLINT)-exp;
}
else
{
// (1 2 3) exp = -5 --> 0.00123 : prec = 5, scale = 5
- info.ColumnSize = digits + (-exp);
+ info.ColumnSize = count + (-exp);
info.DecimalDigits = (SQLSMALLINT)info.ColumnSize;
}
- I(info.ColumnSize >= info.DecimalDigits);
+ I(info.ColumnSize >= (SQLULEN)info.DecimalDigits);
- info.temp = PyObject_CallMethod(param, "__str__", 0);
- if (!info.temp)
+ info.ParameterValuePtr = CreateDecimalString(sign, digits, exp);
+ if (!info.ParameterValuePtr)
+ {
+ PyErr_NoMemory();
return false;
+ }
+ info.allocated = true;
- info.ParameterValuePtr = (SQLPOINTER)PyString_AS_STRING(info.temp);
- info.StrLen_or_Ind = PyString_GET_SIZE(info.temp);
+ info.StrLen_or_Ind = strlen((char*)info.ParameterValuePtr);
return true;
}
View
@@ -391,6 +391,14 @@ def t(self):
locals()['test_decimal_%s_%s_%s' % (p, s, n and 'n' or 'p')] = _maketest(p, s, n)
+ def test_decimal_e(self):
+ """Ensure exponential notation decimals are properly handled"""
+ value = Decimal((0, (1, 2, 3), 5)) # prints as 1.23E+7
+ self.cursor.execute("create table t1(d decimal(10, 2))")
+ self.cursor.execute("insert into t1 values (?)", value)
+ result = self.cursor.execute("select * from t1").fetchone()[0]
+ self.assertEqual(result, value)
+
def test_subquery_params(self):
"""Ensure parameter markers work in a subquery"""
self.cursor.execute("create table t1(id integer, s varchar(20))")

0 comments on commit 842aa97

Please sign in to comment.