Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mssql_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class ConstantsDDBC(Enum):
SQL_DATETIMEOFFSET = -155
SQL_SS_TIME2 = -154
SQL_SS_XML = -152
SQL_SS_VARIANT = -150
SQL_C_SS_TIMESTAMPOFFSET = 0x4001
SQL_SCOPE_CURROW = 0
SQL_BEST_ROWID = 1
Expand Down Expand Up @@ -374,6 +375,7 @@ def get_valid_types(cls) -> set:
ConstantsDDBC.SQL_DATETIMEOFFSET.value,
ConstantsDDBC.SQL_SS_XML.value,
ConstantsDDBC.SQL_GUID.value,
ConstantsDDBC.SQL_SS_VARIANT.value,
}

# Could also add category methods for convenience
Expand Down
1 change: 1 addition & 0 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ def _get_c_type_for_sql_type(self, sql_type: int) -> int:
# Other types
ddbc_sql_const.SQL_GUID.value: ddbc_sql_const.SQL_C_GUID.value,
ddbc_sql_const.SQL_SS_XML.value: ddbc_sql_const.SQL_C_WCHAR.value,
ddbc_sql_const.SQL_SS_VARIANT.value: ddbc_sql_const.SQL_C_BINARY.value,
}
return sql_to_c_type.get(sql_type, ddbc_sql_const.SQL_C_DEFAULT.value)

Expand Down
186 changes: 146 additions & 40 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@
#define MAX_DIGITS_IN_NUMERIC 64
#define SQL_MAX_NUMERIC_LEN 16
#define SQL_SS_XML (-152)
#define SQL_SS_VARIANT (-150)
#define SQL_CA_SS_VARIANT_TYPE (1215)
#ifndef SQL_C_DATE
#define SQL_C_DATE (9)
#endif
#ifndef SQL_C_TIME
#define SQL_C_TIME (10)
#endif
#ifndef SQL_C_TIMESTAMP
#define SQL_C_TIMESTAMP (11)
#endif
// SQL Server-specific variant TIME type code
#define SQL_SS_VARIANT_TIME (16384)

#define STRINGIFY_FOR_CASE(x) \
case x: \
Expand Down Expand Up @@ -1153,7 +1166,8 @@ void SqlHandle::markImplicitlyFreed() {
// Log error but don't throw - we're likely in cleanup/destructor path
LOG_ERROR("SAFETY VIOLATION: Attempted to mark non-STMT handle as implicitly freed. "
"Handle type=%d. This will cause handle leak. Only STMT handles are "
"automatically freed by parent DBC handles.", _type);
"automatically freed by parent DBC handles.",
_type);
return; // Refuse to mark - let normal free() handle it
}
_implicitly_freed = true;
Expand Down Expand Up @@ -2891,6 +2905,67 @@ py::object FetchLobColumnData(SQLHSTMT hStmt, SQLUSMALLINT colIndex, SQLSMALLINT
}
}

// Helper function to map sql_variant's underlying C type to SQL data type
// This allows sql_variant to reuse existing fetch logic for each data type
SQLSMALLINT MapVariantCTypeToSQLType(SQLLEN variantCType) {
switch (variantCType) {
case SQL_C_SLONG:
case SQL_C_LONG:
return SQL_INTEGER;
case SQL_C_SSHORT:
case SQL_C_SHORT:
return SQL_SMALLINT;
case SQL_C_SBIGINT:
return SQL_BIGINT;
case SQL_C_FLOAT:
return SQL_REAL;
case SQL_C_DOUBLE:
return SQL_DOUBLE;
case SQL_C_BIT:
return SQL_BIT;
case SQL_C_CHAR:
return SQL_VARCHAR;
case SQL_C_WCHAR:
return SQL_WVARCHAR;
case SQL_C_DATE:
case SQL_C_TYPE_DATE:
return SQL_TYPE_DATE;
case SQL_C_TIME:
case SQL_C_TYPE_TIME:
case SQL_SS_VARIANT_TIME:
return SQL_TYPE_TIME;
case SQL_C_TIMESTAMP:
case SQL_C_TYPE_TIMESTAMP:
return SQL_TYPE_TIMESTAMP;
case SQL_C_BINARY:
return SQL_VARBINARY;
case SQL_C_GUID:
return SQL_GUID;
case SQL_C_NUMERIC:
return SQL_NUMERIC;
case SQL_C_TINYINT:
case SQL_C_UTINYINT:
case SQL_C_STINYINT:
return SQL_TINYINT;
default:
// Unknown C type code - fallback to WVARCHAR for string conversion
// Note: SQL Server enforces sql_variant restrictions at INSERT time, preventing
// invalid types (text, ntext, image, timestamp, xml, MAX types, nested variants,
// spatial types, hierarchyid, UDTs) from being stored. By the time we fetch data,
// only valid base types exist. This default handles unmapped/future type codes.
return SQL_WVARCHAR;
}
}

// Helper function to check if a column requires SQLGetData streaming (LOB or sql_variant)
static inline bool IsLobOrVariantColumn(SQLSMALLINT dataType, SQLULEN columnSize) {
return dataType == SQL_SS_VARIANT ||
((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR ||
dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY ||
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE));
}

// Helper function to retrieve column data
SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, py::list& row,
const std::string& charEncoding = "utf-8",
Expand Down Expand Up @@ -2929,7 +3004,41 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
continue;
}

switch (dataType) {
// Preprocess sql_variant: detect underlying type to route to correct conversion logic
SQLSMALLINT effectiveDataType = dataType;
if (dataType == SQL_SS_VARIANT) {
// For sql_variant, we MUST call SQLGetData with SQL_C_BINARY (NULL buffer, len=0)
// first. This serves two purposes:
// 1. Detects NULL values via the indicator parameter
// 2. Initializes the variant metadata in the ODBC driver, which is required for
// SQLColAttribute(SQL_CA_SS_VARIANT_TYPE) to return the correct underlying C type.
// Without this probe call, SQLColAttribute returns incorrect type codes.
SQLLEN indicator;
ret = SQLGetData_ptr(hStmt, i, SQL_C_BINARY, NULL, 0, &indicator);
if (!SQL_SUCCEEDED(ret)) {
LOG_ERROR("SQLGetData: Failed to probe sql_variant column %d - SQLRETURN=%d", i, ret);
row.append(py::none());
continue;
}
if (indicator == SQL_NULL_DATA) {
row.append(py::none());
continue;
}
// Now retrieve the underlying C type
SQLLEN variantCType = 0;
ret =
SQLColAttribute_ptr(hStmt, i, SQL_CA_SS_VARIANT_TYPE, NULL, 0, NULL, &variantCType);
if (!SQL_SUCCEEDED(ret)) {
LOG_ERROR("SQLGetData: Failed to get sql_variant underlying type for column %d", i);
row.append(py::none());
continue;
}
effectiveDataType = MapVariantCTypeToSQLType(variantCType);
LOG("SQLGetData: sql_variant column %d has variantCType=%ld, mapped to SQL type %d", i,
(long)variantCType, effectiveDataType);
}

switch (effectiveDataType) {
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR: {
Expand Down Expand Up @@ -4041,10 +4150,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR ||
dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY ||
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
if (IsLobOrVariantColumn(dataType, columnSize)) {
lobColumns.push_back(i + 1); // 1-based
}
}
Expand Down Expand Up @@ -4133,6 +4239,40 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows,
return ret;
}

std::vector<SQLUSMALLINT> lobColumns;
for (SQLSMALLINT i = 0; i < numCols; i++) {
auto colMeta = columnNames[i].cast<py::dict>();
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();

// Detect LOB columns that need SQLGetData streaming
// sql_variant always uses SQLGetData for native type preservation
if (IsLobOrVariantColumn(dataType, columnSize)) {
lobColumns.push_back(i + 1); // 1-based
}
}

// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("FetchAll_wrap: LOB columns detected (%zu columns), using per-row "
"SQLGetData path",
lobColumns.size());
while (true) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA)
break;
if (!SQL_SUCCEEDED(ret))
return ret;

py::list row;
SQLGetData_wrap(StatementHandle, numCols, row, charEncoding,
wcharEncoding); // <-- streams LOBs correctly
rows.append(row);
}
return SQL_SUCCESS;
}

// No LOBs detected - use binding path with batch fetching
// Define a memory limit (1 GB)
const size_t memoryLimit = 1ULL * 1024 * 1024 * 1024;
size_t totalRowSize = calculateRowSize(columnNames, numCols);
Expand Down Expand Up @@ -4173,40 +4313,6 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows,
}
LOG("FetchAll_wrap: Fetching data in batch sizes of %d", fetchSize);

std::vector<SQLUSMALLINT> lobColumns;
for (SQLSMALLINT i = 0; i < numCols; i++) {
auto colMeta = columnNames[i].cast<py::dict>();
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR ||
dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY ||
dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
}

// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("FetchAll_wrap: LOB columns detected (%zu columns), using per-row "
"SQLGetData path",
lobColumns.size());
while (true) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA)
break;
if (!SQL_SUCCEEDED(ret))
return ret;

py::list row;
SQLGetData_wrap(StatementHandle, numCols, row, charEncoding,
wcharEncoding); // <-- streams LOBs correctly
rows.append(row);
}
return SQL_SUCCESS;
}

ColumnBuffers buffers(numCols, fetchSize);

// Bind columns
Expand Down
Loading
Loading