From ad905cedebf0f920630379e477244698b3da3756 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 11 Jun 2025 10:04:48 +0530 Subject: [PATCH 01/10] Bug fixing --- main.py | 26 +++++-- mssql_python/cursor.py | 17 ++++- mssql_python/pybind/ddbc_bindings.cpp | 99 ++++++++++++++++++++++++--- 3 files changed, 122 insertions(+), 20 deletions(-) diff --git a/main.py b/main.py index b45b88d7..c6908037 100644 --- a/main.py +++ b/main.py @@ -3,19 +3,33 @@ import os import decimal -setup_logging('stdout') +# setup_logging('stdout') conn_str = os.getenv("DB_CONNECTION_STRING") conn = connect(conn_str) -# conn.autocommit = True - cursor = conn.cursor() cursor.execute("SELECT database_id, name from sys.databases;") -rows = cursor.fetchall() +rows = cursor.fetchone() + +# Debug: Print the type and content of rows +print(f"Type of rows: {type(rows)}") +print(f"Value of rows: {rows}") -for row in rows: - print(f"Database ID: {row[0]}, Name: {row[1]}") +# Only try to access properties if rows is not None +if rows is not None: + try: + # Try different ways to access the data + print(f"First column by index: {rows[0]}") + + # Access by attribute name (these should now work) + print(f"First column by name: {rows.database_id}") + print(f"Second column by name: {rows.name}") + + # Print all available attributes + print(f"Available attributes: {dir(rows)}") + except Exception as e: + print(f"Exception accessing row data: {e}") cursor.close() conn.close() \ No newline at end of file diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 2f735cf2..d3d3b814 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -6,6 +6,7 @@ import ctypes import decimal import uuid +import collections import datetime from typing import List, Union from mssql_python.constants import ConstantsDDBC as ddbc_sql_const @@ -655,12 +656,22 @@ def fetchone(self) -> Union[None, tuple]: """ self._check_closed() # Check if the cursor is closed - row = [] - ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row) + # Use a list to receive the row data + row_list = [] + ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_list) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) if ret == ddbc_sql_const.SQL_NO_DATA.value: return None - return list(row) + + # Get field names from the description attribute + field_names = [desc[0] for desc in self.description] + + # Create a namedtuple on the Python side + RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) + result = RowRecord(*row_list) + + print(f"DEBUG - Row type: {type(result)}, value: {result}") + return result def fetchmany(self, size: int = None) -> List[tuple]: """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f480fe2a..5c84fc61 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1807,7 +1807,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { // FetchOne_wrap - Fetches a single row of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be fetched. -// @param row: A Python list that will be populated with the fetched row data. +// @param row: A Python object reference that will be populated with a named tuple containing the fetched row data. // // @return SQLRETURN: SQL_SUCCESS or SQL_SUCCESS_WITH_INFO if data is fetched successfully, // SQL_NO_DATA if there are no more rows to fetch, @@ -1815,21 +1815,98 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { // // This function assumes that the statement handle (hStmt) is already allocated and a query has been // executed. It fetches the next row of data from the result set and populates the provided Python -// list with the row data. If there are no more rows to fetch, it returns SQL_NO_DATA. If an error -// occurs during fetching, it throws a runtime error. -SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row) { +// object with a named tuple containing the row data. If there are no more rows to fetch, it returns +// SQL_NO_DATA. If an error occurs during fetching, it throws a runtime error. +SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::object& row) { SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); - // Assume hStmt is already allocated and a query has been executed + if (!SQLFetch_ptr) { + LOG("Function pointer not initialized in FetchOne_wrap. Loading the driver."); + DriverLoader::getInstance().loadDriver(); + } + ret = SQLFetch_ptr(hStmt); - if (SQL_SUCCEEDED(ret)) { - // Retrieve column count - SQLSMALLINT colCount = SQLNumResultCols_wrap(StatementHandle); - ret = SQLGetData_wrap(StatementHandle, colCount, row); - } else if (ret != SQL_NO_DATA) { - LOG("Error when fetching data"); + if (ret == SQL_NO_DATA) { + row = py::none(); + return ret; + } else if (!SQL_SUCCEEDED(ret)) { + LOG("Error when fetching data: SQLFetch_ptr failed with retcode {}", ret); + row = py::none(); + return ret; + } + + // Retrieve column metadata + py::list columnMetadata; + SQLRETURN descRet = SQLDescribeCol_wrap(StatementHandle, columnMetadata); + if (!SQL_SUCCEEDED(descRet)) { + LOG("Error when fetching column metadata: SQLDescribeCol_wrap failed with retcode {}", descRet); + row = py::none(); + return descRet; + } + + // Extract column names for namedtuple + py::list columnNames; + for (const auto& item : columnMetadata) { + py::dict colDict = item.cast(); + std::wstring wColumnName = colDict["ColumnName"].cast(); + + // Convert wstring to UTF-8 string first + std::string utf8ColumnName; + + // Windows-specific wide string to UTF-8 conversion + int size_needed = WideCharToMultiByte(CP_UTF8, 0, wColumnName.c_str(), + (int)wColumnName.length(), NULL, 0, NULL, NULL); + utf8ColumnName.resize(size_needed); + WideCharToMultiByte(CP_UTF8, 0, wColumnName.c_str(), (int)wColumnName.length(), + &utf8ColumnName[0], size_needed, NULL, NULL); + + // Now create a Python string from the UTF-8 encoded string + py::str pyColumnName = py::str(utf8ColumnName); + columnNames.append(pyColumnName); } + + // Get column count + SQLSMALLINT colCount = static_cast(columnMetadata.size()); + + // Get row data + py::list rowDataList; + ret = SQLGetData_wrap(StatementHandle, colCount, rowDataList); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error when fetching data values: SQLGetData_wrap failed with retcode {}", ret); + row = py::none(); + return ret; + } + + // Create named tuple with column names and data + try { + py::module_ collections = py::module_::import("collections"); + + // Create namedtuple type with column names + // Use rename=True to handle invalid identifiers (e.g., names with spaces) + py::object namedtuple_type = collections.attr("namedtuple")( + "RowRecord", columnNames, py::arg("rename") = true); + + // Convert rowDataList to tuple arguments + py::tuple data_args(rowDataList.size()); + for (size_t i = 0; i < rowDataList.size(); ++i) { + data_args[i] = rowDataList[i]; + } + + // Create named tuple instance and assign to the output row parameter + row = namedtuple_type(*data_args); + } + catch (const py::error_already_set& e) { + LOG("Error creating namedtuple: {}. Falling back to returning data as list.", e.what()); + // Fall back to returning the list if namedtuple creation fails + row = rowDataList; + } + printf("Column names: %s\n", py::str(columnNames).cast().c_str()); + printf("Row data: %s\n", py::str(rowDataList).cast().c_str()); + // After creating the named tuple: + printf("Named tuple created successfully: %s\n", py::str(row).cast().c_str()); + + return ret; } From f5a17e28bd32ee64b3c106f515dc938225a27858 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 11 Jun 2025 11:11:45 +0530 Subject: [PATCH 02/10] BUGFIX: Fetchone fix --- mssql_python/cursor.py | 26 +++++++-- mssql_python/pybind/ddbc_bindings.cpp | 83 +++++---------------------- 2 files changed, 34 insertions(+), 75 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index d3d3b814..b906b1b6 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -649,7 +649,8 @@ def fetchone(self) -> Union[None, tuple]: Fetch the next row of a query result set. Returns: - Single sequence or None if no more data is available. + A named tuple representing a single row or None if no more data is available. + The named tuple allows access by column name (e.g., row.column_name) or by index. Raises: Error: If the previous call to execute did not produce any result set. @@ -660,18 +661,33 @@ def fetchone(self) -> Union[None, tuple]: row_list = [] ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_list) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + if ret == ddbc_sql_const.SQL_NO_DATA.value: return None - + + print(f"DEBUG - Row list from C++: {row_list}") + + # If the row list is empty, return None + if not row_list: + return None + # Get field names from the description attribute field_names = [desc[0] for desc in self.description] + print(f"DEBUG - Field names: {field_names}") # Create a namedtuple on the Python side RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) - result = RowRecord(*row_list) - print(f"DEBUG - Row type: {type(result)}, value: {result}") - return result + try: + result = RowRecord(*row_list) + print(f"DEBUG - Created named tuple: {result}") + return result + except TypeError as e: + print(f"ERROR creating namedtuple: {e}") + print(f"Row list: {row_list}") + print(f"Field names: {field_names}") + # Fall back to returning the list directly + return tuple(row_list) if row_list else None def fetchmany(self, size: int = None) -> List[tuple]: """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 5c84fc61..40f88b05 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1817,96 +1817,39 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { // executed. It fetches the next row of data from the result set and populates the provided Python // object with a named tuple containing the row data. If there are no more rows to fetch, it returns // SQL_NO_DATA. If an error occurs during fetching, it throws a runtime error. -SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::object& row) { +SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row_list) { SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); if (!SQLFetch_ptr) { - LOG("Function pointer not initialized in FetchOne_wrap. Loading the driver."); + printf("Function pointer not initialized in FetchOne_wrap. Loading the driver.\n"); DriverLoader::getInstance().loadDriver(); } ret = SQLFetch_ptr(hStmt); if (ret == SQL_NO_DATA) { - row = py::none(); return ret; } else if (!SQL_SUCCEEDED(ret)) { - LOG("Error when fetching data: SQLFetch_ptr failed with retcode {}", ret); - row = py::none(); + printf("Error when fetching data: SQLFetch_ptr failed with retcode %d\n", ret); return ret; } - // Retrieve column metadata - py::list columnMetadata; - SQLRETURN descRet = SQLDescribeCol_wrap(StatementHandle, columnMetadata); - if (!SQL_SUCCEEDED(descRet)) { - LOG("Error when fetching column metadata: SQLDescribeCol_wrap failed with retcode {}", descRet); - row = py::none(); - return descRet; - } - - // Extract column names for namedtuple - py::list columnNames; - for (const auto& item : columnMetadata) { - py::dict colDict = item.cast(); - std::wstring wColumnName = colDict["ColumnName"].cast(); - - // Convert wstring to UTF-8 string first - std::string utf8ColumnName; - - // Windows-specific wide string to UTF-8 conversion - int size_needed = WideCharToMultiByte(CP_UTF8, 0, wColumnName.c_str(), - (int)wColumnName.length(), NULL, 0, NULL, NULL); - utf8ColumnName.resize(size_needed); - WideCharToMultiByte(CP_UTF8, 0, wColumnName.c_str(), (int)wColumnName.length(), - &utf8ColumnName[0], size_needed, NULL, NULL); - - // Now create a Python string from the UTF-8 encoded string - py::str pyColumnName = py::str(utf8ColumnName); - columnNames.append(pyColumnName); + // Get column count - we don't need column names in C++ anymore + SQLSMALLINT colCount; + SQLRETURN colRet = SQLNumResultCols_ptr(hStmt, &colCount); + if (!SQL_SUCCEEDED(colRet)) { + printf("Error when getting column count: SQLNumResultCols_ptr failed with retcode %d\n", colRet); + return colRet; } - // Get column count - SQLSMALLINT colCount = static_cast(columnMetadata.size()); - - // Get row data - py::list rowDataList; - ret = SQLGetData_wrap(StatementHandle, colCount, rowDataList); + // Get row data into the list + ret = SQLGetData_wrap(StatementHandle, colCount, row_list); if (!SQL_SUCCEEDED(ret)) { - LOG("Error when fetching data values: SQLGetData_wrap failed with retcode {}", ret); - row = py::none(); + printf("Error when fetching data values: SQLGetData_wrap failed with retcode %d\n", ret); return ret; } - // Create named tuple with column names and data - try { - py::module_ collections = py::module_::import("collections"); - - // Create namedtuple type with column names - // Use rename=True to handle invalid identifiers (e.g., names with spaces) - py::object namedtuple_type = collections.attr("namedtuple")( - "RowRecord", columnNames, py::arg("rename") = true); - - // Convert rowDataList to tuple arguments - py::tuple data_args(rowDataList.size()); - for (size_t i = 0; i < rowDataList.size(); ++i) { - data_args[i] = rowDataList[i]; - } - - // Create named tuple instance and assign to the output row parameter - row = namedtuple_type(*data_args); - } - catch (const py::error_already_set& e) { - LOG("Error creating namedtuple: {}. Falling back to returning data as list.", e.what()); - // Fall back to returning the list if namedtuple creation fails - row = rowDataList; - } - printf("Column names: %s\n", py::str(columnNames).cast().c_str()); - printf("Row data: %s\n", py::str(rowDataList).cast().c_str()); - // After creating the named tuple: - printf("Named tuple created successfully: %s\n", py::str(row).cast().c_str()); - - + printf("Row data: %s\n", py::str(row_list).cast().c_str()); return ret; } From e65900be8b6d5bdd24442f3a4cec2d629d99bfbb Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 11 Jun 2025 11:43:31 +0530 Subject: [PATCH 03/10] Changing testcases to handle namedtuple --- mssql_python/cursor.py | 33 +++++----- tests/test_004_cursor.py | 130 ++++++++++++++++++++++++++++----------- 2 files changed, 110 insertions(+), 53 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index b906b1b6..7ab219af 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -649,8 +649,8 @@ def fetchone(self) -> Union[None, tuple]: Fetch the next row of a query result set. Returns: - A named tuple representing a single row or None if no more data is available. - The named tuple allows access by column name (e.g., row.column_name) or by index. + A tuple representing a single row or None if no more data is available. + The tuple allows access by index. Raises: Error: If the previous call to execute did not produce any result set. @@ -665,29 +665,28 @@ def fetchone(self) -> Union[None, tuple]: if ret == ddbc_sql_const.SQL_NO_DATA.value: return None - print(f"DEBUG - Row list from C++: {row_list}") - # If the row list is empty, return None if not row_list: return None # Get field names from the description attribute field_names = [desc[0] for desc in self.description] - print(f"DEBUG - Field names: {field_names}") - # Create a namedtuple on the Python side - RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) + # Check if field names are valid for namedtuple + valid_fieldnames = all(isinstance(name, str) and name.isidentifier() for name in field_names) - try: - result = RowRecord(*row_list) - print(f"DEBUG - Created named tuple: {result}") - return result - except TypeError as e: - print(f"ERROR creating namedtuple: {e}") - print(f"Row list: {row_list}") - print(f"Field names: {field_names}") - # Fall back to returning the list directly - return tuple(row_list) if row_list else None + if valid_fieldnames: + # Create a namedtuple on the Python side + try: + RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) + result = RowRecord(*row_list) + return result + except (TypeError, ValueError) as e: + # Fall back to a regular tuple for any error in namedtuple creation + return tuple(row_list) + else: + # If field names aren't valid identifiers, return a regular tuple + return tuple(row_list) def fetchmany(self, size: int = None) -> List[tuple]: """ diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 45f3663d..33053f3f 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -367,10 +367,13 @@ def test_wvarchar_full_capacity(cursor, db_connection): def test_varbinary_full_capacity(cursor, db_connection): """Test SQL_VARBINARY""" try: + # Drop the table if it exists first + drop_table_if_exists(cursor, "pytest_varbinary_test") + cursor.execute("CREATE TABLE pytest_varbinary_test (varbinary_column VARBINARY(8))") db_connection.commit() # Try inserting binary using both bytes & bytearray - cursor.execute("INSERT INTO pytest_varbinary_test (varbinary_column) VALUES (?)", bytearray("12345", 'utf-8')) + cursor.execute("INSERT INTO pytest_varbinary_test (varbinary_column) VALUES (?)", bytearray("12345", 'utf-8')) cursor.execute("INSERT INTO pytest_varbinary_test (varbinary_column) VALUES (?)", bytes("12345678", 'utf-8')) # Full capacity db_connection.commit() expectedRows = 2 @@ -380,8 +383,11 @@ def test_varbinary_full_capacity(cursor, db_connection): for i in range(0, expectedRows): rows.append(cursor.fetchone()) assert cursor.fetchone() == None, "varbinary_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == [bytes("12345", 'utf-8')], "SQL_VARBINARY parsing failed for fetchone - row 0" - assert rows[1] == [bytes("12345678", 'utf-8')], "SQL_VARBINARY parsing failed for fetchone - row 1" + + # Use the compare_row_value function for assertions + assert compare_row_value(rows[0], [bytes("12345", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchone - row 0" + assert compare_row_value(rows[1], [bytes("12345678", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchone - row 1" + # fetchall test cursor.execute("SELECT varbinary_column FROM pytest_varbinary_test") rows = cursor.fetchall() @@ -390,7 +396,8 @@ def test_varbinary_full_capacity(cursor, db_connection): except Exception as e: pytest.fail(f"SQL_VARBINARY parsing test failed: {e}") finally: - cursor.execute("DROP TABLE pytest_varbinary_test") + # Clean up the table + drop_table_if_exists(cursor, "pytest_varbinary_test") db_connection.commit() def test_varchar_max(cursor, db_connection): @@ -407,8 +414,10 @@ def test_varchar_max(cursor, db_connection): for i in range(0, expectedRows): rows.append(cursor.fetchone()) assert cursor.fetchone() == None, "varchar_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == ["ABCDEFGHI"], "SQL_VARCHAR parsing failed for fetchone - row 0" - assert rows[1] == [None], "SQL_VARCHAR parsing failed for fetchone - row 1" + + assert compare_row_value(rows[0], ["ABCDEFGHI"]), "SQL_VARCHAR parsing failed for fetchone - row 0" + assert compare_row_value(rows[1], [None]), "SQL_VARCHAR parsing failed for fetchone - row 1" + # fetchall test cursor.execute("SELECT varchar_column FROM pytest_varchar_test") rows = cursor.fetchall() @@ -416,16 +425,13 @@ def test_varchar_max(cursor, db_connection): assert rows[1] == [None], "SQL_VARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_VARCHAR parsing test failed: {e}") - finally: - cursor.execute("DROP TABLE pytest_varchar_test") - db_connection.commit() def test_wvarchar_max(cursor, db_connection): """Test SQL_WVARCHAR with MAX length""" try: cursor.execute("CREATE TABLE pytest_wvarchar_test (wvarchar_column NVARCHAR(MAX))") db_connection.commit() - cursor.execute("INSERT INTO pytest_wvarchar_test (wvarchar_column) VALUES (?), (?)", ["!@#$%^&*()_+", None]) + cursor.execute("INSERT INTO pytest_wvarchar_test (wvarchar_column) VALUES (?), (?)", ["!@#$%^&*()_+", None]) db_connection.commit() expectedRows = 2 # fetchone test @@ -434,8 +440,10 @@ def test_wvarchar_max(cursor, db_connection): for i in range(0, expectedRows): rows.append(cursor.fetchone()) assert cursor.fetchone() == None, "wvarchar_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == ["!@#$%^&*()_+"], "SQL_WVARCHAR parsing failed for fetchone - row 0" - assert rows[1] == [None], "SQL_WVARCHAR parsing failed for fetchone - row 1" + + assert compare_row_value(rows[0], ["!@#$%^&*()_+"]), "SQL_WVARCHAR parsing failed for fetchone - row 0" + assert compare_row_value(rows[1], [None]), "SQL_WVARCHAR parsing failed for fetchone - row 1" + # fetchall test cursor.execute("SELECT wvarchar_column FROM pytest_wvarchar_test") rows = cursor.fetchall() @@ -443,13 +451,13 @@ def test_wvarchar_max(cursor, db_connection): assert rows[1] == [None], "SQL_WVARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_WVARCHAR parsing test failed: {e}") - finally: - cursor.execute("DROP TABLE pytest_wvarchar_test") - db_connection.commit() def test_varbinary_max(cursor, db_connection): """Test SQL_VARBINARY with MAX length""" try: + # Drop the table if it exists first + drop_table_if_exists(cursor, "pytest_varbinary_test") + cursor.execute("CREATE TABLE pytest_varbinary_test (varbinary_column VARBINARY(MAX))") db_connection.commit() # TODO: Uncomment this execute after adding null binary support @@ -463,8 +471,10 @@ def test_varbinary_max(cursor, db_connection): for i in range(0, expectedRows): rows.append(cursor.fetchone()) assert cursor.fetchone() == None, "varbinary_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == [bytearray("ABCDEF", 'utf-8')], "SQL_VARBINARY parsing failed for fetchone - row 0" - assert rows[1] == [bytes("123!@#", 'utf-8')], "SQL_VARBINARY parsing failed for fetchone - row 1" + + assert compare_row_value(rows[0], [bytearray("ABCDEF", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchone - row 0" + assert compare_row_value(rows[1], [bytes("123!@#", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchone - row 1" + # fetchall test cursor.execute("SELECT varbinary_column FROM pytest_varbinary_test") rows = cursor.fetchall() @@ -473,15 +483,19 @@ def test_varbinary_max(cursor, db_connection): except Exception as e: pytest.fail(f"SQL_VARBINARY parsing test failed: {e}") finally: - cursor.execute("DROP TABLE pytest_varbinary_test") + # Clean up the table + drop_table_if_exists(cursor, "pytest_varbinary_test") db_connection.commit() def test_longvarchar(cursor, db_connection): """Test SQL_LONGVARCHAR""" try: + # Drop the table if it exists first + drop_table_if_exists(cursor, "pytest_longvarchar_test") + cursor.execute("CREATE TABLE pytest_longvarchar_test (longvarchar_column TEXT)") db_connection.commit() - cursor.execute("INSERT INTO pytest_longvarchar_test (longvarchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) + cursor.execute("INSERT INTO pytest_longvarchar_test (longvarchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) db_connection.commit() expectedRows = 2 # fetchone test @@ -489,9 +503,11 @@ def test_longvarchar(cursor, db_connection): rows = [] for i in range(0, expectedRows): rows.append(cursor.fetchone()) - assert cursor.fetchone() == None, "longvarchar_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == ["ABCDEFGHI"], "SQL_LONGVARCHAR parsing failed for fetchone - row 0" - assert rows[1] == [None], "SQL_LONGVARCHAR parsing failed for fetchone - row 1" + assert cursor.fetchone() == None, "longvarchar_column is expected to have only {} rows".format(expectedRows) + + assert compare_row_value(rows[0], ["ABCDEFGHI"]), "SQL_LONGVARCHAR parsing failed for fetchone - row 0" + assert compare_row_value(rows[1], [None]), "SQL_LONGVARCHAR parsing failed for fetchone - row 1" + # fetchall test cursor.execute("SELECT longvarchar_column FROM pytest_longvarchar_test") rows = cursor.fetchall() @@ -500,15 +516,19 @@ def test_longvarchar(cursor, db_connection): except Exception as e: pytest.fail(f"SQL_LONGVARCHAR parsing test failed: {e}") finally: - cursor.execute("DROP TABLE pytest_longvarchar_test") + # Clean up the table + drop_table_if_exists(cursor, "pytest_longvarchar_test") db_connection.commit() def test_longwvarchar(cursor, db_connection): """Test SQL_LONGWVARCHAR""" try: + # Drop the table if it exists first + drop_table_if_exists(cursor, "pytest_longwvarchar_test") + cursor.execute("CREATE TABLE pytest_longwvarchar_test (longwvarchar_column NTEXT)") db_connection.commit() - cursor.execute("INSERT INTO pytest_longwvarchar_test (longwvarchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) + cursor.execute("INSERT INTO pytest_longwvarchar_test (longwvarchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) db_connection.commit() expectedRows = 2 # fetchone test @@ -516,9 +536,11 @@ def test_longwvarchar(cursor, db_connection): rows = [] for i in range(0, expectedRows): rows.append(cursor.fetchone()) - assert cursor.fetchone() == None, "longwvarchar_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == ["ABCDEFGHI"], "SQL_LONGWVARCHAR parsing failed for fetchone - row 0" - assert rows[1] == [None], "SQL_LONGWVARCHAR parsing failed for fetchone - row 1" + assert cursor.fetchone() == None, "longwvarchar_column is expected to have only {} rows".format(expectedRows) + + assert compare_row_value(rows[0], ["ABCDEFGHI"]), "SQL_LONGWVARCHAR parsing failed for fetchone - row 0" + assert compare_row_value(rows[1], [None]), "SQL_LONGWVARCHAR parsing failed for fetchone - row 1" + # fetchall test cursor.execute("SELECT longwvarchar_column FROM pytest_longwvarchar_test") rows = cursor.fetchall() @@ -527,34 +549,41 @@ def test_longwvarchar(cursor, db_connection): except Exception as e: pytest.fail(f"SQL_LONGWVARCHAR parsing test failed: {e}") finally: - cursor.execute("DROP TABLE pytest_longwvarchar_test") + # Clean up the table + drop_table_if_exists(cursor, "pytest_longwvarchar_test") db_connection.commit() def test_longvarbinary(cursor, db_connection): """Test SQL_LONGVARBINARY""" try: + # Drop the table if it exists first + drop_table_if_exists(cursor, "pytest_longvarbinary_test") + cursor.execute("CREATE TABLE pytest_longvarbinary_test (longvarbinary_column IMAGE)") db_connection.commit() cursor.execute("INSERT INTO pytest_longvarbinary_test (longvarbinary_column) VALUES (?), (?)", [bytearray("ABCDEFGHI", 'utf-8'), bytes("123!@#", 'utf-8')]) db_connection.commit() - expectedRows = 3 + expectedRows = 2 # Note: Your test has expectedRows = 3 but only inserts 2 rows # fetchone test cursor.execute("SELECT longvarbinary_column FROM pytest_longvarbinary_test") rows = [] for i in range(0, expectedRows): rows.append(cursor.fetchone()) - assert cursor.fetchone() == None, "longvarbinary_column is expected to have only {} rows".format(expectedRows) - assert rows[0] == [bytearray("ABCDEFGHI", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchone - row 0" - assert rows[1] == [bytes("123!@#\0\0\0", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchone - row 1" + assert cursor.fetchone() == None, "longvarbinary_column is expected to have only {} rows".format(expectedRows) + + assert compare_row_value(rows[0], [bytearray("ABCDEFGHI", 'utf-8')]), "SQL_LONGVARBINARY parsing failed for fetchone - row 0" + assert compare_row_value(rows[1], [bytes("123!@#\0\0\0", 'utf-8')]), "SQL_LONGVARBINARY parsing failed for fetchone - row 1" + # fetchall test cursor.execute("SELECT longvarbinary_column FROM pytest_longvarbinary_test") rows = cursor.fetchall() - assert rows[0] == [bytearray("ABCDEFGHI", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchall - row 0" + assert rows[0] == [bytearray("ABCDEFGHI", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchall - row 0" assert rows[1] == [bytes("123!@#\0\0\0", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_LONGVARBINARY parsing test failed: {e}") finally: - cursor.execute("DROP TABLE pytest_longvarbinary_test") + # Clean up the table + drop_table_if_exists(cursor, "pytest_longvarbinary_test") db_connection.commit() def test_create_table(cursor, db_connection): @@ -1074,6 +1103,9 @@ def test_boolean(cursor, db_connection): def test_sql_wvarchar(cursor, db_connection): """Test SQL_WVARCHAR""" try: + # Drop the table if it exists first + drop_table_if_exists(cursor, "pytest_wvarchar_test") + cursor.execute("CREATE TABLE pytest_wvarchar_test (wvarchar_column NVARCHAR(255))") db_connection.commit() cursor.execute("INSERT INTO pytest_wvarchar_test (wvarchar_column) VALUES (?)", ['nvarchar data']) @@ -1084,12 +1116,16 @@ def test_sql_wvarchar(cursor, db_connection): except Exception as e: pytest.fail(f"SQL_WVARCHAR parsing test failed: {e}") finally: - cursor.execute("DROP TABLE pytest_wvarchar_test") + # Clean up + drop_table_if_exists(cursor, "pytest_wvarchar_test") db_connection.commit() def test_sql_varchar(cursor, db_connection): """Test SQL_VARCHAR""" try: + # Drop the table if it exists first + drop_table_if_exists(cursor, "pytest_varchar_test") + cursor.execute("CREATE TABLE pytest_varchar_test (varchar_column VARCHAR(255))") db_connection.commit() cursor.execute("INSERT INTO pytest_varchar_test (varchar_column) VALUES (?)", ['varchar data']) @@ -1100,7 +1136,8 @@ def test_sql_varchar(cursor, db_connection): except Exception as e: pytest.fail(f"SQL_VARCHAR parsing test failed: {e}") finally: - cursor.execute("DROP TABLE pytest_varchar_test") + # Clean up + drop_table_if_exists(cursor, "pytest_varchar_test") db_connection.commit() def test_numeric_precision_scale_positive_exponent(cursor, db_connection): @@ -1155,3 +1192,24 @@ def test_close(db_connection): pytest.fail(f"Cursor close test failed: {e}") finally: cursor = db_connection.cursor() + +def compare_row_value(actual, expected): + """Compare a row with expected values, supporting both namedtuple and list formats. + + Args: + actual: The actual row returned from fetchone() or similar (could be namedtuple) + expected: The expected values as a list + + Returns: + bool: True if the values match, False otherwise + """ + # If the actual row is a namedtuple (has _fields attribute) + if hasattr(actual, '_fields'): + # For single column result, compare the first field's value directly + if len(actual) == 1: + return getattr(actual, actual._fields[0]) == expected[0] + # For regular list or tuple + elif isinstance(actual, (list, tuple)): + return actual == expected + + return False From 8a4d4ad2bfd43ba7606cb199fedb4691749d3203 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 11 Jun 2025 13:12:17 +0530 Subject: [PATCH 04/10] Adding same for fetchmany and fetchall --- mssql_python/cursor.py | 59 ++++++++++++++++++++------- mssql_python/pybind/ddbc_bindings.cpp | 9 ++-- tests/test_004_cursor.py | 49 +++++++++++----------- 3 files changed, 74 insertions(+), 43 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 7ab219af..155bf5b4 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -688,49 +688,78 @@ def fetchone(self) -> Union[None, tuple]: # If field names aren't valid identifiers, return a regular tuple return tuple(row_list) - def fetchmany(self, size: int = None) -> List[tuple]: + def fetchmany(self, size: int = None) -> list: """ - Fetch the next set of rows of a query result. + Fetch the next set of rows of a query result, returning a list of tuples. + An empty list is returned when no more rows are available. Args: - size: Number of rows to fetch at a time. + size (int): The number of rows to fetch. If not provided, the cursor's arraysize + is used. Returns: - Sequence of sequences (e.g. list of tuples). + A list of tuples, each representing a row of the result set. Raises: Error: If the previous call to execute did not produce any result set. """ self._check_closed() # Check if the cursor is closed - + if size is None: size = self.arraysize - - # Fetch the next set of rows + rows = [] ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows, size) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - if ret == ddbc_sql_const.SQL_NO_DATA.value: - return [] + + # Get field names from the description attribute for named tuple creation + if rows and self.description: + field_names = [desc[0] for desc in self.description] + + # Check if field names are valid for namedtuple + valid_fieldnames = all(isinstance(name, str) and name.isidentifier() for name in field_names) + + if valid_fieldnames: + # Create a named tuple class for the rows + RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) + + # Transform each row from list to named tuple + return [RowRecord(*row) for row in rows] + return rows - def fetchall(self) -> List[tuple]: + def fetchall(self) -> list: """ - Fetch all (remaining) rows of a query result. + Fetch all (remaining) rows of a query result, returning a list of tuples. + An empty list is returned when no more rows are available. Returns: - Sequence of sequences (e.g. list of tuples). + A list of tuples, each representing a row of the result set. Raises: Error: If the previous call to execute did not produce any result set. """ self._check_closed() # Check if the cursor is closed - - # Fetch all remaining rows + rows = [] ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - return list(rows) + + # Get field names from the description attribute for named tuple creation + if rows and self.description: + field_names = [desc[0] for desc in self.description] + + # Check if field names are valid for namedtuple + valid_fieldnames = all(isinstance(name, str) and name.isidentifier() for name in field_names) + + if valid_fieldnames: + # Create a named tuple class for the rows + RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) + + # Transform each row from list to named tuple + return [RowRecord(*row) for row in rows] + + return rows def nextset(self) -> Union[bool, None]: """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 40f88b05..2afa2f61 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1822,7 +1822,7 @@ SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row_list) { SQLHSTMT hStmt = StatementHandle->get(); if (!SQLFetch_ptr) { - printf("Function pointer not initialized in FetchOne_wrap. Loading the driver.\n"); + LOG("Function pointer not initialized in FetchOne_wrap. Loading the driver.\n"); DriverLoader::getInstance().loadDriver(); } @@ -1830,7 +1830,7 @@ SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row_list) { if (ret == SQL_NO_DATA) { return ret; } else if (!SQL_SUCCEEDED(ret)) { - printf("Error when fetching data: SQLFetch_ptr failed with retcode %d\n", ret); + LOG("Error when fetching data: SQLFetch_ptr failed with retcode {}\n", ret); return ret; } @@ -1838,18 +1838,17 @@ SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row_list) { SQLSMALLINT colCount; SQLRETURN colRet = SQLNumResultCols_ptr(hStmt, &colCount); if (!SQL_SUCCEEDED(colRet)) { - printf("Error when getting column count: SQLNumResultCols_ptr failed with retcode %d\n", colRet); + LOG("Error when getting column count: SQLNumResultCols_ptr failed with retcode {}\n", colRet); return colRet; } // Get row data into the list ret = SQLGetData_wrap(StatementHandle, colCount, row_list); if (!SQL_SUCCEEDED(ret)) { - printf("Error when fetching data values: SQLGetData_wrap failed with retcode %d\n", ret); + LOG("Error when fetching data values: SQLGetData_wrap failed with retcode {}\n", ret); return ret; } - printf("Row data: %s\n", py::str(row_list).cast().c_str()); return ret; } diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 33053f3f..799a87a3 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -334,7 +334,7 @@ def test_varchar_full_capacity(cursor, db_connection): # fetchall test cursor.execute("SELECT varchar_column FROM pytest_varchar_test") rows = cursor.fetchall() - assert rows[0] == ['123456789'], "SQL_VARCHAR parsing failed for fetchall" + assert compare_row_value(rows[0], ['123456789']), "SQL_VARCHAR parsing failed for fetchall" except Exception as e: pytest.fail(f"SQL_VARCHAR parsing test failed: {e}") finally: @@ -356,7 +356,7 @@ def test_wvarchar_full_capacity(cursor, db_connection): # fetchall test cursor.execute("SELECT wvarchar_column FROM pytest_wvarchar_test") rows = cursor.fetchall() - assert rows[0] == ['123456'], "SQL_WVARCHAR parsing failed for fetchall" + assert compare_row_value(rows[0],['123456']), "SQL_WVARCHAR parsing failed for fetchall" except Exception as e: pytest.fail(f"SQL_WVARCHAR parsing test failed: {e}") finally: @@ -391,8 +391,8 @@ def test_varbinary_full_capacity(cursor, db_connection): # fetchall test cursor.execute("SELECT varbinary_column FROM pytest_varbinary_test") rows = cursor.fetchall() - assert rows[0] == [bytes("12345", 'utf-8')], "SQL_VARBINARY parsing failed for fetchall - row 0" - assert rows[1] == [bytes("12345678", 'utf-8')], "SQL_VARBINARY parsing failed for fetchall - row 1" + assert compare_row_value(rows[0],[bytes("12345", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchall - row 0" + assert compare_row_value(rows[1],[bytes("12345678", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_VARBINARY parsing test failed: {e}") finally: @@ -421,8 +421,8 @@ def test_varchar_max(cursor, db_connection): # fetchall test cursor.execute("SELECT varchar_column FROM pytest_varchar_test") rows = cursor.fetchall() - assert rows[0] == ["ABCDEFGHI"], "SQL_VARCHAR parsing failed for fetchall - row 0" - assert rows[1] == [None], "SQL_VARCHAR parsing failed for fetchall - row 1" + assert compare_row_value(rows[0],["ABCDEFGHI"]), "SQL_VARCHAR parsing failed for fetchall - row 0" + assert compare_row_value(rows[1],[None]), "SQL_VARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_VARCHAR parsing test failed: {e}") @@ -447,8 +447,8 @@ def test_wvarchar_max(cursor, db_connection): # fetchall test cursor.execute("SELECT wvarchar_column FROM pytest_wvarchar_test") rows = cursor.fetchall() - assert rows[0] == ["!@#$%^&*()_+"], "SQL_WVARCHAR parsing failed for fetchall - row 0" - assert rows[1] == [None], "SQL_WVARCHAR parsing failed for fetchall - row 1" + assert compare_row_value(rows[0],["!@#$%^&*()_+"]), "SQL_WVARCHAR parsing failed for fetchall - row 0" + assert compare_row_value(rows[1],[None]), "SQL_WVARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_WVARCHAR parsing test failed: {e}") @@ -478,8 +478,8 @@ def test_varbinary_max(cursor, db_connection): # fetchall test cursor.execute("SELECT varbinary_column FROM pytest_varbinary_test") rows = cursor.fetchall() - assert rows[0] == [bytearray("ABCDEF", 'utf-8')], "SQL_VARBINARY parsing failed for fetchall - row 0" - assert rows[1] == [bytes("123!@#", 'utf-8')], "SQL_VARBINARY parsing failed for fetchall - row 1" + assert compare_row_value(rows[0],[bytearray("ABCDEF", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchall - row 0" + assert compare_row_value(rows[1],[bytes("123!@#", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_VARBINARY parsing test failed: {e}") finally: @@ -511,8 +511,8 @@ def test_longvarchar(cursor, db_connection): # fetchall test cursor.execute("SELECT longvarchar_column FROM pytest_longvarchar_test") rows = cursor.fetchall() - assert rows[0] == ["ABCDEFGHI"], "SQL_LONGVARCHAR parsing failed for fetchall - row 0" - assert rows[1] == [None], "SQL_LONGVARCHAR parsing failed for fetchall - row 1" + assert compare_row_value(rows[0],["ABCDEFGHI"]), "SQL_LONGVARCHAR parsing failed for fetchall - row 0" + assert compare_row_value(rows[1],[None]), "SQL_LONGVARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_LONGVARCHAR parsing test failed: {e}") finally: @@ -544,8 +544,8 @@ def test_longwvarchar(cursor, db_connection): # fetchall test cursor.execute("SELECT longwvarchar_column FROM pytest_longwvarchar_test") rows = cursor.fetchall() - assert rows[0] == ["ABCDEFGHI"], "SQL_LONGWVARCHAR parsing failed for fetchall - row 0" - assert rows[1] == [None], "SQL_LONGWVARCHAR parsing failed for fetchall - row 1" + assert compare_row_value(rows[0],["ABCDEFGHI"]), "SQL_LONGWVARCHAR parsing failed for fetchall - row 0" + assert compare_row_value(rows[1],[None]), "SQL_LONGWVARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_LONGWVARCHAR parsing test failed: {e}") finally: @@ -577,8 +577,8 @@ def test_longvarbinary(cursor, db_connection): # fetchall test cursor.execute("SELECT longvarbinary_column FROM pytest_longvarbinary_test") rows = cursor.fetchall() - assert rows[0] == [bytearray("ABCDEFGHI", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchall - row 0" - assert rows[1] == [bytes("123!@#\0\0\0", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchall - row 1" + assert compare_row_value(rows[0],[bytearray("ABCDEFGHI", 'utf-8')]), "SQL_LONGVARBINARY parsing failed for fetchall - row 0" + assert compare_row_value(rows[1],[bytes("123!@#\0\0\0", 'utf-8')]), "SQL_LONGVARBINARY parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_LONGVARBINARY parsing test failed: {e}") finally: @@ -872,9 +872,9 @@ def test_join_operations(cursor): """) rows = cursor.fetchall() assert len(rows) == 3, "Join operation returned incorrect number of rows" - assert rows[0] == ['Alice', 'HR', 'Project A'], "Join operation returned incorrect data for row 1" - assert rows[1] == ['Bob', 'Engineering', 'Project B'], "Join operation returned incorrect data for row 2" - assert rows[2] == ['Charlie', 'HR', 'Project C'], "Join operation returned incorrect data for row 3" + assert compare_row_value(rows[0],['Alice', 'HR', 'Project A']), "Join operation returned incorrect data for row 1" + assert compare_row_value(rows[1],['Bob', 'Engineering', 'Project B']), "Join operation returned incorrect data for row 2" + assert compare_row_value(rows[2],['Charlie', 'HR', 'Project C']), "Join operation returned incorrect data for row 3" except Exception as e: pytest.fail(f"Join operation failed: {e}") @@ -892,8 +892,8 @@ def test_join_operations_with_parameters(cursor): cursor.execute(query, employee_ids) rows = cursor.fetchall() assert len(rows) == 2, "Join operation with parameters returned incorrect number of rows" - assert rows[0] == ['Alice', 'HR', 'Project A'], "Join operation with parameters returned incorrect data for row 1" - assert rows[1] == ['Bob', 'Engineering', 'Project B'], "Join operation with parameters returned incorrect data for row 2" + assert compare_row_value(rows[0],['Alice', 'HR', 'Project A']), "Join operation with parameters returned incorrect data for row 1" + assert compare_row_value(rows[1],['Bob', 'Engineering', 'Project B']), "Join operation with parameters returned incorrect data for row 2" except Exception as e: pytest.fail(f"Join operation with parameters failed: {e}") @@ -924,7 +924,7 @@ def test_execute_stored_procedure_with_parameters(cursor): cursor.execute("{CALL GetEmployeeProjects(?)}", [1]) rows = cursor.fetchall() assert len(rows) == 1, "Stored procedure with parameters returned incorrect number of rows" - assert rows[0] == ['Alice', 'Project A'], "Stored procedure with parameters returned incorrect data" + assert compare_row_value(rows[0],['Alice', 'Project A']), "Stored procedure with parameters returned incorrect data" except Exception as e: pytest.fail(f"Stored procedure execution with parameters failed: {e}") @@ -937,7 +937,7 @@ def test_execute_stored_procedure_without_parameters(cursor): """) rows = cursor.fetchall() assert len(rows) == 1, "Stored procedure without parameters returned incorrect number of rows" - assert rows[0] == ['Bob', 'Project B'], "Stored procedure without parameters returned incorrect data" + assert compare_row_value(rows[0],['Bob', 'Project B']), "Stored procedure without parameters returned incorrect data" except Exception as e: pytest.fail(f"Stored procedure execution without parameters failed: {e}") @@ -1208,6 +1208,9 @@ def compare_row_value(actual, expected): # For single column result, compare the first field's value directly if len(actual) == 1: return getattr(actual, actual._fields[0]) == expected[0] + # For multiple columns, extract all values and compare as lists + actual_values = [getattr(actual, field) for field in actual._fields] + return actual_values == expected # For regular list or tuple elif isinstance(actual, (list, tuple)): return actual == expected From 84e0af1b2b7f3b0752b95649a52145cec3a5cded Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 13 Jun 2025 14:31:05 +0530 Subject: [PATCH 05/10] Resolving comments --- mssql_python/cursor.py | 166 +++++++++++++++++++++++++++-------------- 1 file changed, 112 insertions(+), 54 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 155bf5b4..606b9cc8 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -70,6 +70,10 @@ def __init__(self, connection) -> None: # Hence, we can't pass around bools by reference & modify them. # Therefore, it must be a list with exactly one bool element. + # Cache for the named tuple class for the current result set + self._row_namedtuple_class = None + self._row_field_names = None + def _is_unicode_string(self, param): """ Check if a string contains non-ASCII characters. @@ -554,6 +558,10 @@ def execute( if reset_cursor: self._reset_cursor() + # Reset the named tuple class cache when executing a new query + self._row_namedtuple_class = None + self._row_field_names = None + param_info = ddbc_bindings.ParamInfo parameters_type = [] @@ -649,11 +657,20 @@ def fetchone(self) -> Union[None, tuple]: Fetch the next row of a query result set. Returns: - A tuple representing a single row or None if no more data is available. - The tuple allows access by index. + If data is available: + - A named tuple if column names are valid Python identifiers + - A regular tuple otherwise + None if no more data is available + + Named tuples allow access by attribute name (row.column_name) + in addition to index access (row[0]). Raises: Error: If the previous call to execute did not produce any result set. + + Note: + Valid Python identifiers cannot start with numbers and can only + contain alphanumeric characters and underscores. """ self._check_closed() # Check if the cursor is closed @@ -662,30 +679,20 @@ def fetchone(self) -> Union[None, tuple]: ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_list) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - if ret == ddbc_sql_const.SQL_NO_DATA.value: - return None - - # If the row list is empty, return None - if not row_list: + if ret == ddbc_sql_const.SQL_NO_DATA.value or not row_list: return None # Get field names from the description attribute field_names = [desc[0] for desc in self.description] - # Check if field names are valid for namedtuple - valid_fieldnames = all(isinstance(name, str) and name.isidentifier() for name in field_names) + # Get or create the named tuple class + RowRecord = self._get_row_namedtuple_class(field_names) - if valid_fieldnames: - # Create a namedtuple on the Python side - try: - RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) - result = RowRecord(*row_list) - return result - except (TypeError, ValueError) as e: - # Fall back to a regular tuple for any error in namedtuple creation - return tuple(row_list) + if RowRecord: + # Use the cached named tuple class + return RowRecord(*row_list) else: - # If field names aren't valid identifiers, return a regular tuple + # Fall back to a regular tuple return tuple(row_list) def fetchmany(self, size: int = None) -> list: @@ -695,13 +702,22 @@ def fetchmany(self, size: int = None) -> list: Args: size (int): The number of rows to fetch. If not provided, the cursor's arraysize - is used. + is used. Returns: - A list of tuples, each representing a row of the result set. - + A list of row objects where each row is: + - A named tuple if column names are valid Python identifiers + - A regular tuple otherwise + + Named tuples allow access by attribute name (row.column_name) + in addition to index access (row[0]). + Raises: Error: If the previous call to execute did not produce any result set. + + Note: + Valid Python identifiers cannot start with numbers and can only + contain alphanumeric characters and underscores. """ self._check_closed() # Check if the cursor is closed @@ -712,21 +728,21 @@ def fetchmany(self, size: int = None) -> list: ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows, size) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - # Get field names from the description attribute for named tuple creation - if rows and self.description: - field_names = [desc[0] for desc in self.description] - - # Check if field names are valid for namedtuple - valid_fieldnames = all(isinstance(name, str) and name.isidentifier() for name in field_names) - - if valid_fieldnames: - # Create a named tuple class for the rows - RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) - - # Transform each row from list to named tuple - return [RowRecord(*row) for row in rows] + if not rows: + return rows + + # Get field names from the description attribute + field_names = [desc[0] for desc in self.description] - return rows + # Get or create the named tuple class + RowRecord = self._get_row_namedtuple_class(field_names) + + if RowRecord: + # Convert each row to a named tuple + return [RowRecord(*row) for row in rows] + else: + # Return rows as regular tuples + return [tuple(row) for row in rows] def fetchall(self) -> list: """ @@ -734,10 +750,19 @@ def fetchall(self) -> list: An empty list is returned when no more rows are available. Returns: - A list of tuples, each representing a row of the result set. + A list of row objects where each row is: + - A named tuple if column names are valid Python identifiers + - A regular tuple otherwise + + Named tuples allow access by attribute name (row.column_name) + in addition to index access (row[0]). Raises: Error: If the previous call to execute did not produce any result set. + + Note: + Valid Python identifiers cannot start with numbers and can only + contain alphanumeric characters and underscores. """ self._check_closed() # Check if the cursor is closed @@ -745,21 +770,21 @@ def fetchall(self) -> list: ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - # Get field names from the description attribute for named tuple creation - if rows and self.description: - field_names = [desc[0] for desc in self.description] - - # Check if field names are valid for namedtuple - valid_fieldnames = all(isinstance(name, str) and name.isidentifier() for name in field_names) - - if valid_fieldnames: - # Create a named tuple class for the rows - RowRecord = collections.namedtuple('RowRecord', field_names, rename=True) - - # Transform each row from list to named tuple - return [RowRecord(*row) for row in rows] + if not rows: + return rows - return rows + # Get field names from the description attribute + field_names = [desc[0] for desc in self.description] + + # Get or create the named tuple class + RowRecord = self._get_row_namedtuple_class(field_names) + + if RowRecord: + # Convert each row to a named tuple + return [RowRecord(*row) for row in rows] + else: + # Return rows as regular tuples + return [tuple(row) for row in rows] def nextset(self) -> Union[bool, None]: """ @@ -767,9 +792,6 @@ def nextset(self) -> Union[bool, None]: Returns: True if there is another result set, None otherwise. - - Raises: - Error: If the previous call to execute did not produce any result set. """ self._check_closed() # Check if the cursor is closed @@ -779,3 +801,39 @@ def nextset(self) -> Union[bool, None]: if ret == ddbc_sql_const.SQL_NO_DATA.value: return False return True + + def _get_row_namedtuple_class(self, field_names): + """ + Get a cached named tuple class or create a new one if needed. + + Args: + field_names: List of column names from the result set + + Returns: + A named tuple class for the current result set's schema, or None if + the field names are not valid Python identifiers. + """ + # Check if field names are valid for namedtuple + invalid_fields = [name for name in field_names if not (isinstance(name, str) and name.isidentifier())] + if invalid_fields: + if ENABLE_LOGGING: + logger.debug("Cannot create named tuple due to invalid field names: %s", invalid_fields) + return None + + # Check if we already have a cached class with these exact field names + if (self._row_namedtuple_class is not None and + self._row_field_names == field_names): + return self._row_namedtuple_class + + # Create a new named tuple class and cache it + try: + self._row_namedtuple_class = collections.namedtuple('RowRecord', field_names, rename=True) + self._row_field_names = field_names + return self._row_namedtuple_class + except (TypeError, ValueError) as e: + # Log the exception for debugging purposes + if ENABLE_LOGGING: + logger.debug("Failed to create named tuple: %s", str(e)) + self._row_namedtuple_class = None + self._row_field_names = None + return None From a86b99ec9f2a303d7f718e3f9ffa78b0c72546e9 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 13 Jun 2025 14:33:12 +0530 Subject: [PATCH 06/10] Resolving comments --- tests/test_004_cursor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 799a87a3..f0df9aed 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -563,7 +563,7 @@ def test_longvarbinary(cursor, db_connection): db_connection.commit() cursor.execute("INSERT INTO pytest_longvarbinary_test (longvarbinary_column) VALUES (?), (?)", [bytearray("ABCDEFGHI", 'utf-8'), bytes("123!@#", 'utf-8')]) db_connection.commit() - expectedRows = 2 # Note: Your test has expectedRows = 3 but only inserts 2 rows + expectedRows = 2 # The test is intentionally designed for 2 rows # fetchone test cursor.execute("SELECT longvarbinary_column FROM pytest_longvarbinary_test") rows = [] From 919f94081fdf45baa6c32900d11aed584052ece9 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Mon, 16 Jun 2025 15:32:21 +0530 Subject: [PATCH 07/10] Reverting changes --- mssql_python/cursor.py | 165 ++++-------------------- mssql_python/pybind/ddbc_bindings.cpp | 43 ++----- tests/test_004_cursor.py | 179 +++++++++----------------- 3 files changed, 97 insertions(+), 290 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 606b9cc8..227e66a6 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -6,7 +6,6 @@ import ctypes import decimal import uuid -import collections import datetime from typing import List, Union from mssql_python.constants import ConstantsDDBC as ddbc_sql_const @@ -70,10 +69,6 @@ def __init__(self, connection) -> None: # Hence, we can't pass around bools by reference & modify them. # Therefore, it must be a list with exactly one bool element. - # Cache for the named tuple class for the current result set - self._row_namedtuple_class = None - self._row_field_names = None - def _is_unicode_string(self, param): """ Check if a string contains non-ASCII characters. @@ -558,10 +553,6 @@ def execute( if reset_cursor: self._reset_cursor() - # Reset the named tuple class cache when executing a new query - self._row_namedtuple_class = None - self._row_field_names = None - param_info = ddbc_bindings.ParamInfo parameters_type = [] @@ -657,134 +648,63 @@ def fetchone(self) -> Union[None, tuple]: Fetch the next row of a query result set. Returns: - If data is available: - - A named tuple if column names are valid Python identifiers - - A regular tuple otherwise - None if no more data is available - - Named tuples allow access by attribute name (row.column_name) - in addition to index access (row[0]). + Single sequence or None if no more data is available. Raises: Error: If the previous call to execute did not produce any result set. - - Note: - Valid Python identifiers cannot start with numbers and can only - contain alphanumeric characters and underscores. """ self._check_closed() # Check if the cursor is closed - # Use a list to receive the row data - row_list = [] - ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_list) + row = [] + ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - - if ret == ddbc_sql_const.SQL_NO_DATA.value or not row_list: + if ret == ddbc_sql_const.SQL_NO_DATA.value: return None - - # Get field names from the description attribute - field_names = [desc[0] for desc in self.description] - - # Get or create the named tuple class - RowRecord = self._get_row_namedtuple_class(field_names) - - if RowRecord: - # Use the cached named tuple class - return RowRecord(*row_list) - else: - # Fall back to a regular tuple - return tuple(row_list) + return list(row) - def fetchmany(self, size: int = None) -> list: + def fetchmany(self, size: int = None) -> List[tuple]: """ - Fetch the next set of rows of a query result, returning a list of tuples. - An empty list is returned when no more rows are available. + Fetch the next set of rows of a query result. Args: - size (int): The number of rows to fetch. If not provided, the cursor's arraysize - is used. + size: Number of rows to fetch at a time. Returns: - A list of row objects where each row is: - - A named tuple if column names are valid Python identifiers - - A regular tuple otherwise - - Named tuples allow access by attribute name (row.column_name) - in addition to index access (row[0]). - + Sequence of sequences (e.g. list of tuples). + Raises: Error: If the previous call to execute did not produce any result set. - - Note: - Valid Python identifiers cannot start with numbers and can only - contain alphanumeric characters and underscores. """ self._check_closed() # Check if the cursor is closed - + if size is None: size = self.arraysize - + + # Fetch the next set of rows rows = [] ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows, size) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - - if not rows: - return rows - - # Get field names from the description attribute - field_names = [desc[0] for desc in self.description] - - # Get or create the named tuple class - RowRecord = self._get_row_namedtuple_class(field_names) - - if RowRecord: - # Convert each row to a named tuple - return [RowRecord(*row) for row in rows] - else: - # Return rows as regular tuples - return [tuple(row) for row in rows] + if ret == ddbc_sql_const.SQL_NO_DATA.value: + return [] + return rows - def fetchall(self) -> list: + def fetchall(self) -> List[tuple]: """ - Fetch all (remaining) rows of a query result, returning a list of tuples. - An empty list is returned when no more rows are available. + Fetch all (remaining) rows of a query result. Returns: - A list of row objects where each row is: - - A named tuple if column names are valid Python identifiers - - A regular tuple otherwise - - Named tuples allow access by attribute name (row.column_name) - in addition to index access (row[0]). + Sequence of sequences (e.g. list of tuples). Raises: Error: If the previous call to execute did not produce any result set. - - Note: - Valid Python identifiers cannot start with numbers and can only - contain alphanumeric characters and underscores. """ self._check_closed() # Check if the cursor is closed - + + # Fetch all remaining rows rows = [] ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - - if not rows: - return rows - - # Get field names from the description attribute - field_names = [desc[0] for desc in self.description] - - # Get or create the named tuple class - RowRecord = self._get_row_namedtuple_class(field_names) - - if RowRecord: - # Convert each row to a named tuple - return [RowRecord(*row) for row in rows] - else: - # Return rows as regular tuples - return [tuple(row) for row in rows] + return list(rows) def nextset(self) -> Union[bool, None]: """ @@ -792,6 +712,9 @@ def nextset(self) -> Union[bool, None]: Returns: True if there is another result set, None otherwise. + + Raises: + Error: If the previous call to execute did not produce any result set. """ self._check_closed() # Check if the cursor is closed @@ -800,40 +723,4 @@ def nextset(self) -> Union[bool, None]: check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) if ret == ddbc_sql_const.SQL_NO_DATA.value: return False - return True - - def _get_row_namedtuple_class(self, field_names): - """ - Get a cached named tuple class or create a new one if needed. - - Args: - field_names: List of column names from the result set - - Returns: - A named tuple class for the current result set's schema, or None if - the field names are not valid Python identifiers. - """ - # Check if field names are valid for namedtuple - invalid_fields = [name for name in field_names if not (isinstance(name, str) and name.isidentifier())] - if invalid_fields: - if ENABLE_LOGGING: - logger.debug("Cannot create named tuple due to invalid field names: %s", invalid_fields) - return None - - # Check if we already have a cached class with these exact field names - if (self._row_namedtuple_class is not None and - self._row_field_names == field_names): - return self._row_namedtuple_class - - # Create a new named tuple class and cache it - try: - self._row_namedtuple_class = collections.namedtuple('RowRecord', field_names, rename=True) - self._row_field_names = field_names - return self._row_namedtuple_class - except (TypeError, ValueError) as e: - # Log the exception for debugging purposes - if ENABLE_LOGGING: - logger.debug("Failed to create named tuple: %s", str(e)) - self._row_namedtuple_class = None - self._row_field_names = None - return None + return True \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index bdb0ca0a..f9db3d0b 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1816,7 +1816,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { // FetchOne_wrap - Fetches a single row of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be fetched. -// @param row: A Python object reference that will be populated with a named tuple containing the fetched row data. +// @param row: A Python list that will be populated with the fetched row data. // // @return SQLRETURN: SQL_SUCCESS or SQL_SUCCESS_WITH_INFO if data is fetched successfully, // SQL_NO_DATA if there are no more rows to fetch, @@ -1824,40 +1824,21 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { // // This function assumes that the statement handle (hStmt) is already allocated and a query has been // executed. It fetches the next row of data from the result set and populates the provided Python -// object with a named tuple containing the row data. If there are no more rows to fetch, it returns -// SQL_NO_DATA. If an error occurs during fetching, it throws a runtime error. -SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row_list) { +// list with the row data. If there are no more rows to fetch, it returns SQL_NO_DATA. If an error +// occurs during fetching, it throws a runtime error. +SQLRETURN FetchOne_wrap(SqlHandlePtr StatementHandle, py::list& row) { SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); - if (!SQLFetch_ptr) { - LOG("Function pointer not initialized in FetchOne_wrap. Loading the driver.\n"); - DriverLoader::getInstance().loadDriver(); - } - + // Assume hStmt is already allocated and a query has been executed ret = SQLFetch_ptr(hStmt); - if (ret == SQL_NO_DATA) { - return ret; - } else if (!SQL_SUCCEEDED(ret)) { - LOG("Error when fetching data: SQLFetch_ptr failed with retcode {}\n", ret); - return ret; - } - - // Get column count - we don't need column names in C++ anymore - SQLSMALLINT colCount; - SQLRETURN colRet = SQLNumResultCols_ptr(hStmt, &colCount); - if (!SQL_SUCCEEDED(colRet)) { - LOG("Error when getting column count: SQLNumResultCols_ptr failed with retcode {}\n", colRet); - return colRet; - } - - // Get row data into the list - ret = SQLGetData_wrap(StatementHandle, colCount, row_list); - if (!SQL_SUCCEEDED(ret)) { - LOG("Error when fetching data values: SQLGetData_wrap failed with retcode {}\n", ret); - return ret; + if (SQL_SUCCEEDED(ret)) { + // Retrieve column count + SQLSMALLINT colCount = SQLNumResultCols_wrap(StatementHandle); + ret = SQLGetData_wrap(StatementHandle, colCount, row); + } else if (ret != SQL_NO_DATA) { + LOG("Error when fetching data"); } - return ret; } @@ -1994,4 +1975,4 @@ PYBIND11_MODULE(ddbc_bindings, m) { // Log the error but don't throw - let the error happen when functions are called LOG("Failed to load ODBC driver during module initialization: {}", e.what()); } -} +} \ No newline at end of file diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index f0df9aed..62a284b1 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -334,7 +334,7 @@ def test_varchar_full_capacity(cursor, db_connection): # fetchall test cursor.execute("SELECT varchar_column FROM pytest_varchar_test") rows = cursor.fetchall() - assert compare_row_value(rows[0], ['123456789']), "SQL_VARCHAR parsing failed for fetchall" + assert rows[0] == ['123456789'], "SQL_VARCHAR parsing failed for fetchall" except Exception as e: pytest.fail(f"SQL_VARCHAR parsing test failed: {e}") finally: @@ -356,7 +356,7 @@ def test_wvarchar_full_capacity(cursor, db_connection): # fetchall test cursor.execute("SELECT wvarchar_column FROM pytest_wvarchar_test") rows = cursor.fetchall() - assert compare_row_value(rows[0],['123456']), "SQL_WVARCHAR parsing failed for fetchall" + assert rows[0] == ['123456'], "SQL_WVARCHAR parsing failed for fetchall" except Exception as e: pytest.fail(f"SQL_WVARCHAR parsing test failed: {e}") finally: @@ -367,13 +367,10 @@ def test_wvarchar_full_capacity(cursor, db_connection): def test_varbinary_full_capacity(cursor, db_connection): """Test SQL_VARBINARY""" try: - # Drop the table if it exists first - drop_table_if_exists(cursor, "pytest_varbinary_test") - cursor.execute("CREATE TABLE pytest_varbinary_test (varbinary_column VARBINARY(8))") db_connection.commit() # Try inserting binary using both bytes & bytearray - cursor.execute("INSERT INTO pytest_varbinary_test (varbinary_column) VALUES (?)", bytearray("12345", 'utf-8')) + cursor.execute("INSERT INTO pytest_varbinary_test (varbinary_column) VALUES (?)", bytearray("12345", 'utf-8')) cursor.execute("INSERT INTO pytest_varbinary_test (varbinary_column) VALUES (?)", bytes("12345678", 'utf-8')) # Full capacity db_connection.commit() expectedRows = 2 @@ -383,21 +380,17 @@ def test_varbinary_full_capacity(cursor, db_connection): for i in range(0, expectedRows): rows.append(cursor.fetchone()) assert cursor.fetchone() == None, "varbinary_column is expected to have only {} rows".format(expectedRows) - - # Use the compare_row_value function for assertions - assert compare_row_value(rows[0], [bytes("12345", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchone - row 0" - assert compare_row_value(rows[1], [bytes("12345678", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchone - row 1" - + assert rows[0] == [bytes("12345", 'utf-8')], "SQL_VARBINARY parsing failed for fetchone - row 0" + assert rows[1] == [bytes("12345678", 'utf-8')], "SQL_VARBINARY parsing failed for fetchone - row 1" # fetchall test cursor.execute("SELECT varbinary_column FROM pytest_varbinary_test") rows = cursor.fetchall() - assert compare_row_value(rows[0],[bytes("12345", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchall - row 0" - assert compare_row_value(rows[1],[bytes("12345678", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchall - row 1" + assert rows[0] == [bytes("12345", 'utf-8')], "SQL_VARBINARY parsing failed for fetchall - row 0" + assert rows[1] == [bytes("12345678", 'utf-8')], "SQL_VARBINARY parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_VARBINARY parsing test failed: {e}") finally: - # Clean up the table - drop_table_if_exists(cursor, "pytest_varbinary_test") + cursor.execute("DROP TABLE pytest_varbinary_test") db_connection.commit() def test_varchar_max(cursor, db_connection): @@ -414,24 +407,25 @@ def test_varchar_max(cursor, db_connection): for i in range(0, expectedRows): rows.append(cursor.fetchone()) assert cursor.fetchone() == None, "varchar_column is expected to have only {} rows".format(expectedRows) - - assert compare_row_value(rows[0], ["ABCDEFGHI"]), "SQL_VARCHAR parsing failed for fetchone - row 0" - assert compare_row_value(rows[1], [None]), "SQL_VARCHAR parsing failed for fetchone - row 1" - + assert rows[0] == ["ABCDEFGHI"], "SQL_VARCHAR parsing failed for fetchone - row 0" + assert rows[1] == [None], "SQL_VARCHAR parsing failed for fetchone - row 1" # fetchall test cursor.execute("SELECT varchar_column FROM pytest_varchar_test") rows = cursor.fetchall() - assert compare_row_value(rows[0],["ABCDEFGHI"]), "SQL_VARCHAR parsing failed for fetchall - row 0" - assert compare_row_value(rows[1],[None]), "SQL_VARCHAR parsing failed for fetchall - row 1" + assert rows[0] == ["ABCDEFGHI"], "SQL_VARCHAR parsing failed for fetchall - row 0" + assert rows[1] == [None], "SQL_VARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_VARCHAR parsing test failed: {e}") + finally: + cursor.execute("DROP TABLE pytest_varchar_test") + db_connection.commit() def test_wvarchar_max(cursor, db_connection): """Test SQL_WVARCHAR with MAX length""" try: cursor.execute("CREATE TABLE pytest_wvarchar_test (wvarchar_column NVARCHAR(MAX))") db_connection.commit() - cursor.execute("INSERT INTO pytest_wvarchar_test (wvarchar_column) VALUES (?), (?)", ["!@#$%^&*()_+", None]) + cursor.execute("INSERT INTO pytest_wvarchar_test (wvarchar_column) VALUES (?), (?)", ["!@#$%^&*()_+", None]) db_connection.commit() expectedRows = 2 # fetchone test @@ -440,24 +434,22 @@ def test_wvarchar_max(cursor, db_connection): for i in range(0, expectedRows): rows.append(cursor.fetchone()) assert cursor.fetchone() == None, "wvarchar_column is expected to have only {} rows".format(expectedRows) - - assert compare_row_value(rows[0], ["!@#$%^&*()_+"]), "SQL_WVARCHAR parsing failed for fetchone - row 0" - assert compare_row_value(rows[1], [None]), "SQL_WVARCHAR parsing failed for fetchone - row 1" - + assert rows[0] == ["!@#$%^&*()_+"], "SQL_WVARCHAR parsing failed for fetchone - row 0" + assert rows[1] == [None], "SQL_WVARCHAR parsing failed for fetchone - row 1" # fetchall test cursor.execute("SELECT wvarchar_column FROM pytest_wvarchar_test") rows = cursor.fetchall() - assert compare_row_value(rows[0],["!@#$%^&*()_+"]), "SQL_WVARCHAR parsing failed for fetchall - row 0" - assert compare_row_value(rows[1],[None]), "SQL_WVARCHAR parsing failed for fetchall - row 1" + assert rows[0] == ["!@#$%^&*()_+"], "SQL_WVARCHAR parsing failed for fetchall - row 0" + assert rows[1] == [None], "SQL_WVARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_WVARCHAR parsing test failed: {e}") + finally: + cursor.execute("DROP TABLE pytest_wvarchar_test") + db_connection.commit() def test_varbinary_max(cursor, db_connection): """Test SQL_VARBINARY with MAX length""" try: - # Drop the table if it exists first - drop_table_if_exists(cursor, "pytest_varbinary_test") - cursor.execute("CREATE TABLE pytest_varbinary_test (varbinary_column VARBINARY(MAX))") db_connection.commit() # TODO: Uncomment this execute after adding null binary support @@ -471,31 +463,25 @@ def test_varbinary_max(cursor, db_connection): for i in range(0, expectedRows): rows.append(cursor.fetchone()) assert cursor.fetchone() == None, "varbinary_column is expected to have only {} rows".format(expectedRows) - - assert compare_row_value(rows[0], [bytearray("ABCDEF", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchone - row 0" - assert compare_row_value(rows[1], [bytes("123!@#", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchone - row 1" - + assert rows[0] == [bytearray("ABCDEF", 'utf-8')], "SQL_VARBINARY parsing failed for fetchone - row 0" + assert rows[1] == [bytes("123!@#", 'utf-8')], "SQL_VARBINARY parsing failed for fetchone - row 1" # fetchall test cursor.execute("SELECT varbinary_column FROM pytest_varbinary_test") rows = cursor.fetchall() - assert compare_row_value(rows[0],[bytearray("ABCDEF", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchall - row 0" - assert compare_row_value(rows[1],[bytes("123!@#", 'utf-8')]), "SQL_VARBINARY parsing failed for fetchall - row 1" + assert rows[0] == [bytearray("ABCDEF", 'utf-8')], "SQL_VARBINARY parsing failed for fetchall - row 0" + assert rows[1] == [bytes("123!@#", 'utf-8')], "SQL_VARBINARY parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_VARBINARY parsing test failed: {e}") finally: - # Clean up the table - drop_table_if_exists(cursor, "pytest_varbinary_test") + cursor.execute("DROP TABLE pytest_varbinary_test") db_connection.commit() def test_longvarchar(cursor, db_connection): """Test SQL_LONGVARCHAR""" try: - # Drop the table if it exists first - drop_table_if_exists(cursor, "pytest_longvarchar_test") - cursor.execute("CREATE TABLE pytest_longvarchar_test (longvarchar_column TEXT)") db_connection.commit() - cursor.execute("INSERT INTO pytest_longvarchar_test (longvarchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) + cursor.execute("INSERT INTO pytest_longvarchar_test (longvarchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) db_connection.commit() expectedRows = 2 # fetchone test @@ -503,32 +489,26 @@ def test_longvarchar(cursor, db_connection): rows = [] for i in range(0, expectedRows): rows.append(cursor.fetchone()) - assert cursor.fetchone() == None, "longvarchar_column is expected to have only {} rows".format(expectedRows) - - assert compare_row_value(rows[0], ["ABCDEFGHI"]), "SQL_LONGVARCHAR parsing failed for fetchone - row 0" - assert compare_row_value(rows[1], [None]), "SQL_LONGVARCHAR parsing failed for fetchone - row 1" - + assert cursor.fetchone() == None, "longvarchar_column is expected to have only {} rows".format(expectedRows) + assert rows[0] == ["ABCDEFGHI"], "SQL_LONGVARCHAR parsing failed for fetchone - row 0" + assert rows[1] == [None], "SQL_LONGVARCHAR parsing failed for fetchone - row 1" # fetchall test cursor.execute("SELECT longvarchar_column FROM pytest_longvarchar_test") rows = cursor.fetchall() - assert compare_row_value(rows[0],["ABCDEFGHI"]), "SQL_LONGVARCHAR parsing failed for fetchall - row 0" - assert compare_row_value(rows[1],[None]), "SQL_LONGVARCHAR parsing failed for fetchall - row 1" + assert rows[0] == ["ABCDEFGHI"], "SQL_LONGVARCHAR parsing failed for fetchall - row 0" + assert rows[1] == [None], "SQL_LONGVARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_LONGVARCHAR parsing test failed: {e}") finally: - # Clean up the table - drop_table_if_exists(cursor, "pytest_longvarchar_test") + cursor.execute("DROP TABLE pytest_longvarchar_test") db_connection.commit() def test_longwvarchar(cursor, db_connection): """Test SQL_LONGWVARCHAR""" try: - # Drop the table if it exists first - drop_table_if_exists(cursor, "pytest_longwvarchar_test") - cursor.execute("CREATE TABLE pytest_longwvarchar_test (longwvarchar_column NTEXT)") db_connection.commit() - cursor.execute("INSERT INTO pytest_longwvarchar_test (longwvarchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) + cursor.execute("INSERT INTO pytest_longwvarchar_test (longwvarchar_column) VALUES (?), (?)", ["ABCDEFGHI", None]) db_connection.commit() expectedRows = 2 # fetchone test @@ -536,54 +516,45 @@ def test_longwvarchar(cursor, db_connection): rows = [] for i in range(0, expectedRows): rows.append(cursor.fetchone()) - assert cursor.fetchone() == None, "longwvarchar_column is expected to have only {} rows".format(expectedRows) - - assert compare_row_value(rows[0], ["ABCDEFGHI"]), "SQL_LONGWVARCHAR parsing failed for fetchone - row 0" - assert compare_row_value(rows[1], [None]), "SQL_LONGWVARCHAR parsing failed for fetchone - row 1" - + assert cursor.fetchone() == None, "longwvarchar_column is expected to have only {} rows".format(expectedRows) + assert rows[0] == ["ABCDEFGHI"], "SQL_LONGWVARCHAR parsing failed for fetchone - row 0" + assert rows[1] == [None], "SQL_LONGWVARCHAR parsing failed for fetchone - row 1" # fetchall test cursor.execute("SELECT longwvarchar_column FROM pytest_longwvarchar_test") rows = cursor.fetchall() - assert compare_row_value(rows[0],["ABCDEFGHI"]), "SQL_LONGWVARCHAR parsing failed for fetchall - row 0" - assert compare_row_value(rows[1],[None]), "SQL_LONGWVARCHAR parsing failed for fetchall - row 1" + assert rows[0] == ["ABCDEFGHI"], "SQL_LONGWVARCHAR parsing failed for fetchall - row 0" + assert rows[1] == [None], "SQL_LONGWVARCHAR parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_LONGWVARCHAR parsing test failed: {e}") finally: - # Clean up the table - drop_table_if_exists(cursor, "pytest_longwvarchar_test") + cursor.execute("DROP TABLE pytest_longwvarchar_test") db_connection.commit() def test_longvarbinary(cursor, db_connection): """Test SQL_LONGVARBINARY""" try: - # Drop the table if it exists first - drop_table_if_exists(cursor, "pytest_longvarbinary_test") - cursor.execute("CREATE TABLE pytest_longvarbinary_test (longvarbinary_column IMAGE)") db_connection.commit() cursor.execute("INSERT INTO pytest_longvarbinary_test (longvarbinary_column) VALUES (?), (?)", [bytearray("ABCDEFGHI", 'utf-8'), bytes("123!@#", 'utf-8')]) db_connection.commit() - expectedRows = 2 # The test is intentionally designed for 2 rows + expectedRows = 3 # fetchone test cursor.execute("SELECT longvarbinary_column FROM pytest_longvarbinary_test") rows = [] for i in range(0, expectedRows): rows.append(cursor.fetchone()) - assert cursor.fetchone() == None, "longvarbinary_column is expected to have only {} rows".format(expectedRows) - - assert compare_row_value(rows[0], [bytearray("ABCDEFGHI", 'utf-8')]), "SQL_LONGVARBINARY parsing failed for fetchone - row 0" - assert compare_row_value(rows[1], [bytes("123!@#\0\0\0", 'utf-8')]), "SQL_LONGVARBINARY parsing failed for fetchone - row 1" - + assert cursor.fetchone() == None, "longvarbinary_column is expected to have only {} rows".format(expectedRows) + assert rows[0] == [bytearray("ABCDEFGHI", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchone - row 0" + assert rows[1] == [bytes("123!@#\0\0\0", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchone - row 1" # fetchall test cursor.execute("SELECT longvarbinary_column FROM pytest_longvarbinary_test") rows = cursor.fetchall() - assert compare_row_value(rows[0],[bytearray("ABCDEFGHI", 'utf-8')]), "SQL_LONGVARBINARY parsing failed for fetchall - row 0" - assert compare_row_value(rows[1],[bytes("123!@#\0\0\0", 'utf-8')]), "SQL_LONGVARBINARY parsing failed for fetchall - row 1" + assert rows[0] == [bytearray("ABCDEFGHI", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchall - row 0" + assert rows[1] == [bytes("123!@#\0\0\0", 'utf-8')], "SQL_LONGVARBINARY parsing failed for fetchall - row 1" except Exception as e: pytest.fail(f"SQL_LONGVARBINARY parsing test failed: {e}") finally: - # Clean up the table - drop_table_if_exists(cursor, "pytest_longvarbinary_test") + cursor.execute("DROP TABLE pytest_longvarbinary_test") db_connection.commit() def test_create_table(cursor, db_connection): @@ -872,9 +843,9 @@ def test_join_operations(cursor): """) rows = cursor.fetchall() assert len(rows) == 3, "Join operation returned incorrect number of rows" - assert compare_row_value(rows[0],['Alice', 'HR', 'Project A']), "Join operation returned incorrect data for row 1" - assert compare_row_value(rows[1],['Bob', 'Engineering', 'Project B']), "Join operation returned incorrect data for row 2" - assert compare_row_value(rows[2],['Charlie', 'HR', 'Project C']), "Join operation returned incorrect data for row 3" + assert rows[0] == ['Alice', 'HR', 'Project A'], "Join operation returned incorrect data for row 1" + assert rows[1] == ['Bob', 'Engineering', 'Project B'], "Join operation returned incorrect data for row 2" + assert rows[2] == ['Charlie', 'HR', 'Project C'], "Join operation returned incorrect data for row 3" except Exception as e: pytest.fail(f"Join operation failed: {e}") @@ -892,8 +863,8 @@ def test_join_operations_with_parameters(cursor): cursor.execute(query, employee_ids) rows = cursor.fetchall() assert len(rows) == 2, "Join operation with parameters returned incorrect number of rows" - assert compare_row_value(rows[0],['Alice', 'HR', 'Project A']), "Join operation with parameters returned incorrect data for row 1" - assert compare_row_value(rows[1],['Bob', 'Engineering', 'Project B']), "Join operation with parameters returned incorrect data for row 2" + assert rows[0] == ['Alice', 'HR', 'Project A'], "Join operation with parameters returned incorrect data for row 1" + assert rows[1] == ['Bob', 'Engineering', 'Project B'], "Join operation with parameters returned incorrect data for row 2" except Exception as e: pytest.fail(f"Join operation with parameters failed: {e}") @@ -924,7 +895,7 @@ def test_execute_stored_procedure_with_parameters(cursor): cursor.execute("{CALL GetEmployeeProjects(?)}", [1]) rows = cursor.fetchall() assert len(rows) == 1, "Stored procedure with parameters returned incorrect number of rows" - assert compare_row_value(rows[0],['Alice', 'Project A']), "Stored procedure with parameters returned incorrect data" + assert rows[0] == ['Alice', 'Project A'], "Stored procedure with parameters returned incorrect data" except Exception as e: pytest.fail(f"Stored procedure execution with parameters failed: {e}") @@ -937,7 +908,7 @@ def test_execute_stored_procedure_without_parameters(cursor): """) rows = cursor.fetchall() assert len(rows) == 1, "Stored procedure without parameters returned incorrect number of rows" - assert compare_row_value(rows[0],['Bob', 'Project B']), "Stored procedure without parameters returned incorrect data" + assert rows[0] == ['Bob', 'Project B'], "Stored procedure without parameters returned incorrect data" except Exception as e: pytest.fail(f"Stored procedure execution without parameters failed: {e}") @@ -1103,9 +1074,6 @@ def test_boolean(cursor, db_connection): def test_sql_wvarchar(cursor, db_connection): """Test SQL_WVARCHAR""" try: - # Drop the table if it exists first - drop_table_if_exists(cursor, "pytest_wvarchar_test") - cursor.execute("CREATE TABLE pytest_wvarchar_test (wvarchar_column NVARCHAR(255))") db_connection.commit() cursor.execute("INSERT INTO pytest_wvarchar_test (wvarchar_column) VALUES (?)", ['nvarchar data']) @@ -1116,16 +1084,12 @@ def test_sql_wvarchar(cursor, db_connection): except Exception as e: pytest.fail(f"SQL_WVARCHAR parsing test failed: {e}") finally: - # Clean up - drop_table_if_exists(cursor, "pytest_wvarchar_test") + cursor.execute("DROP TABLE pytest_wvarchar_test") db_connection.commit() def test_sql_varchar(cursor, db_connection): """Test SQL_VARCHAR""" try: - # Drop the table if it exists first - drop_table_if_exists(cursor, "pytest_varchar_test") - cursor.execute("CREATE TABLE pytest_varchar_test (varchar_column VARCHAR(255))") db_connection.commit() cursor.execute("INSERT INTO pytest_varchar_test (varchar_column) VALUES (?)", ['varchar data']) @@ -1136,8 +1100,7 @@ def test_sql_varchar(cursor, db_connection): except Exception as e: pytest.fail(f"SQL_VARCHAR parsing test failed: {e}") finally: - # Clean up - drop_table_if_exists(cursor, "pytest_varchar_test") + cursor.execute("DROP TABLE pytest_varchar_test") db_connection.commit() def test_numeric_precision_scale_positive_exponent(cursor, db_connection): @@ -1191,28 +1154,4 @@ def test_close(db_connection): except Exception as e: pytest.fail(f"Cursor close test failed: {e}") finally: - cursor = db_connection.cursor() - -def compare_row_value(actual, expected): - """Compare a row with expected values, supporting both namedtuple and list formats. - - Args: - actual: The actual row returned from fetchone() or similar (could be namedtuple) - expected: The expected values as a list - - Returns: - bool: True if the values match, False otherwise - """ - # If the actual row is a namedtuple (has _fields attribute) - if hasattr(actual, '_fields'): - # For single column result, compare the first field's value directly - if len(actual) == 1: - return getattr(actual, actual._fields[0]) == expected[0] - # For multiple columns, extract all values and compare as lists - actual_values = [getattr(actual, field) for field in actual._fields] - return actual_values == expected - # For regular list or tuple - elif isinstance(actual, (list, tuple)): - return actual == expected - - return False + cursor = db_connection.cursor() \ No newline at end of file From dfc55e649187588a9898b18c51bc0861ff32b8d5 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Tue, 17 Jun 2025 10:51:33 +0530 Subject: [PATCH 08/10] Rewriting the logic --- mssql_python/cursor.py | 69 +++++++++++++++++++++--------------------- mssql_python/row.py | 58 +++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 35 deletions(-) create mode 100644 mssql_python/row.py diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 227e66a6..73d082f8 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -12,6 +12,7 @@ from mssql_python.helpers import check_error from mssql_python.logging_config import get_logger, ENABLE_LOGGING from mssql_python import ddbc_bindings +from .row import Row logger = get_logger() @@ -58,7 +59,8 @@ def __init__(self, connection) -> None: 1 # Default number of rows to fetch at a time is 1, user can change it ) self.buffer_length = 1024 # Default buffer length for string data - self.closed = False # Flag to indicate if the cursor is closed + self.closed = False + self._result_set_empty = False # Add this initialization self.last_executed_stmt = ( "" # Stores the last statement executed by this cursor ) @@ -643,68 +645,65 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: total_rowcount = -1 self.rowcount = total_rowcount - def fetchone(self) -> Union[None, tuple]: + def fetchone(self) -> Union[None, Row]: """ Fetch the next row of a query result set. - + Returns: - Single sequence or None if no more data is available. - - Raises: - Error: If the previous call to execute did not produce any result set. + Single Row object or None if no more data is available. """ self._check_closed() # Check if the cursor is closed - row = [] - ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + # Fetch raw data + row_data = [] + ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: return None - return list(row) + + # Create and return a Row object + return Row(row_data, self.description) - def fetchmany(self, size: int = None) -> List[tuple]: + def fetchmany(self, size: int = None) -> List[Row]: """ Fetch the next set of rows of a query result. - + Args: size: Number of rows to fetch at a time. - + Returns: - Sequence of sequences (e.g. list of tuples). - - Raises: - Error: If the previous call to execute did not produce any result set. + List of Row objects. """ self._check_closed() # Check if the cursor is closed if size is None: size = self.arraysize - # Fetch the next set of rows - rows = [] - ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows, size) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - if ret == ddbc_sql_const.SQL_NO_DATA.value: + if size <= 0: return [] - return rows + + # Fetch raw data + rows_data = [] + ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) + + # Convert raw data to Row objects + return [Row(row_data, self.description) for row_data in rows_data] - def fetchall(self) -> List[tuple]: + def fetchall(self) -> List[Row]: """ Fetch all (remaining) rows of a query result. - + Returns: - Sequence of sequences (e.g. list of tuples). - - Raises: - Error: If the previous call to execute did not produce any result set. + List of Row objects. """ self._check_closed() # Check if the cursor is closed - # Fetch all remaining rows - rows = [] - ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows) - check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) - return list(rows) + # Fetch raw data + rows_data = [] + ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + # Convert raw data to Row objects + return [Row(row_data, self.description) for row_data in rows_data] def nextset(self) -> Union[bool, None]: """ diff --git a/mssql_python/row.py b/mssql_python/row.py new file mode 100644 index 00000000..f6641e99 --- /dev/null +++ b/mssql_python/row.py @@ -0,0 +1,58 @@ +class Row: + """ + A row of data from a cursor fetch operation. Provides both tuple-like indexing + and attribute access to column values. + + Example: + row = cursor.fetchone() + print(row[0]) # Access by index + print(row.column_name) # Access by column name + """ + + def __init__(self, values, cursor_description): + """ + Initialize a Row object with values and cursor description. + + Args: + values: List of values for this row + cursor_description: The cursor description containing column metadata + """ + self._values = values + self.cursor_description = cursor_description + + # Create mapping of column names to indices + self._column_map = {} + for i, desc in enumerate(cursor_description): + if desc and desc[0]: # Ensure column name exists + self._column_map[desc[0]] = i + + def __getitem__(self, index): + """Allow accessing by numeric index: row[0]""" + return self._values[index] + + def __getattr__(self, name): + """Allow accessing by column name as attribute: row.column_name""" + if name in self._column_map: + return self._values[self._column_map[name]] + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __eq__(self, other): + """ + Support comparison with lists for test compatibility. + This is the key change needed to fix the tests. + """ + if isinstance(other, list): + return self._values == other + return super().__eq__(other) + + def __len__(self): + """Return the number of values in the row""" + return len(self._values) + + def __iter__(self): + """Allow iteration through values""" + return iter(self._values) + + def __repr__(self): + """Return a string representation of the row""" + return f"Row{tuple(self._values)}" \ No newline at end of file From 2892f23c7c97fb75ce190e88d55d67792e954bad Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Tue, 17 Jun 2025 10:53:24 +0530 Subject: [PATCH 09/10] adding line --- mssql_python/pybind/ddbc_bindings.cpp | 2 +- tests/test_004_cursor.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f9db3d0b..68c772e8 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1975,4 +1975,4 @@ PYBIND11_MODULE(ddbc_bindings, m) { // Log the error but don't throw - let the error happen when functions are called LOG("Failed to load ODBC driver during module initialization: {}", e.what()); } -} \ No newline at end of file +} diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 62a284b1..9ee53302 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1154,4 +1154,5 @@ def test_close(db_connection): except Exception as e: pytest.fail(f"Cursor close test failed: {e}") finally: - cursor = db_connection.cursor() \ No newline at end of file + cursor = db_connection.cursor() + \ No newline at end of file From 409f2fa0340783d9ea470ae62c8ee629223ce9d3 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Tue, 17 Jun 2025 12:12:48 +0530 Subject: [PATCH 10/10] Adding todo for cursor_description --- main.py | 26 +++-------- mssql_python/row.py | 9 +++- tests/test_004_cursor.py | 95 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 21 deletions(-) diff --git a/main.py b/main.py index c6908037..b45b88d7 100644 --- a/main.py +++ b/main.py @@ -3,33 +3,19 @@ import os import decimal -# setup_logging('stdout') +setup_logging('stdout') conn_str = os.getenv("DB_CONNECTION_STRING") conn = connect(conn_str) +# conn.autocommit = True + cursor = conn.cursor() cursor.execute("SELECT database_id, name from sys.databases;") -rows = cursor.fetchone() - -# Debug: Print the type and content of rows -print(f"Type of rows: {type(rows)}") -print(f"Value of rows: {rows}") +rows = cursor.fetchall() -# Only try to access properties if rows is not None -if rows is not None: - try: - # Try different ways to access the data - print(f"First column by index: {rows[0]}") - - # Access by attribute name (these should now work) - print(f"First column by name: {rows.database_id}") - print(f"Second column by name: {rows.name}") - - # Print all available attributes - print(f"Available attributes: {dir(rows)}") - except Exception as e: - print(f"Exception accessing row data: {e}") +for row in rows: + print(f"Database ID: {row[0]}, Name: {row[1]}") cursor.close() conn.close() \ No newline at end of file diff --git a/mssql_python/row.py b/mssql_python/row.py index f6641e99..bc74288d 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -18,7 +18,12 @@ def __init__(self, values, cursor_description): cursor_description: The cursor description containing column metadata """ self._values = values - self.cursor_description = cursor_description + + # TODO: ADO task - Optimize memory usage by sharing column map across rows + # Instead of storing the full cursor_description in each Row object: + # 1. Build the column map once at the cursor level after setting description + # 2. Pass only this map to each Row instance + # 3. Remove cursor_description from Row objects entirely # Create mapping of column names to indices self._column_map = {} @@ -43,6 +48,8 @@ def __eq__(self, other): """ if isinstance(other, list): return self._values == other + elif isinstance(other, Row): + return self._values == other._values return super().__eq__(other) def __len__(self): diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 9ee53302..659a3164 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1145,6 +1145,101 @@ def test_numeric_precision_scale_negative_exponent(cursor, db_connection): cursor.execute("DROP TABLE pytest_numeric_test") db_connection.commit() +def test_row_attribute_access(cursor, db_connection): + """Test accessing row values by column name as attributes""" + try: + # Create test table with multiple columns + cursor.execute(""" + CREATE TABLE pytest_row_attr_test ( + id INT PRIMARY KEY, + name VARCHAR(50), + email VARCHAR(100), + age INT + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO pytest_row_attr_test (id, name, email, age) + VALUES (1, 'John Doe', 'john@example.com', 30) + """) + db_connection.commit() + + # Test attribute access + cursor.execute("SELECT * FROM pytest_row_attr_test") + row = cursor.fetchone() + + # Access by attribute + assert row.id == 1, "Failed to access 'id' by attribute" + assert row.name == 'John Doe', "Failed to access 'name' by attribute" + assert row.email == 'john@example.com', "Failed to access 'email' by attribute" + assert row.age == 30, "Failed to access 'age' by attribute" + + # Compare attribute access with index access + assert row.id == row[0], "Attribute access for 'id' doesn't match index access" + assert row.name == row[1], "Attribute access for 'name' doesn't match index access" + assert row.email == row[2], "Attribute access for 'email' doesn't match index access" + assert row.age == row[3], "Attribute access for 'age' doesn't match index access" + + # Test attribute that doesn't exist + with pytest.raises(AttributeError): + value = row.nonexistent_column + + except Exception as e: + pytest.fail(f"Row attribute access test failed: {e}") + finally: + cursor.execute("DROP TABLE pytest_row_attr_test") + db_connection.commit() + +def test_row_comparison_with_list(cursor, db_connection): + """Test comparing Row objects with lists (__eq__ method)""" + try: + # Create test table + cursor.execute("CREATE TABLE pytest_row_comparison_test (col1 INT, col2 VARCHAR(20), col3 FLOAT)") + db_connection.commit() + + # Insert test data + cursor.execute("INSERT INTO pytest_row_comparison_test VALUES (10, 'test_string', 3.14)") + db_connection.commit() + + # Test fetchone comparison with list + cursor.execute("SELECT * FROM pytest_row_comparison_test") + row = cursor.fetchone() + assert row == [10, 'test_string', 3.14], "Row did not compare equal to matching list" + assert row != [10, 'different', 3.14], "Row compared equal to non-matching list" + + # Test full row equality + cursor.execute("SELECT * FROM pytest_row_comparison_test") + row1 = cursor.fetchone() + cursor.execute("SELECT * FROM pytest_row_comparison_test") + row2 = cursor.fetchone() + assert row1 == row2, "Identical rows should be equal" + + # Insert different data + cursor.execute("INSERT INTO pytest_row_comparison_test VALUES (20, 'other_string', 2.71)") + db_connection.commit() + + # Test different rows are not equal + cursor.execute("SELECT * FROM pytest_row_comparison_test WHERE col1 = 10") + row1 = cursor.fetchone() + cursor.execute("SELECT * FROM pytest_row_comparison_test WHERE col1 = 20") + row2 = cursor.fetchone() + assert row1 != row2, "Different rows should not be equal" + + # Test fetchmany row comparison with lists + cursor.execute("SELECT * FROM pytest_row_comparison_test ORDER BY col1") + rows = cursor.fetchmany(2) + assert len(rows) == 2, "Should have fetched 2 rows" + assert rows[0] == [10, 'test_string', 3.14], "First row didn't match expected list" + assert rows[1] == [20, 'other_string', 2.71], "Second row didn't match expected list" + + except Exception as e: + pytest.fail(f"Row comparison test failed: {e}") + finally: + cursor.execute("DROP TABLE pytest_row_comparison_test") + db_connection.commit() + def test_close(db_connection): """Test closing the cursor""" try: