Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlspec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ParameterProcessor,
ParameterStyle,
ParameterStyleConfig,
ProcessedState,
SQLResult,
StackOperation,
StackResult,
Expand Down Expand Up @@ -66,6 +67,7 @@
"ParameterStyle",
"ParameterStyleConfig",
"PoolT",
"ProcessedState",
"QueryBuilder",
"SQLFactory",
"SQLFile",
Expand Down
4 changes: 3 additions & 1 deletion sqlspec/adapters/adbc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,9 @@ def _execute_statement(self, cursor: "Cursor", statement: SQL) -> "ExecutionResu
self._handle_postgres_rollback(cursor)
raise

if statement.returns_rows():
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)

if is_select_like:
fetched_data = cursor.fetchall()
column_names = [col[0] for col in cursor.description or []]

Expand Down
9 changes: 6 additions & 3 deletions sqlspec/adapters/bigquery/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,9 +686,13 @@ def _execute_statement(self, cursor: Any, statement: "SQL") -> ExecutionResult:
"""
sql, parameters = self._get_compiled_sql(statement, self.statement_config)
cursor.job = self._run_query_job(sql, parameters, connection=cursor)
job_result = cursor.job.result(job_retry=self._job_retry)
statement_type = str(cursor.job.statement_type or "").upper()
is_select_like = (
statement.returns_rows() or statement_type == "SELECT" or self._should_force_select(statement, cursor)
)

if statement.returns_rows():
job_result = cursor.job.result(job_retry=self._job_retry)
if is_select_like:
rows_list = self._rows_to_results(iter(job_result))
column_names = [field.name for field in cursor.job.schema] if cursor.job.schema else []

Expand All @@ -700,7 +704,6 @@ def _execute_statement(self, cursor: Any, statement: "SQL") -> ExecutionResult:
is_select_result=True,
)

cursor.job.result(job_retry=self._job_retry)
affected_rows = cursor.job.num_dml_affected_rows or 0
return self.create_execution_result(cursor, rowcount_override=affected_rows)

Expand Down
4 changes: 3 additions & 1 deletion sqlspec/adapters/duckdb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,9 @@ def _execute_statement(self, cursor: Any, statement: SQL) -> "ExecutionResult":
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
cursor.execute(sql, prepared_parameters or ())

if statement.returns_rows():
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)

if is_select_like:
fetched_data = cursor.fetchall()
column_names = [col[0] for col in cursor.description or []]

Expand Down
4 changes: 3 additions & 1 deletion sqlspec/adapters/oracledb/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,9 @@ async def _execute_statement(self, cursor: Any, statement: "SQL") -> "ExecutionR
await cursor.execute(sql, prepared_parameters or {})

# SELECT result processing for Oracle
if statement.returns_rows():
is_select_like = statement.returns_rows() or self._should_force_select(statement, cursor)

if is_select_like:
fetched_data = await cursor.fetchall()
column_names = [col[0] for col in cursor.description or []]
column_names = _normalize_column_names(column_names, self.driver_features)
Expand Down
8 changes: 7 additions & 1 deletion sqlspec/core/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@

OPERATION_TYPE_MAP: "dict[type[exp.Expression], OperationType]" = {
exp.Select: "SELECT",
exp.Union: "SELECT",
exp.Except: "SELECT",
exp.Intersect: "SELECT",
exp.With: "SELECT",
exp.Insert: "INSERT",
exp.Update: "UPDATE",
exp.Delete: "DELETE",
Expand Down Expand Up @@ -554,7 +558,9 @@ def _build_operation_profile(
modifies_rows = False

expr = expression
if isinstance(expr, (exp.Select, exp.Values, exp.Table, exp.TableSample, exp.With)):
if isinstance(
expr, (exp.Select, exp.Union, exp.Except, exp.Intersect, exp.Values, exp.Table, exp.TableSample, exp.With)
):
returns_rows = True
elif isinstance(expr, (exp.Insert, exp.Update, exp.Delete, exp.Merge)):
modifies_rows = True
Expand Down
26 changes: 26 additions & 0 deletions sqlspec/driver/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,32 @@ def build_statement_result(self, statement: "SQL", execution_result: ExecutionRe
metadata=execution_result.special_data or {"status_message": "OK"},
)

def _should_force_select(self, statement: "SQL", cursor: Any) -> bool:
"""Determine if a statement with unknown type should be treated as SELECT.

Uses driver metadata (statement_type, description/schema) as a safety net when
the compiler cannot classify the operation. This remains conservative by only
triggering when the operation type is "UNKNOWN".

Args:
statement: SQL statement being executed.
cursor: Database cursor/job object that may expose metadata.

Returns:
True when cursor metadata indicates a row-returning operation despite an
unknown operation type; otherwise False.
"""

if statement.operation_type != "UNKNOWN":
return False

statement_type = getattr(cursor, "statement_type", None)
if isinstance(statement_type, str) and statement_type.upper() == "SELECT":
return True

description = getattr(cursor, "description", None)
return bool(description)

def prepare_statement(
self,
statement: "Statement | QueryBuilder",
Expand Down
21 changes: 21 additions & 0 deletions tests/unit/test_core/test_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,27 @@ def test_sql_returns_rows_detection() -> None:
assert show_stmt.returns_rows() is True


@pytest.mark.parametrize(
"sql_text",
[
"SELECT 1 UNION ALL SELECT 2",
"SELECT 1 EXCEPT SELECT 2",
"SELECT 1 INTERSECT SELECT 1",
"WITH cte AS (SELECT 1 AS id) SELECT * FROM cte",
],
ids=["union", "except", "intersect", "cte_select"],
)
def test_sql_set_and_cte_operations_detect_as_select(sql_text: str) -> None:
"""Ensure set operations and CTE queries are detected as SELECT and return rows."""

stmt = SQL(sql_text)
stmt.compile()

assert stmt.operation_type == "SELECT"
assert stmt.returns_rows() is True
assert stmt.is_modifying_operation() is False


def test_sql_slots_prevent_new_attributes() -> None:
"""Test SQL __slots__ prevent adding new attributes."""
stmt = SQL("SELECT * FROM users")
Expand Down
80 changes: 80 additions & 0 deletions tests/unit/test_driver/test_force_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Tests for the _should_force_select safety net."""

from typing import Any, cast

from sqlspec import SQL, ProcessedState
from sqlspec.adapters.bigquery import bigquery_statement_config
from sqlspec.driver import CommonDriverAttributesMixin


class _DummyDriver(CommonDriverAttributesMixin):
"""Minimal driver to expose _should_force_select for testing."""

__slots__ = ()

def __init__(self) -> None:
super().__init__(connection=None, statement_config=bigquery_statement_config)


class _CursorWithStatementType:
"""Cursor exposing a statement_type attribute."""

def __init__(self, statement_type: str | None) -> None:
self.statement_type = statement_type
self.description = None


class _CursorWithDescription:
"""Cursor exposing a description attribute."""

def __init__(self, has_description: bool) -> None:
self.description = [("col",)] if has_description else None
self.statement_type = None


def _make_unknown_statement(sql_text: str = "select 1") -> "SQL":
stmt = SQL(sql_text)
cast("Any", stmt)._processed_state = ProcessedState(
compiled_sql=sql_text, execution_parameters={}, operation_type="UNKNOWN"
)
return stmt


def _make_select_statement(sql_text: str = "select 1") -> "SQL":
stmt = SQL(sql_text)
cast("Any", stmt)._processed_state = ProcessedState(
compiled_sql=sql_text, execution_parameters={}, operation_type="SELECT"
)
return stmt


def test_force_select_uses_statement_type_select() -> None:
driver = _DummyDriver()
stmt = _make_unknown_statement()
cursor = _CursorWithStatementType("SELECT")

assert cast("Any", driver)._should_force_select(stmt, cursor) is True


def test_force_select_uses_description_when_unknown() -> None:
driver = _DummyDriver()
stmt = _make_unknown_statement()
cursor = _CursorWithDescription(True)

assert cast("Any", driver)._should_force_select(stmt, cursor) is True


def test_force_select_false_when_no_metadata() -> None:
driver = _DummyDriver()
stmt = _make_unknown_statement()
cursor = _CursorWithDescription(False)

assert cast("Any", driver)._should_force_select(stmt, cursor) is False


def test_force_select_ignored_when_operation_known() -> None:
driver = _DummyDriver()
stmt = _make_select_statement()
cursor = _CursorWithDescription(True)

assert cast("Any", driver)._should_force_select(stmt, cursor) is False