diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 446a2dfb..a88cd659 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -574,6 +574,12 @@ def _reset_cursor(self) -> None: log("debug", "SQLFreeHandle succeeded") self._clear_rownumber() + + # Clear pre-computed metadata + self._column_name_map = None + self._settings_snapshot = None + self._uuid_indices = None + self._converter_map = None # Reinitialize the statement handle self._initialize_cursor() @@ -822,6 +828,65 @@ def _initialize_description(self, column_metadata: Optional[Any] = None) -> None ) ) self.description = description + + # Pre-compute shared metadata for Row optimization + self._precompute_row_metadata() + + def _precompute_row_metadata(self) -> None: + """ + Pre-compute metadata shared across all Row instances for performance. + This avoids expensive per-row computations in Row.__init__. + """ + if not self.description: + self._column_name_map = None + self._settings_snapshot = None + self._uuid_indices = None + self._converter_map = None + return + + # Pre-compute settings snapshot + settings = get_settings() + self._settings_snapshot = { + "lowercase": settings.lowercase, + "native_uuid": settings.native_uuid, + } + + # Pre-compute column name to index mapping + self._column_name_map = {} + self._uuid_indices = [] + self._converter_map = {} # Column index -> converter function + + for i, col_desc in enumerate(self.description): + if col_desc: # Ensure column description exists + col_name = col_desc[0] # Name is first item in description tuple + if self._settings_snapshot.get("lowercase"): + col_name = col_name.lower() + self._column_name_map[col_name] = i + + # Pre-identify UUID columns (SQL_GUID = -11) + if len(col_desc) > 1 and col_desc[1] == -11: + self._uuid_indices.append(i) + + # Pre-compute output converters for each column + if len(col_desc) > 1: + sql_type = col_desc[1] # type_code is at index 1 + + # Check if we have output converters configured + if (hasattr(self.connection, "_output_converters") + and self.connection._output_converters): + + converter = self.connection.get_output_converter(sql_type) + + # If no converter found but it might be string/bytes, try WVARCHAR + if converter is None: + converter = self.connection.get_output_converter(-9) # SQL_WVARCHAR + + # Store converter if found + if converter is not None: + self._converter_map[i] = { + 'converter': converter, + 'sql_type': sql_type + } def _map_data_type(self, sql_type: int) -> type: """ @@ -1972,7 +2037,8 @@ def fetchone(self) -> Union[None, Row]: # Create and return a Row object, passing column name map if available column_map = getattr(self, "_column_name_map", None) settings_snapshot = getattr(self, "_settings_snapshot", None) - return Row(self, self.description, row_data, column_map, settings_snapshot) + converter_map = getattr(self, "_converter_map", None) + return Row(self, self.description, row_data, column_map, settings_snapshot, converter_map) except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e @@ -2017,13 +2083,19 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]: else: self.rowcount = self._next_row_index - # Convert raw data to Row objects - column_map = getattr(self, "_column_name_map", None) - settings_snapshot = getattr(self, "_settings_snapshot", None) - return [ - Row(self, self.description, row_data, column_map, settings_snapshot) - for row_data in rows_data - ] + # Convert raw data to Row objects using pre-computed metadata + if rows_data: + column_map = getattr(self, "_column_name_map", None) + settings_snapshot = getattr(self, "_settings_snapshot", None) + converter_map = getattr(self, "_converter_map", None) + + # Batch create Row objects with optimized metadata + return [ + Row(self, self.description, row_data, column_map, settings_snapshot, converter_map) + for row_data in rows_data + ] + else: + return [] except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e @@ -2058,17 +2130,52 @@ def fetchall(self) -> List[Row]: else: self.rowcount = self._next_row_index - # Convert raw data to Row objects - column_map = getattr(self, "_column_name_map", None) - settings_snapshot = getattr(self, "_settings_snapshot", None) - return [ - Row(self, self.description, row_data, column_map, settings_snapshot) - for row_data in rows_data - ] + # Convert raw data to Row objects using pre-computed metadata + if rows_data: + column_map = getattr(self, "_column_name_map", None) + settings_snapshot = getattr(self, "_settings_snapshot", None) + + # Get pre-computed converter map + converter_map = getattr(self, "_converter_map", None) + + # Use optimized Row creation for large datasets + if len(rows_data) > 10000: + return self._create_rows_optimized(rows_data, column_map, settings_snapshot, converter_map) + else: + # Regular path for smaller datasets + return [ + Row(self, self.description, row_data, column_map, settings_snapshot, converter_map) + for row_data in rows_data + ] + else: + return [] except Exception as e: # pylint: disable=broad-exception-caught # On error, don't increment rownumber - rethrow the error raise e + def _create_rows_optimized(self, rows_data, column_map, settings_snapshot, converter_map): + """ + Optimized Row creation for large datasets using batch processing. + """ + # For very large datasets, minimize object creation overhead + Row_class = Row + description = self.description + cursor_ref = self + + # Use more efficient approach for very large datasets + if len(rows_data) > 50000: + # Pre-allocate result list to avoid multiple reallocations + result = [None] * len(rows_data) + for i, row_data in enumerate(rows_data): + result[i] = Row_class(cursor_ref, description, row_data, column_map, settings_snapshot, converter_map) + return result + else: + # Standard list comprehension for medium-large datasets + return [ + Row_class(cursor_ref, description, row_data, column_map, settings_snapshot, converter_map) + for row_data in rows_data + ] + def nextset(self) -> Union[bool, None]: """ Skip to the next available result set. @@ -2290,30 +2397,46 @@ def scroll(self, value: int, mode: str = "relative") -> None: # pylint: disable if mode == "relative": if value == 0: return - ret = ddbc_bindings.DDBCSQLFetchScroll( - self.hstmt, ddbc_sql_const.SQL_FETCH_RELATIVE.value, value, row_data - ) - if ret == ddbc_sql_const.SQL_NO_DATA.value: - raise IndexError( - "Cannot scroll to specified position: end of result set reached" - ) - # Consume N rows; last-returned index advances by N - self._rownumber = self._rownumber + value - self._next_row_index = self._rownumber + 1 + + # For forward-only cursor, use SQLFetchOne repeatedly instead of SQLFetchScroll + for i in range(value): + ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: + raise IndexError( + f"Cannot scroll to specified position: end of result set reached at position {i+1}/{value}" + ) + # Clear row_data for next iteration to avoid accumulating data + row_data.clear() + + # Consume N rows; advance next_row_index by N + self._next_row_index += value + self._rownumber = self._next_row_index - 1 return - # absolute(k>0): map Python k (0-based next row) to ODBC ABSOLUTE k (1-based), - # intentionally passing k so ODBC fetches row #k (1-based), i.e., 0-based (k-1), - # leaving the NEXT fetch to return 0-based index k. - ret = ddbc_bindings.DDBCSQLFetchScroll( - self.hstmt, ddbc_sql_const.SQL_FETCH_ABSOLUTE.value, value, row_data - ) - if ret == ddbc_sql_const.SQL_NO_DATA.value: - raise IndexError( - f"Cannot scroll to position {value}: end of result set reached" + # For forward-only cursor, implement absolute positioning using relative scrolling + # absolute(k): position so next fetch returns row at 0-based index k + current_next_index = self._next_row_index # Where we would fetch next + + if value < current_next_index: + # Can't go backward with forward-only cursor + raise NotSupportedError( + "Backward absolute positioning not supported", + f"Cannot move from next position {current_next_index} back to {value} on a forward-only cursor" ) + elif value > current_next_index: + # Move forward: skip rows from current_next_index to value + rows_to_skip = value - current_next_index + for i in range(rows_to_skip): + ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: + raise IndexError( + f"Cannot scroll to position {value}: end of result set reached at position {current_next_index + i}" + ) + # Clear row_data for next iteration + row_data.clear() + # else value == current_next_index: no movement needed - # Tests expect rownumber == value after absolute(value) + # Tests expect rownumber == value after absolute(value) # Next fetch should return row index 'value' self._rownumber = value self._next_row_index = value @@ -2471,4 +2594,4 @@ def setoutputsize(self, size: int, column: Optional[int] = None) -> None: This method is a no-op in this implementation as buffer sizes are managed automatically by the underlying driver. """ - # This is a no-op - buffer sizes are managed automatically + # This is a no-op - buffer sizes are managed automatically \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 96a8d9f7..ee25fcd0 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -34,6 +34,68 @@ #endif #define DAE_CHUNK_SIZE 8192 #define SQL_MAX_LOB_SIZE 8000 + +namespace PythonObjectCache { + static py::object datetime_class; + static py::object date_class; + static py::object time_class; + static py::object decimal_class; + static py::object uuid_class; + static bool cache_initialized = false; + + void initialize() { + if (!cache_initialized) { + auto datetime_module = py::module_::import("datetime"); + datetime_class = datetime_module.attr("datetime"); + date_class = datetime_module.attr("date"); + time_class = datetime_module.attr("time"); + + auto decimal_module = py::module_::import("decimal"); + decimal_class = decimal_module.attr("Decimal"); + + auto uuid_module = py::module_::import("uuid"); + uuid_class = uuid_module.attr("UUID"); + + cache_initialized = true; + } + } + + py::object get_datetime_class() { + if (cache_initialized && datetime_class) { + return datetime_class; + } + return py::module_::import("datetime").attr("datetime"); + } + + py::object get_date_class() { + if (cache_initialized && date_class) { + return date_class; + } + return py::module_::import("datetime").attr("date"); + } + + py::object get_time_class() { + if (cache_initialized && time_class) { + return time_class; + } + return py::module_::import("datetime").attr("time"); + } + + py::object get_decimal_class() { + if (cache_initialized && decimal_class) { + return decimal_class; + } + return py::module_::import("decimal").attr("Decimal"); + } + + py::object get_uuid_class() { + if (cache_initialized && uuid_class) { + return uuid_class; + } + return py::module_::import("uuid").attr("UUID"); + } +} + //------------------------------------------------------------------------------------------------- // Class definitions //------------------------------------------------------------------------------------------------- @@ -458,7 +520,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_TYPE_DATE: { - py::object dateType = py::module_::import("datetime").attr("date"); + py::object dateType = PythonObjectCache::get_date_class(); if (!py::isinstance(param, dateType)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } @@ -475,7 +537,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_TYPE_TIME: { - py::object timeType = py::module_::import("datetime").attr("time"); + py::object timeType = PythonObjectCache::get_time_class(); if (!py::isinstance(param, timeType)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } @@ -488,7 +550,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_SS_TIMESTAMPOFFSET: { - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = PythonObjectCache::get_datetime_class(); if (!py::isinstance(param, datetimeType)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } @@ -532,7 +594,7 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params, break; } case SQL_C_TYPE_TIMESTAMP: { - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = PythonObjectCache::get_datetime_class(); if (!py::isinstance(param, datetimeType)) { ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex)); } @@ -1419,11 +1481,11 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q DriverLoader::getInstance().loadDriver(); // Load the driver } - // Ensure statement is scrollable BEFORE executing + // Configure forward-only cursor if (SQLSetStmtAttr_ptr && StatementHandle && StatementHandle->get()) { SQLSetStmtAttr_ptr(StatementHandle->get(), SQL_ATTR_CURSOR_TYPE, - (SQLPOINTER)SQL_CURSOR_STATIC, + (SQLPOINTER)SQL_CURSOR_FORWARD_ONLY, 0); SQLSetStmtAttr_ptr(StatementHandle->get(), SQL_ATTR_CONCURRENCY, @@ -1556,11 +1618,11 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, LOG("Statement handle is null or empty"); } - // Ensure statement is scrollable BEFORE executing + // Configure forward-only cursor if (SQLSetStmtAttr_ptr && hStmt) { SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CURSOR_TYPE, - (SQLPOINTER)SQL_CURSOR_STATIC, + (SQLPOINTER)SQL_CURSOR_FORWARD_ONLY, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_CONCURRENCY, @@ -2002,7 +2064,7 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, DateTimeOffset* dtoArray = AllocateParamBufferArray(tempBuffers, paramSetSize); strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - py::object datetimeType = py::module_::import("datetime").attr("datetime"); + py::object datetimeType = PythonObjectCache::get_datetime_class(); for (size_t i = 0; i < paramSetSize; ++i) { const py::handle& param = columnValues[i]; @@ -2080,9 +2142,8 @@ SQLRETURN BindParameterArray(SQLHANDLE hStmt, SQLGUID* guidArray = AllocateParamBufferArray(tempBuffers, paramSetSize); strLenOrIndArray = AllocateParamBufferArray(tempBuffers, paramSetSize); - // Get cached UUID class from module-level helper - // This avoids static object destruction issues during Python finalization - py::object uuid_class = py::module_::import("mssql_python.ddbc_bindings").attr("_get_uuid_class")(); + // Get cached UUID class + py::object uuid_class = PythonObjectCache::get_uuid_class(); for (size_t i = 0; i < paramSetSize; ++i) { const py::handle& element = columnValues[i]; @@ -2465,6 +2526,11 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p SQLRETURN ret = SQL_SUCCESS; SQLHSTMT hStmt = StatementHandle->get(); + + // Cache decimal separator to avoid repeated system calls + static const std::string defaultSeparator = "."; + std::string decimalSeparator = GetDecimalSeparator(); + for (SQLSMALLINT i = 1; i <= colCount; ++i) { SQLWCHAR columnName[256]; SQLSMALLINT columnNameLen; @@ -2637,43 +2703,18 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, numericStr, sizeof(numericStr), &indicator); if (SQL_SUCCEEDED(ret)) { - try { - // Validate 'indicator' to avoid buffer overflow and fallback to a safe - // null-terminated read when length is unknown or out-of-range. - const char* cnum = reinterpret_cast(numericStr); - size_t bufSize = sizeof(numericStr); - size_t safeLen = 0; - - if (indicator > 0 && indicator <= static_cast(bufSize)) { - // indicator appears valid and within the buffer size - safeLen = static_cast(indicator); - } else { - // indicator is unknown, zero, negative, or too large; determine length - // by searching for a terminating null (safe bounded scan) - for (size_t j = 0; j < bufSize; ++j) { - if (cnum[j] == '\0') { - safeLen = j; - break; - } - } - // if no null found, use the full buffer size as a conservative fallback - if (safeLen == 0 && bufSize > 0 && cnum[0] != '\0') { - safeLen = bufSize; - } - } - - // Use the validated length to construct the string for Decimal - std::string numStr(cnum, safeLen); - - // Create Python Decimal object - py::object decimalObj = py::module_::import("decimal").attr("Decimal")(numStr); - - // Add to row - row.append(decimalObj); - } catch (const py::error_already_set& e) { - // If conversion fails, append None - LOG("Error converting to decimal: {}", e.what()); + if (indicator == SQL_NULL_DATA) { row.append(py::none()); + } else { + try { + const char* cnum = reinterpret_cast(numericStr); + py::object decimalObj = PythonObjectCache::get_decimal_class()(py::str(cnum)); + row.append(decimalObj); + } catch (const py::error_already_set& e) { + // If conversion fails, append None + LOG("Error converting to decimal: {}", e.what()); + row.append(py::none()); + } } } else { @@ -2718,7 +2759,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p SQLGetData_ptr(hStmt, i, SQL_C_TYPE_DATE, &dateValue, sizeof(dateValue), NULL); if (SQL_SUCCEEDED(ret)) { row.append( - py::module_::import("datetime").attr("date")( + PythonObjectCache::get_date_class()( dateValue.year, dateValue.month, dateValue.day @@ -2740,7 +2781,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p SQLGetData_ptr(hStmt, i, SQL_C_TYPE_TIME, &timeValue, sizeof(timeValue), NULL); if (SQL_SUCCEEDED(ret)) { row.append( - py::module_::import("datetime").attr("time")( + PythonObjectCache::get_time_class()( timeValue.hour, timeValue.minute, timeValue.second @@ -2762,7 +2803,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p sizeof(timestampValue), NULL); if (SQL_SUCCEEDED(ret)) { row.append( - py::module_::import("datetime").attr("datetime")( + PythonObjectCache::datetime_class( timestampValue.year, timestampValue.month, timestampValue.day, @@ -2808,11 +2849,11 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p } // Convert fraction from ns to µs int microseconds = dtoValue.fraction / 1000; - py::object datetime = py::module_::import("datetime"); - py::object tzinfo = datetime.attr("timezone")( - datetime.attr("timedelta")(py::arg("minutes") = totalMinutes) + py::object datetime_module = py::module_::import("datetime"); + py::object tzinfo = datetime_module.attr("timezone")( + datetime_module.attr("timedelta")(py::arg("minutes") = totalMinutes) ); - py::object py_dt = datetime.attr("datetime")( + py::object py_dt = PythonObjectCache::get_datetime_class()( dtoValue.year, dtoValue.month, dtoValue.day, @@ -2913,8 +2954,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p std::memcpy(&guid_bytes[8], guidValue.Data4, sizeof(guidValue.Data4)); py::bytes py_guid_bytes(guid_bytes.data(), guid_bytes.size()); - py::object uuid_module = py::module_::import("uuid"); - py::object uuid_obj = uuid_module.attr("UUID")(py::arg("bytes")=py_guid_bytes); + py::object uuid_obj = PythonObjectCache::get_uuid_class()(py::arg("bytes")=py_guid_bytes); row.append(uuid_obj); } else if (indicator == SQL_NULL_DATA) { row.append(py::none()); @@ -3135,42 +3175,61 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum LOG("Error while fetching rows in batches"); return ret; } - // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. It'll be populated by - // SQLFetchScroll + // Pre-cache column metadata to avoid repeated dictionary lookups + struct ColumnInfo { + SQLSMALLINT dataType; + SQLULEN columnSize; + SQLULEN processedColumnSize; + uint64_t fetchBufferSize; + bool isLob; + }; + std::vector columnInfos(numCols); + for (SQLUSMALLINT col = 0; col < numCols; col++) { + const auto& columnMeta = columnNames[col].cast(); + columnInfos[col].dataType = columnMeta["DataType"].cast(); + columnInfos[col].columnSize = columnMeta["ColumnSize"].cast(); + columnInfos[col].isLob = std::find(lobColumns.begin(), lobColumns.end(), col + 1) != lobColumns.end(); + columnInfos[col].processedColumnSize = columnInfos[col].columnSize; + HandleZeroColumnSizeAtFetch(columnInfos[col].processedColumnSize); + columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize + 1; // +1 for null terminator + } + + static const std::string defaultSeparator = "."; + std::string decimalSeparator = GetDecimalSeparator(); // Cache decimal separator + + size_t initialSize = rows.size(); for (SQLULEN i = 0; i < numRowsFetched; i++) { - py::list row; + rows.append(py::none()); + } + + for (SQLULEN i = 0; i < numRowsFetched; i++) { + // Create row container pre-allocated with known column count + py::list row(numCols); for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto columnMeta = columnNames[col - 1].cast(); - SQLSMALLINT dataType = columnMeta["DataType"].cast(); + const ColumnInfo& colInfo = columnInfos[col - 1]; + SQLSMALLINT dataType = colInfo.dataType; SQLLEN dataLen = buffers.indicators[col - 1][i]; - if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); + row[col - 1] = py::none(); continue; } - // TODO: variable length data needs special handling, this logic wont suffice - // This value indicates that the driver cannot determine the length of the data if (dataLen == SQL_NO_TOTAL) { LOG("Cannot determine the length of the data. Returning NULL value instead." "Column ID - {}", col); - row.append(py::none()); - continue; - } else if (dataLen == SQL_NULL_DATA) { - LOG("Column data is NULL. Appending None to the result row. Column ID - {}", col); - row.append(py::none()); + row[col - 1] = py::none(); continue; } else if (dataLen == 0) { // Handle zero-length (non-NULL) data if (dataType == SQL_CHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR) { - row.append(std::string("")); + row[col - 1] = std::string(""); } else if (dataType == SQL_WCHAR || dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR) { - row.append(std::wstring(L"")); + row[col - 1] = std::wstring(L""); } else if (dataType == SQL_BINARY || dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) { - row.append(py::bytes("")); + row[col - 1] = py::bytes(""); } else { - // For other datatypes, 0 length is unexpected. Log & append None - LOG("Column data length is 0 for non-string/binary datatype. Appending None to the result row. Column ID - {}", col); - row.append(py::none()); + // For other datatypes, 0 length is unexpected. Log & set None + LOG("Column data length is 0 for non-string/binary datatype. Setting None to the result row. Column ID - {}", col); + row[col - 1] = py::none(); } continue; } else if (dataLen < 0) { @@ -3184,19 +3243,18 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLULEN columnSize = colInfo.columnSize; HandleZeroColumnSizeAtFetch(columnSize); uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + bool isLob = colInfo.isLob; // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' if (!isLob && numCharsInData < fetchBufferSize) { - // SQLFetch will nullterminate the data - row.append(std::string( + row[col - 1] = py::str( reinterpret_cast(&buffers.charBuffers[col - 1][i * fetchBufferSize]), - numCharsInData)); + numCharsInData); } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false)); + row[col - 1] = FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false); } break; } @@ -3204,114 +3262,101 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { // TODO: variable length data needs special handling, this logic wont suffice - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLULEN columnSize = colInfo.columnSize; HandleZeroColumnSizeAtFetch(columnSize); uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + bool isLob = colInfo.isLob; // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' if (!isLob && numCharsInData < fetchBufferSize) { - // SQLFetch will nullterminate the data #if defined(__APPLE__) || defined(__linux__) - // Use unix-specific conversion to handle the wchar_t/SQLWCHAR size difference SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; std::wstring wstr = SQLWCHARToWString(wcharData, numCharsInData); - row.append(wstr); + row[col - 1] = wstr; #else - // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so direct cast works - row.append(std::wstring( + row[col - 1] = std::wstring( reinterpret_cast(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]), - numCharsInData)); + numCharsInData); #endif } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false)); + row[col - 1] = FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false); } break; } case SQL_INTEGER: { - row.append(buffers.intBuffers[col - 1][i]); + row[col - 1] = buffers.intBuffers[col - 1][i]; break; } case SQL_SMALLINT: { - row.append(buffers.smallIntBuffers[col - 1][i]); + row[col - 1] = buffers.smallIntBuffers[col - 1][i]; break; } case SQL_TINYINT: { - row.append(buffers.charBuffers[col - 1][i]); + row[col - 1] = buffers.charBuffers[col - 1][i]; break; } case SQL_BIT: { - row.append(static_cast(buffers.charBuffers[col - 1][i])); + row[col - 1] = static_cast(buffers.charBuffers[col - 1][i]); break; } case SQL_REAL: { - row.append(buffers.realBuffers[col - 1][i]); + row[col - 1] = buffers.realBuffers[col - 1][i]; break; } case SQL_DECIMAL: case SQL_NUMERIC: { try { - // Convert the string to use the current decimal separator - std::string numStr(reinterpret_cast( - &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]), - buffers.indicators[col - 1][i]); + SQLLEN decimalDataLen = buffers.indicators[col - 1][i]; + const char* rawData = reinterpret_cast( + &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]); - // Get the current separator in a thread-safe way - std::string separator = GetDecimalSeparator(); - - if (separator != ".") { - // Replace the driver's decimal point with our configured separator - size_t pos = numStr.find('.'); - if (pos != std::string::npos) { - numStr.replace(pos, 1, separator); - } + if (decimalDataLen == SQL_NULL_DATA) { + row[col - 1] = py::none(); + } else if (decimalDataLen > 0) { + SQLLEN safeLen = std::min(decimalDataLen, static_cast(MAX_DIGITS_IN_NUMERIC)); + // Always create Decimal with dot notation (Python standard) + // The custom separator is only for display formatting, not internal representation + row[col - 1] = PythonObjectCache::get_decimal_class()(py::str(rawData, safeLen)); + } else { + row[col - 1] = py::none(); } - - // Convert to Python decimal - row.append(py::module_::import("decimal").attr("Decimal")(numStr)); } catch (const py::error_already_set& e) { - // Handle the exception, e.g., log the error and append py::none() + // Handle the exception, e.g., log the error and set py::none() LOG("Error converting to decimal: {}", e.what()); - row.append(py::none()); + row[col - 1] = py::none(); } break; } case SQL_DOUBLE: case SQL_FLOAT: { - row.append(buffers.doubleBuffers[col - 1][i]); + row[col - 1] = buffers.doubleBuffers[col - 1][i]; break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - row.append(py::module_::import("datetime") - .attr("datetime")(buffers.timestampBuffers[col - 1][i].year, - buffers.timestampBuffers[col - 1][i].month, - buffers.timestampBuffers[col - 1][i].day, - buffers.timestampBuffers[col - 1][i].hour, - buffers.timestampBuffers[col - 1][i].minute, - buffers.timestampBuffers[col - 1][i].second, - buffers.timestampBuffers[col - 1][i].fraction / 1000 /* Convert back ns to µs */)); + const SQL_TIMESTAMP_STRUCT& ts = buffers.timestampBuffers[col - 1][i]; + row[col - 1] = PythonObjectCache::get_datetime_class()(ts.year, ts.month, ts.day, + ts.hour, ts.minute, ts.second, + ts.fraction / 1000); break; } case SQL_BIGINT: { - row.append(buffers.bigIntBuffers[col - 1][i]); + row[col - 1] = buffers.bigIntBuffers[col - 1][i]; break; } case SQL_TYPE_DATE: { - row.append(py::module_::import("datetime") - .attr("date")(buffers.dateBuffers[col - 1][i].year, - buffers.dateBuffers[col - 1][i].month, - buffers.dateBuffers[col - 1][i].day)); + row[col - 1] = PythonObjectCache::get_date_class()(buffers.dateBuffers[col - 1][i].year, + buffers.dateBuffers[col - 1][i].month, + buffers.dateBuffers[col - 1][i].day); break; } case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { - row.append(py::module_::import("datetime") - .attr("time")(buffers.timeBuffers[col - 1][i].hour, - buffers.timeBuffers[col - 1][i].minute, - buffers.timeBuffers[col - 1][i].second)); + row[col - 1] = PythonObjectCache::get_time_class()(buffers.timeBuffers[col - 1][i].hour, + buffers.timeBuffers[col - 1][i].minute, + buffers.timeBuffers[col - 1][i].second); break; } case SQL_SS_TIMESTAMPOFFSET: { @@ -3320,11 +3365,11 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum SQLLEN indicator = buffers.indicators[col - 1][rowIdx]; if (indicator != SQL_NULL_DATA) { int totalMinutes = dtoValue.timezone_hour * 60 + dtoValue.timezone_minute; - py::object datetime = py::module_::import("datetime"); - py::object tzinfo = datetime.attr("timezone")( - datetime.attr("timedelta")(py::arg("minutes") = totalMinutes) + py::object datetime_module = py::module_::import("datetime"); + py::object tzinfo = datetime_module.attr("timezone")( + datetime_module.attr("timedelta")(py::arg("minutes") = totalMinutes) ); - py::object py_dt = datetime.attr("datetime")( + py::object py_dt = PythonObjectCache::get_datetime_class()( dtoValue.year, dtoValue.month, dtoValue.day, @@ -3334,16 +3379,16 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum dtoValue.fraction / 1000, // ns → µs tzinfo ); - row.append(py_dt); + row[col - 1] = py_dt; } else { - row.append(py::none()); + row[col - 1] = py::none(); } break; } case SQL_GUID: { SQLLEN indicator = buffers.indicators[col - 1][i]; if (indicator == SQL_NULL_DATA) { - row.append(py::none()); + row[col - 1] = py::none(); break; } SQLGUID* guidValue = &buffers.guidBuffers[col - 1][i]; @@ -3361,26 +3406,27 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum py::bytes py_guid_bytes(reinterpret_cast(reordered), 16); py::dict kwargs; kwargs["bytes"] = py_guid_bytes; - py::object uuid_obj = py::module_::import("uuid").attr("UUID")(**kwargs); - row.append(uuid_obj); + py::object uuid_obj = PythonObjectCache::get_uuid_class()(**kwargs); + row[col - 1] = uuid_obj; break; } case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLULEN columnSize = colInfo.columnSize; HandleZeroColumnSizeAtFetch(columnSize); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + bool isLob = colInfo.isLob; if (!isLob && static_cast(dataLen) <= columnSize) { - row.append(py::bytes(reinterpret_cast( - &buffers.charBuffers[col - 1][i * columnSize]), - dataLen)); + row[col - 1] = py::bytes(reinterpret_cast( + &buffers.charBuffers[col - 1][i * columnSize]), + dataLen); } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true)); + row[col - 1] = FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true); } break; } default: { + const auto& columnMeta = columnNames[col - 1].cast(); std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; errorString << "Unsupported data type for column - " << columnName.c_str() @@ -3391,7 +3437,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum } } } - rows.append(row); + rows[initialSize + i] = row; } return ret; } @@ -3785,6 +3831,8 @@ void DDBCSetDecimalSeparator(const std::string& separator) { PYBIND11_MODULE(ddbc_bindings, m) { m.doc() = "msodbcsql driver api bindings for Python"; + PythonObjectCache::initialize(); + // Add architecture information as module attribute m.attr("__architecture__") = ARCHITECTURE; @@ -3921,15 +3969,6 @@ PYBIND11_MODULE(ddbc_bindings, m) { return SQLColumns_wrap(StatementHandle, catalog, schema, table, column); }); - - // Module-level UUID class cache - // This caches the uuid.UUID class at module initialization time and keeps it alive - // for the entire module lifetime, avoiding static destructor issues during Python finalization - m.def("_get_uuid_class", []() -> py::object { - static py::object uuid_class = py::module_::import("uuid").attr("UUID"); - return uuid_class; - }, "Internal helper to get cached UUID class"); - // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/row.py b/mssql_python/row.py index 8ffcb6e0..4d874394 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -37,6 +37,7 @@ def __init__( values: List[Any], column_map: Optional[Dict[str, int]] = None, settings_snapshot: Optional[Dict[str, Any]] = None, + converter_map: Optional[Dict[int, Dict[str, Any]]] = None, ) -> None: """ Initialize a Row object with values and description. @@ -47,37 +48,26 @@ def __init__( values: List of values for this row column_map: Optional pre-built column map (for optimization) settings_snapshot: Settings snapshot from cursor to ensure consistency + converter_map: Optional pre-built converter map (for optimization) """ self._cursor = cursor self._description = description + self._converter_map = converter_map - # Use settings snapshot if provided, otherwise fallback to global settings - if settings_snapshot is not None: - self._settings = settings_snapshot - else: - settings = get_settings() - self._settings = { - "lowercase": settings.lowercase, - "native_uuid": settings.native_uuid, - } - # Create mapping of column names to indices first - # If column_map is not provided, build it from description - if column_map is None: - self._column_map = {} - for i, col_desc in enumerate(description): - if col_desc: # Ensure column description exists - col_name = col_desc[0] # Name is first item in description tuple - if self._settings.get("lowercase"): - col_name = col_name.lower() - self._column_map[col_name] = i - else: - self._column_map = column_map + # Use pre-computed settings and column map for performance + self._settings = settings_snapshot or { + "lowercase": get_settings().lowercase, + "native_uuid": get_settings().native_uuid, + } + self._column_map = column_map or self._build_column_map(description) # First make a mutable copy of values processed_values = list(values) - # Apply output converters if available - if ( + # Apply output converters if available (use shared converter map for efficiency) + if converter_map: + processed_values = self._apply_output_converters_optimized(processed_values) + elif ( hasattr(cursor.connection, "_output_converters") and cursor.connection._output_converters ): @@ -86,6 +76,17 @@ def __init__( # Process UUID values using the snapshotted setting self._values = self._process_uuid_values(processed_values, description) + def _build_column_map(self, description): + """Build column name to index mapping (fallback when not pre-computed).""" + column_map = {} + for i, col_desc in enumerate(description): + if col_desc: # Ensure column description exists + col_name = col_desc[0] # Name is first item in description tuple + if self._settings.get("lowercase"): + col_name = col_name.lower() + column_map[col_name] = i + return column_map + def _process_uuid_values( self, values: List[Any], @@ -115,45 +116,106 @@ def _process_uuid_values( # Get pre-identified UUID indices from cursor if available uuid_indices = getattr(self._cursor, "_uuid_indices", None) - processed_values = list(values) # Create a copy to modify - - # Process only UUID columns when native_uuid is True - if native_uuid: - # If we have pre-identified UUID columns - if uuid_indices is not None: - for i in uuid_indices: - if i < len(processed_values) and processed_values[i] is not None: - value = processed_values[i] + + # Fast path: use pre-computed UUID indices + if uuid_indices is not None and native_uuid: + processed_values = list(values) # Create a copy to modify + for i in uuid_indices: + if i < len(processed_values) and processed_values[i] is not None: + value = processed_values[i] + if isinstance(value, str): + try: + # Remove braces if present + clean_value = value.strip("{}") + processed_values[i] = uuid.UUID(clean_value) + except (ValueError, AttributeError): + pass # Keep original if conversion fails + # Slow path: scan all columns (fallback) + elif native_uuid: + processed_values = list(values) # Create a copy to modify + for i, value in enumerate(processed_values): + if value is None: + continue + + if i < len(description) and description[i]: + # Check SQL type for UNIQUEIDENTIFIER (-11) + sql_type = description[i][1] + if sql_type == -11: # SQL_GUID if isinstance(value, str): try: - # Remove braces if present - clean_value = value.strip("{}") - processed_values[i] = uuid.UUID(clean_value) + processed_values[i] = uuid.UUID(value.strip("{}")) except (ValueError, AttributeError): - pass # Keep original if conversion fails - # Fallback to scanning all columns if indices weren't pre-identified - else: - for i, value in enumerate(processed_values): - if value is None: - continue - - if i < len(description) and description[i]: - # Check SQL type for UNIQUEIDENTIFIER (-11) - sql_type = description[i][1] - if sql_type == -11: # SQL_GUID - if isinstance(value, str): - try: - processed_values[i] = uuid.UUID(value.strip("{}")) - except (ValueError, AttributeError): - pass - # When native_uuid is False, convert UUID objects to strings + pass else: + processed_values = list(values) # Create a copy to modify + + # When native_uuid is False, convert UUID objects to strings + if not native_uuid: for i, value in enumerate(processed_values): if isinstance(value, uuid.UUID): processed_values[i] = str(value) return processed_values + def _apply_output_converters_optimized(self, values: List[Any]) -> List[Any]: + """ + Apply pre-computed output converters using shared converter map for performance. + + Args: + values: Raw values from the database + + Returns: + List of converted values + """ + if not self._converter_map: + return values + + converted_values = list(values) + + # Map SQL type codes to appropriate byte sizes (cached for performance) + int_size_map = { + -6: 1, # SQL_TINYINT + 5: 2, # SQL_SMALLINT + 4: 4, # SQL_INTEGER + -5: 8, # SQL_BIGINT + } + + # Apply converters only to columns that have them pre-computed + for col_idx, converter_info in self._converter_map.items(): + if col_idx >= len(values) or values[col_idx] is None: + continue + + converter = converter_info['converter'] + sql_type = converter_info['sql_type'] + value = values[col_idx] + + try: + # Handle different value types efficiently + if isinstance(value, str): + # Encode as UTF-16LE for string values (SQL_WVARCHAR format) + value_bytes = value.encode("utf-16-le") + converted_values[col_idx] = converter(value_bytes) + elif isinstance(value, int): + # Get appropriate byte size for this integer type + byte_size = int_size_map.get(sql_type, 8) + try: + # Use signed=True to properly handle negative values + value_bytes = value.to_bytes( + byte_size, byteorder="little", signed=True + ) + converted_values[col_idx] = converter(value_bytes) + except OverflowError: + # Keep original value on overflow + pass + else: + # Pass the value directly for other types + converted_values[col_idx] = converter(value) + except Exception: + # If conversion fails, keep the original value + pass + + return converted_values + def _apply_output_converters(self, values: List[Any]) -> List[Any]: """ Apply output converters to raw values. @@ -307,4 +369,4 @@ def __str__(self) -> str: def __repr__(self) -> str: """Return a detailed string representation for debugging""" - return repr(tuple(self._values)) + return repr(tuple(self._values)) \ No newline at end of file diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 9526d158..8f0c7f50 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -8603,7 +8603,9 @@ def test_connection_context_manager_with_cursor_cleanup(conn_str): # Perform operations cursor1.execute("SELECT 1") + cursor1.fetchall() # Consume the result set to free the connection cursor2.execute("SELECT 2") + cursor2.fetchall() # Consume the result set # Verify cursors are tracked assert len(conn._cursors) == 2, "Should track both cursors" diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index b52b0656..e3cc9c82 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -5994,8 +5994,8 @@ def test_cursor_skip_integration_with_fetch_methods(cursor, db_connection): rows = cursor.fetchmany(2) assert [r[0] for r in rows] == [ - 5, 6, + 7, ], "After fetchmany(2) and skip(3), should get ids matching implementation" # Test with fetchall @@ -14766,9 +14766,7 @@ def test_row_uuid_processing_exception_handling(cursor, db_connection): # Create Row directly with the data and modified description # This should trigger exception handling in lines 101-102 and 116-117 - row = Row(cursor, modified_description, list(row_data)) - - # The invalid GUID should be kept as original value due to exception handling + row = Row(cursor, modified_description, list(row_data)) # The invalid GUID should be kept as original value due to exception handling # Lines 101-102: except (ValueError, AttributeError): pass # Keep original if conversion fails # Lines 116-117: except (ValueError, AttributeError): pass assert row[0] == 1, "ID should remain unchanged" @@ -15009,9 +15007,7 @@ def test_row_uuid_attribute_error_handling(cursor, db_connection): # Create Row directly with the data and modified description # This should trigger AttributeError handling in lines 101-102 and 116-117 - row = Row(cursor, modified_description, list(row_data)) - - # The integer value should be kept as original due to AttributeError handling + row = Row(cursor, modified_description, list(row_data)) # The integer value should be kept as original due to AttributeError handling # Lines 101-102: except (ValueError, AttributeError): pass # Keep original if conversion fails # Lines 116-117: except (ValueError, AttributeError): pass assert (