From b2c55d517720ac4b025e7738541352a185d72e04 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 18 Nov 2025 05:14:22 +0000 Subject: [PATCH] fix(driver): validate FROM clause in COUNT query generation Adds validation to _create_count_query to ensure SELECT statements have a FROM clause before attempting to generate COUNT(*) queries. Previously, malformed SQL like "SELECT * ORDER BY id" would raise a confusing "empty SQL expression" error. Now provides clear error: "SELECT statement missing FROM clause. COUNT queries require a FROM clause to determine which table to count rows from." Uses sqlglot AST inspection (expr.args.get("from")) to detect missing FROM clauses before attempting to build COUNT queries. This prevents None from being passed to sqlglot's .from_() method. Test Plan: - Added 12 comprehensive unit tests for edge cases - Verified error messages are clear and actionable - Confirmed no regression in valid SELECT...FROM queries - All existing tests pass (4100+ tests) --- sqlspec/driver/_common.py | 7 + .../test_count_query_edge_cases.py | 233 ++++++++++++++++++ 2 files changed, 240 insertions(+) create mode 100644 tests/unit/test_driver/test_count_query_edge_cases.py diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index 0cb2e7b9..afdad1ed 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -985,6 +985,13 @@ def _create_count_query(self, original_sql: "SQL") -> "SQL": expr = original_sql.expression if isinstance(expr, exp.Select): + if not expr.args.get("from"): + msg = ( + "Cannot create COUNT query: SELECT statement missing FROM clause. " + "COUNT queries require a FROM clause to determine which table to count rows from." + ) + raise ImproperConfigurationError(msg) + if expr.args.get("group"): subquery = expr.subquery(alias="grouped_data") count_expr = exp.select(exp.Count(this=exp.Star())).from_(subquery) diff --git a/tests/unit/test_driver/test_count_query_edge_cases.py b/tests/unit/test_driver/test_count_query_edge_cases.py new file mode 100644 index 00000000..68527710 --- /dev/null +++ b/tests/unit/test_driver/test_count_query_edge_cases.py @@ -0,0 +1,233 @@ +"""Tests for _create_count_query edge cases and validation. + +This module tests COUNT query generation validation, particularly for edge cases +where SELECT statements are missing required clauses (FROM, etc.). +""" + +# pyright: reportPrivateUsage=false + +import pytest + +from sqlspec.core import SQL, StatementConfig +from sqlspec.driver._sync import SyncDriverAdapterBase +from sqlspec.exceptions import ImproperConfigurationError + + +class MockSyncDriver(SyncDriverAdapterBase): + """Mock driver for testing _create_count_query method.""" + + def __init__(self) -> None: + self.statement_config = StatementConfig() + + @property + def connection(self): + return None + + def _execute_statement(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def _execute_many(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def with_cursor(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def handle_database_exceptions(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def create_connection(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def close_connection(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def begin(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def commit(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def rollback(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + def _try_special_handling(self, *args, **kwargs): + raise NotImplementedError("Mock driver - not implemented") + + @property + def data_dictionary(self): + raise NotImplementedError("Mock driver - not implemented") + + +class TestCountQueryValidation: + """Test COUNT query generation validation.""" + + def test_count_query_missing_from_clause_with_order_by(self) -> None: + """Test COUNT query fails with clear error when FROM clause missing (ORDER BY only). + + This is the reported bug scenario from upstream. + """ + driver = MockSyncDriver() + sql = driver.prepare_statement(SQL("SELECT * ORDER BY id"), statement_config=driver.statement_config) + sql.compile() # Parse the SQL to populate expression + + with pytest.raises(ImproperConfigurationError, match="missing FROM clause"): + driver._create_count_query(sql) + + def test_count_query_missing_from_clause_with_where(self) -> None: + """Test COUNT query fails when only WHERE clause present (no FROM).""" + driver = MockSyncDriver() + sql = driver.prepare_statement(SQL("SELECT * WHERE active = true"), statement_config=driver.statement_config) + sql.compile() + + with pytest.raises(ImproperConfigurationError, match="missing FROM clause"): + driver._create_count_query(sql) + + def test_count_query_select_star_no_from(self) -> None: + """Test COUNT query fails for SELECT * without FROM clause.""" + driver = MockSyncDriver() + sql = driver.prepare_statement(SQL("SELECT *"), statement_config=driver.statement_config) + sql.compile() + + with pytest.raises(ImproperConfigurationError, match="missing FROM clause"): + driver._create_count_query(sql) + + def test_count_query_select_columns_no_from(self) -> None: + """Test COUNT query fails for SELECT columns without FROM clause.""" + driver = MockSyncDriver() + sql = driver.prepare_statement(SQL("SELECT id, name"), statement_config=driver.statement_config) + sql.compile() + + with pytest.raises(ImproperConfigurationError, match="missing FROM clause"): + driver._create_count_query(sql) + + def test_count_query_valid_select_with_from(self) -> None: + """Test COUNT query succeeds with valid SELECT...FROM.""" + driver = MockSyncDriver() + sql = driver.prepare_statement(SQL("SELECT * FROM users ORDER BY id"), statement_config=driver.statement_config) + sql.compile() + + count_sql = driver._create_count_query(sql) + + count_str = str(count_sql) + assert "COUNT(*)" in count_str.upper() + assert "FROM users" in count_str or "FROM USERS" in count_str.upper() + assert "ORDER BY" not in count_str.upper() + + def test_count_query_with_where_and_from(self) -> None: + """Test COUNT query preserves WHERE clause when FROM present.""" + driver = MockSyncDriver() + sql = driver.prepare_statement( + SQL("SELECT * FROM users WHERE active = true ORDER BY id"), statement_config=driver.statement_config + ) + sql.compile() + + count_sql = driver._create_count_query(sql) + + count_str = str(count_sql) + assert "COUNT(*)" in count_str.upper() + assert "FROM users" in count_str or "FROM USERS" in count_str.upper() + assert "WHERE" in count_str.upper() + assert "active" in count_str or "ACTIVE" in count_str.upper() + assert "ORDER BY" not in count_str.upper() + + def test_count_query_with_group_by(self) -> None: + """Test COUNT query wraps grouped query in subquery.""" + driver = MockSyncDriver() + sql = driver.prepare_statement( + SQL("SELECT status, COUNT(*) FROM users GROUP BY status"), statement_config=driver.statement_config + ) + sql.compile() + + count_sql = driver._create_count_query(sql) + + count_str = str(count_sql) + assert "COUNT(*)" in count_str.upper() + assert "grouped_data" in count_str.lower() + + def test_count_query_removes_limit_offset(self) -> None: + """Test COUNT query removes LIMIT and OFFSET clauses.""" + driver = MockSyncDriver() + sql = driver.prepare_statement( + SQL("SELECT * FROM users ORDER BY id LIMIT 10 OFFSET 20"), statement_config=driver.statement_config + ) + sql.compile() + + count_sql = driver._create_count_query(sql) + + count_str = str(count_sql) + assert "LIMIT" not in count_str.upper() + assert "OFFSET" not in count_str.upper() + assert "ORDER BY" not in count_str.upper() + + def test_count_query_with_having(self) -> None: + """Test COUNT query preserves HAVING clause.""" + driver = MockSyncDriver() + sql = driver.prepare_statement( + SQL("SELECT status, COUNT(*) as cnt FROM users GROUP BY status HAVING cnt > 5"), + statement_config=driver.statement_config, + ) + sql.compile() + + count_sql = driver._create_count_query(sql) + + count_str = str(count_sql) + assert "COUNT(*)" in count_str.upper() + + +class TestCountQueryEdgeCases: + """Test COUNT query edge cases that previously caused issues.""" + + def test_complex_select_with_join(self) -> None: + """Test complex SELECT with JOIN generates correct COUNT.""" + driver = MockSyncDriver() + sql = driver.prepare_statement( + SQL(""" + SELECT u.id, u.name, o.total + FROM users u + JOIN orders o ON u.id = o.user_id + WHERE u.active = true + AND o.total > 100 + ORDER BY o.total DESC + LIMIT 10 + """), + statement_config=driver.statement_config, + ) + sql.compile() + + count_sql = driver._create_count_query(sql) + + count_str = str(count_sql) + assert "COUNT(*)" in count_str.upper() + assert "FROM users" in count_str or "FROM USERS" in count_str.upper() + assert "ORDER BY" not in count_str.upper() + assert "LIMIT" not in count_str.upper() + + def test_select_with_subquery_in_from(self) -> None: + """Test SELECT with subquery in FROM clause.""" + driver = MockSyncDriver() + sql = driver.prepare_statement( + SQL(""" + SELECT t.id + FROM (SELECT id FROM users WHERE active = true) t + ORDER BY t.id + """), + statement_config=driver.statement_config, + ) + sql.compile() + + count_sql = driver._create_count_query(sql) + + count_str = str(count_sql) + assert "COUNT(*)" in count_str.upper() + + def test_error_message_clarity(self) -> None: + """Test that error message explains why FROM clause is required.""" + driver = MockSyncDriver() + sql = driver.prepare_statement(SQL("SELECT * ORDER BY id"), statement_config=driver.statement_config) + sql.compile() + + with pytest.raises( + ImproperConfigurationError, + match="COUNT queries require a FROM clause to determine which table to count rows from", + ): + driver._create_count_query(sql)