Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Allocated own string for decimals.

Was using str(decimal), but it can return items in exponential notation, which I'm not sure
many databases would handle properly.
  • Loading branch information...
commit 842aa9772c9cb3393fc08080233e801c61dc6304 1 parent a64f284
@mkleehammer authored
Showing with 109 additions and 31 deletions.
  1. +4 −10 src/cursor.h
  2. +97 −21 src/params.cpp
  3. +8 −0 tests/sqlservertests.py
View
14 src/cursor.h
@@ -42,21 +42,15 @@ struct ParamInfo
SQLULEN ColumnSize;
SQLSMALLINT DecimalDigits;
- // If it is possible to bind into the parameter itself (ANSI string), this points into the Python object and must
- // not be modified or freed. Otherwise, this is memory allocated with 'malloc' specifically for this parameter and
- // should be freed after the call.
+ // The value pointer that will be bound. If `alloc` is true, this was allocated with malloc and must be freed.
+ // Otherwise it is zero or points into memory owned by the original Python parameter.
SQLPOINTER ParameterValuePtr;
SQLLEN BufferLength;
SQLLEN StrLen_or_Ind;
- // Optional Python object if we converted from the original parameter to one (e.g. Decimal to String). If
- // non-zero, ParameterValuePtr will point into this and allocated will be zero.
- PyObject* temp;
-
- // The amount of memory allocated (bytes) if binding into the original parameter or a temporary Python object was
- // not posssible. Otherwise zero.
- Py_ssize_t allocated;
+ // If true, the memory in ParameterValuePtr was allocated via malloc and must be freed.
+ bool allocated;
// Optional data. If used, ParameterValuePtr will point into this.
union
View
118 src/params.cpp
@@ -20,16 +20,8 @@ static bool GetParamType(Cursor* cur, Py_ssize_t iParam, SQLSMALLINT& type);
static void FreeInfos(ParamInfo* a, Py_ssize_t count)
{
for (Py_ssize_t i = 0; i < count; i++)
- {
- if (a[i].temp != 0)
- {
- Py_DECREF(a[i].temp);
- }
- else if (a[i].allocated != 0)
- {
+ if (a[i].allocated)
pyodbc_free(a[i].ParameterValuePtr);
- }
- }
pyodbc_free(a);
}
@@ -133,7 +125,6 @@ static bool GetStringInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamI
if (len <= cur->cnxn->varchar_maxlength)
{
info.ParameterType = SQL_VARCHAR;
- // info.BufferLength = len + 1; // + NULL
info.StrLen_or_Ind = len;
info.ParameterValuePtr = PyString_AS_STRING(param);
}
@@ -164,7 +155,7 @@ static bool GetUnicodeInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Param
info.ParameterValuePtr = SQLWCHAR_FromUnicode(pch, len);
if (info.ParameterValuePtr == 0)
return false;
- info.allocated = len * sizeof(SQLWCHAR);
+ info.allocated = true;
}
else
{
@@ -294,15 +285,96 @@ static bool GetFloatInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamIn
return true;
}
+static char* CreateDecimalString(long sign, PyObject* digits, long exp)
+{
+ long count = (long)PyTuple_GET_SIZE(digits);
+
+ char* pch;
+ long len;
+
+ if (exp >= 0)
+ {
+ // (1 2 3) exp = 2 --> '12300'
+
+ len = sign + count + exp + 1; // 1: NULL
+ pch = (char*)pyodbc_malloc(len);
+ if (pch)
+ {
+ char* p = pch;
+ if (sign)
+ *p++ = '-';
+ for (long i = 0; i < count; i++)
+ *p++ = (char)('0' + PyInt_AS_LONG(PyTuple_GET_ITEM(digits, i)));
+ for (long i = 0; i < exp; i++)
+ *p++ = '0';
+ *p = 0;
+ }
+ }
+ else if (-exp < count)
+ {
+ // (1 2 3) exp = -2 --> 1.23 : prec = 3, scale = 2
+
+ len = sign + count + 2; // 2: decimal + NULL
+ pch = (char*)pyodbc_malloc(len);
+ if (pch)
+ {
+ char* p = pch;
+ if (sign)
+ *p++ = '-';
+ int i = 0;
+ for (; i < (count + exp); i++)
+ *p++ = (char)('0' + PyInt_AS_LONG(PyTuple_GET_ITEM(digits, i)));
+ *p++ = '.';
+ for (; i < count; i++)
+ *p++ = (char)('0' + PyInt_AS_LONG(PyTuple_GET_ITEM(digits, i)));
+ *p++ = 0;
+ }
+ }
+ else
+ {
+ // (1 2 3) exp = -5 --> 0.00123 : prec = 5, scale = 5
+
+ len = sign + -exp + 3; // 3: leading zero + decimal + NULL
+
+ pch = (char*)pyodbc_malloc(len);
+ if (pch)
+ {
+ char* p = pch;
+ if (sign)
+ *p++ = '-';
+ *p++ = '0';
+ *p++ = '.';
+
+ for (int i = 0; i < -(exp + count); i++)
+ *p++ = '0';
+
+ for (int i = 0; i < count; i++)
+ *p++ = (char)('0' + PyInt_AS_LONG(PyTuple_GET_ITEM(digits, i)));
+ *p++ = 0;
+ }
+ }
+
+ I(pch == 0 || strlen(pch) + 1 == len);
+
+ return pch;
+}
+
static bool GetDecimalInfo(Cursor* cur, Py_ssize_t index, PyObject* param, ParamInfo& info)
{
+ // The NUMERIC structure never works right with SQL Server and probably a lot of other drivers. We'll bind as a
+ // string. Unfortunately, the Decimal class doesn't seem to have a way to force it to return a string without
+ // exponents, so we'll have to build it ourselves.
+
Object t = PyObject_CallMethod(param, "as_tuple", 0);
if (!t)
return false;
- Py_ssize_t digits = PyTuple_GET_SIZE(PyTuple_GET_ITEM(t.Get(), 1));
+ long sign = PyInt_AsLong(PyTuple_GET_ITEM(t.Get(), 0));
+ PyObject* digits = PyTuple_GET_ITEM(t.Get(), 1);
long exp = PyInt_AsLong(PyTuple_GET_ITEM(t.Get(), 2));
+ Py_ssize_t count = PyTuple_GET_SIZE(digits);
+
info.ValueType = SQL_C_CHAR;
info.ParameterType = SQL_NUMERIC;
@@ -310,30 +382,34 @@ static bool GetDecimalInfo(Cursor* cur, Py_ssize_t index, PyObject* param, Param
{
// (1 2 3) exp = 2 --> '12300'
- info.ColumnSize = digits + exp;
+ info.ColumnSize = count + exp;
info.DecimalDigits = 0;
+
}
- else if (-exp <= digits)
+ else if (-exp <= count)
{
// (1 2 3) exp = -2 --> 1.23 : prec = 3, scale = 2
- info.ColumnSize = digits;
+ info.ColumnSize = count;
info.DecimalDigits = (SQLSMALLINT)-exp;
}
else
{
// (1 2 3) exp = -5 --> 0.00123 : prec = 5, scale = 5
- info.ColumnSize = digits + (-exp);
+ info.ColumnSize = count + (-exp);
info.DecimalDigits = (SQLSMALLINT)info.ColumnSize;
}
- I(info.ColumnSize >= info.DecimalDigits);
+ I(info.ColumnSize >= (SQLULEN)info.DecimalDigits);
- info.temp = PyObject_CallMethod(param, "__str__", 0);
- if (!info.temp)
+ info.ParameterValuePtr = CreateDecimalString(sign, digits, exp);
+ if (!info.ParameterValuePtr)
+ {
+ PyErr_NoMemory();
return false;
+ }
+ info.allocated = true;
- info.ParameterValuePtr = (SQLPOINTER)PyString_AS_STRING(info.temp);
- info.StrLen_or_Ind = PyString_GET_SIZE(info.temp);
+ info.StrLen_or_Ind = strlen((char*)info.ParameterValuePtr);
return true;
}
View
8 tests/sqlservertests.py
@@ -391,6 +391,14 @@ def t(self):
locals()['test_decimal_%s_%s_%s' % (p, s, n and 'n' or 'p')] = _maketest(p, s, n)
+ def test_decimal_e(self):
+ """Ensure exponential notation decimals are properly handled"""
+ value = Decimal((0, (1, 2, 3), 5)) # prints as 1.23E+7
+ self.cursor.execute("create table t1(d decimal(10, 2))")
+ self.cursor.execute("insert into t1 values (?)", value)
+ result = self.cursor.execute("select * from t1").fetchone()[0]
+ self.assertEqual(result, value)
+
def test_subquery_params(self):
"""Ensure parameter markers work in a subquery"""
self.cursor.execute("create table t1(id integer, s varchar(20))")
Please sign in to comment.
Something went wrong with that request. Please try again.