From 32bde41319b09fafbc7e8a04cb76b41f1734b86d Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 5 Nov 2025 20:38:16 +0000 Subject: [PATCH 1/5] feat(adk): Implement decimal coercion in Oracle ADK store deserialization tests --- sqlspec/adapters/oracledb/adk/store.py | 182 +++++++++--------- .../test_oracledb/test_adk_store.py | 49 +++++ 2 files changed, 140 insertions(+), 91 deletions(-) create mode 100644 tests/unit/test_adapters/test_oracledb/test_adk_store.py diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index c2156e63..fe60ed0c 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -1,11 +1,17 @@ """Oracle ADK store for Google Agent Development Kit session/event storage.""" +from decimal import Decimal from enum import Enum -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, cast import oracledb from sqlspec import SQL +from sqlspec.adapters.oracledb.data_dictionary import ( + OracleAsyncDataDictionary, + OracleSyncDataDictionary, + OracleVersionInfo, +) from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json @@ -33,6 +39,41 @@ class JSONStorageType(str, Enum): BLOB_PLAIN = "blob_plain" +def _coerce_decimal_values(value: Any) -> Any: + if isinstance(value, Decimal): + return float(value) + if isinstance(value, dict): + return {key: _coerce_decimal_values(val) for key, val in value.items()} + if isinstance(value, list): + return [_coerce_decimal_values(item) for item in value] + if isinstance(value, tuple): + return tuple(_coerce_decimal_values(item) for item in value) + if isinstance(value, set): + return {_coerce_decimal_values(item) for item in value} + if isinstance(value, frozenset): + return frozenset(_coerce_decimal_values(item) for item in value) + return value + + +def _storage_type_from_version(version_info: "OracleVersionInfo | None") -> JSONStorageType: + """Determine JSON storage type based on Oracle version metadata.""" + + if version_info and version_info.supports_native_json(): + logger.info("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_info) + return JSONStorageType.JSON_NATIVE + + if version_info and version_info.supports_json_blob(): + logger.info("Detected Oracle %s, using BLOB_JSON (recommended)", version_info) + return JSONStorageType.BLOB_JSON + + if version_info: + logger.info("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_info) + return JSONStorageType.BLOB_PLAIN + + logger.warning("Oracle version could not be detected; defaulting to BLOB_JSON storage") + return JSONStorageType.BLOB_JSON + + def _to_oracle_bool(value: "bool | None") -> "int | None": """Convert Python boolean to Oracle NUMBER(1). @@ -103,7 +144,7 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): - Configuration is read from config.extension_config["adk"] """ - __slots__ = ("_in_memory", "_json_storage_type") + __slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info") def __init__(self, config: "OracleAsyncConfig") -> None: """Initialize Oracle ADK store. @@ -120,6 +161,7 @@ def __init__(self, config: "OracleAsyncConfig") -> None: """ super().__init__(config) self._json_storage_type: JSONStorageType | None = None + self._oracle_version_info: OracleVersionInfo | None = None adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) @@ -160,44 +202,22 @@ async def _detect_json_storage_type(self) -> JSONStorageType: if self._json_storage_type is not None: return self._json_storage_type - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute( - """ - SELECT version FROM product_component_version - WHERE product LIKE 'Oracle%Database%' - """ - ) - row = await cursor.fetchone() - - if row is None: - logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON") - self._json_storage_type = JSONStorageType.BLOB_JSON - return self._json_storage_type - - version_str = str(row[0]) - version_parts = version_str.split(".") - major_version = int(version_parts[0]) if version_parts else 0 - - if major_version >= ORACLE_MIN_JSON_NATIVE_VERSION: - await cursor.execute("SELECT value FROM v$parameter WHERE name = 'compatible'") - compatible_row = await cursor.fetchone() - if compatible_row: - compatible_parts = str(compatible_row[0]).split(".") - compatible_major = int(compatible_parts[0]) if compatible_parts else 0 - if compatible_major >= ORACLE_MIN_JSON_NATIVE_COMPATIBLE: - logger.info("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_str) - self._json_storage_type = JSONStorageType.JSON_NATIVE - return self._json_storage_type - - if major_version >= ORACLE_MIN_JSON_BLOB_VERSION: - logger.info("Detected Oracle %s, using BLOB_JSON (recommended)", version_str) - self._json_storage_type = JSONStorageType.BLOB_JSON - return self._json_storage_type - - logger.info("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_str) - self._json_storage_type = JSONStorageType.BLOB_PLAIN - return self._json_storage_type + version_info = await self._get_version_info() + self._json_storage_type = _storage_type_from_version(version_info) + return self._json_storage_type + + async def _get_version_info(self) -> "OracleVersionInfo | None": + """Return cached Oracle version info, querying via data dictionary if needed.""" + + if self._oracle_version_info is not None: + return self._oracle_version_info + + async with self._config.provide_session() as driver: + data_dictionary = OracleAsyncDataDictionary() + self._oracle_version_info = await data_dictionary.get_version(driver) + if self._oracle_version_info is None: + logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage") + return self._oracle_version_info async def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes": """Serialize state dictionary to appropriate format based on storage type. @@ -232,7 +252,7 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]": data = await data.read() if isinstance(data, dict): - return data + return cast("dict[str, Any]", _coerce_decimal_values(data)) if isinstance(data, bytes): return from_json(data) # type: ignore[no-any-return] @@ -280,7 +300,7 @@ async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": data = await data.read() if isinstance(data, dict): - return data + return cast("dict[str, Any]", _coerce_decimal_values(data)) if isinstance(data, bytes): return from_json(data) # type: ignore[no-any-return] @@ -490,7 +510,7 @@ async def create_tables(self) -> None: Uses version-appropriate table schema. """ storage_type = await self._detect_json_storage_type() - logger.info("Creating ADK tables with storage type: %s", storage_type) + logger.debug("Creating ADK tables with storage type: %s", storage_type) async with self._config.provide_session() as driver: await driver.execute_script(self._get_create_sessions_table_sql_for_type(storage_type)) @@ -561,16 +581,17 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": State is deserialized using version-appropriate format. """ - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = :id - """ - try: async with self._config.provide_connection() as conn: cursor = conn.cursor() - await cursor.execute(sql, {"id": session_id}) + await cursor.execute( + f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE id = :id + """, + {"id": session_id}, + ) row = await cursor.fetchone() if row is None: @@ -881,7 +902,7 @@ class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]): - Configuration is read from config.extension_config["adk"] """ - __slots__ = ("_in_memory", "_json_storage_type") + __slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info") def __init__(self, config: "OracleSyncConfig") -> None: """Initialize Oracle synchronous ADK store. @@ -898,6 +919,7 @@ def __init__(self, config: "OracleSyncConfig") -> None: """ super().__init__(config) self._json_storage_type: JSONStorageType | None = None + self._oracle_version_info: OracleVersionInfo | None = None adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) @@ -938,44 +960,22 @@ def _detect_json_storage_type(self) -> JSONStorageType: if self._json_storage_type is not None: return self._json_storage_type - with self._config.provide_connection() as conn: - cursor = conn.cursor() - cursor.execute( - """ - SELECT version FROM product_component_version - WHERE product LIKE 'Oracle%Database%' - """ - ) - row = cursor.fetchone() - - if row is None: - logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON") - self._json_storage_type = JSONStorageType.BLOB_JSON - return self._json_storage_type - - version_str = str(row[0]) - version_parts = version_str.split(".") - major_version = int(version_parts[0]) if version_parts else 0 - - if major_version >= ORACLE_MIN_JSON_NATIVE_VERSION: - cursor.execute("SELECT value FROM v$parameter WHERE name = 'compatible'") - compatible_row = cursor.fetchone() - if compatible_row: - compatible_parts = str(compatible_row[0]).split(".") - compatible_major = int(compatible_parts[0]) if compatible_parts else 0 - if compatible_major >= ORACLE_MIN_JSON_NATIVE_COMPATIBLE: - logger.info("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_str) - self._json_storage_type = JSONStorageType.JSON_NATIVE - return self._json_storage_type - - if major_version >= ORACLE_MIN_JSON_BLOB_VERSION: - logger.info("Detected Oracle %s, using BLOB_JSON (recommended)", version_str) - self._json_storage_type = JSONStorageType.BLOB_JSON - return self._json_storage_type - - logger.info("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_str) - self._json_storage_type = JSONStorageType.BLOB_PLAIN - return self._json_storage_type + version_info = self._get_version_info() + self._json_storage_type = _storage_type_from_version(version_info) + return self._json_storage_type + + def _get_version_info(self) -> "OracleVersionInfo | None": + """Return cached Oracle version info for sync store.""" + + if self._oracle_version_info is not None: + return self._oracle_version_info + + with self._config.provide_session() as driver: + data_dictionary = OracleSyncDataDictionary() + self._oracle_version_info = data_dictionary.get_version(driver) + if self._oracle_version_info is None: + logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage") + return self._oracle_version_info def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes": """Serialize state dictionary to appropriate format based on storage type. @@ -1010,7 +1010,7 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]": data = data.read() if isinstance(data, dict): - return data + return cast("dict[str, Any]", _coerce_decimal_values(data)) if isinstance(data, bytes): return from_json(data) # type: ignore[no-any-return] @@ -1058,7 +1058,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": data = data.read() if isinstance(data, dict): - return data + return cast("dict[str, Any]", _coerce_decimal_values(data)) if isinstance(data, bytes): return from_json(data) # type: ignore[no-any-return] diff --git a/tests/unit/test_adapters/test_oracledb/test_adk_store.py b/tests/unit/test_adapters/test_oracledb/test_adk_store.py new file mode 100644 index 00000000..c206ec1c --- /dev/null +++ b/tests/unit/test_adapters/test_oracledb/test_adk_store.py @@ -0,0 +1,49 @@ +"""Tests for Oracle ADK store Decimal coercion.""" + +from decimal import Decimal + +import pytest + +from sqlspec.adapters.oracledb.adk.store import OracleAsyncADKStore, OracleSyncADKStore + + +@pytest.mark.asyncio +async def test_oracle_async_adk_store_deserialize_dict_coerces_decimal() -> None: + store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg] + + payload = {"value": Decimal("1.25"), "nested": {"score": Decimal("0.5")}} + + result = await store._deserialize_json_field(payload) # type: ignore[attr-defined] + + assert result == {"value": 1.25, "nested": {"score": 0.5}} + + +@pytest.mark.asyncio +async def test_oracle_async_adk_store_deserialize_state_dict_coerces_decimal() -> None: + store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg] + + payload = {"state": Decimal("2.0")} + + result = await store._deserialize_state(payload) # type: ignore[attr-defined] + + assert result == {"state": 2.0} + + +def test_oracle_sync_adk_store_deserialize_dict_coerces_decimal() -> None: + store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg] + + payload = {"value": Decimal("3.14"), "items": [Decimal("1.0"), Decimal("2.0")]} + + result = store._deserialize_json_field(payload) # type: ignore[attr-defined] + + assert result == {"value": 3.14, "items": [1.0, 2.0]} + + +def test_oracle_sync_adk_store_deserialize_state_dict_coerces_decimal() -> None: + store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg] + + payload = {"state": Decimal("5.0")} + + result = store._deserialize_state(payload) # type: ignore[attr-defined] + + assert result == {"state": 5.0} From 106fbe21ee2d0ef345d4876c0c3d91ebaa783de0 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 5 Nov 2025 21:40:18 +0000 Subject: [PATCH 2/5] feat: Update version info retrieval to use Oracle data dictionary in async and sync stores --- sqlspec/adapters/oracledb/adk/store.py | 16 ++++++++++------ sqlspec/core/parameters/_validator.py | 10 +++++----- tests/unit/test_core/test_parameters.py | 2 ++ 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index fe60ed0c..32a0a082 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -207,16 +207,18 @@ async def _detect_json_storage_type(self) -> JSONStorageType: return self._json_storage_type async def _get_version_info(self) -> "OracleVersionInfo | None": - """Return cached Oracle version info, querying via data dictionary if needed.""" + """Return cached Oracle version info using Oracle data dictionary.""" if self._oracle_version_info is not None: return self._oracle_version_info async with self._config.provide_session() as driver: - data_dictionary = OracleAsyncDataDictionary() - self._oracle_version_info = await data_dictionary.get_version(driver) + dictionary = OracleAsyncDataDictionary() + self._oracle_version_info = await dictionary.get_version(driver) + if self._oracle_version_info is None: logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage") + return self._oracle_version_info async def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes": @@ -965,16 +967,18 @@ def _detect_json_storage_type(self) -> JSONStorageType: return self._json_storage_type def _get_version_info(self) -> "OracleVersionInfo | None": - """Return cached Oracle version info for sync store.""" + """Return cached Oracle version info using Oracle data dictionary.""" if self._oracle_version_info is not None: return self._oracle_version_info with self._config.provide_session() as driver: - data_dictionary = OracleSyncDataDictionary() - self._oracle_version_info = data_dictionary.get_version(driver) + dictionary = OracleSyncDataDictionary() + self._oracle_version_info = dictionary.get_version(driver) + if self._oracle_version_info is None: logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage") + return self._oracle_version_info def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes": diff --git a/sqlspec/core/parameters/_validator.py b/sqlspec/core/parameters/_validator.py index af4c069e..6fa131d5 100644 --- a/sqlspec/core/parameters/_validator.py +++ b/sqlspec/core/parameters/_validator.py @@ -20,11 +20,11 @@ (?P::(?P\w+)) | (?P%\((?P\w+)\)s) | (?P%s) | - (?P:(?P\d+)) | - (?P:(?P\w+)) | - (?P@(?P\w+)) | - (?P\$(?P\d+)) | - (?P\$(?P\w+)) | + (?P(?\d+)) | + (?P(?\w+)) | + (?P(?\w+)) | + (?P(?\d+)) | + (?P(?\w+)) | (?P\?) """, re.VERBOSE | re.IGNORECASE | re.MULTILINE | re.DOTALL, diff --git a/tests/unit/test_core/test_parameters.py b/tests/unit/test_core/test_parameters.py index 80f96aa6..1248de4b 100644 --- a/tests/unit/test_core/test_parameters.py +++ b/tests/unit/test_core/test_parameters.py @@ -669,6 +669,8 @@ def test_extract_parameters( ("SELECT * FROM json WHERE data ?| array['key']", True), ("SELECT * FROM json WHERE data ?& array['key']", True), ("SELECT * FROM users WHERE id::int = 5", False), + ("SELECT * FROM v$version", True), + ('SELECT * FROM "V$VERSION"', True), ], ) def test_extract_parameters_ignores_special_cases( From 6cb7041e859103822a6becceb08edf6034ade82d Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 5 Nov 2025 21:42:46 +0000 Subject: [PATCH 3/5] feat: Enhance PARAMETER_REGEX to support SQL Server global variables and add performance tests --- sqlspec/core/parameters/_validator.py | 2 + .../test_parameter_regex_performance.py | 206 ++++++++++++++++++ 2 files changed, 208 insertions(+) create mode 100644 tests/unit/test_core/test_parameter_regex_performance.py diff --git a/sqlspec/core/parameters/_validator.py b/sqlspec/core/parameters/_validator.py index 6fa131d5..7bc12ebd 100644 --- a/sqlspec/core/parameters/_validator.py +++ b/sqlspec/core/parameters/_validator.py @@ -18,6 +18,7 @@ (?P/\*(?:[^*]|\*(?!/))*\*/) | (?P\?\?|\?\||\?&) | (?P::(?P\w+)) | + (?P@@(?P\w+)) | (?P%\((?P\w+)\)s) | (?P%s) | (?P(?\d+)) | @@ -85,6 +86,7 @@ def extract_parameters(self, sql: str) -> "list[ParameterInfo]": "block_comment", "pg_q_operator", "pg_cast", + "sql_server_global", ) for match in PARAMETER_REGEX.finditer(sql): diff --git a/tests/unit/test_core/test_parameter_regex_performance.py b/tests/unit/test_core/test_parameter_regex_performance.py new file mode 100644 index 00000000..67cbc76b --- /dev/null +++ b/tests/unit/test_core/test_parameter_regex_performance.py @@ -0,0 +1,206 @@ +"""Performance and edge case tests for PARAMETER_REGEX. + +Tests regex efficiency, order dependency, and edge cases across all SQL dialects. +""" + +import time + +from sqlspec.core.parameters import ParameterValidator + + +class TestParameterRegexPerformance: + """Test parameter regex performance and correctness.""" + + def test_oracle_system_views_not_detected_as_parameters(self) -> None: + """Verify Oracle system views with $ are not detected as parameters.""" + validator = ParameterValidator() + + # Oracle system views should NOT be detected as parameters + test_cases = [ + ("SELECT * FROM v$version", []), + ("SELECT * FROM v$session WHERE sid = :sid", ["sid"]), + ("SELECT * FROM v$database, v$instance", []), + ("SELECT banner FROM v$version WHERE banner LIKE :pattern", ["pattern"]), + ("SELECT * FROM gv$session WHERE inst_id = :inst", ["inst"]), + ] + + for sql, expected_params in test_cases: + params = validator.extract_parameters(sql) + param_names = [p.name for p in params if p.name is not None] + assert param_names == expected_params, f"Failed for: {sql}" + + def test_sql_server_global_variables_not_detected(self) -> None: + """Verify SQL Server @@variables are not detected as parameters.""" + validator = ParameterValidator() + + test_cases = [ + ("SELECT @@VERSION", []), + ("SELECT @@IDENTITY", []), + ("SELECT @@ROWCOUNT, @param", ["param"]), + ("IF @@ERROR <> 0 SELECT @value", ["value"]), + ] + + for sql, expected_params in test_cases: + params = validator.extract_parameters(sql) + param_names = [p.name for p in params if p.name is not None] + assert param_names == expected_params, f"Failed for: {sql}" + + def test_postgres_dollar_quoted_strings_not_detected(self) -> None: + """Verify PostgreSQL dollar-quoted strings don't create false parameters.""" + validator = ParameterValidator() + + test_cases = [ + ("SELECT $$hello$$", []), + ("SELECT $tag$world$tag$", []), + ("SELECT $$value:123$$, :param", ["param"]), + ("SELECT $func$SELECT $1$func$, $1", ["1"]), + ] + + for sql, expected_params in test_cases: + params = validator.extract_parameters(sql) + param_names = [p.name for p in params if p.name is not None] + assert param_names == expected_params, f"Failed for: {sql}" + + def test_xml_namespaces_not_detected(self) -> None: + """Verify XML namespaces with colons are not detected as parameters.""" + validator = ParameterValidator() + + test_cases = [ + ("SELECT 'data'", []), + ("SELECT '' WHERE id = :id", ["id"]), + ("UPDATE xml SET data = '' WHERE id = :id", ["id"]), + ] + + for sql, expected_params in test_cases: + params = validator.extract_parameters(sql) + param_names = [p.name for p in params if p.name is not None] + assert param_names == expected_params, f"Failed for: {sql}" + + def test_parameter_order_precedence(self) -> None: + """Verify parameter detection order is correct.""" + validator = ParameterValidator() + + # Positional colon (:1, :2) MUST be detected before named colon (:name) + params = validator.extract_parameters("SELECT :1, :name, :123, :user") + assert len(params) == 4 + assert params[0].name == "1" + assert params[1].name == "name" + assert params[2].name == "123" + assert params[3].name == "user" + + def test_mixed_identifiers_and_parameters(self) -> None: + """Test SQL with identifiers that contain parameter-like characters.""" + validator = ParameterValidator() + + test_cases = [ + # Table names with special chars (not detected due to negative lookbehind) + ("SELECT * FROM user$data WHERE id = :id", ["id"]), + ("SELECT * FROM price@2023 WHERE amount > @amount", ["amount"]), + ("SELECT * FROM log:entry WHERE time = :time", ["time"]), + + # Column names - NOTE: @column IS detected as parameter in SQL Server + # This is correct behavior - if you write SELECT @var, it's a parameter + ("SELECT user$id, :param FROM table", ["param"]), + # @@VERSION is skipped, but @column and @param are detected (correct!) + ("SELECT @column, @@VERSION, @param FROM t", ["column", "param"]), + ] + + for sql, expected_params in test_cases: + params = validator.extract_parameters(sql) + param_names = [p.name for p in params if p.name is not None] + assert param_names == expected_params, f"Failed for: {sql}" + + def test_all_dialect_parameter_styles(self) -> None: + """Comprehensive test of all supported parameter styles.""" + validator = ParameterValidator() + + dialect_tests = { + "qmark": ("SELECT * FROM t WHERE a = ? AND b = ?", [None, None]), + "numeric": ("SELECT * FROM t WHERE a = $1 AND b = $2", ["1", "2"]), + "named_dollar": ("SELECT * FROM t WHERE a = $foo AND b = $bar", ["foo", "bar"]), + "named_colon": ("SELECT * FROM t WHERE a = :foo AND b = :bar", ["foo", "bar"]), + "positional_colon": ("SELECT * FROM t WHERE a = :1 AND b = :2", ["1", "2"]), + "named_at": ("SELECT * FROM t WHERE a = @foo AND b = @bar", ["foo", "bar"]), + "pyformat_named": ("SELECT * FROM t WHERE a = %(foo)s AND b = %(bar)s", ["foo", "bar"]), + "pyformat_pos": ("SELECT * FROM t WHERE a = %s AND b = %s", [None, None]), + } + + for style, (sql, expected) in dialect_tests.items(): + params = validator.extract_parameters(sql) + param_names = [p.name for p in params] + assert param_names == expected, f"Failed for {style}: {sql}" + + def test_regex_performance_on_large_sql(self) -> None: + """Benchmark regex performance on large SQL statements.""" + validator = ParameterValidator() + + # Generate large SQL with many parameters + large_sql = "SELECT * FROM t WHERE " + " AND ".join([f"col{i} = :param{i}" for i in range(1000)]) + + start = time.perf_counter() + params = validator.extract_parameters(large_sql) + elapsed = time.perf_counter() - start + + assert len(params) == 1000 + assert elapsed < 0.1, f"Regex took too long: {elapsed:.4f}s" # Should be <100ms + + # Test cache hit (should be much faster) + start = time.perf_counter() + params_cached = validator.extract_parameters(large_sql) + elapsed_cached = time.perf_counter() - start + + assert len(params_cached) == 1000 + assert elapsed_cached < 0.001, f"Cache lookup too slow: {elapsed_cached:.6f}s" # Should be <1ms + + def test_no_catastrophic_backtracking(self) -> None: + """Ensure regex doesn't have catastrophic backtracking.""" + validator = ParameterValidator() + + # Pathological cases that could cause backtracking + pathological_cases = [ + # Many nested quotes + "SELECT '" + ("x" * 10000) + "'", + # Many dollar signs (but not valid parameters) + "SELECT price" + ("$" * 1000) + "2023", + # Many colons in strings + "SELECT '" + ("::" * 1000) + "'", + ] + + for sql in pathological_cases: + start = time.perf_counter() + validator.extract_parameters(sql) + elapsed = time.perf_counter() - start + assert elapsed < 0.1, f"Pathological case too slow: {elapsed:.4f}s" + + def test_edge_case_empty_and_whitespace(self) -> None: + """Test edge cases with empty strings and whitespace.""" + validator = ParameterValidator() + + test_cases = [ + ("", []), + (" ", []), + ("SELECT :param", ["param"]), + ("SELECT :param1 , :param2 ", ["param1", "param2"]), + ("-- comment :not_param\nSELECT :param", ["param"]), + ] + + for sql, expected in test_cases: + params = validator.extract_parameters(sql) + param_names = [p.name for p in params if p.name is not None] + assert param_names == expected, f"Failed for: {sql!r}" + + def test_unicode_and_special_characters(self) -> None: + """Test parameter detection with Unicode and special characters.""" + validator = ParameterValidator() + + test_cases = [ + ("SELECT :café FROM table", ["café"]), + ("SELECT :用户 FROM table", ["用户"]), + ("SELECT :Москва FROM table", ["Москва"]), + ("SELECT :param_123 FROM table", ["param_123"]), + ] + + for sql, expected in test_cases: + params = validator.extract_parameters(sql) + param_names = [p.name for p in params if p.name is not None] + assert param_names == expected, f"Failed for: {sql}" From 56a48d59d4f79bfc6746295a4b537af4e53d8667 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 5 Nov 2025 21:47:00 +0000 Subject: [PATCH 4/5] fix: Remove unnecessary blank line in parameter regex performance tests --- tests/unit/test_core/test_parameter_regex_performance.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_core/test_parameter_regex_performance.py b/tests/unit/test_core/test_parameter_regex_performance.py index 67cbc76b..61bf82e1 100644 --- a/tests/unit/test_core/test_parameter_regex_performance.py +++ b/tests/unit/test_core/test_parameter_regex_performance.py @@ -97,7 +97,6 @@ def test_mixed_identifiers_and_parameters(self) -> None: ("SELECT * FROM user$data WHERE id = :id", ["id"]), ("SELECT * FROM price@2023 WHERE amount > @amount", ["amount"]), ("SELECT * FROM log:entry WHERE time = :time", ["time"]), - # Column names - NOTE: @column IS detected as parameter in SQL Server # This is correct behavior - if you write SELECT @var, it's a parameter ("SELECT user$id, :param FROM table", ["param"]), From 51407adf2452da350d09095aba9f4e0685a75a00 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Wed, 5 Nov 2025 21:52:29 +0000 Subject: [PATCH 5/5] refactor: Change logger level from info to debug for Oracle version detection in storage type --- sqlspec/adapters/oracledb/adk/store.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 32a0a082..35eb4d2b 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -59,15 +59,15 @@ def _storage_type_from_version(version_info: "OracleVersionInfo | None") -> JSON """Determine JSON storage type based on Oracle version metadata.""" if version_info and version_info.supports_native_json(): - logger.info("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_info) + logger.debug("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_info) return JSONStorageType.JSON_NATIVE if version_info and version_info.supports_json_blob(): - logger.info("Detected Oracle %s, using BLOB_JSON (recommended)", version_info) + logger.debug("Detected Oracle %s, using BLOB_JSON (recommended)", version_info) return JSONStorageType.BLOB_JSON if version_info: - logger.info("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_info) + logger.debug("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_info) return JSONStorageType.BLOB_PLAIN logger.warning("Oracle version could not be detected; defaulting to BLOB_JSON storage")