From 459b4c1a61449e1bafa793e160f6f8f75b225ecb Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Thu, 30 Oct 2025 23:53:51 +0530 Subject: [PATCH 1/4] adding perf improvements --- mssql_python/cursor.py | 160 +++++++++++-- mssql_python/pybind/ddbc_bindings.cpp | 326 ++++++++++++++++---------- mssql_python/row.py | 155 ++++-------- 3 files changed, 375 insertions(+), 266 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 446a2dfb..8f81170a 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -14,7 +14,7 @@ import uuid import datetime import warnings -from typing import List, Union, Any, Optional, Tuple, Sequence, TYPE_CHECKING +from typing import List, Union, Any, Optional, Tuple, Sequence, TYPE_CHECKING, Dict, Callable from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings @@ -131,6 +131,9 @@ def __init__(self, connection: "Connection", timeout: int = 0) -> None: self._skip_increment_for_next_fetch: bool = ( False # Track if we need to skip incrementing the row index ) + self._cached_column_map: Optional[Dict[str, int]] = None + self._cached_converter_map: Optional[List[Optional[Callable[[Any], Any]]]] = None + self._settings_snapshot: Optional[Dict[str, Any]] = None self.messages: List[str] = [] # Store diagnostic messages @@ -574,10 +577,90 @@ def _reset_cursor(self) -> None: log("debug", "SQLFreeHandle succeeded") self._clear_rownumber() + + # Clear cached optimizations when resetting cursor + self._cached_column_map = None + self._cached_converter_map = None + self._settings_snapshot = None # Reinitialize the statement handle self._initialize_cursor() + def _build_shared_converter_map(self) -> Optional[List[Optional[Callable[[Any], Any]]]]: + """ + Build a shared converter map for all rows in this result set. + This optimization avoids repeated converter lookups for each row. + + Returns: + List of converters (one per column, None if no converter needed) + """ + if not self.description or not hasattr(self.connection, '_output_converters'): + return None + + if not self.connection._output_converters: + return None + + converter_map = [] + + # Map SQL type codes to appropriate byte sizes for integer conversion + int_size_map = { + ddbc_sql_const.SQL_TINYINT.value: 1, + ddbc_sql_const.SQL_SMALLINT.value: 2, + ddbc_sql_const.SQL_INTEGER.value: 4, + ddbc_sql_const.SQL_BIGINT.value: 8, + } + + for desc in self.description: + if desc is None: + converter_map.append(None) + continue + + sql_type = desc[1] # type_code is at index 1 in description tuple + + # Try to get a converter for this type + converter = self.connection.get_output_converter(sql_type) + + # If no converter found for the SQL type but we expect string/bytes, + # try the WVARCHAR converter as a fallback + if converter is None: + converter = self.connection.get_output_converter( + ddbc_sql_const.SQL_WVARCHAR.value + ) + + converter_map.append(converter) + + return converter_map + + def _build_settings_snapshot(self) -> Dict[str, Any]: + """ + Build a settings snapshot to avoid repeated get_settings() calls for each row. + + Returns: + Dictionary with current settings values + """ + settings = get_settings() + return { + "lowercase": settings.lowercase, + "native_uuid": settings.native_uuid, + } + + def _ensure_cached_optimizations(self) -> None: + """ + Ensure column map, converter map, and settings snapshot are cached. + Called before fetching rows to optimize row creation performance. + """ + # Only build settings snapshot - keep other optimizations minimal for now + if self._settings_snapshot is None: + self._settings_snapshot = self._build_settings_snapshot() + + # Build basic column map if description exists + if self._cached_column_map is None and self.description: + self._cached_column_map = {} + for i, col_desc in enumerate(self.description): + if col_desc: # Ensure column description exists + col_name = col_desc[0] # Name is first item in description tuple + self._cached_column_map[col_name] = i + def close(self) -> None: """ Close the connection now (rather than whenever .__del__() is called). @@ -1159,6 +1242,15 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self._clear_rownumber() + # After successful execution, initialize description if there are results + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except Exception as e: + # If describe fails, it's likely there are no results (e.g., for INSERT) + self.description = None + self._reset_inputsizes() # Reset input sizes after execution # Return self for method chaining return self @@ -1913,7 +2005,11 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s ) ret = ddbc_bindings.SQLExecuteMany( - self.hstmt, operation, columnwise_params, parameters_type, row_count + self.hstmt, + operation, + columnwise_params, + parameters_type, + row_count ) # Capture any diagnostic messages after execution @@ -1968,12 +2064,17 @@ def fetchone(self) -> Union[None, Row]: self._increment_rownumber() self.rowcount = self._next_row_index - - # Create and return a Row object, passing column name map if available - column_map = getattr(self, "_column_name_map", None) - settings_snapshot = getattr(self, "_settings_snapshot", None) - return Row(self, self.description, row_data, column_map, settings_snapshot) - except Exception as e: # pylint: disable=broad-exception-caught + self._ensure_cached_optimizations() + + return Row( + values=row_data, + cursor=self, + description=self.description, + column_map=self._cached_column_map, + converter_map=self._cached_converter_map, + settings_snapshot=self._settings_snapshot + ) + except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -2000,7 +2101,7 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: # Fetch raw data rows_data = [] try: - _ = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) + ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) if self.hstmt: self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) @@ -2016,15 +2117,20 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: self.rowcount = 0 else: self.rowcount = self._next_row_index - - # Convert raw data to Row objects - column_map = getattr(self, "_column_name_map", None) - settings_snapshot = getattr(self, "_settings_snapshot", None) + self._ensure_cached_optimizations() + + # Convert raw data to Row objects using shared cached optimizations return [ - Row(self, self.description, row_data, column_map, settings_snapshot) - for row_data in rows_data + Row( + values=row_data, + cursor=self, + description=self.description, + column_map=self._cached_column_map, + converter_map=self._cached_converter_map, + settings_snapshot=self._settings_snapshot + ) for row_data in rows_data ] - except Exception as e: # pylint: disable=broad-exception-caught + except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -2042,7 +2148,7 @@ def fetchall(self) -> List[Row]: # Fetch raw data rows_data = [] try: - _ = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) if self.hstmt: self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) @@ -2057,15 +2163,20 @@ def fetchall(self) -> List[Row]: self.rowcount = 0 else: self.rowcount = self._next_row_index - - # Convert raw data to Row objects - column_map = getattr(self, "_column_name_map", None) - settings_snapshot = getattr(self, "_settings_snapshot", None) + self._ensure_cached_optimizations() + + # Convert raw data to Row objects using shared cached optimizations return [ - Row(self, self.description, row_data, column_map, settings_snapshot) - for row_data in rows_data + Row( + values=row_data, + cursor=self, + description=self.description, + column_map=self._cached_column_map, + converter_map=self._cached_converter_map, + settings_snapshot=self._settings_snapshot + ) for row_data in rows_data ] - except Exception as e: # pylint: disable=broad-exception-caught + except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -2110,6 +2221,7 @@ def __exit__(self, *args): """Closes the cursor when exiting the context, ensuring proper resource cleanup.""" if not self.closed: self.close() + return None def fetchval(self): """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 96a8d9f7..11b05456 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -34,6 +34,68 @@ #endif #define DAE_CHUNK_SIZE 8192 #define SQL_MAX_LOB_SIZE 8000 + +namespace PythonObjectCache { + static py::object datetime_class; + static py::object date_class; + static py::object time_class; + static py::object decimal_class; + static py::object uuid_class; + static bool cache_initialized = false; + + void initialize() { + if (!cache_initialized) { + auto datetime_module = py::module_::import("datetime"); + datetime_class = datetime_module.attr("datetime"); + date_class = datetime_module.attr("date"); + time_class = datetime_module.attr("time"); + + auto decimal_module = py::module_::import("decimal"); + decimal_class = decimal_module.attr("Decimal"); + + auto uuid_module = py::module_::import("uuid"); + uuid_class = uuid_module.attr("UUID"); + + cache_initialized = true; + } + } + + py::object get_datetime_class() { + if (cache_initialized && datetime_class) { + return datetime_class; + } + return py::module_::import("datetime").attr("datetime"); + } + + py::object get_date_class() { + if (cache_initialized && date_class) { + return date_class; + } + return py::module_::import("datetime").attr("date"); + } + + py::object get_time_class() { + if (cache_initialized && time_class) { + return time_class; + } + return py::module_::import("datetime").attr("time"); + } + + py::object get_decimal_class() { + if (cache_initialized && decimal_class) { + return decimal_class; + } + return py::module_::import("decimal").attr("Decimal"); + } + + py::object get_uuid_class() { + if (cache_initialized && uuid_class) { + return uuid_class; + } + return py::module_::import("uuid").attr("UUID"); + } +} + //------------------------------------------------------------------------------------------------- // Class definitions //------------------------------------------------------------------------------------------------- @@ -458,7 +520,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_TYPE_DATE: { - py::object dateType = py::module_::import("datetime").attr("date"); + py::object dateType = PythonObjectCache::get_date_class(); if (!py::isinstance(param, dateType)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } @@ -475,7 +537,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_TYPE_TIME: { - py::object timeType = py::module_::import("datetime").attr("time"); + py::object timeType = PythonObjectCache::get_time_class(); if (!py::isinstance(param, timeType)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } @@ -488,7 +550,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_SS_TIMESTAMPOFFSET: { - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = PythonObjectCache::get_datetime_class(); if (!py::isinstance(param, datetimeType)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } @@ -532,7 +594,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_TYPE_TIMESTAMP: { - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = PythonObjectCache::get_datetime_class(); if (!py::isinstance(param, datetimeType)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } @@ -1419,11 +1481,11 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q DriverLoader::getInstance().loadDriver(); // Load the driver } - // Ensure statement is scrollable BEFORE executing + // Configure forward-only cursor if (SQLSetStmtAttr_ptr && StatementHandle && StatementHandle->get()) { SQLSetStmtAttr_ptr(StatementHandle->get(), SQL_ATTR_CURSOR_TYPE, - (SQLPOINTER)SQL_CURSOR_STATIC, + (SQLPOINTER)SQL_CURSOR_FORWARD_ONLY, 0); SQLSetStmtAttr_ptr(StatementHandle->get(), SQL_ATTR_CONCURRENCY, @@ -1556,11 +1618,11 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, LOG("Statement handle is null or empty"); } - // Ensure statement is scrollable BEFORE executing + // Configure forward-only cursor if (SQLSetStmtAttr_ptr && hStmt) { SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CURSOR_TYPE, - (SQLPOINTER)SQL_CURSOR_STATIC, + (SQLPOINTER)SQL_CURSOR_FORWARD_ONLY, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CONCURRENCY, @@ -2002,7 +2064,7 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, DateTimeOffset* dtoArray = AllocateParamBufferArray(tempBuffers, paramSetSize); strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = PythonObjectCache::get_datetime_class(); for (size_t i = 0; i < paramSetSize; ++i) { const py::handle& param = columnValues[i]; @@ -2080,9 +2142,8 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, SQLGUID* guidArray = AllocateParamBufferArray(tempBuffers, paramSetSize); strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - // Get cached UUID class from module-level helper - // This avoids static object destruction issues during Python finalization - py::object uuid_class = py::module_::import("mssql_python.ddbc_bindings").attr("_get_uuid_class")(); + // Get cached UUID class + py::object uuid_class = PythonObjectCache::get_uuid_class(); for (size_t i = 0; i < paramSetSize; ++i) { const py::handle& element = columnValues[i]; @@ -2465,6 +2526,12 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p SQLRETURN ret = SQL_SUCCESS; SQLHSTMT hStmt = StatementHandle->get(); + + // Cache decimal separator to avoid repeated system calls + static const std::string defaultSeparator = "."; + std::string decimalSeparator = GetDecimalSeparator(); + bool isDefaultDecimalSeparator = (decimalSeparator == defaultSeparator); + for (SQLSMALLINT i = 1; i <= colCount; ++i) { SQLWCHAR columnName[256]; SQLSMALLINT columnNameLen; @@ -2661,15 +2728,18 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p safeLen = bufSize; } } - - // Use the validated length to construct the string for Decimal - std::string numStr(cnum, safeLen); - - // Create Python Decimal object - py::object decimalObj = py::module_::import("decimal").attr("Decimal")(numStr); - - // Add to row - row.append(decimalObj); + if (isDefaultDecimalSeparator) { + py::object decimalObj = PythonObjectCache::get_decimal_class()(py::str(cnum, safeLen)); + row.append(decimalObj); + } else { + std::string numStr(cnum, safeLen); + size_t pos = numStr.find('.'); + if (pos != std::string::npos) { + numStr.replace(pos, 1, decimalSeparator); + } + py::object decimalObj = PythonObjectCache::get_decimal_class()(numStr); + row.append(decimalObj); + } } catch (const py::error_already_set& e) { // If conversion fails, append None LOG("Error converting to decimal: {}", e.what()); @@ -2718,7 +2788,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p SQLGetData_ptr(hStmt, i, SQL_C_TYPE_DATE, &dateValue, sizeof(dateValue), NULL); if (SQL_SUCCEEDED(ret)) { row.append( - py::module_::import("datetime").attr("date")( + PythonObjectCache::get_date_class()( dateValue.year, dateValue.month, dateValue.day @@ -2740,7 +2810,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIME, &timeValue, sizeof(timeValue), NULL); if (SQL_SUCCEEDED(ret)) { row.append( - py::module_::import("datetime").attr("time")( + PythonObjectCache::get_time_class()( timeValue.hour, timeValue.minute, timeValue.second @@ -2762,7 +2832,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p sizeof(timestampValue), NULL); if (SQL_SUCCEEDED(ret)) { row.append( - py::module_::import("datetime").attr("datetime")( + PythonObjectCache::datetime_class( timestampValue.year, timestampValue.month, timestampValue.day, @@ -2808,11 +2878,11 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } // Convert fraction from ns to µs int microseconds = dtoValue.fraction / 1000; - py::object datetime = py::module_::import("datetime"); - py::object tzinfo = datetime.attr("timezone")( - datetime.attr("timedelta")(py::arg("minutes") = totalMinutes) + py::object datetime_module = py::module_::import("datetime"); + py::object tzinfo = datetime_module.attr("timezone")( + datetime_module.attr("timedelta")(py::arg("minutes") = totalMinutes) ); - py::object py_dt = datetime.attr("datetime")( + py::object py_dt = PythonObjectCache::get_datetime_class()( dtoValue.year, dtoValue.month, dtoValue.day, @@ -2913,8 +2983,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p std::memcpy(&guid_bytes[8], guidValue.Data4, sizeof(guidValue.Data4)); py::bytes py_guid_bytes(guid_bytes.data(), guid_bytes.size()); - py::object uuid_module = py::module_::import("uuid"); - py::object uuid_obj = uuid_module.attr("UUID")(py::arg("bytes")=py_guid_bytes); + py::object uuid_obj = PythonObjectCache::get_uuid_class()(py::arg("bytes")=py_guid_bytes); row.append(uuid_obj); } else if (indicator == SQL_NULL_DATA) { row.append(py::none()); @@ -3135,42 +3204,62 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum LOG("Error while fetching rows in batches"); return ret; } - // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. It'll be populated by - // SQLFetchScroll + // Pre-cache column metadata to avoid repeated dictionary lookups + struct ColumnInfo { + SQLSMALLINT dataType; + SQLULEN columnSize; + SQLULEN processedColumnSize; + uint64_t fetchBufferSize; + bool isLob; + }; + std::vector columnInfos(numCols); + for (SQLUSMALLINT col = 0; col < numCols; col++) { + const auto& columnMeta = columnNames[col].cast(); + columnInfos[col].dataType = columnMeta["DataType"].cast(); + columnInfos[col].columnSize = columnMeta["ColumnSize"].cast(); + columnInfos[col].isLob = std::find(lobColumns.begin(), lobColumns.end(), col + 1) != lobColumns.end(); + columnInfos[col].processedColumnSize = columnInfos[col].columnSize; + HandleZeroColumnSizeAtFetch(columnInfos[col].processedColumnSize); + columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize + 1; // +1 for null terminator + } + + static const std::string defaultSeparator = "."; + std::string decimalSeparator = GetDecimalSeparator(); // Cache decimal separator + bool isDefaultDecimalSeparator = (decimalSeparator == defaultSeparator); + + size_t initialSize = rows.size(); + for (SQLULEN i = 0; i < numRowsFetched; i++) { + rows.append(py::none()); + } + for (SQLULEN i = 0; i < numRowsFetched; i++) { - py::list row; + // Create row container pre-allocated with known column count + py::list row(numCols); for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto columnMeta = columnNames[col - 1].cast(); - SQLSMALLINT dataType = columnMeta["DataType"].cast(); + const ColumnInfo& colInfo = columnInfos[col - 1]; + SQLSMALLINT dataType = colInfo.dataType; SQLLEN dataLen = buffers.indicators[col - 1][i]; - if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); + row[col - 1] = py::none(); continue; } - // TODO: variable length data needs special handling, this logic wont suffice - // This value indicates that the driver cannot determine the length of the data if (dataLen == SQL_NO_TOTAL) { LOG("Cannot determine the length of the data. Returning NULL value instead." "Column ID - {}", col); - row.append(py::none()); - continue; - } else if (dataLen == SQL_NULL_DATA) { - LOG("Column data is NULL. Appending None to the result row. Column ID - {}", col); - row.append(py::none()); + row[col - 1] = py::none(); continue; } else if (dataLen == 0) { // Handle zero-length (non-NULL) data if (dataType == SQL_CHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR) { - row.append(std::string("")); + row[col - 1] = std::string(""); } else if (dataType == SQL_WCHAR || dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR) { - row.append(std::wstring(L"")); + row[col - 1] = std::wstring(L""); } else if (dataType == SQL_BINARY || dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) { - row.append(py::bytes("")); + row[col - 1] = py::bytes(""); } else { - // For other datatypes, 0 length is unexpected. Log & append None - LOG("Column data length is 0 for non-string/binary datatype. Appending None to the result row. Column ID - {}", col); - row.append(py::none()); + // For other datatypes, 0 length is unexpected. Log & set None + LOG("Column data length is 0 for non-string/binary datatype. Setting None to the result row. Column ID - {}", col); + row[col - 1] = py::none(); } continue; } else if (dataLen < 0) { @@ -3184,19 +3273,18 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLULEN columnSize = colInfo.columnSize; HandleZeroColumnSizeAtFetch(columnSize); uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + bool isLob = colInfo.isLob; // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' if (!isLob && numCharsInData < fetchBufferSize) { - // SQLFetch will nullterminate the data - row.append(std::string( + row[col - 1] = py::str( reinterpret_cast(&buffers.charBuffers[col - 1][i * fetchBufferSize]), - numCharsInData)); + numCharsInData); } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false)); + row[col - 1] = FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false); } break; } @@ -3204,114 +3292,102 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { // TODO: variable length data needs special handling, this logic wont suffice - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLULEN columnSize = colInfo.columnSize; HandleZeroColumnSizeAtFetch(columnSize); uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + bool isLob = colInfo.isLob; // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' if (!isLob && numCharsInData < fetchBufferSize) { - // SQLFetch will nullterminate the data #if defined(__APPLE__) || defined(__linux__) - // Use unix-specific conversion to handle the wchar_t/SQLWCHAR size difference SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; std::wstring wstr = SQLWCHARToWString(wcharData, numCharsInData); - row.append(wstr); + row[col - 1] = wstr; #else - // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so direct cast works - row.append(std::wstring( + row[col - 1] = std::wstring( reinterpret_cast(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]), - numCharsInData)); + numCharsInData); #endif } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false)); + row[col - 1] = FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false); } break; } case SQL_INTEGER: { - row.append(buffers.intBuffers[col - 1][i]); + row[col - 1] = buffers.intBuffers[col - 1][i]; break; } case SQL_SMALLINT: { - row.append(buffers.smallIntBuffers[col - 1][i]); + row[col - 1] = buffers.smallIntBuffers[col - 1][i]; break; } case SQL_TINYINT: { - row.append(buffers.charBuffers[col - 1][i]); + row[col - 1] = buffers.charBuffers[col - 1][i]; break; } case SQL_BIT: { - row.append(static_cast(buffers.charBuffers[col - 1][i])); + row[col - 1] = static_cast(buffers.charBuffers[col - 1][i]); break; } case SQL_REAL: { - row.append(buffers.realBuffers[col - 1][i]); + row[col - 1] = buffers.realBuffers[col - 1][i]; break; } case SQL_DECIMAL: case SQL_NUMERIC: { try { - // Convert the string to use the current decimal separator - std::string numStr(reinterpret_cast( - &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]), - buffers.indicators[col - 1][i]); - - // Get the current separator in a thread-safe way - std::string separator = GetDecimalSeparator(); + SQLLEN decimalDataLen = buffers.indicators[col - 1][i]; + const char* rawData = reinterpret_cast( + &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]); - if (separator != ".") { - // Replace the driver's decimal point with our configured separator + // Use pre-cached decimal separator + if (isDefaultDecimalSeparator) { + row[col - 1] = PythonObjectCache::get_decimal_class()(py::str(rawData, decimalDataLen)); + } else { + std::string numStr(rawData, decimalDataLen); size_t pos = numStr.find('.'); if (pos != std::string::npos) { - numStr.replace(pos, 1, separator); + numStr.replace(pos, 1, decimalSeparator); } + row[col - 1] = PythonObjectCache::get_decimal_class()(numStr); } - - // Convert to Python decimal - row.append(py::module_::import("decimal").attr("Decimal")(numStr)); } catch (const py::error_already_set& e) { - // Handle the exception, e.g., log the error and append py::none() + // Handle the exception, e.g., log the error and set py::none() LOG("Error converting to decimal: {}", e.what()); - row.append(py::none()); + row[col - 1] = py::none(); } break; } case SQL_DOUBLE: case SQL_FLOAT: { - row.append(buffers.doubleBuffers[col - 1][i]); + row[col - 1] = buffers.doubleBuffers[col - 1][i]; break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - row.append(py::module_::import("datetime") - .attr("datetime")(buffers.timestampBuffers[col - 1][i].year, - buffers.timestampBuffers[col - 1][i].month, - buffers.timestampBuffers[col - 1][i].day, - buffers.timestampBuffers[col - 1][i].hour, - buffers.timestampBuffers[col - 1][i].minute, - buffers.timestampBuffers[col - 1][i].second, - buffers.timestampBuffers[col - 1][i].fraction / 1000 /* Convert back ns to µs */)); + const SQL_TIMESTAMP_STRUCT& ts = buffers.timestampBuffers[col - 1][i]; + row[col - 1] = PythonObjectCache::get_datetime_class()(ts.year, ts.month, ts.day, + ts.hour, ts.minute, ts.second, + ts.fraction / 1000); break; } case SQL_BIGINT: { - row.append(buffers.bigIntBuffers[col - 1][i]); + row[col - 1] = buffers.bigIntBuffers[col - 1][i]; break; } case SQL_TYPE_DATE: { - row.append(py::module_::import("datetime") - .attr("date")(buffers.dateBuffers[col - 1][i].year, - buffers.dateBuffers[col - 1][i].month, - buffers.dateBuffers[col - 1][i].day)); + row[col - 1] = PythonObjectCache::get_date_class()(buffers.dateBuffers[col - 1][i].year, + buffers.dateBuffers[col - 1][i].month, + buffers.dateBuffers[col - 1][i].day); break; } case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { - row.append(py::module_::import("datetime") - .attr("time")(buffers.timeBuffers[col - 1][i].hour, - buffers.timeBuffers[col - 1][i].minute, - buffers.timeBuffers[col - 1][i].second)); + row[col - 1] = PythonObjectCache::get_time_class()(buffers.timeBuffers[col - 1][i].hour, + buffers.timeBuffers[col - 1][i].minute, + buffers.timeBuffers[col - 1][i].second); break; } case SQL_SS_TIMESTAMPOFFSET: { @@ -3320,11 +3396,11 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum SQLLEN indicator = buffers.indicators[col - 1][rowIdx]; if (indicator != SQL_NULL_DATA) { int totalMinutes = dtoValue.timezone_hour * 60 + dtoValue.timezone_minute; - py::object datetime = py::module_::import("datetime"); - py::object tzinfo = datetime.attr("timezone")( - datetime.attr("timedelta")(py::arg("minutes") = totalMinutes) + py::object datetime_module = py::module_::import("datetime"); + py::object tzinfo = datetime_module.attr("timezone")( + datetime_module.attr("timedelta")(py::arg("minutes") = totalMinutes) ); - py::object py_dt = datetime.attr("datetime")( + py::object py_dt = PythonObjectCache::get_datetime_class()( dtoValue.year, dtoValue.month, dtoValue.day, @@ -3334,16 +3410,16 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum dtoValue.fraction / 1000, // ns → µs tzinfo ); - row.append(py_dt); + row[col - 1] = py_dt; } else { - row.append(py::none()); + row[col - 1] = py::none(); } break; } case SQL_GUID: { SQLLEN indicator = buffers.indicators[col - 1][i]; if (indicator == SQL_NULL_DATA) { - row.append(py::none()); + row[col - 1] = py::none(); break; } SQLGUID* guidValue = &buffers.guidBuffers[col - 1][i]; @@ -3361,26 +3437,27 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum py::bytes py_guid_bytes(reinterpret_cast(reordered), 16); py::dict kwargs; kwargs["bytes"] = py_guid_bytes; - py::object uuid_obj = py::module_::import("uuid").attr("UUID")(**kwargs); - row.append(uuid_obj); + py::object uuid_obj = PythonObjectCache::get_uuid_class()(**kwargs); + row[col - 1] = uuid_obj; break; } case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLULEN columnSize = colInfo.columnSize; HandleZeroColumnSizeAtFetch(columnSize); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + bool isLob = colInfo.isLob; if (!isLob && static_cast(dataLen) <= columnSize) { - row.append(py::bytes(reinterpret_cast( - &buffers.charBuffers[col - 1][i * columnSize]), - dataLen)); + row[col - 1] = py::bytes(reinterpret_cast( + &buffers.charBuffers[col - 1][i * columnSize]), + dataLen); } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true)); + row[col - 1] = FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true); } break; } default: { + const auto& columnMeta = columnNames[col - 1].cast(); std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; errorString << "Unsupported data type for column - " << columnName.c_str() @@ -3391,7 +3468,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum } } } - rows.append(row); + rows[initialSize + i] = row; } return ret; } @@ -3785,6 +3862,8 @@ void DDBCSetDecimalSeparator(const std::string& separator) { PYBIND11_MODULE(ddbc_bindings, m) { m.doc() = "msodbcsql driver api bindings for Python"; + PythonObjectCache::initialize(); + // Add architecture information as module attribute m.attr("__architecture__") = ARCHITECTURE; @@ -3921,15 +4000,6 @@ PYBIND11_MODULE(ddbc_bindings, m) { return SQLColumns_wrap(StatementHandle, catalog, schema, table, column); }); - - // Module-level UUID class cache - // This caches the uuid.UUID class at module initialization time and keeps it alive - // for the entire module lifetime, avoiding static destructor issues during Python finalization - m.def("_get_uuid_class", []() -> py::object { - static py::object uuid_class = py::module_::import("uuid").attr("UUID"); - return uuid_class; - }, "Internal helper to get cached UUID class"); - // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/row.py b/mssql_python/row.py index 8ffcb6e0..7b824dba 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -22,6 +22,7 @@ class Row: def __init__( self, + values: List[Any], cursor: "Cursor", description: List[ Tuple[ @@ -34,54 +35,37 @@ def __init__( Optional[bool], ] ], - values: List[Any], column_map: Optional[Dict[str, int]] = None, + converter_map: Optional[List[Optional[Any]]] = None, settings_snapshot: Optional[Dict[str, Any]] = None, ) -> None: """ Initialize a Row object with values and description. Args: + values: List of values for this row cursor: The cursor object description: The cursor description containing column metadata - values: List of values for this row column_map: Optional pre-built column map (for optimization) + converter_map: Pre-computed converter map (shared across rows for performance) settings_snapshot: Settings snapshot from cursor to ensure consistency """ self._cursor = cursor self._description = description - # Use settings snapshot if provided, otherwise fallback to global settings - if settings_snapshot is not None: - self._settings = settings_snapshot - else: - settings = get_settings() - self._settings = { - "lowercase": settings.lowercase, - "native_uuid": settings.native_uuid, - } - # Create mapping of column names to indices first - # If column_map is not provided, build it from description - if column_map is None: - self._column_map = {} - for i, col_desc in enumerate(description): - if col_desc: # Ensure column description exists - col_name = col_desc[0] # Name is first item in description tuple - if self._settings.get("lowercase"): - col_name = col_name.lower() - self._column_map[col_name] = i - else: - self._column_map = column_map - - # First make a mutable copy of values - processed_values = list(values) + # Store pre-built column map + self._column_map = column_map or {} + self._settings = settings_snapshot or { + "lowercase": get_settings().lowercase, + "native_uuid": get_settings().native_uuid, + } - # Apply output converters if available - if ( - hasattr(cursor.connection, "_output_converters") - and cursor.connection._output_converters - ): - processed_values = self._apply_output_converters(processed_values) + # Apply output converters using pre-built converter map if available + if converter_map: + processed_values = self._apply_output_converters(values, converter_map) + else: + # Fallback to no conversion + processed_values = list(values) # Process UUID values using the snapshotted setting self._values = self._process_uuid_values(processed_values, description) @@ -154,90 +138,30 @@ def _process_uuid_values( return processed_values - def _apply_output_converters(self, values: List[Any]) -> List[Any]: + def _apply_output_converters(self, values, converter_map): """ - Apply output converters to raw values. - + Apply output converters using pre-computed converter map for optimal performance. + Args: values: Raw values from the database - + converter_map: Pre-computed list of converters (one per column, None if no converter) + Returns: List of converted values """ - if not self._description: - return values - converted_values = list(values) - - # Map SQL type codes to appropriate byte sizes - int_size_map = { - # SQL_TINYINT - ConstantsDDBC.SQL_TINYINT.value: 1, - # SQL_SMALLINT - ConstantsDDBC.SQL_SMALLINT.value: 2, - # SQL_INTEGER - ConstantsDDBC.SQL_INTEGER.value: 4, - # SQL_BIGINT - ConstantsDDBC.SQL_BIGINT.value: 8, - } - - for i, (value, desc) in enumerate(zip(values, self._description)): - if desc is None or value is None: - continue - - # Get SQL type from description - sql_type = desc[1] # type_code is at index 1 in description tuple - - # Try to get a converter for this type - converter = self._cursor.connection.get_output_converter(sql_type) - - # If no converter found for the SQL type but the value is a string or bytes, - # try the WVARCHAR converter as a fallback - if converter is None and isinstance(value, (str, bytes)): - converter = self._cursor.connection.get_output_converter( - ConstantsDDBC.SQL_WVARCHAR.value - ) - - # If we found a converter, apply it - if converter: + + for i, (value, converter) in enumerate(zip(values, converter_map)): + if converter and value is not None: try: - # If value is already a Python type (str, int, etc.), - # we need to handle it appropriately if isinstance(value, str): - # Encode as UTF-16LE for string values (SQL_WVARCHAR format) - value_bytes = value.encode("utf-16-le") + value_bytes = value.encode('utf-16-le') converted_values[i] = converter(value_bytes) - elif isinstance(value, int): - # Get appropriate byte size for this integer type - byte_size = int_size_map.get(sql_type, 8) - try: - # Use signed=True to properly handle negative values - value_bytes = value.to_bytes( - byte_size, byteorder="little", signed=True - ) - converted_values[i] = converter(value_bytes) - except OverflowError: - # Log specific overflow error with details to help diagnose the issue - if hasattr(self._cursor, "log"): - self._cursor.log( - "warning", - f"Integer overflow: value {value} does not fit in " - f"{byte_size} bytes for SQL type {sql_type}", - ) - # Keep the original value in this case else: - # Pass the value directly for other types converted_values[i] = converter(value) - except Exception as e: - # Log the exception for debugging without leaking sensitive data - if hasattr(self._cursor, "log"): - self._cursor.log( - "warning", - f"Exception in output converter: {type(e).__name__} " - f"for SQL type {sql_type}", - ) - # If conversion fails, keep the original value - + except Exception: + pass + return converted_values def __getitem__(self, index: int) -> Any: @@ -286,19 +210,22 @@ def __iter__(self) -> Any: def __str__(self) -> str: """Return string representation of the row""" - # Local import to avoid circular dependency - from mssql_python import getDecimalSeparator parts = [] for value in self: if isinstance(value, decimal.Decimal): - # Apply custom decimal separator for display - sep = getDecimalSeparator() - if sep != "." and value is not None: - s = str(value) - if "." in s: - s = s.replace(".", sep) - parts.append(s) - else: + try: + # Apply custom decimal separator for display with safety checks + # Local import to avoid circular dependency + from mssql_python import getDecimalSeparator + sep = getDecimalSeparator() + if sep and sep != "." and value is not None: + s = str(value) + if "." in s: + s = s.replace(".", sep) + parts.append(s) + else: + parts.append(str(value)) + except Exception: parts.append(str(value)) else: parts.append(repr(value)) From eccfe5b37c59f1f80a7a8a9c8b83d5191310d107 Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Fri, 31 Oct 2025 11:16:34 +0530 Subject: [PATCH 2/4] fixed converter test failures, decimal separator tests failure and row constructor in uuid --- mssql_python/cursor.py | 6 ++- mssql_python/pybind/ddbc_bindings.cpp | 65 ++++++++------------------- tests/test_003_connection.py | 2 + tests/test_004_cursor.py | 4 +- 4 files changed, 27 insertions(+), 50 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 8f81170a..4b213f53 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -649,7 +649,7 @@ def _ensure_cached_optimizations(self) -> None: Ensure column map, converter map, and settings snapshot are cached. Called before fetching rows to optimize row creation performance. """ - # Only build settings snapshot - keep other optimizations minimal for now + # Build settings snapshot if self._settings_snapshot is None: self._settings_snapshot = self._build_settings_snapshot() @@ -660,6 +660,10 @@ def _ensure_cached_optimizations(self) -> None: if col_desc: # Ensure column description exists col_name = col_desc[0] # Name is first item in description tuple self._cached_column_map[col_name] = i + + # Build converter map if needed + if self._cached_converter_map is None: + self._cached_converter_map = self._build_shared_converter_map() def close(self) -> None: """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 11b05456..f46abc5e 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -2704,46 +2704,18 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, numericStr, sizeof(numericStr), &indicator); if (SQL_SUCCEEDED(ret)) { - try { - // Validate 'indicator' to avoid buffer overflow and fallback to a safe - // null-terminated read when length is unknown or out-of-range. - const char* cnum = reinterpret_cast(numericStr); - size_t bufSize = sizeof(numericStr); - size_t safeLen = 0; - - if (indicator > 0 && indicator <= static_cast(bufSize)) { - // indicator appears valid and within the buffer size - safeLen = static_cast(indicator); - } else { - // indicator is unknown, zero, negative, or too large; determine length - // by searching for a terminating null (safe bounded scan) - for (size_t j = 0; j < bufSize; ++j) { - if (cnum[j] == '\0') { - safeLen = j; - break; - } - } - // if no null found, use the full buffer size as a conservative fallback - if (safeLen == 0 && bufSize > 0 && cnum[0] != '\0') { - safeLen = bufSize; - } - } - if (isDefaultDecimalSeparator) { - py::object decimalObj = PythonObjectCache::get_decimal_class()(py::str(cnum, safeLen)); - row.append(decimalObj); - } else { - std::string numStr(cnum, safeLen); - size_t pos = numStr.find('.'); - if (pos != std::string::npos) { - numStr.replace(pos, 1, decimalSeparator); - } - py::object decimalObj = PythonObjectCache::get_decimal_class()(numStr); + if (indicator == SQL_NULL_DATA) { + row.append(py::none()); + } else { + try { + const char* cnum = reinterpret_cast(numericStr); + py::object decimalObj = PythonObjectCache::get_decimal_class()(py::str(cnum)); row.append(decimalObj); + } catch (const py::error_already_set& e) { + // If conversion fails, append None + LOG("Error converting to decimal: {}", e.what()); + row.append(py::none()); } - } catch (const py::error_already_set& e) { - // If conversion fails, append None - LOG("Error converting to decimal: {}", e.what()); - row.append(py::none()); } } else { @@ -3340,16 +3312,15 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum const char* rawData = reinterpret_cast( &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]); - // Use pre-cached decimal separator - if (isDefaultDecimalSeparator) { - row[col - 1] = PythonObjectCache::get_decimal_class()(py::str(rawData, decimalDataLen)); + if (decimalDataLen == SQL_NULL_DATA) { + row[col - 1] = py::none(); + } else if (decimalDataLen > 0) { + SQLLEN safeLen = std::min(decimalDataLen, static_cast(MAX_DIGITS_IN_NUMERIC)); + // Always create Decimal with dot notation (Python standard) + // The custom separator is only for display formatting, not internal representation + row[col - 1] = PythonObjectCache::get_decimal_class()(py::str(rawData, safeLen)); } else { - std::string numStr(rawData, decimalDataLen); - size_t pos = numStr.find('.'); - if (pos != std::string::npos) { - numStr.replace(pos, 1, decimalSeparator); - } - row[col - 1] = PythonObjectCache::get_decimal_class()(numStr); + row[col - 1] = py::none(); } } catch (const py::error_already_set& e) { // Handle the exception, e.g., log the error and set py::none() diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 9526d158..8f0c7f50 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -8603,7 +8603,9 @@ def test_connection_context_manager_with_cursor_cleanup(conn_str): # Perform operations cursor1.execute("SELECT 1") + cursor1.fetchall() # Consume the result set to free the connection cursor2.execute("SELECT 2") + cursor2.fetchall() # Consume the result set # Verify cursors are tracked assert len(conn._cursors) == 2, "Should track both cursors" diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index b52b0656..0d3b5dc3 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -14766,7 +14766,7 @@ def test_row_uuid_processing_exception_handling(cursor, db_connection): # Create Row directly with the data and modified description # This should trigger exception handling in lines 101-102 and 116-117 - row = Row(cursor, modified_description, list(row_data)) + row = Row(list(row_data), cursor, modified_description) # The invalid GUID should be kept as original value due to exception handling # Lines 101-102: except (ValueError, AttributeError): pass # Keep original if conversion fails @@ -15009,7 +15009,7 @@ def test_row_uuid_attribute_error_handling(cursor, db_connection): # Create Row directly with the data and modified description # This should trigger AttributeError handling in lines 101-102 and 116-117 - row = Row(cursor, modified_description, list(row_data)) + row = Row(list(row_data), cursor, modified_description) # The integer value should be kept as original due to AttributeError handling # Lines 101-102: except (ValueError, AttributeError): pass # Keep original if conversion fails From 285f7ca865592c35b2f0b84cd675b0a75c76f208 Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Fri, 31 Oct 2025 11:58:16 +0530 Subject: [PATCH 3/4] fixed cursor's scroll/skip logic with fast-fwd cursor and tests --- mssql_python/cursor.py | 222 +++++++++++---------------------------- mssql_python/row.py | 157 +++++++++++++++++++-------- tests/test_004_cursor.py | 10 +- 3 files changed, 179 insertions(+), 210 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 4b213f53..e647d4c5 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -14,7 +14,7 @@ import uuid import datetime import warnings -from typing import List, Union, Any, Optional, Tuple, Sequence, TYPE_CHECKING, Dict, Callable +from typing import List, Union, Any, Optional, Tuple, Sequence, TYPE_CHECKING from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings @@ -131,9 +131,6 @@ def __init__(self, connection: "Connection", timeout: int = 0) -> None: self._skip_increment_for_next_fetch: bool = ( False # Track if we need to skip incrementing the row index ) - self._cached_column_map: Optional[Dict[str, int]] = None - self._cached_converter_map: Optional[List[Optional[Callable[[Any], Any]]]] = None - self._settings_snapshot: Optional[Dict[str, Any]] = None self.messages: List[str] = [] # Store diagnostic messages @@ -577,94 +574,10 @@ def _reset_cursor(self) -> None: log("debug", "SQLFreeHandle succeeded") self._clear_rownumber() - - # Clear cached optimizations when resetting cursor - self._cached_column_map = None - self._cached_converter_map = None - self._settings_snapshot = None # Reinitialize the statement handle self._initialize_cursor() - def _build_shared_converter_map(self) -> Optional[List[Optional[Callable[[Any], Any]]]]: - """ - Build a shared converter map for all rows in this result set. - This optimization avoids repeated converter lookups for each row. - - Returns: - List of converters (one per column, None if no converter needed) - """ - if not self.description or not hasattr(self.connection, '_output_converters'): - return None - - if not self.connection._output_converters: - return None - - converter_map = [] - - # Map SQL type codes to appropriate byte sizes for integer conversion - int_size_map = { - ddbc_sql_const.SQL_TINYINT.value: 1, - ddbc_sql_const.SQL_SMALLINT.value: 2, - ddbc_sql_const.SQL_INTEGER.value: 4, - ddbc_sql_const.SQL_BIGINT.value: 8, - } - - for desc in self.description: - if desc is None: - converter_map.append(None) - continue - - sql_type = desc[1] # type_code is at index 1 in description tuple - - # Try to get a converter for this type - converter = self.connection.get_output_converter(sql_type) - - # If no converter found for the SQL type but we expect string/bytes, - # try the WVARCHAR converter as a fallback - if converter is None: - converter = self.connection.get_output_converter( - ddbc_sql_const.SQL_WVARCHAR.value - ) - - converter_map.append(converter) - - return converter_map - - def _build_settings_snapshot(self) -> Dict[str, Any]: - """ - Build a settings snapshot to avoid repeated get_settings() calls for each row. - - Returns: - Dictionary with current settings values - """ - settings = get_settings() - return { - "lowercase": settings.lowercase, - "native_uuid": settings.native_uuid, - } - - def _ensure_cached_optimizations(self) -> None: - """ - Ensure column map, converter map, and settings snapshot are cached. - Called before fetching rows to optimize row creation performance. - """ - # Build settings snapshot - if self._settings_snapshot is None: - self._settings_snapshot = self._build_settings_snapshot() - - # Build basic column map if description exists - if self._cached_column_map is None and self.description: - self._cached_column_map = {} - for i, col_desc in enumerate(self.description): - if col_desc: # Ensure column description exists - col_name = col_desc[0] # Name is first item in description tuple - self._cached_column_map[col_name] = i - - # Build converter map if needed - if self._cached_converter_map is None: - self._cached_converter_map = self._build_shared_converter_map() - def close(self) -> None: """ Close the connection now (rather than whenever .__del__() is called). @@ -1246,15 +1159,6 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self._clear_rownumber() - # After successful execution, initialize description if there are results - column_metadata = [] - try: - ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) - self._initialize_description(column_metadata) - except Exception as e: - # If describe fails, it's likely there are no results (e.g., for INSERT) - self.description = None - self._reset_inputsizes() # Reset input sizes after execution # Return self for method chaining return self @@ -2009,11 +1913,7 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s ) ret = ddbc_bindings.SQLExecuteMany( - self.hstmt, - operation, - columnwise_params, - parameters_type, - row_count + self.hstmt, operation, columnwise_params, parameters_type, row_count ) # Capture any diagnostic messages after execution @@ -2068,17 +1968,12 @@ def fetchone(self) -> Union[None, Row]: self._increment_rownumber() self.rowcount = self._next_row_index - self._ensure_cached_optimizations() - - return Row( - values=row_data, - cursor=self, - description=self.description, - column_map=self._cached_column_map, - converter_map=self._cached_converter_map, - settings_snapshot=self._settings_snapshot - ) - except Exception as e: + + # Create and return a Row object, passing column name map if available + column_map = getattr(self, "_column_name_map", None) + settings_snapshot = getattr(self, "_settings_snapshot", None) + return Row(self, self.description, row_data, column_map, settings_snapshot) + except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e @@ -2105,7 +2000,7 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: # Fetch raw data rows_data = [] try: - ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) + _ = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) if self.hstmt: self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) @@ -2121,20 +2016,15 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: self.rowcount = 0 else: self.rowcount = self._next_row_index - self._ensure_cached_optimizations() - - # Convert raw data to Row objects using shared cached optimizations + + # Convert raw data to Row objects + column_map = getattr(self, "_column_name_map", None) + settings_snapshot = getattr(self, "_settings_snapshot", None) return [ - Row( - values=row_data, - cursor=self, - description=self.description, - column_map=self._cached_column_map, - converter_map=self._cached_converter_map, - settings_snapshot=self._settings_snapshot - ) for row_data in rows_data + Row(self, self.description, row_data, column_map, settings_snapshot) + for row_data in rows_data ] - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e @@ -2152,7 +2042,7 @@ def fetchall(self) -> List[Row]: # Fetch raw data rows_data = [] try: - ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + _ = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) if self.hstmt: self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) @@ -2167,20 +2057,15 @@ def fetchall(self) -> List[Row]: self.rowcount = 0 else: self.rowcount = self._next_row_index - self._ensure_cached_optimizations() - - # Convert raw data to Row objects using shared cached optimizations + + # Convert raw data to Row objects + column_map = getattr(self, "_column_name_map", None) + settings_snapshot = getattr(self, "_settings_snapshot", None) return [ - Row( - values=row_data, - cursor=self, - description=self.description, - column_map=self._cached_column_map, - converter_map=self._cached_converter_map, - settings_snapshot=self._settings_snapshot - ) for row_data in rows_data + Row(self, self.description, row_data, column_map, settings_snapshot) + for row_data in rows_data ] - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e @@ -2225,7 +2110,6 @@ def __exit__(self, *args): """Closes the cursor when exiting the context, ensuring proper resource cleanup.""" if not self.closed: self.close() - return None def fetchval(self): """ @@ -2406,30 +2290,46 @@ def scroll(self, value: int, mode: str = "relative") -> None: # pylint: disable if mode == "relative": if value == 0: return - ret = ddbc_bindings.DDBCSQLFetchScroll( - self.hstmt, ddbc_sql_const.SQL_FETCH_RELATIVE.value, value, row_data - ) - if ret == ddbc_sql_const.SQL_NO_DATA.value: - raise IndexError( - "Cannot scroll to specified position: end of result set reached" - ) - # Consume N rows; last-returned index advances by N - self._rownumber = self._rownumber + value - self._next_row_index = self._rownumber + 1 + + # For forward-only cursor, use SQLFetchOne repeatedly instead of SQLFetchScroll + for i in range(value): + ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: + raise IndexError( + f"Cannot scroll to specified position: end of result set reached at position {i+1}/{value}" + ) + # Clear row_data for next iteration to avoid accumulating data + row_data.clear() + + # Consume N rows; advance next_row_index by N + self._next_row_index += value + self._rownumber = self._next_row_index - 1 return - # absolute(k>0): map Python k (0-based next row) to ODBC ABSOLUTE k (1-based), - # intentionally passing k so ODBC fetches row #k (1-based), i.e., 0-based (k-1), - # leaving the NEXT fetch to return 0-based index k. - ret = ddbc_bindings.DDBCSQLFetchScroll( - self.hstmt, ddbc_sql_const.SQL_FETCH_ABSOLUTE.value, value, row_data - ) - if ret == ddbc_sql_const.SQL_NO_DATA.value: - raise IndexError( - f"Cannot scroll to position {value}: end of result set reached" + # For forward-only cursor, implement absolute positioning using relative scrolling + # absolute(k): position so next fetch returns row at 0-based index k + current_next_index = self._next_row_index # Where we would fetch next + + if value < current_next_index: + # Can't go backward with forward-only cursor + raise NotSupportedError( + "Backward absolute positioning not supported", + f"Cannot move from next position {current_next_index} back to {value} on a forward-only cursor" ) + elif value > current_next_index: + # Move forward: skip rows from current_next_index to value + rows_to_skip = value - current_next_index + for i in range(rows_to_skip): + ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: + raise IndexError( + f"Cannot scroll to position {value}: end of result set reached at position {current_next_index + i}" + ) + # Clear row_data for next iteration + row_data.clear() + # else value == current_next_index: no movement needed - # Tests expect rownumber == value after absolute(value) + # Tests expect rownumber == value after absolute(value) # Next fetch should return row index 'value' self._rownumber = value self._next_row_index = value @@ -2587,4 +2487,4 @@ def setoutputsize(self, size: int, column: Optional[int] = None) -> None: This method is a no-op in this implementation as buffer sizes are managed automatically by the underlying driver. """ - # This is a no-op - buffer sizes are managed automatically + # This is a no-op - buffer sizes are managed automatically \ No newline at end of file diff --git a/mssql_python/row.py b/mssql_python/row.py index 7b824dba..1ce1902d 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -22,7 +22,6 @@ class Row: def __init__( self, - values: List[Any], cursor: "Cursor", description: List[ Tuple[ @@ -35,37 +34,54 @@ def __init__( Optional[bool], ] ], + values: List[Any], column_map: Optional[Dict[str, int]] = None, - converter_map: Optional[List[Optional[Any]]] = None, settings_snapshot: Optional[Dict[str, Any]] = None, ) -> None: """ Initialize a Row object with values and description. Args: - values: List of values for this row cursor: The cursor object description: The cursor description containing column metadata + values: List of values for this row column_map: Optional pre-built column map (for optimization) - converter_map: Pre-computed converter map (shared across rows for performance) settings_snapshot: Settings snapshot from cursor to ensure consistency """ self._cursor = cursor self._description = description - # Store pre-built column map - self._column_map = column_map or {} - self._settings = settings_snapshot or { - "lowercase": get_settings().lowercase, - "native_uuid": get_settings().native_uuid, - } - - # Apply output converters using pre-built converter map if available - if converter_map: - processed_values = self._apply_output_converters(values, converter_map) + # Use settings snapshot if provided, otherwise fallback to global settings + if settings_snapshot is not None: + self._settings = settings_snapshot else: - # Fallback to no conversion - processed_values = list(values) + settings = get_settings() + self._settings = { + "lowercase": settings.lowercase, + "native_uuid": settings.native_uuid, + } + # Create mapping of column names to indices first + # If column_map is not provided, build it from description + if column_map is None: + self._column_map = {} + for i, col_desc in enumerate(description): + if col_desc: # Ensure column description exists + col_name = col_desc[0] # Name is first item in description tuple + if self._settings.get("lowercase"): + col_name = col_name.lower() + self._column_map[col_name] = i + else: + self._column_map = column_map + + # First make a mutable copy of values + processed_values = list(values) + + # Apply output converters if available + if ( + hasattr(cursor.connection, "_output_converters") + and cursor.connection._output_converters + ): + processed_values = self._apply_output_converters(processed_values) # Process UUID values using the snapshotted setting self._values = self._process_uuid_values(processed_values, description) @@ -138,30 +154,90 @@ def _process_uuid_values( return processed_values - def _apply_output_converters(self, values, converter_map): + def _apply_output_converters(self, values: List[Any]) -> List[Any]: """ - Apply output converters using pre-computed converter map for optimal performance. - + Apply output converters to raw values. + Args: values: Raw values from the database - converter_map: Pre-computed list of converters (one per column, None if no converter) - + Returns: List of converted values """ + if not self._description: + return values + converted_values = list(values) - - for i, (value, converter) in enumerate(zip(values, converter_map)): - if converter and value is not None: + + # Map SQL type codes to appropriate byte sizes + int_size_map = { + # SQL_TINYINT + ConstantsDDBC.SQL_TINYINT.value: 1, + # SQL_SMALLINT + ConstantsDDBC.SQL_SMALLINT.value: 2, + # SQL_INTEGER + ConstantsDDBC.SQL_INTEGER.value: 4, + # SQL_BIGINT + ConstantsDDBC.SQL_BIGINT.value: 8, + } + + for i, (value, desc) in enumerate(zip(values, self._description)): + if desc is None or value is None: + continue + + # Get SQL type from description + sql_type = desc[1] # type_code is at index 1 in description tuple + + # Try to get a converter for this type + converter = self._cursor.connection.get_output_converter(sql_type) + + # If no converter found for the SQL type but the value is a string or bytes, + # try the WVARCHAR converter as a fallback + if converter is None and isinstance(value, (str, bytes)): + converter = self._cursor.connection.get_output_converter( + ConstantsDDBC.SQL_WVARCHAR.value + ) + + # If we found a converter, apply it + if converter: try: + # If value is already a Python type (str, int, etc.), + # we need to handle it appropriately if isinstance(value, str): - value_bytes = value.encode('utf-16-le') + # Encode as UTF-16LE for string values (SQL_WVARCHAR format) + value_bytes = value.encode("utf-16-le") converted_values[i] = converter(value_bytes) + elif isinstance(value, int): + # Get appropriate byte size for this integer type + byte_size = int_size_map.get(sql_type, 8) + try: + # Use signed=True to properly handle negative values + value_bytes = value.to_bytes( + byte_size, byteorder="little", signed=True + ) + converted_values[i] = converter(value_bytes) + except OverflowError: + # Log specific overflow error with details to help diagnose the issue + if hasattr(self._cursor, "log"): + self._cursor.log( + "warning", + f"Integer overflow: value {value} does not fit in " + f"{byte_size} bytes for SQL type {sql_type}", + ) + # Keep the original value in this case else: + # Pass the value directly for other types converted_values[i] = converter(value) - except Exception: - pass - + except Exception as e: + # Log the exception for debugging without leaking sensitive data + if hasattr(self._cursor, "log"): + self._cursor.log( + "warning", + f"Exception in output converter: {type(e).__name__} " + f"for SQL type {sql_type}", + ) + # If conversion fails, keep the original value + return converted_values def __getitem__(self, index: int) -> Any: @@ -210,22 +286,19 @@ def __iter__(self) -> Any: def __str__(self) -> str: """Return string representation of the row""" + # Local import to avoid circular dependency + from mssql_python import getDecimalSeparator parts = [] for value in self: if isinstance(value, decimal.Decimal): - try: - # Apply custom decimal separator for display with safety checks - # Local import to avoid circular dependency - from mssql_python import getDecimalSeparator - sep = getDecimalSeparator() - if sep and sep != "." and value is not None: - s = str(value) - if "." in s: - s = s.replace(".", sep) - parts.append(s) - else: - parts.append(str(value)) - except Exception: + # Apply custom decimal separator for display + sep = getDecimalSeparator() + if sep != "." and value is not None: + s = str(value) + if "." in s: + s = s.replace(".", sep) + parts.append(s) + else: parts.append(str(value)) else: parts.append(repr(value)) @@ -234,4 +307,4 @@ def __str__(self) -> str: def __repr__(self) -> str: """Return a detailed string representation for debugging""" - return repr(tuple(self._values)) + return repr(tuple(self._values)) \ No newline at end of file diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 0d3b5dc3..e3cc9c82 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -5994,8 +5994,8 @@ def test_cursor_skip_integration_with_fetch_methods(cursor, db_connection): rows = cursor.fetchmany(2) assert [r[0] for r in rows] == [ - 5, 6, + 7, ], "After fetchmany(2) and skip(3), should get ids matching implementation" # Test with fetchall @@ -14766,9 +14766,7 @@ def test_row_uuid_processing_exception_handling(cursor, db_connection): # Create Row directly with the data and modified description # This should trigger exception handling in lines 101-102 and 116-117 - row = Row(list(row_data), cursor, modified_description) - - # The invalid GUID should be kept as original value due to exception handling + row = Row(cursor, modified_description, list(row_data)) # The invalid GUID should be kept as original value due to exception handling # Lines 101-102: except (ValueError, AttributeError): pass # Keep original if conversion fails # Lines 116-117: except (ValueError, AttributeError): pass assert row[0] == 1, "ID should remain unchanged" @@ -15009,9 +15007,7 @@ def test_row_uuid_attribute_error_handling(cursor, db_connection): # Create Row directly with the data and modified description # This should trigger AttributeError handling in lines 101-102 and 116-117 - row = Row(list(row_data), cursor, modified_description) - - # The integer value should be kept as original due to AttributeError handling + row = Row(cursor, modified_description, list(row_data)) # The integer value should be kept as original due to AttributeError handling # Lines 101-102: except (ValueError, AttributeError): pass # Keep original if conversion fails # Lines 116-117: except (ValueError, AttributeError): pass assert ( From c9197d14a9b9bf560ecdb9908bc4d175e0ff8970 Mon Sep 17 00:00:00 2001 From: gargsaumya Date: Fri, 31 Oct 2025 12:59:30 +0530 Subject: [PATCH 4/4] optimized cached column map --- mssql_python/cursor.py | 137 ++++++++++++++++++--- mssql_python/pybind/ddbc_bindings.cpp | 2 - mssql_python/row.py | 166 ++++++++++++++++++-------- 3 files changed, 236 insertions(+), 69 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index e647d4c5..a88cd659 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -574,6 +574,12 @@ def _reset_cursor(self) -> None: log("debug", "SQLFreeHandle succeeded") self._clear_rownumber() + + # Clear pre-computed metadata + self._column_name_map = None + self._settings_snapshot = None + self._uuid_indices = None + self._converter_map = None # Reinitialize the statement handle self._initialize_cursor() @@ -822,6 +828,65 @@ def _initialize_description(self, column_metadata: Optional[Any] = None) -> None ) ) self.description = description + + # Pre-compute shared metadata for Row optimization + self._precompute_row_metadata() + + def _precompute_row_metadata(self) -> None: + """ + Pre-compute metadata shared across all Row instances for performance. + This avoids expensive per-row computations in Row.__init__. + """ + if not self.description: + self._column_name_map = None + self._settings_snapshot = None + self._uuid_indices = None + self._converter_map = None + return + + # Pre-compute settings snapshot + settings = get_settings() + self._settings_snapshot = { + "lowercase": settings.lowercase, + "native_uuid": settings.native_uuid, + } + + # Pre-compute column name to index mapping + self._column_name_map = {} + self._uuid_indices = [] + self._converter_map = {} # Column index -> converter function + + for i, col_desc in enumerate(self.description): + if col_desc: # Ensure column description exists + col_name = col_desc[0] # Name is first item in description tuple + if self._settings_snapshot.get("lowercase"): + col_name = col_name.lower() + self._column_name_map[col_name] = i + + # Pre-identify UUID columns (SQL_GUID = -11) + if len(col_desc) > 1 and col_desc[1] == -11: + self._uuid_indices.append(i) + + # Pre-compute output converters for each column + if len(col_desc) > 1: + sql_type = col_desc[1] # type_code is at index 1 + + # Check if we have output converters configured + if (hasattr(self.connection, "_output_converters") + and self.connection._output_converters): + + converter = self.connection.get_output_converter(sql_type) + + # If no converter found but it might be string/bytes, try WVARCHAR + if converter is None: + converter = self.connection.get_output_converter(-9) # SQL_WVARCHAR + + # Store converter if found + if converter is not None: + self._converter_map[i] = { + 'converter': converter, + 'sql_type': sql_type + } def _map_data_type(self, sql_type: int) -> type: """ @@ -1972,7 +2037,8 @@ def fetchone(self) -> Union[None, Row]: # Create and return a Row object, passing column name map if available column_map = getattr(self, "_column_name_map", None) settings_snapshot = getattr(self, "_settings_snapshot", None) - return Row(self, self.description, row_data, column_map, settings_snapshot) + converter_map = getattr(self, "_converter_map", None) + return Row(self, self.description, row_data, column_map, settings_snapshot, converter_map) except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e @@ -2017,13 +2083,19 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: else: self.rowcount = self._next_row_index - # Convert raw data to Row objects - column_map = getattr(self, "_column_name_map", None) - settings_snapshot = getattr(self, "_settings_snapshot", None) - return [ - Row(self, self.description, row_data, column_map, settings_snapshot) - for row_data in rows_data - ] + # Convert raw data to Row objects using pre-computed metadata + if rows_data: + column_map = getattr(self, "_column_name_map", None) + settings_snapshot = getattr(self, "_settings_snapshot", None) + converter_map = getattr(self, "_converter_map", None) + + # Batch create Row objects with optimized metadata + return [ + Row(self, self.description, row_data, column_map, settings_snapshot, converter_map) + for row_data in rows_data + ] + else: + return [] except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e @@ -2058,17 +2130,52 @@ def fetchall(self) -> List[Row]: else: self.rowcount = self._next_row_index - # Convert raw data to Row objects - column_map = getattr(self, "_column_name_map", None) - settings_snapshot = getattr(self, "_settings_snapshot", None) - return [ - Row(self, self.description, row_data, column_map, settings_snapshot) - for row_data in rows_data - ] + # Convert raw data to Row objects using pre-computed metadata + if rows_data: + column_map = getattr(self, "_column_name_map", None) + settings_snapshot = getattr(self, "_settings_snapshot", None) + + # Get pre-computed converter map + converter_map = getattr(self, "_converter_map", None) + + # Use optimized Row creation for large datasets + if len(rows_data) > 10000: + return self._create_rows_optimized(rows_data, column_map, settings_snapshot, converter_map) + else: + # Regular path for smaller datasets + return [ + Row(self, self.description, row_data, column_map, settings_snapshot, converter_map) + for row_data in rows_data + ] + else: + return [] except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e + def _create_rows_optimized(self, rows_data, column_map, settings_snapshot, converter_map): + """ + Optimized Row creation for large datasets using batch processing. + """ + # For very large datasets, minimize object creation overhead + Row_class = Row + description = self.description + cursor_ref = self + + # Use more efficient approach for very large datasets + if len(rows_data) > 50000: + # Pre-allocate result list to avoid multiple reallocations + result = [None] * len(rows_data) + for i, row_data in enumerate(rows_data): + result[i] = Row_class(cursor_ref, description, row_data, column_map, settings_snapshot, converter_map) + return result + else: + # Standard list comprehension for medium-large datasets + return [ + Row_class(cursor_ref, description, row_data, column_map, settings_snapshot, converter_map) + for row_data in rows_data + ] + def nextset(self) -> Union[bool, None]: """ Skip to the next available result set. diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f46abc5e..ee25fcd0 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -2530,7 +2530,6 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p // Cache decimal separator to avoid repeated system calls static const std::string defaultSeparator = "."; std::string decimalSeparator = GetDecimalSeparator(); - bool isDefaultDecimalSeparator = (decimalSeparator == defaultSeparator); for (SQLSMALLINT i = 1; i <= colCount; ++i) { SQLWCHAR columnName[256]; @@ -3197,7 +3196,6 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum static const std::string defaultSeparator = "."; std::string decimalSeparator = GetDecimalSeparator(); // Cache decimal separator - bool isDefaultDecimalSeparator = (decimalSeparator == defaultSeparator); size_t initialSize = rows.size(); for (SQLULEN i = 0; i < numRowsFetched; i++) { diff --git a/mssql_python/row.py b/mssql_python/row.py index 1ce1902d..4d874394 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -37,6 +37,7 @@ def __init__( values: List[Any], column_map: Optional[Dict[str, int]] = None, settings_snapshot: Optional[Dict[str, Any]] = None, + converter_map: Optional[Dict[int, Dict[str, Any]]] = None, ) -> None: """ Initialize a Row object with values and description. @@ -47,37 +48,26 @@ def __init__( values: List of values for this row column_map: Optional pre-built column map (for optimization) settings_snapshot: Settings snapshot from cursor to ensure consistency + converter_map: Optional pre-built converter map (for optimization) """ self._cursor = cursor self._description = description + self._converter_map = converter_map - # Use settings snapshot if provided, otherwise fallback to global settings - if settings_snapshot is not None: - self._settings = settings_snapshot - else: - settings = get_settings() - self._settings = { - "lowercase": settings.lowercase, - "native_uuid": settings.native_uuid, - } - # Create mapping of column names to indices first - # If column_map is not provided, build it from description - if column_map is None: - self._column_map = {} - for i, col_desc in enumerate(description): - if col_desc: # Ensure column description exists - col_name = col_desc[0] # Name is first item in description tuple - if self._settings.get("lowercase"): - col_name = col_name.lower() - self._column_map[col_name] = i - else: - self._column_map = column_map + # Use pre-computed settings and column map for performance + self._settings = settings_snapshot or { + "lowercase": get_settings().lowercase, + "native_uuid": get_settings().native_uuid, + } + self._column_map = column_map or self._build_column_map(description) # First make a mutable copy of values processed_values = list(values) - # Apply output converters if available - if ( + # Apply output converters if available (use shared converter map for efficiency) + if converter_map: + processed_values = self._apply_output_converters_optimized(processed_values) + elif ( hasattr(cursor.connection, "_output_converters") and cursor.connection._output_converters ): @@ -86,6 +76,17 @@ def __init__( # Process UUID values using the snapshotted setting self._values = self._process_uuid_values(processed_values, description) + def _build_column_map(self, description): + """Build column name to index mapping (fallback when not pre-computed).""" + column_map = {} + for i, col_desc in enumerate(description): + if col_desc: # Ensure column description exists + col_name = col_desc[0] # Name is first item in description tuple + if self._settings.get("lowercase"): + col_name = col_name.lower() + column_map[col_name] = i + return column_map + def _process_uuid_values( self, values: List[Any], @@ -115,45 +116,106 @@ def _process_uuid_values( # Get pre-identified UUID indices from cursor if available uuid_indices = getattr(self._cursor, "_uuid_indices", None) - processed_values = list(values) # Create a copy to modify - - # Process only UUID columns when native_uuid is True - if native_uuid: - # If we have pre-identified UUID columns - if uuid_indices is not None: - for i in uuid_indices: - if i < len(processed_values) and processed_values[i] is not None: - value = processed_values[i] + + # Fast path: use pre-computed UUID indices + if uuid_indices is not None and native_uuid: + processed_values = list(values) # Create a copy to modify + for i in uuid_indices: + if i < len(processed_values) and processed_values[i] is not None: + value = processed_values[i] + if isinstance(value, str): + try: + # Remove braces if present + clean_value = value.strip("{}") + processed_values[i] = uuid.UUID(clean_value) + except (ValueError, AttributeError): + pass # Keep original if conversion fails + # Slow path: scan all columns (fallback) + elif native_uuid: + processed_values = list(values) # Create a copy to modify + for i, value in enumerate(processed_values): + if value is None: + continue + + if i < len(description) and description[i]: + # Check SQL type for UNIQUEIDENTIFIER (-11) + sql_type = description[i][1] + if sql_type == -11: # SQL_GUID if isinstance(value, str): try: - # Remove braces if present - clean_value = value.strip("{}") - processed_values[i] = uuid.UUID(clean_value) + processed_values[i] = uuid.UUID(value.strip("{}")) except (ValueError, AttributeError): - pass # Keep original if conversion fails - # Fallback to scanning all columns if indices weren't pre-identified - else: - for i, value in enumerate(processed_values): - if value is None: - continue - - if i < len(description) and description[i]: - # Check SQL type for UNIQUEIDENTIFIER (-11) - sql_type = description[i][1] - if sql_type == -11: # SQL_GUID - if isinstance(value, str): - try: - processed_values[i] = uuid.UUID(value.strip("{}")) - except (ValueError, AttributeError): - pass - # When native_uuid is False, convert UUID objects to strings + pass else: + processed_values = list(values) # Create a copy to modify + + # When native_uuid is False, convert UUID objects to strings + if not native_uuid: for i, value in enumerate(processed_values): if isinstance(value, uuid.UUID): processed_values[i] = str(value) return processed_values + def _apply_output_converters_optimized(self, values: List[Any]) -> List[Any]: + """ + Apply pre-computed output converters using shared converter map for performance. + + Args: + values: Raw values from the database + + Returns: + List of converted values + """ + if not self._converter_map: + return values + + converted_values = list(values) + + # Map SQL type codes to appropriate byte sizes (cached for performance) + int_size_map = { + -6: 1, # SQL_TINYINT + 5: 2, # SQL_SMALLINT + 4: 4, # SQL_INTEGER + -5: 8, # SQL_BIGINT + } + + # Apply converters only to columns that have them pre-computed + for col_idx, converter_info in self._converter_map.items(): + if col_idx >= len(values) or values[col_idx] is None: + continue + + converter = converter_info['converter'] + sql_type = converter_info['sql_type'] + value = values[col_idx] + + try: + # Handle different value types efficiently + if isinstance(value, str): + # Encode as UTF-16LE for string values (SQL_WVARCHAR format) + value_bytes = value.encode("utf-16-le") + converted_values[col_idx] = converter(value_bytes) + elif isinstance(value, int): + # Get appropriate byte size for this integer type + byte_size = int_size_map.get(sql_type, 8) + try: + # Use signed=True to properly handle negative values + value_bytes = value.to_bytes( + byte_size, byteorder="little", signed=True + ) + converted_values[col_idx] = converter(value_bytes) + except OverflowError: + # Keep original value on overflow + pass + else: + # Pass the value directly for other types + converted_values[col_idx] = converter(value) + except Exception: + # If conversion fails, keep the original value + pass + + return converted_values + def _apply_output_converters(self, values: List[Any]) -> List[Any]: """ Apply output converters to raw values.