diff --git a/README.md b/README.md index d57efda1f..ec82a3637 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,12 @@ or to a Databricks Runtime interactive cluster (e.g. /sql/protocolv1/o/123456789 > to authenticate the target Databricks user account and needs to open the browser for authentication. So it > can only run on the user's machine. +## Transaction Support + +The connector supports multi-statement transactions with manual commit/rollback control. Set `connection.autocommit = False` to disable autocommit mode, then use `connection.commit()` and `connection.rollback()` to control transactions. + +For detailed documentation, examples, and best practices, see **[TRANSACTIONS.md](TRANSACTIONS.md)**. + ## SQLAlchemy Starting from `databricks-sql-connector` version 4.0.0 SQLAlchemy support has been extracted to a new library `databricks-sqlalchemy`. diff --git a/TRANSACTIONS.md b/TRANSACTIONS.md new file mode 100644 index 000000000..590c298c0 --- /dev/null +++ b/TRANSACTIONS.md @@ -0,0 +1,387 @@ +# Transaction Support + +The Databricks SQL Connector for Python supports multi-statement transactions (MST). This allows you to group multiple SQL statements into atomic units that either succeed completely or fail completely. + +## Autocommit Behavior + +By default, every SQL statement executes in its own transaction and commits immediately (autocommit mode). This is the standard behavior for most database connectors. + +```python +from databricks import sql + +connection = sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc123" +) + +# Default: autocommit is True +print(connection.autocommit) # True + +# Each statement commits immediately +cursor = connection.cursor() +cursor.execute("INSERT INTO my_table VALUES (1, 'data')") +# Already committed - data is visible to other connections +``` + +To use explicit transactions, disable autocommit: + +```python +connection.autocommit = False + +# Now statements are grouped into a transaction +cursor = connection.cursor() +cursor.execute("INSERT INTO my_table VALUES (1, 'data')") +# Not committed yet - must call connection.commit() + +connection.commit() # Now it's visible +``` + +## Basic Transaction Operations + +### Committing Changes + +When autocommit is disabled, you must explicitly commit your changes: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO orders VALUES (1, 100.00)") + cursor.execute("INSERT INTO order_items VALUES (1, 'Widget', 2)") + connection.commit() # Both inserts succeed together +except Exception as e: + connection.rollback() # Neither insert is saved + raise +finally: + connection.autocommit = True # Restore default state +``` + +### Rolling Back Changes + +Use `rollback()` to discard all changes made in the current transaction: + +```python +connection.autocommit = False +cursor = connection.cursor() + +cursor.execute("INSERT INTO accounts VALUES (1, 1000)") +cursor.execute("UPDATE accounts SET balance = balance - 500 WHERE id = 1") + +# Changed your mind? +connection.rollback() # All changes discarded +``` + +Note: Calling `rollback()` when autocommit is enabled is safe (it's a no-op), but calling `commit()` will raise a `TransactionError`. + +### Sequential Transactions + +After a commit or rollback, a new transaction starts automatically: + +```python +connection.autocommit = False + +# First transaction +cursor.execute("INSERT INTO logs VALUES (1, 'event1')") +connection.commit() + +# Second transaction starts automatically +cursor.execute("INSERT INTO logs VALUES (2, 'event2')") +connection.rollback() # Only the second insert is discarded +``` + +## Multi-Table Transactions + +Transactions span multiple tables atomically. Either all changes are committed, or all are rolled back: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + # Insert into multiple tables + cursor.execute("INSERT INTO customers VALUES (1, 'Alice')") + cursor.execute("INSERT INTO orders VALUES (1, 1, 100.00)") + cursor.execute("INSERT INTO shipments VALUES (1, 1, 'pending')") + + connection.commit() # All three inserts succeed atomically +except Exception as e: + connection.rollback() # All three inserts are discarded + raise +finally: + connection.autocommit = True # Restore default state +``` + +This is particularly useful for maintaining data consistency across related tables. + +## Transaction Isolation + +Databricks uses **Snapshot Isolation** (mapped to `REPEATABLE_READ` in standard SQL terminology). This means: + +- **Repeatable reads**: Once you read data in a transaction, subsequent reads will see the same data (even if other transactions modify it) +- **Atomic commits**: Changes are visible to other connections only after commit +- **Write serializability within a single table**: Concurrent writes to the same table will cause conflicts +- **Snapshot isolation across tables**: Concurrent writes to different tables can succeed + +### Getting the Isolation Level + +```python +level = connection.get_transaction_isolation() +print(level) # Output: REPEATABLE_READ +``` + +### Setting the Isolation Level + +Currently, only `REPEATABLE_READ` is supported: + +```python +from databricks import sql + +# Using the constant +connection.set_transaction_isolation(sql.TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ) + +# Or using a string +connection.set_transaction_isolation("REPEATABLE_READ") + +# Other levels will raise NotSupportedError +connection.set_transaction_isolation("READ_COMMITTED") # Raises NotSupportedError +``` + +### What Repeatable Read Means in Practice + +Within a transaction, you'll always see a consistent snapshot of the data: + +```python +connection.autocommit = False +cursor = connection.cursor() + +# First read +cursor.execute("SELECT balance FROM accounts WHERE id = 1") +balance1 = cursor.fetchone()[0] # Returns 1000 + +# Another connection updates the balance +# (In a separate connection: UPDATE accounts SET balance = 500 WHERE id = 1) + +# Second read in the same transaction +cursor.execute("SELECT balance FROM accounts WHERE id = 1") +balance2 = cursor.fetchone()[0] # Still returns 1000 (repeatable read!) + +connection.commit() + +# After commit, new transactions will see the updated value (500) +``` + +## Error Handling + +### Setting Autocommit During a Transaction + +You cannot change autocommit mode while a transaction is active: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO logs VALUES (1, 'data')") + + # This will raise TransactionError + connection.autocommit = True # Error: transaction is active + +except sql.TransactionError as e: + print(f"Cannot change autocommit: {e}") + connection.rollback() # Clean up the transaction +finally: + connection.autocommit = True # Now it's safe to restore +``` + +### Committing Without an Active Transaction + +If autocommit is enabled, there's no active transaction, so calling `commit()` will fail: + +```python +connection.autocommit = True # Default + +try: + connection.commit() # Raises TransactionError +except sql.TransactionError as e: + print(f"No active transaction: {e}") +``` + +However, `rollback()` is safe in this case (it's a no-op). + +### Recovering from Query Failures + +If a statement fails during a transaction, roll back and start a new transaction: + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO valid_table VALUES (1, 'data')") + cursor.execute("INSERT INTO nonexistent_table VALUES (2, 'data')") # Fails + connection.commit() +except Exception as e: + connection.rollback() # Discard the partial transaction + + # Log the error (with autocommit still disabled) + try: + cursor.execute("INSERT INTO error_log VALUES (1, 'Query failed')") + connection.commit() + except Exception: + connection.rollback() +finally: + connection.autocommit = True # Restore default state +``` + +## Querying Server State + +By default, the `autocommit` property returns a cached value for performance. If you need to query the server each time (for instance, when strong consistency is required): + +```python +connection = sql.connect( + server_hostname="your-workspace.cloud.databricks.com", + http_path="/sql/1.0/warehouses/abc123", + fetch_autocommit_from_server=True +) + +# Each access queries the server +state = connection.autocommit # Executes "SET AUTOCOMMIT" query +``` + +This is generally not needed for normal usage. + +## Write Conflicts + +### Within a Single Table + +Databricks enforces **write serializability** within a single table. If two transactions try to modify the same table concurrently, one will fail: + +```python +# Connection 1 +conn1.autocommit = False +cursor1 = conn1.cursor() +cursor1.execute("INSERT INTO accounts VALUES (1, 100)") + +# Connection 2 (concurrent) +conn2.autocommit = False +cursor2 = conn2.cursor() +cursor2.execute("INSERT INTO accounts VALUES (2, 200)") + +# First commit succeeds +conn1.commit() # OK + +# Second commit fails with concurrent write conflict +try: + conn2.commit() # Raises error about concurrent writes +except Exception as e: + conn2.rollback() + print(f"Concurrent write detected: {e}") +``` + +This happens even when the rows being modified are different. The conflict detection is at the table level. + +### Across Multiple Tables + +Concurrent writes to *different* tables can succeed. Each table tracks its own write conflicts independently: + +```python +# Connection 1: writes to table_a +conn1.autocommit = False +cursor1 = conn1.cursor() +cursor1.execute("INSERT INTO table_a VALUES (1, 'data')") + +# Connection 2: writes to table_b (different table) +conn2.autocommit = False +cursor2 = conn2.cursor() +cursor2.execute("INSERT INTO table_b VALUES (1, 'data')") + +# Both commits succeed (different tables) +conn1.commit() # OK +conn2.commit() # Also OK +``` + +## Best Practices + +1. **Keep transactions short**: Long-running transactions can cause conflicts with other connections. Commit as soon as your atomic unit of work is complete. + +2. **Always handle exceptions**: Wrap transaction code in try/except/finally and call `rollback()` on errors. + +```python +connection.autocommit = False +cursor = connection.cursor() + +try: + cursor.execute("INSERT INTO table1 VALUES (1, 'data')") + cursor.execute("UPDATE table2 SET status = 'updated'") + connection.commit() +except Exception as e: + connection.rollback() + logger.error(f"Transaction failed: {e}") + raise +finally: + connection.autocommit = True # Restore default state +``` + +3. **Use context managers**: If you're writing helper functions, consider using a context manager pattern: + +```python +from contextlib import contextmanager + +@contextmanager +def transaction(connection): + connection.autocommit = False + try: + yield connection + connection.commit() + except Exception: + connection.rollback() + raise + finally: + connection.autocommit = True + +# Usage +with transaction(connection): + cursor = connection.cursor() + cursor.execute("INSERT INTO logs VALUES (1, 'message')") + # Auto-commits on success, auto-rolls back on exception +``` + +4. **Reset autocommit when done**: Use a `finally` block to restore autocommit to `True`. This is especially important if the connection is reused or part of a connection pool: + +```python +connection.autocommit = False +try: + # ... transaction code ... + connection.commit() +except Exception: + connection.rollback() + raise +finally: + connection.autocommit = True # Restore to default state +``` + +5. **Be aware of isolation semantics**: Remember that repeatable read means you see a snapshot from the start of your transaction. If you need to see recent changes from other transactions, commit your current transaction and start a new one. + +## Requirements + +To use transactions, you need: +- A Databricks SQL warehouse that supports Multi-Statement Transactions (MST) +- Tables created with the `delta.feature.catalogOwned-preview` table property: + +```sql +CREATE TABLE my_table (id INT, value STRING) +USING DELTA +TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') +``` + +## Related APIs + +- `connection.autocommit` - Get or set autocommit mode (boolean) +- `connection.commit()` - Commit the current transaction +- `connection.rollback()` - Roll back the current transaction +- `connection.get_transaction_isolation()` - Get the isolation level (returns `"REPEATABLE_READ"`) +- `connection.set_transaction_isolation(level)` - Validate/set isolation level (only `"REPEATABLE_READ"` supported) +- `sql.TransactionError` - Exception raised for transaction-specific errors + +All of these are extensions to [PEP 249](https://www.python.org/dev/peps/pep-0249/) (Python Database API Specification v2.0). diff --git a/examples/README.md b/examples/README.md index d73c58a6b..f52dede1d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -31,6 +31,7 @@ To run all of these examples you can clone the entire repository to your disk. O - **`query_execute.py`** connects to the `samples` database of your default catalog, runs a small query, and prints the result to screen. - **`insert_data.py`** adds a tables called `squares` to your default catalog and inserts one hundred rows of example data. Then it fetches this data and prints it to the screen. +- **`transactions.py`** demonstrates multi-statement transaction support with explicit commit/rollback control. Shows how to group multiple SQL statements into an atomic unit that either succeeds completely or fails completely. - **`query_cancel.py`** shows how to cancel a query assuming that you can access the `Cursor` executing that query from a different thread. This is necessary because `databricks-sql-connector` does not yet implement an asynchronous API; calling `.execute()` blocks the current thread until execution completes. Therefore, the connector can't cancel queries from the same thread where they began. - **`interactive_oauth.py`** shows the simplest example of authenticating by OAuth (no need for a PAT generated in the DBSQL UI) while Bring Your Own IDP is in public preview. When you run the script it will open a browser window so you can authenticate. Afterward, the script fetches some sample data from Databricks and prints it to the screen. For this script, the OAuth token is not persisted which means you need to authenticate every time you run the script. - **`m2m_oauth.py`** shows the simplest example of authenticating by using OAuth M2M (machine-to-machine) for service principal. diff --git a/examples/transactions.py b/examples/transactions.py new file mode 100644 index 000000000..6f58dbd2d --- /dev/null +++ b/examples/transactions.py @@ -0,0 +1,47 @@ +from databricks import sql +import os + +with sql.connect( + server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), + http_path=os.getenv("DATABRICKS_HTTP_PATH"), + access_token=os.getenv("DATABRICKS_TOKEN"), +) as connection: + + # Disable autocommit to use explicit transactions + connection.autocommit = False + + with connection.cursor() as cursor: + try: + # Create tables for demonstration + cursor.execute("CREATE TABLE IF NOT EXISTS accounts (id int, balance int)") + cursor.execute( + "CREATE TABLE IF NOT EXISTS transfers (from_id int, to_id int, amount int)" + ) + connection.commit() + + # Start a new transaction - transfer money between accounts + cursor.execute("INSERT INTO accounts VALUES (1, 1000), (2, 500)") + cursor.execute("UPDATE accounts SET balance = balance - 100 WHERE id = 1") + cursor.execute("UPDATE accounts SET balance = balance + 100 WHERE id = 2") + cursor.execute("INSERT INTO transfers VALUES (1, 2, 100)") + + # Commit the transaction - all changes succeed together + connection.commit() + print("Transaction committed successfully") + + # Verify the results + cursor.execute("SELECT * FROM accounts ORDER BY id") + print("Accounts:", cursor.fetchall()) + + cursor.execute("SELECT * FROM transfers") + print("Transfers:", cursor.fetchall()) + + except Exception as e: + # Roll back on error - all changes are discarded + connection.rollback() + print(f"Transaction rolled back due to error: {e}") + raise + + finally: + # Restore autocommit to default state + connection.autocommit = True diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 403a4d130..df44dd534 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -8,6 +8,9 @@ paramstyle = "named" +# Transaction isolation level constants (extension to PEP 249) +TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" + import re from typing import TYPE_CHECKING diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 5bb191ca2..4db1ad118 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -20,6 +20,8 @@ InterfaceError, NotSupportedError, ProgrammingError, + TransactionError, + DatabaseError, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -86,6 +88,9 @@ NO_NATIVE_PARAMS: List = [] +# Transaction isolation level constants (extension to PEP 249) +TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ = "REPEATABLE_READ" + class Connection: def __init__( @@ -206,6 +211,11 @@ def read(self) -> Optional[OAuthToken]: This allows 1. cursor.tables() to return METRIC_VIEW table type 2. cursor.columns() to return "measure" column type + :param fetch_autocommit_from_server: `bool`, optional (default is False) + When True, the connection.autocommit property queries the server for current state + using SET AUTOCOMMIT instead of returning cached value. + Set to True if autocommit might be changed by external means (e.g., external SQL commands). + When False (default), uses cached state for better performance. """ # Internal arguments in **kwargs: @@ -304,6 +314,9 @@ def read(self) -> Optional[OAuthToken]: kwargs.get("use_inline_params", False) ) self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None) + self._fetch_autocommit_from_server = kwargs.get( + "fetch_autocommit_from_server", False + ) self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False) self.enable_telemetry = kwargs.get("enable_telemetry", False) @@ -473,15 +486,261 @@ def _close(self, close_cursors=True) -> None: if self.http_client: self.http_client.close() - def commit(self): - """No-op because Databricks does not support transactions""" - pass + @property + def autocommit(self) -> bool: + """ + Get auto-commit mode for this connection. - def rollback(self): - raise NotSupportedError( - "Transactions are not supported on Databricks", - session_id_hex=self.get_session_id_hex(), - ) + Extension to PEP 249. Returns cached value by default. + If fetch_autocommit_from_server=True was set during connection, + queries server for current state. + + Returns: + bool: True if auto-commit is enabled, False otherwise + + Raises: + InterfaceError: If connection is closed + TransactionError: If fetch_autocommit_from_server=True and query fails + """ + if not self.open: + raise InterfaceError( + "Cannot get autocommit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + if self._fetch_autocommit_from_server: + return self._fetch_autocommit_state_from_server() + + return self.session.get_autocommit() + + @autocommit.setter + def autocommit(self, value: bool) -> None: + """ + Set auto-commit mode for this connection. + + Extension to PEP 249. Executes SET AUTOCOMMIT command on server. + + Args: + value: True to enable auto-commit, False to disable + + Raises: + InterfaceError: If connection is closed + TransactionError: If server rejects the change + """ + if not self.open: + raise InterfaceError( + "Cannot set autocommit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + # Create internal cursor for transaction control + cursor = None + try: + cursor = self.cursor() + sql = f"SET AUTOCOMMIT = {'TRUE' if value else 'FALSE'}" + cursor.execute(sql) + + # Update cached state on success + self.session.set_autocommit(value) + + except DatabaseError as e: + # Wrap in TransactionError with context + raise TransactionError( + f"Failed to set autocommit to {value}: {e.message}", + context={ + **e.context, + "operation": "set_autocommit", + "autocommit_value": value, + }, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def _fetch_autocommit_state_from_server(self) -> bool: + """ + Query server for current autocommit state using SET AUTOCOMMIT. + + Returns: + bool: Server's autocommit state + + Raises: + TransactionError: If query fails + """ + cursor = None + try: + cursor = self.cursor() + cursor.execute("SET AUTOCOMMIT") + + # Fetch result: should return row with value column + result = cursor.fetchone() + if result is None: + raise TransactionError( + "No result returned from SET AUTOCOMMIT query", + context={"operation": "fetch_autocommit"}, + session_id_hex=self.get_session_id_hex(), + ) + + # Parse value (first column should be "true" or "false") + value_str = str(result[0]).lower() + autocommit_state = value_str == "true" + + # Update cache + self.session.set_autocommit(autocommit_state) + + return autocommit_state + + except TransactionError: + # Re-raise TransactionError as-is + raise + except DatabaseError as e: + # Wrap other DatabaseErrors + raise TransactionError( + f"Failed to fetch autocommit state from server: {e.message}", + context={**e.context, "operation": "fetch_autocommit"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def commit(self) -> None: + """ + Commit the current transaction. + + Per PEP 249. Should be called only when autocommit is disabled. + + When autocommit is False: + - Commits the current transaction + - Server automatically starts new transaction + + When autocommit is True: + - Server may throw error if no active transaction + + Raises: + InterfaceError: If connection is closed + TransactionError: If commit fails (e.g., no active transaction) + """ + if not self.open: + raise InterfaceError( + "Cannot commit on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + cursor = None + try: + cursor = self.cursor() + cursor.execute("COMMIT") + + except DatabaseError as e: + raise TransactionError( + f"Failed to commit transaction: {e.message}", + context={**e.context, "operation": "commit"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def rollback(self) -> None: + """ + Rollback the current transaction. + + Per PEP 249. Should be called only when autocommit is disabled. + + When autocommit is False: + - Rolls back the current transaction + - Server automatically starts new transaction + + When autocommit is True: + - ROLLBACK is forgiving (no-op, doesn't throw exception) + + Note: ROLLBACK is safe to call even without active transaction. + + Raises: + InterfaceError: If connection is closed + TransactionError: If rollback fails + """ + if not self.open: + raise InterfaceError( + "Cannot rollback on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + cursor = None + try: + cursor = self.cursor() + cursor.execute("ROLLBACK") + + except DatabaseError as e: + raise TransactionError( + f"Failed to rollback transaction: {e.message}", + context={**e.context, "operation": "rollback"}, + session_id_hex=self.get_session_id_hex(), + ) from e + finally: + if cursor: + cursor.close() + + def get_transaction_isolation(self) -> str: + """ + Get the transaction isolation level. + + Extension to PEP 249. + + Databricks supports REPEATABLE_READ isolation level (Snapshot Isolation), + which is the default and only supported level. + + Returns: + str: "REPEATABLE_READ" - the transaction isolation level constant + + Raises: + InterfaceError: If connection is closed + """ + if not self.open: + raise InterfaceError( + "Cannot get transaction isolation on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + return TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ + + def set_transaction_isolation(self, level: str) -> None: + """ + Set transaction isolation level. + + Extension to PEP 249. + + Databricks supports only REPEATABLE_READ isolation level (Snapshot Isolation). + This method validates that the requested level is supported but does not + execute any SQL, as REPEATABLE_READ is the default server behavior. + + Args: + level: Isolation level. Must be "REPEATABLE_READ" or "REPEATABLE READ" + (case-insensitive, underscores and spaces are interchangeable) + + Raises: + InterfaceError: If connection is closed + NotSupportedError: If isolation level not supported + """ + if not self.open: + raise InterfaceError( + "Cannot set transaction isolation on closed connection", + session_id_hex=self.get_session_id_hex(), + ) + + # Normalize and validate isolation level + normalized_level = level.upper().replace("_", " ") + + if normalized_level != TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ.replace( + "_", " " + ): + raise NotSupportedError( + f"Setting transaction isolation level '{level}' is not supported. " + f"Only {TRANSACTION_ISOLATION_LEVEL_REPEATABLE_READ} is supported.", + session_id_hex=self.get_session_id_hex(), + ) class Cursor: diff --git a/src/databricks/sql/exc.py b/src/databricks/sql/exc.py index 4a772c49b..3a3a6b3c5 100644 --- a/src/databricks/sql/exc.py +++ b/src/databricks/sql/exc.py @@ -70,6 +70,23 @@ class NotSupportedError(DatabaseError): pass +class TransactionError(DatabaseError): + """ + Exception raised for transaction-specific errors. + + This exception is used when transaction control operations fail, such as: + - Setting autocommit mode (AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION) + - Committing a transaction (MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION) + - Rolling back a transaction + - Setting transaction isolation level + + The exception includes context about which transaction operation failed + and preserves the underlying cause via exception chaining. + """ + + pass + + ### Custom error classes ### class InvalidServerResponseError(OperationalError): """Thrown if the server does not set the initial namespace correctly""" diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index d8ba5d125..0f723d144 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -45,6 +45,9 @@ def __init__( self.schema = schema self.http_path = http_path + # Initialize autocommit state (JDBC default is True) + self._autocommit = True + user_agent_entry = kwargs.get("user_agent_entry") if user_agent_entry is None: user_agent_entry = kwargs.get("_user_agent_entry") @@ -168,6 +171,24 @@ def guid_hex(self) -> str: """Get the session ID in hex format""" return self._session_id.hex_guid + def get_autocommit(self) -> bool: + """ + Get the cached autocommit state for this session. + + Returns: + bool: True if autocommit is enabled, False otherwise + """ + return self._autocommit + + def set_autocommit(self, value: bool) -> None: + """ + Update the cached autocommit state for this session. + + Args: + value: True to cache autocommit as enabled, False as disabled + """ + self._autocommit = value + def close(self) -> None: """Close the underlying session.""" logger.info("Closing session %s", self.guid_hex) diff --git a/tests/e2e/test_transactions.py b/tests/e2e/test_transactions.py new file mode 100644 index 000000000..09cbdae24 --- /dev/null +++ b/tests/e2e/test_transactions.py @@ -0,0 +1,597 @@ +""" +End-to-end integration tests for Multi-Statement Transaction (MST) APIs. + +These tests verify: +- autocommit property (getter/setter) +- commit() and rollback() methods +- get_transaction_isolation() and set_transaction_isolation() methods +- Transaction error handling + +Requirements: +- DBSQL warehouse that supports Multi-Statement Transactions (MST) +- Test environment configured via test.env file or environment variables + +Setup: +Set the following environment variables: +- DATABRICKS_SERVER_HOSTNAME +- DATABRICKS_HTTP_PATH +- DATABRICKS_ACCESS_TOKEN (or use OAuth) + +Usage: + pytest tests/e2e/test_transactions.py -v +""" + +import logging +import os +import pytest +from typing import Any, Dict + +import databricks.sql as sql +from databricks.sql import TransactionError, NotSupportedError, InterfaceError + +logger = logging.getLogger(__name__) + + +@pytest.mark.skip( + reason="Test environment does not yet support multi-statement transactions" +) +class TestTransactions: + """E2E tests for transaction control methods (MST support).""" + + # Test table name + TEST_TABLE_NAME = "transaction_test_table" + + @pytest.fixture(autouse=True) + def setup_and_teardown(self, connection_details): + """Setup test environment before each test and cleanup after.""" + self.connection_params = { + "server_hostname": connection_details["host"], + "http_path": connection_details["http_path"], + "access_token": connection_details.get("access_token"), + } + + # Get catalog and schema from environment or use defaults + self.catalog = os.getenv("DATABRICKS_CATALOG", "main") + self.schema = os.getenv("DATABRICKS_SCHEMA", "default") + + # Create connection for setup + self.connection = sql.connect(**self.connection_params) + + # Setup: Create test table + self._create_test_table() + + yield + + # Teardown: Cleanup + self._cleanup() + + def _get_fully_qualified_table_name(self) -> str: + """Get the fully qualified table name.""" + return f"{self.catalog}.{self.schema}.{self.TEST_TABLE_NAME}" + + def _create_test_table(self): + """Create the test table with Delta format and MST support.""" + fq_table_name = self._get_fully_qualified_table_name() + cursor = self.connection.cursor() + + try: + # Drop if exists + cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") + + # Create table with Delta and catalog-owned feature for MST compatibility + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table_name} + (id INT, value STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + + logger.info(f"Created test table: {fq_table_name}") + finally: + cursor.close() + + def _cleanup(self): + """Cleanup after test: rollback pending transactions, drop table, close connection.""" + try: + # Try to rollback any pending transaction + if ( + self.connection + and self.connection.open + and not self.connection.autocommit + ): + try: + self.connection.rollback() + except Exception as e: + logger.debug( + f"Rollback during cleanup failed (may be expected): {e}" + ) + + # Reset to autocommit mode + try: + self.connection.autocommit = True + except Exception as e: + logger.debug(f"Reset autocommit during cleanup failed: {e}") + + # Drop test table + if self.connection and self.connection.open: + fq_table_name = self._get_fully_qualified_table_name() + cursor = self.connection.cursor() + try: + cursor.execute(f"DROP TABLE IF EXISTS {fq_table_name}") + logger.info(f"Dropped test table: {fq_table_name}") + except Exception as e: + logger.warning(f"Failed to drop test table: {e}") + finally: + cursor.close() + + finally: + # Close connection + if self.connection: + self.connection.close() + + # ==================== BASIC AUTOCOMMIT TESTS ==================== + + def test_default_autocommit_is_true(self): + """Test that new connection defaults to autocommit=true.""" + assert ( + self.connection.autocommit is True + ), "New connection should have autocommit=true by default" + + def test_set_autocommit_to_false(self): + """Test successfully setting autocommit to false.""" + self.connection.autocommit = False + assert ( + self.connection.autocommit is False + ), "autocommit should be false after setting to false" + + def test_set_autocommit_to_true(self): + """Test successfully setting autocommit back to true.""" + # First disable + self.connection.autocommit = False + assert self.connection.autocommit is False + + # Then enable + self.connection.autocommit = True + assert ( + self.connection.autocommit is True + ), "autocommit should be true after setting to true" + + # ==================== COMMIT TESTS ==================== + + def test_commit_single_insert(self): + """Test successfully committing a transaction with single INSERT.""" + fq_table_name = self._get_fully_qualified_table_name() + + # Start transaction + self.connection.autocommit = False + + # Insert data + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'test_value')" + ) + cursor.close() + + # Commit + self.connection.commit() + + # Verify data is persisted using a new connection + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result is not None, "Should find inserted row after commit" + assert result[0] == "test_value", "Value should match inserted value" + finally: + verify_conn.close() + + def test_commit_multiple_inserts(self): + """Test successfully committing a transaction with multiple INSERTs.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # Insert multiple rows + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'value1')") + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'value2')") + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'value3')") + cursor.close() + + self.connection.commit() + + # Verify all rows persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name}") + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result[0] == 3, "Should have 3 rows after commit" + finally: + verify_conn.close() + + # ==================== ROLLBACK TESTS ==================== + + def test_rollback_single_insert(self): + """Test successfully rolling back a transaction.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # Insert data + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (100, 'rollback_test')" + ) + cursor.close() + + # Rollback + self.connection.rollback() + + # Verify data is NOT persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 100" + ) + result = verify_cursor.fetchone() + verify_cursor.close() + + assert result[0] == 0, "Rolled back data should not be persisted" + finally: + verify_conn.close() + + # ==================== SEQUENTIAL TRANSACTION TESTS ==================== + + def test_multiple_sequential_transactions(self): + """Test executing multiple sequential transactions (commit, commit, rollback).""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'txn1')") + cursor.close() + self.connection.commit() + + # Second transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'txn2')") + cursor.close() + self.connection.commit() + + # Third transaction - rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (3, 'txn3')") + cursor.close() + self.connection.rollback() + + # Verify only first two transactions persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table_name} WHERE id IN (1, 2)" + ) + result = verify_cursor.fetchone() + assert result[0] == 2, "Should have 2 committed rows" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 3") + result = verify_cursor.fetchone() + assert result[0] == 0, "Rolled back row should not exist" + verify_cursor.close() + finally: + verify_conn.close() + + def test_auto_start_transaction_after_commit(self): + """Test that new transaction automatically starts after commit.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") + cursor.close() + self.connection.commit() + + # New transaction should start automatically - insert and rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") + cursor.close() + self.connection.rollback() + + # Verify: first committed, second rolled back + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == 1, "First insert should be committed" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") + result = verify_cursor.fetchone() + assert result[0] == 0, "Second insert should be rolled back" + verify_cursor.close() + finally: + verify_conn.close() + + def test_auto_start_transaction_after_rollback(self): + """Test that new transaction automatically starts after rollback.""" + fq_table_name = self._get_fully_qualified_table_name() + + self.connection.autocommit = False + + # First transaction - rollback + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'first')") + cursor.close() + self.connection.rollback() + + # New transaction should start automatically - insert and commit + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (2, 'second')") + cursor.close() + self.connection.commit() + + # Verify: first rolled back, second committed + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == 0, "First insert should be rolled back" + + verify_cursor.execute(f"SELECT COUNT(*) FROM {fq_table_name} WHERE id = 2") + result = verify_cursor.fetchone() + assert result[0] == 1, "Second insert should be committed" + verify_cursor.close() + finally: + verify_conn.close() + + # ==================== UPDATE/DELETE OPERATION TESTS ==================== + + def test_update_in_transaction(self): + """Test UPDATE operation in transaction.""" + fq_table_name = self._get_fully_qualified_table_name() + + # First insert a row with autocommit + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table_name} (id, value) VALUES (1, 'original')" + ) + cursor.close() + + # Start transaction and update + self.connection.autocommit = False + cursor = self.connection.cursor() + cursor.execute(f"UPDATE {fq_table_name} SET value = 'updated' WHERE id = 1") + cursor.close() + self.connection.commit() + + # Verify update persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + verify_cursor.execute(f"SELECT value FROM {fq_table_name} WHERE id = 1") + result = verify_cursor.fetchone() + assert result[0] == "updated", "Value should be updated after commit" + verify_cursor.close() + finally: + verify_conn.close() + + # ==================== MULTI-TABLE TRANSACTION TESTS ==================== + + def test_multi_table_transaction_commit(self): + """Test atomic commit across multiple tables.""" + fq_table1_name = self._get_fully_qualified_table_name() + table2_name = self.TEST_TABLE_NAME + "_2" + fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" + + # Create second table + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table2_name} + (id INT, category STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + cursor.close() + + try: + # Start transaction and insert into both tables + self.connection.autocommit = False + + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table1_name} (id, value) VALUES (10, 'table1_data')" + ) + cursor.execute( + f"INSERT INTO {fq_table2_name} (id, category) VALUES (10, 'table2_data')" + ) + cursor.close() + + # Commit both atomically + self.connection.commit() + + # Verify both inserts persisted + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 10" + ) + result = verify_cursor.fetchone() + assert result[0] == 1, "Table1 insert should be committed" + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 10" + ) + result = verify_cursor.fetchone() + assert result[0] == 1, "Table2 insert should be committed" + + verify_cursor.close() + finally: + verify_conn.close() + + finally: + # Cleanup second table + self.connection.autocommit = True + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.close() + + def test_multi_table_transaction_rollback(self): + """Test atomic rollback across multiple tables.""" + fq_table1_name = self._get_fully_qualified_table_name() + table2_name = self.TEST_TABLE_NAME + "_2" + fq_table2_name = f"{self.catalog}.{self.schema}.{table2_name}" + + # Create second table + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {fq_table2_name} + (id INT, category STRING) + USING DELTA + TBLPROPERTIES ('delta.feature.catalogOwned-preview' = 'supported') + """ + ) + cursor.close() + + try: + # Start transaction and insert into both tables + self.connection.autocommit = False + + cursor = self.connection.cursor() + cursor.execute( + f"INSERT INTO {fq_table1_name} (id, value) VALUES (20, 'rollback1')" + ) + cursor.execute( + f"INSERT INTO {fq_table2_name} (id, category) VALUES (20, 'rollback2')" + ) + cursor.close() + + # Rollback both atomically + self.connection.rollback() + + # Verify both inserts were rolled back + verify_conn = sql.connect(**self.connection_params) + try: + verify_cursor = verify_conn.cursor() + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table1_name} WHERE id = 20" + ) + result = verify_cursor.fetchone() + assert result[0] == 0, "Table1 insert should be rolled back" + + verify_cursor.execute( + f"SELECT COUNT(*) FROM {fq_table2_name} WHERE id = 20" + ) + result = verify_cursor.fetchone() + assert result[0] == 0, "Table2 insert should be rolled back" + + verify_cursor.close() + finally: + verify_conn.close() + + finally: + # Cleanup second table + self.connection.autocommit = True + cursor = self.connection.cursor() + cursor.execute(f"DROP TABLE IF EXISTS {fq_table2_name}") + cursor.close() + + # ==================== ERROR HANDLING TESTS ==================== + + def test_set_autocommit_during_active_transaction(self): + """Test that setting autocommit during an active transaction throws error.""" + fq_table_name = self._get_fully_qualified_table_name() + + # Start transaction + self.connection.autocommit = False + cursor = self.connection.cursor() + cursor.execute(f"INSERT INTO {fq_table_name} (id, value) VALUES (99, 'test')") + cursor.close() + + # Try to set autocommit=True during active transaction + with pytest.raises(TransactionError) as exc_info: + self.connection.autocommit = True + + # Verify error message mentions autocommit or active transaction + error_msg = str(exc_info.value).lower() + assert ( + "autocommit" in error_msg or "active transaction" in error_msg + ), "Error should mention autocommit or active transaction" + + # Cleanup - rollback the transaction + self.connection.rollback() + + def test_commit_without_active_transaction_throws_error(self): + """Test that commit() throws error when autocommit=true (no active transaction).""" + # Ensure autocommit is true (default) + assert self.connection.autocommit is True + + # Attempt commit without active transaction should throw + with pytest.raises(TransactionError) as exc_info: + self.connection.commit() + + # Verify error message indicates no active transaction + error_message = str(exc_info.value) + assert ( + "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION" in error_message + or "no active transaction" in error_message.lower() + ), "Error should indicate no active transaction" + + def test_rollback_without_active_transaction_is_safe(self): + """Test that rollback() without active transaction is a safe no-op.""" + # With autocommit=true (no active transaction) + assert self.connection.autocommit is True + + # ROLLBACK should be safe (no exception) + self.connection.rollback() + + # Verify connection is still usable + assert self.connection.autocommit is True + assert self.connection.open is True + + # ==================== TRANSACTION ISOLATION TESTS ==================== + + def test_get_transaction_isolation_returns_repeatable_read(self): + """Test that get_transaction_isolation() returns REPEATABLE_READ.""" + isolation_level = self.connection.get_transaction_isolation() + assert ( + isolation_level == "REPEATABLE_READ" + ), "Databricks MST should use REPEATABLE_READ (Snapshot Isolation)" + + def test_set_transaction_isolation_accepts_repeatable_read(self): + """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" + # Should not raise - these are all valid formats + self.connection.set_transaction_isolation("REPEATABLE_READ") + self.connection.set_transaction_isolation("REPEATABLE READ") + self.connection.set_transaction_isolation("repeatable_read") + self.connection.set_transaction_isolation("repeatable read") + + def test_set_transaction_isolation_rejects_unsupported_level(self): + """Test that set_transaction_isolation() rejects unsupported levels.""" + with pytest.raises(NotSupportedError) as exc_info: + self.connection.set_transaction_isolation("READ_COMMITTED") + + error_message = str(exc_info.value) + assert "not supported" in error_message.lower() + assert "READ_COMMITTED" in error_message diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 19375cde3..cb810afbb 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -22,7 +22,13 @@ import databricks.sql import databricks.sql.client as client -from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError +from databricks.sql import ( + InterfaceError, + DatabaseError, + Error, + NotSupportedError, + TransactionError, +) from databricks.sql.types import Row from databricks.sql.result_set import ResultSet, ThriftResultSet from databricks.sql.backend.types import CommandId, CommandState @@ -439,11 +445,6 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set(self): "last operation", ) - @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_commit_a_noop(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - c.commit() - def test_setinputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setinputsizes(1) @@ -452,12 +453,6 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) - def test_rollback_not_supported(self, mock_thrift_backend_class): - c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - with self.assertRaises(NotSupportedError): - c.rollback() - @unittest.skip("JDW: skipping winter 2024 as we're about to rewrite this interface") @patch("%s.client.ThriftDatabricksClient" % PACKAGE_NAME) def test_row_number_respected(self, mock_thrift_backend_class): @@ -639,11 +634,377 @@ def mock_close_normal(): ) +class TransactionTestSuite(unittest.TestCase): + """ + Unit tests for transaction control methods (MST support). + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + def _create_mock_connection(self, mock_session_class): + """Helper to create a mocked connection for transaction tests.""" + # Mock session + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session.get_autocommit.return_value = True + mock_session_class.return_value = mock_session + + # Create connection + conn = client.Connection(**self.DUMMY_CONNECTION_ARGS) + return conn + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_getter_returns_cached_value(self, mock_session_class): + """Test that autocommit property returns cached session value by default.""" + conn = self._create_mock_connection(mock_session_class) + + # Get autocommit (should use cached value) + result = conn.autocommit + + conn.session.get_autocommit.assert_called_once() + self.assertTrue(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_executes_sql(self, mock_session_class): + """Test that setting autocommit executes SET AUTOCOMMIT command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.autocommit = False + + # Verify SQL was executed + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = FALSE") + mock_cursor.close.assert_called_once() + + conn.session.set_autocommit.assert_called_once_with(False) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_with_true_value(self, mock_session_class): + """Test setting autocommit to True.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.autocommit = True + + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT = TRUE") + conn.session.set_autocommit.assert_called_once_with(True) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_wraps_database_error(self, mock_session_class): + """Test that autocommit setter wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "AUTOCOMMIT_SET_DURING_ACTIVE_TRANSACTION", + context={"sql_state": "25000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.autocommit = False + + self.assertIn("Failed to set autocommit", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "set_autocommit") + self.assertEqual(ctx.exception.context["autocommit_value"], False) + + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_autocommit_setter_preserves_exception_chain(self, mock_session_class): + """Test that exception chaining is preserved.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + original_error = DatabaseError( + "Original error", session_id_hex="test-session-id" + ) + mock_cursor.execute.side_effect = original_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.autocommit = False + + self.assertEqual(ctx.exception.__cause__, original_error) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_executes_sql(self, mock_session_class): + """Test that commit() executes COMMIT command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.commit() + + mock_cursor.execute.assert_called_once_with("COMMIT") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_wraps_database_error(self, mock_session_class): + """Test that commit() wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "MULTI_STATEMENT_TRANSACTION_NO_ACTIVE_TRANSACTION", + context={"sql_state": "25000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.commit() + + self.assertIn("Failed to commit", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "commit") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_commit_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that commit() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.commit() + + self.assertIn("Cannot commit on closed connection", str(ctx.exception)) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_executes_sql(self, mock_session_class): + """Test that rollback() executes ROLLBACK command.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + with patch.object(conn, "cursor", return_value=mock_cursor): + conn.rollback() + + mock_cursor.execute.assert_called_once_with("ROLLBACK") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_wraps_database_error(self, mock_session_class): + """Test that rollback() wraps DatabaseError in TransactionError.""" + conn = self._create_mock_connection(mock_session_class) + + mock_cursor = Mock() + server_error = DatabaseError( + "Unexpected rollback error", + context={"sql_state": "HY000"}, + session_id_hex="test-session-id", + ) + mock_cursor.execute.side_effect = server_error + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + conn.rollback() + + self.assertIn("Failed to rollback", str(ctx.exception)) + self.assertEqual(ctx.exception.context["operation"], "rollback") + mock_cursor.close.assert_called_once() + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_rollback_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that rollback() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.rollback() + + self.assertIn("Cannot rollback on closed connection", str(ctx.exception)) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_get_transaction_isolation_returns_repeatable_read( + self, mock_session_class + ): + """Test that get_transaction_isolation() returns REPEATABLE_READ.""" + conn = self._create_mock_connection(mock_session_class) + + result = conn.get_transaction_isolation() + + self.assertEqual(result, "REPEATABLE_READ") + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_get_transaction_isolation_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that get_transaction_isolation() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.get_transaction_isolation() + + self.assertIn( + "Cannot get transaction isolation on closed connection", str(ctx.exception) + ) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_accepts_repeatable_read( + self, mock_session_class + ): + """Test that set_transaction_isolation() accepts REPEATABLE_READ.""" + conn = self._create_mock_connection(mock_session_class) + + # Should not raise + conn.set_transaction_isolation("REPEATABLE_READ") + conn.set_transaction_isolation("REPEATABLE READ") # With space + conn.set_transaction_isolation("repeatable_read") # Lowercase with underscore + conn.set_transaction_isolation("repeatable read") # Lowercase with space + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_rejects_other_levels(self, mock_session_class): + """Test that set_transaction_isolation() rejects non-REPEATABLE_READ levels.""" + conn = self._create_mock_connection(mock_session_class) + + with self.assertRaises(NotSupportedError) as ctx: + conn.set_transaction_isolation("READ_COMMITTED") + + self.assertIn("not supported", str(ctx.exception)) + self.assertIn("READ_COMMITTED", str(ctx.exception)) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_set_transaction_isolation_on_closed_connection_raises_interface_error( + self, mock_session_class + ): + """Test that set_transaction_isolation() on closed connection raises InterfaceError.""" + conn = self._create_mock_connection(mock_session_class) + conn.session.is_open = False + + with self.assertRaises(InterfaceError) as ctx: + conn.set_transaction_isolation("REPEATABLE_READ") + + self.assertIn( + "Cannot set transaction isolation on closed connection", str(ctx.exception) + ) + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_queries_server(self, mock_session_class): + """Test that fetch_autocommit_from_server=True queries server.""" + # Create connection with fetch_autocommit_from_server=True + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_row = Mock() + mock_row.__getitem__ = Mock(return_value="true") + mock_cursor.fetchone.return_value = mock_row + + with patch.object(conn, "cursor", return_value=mock_cursor): + result = conn.autocommit + + mock_cursor.execute.assert_called_once_with("SET AUTOCOMMIT") + mock_cursor.fetchone.assert_called_once() + mock_cursor.close.assert_called_once() + + conn.session.set_autocommit.assert_called_once_with(True) + + self.assertTrue(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_handles_false_value(self, mock_session_class): + """Test that fetch_autocommit_from_server correctly parses false value.""" + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_row = Mock() + mock_row.__getitem__ = Mock(return_value="false") + mock_cursor.fetchone.return_value = mock_row + + with patch.object(conn, "cursor", return_value=mock_cursor): + result = conn.autocommit + + conn.session.set_autocommit.assert_called_once_with(False) + self.assertFalse(result) + + conn.close() + + @patch("%s.client.Session" % PACKAGE_NAME) + def test_fetch_autocommit_from_server_raises_on_no_result(self, mock_session_class): + """Test that fetch_autocommit_from_server raises error when no result.""" + mock_session = Mock() + mock_session.is_open = True + mock_session.guid_hex = "test-session-id" + mock_session_class.return_value = mock_session + + conn = client.Connection( + fetch_autocommit_from_server=True, **self.DUMMY_CONNECTION_ARGS + ) + + mock_cursor = Mock() + mock_cursor.fetchone.return_value = None + + with patch.object(conn, "cursor", return_value=mock_cursor): + with self.assertRaises(TransactionError) as ctx: + _ = conn.autocommit + + self.assertIn("No result returned", str(ctx.exception)) + mock_cursor.close.assert_called_once() + + conn.close() + + if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) loader = unittest.TestLoader() test_classes = [ ClientTestSuite, + TransactionTestSuite, FetchTests, ThriftBackendTestSuite, ArrowQueueSuite,