Skip to content

Commit 26d1426

Browse files
authored
Merge branch 'main' into jahnvi/native-uuid
2 parents 6a2aad3 + cd828b6 commit 26d1426

File tree

3 files changed

+546
-88
lines changed

3 files changed

+546
-88
lines changed

mssql_python/cursor.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ def _get_numeric_data(self, param):
195195
the numeric data.
196196
"""
197197
decimal_as_tuple = param.as_tuple()
198-
num_digits = len(decimal_as_tuple.digits)
198+
digits_tuple = decimal_as_tuple.digits
199+
num_digits = len(digits_tuple)
199200
exponent = decimal_as_tuple.exponent
200201

201202
# Calculate the SQL precision & scale
@@ -215,12 +216,11 @@ def _get_numeric_data(self, param):
215216
precision = exponent * -1
216217
scale = exponent * -1
217218

218-
# TODO: Revisit this check, do we want this restriction?
219-
if precision > 15:
219+
if precision > 38:
220220
raise ValueError(
221221
"Precision of the numeric value is too high - "
222222
+ str(param)
223-
+ ". Should be less than or equal to 15"
223+
+ ". Should be less than or equal to 38"
224224
)
225225
Numeric_Data = ddbc_bindings.NumericData
226226
numeric_data = Numeric_Data()
@@ -229,12 +229,26 @@ def _get_numeric_data(self, param):
229229
numeric_data.sign = 1 if decimal_as_tuple.sign == 0 else 0
230230
# strip decimal point from param & convert the significant digits to integer
231231
# Ex: 12.34 ---> 1234
232-
val = str(param)
233-
if "." in val or "-" in val:
234-
val = val.replace(".", "")
235-
val = val.replace("-", "")
236-
val = int(val)
237-
numeric_data.val = val
232+
int_str = ''.join(str(d) for d in digits_tuple)
233+
if exponent > 0:
234+
int_str = int_str + ('0' * exponent)
235+
elif exponent < 0:
236+
if -exponent > num_digits:
237+
int_str = ('0' * (-exponent - num_digits)) + int_str
238+
239+
if int_str == '':
240+
int_str = '0'
241+
242+
# Convert decimal base-10 string to python int, then to 16 little-endian bytes
243+
big_int = int(int_str)
244+
byte_array = bytearray(16) # SQL_MAX_NUMERIC_LEN
245+
for i in range(16):
246+
byte_array[i] = big_int & 0xFF
247+
big_int >>= 8
248+
if big_int == 0:
249+
break
250+
251+
numeric_data.val = bytes(byte_array)
238252
return numeric_data
239253

240254
def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
@@ -307,7 +321,27 @@ def _map_sql_type(self, param, parameters_list, i, min_val=None, max_val=None):
307321
)
308322

309323
if isinstance(param, decimal.Decimal):
310-
# Detect MONEY / SMALLMONEY range
324+
# First check precision limit for all decimal values
325+
decimal_as_tuple = param.as_tuple()
326+
digits_tuple = decimal_as_tuple.digits
327+
num_digits = len(digits_tuple)
328+
exponent = decimal_as_tuple.exponent
329+
330+
# Calculate the SQL precision (same logic as _get_numeric_data)
331+
if exponent >= 0:
332+
precision = num_digits + exponent
333+
elif (-1 * exponent) <= num_digits:
334+
precision = num_digits
335+
else:
336+
precision = exponent * -1
337+
338+
if precision > 38:
339+
raise ValueError(
340+
f"Precision of the numeric value is too high. "
341+
f"The maximum precision supported by SQL Server is 38, but got {precision}."
342+
)
343+
344+
# Detect MONEY / SMALLMONEY range
311345
if SMALLMONEY_MIN <= param <= SMALLMONEY_MAX:
312346
# smallmoney
313347
parameters_list[i] = str(param)

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#define SQL_SS_TIMESTAMPOFFSET (-155)
2222
#define SQL_C_SS_TIMESTAMPOFFSET (0x4001)
2323
#define MAX_DIGITS_IN_NUMERIC 64
24+
#define SQL_MAX_NUMERIC_LEN 16
25+
#define SQL_SS_XML (-152)
2426

2527
#define STRINGIFY_FOR_CASE(x) \
2628
case x: \
@@ -56,12 +58,18 @@ struct NumericData {
5658
SQLCHAR precision;
5759
SQLSCHAR scale;
5860
SQLCHAR sign; // 1=pos, 0=neg
59-
std::uint64_t val; // 123.45 -> 12345
61+
std::string val; // 123.45 -> 12345
6062

61-
NumericData() : precision(0), scale(0), sign(0), val(0) {}
63+
NumericData() : precision(0), scale(0), sign(0), val(SQL_MAX_NUMERIC_LEN, '\0') {}
6264

63-
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, std::uint64_t value)
64-
: precision(precision), scale(scale), sign(sign), val(value) {}
65+
NumericData(SQLCHAR precision, SQLSCHAR scale, SQLCHAR sign, const std::string& valueBytes)
66+
: precision(precision), scale(scale), sign(sign), val(SQL_MAX_NUMERIC_LEN, '\0') {
67+
if (valueBytes.size() > SQL_MAX_NUMERIC_LEN) {
68+
throw std::runtime_error("NumericData valueBytes size exceeds SQL_MAX_NUMERIC_LEN (16)");
69+
}
70+
// Copy binary data to buffer, remaining bytes stay zero-padded
71+
std::memcpy(&val[0], valueBytes.data(), valueBytes.size());
72+
}
6573
};
6674

6775
// Struct to hold the DateTimeOffset structure
@@ -557,9 +565,10 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
557565
decimalPtr->sign = decimalParam.sign;
558566
// Convert the integer decimalParam.val to char array
559567
std::memset(static_cast<void*>(decimalPtr->val), 0, sizeof(decimalPtr->val));
560-
std::memcpy(static_cast<void*>(decimalPtr->val),
561-
reinterpret_cast<char*>(&decimalParam.val),
562-
sizeof(decimalParam.val));
568+
size_t copyLen = std::min(decimalParam.val.size(), sizeof(decimalPtr->val));
569+
if (copyLen > 0) {
570+
std::memcpy(decimalPtr->val, decimalParam.val.data(), copyLen);
571+
}
563572
dataPtr = static_cast<void*>(decimalPtr);
564573
break;
565574
}
@@ -2050,15 +2059,17 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt,
20502059
throw std::runtime_error(MakeParamMismatchErrorStr(info.paramCType, paramIndex));
20512060
}
20522061
NumericData decimalParam = element.cast<NumericData>();
2053-
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%lld",
2054-
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val);
2055-
numericArray[i].precision = decimalParam.precision;
2056-
numericArray[i].scale = decimalParam.scale;
2057-
numericArray[i].sign = decimalParam.sign;
2058-
std::memset(numericArray[i].val, 0, sizeof(numericArray[i].val));
2059-
std::memcpy(numericArray[i].val,
2060-
reinterpret_cast<const char*>(&decimalParam.val),
2061-
std::min(sizeof(decimalParam.val), sizeof(numericArray[i].val)));
2062+
LOG("Received numeric parameter at [%zu]: precision=%d, scale=%d, sign=%d, val=%s",
2063+
i, decimalParam.precision, decimalParam.scale, decimalParam.sign, decimalParam.val.c_str());
2064+
SQL_NUMERIC_STRUCT& target = numericArray[i];
2065+
std::memset(&target, 0, sizeof(SQL_NUMERIC_STRUCT));
2066+
target.precision = decimalParam.precision;
2067+
target.scale = decimalParam.scale;
2068+
target.sign = decimalParam.sign;
2069+
size_t copyLen = std::min(decimalParam.val.size(), sizeof(target.val));
2070+
if (copyLen > 0) {
2071+
std::memcpy(target.val, decimalParam.val.data(), copyLen);
2072+
}
20622073
strLenOrIndArray[i] = sizeof(SQL_NUMERIC_STRUCT);
20632074
}
20642075
dataPtr = numericArray;
@@ -2525,6 +2536,12 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
25252536
}
25262537
break;
25272538
}
2539+
case SQL_SS_XML:
2540+
{
2541+
LOG("Streaming XML for column {}", i);
2542+
row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false));
2543+
break;
2544+
}
25282545
case SQL_WCHAR:
25292546
case SQL_WVARCHAR:
25302547
case SQL_WLONGVARCHAR: {
@@ -3395,6 +3412,7 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) {
33953412
case SQL_LONGVARCHAR:
33963413
rowSize += columnSize;
33973414
break;
3415+
case SQL_SS_XML:
33983416
case SQL_WCHAR:
33993417
case SQL_WVARCHAR:
34003418
case SQL_WLONGVARCHAR:
@@ -3499,7 +3517,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
34993517

35003518
if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR ||
35013519
dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR ||
3502-
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) &&
3520+
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) &&
35033521
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
35043522
lobColumns.push_back(i + 1); // 1-based
35053523
}
@@ -3621,7 +3639,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {
36213639

36223640
if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR ||
36233641
dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR ||
3624-
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) &&
3642+
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) &&
36253643
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
36263644
lobColumns.push_back(i + 1); // 1-based
36273645
}
@@ -3792,7 +3810,7 @@ PYBIND11_MODULE(ddbc_bindings, m) {
37923810
// Define numeric data class
37933811
py::class_<NumericData>(m, "NumericData")
37943812
.def(py::init<>())
3795-
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, std::uint64_t>())
3813+
.def(py::init<SQLCHAR, SQLSCHAR, SQLCHAR, const std::string&>())
37963814
.def_readwrite("precision", &NumericData::precision)
37973815
.def_readwrite("scale", &NumericData::scale)
37983816
.def_readwrite("sign", &NumericData::sign)

0 commit comments

Comments
 (0)