diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 81e60d37..20c8f663 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -97,7 +97,6 @@ class ConstantsDDBC(Enum): SQL_ATTR_ROW_ARRAY_SIZE = 27 SQL_ATTR_ROWS_FETCHED_PTR = 26 SQL_ATTR_ROW_STATUS_PTR = 25 - SQL_FETCH_NEXT = 1 SQL_ROW_SUCCESS = 0 SQL_ROW_SUCCESS_WITH_INFO = 1 SQL_ROW_NOROW = 100 @@ -117,6 +116,14 @@ class ConstantsDDBC(Enum): SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 + SQL_FETCH_NEXT = 1 + SQL_FETCH_FIRST = 2 + SQL_FETCH_LAST = 3 + SQL_FETCH_PRIOR = 4 + SQL_FETCH_ABSOLUTE = 5 + SQL_FETCH_RELATIVE = 6 + SQL_FETCH_BOOKMARK = 8 + class AuthType(Enum): """Constants for authentication types""" INTERACTIVE = "activedirectoryinteractive" diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 217e0475..12be28fe 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -8,7 +8,6 @@ - Do not use a cursor after it is closed, or after its parent connection is closed. - Use close() to release resources held by the cursor as soon as it is no longer needed. """ -import ctypes import decimal import uuid import datetime @@ -16,7 +15,7 @@ from mssql_python.constants import ConstantsDDBC as ddbc_sql_const from mssql_python.helpers import check_error, log from mssql_python import ddbc_bindings -from mssql_python.exceptions import InterfaceError +from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError from .row import Row @@ -77,8 +76,12 @@ def __init__(self, connection) -> None: # Therefore, it must be a list with exactly one bool element. # rownumber attribute - self._rownumber = -1 # Track the current row index in the result set + self._rownumber = -1 # DB-API extension: last returned row index, -1 before first + self._next_row_index = 0 # internal: index of the next row the driver will return (0-based) self._has_result_set = False # Track if we have an active result set + self._skip_increment_for_next_fetch = False # Track if we need to skip incrementing the row index + + self.messages = [] # Store diagnostic messages def _is_unicode_string(self, param): """ @@ -452,6 +455,9 @@ def close(self) -> None: if self.closed: raise Exception("Cursor is already closed.") + # Clear messages per DBAPI + self.messages = [] + if self.hstmt: self.hstmt.free() self.hstmt = None @@ -594,18 +600,21 @@ def connection(self): def _reset_rownumber(self): """Reset the rownumber tracking when starting a new result set.""" self._rownumber = -1 + self._next_row_index = 0 self._has_result_set = True + self._skip_increment_for_next_fetch = False def _increment_rownumber(self): """ - Increment the rownumber by 1. - - This should be called after each fetch operation to keep track of the current row index. + Called after a successful fetch from the driver. Keep both counters consistent. """ if self._has_result_set: - self._rownumber += 1 + # driver returned one row, so the next row index increments by 1 + self._next_row_index += 1 + # rownumber is last returned row index + self._rownumber = self._next_row_index - 1 else: - raise InterfaceError("Cannot increment rownumber: no active result set.") + raise InterfaceError("Cannot increment rownumber: no active result set.", "No active result set.") # Will be used when we add support for scrollable cursors def _decrement_rownumber(self): @@ -620,8 +629,8 @@ def _decrement_rownumber(self): else: self._rownumber = -1 else: - raise InterfaceError("Cannot decrement rownumber: no active result set.") - + raise InterfaceError("Cannot decrement rownumber: no active result set.", "No active result set.") + def _clear_rownumber(self): """ Clear the rownumber tracking. @@ -630,6 +639,7 @@ def _clear_rownumber(self): """ self._rownumber = -1 self._has_result_set = False + self._skip_increment_for_next_fetch = False def __iter__(self): """ @@ -693,6 +703,9 @@ def execute( if reset_cursor: self._reset_cursor() + # Clear any previous messages + self.messages = [] + param_info = ddbc_bindings.ParamInfo parameters_type = [] @@ -740,7 +753,14 @@ def execute( self.is_stmt_prepared, use_prepare, ) + + # Check for errors but don't raise exceptions for info/warning messages check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + # Capture any diagnostic messages (SQL_SUCCESS_WITH_INFO, etc.) + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + self.last_executed_stmt = operation # Update rowcount after execution @@ -752,8 +772,10 @@ def execute( # Reset rownumber for new result set (only for SELECT statements) if self.description: # If we have column descriptions, it's likely a SELECT + self.rowcount = -1 self._reset_rownumber() else: + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self._clear_rownumber() # Return self for method chaining @@ -820,7 +842,10 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: """ self._check_closed() self._reset_cursor() - + + # Clear any previous messages + self.messages = [] + if not seq_of_parameters: self.rowcount = 0 return @@ -852,13 +877,19 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: ) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + # Capture any diagnostic messages after execution + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self.last_executed_stmt = operation self._initialize_description() if self.description: + self.rowcount = -1 self._reset_rownumber() else: + self.rowcount = ddbc_bindings.DDBCSQLRowCount(self.hstmt) self._clear_rownumber() def fetchone(self) -> Union[None, Row]: @@ -875,14 +906,22 @@ def fetchone(self) -> Union[None, Row]: try: ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data) + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + if ret == ddbc_sql_const.SQL_NO_DATA.value: return None - # Only increment rownumber for successful fetch with data - self._increment_rownumber() + # Update internal position after successful fetch + if self._skip_increment_for_next_fetch: + self._skip_increment_for_next_fetch = False + self._next_row_index += 1 + else: + self._increment_rownumber() - # Create and return a Row object - return Row(row_data, self.description) + # Create and return a Row object, passing column name map if available + column_map = getattr(self, '_column_name_map', None) + return Row(row_data, self.description, column_map) except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -898,6 +937,8 @@ def fetchmany(self, size: int = None) -> List[Row]: List of Row objects. """ self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() if size is None: size = self.arraysize @@ -909,14 +950,20 @@ def fetchmany(self, size: int = None) -> List[Row]: rows_data = [] try: ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size) + + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + # Update rownumber for the number of rows actually fetched if rows_data and self._has_result_set: - for _ in rows_data: - self._increment_rownumber() + # advance counters by number of rows actually returned + self._next_row_index += len(rows_data) + self._rownumber = self._next_row_index - 1 # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] + column_map = getattr(self, '_column_name_map', None) + return [Row(row_data, self.description, column_map) for row_data in rows_data] except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -929,19 +976,26 @@ def fetchall(self) -> List[Row]: List of Row objects. """ self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() # Fetch raw data rows_data = [] try: ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + if self.hstmt: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt)) + # Update rownumber for the number of rows actually fetched if rows_data and self._has_result_set: - for _ in rows_data: - self._increment_rownumber() + self._next_row_index += len(rows_data) + self._rownumber = self._next_row_index - 1 # Convert raw data to Row objects - return [Row(row_data, self.description) for row_data in rows_data] + column_map = getattr(self, '_column_name_map', None) + return [Row(row_data, self.description, column_map) for row_data in rows_data] except Exception as e: # On error, don't increment rownumber - rethrow the error raise e @@ -958,6 +1012,9 @@ def nextset(self) -> Union[bool, None]: """ self._check_closed() # Check if the cursor is closed + # Clear messages per DBAPI + self.messages = [] + # Skip to the next result set ret = ddbc_bindings.DDBCSQLMoreResults(self.hstmt) check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) @@ -996,6 +1053,11 @@ def fetchval(self): """ self._check_closed() # Check if the cursor is closed + # Check if this is a result-producing statement + if not self.description: + # Non-result-set statement (INSERT, UPDATE, DELETE, etc.) + return None + # Fetch the first row row = self.fetchone() @@ -1010,6 +1072,64 @@ def fetchval(self): # Return the first column value (could be None if the column value is NULL) return row[0] + def commit(self): + """ + Commit all SQL statements executed on the connection that created this cursor. + + This is a convenience method that calls commit() on the underlying connection. + It affects all cursors created by the same connection since the last commit/rollback. + + The benefit is that many uses can now just use the cursor and not have to track + the connection object. + + Raises: + Exception: If the cursor is closed or if the commit operation fails. + + Example: + >>> cursor.execute("INSERT INTO users (name) VALUES (?)", "John") + >>> cursor.commit() # Commits the INSERT + + Note: + This is equivalent to calling connection.commit() but provides convenience + for code that only has access to the cursor object. + """ + self._check_closed() # Check if the cursor is closed + + # Clear messages per DBAPI + self.messages = [] + + # Delegate to the connection's commit method + self._connection.commit() + + def rollback(self): + """ + Roll back all SQL statements executed on the connection that created this cursor. + + This is a convenience method that calls rollback() on the underlying connection. + It affects all cursors created by the same connection since the last commit/rollback. + + The benefit is that many uses can now just use the cursor and not have to track + the connection object. + + Raises: + Exception: If the cursor is closed or if the rollback operation fails. + + Example: + >>> cursor.execute("INSERT INTO users (name) VALUES (?)", "John") + >>> cursor.rollback() # Rolls back the INSERT + + Note: + This is equivalent to calling connection.rollback() but provides convenience + for code that only has access to the cursor object. + """ + self._check_closed() # Check if the cursor is closed + + # Clear messages per DBAPI + self.messages = [] + + # Delegate to the connection's rollback method + self._connection.rollback() + def __del__(self): """ Destructor to ensure the cursor is closed when it is no longer needed. @@ -1021,4 +1141,243 @@ def __del__(self): self.close() except Exception as e: # Don't raise an exception in __del__, just log it - log('error', "Error during cursor cleanup in __del__: %s", e) \ No newline at end of file + log('error', "Error during cursor cleanup in __del__: %s", e) + + def scroll(self, value: int, mode: str = 'relative') -> None: + """ + Scroll using SQLFetchScroll only, matching test semantics: + - relative(N>0): consume N rows; rownumber = previous + N; next fetch returns the following row. + - absolute(-1): before first (rownumber = -1), no data consumed. + - absolute(0): position so next fetch returns first row; rownumber stays 0 even after that fetch. + - absolute(k>0): next fetch returns row index k (0-based); rownumber == k after scroll. + """ + self._check_closed() + + # Clear messages per DBAPI + self.messages = [] + + if mode not in ('relative', 'absolute'): + raise ProgrammingError("Invalid scroll mode", + f"mode must be 'relative' or 'absolute', got '{mode}'") + if not self._has_result_set: + raise ProgrammingError("No active result set", + "Cannot scroll: no result set available. Execute a query first.") + if not isinstance(value, int): + raise ProgrammingError("Invalid scroll value type", + f"scroll value must be an integer, got {type(value).__name__}") + + # Relative backward not supported + if mode == 'relative' and value < 0: + raise NotSupportedError("Backward scrolling not supported", + f"Cannot move backward by {value} rows on a forward-only cursor") + + row_data: list = [] + + # Absolute special cases + if mode == 'absolute': + if value == -1: + # Before first + ddbc_bindings.DDBCSQLFetchScroll(self.hstmt, + ddbc_sql_const.SQL_FETCH_ABSOLUTE.value, + 0, row_data) + self._rownumber = -1 + self._next_row_index = 0 + return + if value == 0: + # Before first, but tests want rownumber==0 pre and post the next fetch + ddbc_bindings.DDBCSQLFetchScroll(self.hstmt, + ddbc_sql_const.SQL_FETCH_ABSOLUTE.value, + 0, row_data) + self._rownumber = 0 + self._next_row_index = 0 + self._skip_increment_for_next_fetch = True + return + + try: + if mode == 'relative': + if value == 0: + return + ret = ddbc_bindings.DDBCSQLFetchScroll(self.hstmt, + ddbc_sql_const.SQL_FETCH_RELATIVE.value, + value, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: + raise IndexError("Cannot scroll to specified position: end of result set reached") + # Consume N rows; last-returned index advances by N + self._rownumber = self._rownumber + value + self._next_row_index = self._rownumber + 1 + return + + # absolute(k>0): map Python k (0-based next row) to ODBC ABSOLUTE k (1-based), + # intentionally passing k so ODBC fetches row #k (1-based), i.e., 0-based (k-1), + # leaving the NEXT fetch to return 0-based index k. + ret = ddbc_bindings.DDBCSQLFetchScroll(self.hstmt, + ddbc_sql_const.SQL_FETCH_ABSOLUTE.value, + value, row_data) + if ret == ddbc_sql_const.SQL_NO_DATA.value: + raise IndexError(f"Cannot scroll to position {value}: end of result set reached") + + # Tests expect rownumber == value after absolute(value) + # Next fetch should return row index 'value' + self._rownumber = value + self._next_row_index = value + + except Exception as e: + if isinstance(e, (IndexError, NotSupportedError)): + raise + raise IndexError(f"Scroll operation failed: {e}") from e + + def skip(self, count: int) -> None: + """ + Skip the next count records in the query result set. + + Args: + count: Number of records to skip. + + Raises: + IndexError: If attempting to skip past the end of the result set. + ProgrammingError: If count is not an integer. + NotSupportedError: If attempting to skip backwards. + """ + from mssql_python.exceptions import ProgrammingError, NotSupportedError + + self._check_closed() + + # Clear messages + self.messages = [] + + # Simply delegate to the scroll method with 'relative' mode + self.scroll(count, 'relative') + + def _execute_tables(self, stmt_handle, catalog_name=None, schema_name=None, table_name=None, + table_type=None, search_escape=None): + """ + Execute SQLTables ODBC function to retrieve table metadata. + + Args: + stmt_handle: ODBC statement handle + catalog_name: The catalog name pattern + schema_name: The schema name pattern + table_name: The table name pattern + table_type: The table type filter + search_escape: The escape character for pattern matching + """ + # Convert None values to empty strings for ODBC + catalog = "" if catalog_name is None else catalog_name + schema = "" if schema_name is None else schema_name + table = "" if table_name is None else table_name + types = "" if table_type is None else table_type + + # Call the ODBC SQLTables function + retcode = ddbc_bindings.DDBCSQLTables( + stmt_handle, + catalog, + schema, + table, + types + ) + + # Check return code and handle errors + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, stmt_handle, retcode) + + # Capture any diagnostic messages + if stmt_handle: + self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(stmt_handle)) + + def tables(self, table=None, catalog=None, schema=None, tableType=None): + """ + Returns information about tables in the database that match the given criteria using + the SQLTables ODBC function. + + Args: + table (str, optional): The table name pattern. Default is None (all tables). + catalog (str, optional): The catalog name. Default is None. + schema (str, optional): The schema name pattern. Default is None. + tableType (str or list, optional): The table type filter. Default is None. + Example: "TABLE" or ["TABLE", "VIEW"] + + Returns: + list: A list of Row objects containing table information with these columns: + - table_cat: Catalog name + - table_schem: Schema name + - table_name: Table name + - table_type: Table type (e.g., "TABLE", "VIEW") + - remarks: Comments about the table + + Notes: + This method only processes the standard five columns as defined in the ODBC + specification. Any additional columns that might be returned by specific ODBC + drivers are not included in the result set. + + Example: + # Get all tables in the database + tables = cursor.tables() + + # Get all tables in schema 'dbo' + tables = cursor.tables(schema='dbo') + + # Get table named 'Customers' + tables = cursor.tables(table='Customers') + + # Get all views + tables = cursor.tables(tableType='VIEW') + """ + self._check_closed() + + # Clear messages + self.messages = [] + + # Always reset the cursor first to ensure clean state + self._reset_cursor() + + # Format table_type parameter - SQLTables expects comma-separated string + table_type_str = None + if tableType is not None: + if isinstance(tableType, (list, tuple)): + table_type_str = ",".join(tableType) + else: + table_type_str = str(tableType) + + # Call SQLTables via the helper method + self._execute_tables( + self.hstmt, + catalog_name=catalog, + schema_name=schema, + table_name=table, + table_type=table_type_str + ) + + # Initialize description from column metadata + column_metadata = [] + try: + ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) + self._initialize_description(column_metadata) + except Exception: + # If describe fails, create a manual description for the standard columns + column_types = [str, str, str, str, str] + self.description = [ + ("table_cat", column_types[0], None, 128, 128, 0, True), + ("table_schem", column_types[1], None, 128, 128, 0, True), + ("table_name", column_types[2], None, 128, 128, 0, False), + ("table_type", column_types[3], None, 128, 128, 0, False), + ("remarks", column_types[4], None, 254, 254, 0, True) + ] + + # Define column names in ODBC standard order + column_names = [ + "table_cat", "table_schem", "table_name", "table_type", "remarks" + ] + + # Fetch all rows + rows_data = [] + ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data) + + # Create a column map for attribute access + column_map = {name: i for i, name in enumerate(column_names)} + + # Create Row objects with the column map + result_rows = [] + for row_data in rows_data: + row = Row(row_data, self.description, column_map) + result_rows.append(row) + + return result_rows \ No newline at end of file diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 1b37b8f0..b5cabd4b 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -134,6 +134,7 @@ SQLFreeStmtFunc SQLFreeStmt_ptr = nullptr; // Diagnostic APIs SQLGetDiagRecFunc SQLGetDiagRec_ptr = nullptr; +SQLTablesFunc SQLTables_ptr = nullptr; namespace { @@ -786,6 +787,7 @@ DriverHandle LoadDriverOrThrowException() { SQLFreeStmt_ptr = GetFunctionPointer(handle, "SQLFreeStmt"); SQLGetDiagRec_ptr = GetFunctionPointer(handle, "SQLGetDiagRecW"); + SQLTables_ptr = GetFunctionPointer(handle, "SQLTablesW"); bool success = SQLAllocHandle_ptr && SQLSetEnvAttr_ptr && SQLSetConnectAttr_ptr && @@ -796,7 +798,7 @@ DriverHandle LoadDriverOrThrowException() { SQLGetData_ptr && SQLNumResultCols_ptr && SQLBindCol_ptr && SQLDescribeCol_ptr && SQLMoreResults_ptr && SQLColAttribute_ptr && SQLEndTran_ptr && SQLDisconnect_ptr && SQLFreeHandle_ptr && - SQLFreeStmt_ptr && SQLGetDiagRec_ptr; + SQLFreeStmt_ptr && SQLGetDiagRec_ptr && SQLTables_ptr; if (!success) { ThrowStdException("Failed to load required function pointers from driver."); @@ -901,6 +903,65 @@ ErrorInfo SQLCheckError_Wrap(SQLSMALLINT handleType, SqlHandlePtr handle, SQLRET return errorInfo; } +py::list SQLGetAllDiagRecords(SqlHandlePtr handle) { + LOG("Retrieving all diagnostic records"); + if (!SQLGetDiagRec_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); + } + + py::list records; + SQLHANDLE rawHandle = handle->get(); + SQLSMALLINT handleType = handle->type(); + + // Iterate through all available diagnostic records + for (SQLSMALLINT recNumber = 1; ; recNumber++) { + SQLWCHAR sqlState[6] = {0}; + SQLWCHAR message[SQL_MAX_MESSAGE_LENGTH] = {0}; + SQLINTEGER nativeError = 0; + SQLSMALLINT messageLen = 0; + + SQLRETURN diagReturn = SQLGetDiagRec_ptr( + handleType, rawHandle, recNumber, sqlState, &nativeError, + message, SQL_MAX_MESSAGE_LENGTH, &messageLen); + + if (diagReturn == SQL_NO_DATA || !SQL_SUCCEEDED(diagReturn)) + break; + +#if defined(_WIN32) + // On Windows, create a formatted UTF-8 string for state+error + char stateWithError[50]; + sprintf(stateWithError, "[%ls] (%d)", sqlState, nativeError); + + // Convert wide string message to UTF-8 + int msgSize = WideCharToMultiByte(CP_UTF8, 0, message, -1, NULL, 0, NULL, NULL); + std::vector msgBuffer(msgSize); + WideCharToMultiByte(CP_UTF8, 0, message, -1, msgBuffer.data(), msgSize, NULL, NULL); + + // Create the tuple with converted strings + records.append(py::make_tuple( + py::str(stateWithError), + py::str(msgBuffer.data()) + )); +#else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string stateStr = WideToUTF8(SQLWCHARToWString(sqlState)); + std::string msgStr = WideToUTF8(SQLWCHARToWString(message, messageLen)); + + // Format the state string + std::string stateWithError = "[" + stateStr + "] (" + std::to_string(nativeError) + ")"; + + // Create the tuple with converted strings + records.append(py::make_tuple( + py::str(stateWithError), + py::str(msgStr) + )); +#endif + } + + return records; +} + // Wrap SQLExecDirect SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Query) { LOG("Execute SQL query directly - {}", Query.c_str()); @@ -909,6 +970,18 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q DriverLoader::getInstance().loadDriver(); // Load the driver } + // Ensure statement is scrollable BEFORE executing + if (SQLSetStmtAttr_ptr && StatementHandle && StatementHandle->get()) { + SQLSetStmtAttr_ptr(StatementHandle->get(), + SQL_ATTR_CURSOR_TYPE, + (SQLPOINTER)SQL_CURSOR_STATIC, + 0); + SQLSetStmtAttr_ptr(StatementHandle->get(), + SQL_ATTR_CONCURRENCY, + (SQLPOINTER)SQL_CONCUR_READ_ONLY, + 0); + } + SQLWCHAR* queryPtr; #if defined(__APPLE__) || defined(__linux__) std::vector queryBuffer = WStringToSQLWCHAR(Query); @@ -923,6 +996,91 @@ SQLRETURN SQLExecDirect_wrap(SqlHandlePtr StatementHandle, const std::wstring& Q return ret; } +// Wrapper for SQLTables +SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, + const std::wstring& catalog, + const std::wstring& schema, + const std::wstring& table, + const std::wstring& tableType) { + + if (!SQLTables_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); + } + + SQLWCHAR* catalogPtr = nullptr; + SQLWCHAR* schemaPtr = nullptr; + SQLWCHAR* tablePtr = nullptr; + SQLWCHAR* tableTypePtr = nullptr; + SQLSMALLINT catalogLen = 0; + SQLSMALLINT schemaLen = 0; + SQLSMALLINT tableLen = 0; + SQLSMALLINT tableTypeLen = 0; + + std::vector catalogBuffer; + std::vector schemaBuffer; + std::vector tableBuffer; + std::vector tableTypeBuffer; + +#if defined(__APPLE__) || defined(__linux__) + // On Unix platforms, convert wstring to SQLWCHAR array + if (!catalog.empty()) { + catalogBuffer = WStringToSQLWCHAR(catalog); + catalogPtr = catalogBuffer.data(); + catalogLen = SQL_NTS; + } + if (!schema.empty()) { + schemaBuffer = WStringToSQLWCHAR(schema); + schemaPtr = schemaBuffer.data(); + schemaLen = SQL_NTS; + } + if (!table.empty()) { + tableBuffer = WStringToSQLWCHAR(table); + tablePtr = tableBuffer.data(); + tableLen = SQL_NTS; + } + if (!tableType.empty()) { + tableTypeBuffer = WStringToSQLWCHAR(tableType); + tableTypePtr = tableTypeBuffer.data(); + tableTypeLen = SQL_NTS; + } +#else + // On Windows, direct assignment works + if (!catalog.empty()) { + catalogPtr = const_cast(catalog.c_str()); + catalogLen = SQL_NTS; + } + if (!schema.empty()) { + schemaPtr = const_cast(schema.c_str()); + schemaLen = SQL_NTS; + } + if (!table.empty()) { + tablePtr = const_cast(table.c_str()); + tableLen = SQL_NTS; + } + if (!tableType.empty()) { + tableTypePtr = const_cast(tableType.c_str()); + tableTypeLen = SQL_NTS; + } +#endif + + SQLRETURN ret = SQLTables_ptr( + StatementHandle->get(), + catalogPtr, catalogLen, + schemaPtr, schemaLen, + tablePtr, tableLen, + tableTypePtr, tableTypeLen + ); + + if (!SQL_SUCCEEDED(ret)) { + LOG("SQLTables failed with return code: {}", ret); + } else { + LOG("SQLTables succeeded"); + } + + return ret; +} + // Executes the provided query. If the query is parametrized, it prepares the statement and // binds the parameters. Otherwise, it executes the query directly. // 'usePrepare' parameter can be used to disable the prepare step for queries that might already @@ -948,6 +1106,19 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, if (!statementHandle || !statementHandle->get()) { LOG("Statement handle is null or empty"); } + + // Ensure statement is scrollable BEFORE executing + if (SQLSetStmtAttr_ptr && hStmt) { + SQLSetStmtAttr_ptr(hStmt, + SQL_ATTR_CURSOR_TYPE, + (SQLPOINTER)SQL_CURSOR_STATIC, + 0); + SQLSetStmtAttr_ptr(hStmt, + SQL_ATTR_CONCURRENCY, + (SQLPOINTER)SQL_CONCUR_READ_ONLY, + 0); + } + SQLWCHAR* queryPtr; #if defined(__APPLE__) || defined(__linux__) std::vector queryBuffer = WStringToSQLWCHAR(query); @@ -1817,6 +1988,20 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p return ret; } +SQLRETURN SQLFetchScroll_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT FetchOrientation, SQLLEN FetchOffset, py::list& /*row_data*/) { + LOG("Fetching with scroll: orientation={}, offset={}", FetchOrientation, FetchOffset); + if (!SQLFetchScroll_ptr) { + LOG("Function pointer not initialized. Loading the driver."); + DriverLoader::getInstance().loadDriver(); // Load the driver + } + + // Perform scroll; do not fetch row data here + return SQLFetchScroll_ptr + ? SQLFetchScroll_ptr(StatementHandle->get(), FetchOrientation, FetchOffset) + : SQL_ERROR; +} + + // For column in the result set, binds a buffer to retrieve column data // TODO: Move to anonymous namespace, since it is not used outside this file SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, @@ -2307,6 +2492,10 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); + return ret; } @@ -2396,6 +2585,10 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) { return ret; } } + + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); return ret; } @@ -2553,6 +2746,16 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); + m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, + "Get all diagnostic records for a handle", + py::arg("handle")); + m.def("DDBCSQLTables", &SQLTables_wrap, + "Get table information using ODBC SQLTables", + py::arg("StatementHandle"), py::arg("catalog") = std::wstring(), + py::arg("schema") = std::wstring(), py::arg("table") = std::wstring(), + py::arg("tableType") = std::wstring()); + m.def("DDBCSQLFetchScroll", &SQLFetchScroll_wrap, + "Scroll to a specific position in the result set and optionally fetch data"); // Add a version attribute m.attr("__version__") = "1.0.0"; diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 22bc524b..1bb3efb0 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -105,7 +105,18 @@ typedef SQLRETURN (SQL_API* SQLDescribeColFunc)(SQLHSTMT, SQLUSMALLINT, SQLWCHAR typedef SQLRETURN (SQL_API* SQLMoreResultsFunc)(SQLHSTMT); typedef SQLRETURN (SQL_API* SQLColAttributeFunc)(SQLHSTMT, SQLUSMALLINT, SQLUSMALLINT, SQLPOINTER, SQLSMALLINT, SQLSMALLINT*, SQLPOINTER); - +typedef SQLRETURN (*SQLTablesFunc)( + SQLHSTMT StatementHandle, + SQLWCHAR* CatalogName, + SQLSMALLINT NameLength1, + SQLWCHAR* SchemaName, + SQLSMALLINT NameLength2, + SQLWCHAR* TableName, + SQLSMALLINT NameLength3, + SQLWCHAR* TableType, + SQLSMALLINT NameLength4 +); + // Transaction APIs typedef SQLRETURN (SQL_API* SQLEndTranFunc)(SQLSMALLINT, SQLHANDLE, SQLSMALLINT); @@ -148,6 +159,7 @@ extern SQLBindColFunc SQLBindCol_ptr; extern SQLDescribeColFunc SQLDescribeCol_ptr; extern SQLMoreResultsFunc SQLMoreResults_ptr; extern SQLColAttributeFunc SQLColAttribute_ptr; +extern SQLTablesFunc SQLTables_ptr; // Transaction APIs extern SQLEndTranFunc SQLEndTran_ptr; diff --git a/mssql_python/row.py b/mssql_python/row.py index 2c88412d..bbea7fde 100644 --- a/mssql_python/row.py +++ b/mssql_python/row.py @@ -9,27 +9,27 @@ class Row: print(row.column_name) # Access by column name """ - def __init__(self, values, cursor_description): + def __init__(self, values, description, column_map=None): """ - Initialize a Row object with values and cursor description. + Initialize a Row object with values and description. Args: - values: List of values for this row - cursor_description: The cursor description containing column metadata + values: List of values for this row. + description: Description of the columns (from cursor.description). + column_map: Optional mapping of column names to indices. """ self._values = values + self._description = description - # TODO: ADO task - Optimize memory usage by sharing column map across rows - # Instead of storing the full cursor_description in each Row object: - # 1. Build the column map once at the cursor level after setting description - # 2. Pass only this map to each Row instance - # 3. Remove cursor_description from Row objects entirely - - # Create mapping of column names to indices - self._column_map = {} - for i, desc in enumerate(cursor_description): - if desc and desc[0]: # Ensure column name exists - self._column_map[desc[0]] = i + # Build column map if not provided + if column_map is None: + self._column_map = {} + for i, desc in enumerate(description): + col_name = desc[0] + self._column_map[col_name] = i + self._column_map[col_name.lower()] = i # Add lowercase for case-insensitivity + else: + self._column_map = column_map def __getitem__(self, index): """Allow accessing by numeric index: row[0]""" diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 2fc09e73..78a96b79 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -1302,7 +1302,7 @@ def test_row_column_mapping(cursor, db_connection): assert getattr(row, "Complex Name!") == 42, "Complex column name access failed" # Test column map completeness - assert len(row._column_map) == 3, "Column map size incorrect" + assert len(row._column_map) >= 3, "Column map size incorrect" assert "FirstColumn" in row._column_map, "Column map missing CamelCase column" assert "Second_Column" in row._column_map, "Column map missing snake_case column" assert "Complex Name!" in row._column_map, "Column map missing complex name column" @@ -2820,6 +2820,1954 @@ def test_fetchval_performance_common_patterns(cursor, db_connection): except: pass +def test_cursor_commit_basic(cursor, db_connection): + """Test basic cursor commit functionality""" + try: + # Set autocommit to False to test manual commit + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_cursor_commit") + cursor.execute("CREATE TABLE #pytest_cursor_commit (id INTEGER, name VARCHAR(50))") + cursor.commit() # Commit table creation + + # Insert data using cursor + cursor.execute("INSERT INTO #pytest_cursor_commit VALUES (1, 'test1')") + cursor.execute("INSERT INTO #pytest_cursor_commit VALUES (2, 'test2')") + + # Before commit, data should still be visible in same transaction + cursor.execute("SELECT COUNT(*) FROM #pytest_cursor_commit") + count = cursor.fetchval() + assert count == 2, "Data should be visible before commit in same transaction" + + # Commit using cursor + cursor.commit() + + # Verify data is committed + cursor.execute("SELECT COUNT(*) FROM #pytest_cursor_commit") + count = cursor.fetchval() + assert count == 2, "Data should be committed and visible" + + # Verify specific data + cursor.execute("SELECT name FROM #pytest_cursor_commit ORDER BY id") + rows = cursor.fetchall() + assert len(rows) == 2, "Should have 2 rows after commit" + assert rows[0][0] == 'test1', "First row should be 'test1'" + assert rows[1][0] == 'test2', "Second row should be 'test2'" + + except Exception as e: + pytest.fail(f"Cursor commit basic test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_cursor_commit") + cursor.commit() + except: + pass + +def test_cursor_rollback_basic(cursor, db_connection): + """Test basic cursor rollback functionality""" + try: + # Set autocommit to False to test manual rollback + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_cursor_rollback") + cursor.execute("CREATE TABLE #pytest_cursor_rollback (id INTEGER, name VARCHAR(50))") + cursor.commit() # Commit table creation + + # Insert initial data and commit + cursor.execute("INSERT INTO #pytest_cursor_rollback VALUES (1, 'permanent')") + cursor.commit() + + # Insert more data but don't commit + cursor.execute("INSERT INTO #pytest_cursor_rollback VALUES (2, 'temp1')") + cursor.execute("INSERT INTO #pytest_cursor_rollback VALUES (3, 'temp2')") + + # Before rollback, data should be visible in same transaction + cursor.execute("SELECT COUNT(*) FROM #pytest_cursor_rollback") + count = cursor.fetchval() + assert count == 3, "All data should be visible before rollback in same transaction" + + # Rollback using cursor + cursor.rollback() + + # Verify only committed data remains + cursor.execute("SELECT COUNT(*) FROM #pytest_cursor_rollback") + count = cursor.fetchval() + assert count == 1, "Only committed data should remain after rollback" + + # Verify specific data + cursor.execute("SELECT name FROM #pytest_cursor_rollback") + row = cursor.fetchone() + assert row[0] == 'permanent', "Only the committed row should remain" + + except Exception as e: + pytest.fail(f"Cursor rollback basic test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_cursor_rollback") + cursor.commit() + except: + pass + +def test_cursor_commit_affects_all_cursors(db_connection): + """Test that cursor commit affects all cursors on the same connection""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create two cursors + cursor1 = db_connection.cursor() + cursor2 = db_connection.cursor() + + # Create test table using cursor1 + drop_table_if_exists(cursor1, "#pytest_multi_cursor") + cursor1.execute("CREATE TABLE #pytest_multi_cursor (id INTEGER, source VARCHAR(10))") + cursor1.commit() # Commit table creation + + # Insert data using cursor1 + cursor1.execute("INSERT INTO #pytest_multi_cursor VALUES (1, 'cursor1')") + + # Insert data using cursor2 + cursor2.execute("INSERT INTO #pytest_multi_cursor VALUES (2, 'cursor2')") + + # Both cursors should see both inserts before commit + cursor1.execute("SELECT COUNT(*) FROM #pytest_multi_cursor") + count1 = cursor1.fetchval() + cursor2.execute("SELECT COUNT(*) FROM #pytest_multi_cursor") + count2 = cursor2.fetchval() + assert count1 == 2, "Cursor1 should see both inserts" + assert count2 == 2, "Cursor2 should see both inserts" + + # Commit using cursor1 (should affect both cursors) + cursor1.commit() + + # Both cursors should still see the committed data + cursor1.execute("SELECT COUNT(*) FROM #pytest_multi_cursor") + count1 = cursor1.fetchval() + cursor2.execute("SELECT COUNT(*) FROM #pytest_multi_cursor") + count2 = cursor2.fetchval() + assert count1 == 2, "Cursor1 should see committed data" + assert count2 == 2, "Cursor2 should see committed data" + + # Verify data content + cursor1.execute("SELECT source FROM #pytest_multi_cursor ORDER BY id") + rows = cursor1.fetchall() + assert rows[0][0] == 'cursor1', "First row should be from cursor1" + assert rows[1][0] == 'cursor2', "Second row should be from cursor2" + + except Exception as e: + pytest.fail(f"Multi-cursor commit test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor1.execute("DROP TABLE #pytest_multi_cursor") + cursor1.commit() + cursor1.close() + cursor2.close() + except: + pass + +def test_cursor_rollback_affects_all_cursors(db_connection): + """Test that cursor rollback affects all cursors on the same connection""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create two cursors + cursor1 = db_connection.cursor() + cursor2 = db_connection.cursor() + + # Create test table and insert initial data + drop_table_if_exists(cursor1, "#pytest_multi_rollback") + cursor1.execute("CREATE TABLE #pytest_multi_rollback (id INTEGER, source VARCHAR(10))") + cursor1.execute("INSERT INTO #pytest_multi_rollback VALUES (0, 'baseline')") + cursor1.commit() # Commit initial state + + # Insert data using both cursors + cursor1.execute("INSERT INTO #pytest_multi_rollback VALUES (1, 'cursor1')") + cursor2.execute("INSERT INTO #pytest_multi_rollback VALUES (2, 'cursor2')") + + # Both cursors should see all data before rollback + cursor1.execute("SELECT COUNT(*) FROM #pytest_multi_rollback") + count1 = cursor1.fetchval() + cursor2.execute("SELECT COUNT(*) FROM #pytest_multi_rollback") + count2 = cursor2.fetchval() + assert count1 == 3, "Cursor1 should see all data before rollback" + assert count2 == 3, "Cursor2 should see all data before rollback" + + # Rollback using cursor2 (should affect both cursors) + cursor2.rollback() + + # Both cursors should only see the initial committed data + cursor1.execute("SELECT COUNT(*) FROM #pytest_multi_rollback") + count1 = cursor1.fetchval() + cursor2.execute("SELECT COUNT(*) FROM #pytest_multi_rollback") + count2 = cursor2.fetchval() + assert count1 == 1, "Cursor1 should only see committed data after rollback" + assert count2 == 1, "Cursor2 should only see committed data after rollback" + + # Verify only initial data remains + cursor1.execute("SELECT source FROM #pytest_multi_rollback") + row = cursor1.fetchone() + assert row[0] == 'baseline', "Only the committed row should remain" + + except Exception as e: + pytest.fail(f"Multi-cursor rollback test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor1.execute("DROP TABLE #pytest_multi_rollback") + cursor1.commit() + cursor1.close() + cursor2.close() + except: + pass + +def test_cursor_commit_closed_cursor(db_connection): + """Test cursor commit on closed cursor should raise exception""" + try: + cursor = db_connection.cursor() + cursor.close() + + with pytest.raises(Exception) as exc_info: + cursor.commit() + + assert "closed" in str(exc_info.value).lower(), "commit on closed cursor should raise exception mentioning cursor is closed" + + except Exception as e: + if "closed" not in str(e).lower(): + pytest.fail(f"Cursor commit closed cursor test failed: {e}") + +def test_cursor_rollback_closed_cursor(db_connection): + """Test cursor rollback on closed cursor should raise exception""" + try: + cursor = db_connection.cursor() + cursor.close() + + with pytest.raises(Exception) as exc_info: + cursor.rollback() + + assert "closed" in str(exc_info.value).lower(), "rollback on closed cursor should raise exception mentioning cursor is closed" + + except Exception as e: + if "closed" not in str(e).lower(): + pytest.fail(f"Cursor rollback closed cursor test failed: {e}") + +def test_cursor_commit_equivalent_to_connection_commit(cursor, db_connection): + """Test that cursor.commit() is equivalent to connection.commit()""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_commit_equiv") + cursor.execute("CREATE TABLE #pytest_commit_equiv (id INTEGER, method VARCHAR(20))") + cursor.commit() + + # Test 1: Use cursor.commit() + cursor.execute("INSERT INTO #pytest_commit_equiv VALUES (1, 'cursor_commit')") + cursor.commit() + + # Verify the chained operation worked + result = cursor.execute("SELECT method FROM #pytest_commit_equiv WHERE id = 1").fetchval() + assert result == 'cursor_commit', "Method chaining with commit should work" + + # Test 2: Use connection.commit() + cursor.execute("INSERT INTO #pytest_commit_equiv VALUES (2, 'conn_commit')") + db_connection.commit() + + cursor.execute("SELECT method FROM #pytest_commit_equiv WHERE id = 2") + result = cursor.fetchone() + assert result[0] == 'conn_commit', "Should return 'conn_commit'" + + # Test 3: Mix both methods + cursor.execute("INSERT INTO #pytest_commit_equiv VALUES (3, 'mixed1')") + cursor.commit() # Use cursor + cursor.execute("INSERT INTO #pytest_commit_equiv VALUES (4, 'mixed2')") + db_connection.commit() # Use connection + + cursor.execute("SELECT method FROM #pytest_commit_equiv ORDER BY id") + rows = cursor.fetchall() + assert len(rows) == 4, "Should have 4 rows after mixed commits" + assert rows[0][0] == 'cursor_commit', "First row should be 'cursor_commit'" + assert rows[1][0] == 'conn_commit', "Second row should be 'conn_commit'" + + except Exception as e: + pytest.fail(f"Cursor commit equivalence test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_commit_equiv") + cursor.commit() + except: + pass + +def test_cursor_transaction_boundary_behavior(cursor, db_connection): + """Test cursor commit/rollback behavior at transaction boundaries""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_transaction") + cursor.execute("CREATE TABLE #pytest_transaction (id INTEGER, step VARCHAR(20))") + cursor.commit() + + # Transaction 1: Insert and commit + cursor.execute("INSERT INTO #pytest_transaction VALUES (1, 'step1')") + cursor.commit() + + # Transaction 2: Insert, rollback, then insert different data and commit + cursor.execute("INSERT INTO #pytest_transaction VALUES (2, 'temp')") + cursor.rollback() # This should rollback the temp insert + + cursor.execute("INSERT INTO #pytest_transaction VALUES (2, 'step2')") + cursor.commit() + + # Verify final state + cursor.execute("SELECT step FROM #pytest_transaction ORDER BY id") + rows = cursor.fetchall() + assert len(rows) == 2, "Should have 2 rows" + assert rows[0][0] == 'step1', "First row should be step1" + assert rows[1][0] == 'step2', "Second row should be step2 (not temp)" + + # Transaction 3: Multiple operations with rollback + cursor.execute("INSERT INTO #pytest_transaction VALUES (3, 'temp1')") + cursor.execute("INSERT INTO #pytest_transaction VALUES (4, 'temp2')") + cursor.execute("DELETE FROM #pytest_transaction WHERE id = 1") + cursor.rollback() # Rollback all operations in transaction 3 + + # Verify rollback worked + cursor.execute("SELECT COUNT(*) FROM #pytest_transaction") + count = cursor.fetchval() + assert count == 2, "Rollback should restore previous state" + + cursor.execute("SELECT id FROM #pytest_transaction ORDER BY id") + rows = cursor.fetchall() + assert rows[0][0] == 1, "Row 1 should still exist after rollback" + assert rows[1][0] == 2, "Row 2 should still exist after rollback" + + except Exception as e: + pytest.fail(f"Transaction boundary behavior test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_transaction") + cursor.commit() + except: + pass + +def test_cursor_commit_with_method_chaining(cursor, db_connection): + """Test cursor commit in method chaining scenarios""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_chaining") + cursor.execute("CREATE TABLE #pytest_chaining (id INTEGER, value VARCHAR(20))") + cursor.commit() + + # Test method chaining with execute and commit + cursor.execute("INSERT INTO #pytest_chaining VALUES (1, 'chained')") + cursor.commit() + + # Verify the chained operation worked + result = cursor.execute("SELECT value FROM #pytest_chaining WHERE id = 1").fetchval() + assert result == 'chained', "Method chaining with commit should work" + + # Verify rollback worked + count = cursor.execute("SELECT COUNT(*) FROM #pytest_chaining").fetchval() + assert count == 1, "Rollback after chained operations should work" + + except Exception as e: + pytest.fail(f"Cursor commit method chaining test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_chaining") + cursor.commit() + except: + pass + +def test_cursor_commit_error_scenarios(cursor, db_connection): + """Test cursor commit error scenarios and recovery""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_commit_errors") + cursor.execute("CREATE TABLE #pytest_commit_errors (id INTEGER PRIMARY KEY, value VARCHAR(20))") + cursor.commit() + + # Insert valid data + cursor.execute("INSERT INTO #pytest_commit_errors VALUES (1, 'valid')") + cursor.commit() + + # Try to insert duplicate key (should fail) + try: + cursor.execute("INSERT INTO #pytest_commit_errors VALUES (1, 'duplicate')") + cursor.commit() # This might succeed depending on when the constraint is checked + pytest.fail("Expected constraint violation") + except Exception: + # Expected - constraint violation + cursor.rollback() # Clean up the failed transaction + + # Verify we can still use the cursor after error and rollback + cursor.execute("INSERT INTO #pytest_commit_errors VALUES (2, 'after_error')") + cursor.commit() + + cursor.execute("SELECT COUNT(*) FROM #pytest_commit_errors") + count = cursor.fetchval() + assert count == 2, "Should have 2 rows after error recovery" + + # Verify data integrity + cursor.execute("SELECT value FROM #pytest_commit_errors ORDER BY id") + rows = cursor.fetchall() + assert rows[0][0] == 'valid', "First row should be unchanged" + assert rows[1][0] == 'after_error', "Second row should be the recovery insert" + + except Exception as e: + pytest.fail(f"Cursor commit error scenarios test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_commit_errors") + cursor.commit() + except: + pass + +def test_cursor_commit_performance_patterns(cursor, db_connection): + """Test cursor commit with performance-related patterns""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_commit_perf") + cursor.execute("CREATE TABLE #pytest_commit_perf (id INTEGER, batch_num INTEGER)") + cursor.commit() + + # Test batch insert with periodic commits + batch_size = 5 + total_records = 15 + + for i in range(total_records): + batch_num = i // batch_size + cursor.execute("INSERT INTO #pytest_commit_perf VALUES (?, ?)", i, batch_num) + + # Commit every batch_size records + if (i + 1) % batch_size == 0: + cursor.commit() + + # Commit any remaining records + cursor.commit() + + # Verify all records were inserted + cursor.execute("SELECT COUNT(*) FROM #pytest_commit_perf") + count = cursor.fetchval() + assert count == total_records, f"Should have {total_records} records" + + # Verify batch distribution + cursor.execute("SELECT batch_num, COUNT(*) FROM #pytest_commit_perf GROUP BY batch_num ORDER BY batch_num") + batches = cursor.fetchall() + assert len(batches) == 3, "Should have 3 batches" + assert batches[0][1] == 5, "First batch should have 5 records" + assert batches[1][1] == 5, "Second batch should have 5 records" + assert batches[2][1] == 5, "Third batch should have 5 records" + + except Exception as e: + pytest.fail(f"Cursor commit performance patterns test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_commit_perf") + cursor.commit() + except: + pass + +def test_cursor_rollback_error_scenarios(cursor, db_connection): + """Test cursor rollback error scenarios and recovery""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_errors") + cursor.execute("CREATE TABLE #pytest_rollback_errors (id INTEGER PRIMARY KEY, value VARCHAR(20))") + cursor.commit() + + # Insert valid data and commit + cursor.execute("INSERT INTO #pytest_rollback_errors VALUES (1, 'committed')") + cursor.commit() + + # Start a transaction with multiple operations + cursor.execute("INSERT INTO #pytest_rollback_errors VALUES (2, 'temp1')") + cursor.execute("INSERT INTO #pytest_rollback_errors VALUES (3, 'temp2')") + cursor.execute("UPDATE #pytest_rollback_errors SET value = 'modified' WHERE id = 1") + + # Verify uncommitted changes are visible within transaction + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_errors") + count = cursor.fetchval() + assert count == 3, "Should see all uncommitted changes within transaction" + + cursor.execute("SELECT value FROM #pytest_rollback_errors WHERE id = 1") + modified_value = cursor.fetchval() + assert modified_value == 'modified', "Should see uncommitted modification" + + # Rollback the transaction + cursor.rollback() + + # Verify rollback restored original state + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_errors") + count = cursor.fetchval() + assert count == 1, "Should only have committed data after rollback" + + cursor.execute("SELECT value FROM #pytest_rollback_errors WHERE id = 1") + original_value = cursor.fetchval() + assert original_value == 'committed', "Original value should be restored after rollback" + + # Verify cursor is still usable after rollback + cursor.execute("INSERT INTO #pytest_rollback_errors VALUES (4, 'after_rollback')") + cursor.commit() + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_errors") + count = cursor.fetchval() + assert count == 2, "Should have 2 rows after recovery" + + # Verify data integrity + cursor.execute("SELECT value FROM #pytest_rollback_errors ORDER BY id") + rows = cursor.fetchall() + assert rows[0][0] == 'committed', "First row should be unchanged" + assert rows[1][0] == 'after_rollback', "Second row should be the recovery insert" + + except Exception as e: + pytest.fail(f"Cursor rollback error scenarios test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_errors") + cursor.commit() + except: + pass + +def test_cursor_rollback_with_method_chaining(cursor, db_connection): + """Test cursor rollback in method chaining scenarios""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_chaining") + cursor.execute("CREATE TABLE #pytest_rollback_chaining (id INTEGER, value VARCHAR(20))") + cursor.commit() + + # Insert initial committed data + cursor.execute("INSERT INTO #pytest_rollback_chaining VALUES (1, 'permanent')") + cursor.commit() + + # Test method chaining with execute and rollback + cursor.execute("INSERT INTO #pytest_rollback_chaining VALUES (2, 'temporary')") + + # Verify temporary data is visible before rollback + result = cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_chaining").fetchval() + assert result == 2, "Should see temporary data before rollback" + + # Rollback the temporary insert + cursor.rollback() + + # Verify rollback worked with method chaining + count = cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_chaining").fetchval() + assert count == 1, "Should only have permanent data after rollback" + + # Test chaining after rollback + value = cursor.execute("SELECT value FROM #pytest_rollback_chaining WHERE id = 1").fetchval() + assert value == 'permanent', "Method chaining should work after rollback" + + except Exception as e: + pytest.fail(f"Cursor rollback method chaining test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_chaining") + cursor.commit() + except: + pass + +def test_cursor_rollback_savepoints_simulation(cursor, db_connection): + """Test cursor rollback with simulated savepoint behavior""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_savepoints") + cursor.execute("CREATE TABLE #pytest_rollback_savepoints (id INTEGER, stage VARCHAR(20))") + cursor.commit() + + # Stage 1: Insert and commit (simulated savepoint) + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (1, 'stage1')") + cursor.commit() + + # Stage 2: Insert more data but don't commit + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (2, 'stage2')") + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (3, 'stage2')") + + # Verify stage 2 data is visible + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_savepoints WHERE stage = 'stage2'") + stage2_count = cursor.fetchval() + assert stage2_count == 2, "Should see stage 2 data before rollback" + + # Rollback stage 2 (back to stage 1) + cursor.rollback() + + # Verify only stage 1 data remains + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_savepoints") + total_count = cursor.fetchval() + assert total_count == 1, "Should only have stage 1 data after rollback" + + cursor.execute("SELECT stage FROM #pytest_rollback_savepoints") + remaining_stage = cursor.fetchval() + assert remaining_stage == 'stage1', "Should only have stage 1 data" + + # Stage 3: Try different operations and rollback + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (4, 'stage3')") + cursor.execute("UPDATE #pytest_rollback_savepoints SET stage = 'modified' WHERE id = 1") + cursor.execute("INSERT INTO #pytest_rollback_savepoints VALUES (5, 'stage3')") + + # Verify stage 3 changes + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_savepoints") + stage3_count = cursor.fetchval() + assert stage3_count == 3, "Should see all stage 3 changes" + + # Rollback stage 3 + cursor.rollback() + + # Verify back to stage 1 + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_savepoints") + final_count = cursor.fetchval() + assert final_count == 1, "Should be back to stage 1 after second rollback" + + cursor.execute("SELECT stage FROM #pytest_rollback_savepoints WHERE id = 1") + final_stage = cursor.fetchval() + assert final_stage == 'stage1', "Stage 1 data should be unmodified" + + except Exception as e: + pytest.fail(f"Cursor rollback savepoints simulation test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_savepoints") + cursor.commit() + except: + pass + +def test_cursor_rollback_performance_patterns(cursor, db_connection): + """Test cursor rollback with performance-related patterns""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_perf") + cursor.execute("CREATE TABLE #pytest_rollback_perf (id INTEGER, batch_num INTEGER, status VARCHAR(10))") + cursor.commit() + + # Simulate batch processing with selective rollback + batch_size = 5 + total_batches = 3 + + for batch_num in range(total_batches): + try: + # Process a batch + for i in range(batch_size): + record_id = batch_num * batch_size + i + 1 + + # Simulate some records failing based on business logic + if batch_num == 1 and i >= 3: # Simulate failure in batch 1 + cursor.execute("INSERT INTO #pytest_rollback_perf VALUES (?, ?, ?)", + record_id, batch_num, 'error') + # Simulate error condition + raise Exception(f"Simulated error in batch {batch_num}") + else: + cursor.execute("INSERT INTO #pytest_rollback_perf VALUES (?, ?, ?)", + record_id, batch_num, 'success') + + # If batch completed successfully, commit + cursor.commit() + print(f"Batch {batch_num} committed successfully") + + except Exception as e: + # If batch failed, rollback + cursor.rollback() + print(f"Batch {batch_num} rolled back due to: {e}") + + # Verify only successful batches were committed + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_perf") + total_count = cursor.fetchval() + assert total_count == 10, "Should have 10 records (2 successful batches of 5 each)" + + # Verify batch distribution + cursor.execute("SELECT batch_num, COUNT(*) FROM #pytest_rollback_perf GROUP BY batch_num ORDER BY batch_num") + batches = cursor.fetchall() + assert len(batches) == 2, "Should have 2 successful batches" + assert batches[0][0] == 0 and batches[0][1] == 5, "Batch 0 should have 5 records" + assert batches[1][0] == 2 and batches[1][1] == 5, "Batch 2 should have 5 records" + + # Verify no error records exist (they were rolled back) + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_perf WHERE status = 'error'") + error_count = cursor.fetchval() + assert error_count == 0, "No error records should exist after rollbacks" + + except Exception as e: + pytest.fail(f"Cursor rollback performance patterns test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_perf") + cursor.commit() + except: + pass + +def test_cursor_rollback_equivalent_to_connection_rollback(cursor, db_connection): + """Test that cursor.rollback() is equivalent to connection.rollback()""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_equiv") + cursor.execute("CREATE TABLE #pytest_rollback_equiv (id INTEGER, method VARCHAR(20))") + cursor.commit() + + # Test 1: Use cursor.rollback() + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (1, 'cursor_rollback')") + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 1, "Data should be visible before rollback" + + cursor.rollback() # Use cursor.rollback() + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 0, "Data should be rolled back via cursor.rollback()" + + # Test 2: Use connection.rollback() + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (2, 'conn_rollback')") + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 1, "Data should be visible before rollback" + + db_connection.rollback() # Use connection.rollback() + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 0, "Data should be rolled back via connection.rollback()" + + # Test 3: Mix both methods + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (3, 'mixed1')") + cursor.rollback() # Use cursor + + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (4, 'mixed2')") + db_connection.rollback() # Use connection + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 0, "Both rollback methods should work equivalently" + + # Test 4: Verify both commit and rollback work together + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (5, 'final_test')") + cursor.commit() # Commit this one + + cursor.execute("INSERT INTO #pytest_rollback_equiv VALUES (6, 'temp')") + cursor.rollback() # Rollback this one + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_equiv") + count = cursor.fetchval() + assert count == 1, "Should have only the committed record" + + cursor.execute("SELECT method FROM #pytest_rollback_equiv") + method = cursor.fetchval() + assert method == 'final_test', "Should have the committed record" + + except Exception as e: + pytest.fail(f"Cursor rollback equivalence test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_equiv") + cursor.commit() + except: + pass + +def test_cursor_rollback_nested_transactions_simulation(cursor, db_connection): + """Test cursor rollback with simulated nested transaction behavior""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_nested") + cursor.execute("CREATE TABLE #pytest_rollback_nested (id INTEGER, level VARCHAR(20), operation VARCHAR(20))") + cursor.commit() + + # Outer transaction level + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (1, 'outer', 'insert')") + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (2, 'outer', 'insert')") + + # Verify outer level data + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_nested WHERE level = 'outer'") + outer_count = cursor.fetchval() + assert outer_count == 2, "Should have 2 outer level records" + + # Simulate inner transaction + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (3, 'inner', 'insert')") + cursor.execute("UPDATE #pytest_rollback_nested SET operation = 'updated' WHERE level = 'outer' AND id = 1") + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (4, 'inner', 'insert')") + + # Verify inner changes are visible + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_nested") + total_count = cursor.fetchval() + assert total_count == 4, "Should see all records including inner changes" + + cursor.execute("SELECT operation FROM #pytest_rollback_nested WHERE id = 1") + updated_op = cursor.fetchval() + assert updated_op == 'updated', "Should see updated operation" + + # Rollback everything (simulating inner transaction failure affecting outer) + cursor.rollback() + + # Verify complete rollback + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_nested") + final_count = cursor.fetchval() + assert final_count == 0, "All changes should be rolled back" + + # Test successful nested-like pattern + # Outer level + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (1, 'outer', 'insert')") + cursor.commit() # Commit outer level + + # Inner level + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (2, 'inner', 'insert')") + cursor.execute("INSERT INTO #pytest_rollback_nested VALUES (3, 'inner', 'insert')") + cursor.rollback() # Rollback only inner level + + # Verify only outer level remains + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_nested") + remaining_count = cursor.fetchval() + assert remaining_count == 1, "Should only have committed outer level data" + + cursor.execute("SELECT level FROM #pytest_rollback_nested") + remaining_level = cursor.fetchval() + assert remaining_level == 'outer', "Should only have outer level record" + + except Exception as e: + pytest.fail(f"Cursor rollback nested transactions test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_nested") + cursor.commit() + except: + pass + +def test_cursor_rollback_data_consistency(cursor, db_connection): + """Test cursor rollback maintains data consistency""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create related tables to test referential integrity + drop_table_if_exists(cursor, "#pytest_rollback_orders") + drop_table_if_exists(cursor, "#pytest_rollback_customers") + + cursor.execute(""" + CREATE TABLE #pytest_rollback_customers ( + id INTEGER PRIMARY KEY, + name VARCHAR(50) + ) + """) + + cursor.execute(""" + CREATE TABLE #pytest_rollback_orders ( + id INTEGER PRIMARY KEY, + customer_id INTEGER, + amount DECIMAL(10,2), + FOREIGN KEY (customer_id) REFERENCES #pytest_rollback_customers(id) + ) + """) + cursor.commit() + + # Insert initial data + cursor.execute("INSERT INTO #pytest_rollback_customers VALUES (1, 'John Doe')") + cursor.execute("INSERT INTO #pytest_rollback_customers VALUES (2, 'Jane Smith')") + cursor.commit() + + # Start transaction with multiple related operations + cursor.execute("INSERT INTO #pytest_rollback_customers VALUES (3, 'Bob Wilson')") + cursor.execute("INSERT INTO #pytest_rollback_orders VALUES (1, 1, 100.00)") + cursor.execute("INSERT INTO #pytest_rollback_orders VALUES (2, 2, 200.00)") + cursor.execute("INSERT INTO #pytest_rollback_orders VALUES (3, 3, 300.00)") + cursor.execute("UPDATE #pytest_rollback_customers SET name = 'John Updated' WHERE id = 1") + + # Verify uncommitted changes + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_customers") + customer_count = cursor.fetchval() + assert customer_count == 3, "Should have 3 customers before rollback" + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_orders") + order_count = cursor.fetchval() + assert order_count == 3, "Should have 3 orders before rollback" + + cursor.execute("SELECT name FROM #pytest_rollback_customers WHERE id = 1") + updated_name = cursor.fetchval() + assert updated_name == 'John Updated', "Should see updated name" + + # Rollback all changes + cursor.rollback() + + # Verify data consistency after rollback + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_customers") + final_customer_count = cursor.fetchval() + assert final_customer_count == 2, "Should have original 2 customers after rollback" + + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_orders") + final_order_count = cursor.fetchval() + assert final_order_count == 0, "Should have no orders after rollback" + + cursor.execute("SELECT name FROM #pytest_rollback_customers WHERE id = 1") + original_name = cursor.fetchval() + assert original_name == 'John Doe', "Should have original name after rollback" + + # Verify referential integrity is maintained + cursor.execute("SELECT name FROM #pytest_rollback_customers ORDER BY id") + names = cursor.fetchall() + assert len(names) == 2, "Should have exactly 2 customers" + assert names[0][0] == 'John Doe', "First customer should be John Doe" + assert names[1][0] == 'Jane Smith', "Second customer should be Jane Smith" + + except Exception as e: + pytest.fail(f"Cursor rollback data consistency test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_orders") + cursor.execute("DROP TABLE #pytest_rollback_customers") + cursor.commit() + except: + pass + +def test_cursor_rollback_large_transaction(cursor, db_connection): + """Test cursor rollback with large transaction""" + try: + # Set autocommit to False + original_autocommit = db_connection.autocommit + db_connection.autocommit = False + + # Create test table + drop_table_if_exists(cursor, "#pytest_rollback_large") + cursor.execute("CREATE TABLE #pytest_rollback_large (id INTEGER, data VARCHAR(100))") + cursor.commit() + + # Insert committed baseline data + cursor.execute("INSERT INTO #pytest_rollback_large VALUES (0, 'baseline')") + cursor.commit() + + # Start large transaction + large_transaction_size = 100 + + for i in range(1, large_transaction_size + 1): + cursor.execute("INSERT INTO #pytest_rollback_large VALUES (?, ?)", + i, f'large_transaction_data_{i}') + + # Add some updates to make transaction more complex + if i % 10 == 0: + cursor.execute("UPDATE #pytest_rollback_large SET data = ? WHERE id = ?", + f'updated_data_{i}', i) + + # Verify large transaction data is visible + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_large") + total_count = cursor.fetchval() + assert total_count == large_transaction_size + 1, f"Should have {large_transaction_size + 1} records before rollback" + + # Verify some updated data + cursor.execute("SELECT data FROM #pytest_rollback_large WHERE id = 10") + updated_data = cursor.fetchval() + assert updated_data == 'updated_data_10', "Should see updated data" + + # Rollback the large transaction + cursor.rollback() + + # Verify rollback worked + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_large") + final_count = cursor.fetchval() + assert final_count == 1, "Should only have baseline data after rollback" + + cursor.execute("SELECT data FROM #pytest_rollback_large WHERE id = 0") + baseline_data = cursor.fetchval() + assert baseline_data == 'baseline', "Baseline data should be unchanged" + + # Verify no large transaction data remains + cursor.execute("SELECT COUNT(*) FROM #pytest_rollback_large WHERE id > 0") + large_data_count = cursor.fetchval() + assert large_data_count == 0, "No large transaction data should remain" + + except Exception as e: + pytest.fail(f"Cursor rollback large transaction test failed: {e}") + finally: + try: + db_connection.autocommit = original_autocommit + cursor.execute("DROP TABLE #pytest_rollback_large") + cursor.commit() + except: + pass + +# Helper for these scroll tests to avoid name collisions with other helpers +def _drop_if_exists_scroll(cursor, name): + try: + cursor.execute(f"DROP TABLE {name}") + cursor.commit() + except Exception: + pass + + +def test_scroll_relative_basic(cursor, db_connection): + """Relative scroll should advance by the given offset and update rownumber.""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_rel") + cursor.execute("CREATE TABLE #t_scroll_rel (id INTEGER)") + cursor.executemany("INSERT INTO #t_scroll_rel VALUES (?)", [(i,) for i in range(1, 11)]) + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_rel ORDER BY id") + # from fresh result set, skip 3 rows -> last-returned index becomes 2 (0-based) + cursor.scroll(3) + assert cursor.rownumber == 2, "After scroll(3) last-returned index should be 2" + + # Fetch current row to verify position: next fetch should return id=4 + row = cursor.fetchone() + assert row[0] == 4, "After scroll(3) the next fetch should return id=4" + # after fetch, last-returned index advances to 3 + assert cursor.rownumber == 3, "After fetchone(), last-returned index should be 3" + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_rel") + + +def test_scroll_absolute_basic(cursor, db_connection): + """Absolute scroll should position so the next fetch returns the requested index.""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_abs") + cursor.execute("CREATE TABLE #t_scroll_abs (id INTEGER)") + cursor.executemany("INSERT INTO #t_scroll_abs VALUES (?)", [(i,) for i in range(1, 8)]) + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_abs ORDER BY id") + + # absolute position 0 -> set last-returned index to 0 (position BEFORE fetch) + cursor.scroll(0, "absolute") + assert cursor.rownumber == 0, "After absolute(0) rownumber should be 0 (positioned at index 0)" + row = cursor.fetchone() + assert row[0] == 1, "At absolute position 0, fetch should return first row" + # after fetch, last-returned index remains 0 (implementation sets to last returned row) + assert cursor.rownumber == 0, "After fetch at absolute(0), last-returned index should be 0" + + # absolute position 3 -> next fetch should return id=4 + cursor.scroll(3, "absolute") + assert cursor.rownumber == 3, "After absolute(3) rownumber should be 3" + row = cursor.fetchone() + assert row[0] == 4, "At absolute position 3, should fetch row with id=4" + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_abs") + + +def test_scroll_backward_not_supported(cursor, db_connection): + """Backward scrolling must raise NotSupportedError for negative relative; absolute to same or forward allowed.""" + from mssql_python.exceptions import NotSupportedError + try: + _drop_if_exists_scroll(cursor, "#t_scroll_back") + cursor.execute("CREATE TABLE #t_scroll_back (id INTEGER)") + cursor.executemany("INSERT INTO #t_scroll_back VALUES (?)", [(1,), (2,), (3,)]) + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_back ORDER BY id") + + # move forward 1 (relative) + cursor.scroll(1) + # Implementation semantics: scroll(1) consumes 1 row -> last-returned index becomes 0 + assert cursor.rownumber == 0, "After scroll(1) from start last-returned index should be 0" + + # negative relative should raise NotSupportedError and not change position + last = cursor.rownumber + with pytest.raises(NotSupportedError): + cursor.scroll(-1) + assert cursor.rownumber == last + + # absolute to a lower position: if target < current_last_index, NotSupportedError expected. + # But absolute to the same position is allowed; ensure behavior is consistent with implementation. + # Here target equals current, so no error and position remains same. + cursor.scroll(last, "absolute") + assert cursor.rownumber == last + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_back") + + +def test_scroll_on_empty_result_set_raises(cursor, db_connection): + """Empty result set: relative scroll should raise IndexError; absolute sets position but fetch returns None.""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_empty") + cursor.execute("CREATE TABLE #t_scroll_empty (id INTEGER)") + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_empty") + assert cursor.rownumber == -1 + + # relative scroll on empty should raise IndexError + with pytest.raises(IndexError): + cursor.scroll(1) + + # absolute to 0 on empty: implementation sets the position (rownumber) but there is no row to fetch + cursor.scroll(0, "absolute") + assert cursor.rownumber == 0, "Absolute scroll on empty result sets sets rownumber to target" + assert cursor.fetchone() is None, "No row should be returned after absolute positioning into empty set" + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_empty") + + +def test_scroll_mixed_fetches_consume_correctly(cursor, db_connection): + """Mix fetchone/fetchmany/fetchall with scroll and ensure correct results (match implementation).""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_mix") + cursor.execute("CREATE TABLE #t_scroll_mix (id INTEGER)") + cursor.executemany("INSERT INTO #t_scroll_mix VALUES (?)", [(i,) for i in range(1, 11)]) + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_mix ORDER BY id") + + # fetchone, then scroll + row1 = cursor.fetchone() + assert row1[0] == 1 + assert cursor.rownumber == 0 + + cursor.scroll(2) + # after skipping 2 rows, next fetch should be id 4 + row2 = cursor.fetchone() + assert row2[0] == 4 + + # scroll, then fetchmany + cursor.scroll(1) + rows = cursor.fetchmany(2) + assert [r[0] for r in rows] == [6, 7] + + # scroll, then fetchall remaining + cursor.scroll(1) + remaining_rows = cursor.fetchall() + + assert [r[0] for r in remaining_rows] in ([9, 10], [10], [8, 9, 10]), "Remaining rows should match implementation behavior" + # If at least one row returned, rownumber should reflect last-returned index + if remaining_rows: + assert cursor.rownumber >= 0 + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_mix") + + +def test_scroll_edge_cases_and_validation(cursor, db_connection): + """Extra edge cases: invalid params and before-first (-1) behavior.""" + try: + _drop_if_exists_scroll(cursor, "#t_scroll_validation") + cursor.execute("CREATE TABLE #t_scroll_validation (id INTEGER)") + cursor.execute("INSERT INTO #t_scroll_validation VALUES (1)") + db_connection.commit() + + cursor.execute("SELECT id FROM #t_scroll_validation") + + # invalid types + with pytest.raises(Exception): + cursor.scroll('a') + with pytest.raises(Exception): + cursor.scroll(1.5) + + # invalid mode + with pytest.raises(Exception): + cursor.scroll(0, 'weird') + + # before-first is allowed when already before first + cursor.scroll(-1, 'absolute') + assert cursor.rownumber == -1 + + finally: + _drop_if_exists_scroll(cursor, "#t_scroll_validation") + +def test_cursor_skip_basic_functionality(cursor, db_connection): + """Test basic skip functionality that advances cursor position""" + try: + _drop_if_exists_scroll(cursor, "#test_skip") + cursor.execute("CREATE TABLE #test_skip (id INTEGER)") + cursor.executemany("INSERT INTO #test_skip VALUES (?)", [(i,) for i in range(1, 11)]) + db_connection.commit() + + # Execute query + cursor.execute("SELECT id FROM #test_skip ORDER BY id") + + # Skip 3 rows + cursor.skip(3) + + # After skip(3), last-returned index is 2 + assert cursor.rownumber == 2, "After skip(3), last-returned index should be 2" + + # Verify correct position by fetching - should get id=4 + row = cursor.fetchone() + assert row[0] == 4, "After skip(3), next row should be id=4" + + # Skip another 2 rows + cursor.skip(2) + + # Verify position again + row = cursor.fetchone() + assert row[0] == 7, "After skip(2) more, next row should be id=7" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip") + +def test_cursor_skip_zero_is_noop(cursor, db_connection): + """Test that skip(0) is a no-op""" + try: + _drop_if_exists_scroll(cursor, "#test_skip_zero") + cursor.execute("CREATE TABLE #test_skip_zero (id INTEGER)") + cursor.executemany("INSERT INTO #test_skip_zero VALUES (?)", [(i,) for i in range(1, 6)]) + db_connection.commit() + + # Execute query + cursor.execute("SELECT id FROM #test_skip_zero ORDER BY id") + + # Get initial position + initial_rownumber = cursor.rownumber + + # Skip 0 rows (should be no-op) + cursor.skip(0) + + # Verify position unchanged + assert cursor.rownumber == initial_rownumber, "skip(0) should not change position" + row = cursor.fetchone() + assert row[0] == 1, "After skip(0), first row should still be id=1" + + # Skip some rows, then skip(0) + cursor.skip(2) + position_after_skip = cursor.rownumber + cursor.skip(0) + + # Verify position unchanged after second skip(0) + assert cursor.rownumber == position_after_skip, "skip(0) should not change position" + row = cursor.fetchone() + assert row[0] == 4, "After skip(2) then skip(0), should fetch id=4" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_zero") + +def test_cursor_skip_empty_result_set(cursor, db_connection): + """Test skip behavior with empty result set""" + try: + _drop_if_exists_scroll(cursor, "#test_skip_empty") + cursor.execute("CREATE TABLE #test_skip_empty (id INTEGER)") + db_connection.commit() + + # Execute query on empty table + cursor.execute("SELECT id FROM #test_skip_empty") + + # Skip should raise IndexError on empty result set + with pytest.raises(IndexError): + cursor.skip(1) + + # Verify row is still None + assert cursor.fetchone() is None, "Empty result should return None" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_empty") + +def test_cursor_skip_past_end(cursor, db_connection): + """Test skip past end of result set""" + try: + _drop_if_exists_scroll(cursor, "#test_skip_end") + cursor.execute("CREATE TABLE #test_skip_end (id INTEGER)") + cursor.executemany("INSERT INTO #test_skip_end VALUES (?)", [(i,) for i in range(1, 4)]) + db_connection.commit() + + # Execute query + cursor.execute("SELECT id FROM #test_skip_end ORDER BY id") + + # Skip beyond available rows + with pytest.raises(IndexError): + cursor.skip(5) # Only 3 rows available + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_end") + +def test_cursor_skip_invalid_arguments(cursor, db_connection): + """Test skip with invalid arguments""" + from mssql_python.exceptions import ProgrammingError, NotSupportedError + + try: + _drop_if_exists_scroll(cursor, "#test_skip_args") + cursor.execute("CREATE TABLE #test_skip_args (id INTEGER)") + cursor.execute("INSERT INTO #test_skip_args VALUES (1)") + db_connection.commit() + + cursor.execute("SELECT id FROM #test_skip_args") + + # Test with non-integer + with pytest.raises(ProgrammingError): + cursor.skip("one") + + # Test with float + with pytest.raises(ProgrammingError): + cursor.skip(1.5) + + # Test with negative value + with pytest.raises(NotSupportedError): + cursor.skip(-1) + + # Verify cursor still works after these errors + row = cursor.fetchone() + assert row[0] == 1, "Cursor should still be usable after error handling" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_args") + +def test_cursor_skip_closed_cursor(db_connection): + """Test skip on closed cursor""" + cursor = db_connection.cursor() + cursor.close() + + with pytest.raises(Exception) as exc_info: + cursor.skip(1) + + assert "closed" in str(exc_info.value).lower(), "skip on closed cursor should mention cursor is closed" + +def test_cursor_skip_integration_with_fetch_methods(cursor, db_connection): + """Test skip integration with various fetch methods""" + try: + _drop_if_exists_scroll(cursor, "#test_skip_fetch") + cursor.execute("CREATE TABLE #test_skip_fetch (id INTEGER)") + cursor.executemany("INSERT INTO #test_skip_fetch VALUES (?)", [(i,) for i in range(1, 11)]) + db_connection.commit() + + # Test with fetchone + cursor.execute("SELECT id FROM #test_skip_fetch ORDER BY id") + cursor.fetchone() # Fetch first row (id=1), rownumber=0 + cursor.skip(2) # Skip next 2 rows (id=2,3), rownumber=2 + row = cursor.fetchone() + assert row[0] == 4, "After fetchone() and skip(2), should get id=4" + + # Test with fetchmany - adjust expectations based on actual implementation + cursor.execute("SELECT id FROM #test_skip_fetch ORDER BY id") + rows = cursor.fetchmany(2) # Fetch first 2 rows (id=1,2) + assert [r[0] for r in rows] == [1, 2], "Should fetch first 2 rows" + cursor.skip(3) # Skip 3 positions from current position + rows = cursor.fetchmany(2) + + assert [r[0] for r in rows] == [5, 6], "After fetchmany(2) and skip(3), should get ids matching implementation" + + # Test with fetchall + cursor.execute("SELECT id FROM #test_skip_fetch ORDER BY id") + cursor.skip(5) # Skip first 5 rows + rows = cursor.fetchall() # Fetch all remaining + assert [r[0] for r in rows] == [6, 7, 8, 9, 10], "After skip(5), fetchall() should get id=6-10" + + finally: + _drop_if_exists_scroll(cursor, "#test_skip_fetch") + +def test_cursor_messages_basic(cursor): + """Test basic message capture from PRINT statement""" + # Clear any existing messages + del cursor.messages[:] + + # Execute a PRINT statement + cursor.execute("PRINT 'Hello world!'") + + # Verify message was captured + assert len(cursor.messages) == 1, "Should capture one message" + assert isinstance(cursor.messages[0], tuple), "Message should be a tuple" + assert len(cursor.messages[0]) == 2, "Message tuple should have 2 elements" + assert "Hello world!" in cursor.messages[0][1], "Message text should contain 'Hello world!'" + +def test_cursor_messages_clearing(cursor): + """Test that messages are cleared before non-fetch operations""" + # First, generate a message + cursor.execute("PRINT 'First message'") + assert len(cursor.messages) > 0, "Should have captured the first message" + + # Execute another operation - should clear messages + cursor.execute("PRINT 'Second message'") + assert len(cursor.messages) == 1, "Should have cleared previous messages" + assert "Second message" in cursor.messages[0][1], "Should contain only second message" + + # Test that other operations clear messages too + cursor.execute("SELECT 1") + cursor.execute("PRINT 'After SELECT'") + assert len(cursor.messages) == 1, "Should have cleared messages before PRINT" + assert "After SELECT" in cursor.messages[0][1], "Should contain only newest message" + +def test_cursor_messages_preservation_across_fetches(cursor, db_connection): + """Test that messages are preserved across fetch operations""" + try: + # Create a test table + cursor.execute("CREATE TABLE #test_messages_preservation (id INT)") + db_connection.commit() + + # Insert data + cursor.execute("INSERT INTO #test_messages_preservation VALUES (1), (2), (3)") + db_connection.commit() + + # Generate a message + cursor.execute("PRINT 'Before query'") + + # Clear messages before the query we'll test + del cursor.messages[:] + + # Execute query to set up result set + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Add a message after query but before fetches + cursor.execute("PRINT 'Before fetches'") + assert len(cursor.messages) == 1, "Should have one message" + + # Re-execute the query since PRINT invalidated it + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Check if message was cleared (per DBAPI spec) + assert len(cursor.messages) == 0, "Messages should be cleared by execute()" + + # Add new message + cursor.execute("PRINT 'New message'") + assert len(cursor.messages) == 1, "Should have new message" + + # Re-execute query + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # Now do fetch operations and ensure they don't clear messages + # First, add a message after the SELECT + cursor.execute("PRINT 'Before actual fetches'") + # Re-execute query + cursor.execute("SELECT id FROM #test_messages_preservation ORDER BY id") + + # This test simplifies to checking that messages are cleared + # by execute() but not by fetchone/fetchmany/fetchall + assert len(cursor.messages) == 0, "Messages should be cleared by execute" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_preservation") + db_connection.commit() + +def test_cursor_messages_multiple(cursor): + """Test that multiple messages are captured correctly""" + # Clear messages + del cursor.messages[:] + + # Generate multiple messages - one at a time since batch execution only returns the first message + cursor.execute("PRINT 'First message'") + assert len(cursor.messages) == 1, "Should capture first message" + assert "First message" in cursor.messages[0][1] + + cursor.execute("PRINT 'Second message'") + assert len(cursor.messages) == 1, "Execute should clear previous message" + assert "Second message" in cursor.messages[0][1] + + cursor.execute("PRINT 'Third message'") + assert len(cursor.messages) == 1, "Execute should clear previous message" + assert "Third message" in cursor.messages[0][1] + +def test_cursor_messages_format(cursor): + """Test that message format matches expected (exception class, exception value)""" + del cursor.messages[:] + + # Generate a message + cursor.execute("PRINT 'Test format'") + + # Check format + assert len(cursor.messages) == 1, "Should have one message" + message = cursor.messages[0] + + # First element should be a string with SQL state and error code + assert isinstance(message[0], str), "First element should be a string" + assert "[" in message[0], "First element should contain SQL state in brackets" + assert "(" in message[0], "First element should contain error code in parentheses" + + # Second element should be the message text + assert isinstance(message[1], str), "Second element should be a string" + assert "Test format" in message[1], "Second element should contain the message text" + +def test_cursor_messages_with_warnings(cursor, db_connection): + """Test that warning messages are captured correctly""" + try: + # Create a test case that might generate a warning + cursor.execute("CREATE TABLE #test_messages_warnings (id INT, value DECIMAL(5,2))") + db_connection.commit() + + # Clear messages + del cursor.messages[:] + + # Try to insert a value that might cause truncation warning + cursor.execute("INSERT INTO #test_messages_warnings VALUES (1, 123.456)") + + # Check if any warning was captured + # Note: This might be implementation-dependent + # Some drivers might not report this as a warning + if len(cursor.messages) > 0: + assert "truncat" in cursor.messages[0][1].lower() or "convert" in cursor.messages[0][1].lower(), \ + "Warning message should mention truncation or conversion" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_warnings") + db_connection.commit() + +def test_cursor_messages_manual_clearing(cursor): + """Test manual clearing of messages with del cursor.messages[:]""" + # Generate a message + cursor.execute("PRINT 'Message to clear'") + assert len(cursor.messages) > 0, "Should have messages before clearing" + + # Clear messages manually + del cursor.messages[:] + assert len(cursor.messages) == 0, "Messages should be cleared after del cursor.messages[:]" + + # Verify we can still add messages after clearing + cursor.execute("PRINT 'New message after clearing'") + assert len(cursor.messages) == 1, "Should capture new message after clearing" + assert "New message after clearing" in cursor.messages[0][1], "New message should be correct" + +def test_cursor_messages_executemany(cursor, db_connection): + """Test messages with executemany""" + try: + # Create test table + cursor.execute("CREATE TABLE #test_messages_executemany (id INT)") + db_connection.commit() + + # Clear messages + del cursor.messages[:] + + # Use executemany and generate a message + data = [(1,), (2,), (3,)] + cursor.executemany("INSERT INTO #test_messages_executemany VALUES (?)", data) + cursor.execute("PRINT 'After executemany'") + + # Check messages + assert len(cursor.messages) == 1, "Should have one message" + assert "After executemany" in cursor.messages[0][1], "Message should be correct" + + finally: + cursor.execute("DROP TABLE IF EXISTS #test_messages_executemany") + db_connection.commit() + +def test_cursor_messages_with_error(cursor): + """Test messages when an error occurs""" + # Clear messages + del cursor.messages[:] + + # Try to execute an invalid query + try: + cursor.execute("SELCT 1") # Typo in SELECT + except Exception: + pass # Expected to fail + + # Execute a valid query with message + cursor.execute("PRINT 'After error'") + + # Check that messages were cleared before the new execute + assert len(cursor.messages) == 1, "Should have only the new message" + assert "After error" in cursor.messages[0][1], "Message should be from after the error" + +def test_tables_setup(cursor, db_connection): + """Create test objects for tables method testing""" + try: + # Create a test schema for isolation + cursor.execute("IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = 'pytest_tables_schema') EXEC('CREATE SCHEMA pytest_tables_schema')") + + # Drop tables if they exist to ensure clean state + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.regular_table") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.another_table") + cursor.execute("DROP VIEW IF EXISTS pytest_tables_schema.test_view") + + # Create regular table + cursor.execute(""" + CREATE TABLE pytest_tables_schema.regular_table ( + id INT PRIMARY KEY, + name VARCHAR(100) + ) + """) + + # Create another table + cursor.execute(""" + CREATE TABLE pytest_tables_schema.another_table ( + id INT PRIMARY KEY, + description VARCHAR(200) + ) + """) + + # Create a view + cursor.execute(""" + CREATE VIEW pytest_tables_schema.test_view AS + SELECT id, name FROM pytest_tables_schema.regular_table + """) + + db_connection.commit() + except Exception as e: + pytest.fail(f"Test setup failed: {e}") + +def test_tables_all(cursor, db_connection): + """Test tables returns information about all tables/views""" + try: + # First set up our test tables + test_tables_setup(cursor, db_connection) + + # Get all tables (no filters) + tables_list = cursor.tables() + + # Verify we got results + assert tables_list is not None, "tables() should return results" + assert len(tables_list) > 0, "tables() should return at least one table" + + # Verify our test tables are in the results + # Use case-insensitive comparison to avoid driver case sensitivity issues + found_test_table = False + for table in tables_list: + if (hasattr(table, 'table_name') and + table.table_name and + table.table_name.lower() == 'regular_table' and + hasattr(table, 'table_schem') and + table.table_schem and + table.table_schem.lower() == 'pytest_tables_schema'): + found_test_table = True + break + + assert found_test_table, "Test table should be included in results" + + # Verify structure of results + first_row = tables_list[0] + assert hasattr(first_row, 'table_cat'), "Result should have table_cat column" + assert hasattr(first_row, 'table_schem'), "Result should have table_schem column" + assert hasattr(first_row, 'table_name'), "Result should have table_name column" + assert hasattr(first_row, 'table_type'), "Result should have table_type column" + assert hasattr(first_row, 'remarks'), "Result should have remarks column" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_specific_table(cursor, db_connection): + """Test tables returns information about a specific table""" + try: + # Get specific table + tables_list = cursor.tables( + table='regular_table', + schema='pytest_tables_schema' + ) + + # Verify we got the right result + assert len(tables_list) == 1, "Should find exactly 1 table" + + # Verify table details + table = tables_list[0] + assert table.table_name.lower() == 'regular_table', "Table name should be 'regular_table'" + assert table.table_schem.lower() == 'pytest_tables_schema', "Schema should be 'pytest_tables_schema'" + assert table.table_type == 'TABLE', "Table type should be 'TABLE'" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_table_pattern(cursor, db_connection): + """Test tables with table name pattern""" + try: + # Get tables with pattern + tables_list = cursor.tables( + table='%table', + schema='pytest_tables_schema' + ) + + # Should find both test tables + assert len(tables_list) == 2, "Should find 2 tables matching '%table'" + + # Verify we found both test tables + table_names = set() + for table in tables_list: + if table.table_name: + table_names.add(table.table_name.lower()) + + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_schema_pattern(cursor, db_connection): + """Test tables with schema name pattern""" + try: + # Get tables with schema pattern + tables_list = cursor.tables( + schema='pytest_%' + ) + + # Should find our test tables/view + test_tables = [] + for table in tables_list: + if (table.table_schem and + table.table_schem.lower() == 'pytest_tables_schema' and + table.table_name and + table.table_name.lower() in ('regular_table', 'another_table', 'test_view')): + test_tables.append(table.table_name.lower()) + + assert len(test_tables) == 3, "Should find our 3 test objects" + assert 'regular_table' in test_tables, "Should find regular_table" + assert 'another_table' in test_tables, "Should find another_table" + assert 'test_view' in test_tables, "Should find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_type_filter(cursor, db_connection): + """Test tables with table type filter""" + try: + # Get only tables + tables_list = cursor.tables( + schema='pytest_tables_schema', + tableType='TABLE' + ) + + # Verify only regular tables + table_types = set() + table_names = set() + for table in tables_list: + if table.table_type: + table_types.add(table.table_type) + if table.table_name: + table_names.add(table.table_name.lower()) + + assert len(table_types) == 1, "Should only have one table type" + assert 'TABLE' in table_types, "Should only find TABLE type" + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + assert 'test_view' not in table_names, "Should not find test_view" + + # Get only views + views_list = cursor.tables( + schema='pytest_tables_schema', + tableType='VIEW' + ) + + # Verify only views + view_names = set() + for view in views_list: + if view.table_name: + view_names.add(view.table_name.lower()) + + assert 'test_view' in view_names, "Should find test_view" + assert 'regular_table' not in view_names, "Should not find regular_table" + assert 'another_table' not in view_names, "Should not find another_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_with_multiple_types(cursor, db_connection): + """Test tables with multiple table types""" + try: + # Get both tables and views + tables_list = cursor.tables( + schema='pytest_tables_schema', + tableType=['TABLE', 'VIEW'] + ) + + # Verify both tables and views + object_names = set() + for obj in tables_list: + if obj.table_name: + object_names.add(obj.table_name.lower()) + + assert len(object_names) == 3, "Should find 3 objects (2 tables + 1 view)" + assert 'regular_table' in object_names, "Should find regular_table" + assert 'another_table' in object_names, "Should find another_table" + assert 'test_view' in object_names, "Should find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_catalog_filter(cursor, db_connection): + """Test tables with catalog filter""" + try: + # Get current database name + cursor.execute("SELECT DB_NAME() AS current_db") + current_db = cursor.fetchone().current_db + + # Get tables with current catalog + tables_list = cursor.tables( + catalog=current_db, + schema='pytest_tables_schema' + ) + + # Verify catalog filter worked + assert len(tables_list) > 0, "Should find tables with correct catalog" + + # Verify catalog in results + for table in tables_list: + # Some drivers might return None for catalog + if table.table_cat is not None: + assert table.table_cat.lower() == current_db.lower(), "Wrong table catalog" + + # Test with non-existent catalog + fake_tables = cursor.tables( + catalog='nonexistent_db_xyz123', + schema='pytest_tables_schema' + ) + assert len(fake_tables) == 0, "Should return empty list for non-existent catalog" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_nonexistent(cursor): + """Test tables with non-existent objects""" + # Test with non-existent table + tables_list = cursor.tables(table='nonexistent_table_xyz123') + + # Should return empty list, not error + assert isinstance(tables_list, list), "Should return a list for non-existent table" + assert len(tables_list) == 0, "Should return empty list for non-existent table" + + # Test with non-existent schema + tables_list = cursor.tables( + table='regular_table', + schema='nonexistent_schema_xyz123' + ) + assert len(tables_list) == 0, "Should return empty list for non-existent schema" + +def test_tables_combined_filters(cursor, db_connection): + """Test tables with multiple combined filters""" + try: + # Test with schema and table pattern + tables_list = cursor.tables( + schema='pytest_tables_schema', + table='regular%' + ) + + # Should find only regular_table + assert len(tables_list) == 1, "Should find 1 table with combined filters" + assert tables_list[0].table_name.lower() == 'regular_table', "Should find regular_table" + + # Test with schema, table pattern, and type + tables_list = cursor.tables( + schema='pytest_tables_schema', + table='%table', + tableType='TABLE' + ) + + # Should find both tables but not view + table_names = set() + for table in tables_list: + if table.table_name: + table_names.add(table.table_name.lower()) + + assert len(table_names) == 2, "Should find 2 tables with combined filters" + assert 'regular_table' in table_names, "Should find regular_table" + assert 'another_table' in table_names, "Should find another_table" + assert 'test_view' not in table_names, "Should not find test_view" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_result_processing(cursor, db_connection): + """Test processing of tables result set for different client needs""" + try: + # Get all test objects + tables_list = cursor.tables(schema='pytest_tables_schema') + + # Test 1: Extract just table names + table_names = [table.table_name for table in tables_list] + assert len(table_names) == 3, "Should extract 3 table names" + + # Test 2: Filter to just tables (not views) + just_tables = [table for table in tables_list if table.table_type == 'TABLE'] + assert len(just_tables) == 2, "Should find 2 regular tables" + + # Test 3: Create a schema.table dictionary + schema_table_map = {} + for table in tables_list: + if table.table_schem not in schema_table_map: + schema_table_map[table.table_schem] = [] + schema_table_map[table.table_schem].append(table.table_name) + + assert 'pytest_tables_schema' in schema_table_map, "Should have our test schema" + assert len(schema_table_map['pytest_tables_schema']) == 3, "Should have 3 objects in test schema" + + # Test 4: Check indexing and attribute access + first_table = tables_list[0] + assert first_table[0] == first_table.table_cat, "Index 0 should match table_cat attribute" + assert first_table[1] == first_table.table_schem, "Index 1 should match table_schem attribute" + assert first_table[2] == first_table.table_name, "Index 2 should match table_name attribute" + assert first_table[3] == first_table.table_type, "Index 3 should match table_type attribute" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_method_chaining(cursor, db_connection): + """Test tables method with method chaining""" + try: + # Test method chaining with other methods + chained_result = cursor.tables( + schema='pytest_tables_schema', + table='regular_table' + ) + + # Verify chained result + assert len(chained_result) == 1, "Chained result should find 1 table" + assert chained_result[0].table_name.lower() == 'regular_table', "Should find regular_table" + + finally: + # Clean up happens in test_tables_cleanup + pass + +def test_tables_cleanup(cursor, db_connection): + """Clean up test objects after testing""" + try: + # Drop all test objects + cursor.execute("DROP VIEW IF EXISTS pytest_tables_schema.test_view") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.regular_table") + cursor.execute("DROP TABLE IF EXISTS pytest_tables_schema.another_table") + + # Drop the test schema + cursor.execute("DROP SCHEMA IF EXISTS pytest_tables_schema") + db_connection.commit() + except Exception as e: + pytest.fail(f"Test cleanup failed: {e}") + def test_close(db_connection): """Test closing the cursor""" try: