From 937ba01cf0748291935dff4dc723baf8ab04f0b7 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 17:49:06 +0000 Subject: [PATCH 1/3] fix: `returns_row` return `False` incorrectly --- sqlspec/adapters/adbc/driver.py | 4 +- sqlspec/adapters/bigquery/driver.py | 9 ++- sqlspec/adapters/duckdb/driver.py | 4 +- sqlspec/adapters/oracledb/driver.py | 4 +- sqlspec/core/compiler.py | 8 ++- sqlspec/driver/_common.py | 26 +++++++ tests/unit/test_core/test_statement.py | 21 ++++++ tests/unit/test_driver/test_force_select.py | 78 +++++++++++++++++++++ 8 files changed, 147 insertions(+), 7 deletions(-) create mode 100644 tests/unit/test_driver/test_force_select.py diff --git a/sqlspec/adapters/adbc/driver.py b/sqlspec/adapters/adbc/driver.py index 2da43c46..4aa0e457 100644 --- a/sqlspec/adapters/adbc/driver.py +++ b/sqlspec/adapters/adbc/driver.py @@ -513,7 +513,9 @@ def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResu self._handle_postgres_rollback(cursor) raise - if statement.returns_rows(): + is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor) + + if is_select_like: fetched_data = cursor.fetchall() column_names = [col[0] for col in cursor.description or []] diff --git a/sqlspec/adapters/bigquery/driver.py b/sqlspec/adapters/bigquery/driver.py index fd41f76f..287bba61 100644 --- a/sqlspec/adapters/bigquery/driver.py +++ b/sqlspec/adapters/bigquery/driver.py @@ -686,9 +686,13 @@ def _execute_statement(self, cursor: Any, statement: "SQL") -> ExecutionResult: """ sql, parameters = self._get_compiled_sql(statement, self.statement_config) cursor.job = self._run_query_job(sql, parameters, connection=cursor) + job_result = cursor.job.result(job_retry=self._job_retry) + statement_type = str(cursor.job.statement_type or "").upper() + is_select_like = ( + statement.returns_rows() or statement_type == "SELECT" or self._should_force_select(statement, cursor) + ) - if statement.returns_rows(): - job_result = cursor.job.result(job_retry=self._job_retry) + if is_select_like: rows_list = self._rows_to_results(iter(job_result)) column_names = [field.name for field in cursor.job.schema] if cursor.job.schema else [] @@ -700,7 +704,6 @@ def _execute_statement(self, cursor: Any, statement: "SQL") -> ExecutionResult: is_select_result=True, ) - cursor.job.result(job_retry=self._job_retry) affected_rows = cursor.job.num_dml_affected_rows or 0 return self.create_execution_result(cursor, rowcount_override=affected_rows) diff --git a/sqlspec/adapters/duckdb/driver.py b/sqlspec/adapters/duckdb/driver.py index ff13fc0e..9d023fdc 100644 --- a/sqlspec/adapters/duckdb/driver.py +++ b/sqlspec/adapters/duckdb/driver.py @@ -349,7 +349,9 @@ def _execute_statement(self, cursor: Any, statement: SQL) -> "ExecutionResult": sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) cursor.execute(sql, prepared_parameters or ()) - if statement.returns_rows(): + is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor) + + if is_select_like: fetched_data = cursor.fetchall() column_names = [col[0] for col in cursor.description or []] diff --git a/sqlspec/adapters/oracledb/driver.py b/sqlspec/adapters/oracledb/driver.py index 62cc8dd0..caef2e18 100644 --- a/sqlspec/adapters/oracledb/driver.py +++ b/sqlspec/adapters/oracledb/driver.py @@ -1243,7 +1243,9 @@ async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionR await cursor.execute(sql, prepared_parameters or {}) # SELECT result processing for Oracle - if statement.returns_rows(): + is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor) + + if is_select_like: fetched_data = await cursor.fetchall() column_names = [col[0] for col in cursor.description or []] column_names = _normalize_column_names(column_names, self.driver_features) diff --git a/sqlspec/core/compiler.py b/sqlspec/core/compiler.py index b8bba8b3..adbb62b4 100644 --- a/sqlspec/core/compiler.py +++ b/sqlspec/core/compiler.py @@ -55,6 +55,10 @@ OPERATION_TYPE_MAP: "dict[type[exp.Expression], OperationType]" = { exp.Select: "SELECT", + exp.Union: "SELECT", + exp.Except: "SELECT", + exp.Intersect: "SELECT", + exp.With: "SELECT", exp.Insert: "INSERT", exp.Update: "UPDATE", exp.Delete: "DELETE", @@ -554,7 +558,9 @@ def _build_operation_profile( modifies_rows = False expr = expression - if isinstance(expr, (exp.Select, exp.Values, exp.Table, exp.TableSample, exp.With)): + if isinstance( + expr, (exp.Select, exp.Union, exp.Except, exp.Intersect, exp.Values, exp.Table, exp.TableSample, exp.With) + ): returns_rows = True elif isinstance(expr, (exp.Insert, exp.Update, exp.Delete, exp.Merge)): modifies_rows = True diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index afdad1ed..7d686d0a 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -548,6 +548,32 @@ def build_statement_result(self, statement: "SQL", execution_result: ExecutionRe metadata=execution_result.special_data or {"status_message": "OK"}, ) + def _should_force_select(self, statement: "SQL", cursor: Any) -> bool: + """Determine if a statement with unknown type should be treated as SELECT. + + Uses driver metadata (statement_type, description/schema) as a safety net when + the compiler cannot classify the operation. This remains conservative by only + triggering when the operation type is "UNKNOWN". + + Args: + statement: SQL statement being executed. + cursor: Database cursor/job object that may expose metadata. + + Returns: + True when cursor metadata indicates a row-returning operation despite an + unknown operation type; otherwise False. + """ + + if statement.operation_type != "UNKNOWN": + return False + + statement_type = getattr(cursor, "statement_type", None) + if isinstance(statement_type, str) and statement_type.upper() == "SELECT": + return True + + description = getattr(cursor, "description", None) + return bool(description) + def prepare_statement( self, statement: "Statement | QueryBuilder", diff --git a/tests/unit/test_core/test_statement.py b/tests/unit/test_core/test_statement.py index 536b2751..44f9769b 100644 --- a/tests/unit/test_core/test_statement.py +++ b/tests/unit/test_core/test_statement.py @@ -534,6 +534,27 @@ def test_sql_returns_rows_detection() -> None: assert show_stmt.returns_rows() is True +@pytest.mark.parametrize( + "sql_text", + [ + "SELECT 1 UNION ALL SELECT 2", + "SELECT 1 EXCEPT SELECT 2", + "SELECT 1 INTERSECT SELECT 1", + "WITH cte AS (SELECT 1 AS id) SELECT * FROM cte", + ], + ids=["union", "except", "intersect", "cte_select"], +) +def test_sql_set_and_cte_operations_detect_as_select(sql_text: str) -> None: + """Ensure set operations and CTE queries are detected as SELECT and return rows.""" + + stmt = SQL(sql_text) + stmt.compile() + + assert stmt.operation_type == "SELECT" + assert stmt.returns_rows() is True + assert stmt.is_modifying_operation() is False + + def test_sql_slots_prevent_new_attributes() -> None: """Test SQL __slots__ prevent adding new attributes.""" stmt = SQL("SELECT * FROM users") diff --git a/tests/unit/test_driver/test_force_select.py b/tests/unit/test_driver/test_force_select.py new file mode 100644 index 00000000..576f619c --- /dev/null +++ b/tests/unit/test_driver/test_force_select.py @@ -0,0 +1,78 @@ +"""Tests for the _should_force_select safety net.""" + +from sqlspec.adapters.bigquery import bigquery_statement_config +from sqlspec.core import SQL, ProcessedState +from sqlspec.driver._common import CommonDriverAttributesMixin + + +class _DummyDriver(CommonDriverAttributesMixin): + """Minimal driver to expose _should_force_select for testing.""" + + __slots__ = () + + def __init__(self) -> None: + super().__init__(connection=None, statement_config=bigquery_statement_config) + + +class _CursorWithStatementType: + """Cursor exposing a statement_type attribute.""" + + def __init__(self, statement_type: str | None) -> None: + self.statement_type = statement_type + self.description = None + + +class _CursorWithDescription: + """Cursor exposing a description attribute.""" + + def __init__(self, has_description: bool) -> None: + self.description = [("col",)] if has_description else None + self.statement_type = None + + +def _make_unknown_statement(sql_text: str = "select 1") -> "SQL": + stmt = SQL(sql_text) + stmt._processed_state = ProcessedState( # pylint: disable=protected-access + compiled_sql=sql_text, execution_parameters={}, operation_type="UNKNOWN" + ) + return stmt + + +def _make_select_statement(sql_text: str = "select 1") -> "SQL": + stmt = SQL(sql_text) + stmt._processed_state = ProcessedState( # pylint: disable=protected-access + compiled_sql=sql_text, execution_parameters={}, operation_type="SELECT" + ) + return stmt + + +def test_force_select_uses_statement_type_select() -> None: + driver = _DummyDriver() + stmt = _make_unknown_statement() + cursor = _CursorWithStatementType("SELECT") + + assert driver._should_force_select(stmt, cursor) is True # pylint: disable=protected-access + + +def test_force_select_uses_description_when_unknown() -> None: + driver = _DummyDriver() + stmt = _make_unknown_statement() + cursor = _CursorWithDescription(True) + + assert driver._should_force_select(stmt, cursor) is True # pylint: disable=protected-access + + +def test_force_select_false_when_no_metadata() -> None: + driver = _DummyDriver() + stmt = _make_unknown_statement() + cursor = _CursorWithDescription(False) + + assert driver._should_force_select(stmt, cursor) is False # pylint: disable=protected-access + + +def test_force_select_ignored_when_operation_known() -> None: + driver = _DummyDriver() + stmt = _make_select_statement() + cursor = _CursorWithDescription(True) + + assert driver._should_force_select(stmt, cursor) is False # pylint: disable=protected-access From ebed6d48ca9105893bddd3fb6b4a48931f318024 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 17:53:55 +0000 Subject: [PATCH 2/3] fix: update access to protected attributes in test_force_select.py --- tests/unit/test_driver/test_force_select.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_driver/test_force_select.py b/tests/unit/test_driver/test_force_select.py index 576f619c..7613f36c 100644 --- a/tests/unit/test_driver/test_force_select.py +++ b/tests/unit/test_driver/test_force_select.py @@ -1,8 +1,10 @@ """Tests for the _should_force_select safety net.""" +from typing import Any, cast + +from sqlspec import SQL, ProcessedState from sqlspec.adapters.bigquery import bigquery_statement_config -from sqlspec.core import SQL, ProcessedState -from sqlspec.driver._common import CommonDriverAttributesMixin +from sqlspec.driver import CommonDriverAttributesMixin class _DummyDriver(CommonDriverAttributesMixin): @@ -32,7 +34,7 @@ def __init__(self, has_description: bool) -> None: def _make_unknown_statement(sql_text: str = "select 1") -> "SQL": stmt = SQL(sql_text) - stmt._processed_state = ProcessedState( # pylint: disable=protected-access + cast("Any", stmt)._processed_state = ProcessedState( compiled_sql=sql_text, execution_parameters={}, operation_type="UNKNOWN" ) return stmt @@ -40,7 +42,7 @@ def _make_unknown_statement(sql_text: str = "select 1") -> "SQL": def _make_select_statement(sql_text: str = "select 1") -> "SQL": stmt = SQL(sql_text) - stmt._processed_state = ProcessedState( # pylint: disable=protected-access + cast("Any", stmt)._processed_state = ProcessedState( compiled_sql=sql_text, execution_parameters={}, operation_type="SELECT" ) return stmt @@ -51,7 +53,7 @@ def test_force_select_uses_statement_type_select() -> None: stmt = _make_unknown_statement() cursor = _CursorWithStatementType("SELECT") - assert driver._should_force_select(stmt, cursor) is True # pylint: disable=protected-access + assert cast("Any", driver)._should_force_select(stmt, cursor) is True def test_force_select_uses_description_when_unknown() -> None: @@ -59,7 +61,7 @@ def test_force_select_uses_description_when_unknown() -> None: stmt = _make_unknown_statement() cursor = _CursorWithDescription(True) - assert driver._should_force_select(stmt, cursor) is True # pylint: disable=protected-access + assert cast("Any", driver)._should_force_select(stmt, cursor) is True def test_force_select_false_when_no_metadata() -> None: @@ -67,7 +69,7 @@ def test_force_select_false_when_no_metadata() -> None: stmt = _make_unknown_statement() cursor = _CursorWithDescription(False) - assert driver._should_force_select(stmt, cursor) is False # pylint: disable=protected-access + assert cast("Any", driver)._should_force_select(stmt, cursor) is False def test_force_select_ignored_when_operation_known() -> None: @@ -75,4 +77,4 @@ def test_force_select_ignored_when_operation_known() -> None: stmt = _make_select_statement() cursor = _CursorWithDescription(True) - assert driver._should_force_select(stmt, cursor) is False # pylint: disable=protected-access + assert cast("Any", driver)._should_force_select(stmt, cursor) is False From b71d7461f08f051ebe467971432336e430ca7dda Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 18:07:13 +0000 Subject: [PATCH 3/3] fix: add ProcessedState to core imports and __all__ export --- sqlspec/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sqlspec/__init__.py b/sqlspec/__init__.py index 626c30d9..d1b0ae29 100644 --- a/sqlspec/__init__.py +++ b/sqlspec/__init__.py @@ -28,6 +28,7 @@ ParameterProcessor, ParameterStyle, ParameterStyleConfig, + ProcessedState, SQLResult, StackOperation, StackResult, @@ -66,6 +67,7 @@ "ParameterStyle", "ParameterStyleConfig", "PoolT", + "ProcessedState", "QueryBuilder", "SQLFactory", "SQLFile",