From 58afcd392b0e434871d48218828dce97d9ffde84 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 16 Nov 2025 01:42:27 +0000 Subject: [PATCH] feat(oracle): correct lookup for `JSON` supported versions --- sqlspec/adapters/oracledb/data_dictionary.py | 118 +++--- .../adapters/test_oracledb_uuid_handlers.py | 358 ------------------ .../test_duckdb/test_extension_flags.py} | 6 + .../test_oracledb/test_data_dictionary.py | 139 +++++++ .../test_oracledb/test_uuid_handlers.py | 240 ++++++++++++ 5 files changed, 458 insertions(+), 403 deletions(-) delete mode 100644 tests/unit/adapters/test_oracledb_uuid_handlers.py rename tests/unit/{adapters/test_duckdb_config.py => test_adapters/test_duckdb/test_extension_flags.py} (84%) create mode 100644 tests/unit/test_adapters/test_oracledb/test_data_dictionary.py create mode 100644 tests/unit/test_adapters/test_oracledb/test_uuid_handlers.py diff --git a/sqlspec/adapters/oracledb/data_dictionary.py b/sqlspec/adapters/oracledb/data_dictionary.py index 55193081..272f43e7 100644 --- a/sqlspec/adapters/oracledb/data_dictionary.py +++ b/sqlspec/adapters/oracledb/data_dictionary.py @@ -27,12 +27,14 @@ ORACLE_MIN_OSON_VERSION = 19 # Compiled regex patterns -ORACLE_VERSION_PATTERN = re.compile(r"Oracle Database (\d+)c?.* Release (\d+)\.(\d+)\.(\d+)") +VERSION_NUMBER_PATTERN = re.compile(r"(\d+)") +VERSION_COMPONENT_COUNT = 3 COMPONENT_VERSION_SQL = ( - "SELECT product || ' Release ' || version AS \"banner\" " + 'SELECT product AS "product", version AS "version", status AS "status" ' "FROM product_component_version WHERE product LIKE 'Oracle%' " - "ORDER BY version DESC FETCH FIRST 1 ROWS ONLY" + "ORDER BY TO_NUMBER(REGEXP_SUBSTR(version, '^[0-9]+')) DESC, version DESC " + "FETCH FIRST 1 ROWS ONLY" ) AUTONOMOUS_SERVICE_SQL = "SELECT sys_context('USERENV','CLOUD_SERVICE') AS \"service\" FROM dual" @@ -137,47 +139,86 @@ def _get_columns_sql(self, table: str, schema: "str | None" = None) -> str: ORDER BY column_id """ - def _select_version_banner(self, driver: "OracleSyncDriver") -> str: - return str(driver.select_value(COMPONENT_VERSION_SQL)) + def _select_component_version_row(self, driver: "OracleSyncDriver") -> "dict[str, Any] | None": + """Fetch the latest Oracle component version row. - async def _select_version_banner_async(self, driver: "OracleAsyncDriver") -> str: - result = await driver.select_value(COMPONENT_VERSION_SQL) - return str(result) + Args: + driver: Oracle sync driver instance. - def _get_oracle_version(self, driver: "OracleAsyncDriver | OracleSyncDriver") -> "OracleVersionInfo | None": - """Get Oracle database version information. + Returns: + First matching row from product_component_version or None. + """ + + result = driver.execute(COMPONENT_VERSION_SQL) + data = result.get_data() + if not data: + logger.warning("No rows returned from product_component_version") + return None + return data[0] + + async def _select_component_version_row_async(self, driver: "OracleAsyncDriver") -> "dict[str, Any] | None": + """Async helper to fetch the latest Oracle component version row. Args: - driver: Database driver instance + driver: Oracle async driver instance. Returns: - Oracle version information or None if detection fails + First matching row from product_component_version or None. + """ + + result = await driver.execute(COMPONENT_VERSION_SQL) + data = result.get_data() + if not data: + logger.warning("No rows returned from product_component_version") + return None + return data[0] + + def _build_version_info_from_row(self, row: "dict[str, Any] | None") -> "OracleVersionInfo | None": + """Build Oracle version metadata from a component version row. + + Args: + row: Data dictionary row containing product/version fields. + + Returns: + OracleVersionInfo if parsing succeeds, otherwise None. """ - banner = self._select_version_banner(cast("OracleSyncDriver", driver)) - # Parse version from banner like "Oracle Database 21c Enterprise Edition Release 21.0.0.0.0 - Production" - # or "Oracle Database 19c Standard Edition 2 Release 19.0.0.0.0 - Production" - version_match = ORACLE_VERSION_PATTERN.search(str(banner)) + if not row: + logger.warning("Unable to determine Oracle version without component data") + return None + + version_value = row.get("version_full") or row.get("VERSION_FULL") or row.get("version") or row.get("VERSION") - if not version_match: - logger.warning("Could not parse Oracle version from banner: %s", banner) + if version_value is None: + logger.warning("Component version row missing VERSION column: %s", row) return None - major = int(version_match.group(1)) - release_major = int(version_match.group(2)) - minor = int(version_match.group(3)) - patch = int(version_match.group(4)) + matches = VERSION_NUMBER_PATTERN.findall(str(version_value)) + if not matches: + logger.warning("Unable to parse Oracle version from value: %s", version_value) + return None - # For Oracle 21c+, the major version is in the first group - # For Oracle 19c and earlier, use the release version - if major >= ORACLE_MIN_JSON_NATIVE_VERSION: - version_info = OracleVersionInfo(major, minor, patch) - else: - version_info = OracleVersionInfo(release_major, minor, patch) + numbers = [int(match) for match in matches[:VERSION_COMPONENT_COUNT]] + while len(numbers) < VERSION_COMPONENT_COUNT: + numbers.append(0) - logger.debug("Detected Oracle version: %s", version_info) + version_info = OracleVersionInfo(numbers[0], numbers[1], numbers[2]) + product_name = row.get("product") or row.get("PRODUCT") or "Oracle Database" + logger.debug("Detected Oracle component version for %s: %s", product_name, version_info) return version_info + def _get_oracle_version(self, driver: "OracleAsyncDriver | OracleSyncDriver") -> "OracleVersionInfo | None": + """Get Oracle database version information. + + Args: + driver: Database driver instance + + Returns: + Oracle version information or None if detection fails + """ + row = self._select_component_version_row(cast("OracleSyncDriver", driver)) + return self._build_version_info_from_row(row) + def _get_oracle_compatible(self, driver: "OracleAsyncDriver | OracleSyncDriver") -> "str | None": """Get Oracle compatible parameter value. @@ -372,25 +413,12 @@ async def get_version(self, driver: AsyncDriverAdapterBase) -> "OracleVersionInf Oracle version information or None if detection fails """ oracle_driver = cast("OracleAsyncDriver", driver) - banner = await self._select_version_banner_async(oracle_driver) + row = await self._select_component_version_row_async(oracle_driver) + version_info = self._build_version_info_from_row(row) - version_match = ORACLE_VERSION_PATTERN.search(str(banner)) - - if not version_match: - logger.warning("Could not parse Oracle version from banner: %s", banner) + if not version_info: return None - major = int(version_match.group(1)) - release_major = int(version_match.group(2)) - minor = int(version_match.group(3)) - patch = int(version_match.group(4)) - - if major >= ORACLE_MIN_JSON_NATIVE_VERSION: - version_info = OracleVersionInfo(major, minor, patch) - else: - version_info = OracleVersionInfo(release_major, minor, patch) - - # Enhance with additional information compatible = await self._get_oracle_compatible_async(oracle_driver) is_autonomous = await self._is_oracle_autonomous_async(oracle_driver) diff --git a/tests/unit/adapters/test_oracledb_uuid_handlers.py b/tests/unit/adapters/test_oracledb_uuid_handlers.py deleted file mode 100644 index 5e1b9e39..00000000 --- a/tests/unit/adapters/test_oracledb_uuid_handlers.py +++ /dev/null @@ -1,358 +0,0 @@ -"""Unit tests for Oracle UUID type handlers.""" - -import uuid -from unittest.mock import Mock - -from sqlspec.adapters.oracledb._uuid_handlers import ( - _input_type_handler, # pyright: ignore - _output_type_handler, # pyright: ignore - register_uuid_handlers, - uuid_converter_in, - uuid_converter_out, -) - - -def test_uuid_converter_in() -> None: - """Test UUID to bytes conversion.""" - test_uuid = uuid.UUID("12345678-1234-5678-1234-567812345678") - result = uuid_converter_in(test_uuid) - - assert isinstance(result, bytes) - assert len(result) == 16 - assert result == test_uuid.bytes - - -def test_uuid_converter_out_valid() -> None: - """Test valid bytes to UUID conversion.""" - test_uuid = uuid.UUID("87654321-4321-8765-4321-876543218765") - test_bytes = test_uuid.bytes - - result = uuid_converter_out(test_bytes) - - assert isinstance(result, uuid.UUID) - assert result == test_uuid - - -def test_uuid_converter_out_none() -> None: - """Test NULL handling returns None.""" - result = uuid_converter_out(None) - assert result is None - - -def test_uuid_converter_out_invalid_length() -> None: - """Test invalid length bytes returns original bytes.""" - invalid_bytes = b"12345" - result = uuid_converter_out(invalid_bytes) - - assert result is invalid_bytes - assert isinstance(result, bytes) - - -def test_uuid_converter_out_invalid_format() -> None: - """Test invalid UUID format bytes gracefully falls back to bytes. - - Note: Most 16-byte values are technically valid UUIDs, so this test - verifies that the converter attempts conversion and returns bytes - if it somehow fails (which is rare in practice). - """ - test_bytes = uuid.uuid4().bytes - result = uuid_converter_out(test_bytes) - - assert isinstance(result, uuid.UUID) - - -def test_uuid_converter_out_type_error() -> None: - """Test TypeError during UUID conversion falls back to original value.""" - from unittest.mock import patch - - test_bytes = b"1234567890123456" - - with patch("uuid.UUID", side_effect=TypeError("Invalid type")): - result = uuid_converter_out(test_bytes) - - assert result is test_bytes - assert isinstance(result, bytes) - - -def test_uuid_converter_out_value_error() -> None: - """Test ValueError during UUID conversion falls back to original value.""" - from unittest.mock import patch - - test_bytes = b"1234567890123456" - - with patch("uuid.UUID", side_effect=ValueError("Invalid UUID")): - result = uuid_converter_out(test_bytes) - - assert result is test_bytes - assert isinstance(result, bytes) - - -def test_uuid_variants() -> None: - """Test all UUID variants (v1, v4, v5) roundtrip correctly.""" - test_uuids = [uuid.uuid1(), uuid.uuid4(), uuid.uuid5(uuid.NAMESPACE_DNS, "example.com")] - - for test_uuid in test_uuids: - binary = uuid_converter_in(test_uuid) - converted = uuid_converter_out(binary) - assert converted == test_uuid - - -def test_uuid_roundtrip() -> None: - """Test complete roundtrip conversion.""" - original = uuid.uuid4() - binary = uuid_converter_in(original) - converted = uuid_converter_out(binary) - - assert converted == original - assert isinstance(converted, uuid.UUID) - - -def test_input_type_handler_with_uuid() -> None: - """Test input type handler detects UUID and creates cursor variable.""" - import oracledb - - cursor = Mock() - cursor_var = Mock() - cursor.var = Mock(return_value=cursor_var) - - test_uuid = uuid.uuid4() - arraysize = 1 - - result = _input_type_handler(cursor, test_uuid, arraysize) - - assert result is cursor_var - cursor.var.assert_called_once_with(oracledb.DB_TYPE_RAW, arraysize=arraysize, inconverter=uuid_converter_in) - - -def test_input_type_handler_with_non_uuid() -> None: - """Test input type handler returns None for non-UUID values.""" - cursor = Mock() - - result = _input_type_handler(cursor, "not a uuid", 1) - - assert result is None - cursor.var.assert_not_called() - - -def test_input_type_handler_with_string() -> None: - """Test input type handler returns None for string values.""" - cursor = Mock() - - result = _input_type_handler(cursor, "12345678-1234-5678-1234-567812345678", 1) - - assert result is None - - -def test_input_type_handler_with_bytes() -> None: - """Test input type handler returns None for bytes values.""" - cursor = Mock() - - result = _input_type_handler(cursor, b"some bytes", 1) - - assert result is None - - -def test_output_type_handler_with_raw16() -> None: - """Test output type handler detects RAW(16) columns.""" - import oracledb - - cursor = Mock() - cursor.arraysize = 50 - cursor_var = Mock() - cursor.var = Mock(return_value=cursor_var) - - metadata = ("RAW_COL", oracledb.DB_TYPE_RAW, 16, 16, None, None, True) - - result = _output_type_handler(cursor, metadata) - - assert result is cursor_var - cursor.var.assert_called_once_with(oracledb.DB_TYPE_RAW, arraysize=50, outconverter=uuid_converter_out) - - -def test_output_type_handler_with_raw32() -> None: - """Test output type handler returns None for RAW(32) columns.""" - import oracledb - - cursor = Mock() - metadata = ("RAW32_COL", oracledb.DB_TYPE_RAW, 32, 32, None, None, True) - - result = _output_type_handler(cursor, metadata) - - assert result is None - - -def test_output_type_handler_with_varchar() -> None: - """Test output type handler returns None for VARCHAR2 columns.""" - import oracledb - - cursor = Mock() - metadata = ("VARCHAR_COL", oracledb.DB_TYPE_VARCHAR, 36, 36, None, None, True) - - result = _output_type_handler(cursor, metadata) - - assert result is None - - -def test_output_type_handler_with_number() -> None: - """Test output type handler returns None for NUMBER columns.""" - import oracledb - - cursor = Mock() - metadata = ("NUM_COL", oracledb.DB_TYPE_NUMBER, 10, 10, 10, 0, True) - - result = _output_type_handler(cursor, metadata) - - assert result is None - - -def test_register_uuid_handlers_no_existing() -> None: - """Test registering UUID handlers on connection without existing handlers.""" - connection = Mock() - connection.inputtypehandler = None - connection.outputtypehandler = None - - register_uuid_handlers(connection) - - assert connection.inputtypehandler is not None - assert connection.outputtypehandler is not None - - -def test_register_uuid_handlers_with_chaining() -> None: - """Test UUID handler chaining with existing handlers.""" - existing_input = Mock(return_value=None) - existing_output = Mock(return_value=None) - - connection = Mock() - connection.inputtypehandler = existing_input - connection.outputtypehandler = existing_output - - register_uuid_handlers(connection) - - assert connection.inputtypehandler is not None - assert connection.outputtypehandler is not None - assert connection.inputtypehandler != existing_input - assert connection.outputtypehandler != existing_output - - -def test_register_uuid_handlers_chaining_fallback() -> None: - """Test chaining falls back to existing handler when UUID handler returns None.""" - existing_input_result = Mock() - existing_input = Mock(return_value=existing_input_result) - - connection = Mock() - connection.inputtypehandler = existing_input - connection.outputtypehandler = None - - register_uuid_handlers(connection) - - cursor = Mock() - non_uuid_value = "not a uuid" - - result = connection.inputtypehandler(cursor, non_uuid_value, 1) - - existing_input.assert_called_once_with(cursor, non_uuid_value, 1) - assert result is existing_input_result - - -def test_register_uuid_handlers_chaining_uuid_takes_priority() -> None: - """Test UUID handler takes priority over existing handler for UUID values.""" - import oracledb - - existing_input = Mock(return_value=Mock()) - - connection = Mock() - connection.inputtypehandler = existing_input - connection.outputtypehandler = None - - register_uuid_handlers(connection) - - cursor = Mock() - cursor_var = Mock() - cursor.var = Mock(return_value=cursor_var) - test_uuid = uuid.uuid4() - - result = connection.inputtypehandler(cursor, test_uuid, 1) - - existing_input.assert_not_called() - assert result is cursor_var - cursor.var.assert_called_once_with(oracledb.DB_TYPE_RAW, arraysize=1, inconverter=uuid_converter_in) - - -def test_register_uuid_handlers_output_chaining() -> None: - """Test output handler chaining delegates to existing handler for non-RAW16.""" - import oracledb - - existing_output_result = Mock() - existing_output = Mock(return_value=existing_output_result) - - connection = Mock() - connection.inputtypehandler = None - connection.outputtypehandler = existing_output - - register_uuid_handlers(connection) - - cursor = Mock() - metadata = ("VARCHAR_COL", oracledb.DB_TYPE_VARCHAR, 36, 36, None, None, True) - - result = connection.outputtypehandler(cursor, metadata) - - existing_output.assert_called_once_with(cursor, metadata) - assert result is existing_output_result - - -def test_combined_input_handler_no_existing_non_uuid() -> None: - """Test combined input handler returns None when no existing handler and non-UUID value.""" - connection = Mock() - connection.inputtypehandler = None - connection.outputtypehandler = None - - register_uuid_handlers(connection) - - cursor = Mock() - result = connection.inputtypehandler(cursor, "not a uuid", 1) - - assert result is None - - -def test_combined_output_handler_no_existing_non_raw16() -> None: - """Test combined output handler returns None when no existing handler and non-RAW16.""" - import oracledb - - connection = Mock() - connection.inputtypehandler = None - connection.outputtypehandler = None - - register_uuid_handlers(connection) - - cursor = Mock() - metadata = ("VARCHAR_COL", oracledb.DB_TYPE_VARCHAR, 36, 36, None, None, True) - - result = connection.outputtypehandler(cursor, metadata) - - assert result is None - - -def test_combined_output_handler_raw16_priority() -> None: - """Test combined output handler prioritizes UUID handler for RAW16.""" - import oracledb - - existing_output = Mock(return_value=Mock()) - - connection = Mock() - connection.inputtypehandler = None - connection.outputtypehandler = existing_output - - register_uuid_handlers(connection) - - cursor = Mock() - cursor.arraysize = 50 - cursor_var = Mock() - cursor.var = Mock(return_value=cursor_var) - - metadata = ("RAW_COL", oracledb.DB_TYPE_RAW, 16, 16, None, None, True) - - result = connection.outputtypehandler(cursor, metadata) - - existing_output.assert_not_called() - assert result is cursor_var - cursor.var.assert_called_once_with(oracledb.DB_TYPE_RAW, arraysize=50, outconverter=uuid_converter_out) diff --git a/tests/unit/adapters/test_duckdb_config.py b/tests/unit/test_adapters/test_duckdb/test_extension_flags.py similarity index 84% rename from tests/unit/adapters/test_duckdb_config.py rename to tests/unit/test_adapters/test_duckdb/test_extension_flags.py index 5ebc9c34..5a2b7381 100644 --- a/tests/unit/adapters/test_duckdb_config.py +++ b/tests/unit/test_adapters/test_duckdb/test_extension_flags.py @@ -1,3 +1,5 @@ +"""DuckDB configuration tests for security/extension flag promotion.""" + import pytest pytest.importorskip("duckdb", reason="DuckDB adapter requires duckdb package") @@ -6,6 +8,8 @@ def test_duckdb_config_promotes_security_flags() -> None: + """Extension flags should move from pool_config to driver_features.""" + config = DuckDBConfig( pool_config={ "database": ":memory:", @@ -27,6 +31,8 @@ def test_duckdb_config_promotes_security_flags() -> None: def test_duckdb_config_merges_existing_extension_flags() -> None: + """Existing driver feature flags should merge with promoted ones.""" + config = DuckDBConfig( pool_config={"database": ":memory:", "allow_community_extensions": True}, driver_features={"extension_flags": {"custom": "value"}}, diff --git a/tests/unit/test_adapters/test_oracledb/test_data_dictionary.py b/tests/unit/test_adapters/test_oracledb/test_data_dictionary.py new file mode 100644 index 00000000..55171e70 --- /dev/null +++ b/tests/unit/test_adapters/test_oracledb/test_data_dictionary.py @@ -0,0 +1,139 @@ +"""Unit tests for Oracle data dictionary version handling.""" + +from typing import Any, cast + +import pytest + +from sqlspec.adapters.oracledb.data_dictionary import OracleAsyncDataDictionary, OracleSyncDataDictionary +from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase + +ORACLE_AI_COMPONENT_ROW = { + "product": "Oracle AI Database 26ai Free", + "version": "23.26.0.0.0", + "status": "Develop, Learn, and Run for Free", +} + + +class _FakeResult: + """Return predetermined rows to mimic SQLResult.""" + + def __init__(self, rows: "list[dict[str, Any]]") -> None: + self._rows = rows + + def get_data(self, schema_type: "type[Any] | None" = None) -> "list[dict[str, Any]]": + """Return stored rows.""" + + _ = schema_type + return self._rows + + +class _FakeSyncOracleDriver: + """Minimal sync Oracle driver stub for data dictionary tests.""" + + def __init__( + self, rows: "list[dict[str, Any]]", compatible: str = "23.0.0.0.0", service: str = "AUTONOMOUS SHARED" + ) -> None: + self._rows = rows + self._compatible = compatible + self._service = service + + def execute(self, statement: str, *args: Any, **kwargs: Any) -> "_FakeResult": + """Return stored rows regardless of SQL.""" + + _ = (statement, args, kwargs) + return _FakeResult(self._rows) + + def select_value(self, statement: str, *args: Any, **kwargs: Any) -> str: + """Return compatible parameter when requested.""" + + _ = (args, kwargs) + if "v$parameter" in statement.lower(): + return self._compatible + raise ValueError(f"Unexpected select_value SQL: {statement}") + + def select_value_or_none(self, statement: str, *args: Any, **kwargs: Any) -> str | None: + """Return cloud service identifier when requested.""" + + _ = (args, kwargs) + if "sys_context" in statement.lower(): + return self._service + return None + + +class _FakeAsyncOracleDriver: + """Minimal async Oracle driver stub for data dictionary tests.""" + + def __init__( + self, rows: "list[dict[str, Any]]", compatible: str = "23.0.0.0.0", service: str = "AUTONOMOUS SHARED" + ) -> None: + self._rows = rows + self._compatible = compatible + self._service = service + + async def execute(self, statement: str, *args: Any, **kwargs: Any) -> "_FakeResult": + """Return stored rows regardless of SQL (async).""" + + _ = (statement, args, kwargs) + return _FakeResult(self._rows) + + async def select_value(self, statement: str, *args: Any, **kwargs: Any) -> str: + """Return compatible parameter when requested (async).""" + + _ = (args, kwargs) + if "v$parameter" in statement.lower(): + return self._compatible + raise ValueError(f"Unexpected select_value SQL: {statement}") + + async def select_value_or_none(self, statement: str, *args: Any, **kwargs: Any) -> str | None: + """Return cloud service identifier when requested (async).""" + + _ = (args, kwargs) + if "sys_context" in statement.lower(): + return self._service + return None + + +@pytest.fixture +def oracle_component_rows() -> "list[dict[str, Any]]": + """Return canonical Oracle component version rows for tests.""" + + return [dict(ORACLE_AI_COMPONENT_ROW)] + + +@pytest.fixture +def oracle_sync_driver(oracle_component_rows: "list[dict[str, Any]]") -> "_FakeSyncOracleDriver": + """Build a fake sync Oracle driver using the canonical component row.""" + + return _FakeSyncOracleDriver(oracle_component_rows, compatible="23.20.0.0.0", service="AUTONOMOUS AI") + + +@pytest.fixture +def oracle_async_driver(oracle_component_rows: "list[dict[str, Any]]") -> "_FakeAsyncOracleDriver": + """Build a fake async Oracle driver using the canonical component row.""" + + return _FakeAsyncOracleDriver(oracle_component_rows, compatible="23.20.0.0.0", service="AUTONOMOUS AI") + + +def test_sync_data_dictionary_detects_native_json_type(oracle_sync_driver: "_FakeSyncOracleDriver") -> None: + """Ensure sync data dictionary maps Oracle 23ai to native JSON columns.""" + + data_dictionary = OracleSyncDataDictionary() + sync_driver = cast("SyncDriverAdapterBase", oracle_sync_driver) + version_info = data_dictionary.get_version(sync_driver) + + assert version_info is not None + assert version_info.supports_native_json() + assert data_dictionary.get_optimal_type(sync_driver, "json") == "JSON" + + +@pytest.mark.anyio +async def test_async_data_dictionary_detects_native_json_type(oracle_async_driver: "_FakeAsyncOracleDriver") -> None: + """Ensure async data dictionary maps Oracle 23ai to native JSON columns.""" + + data_dictionary = OracleAsyncDataDictionary() + async_driver = cast("AsyncDriverAdapterBase", oracle_async_driver) + version_info = await data_dictionary.get_version(async_driver) + + assert version_info is not None + assert version_info.supports_native_json() + assert await data_dictionary.get_optimal_type(async_driver, "json") == "JSON" diff --git a/tests/unit/test_adapters/test_oracledb/test_uuid_handlers.py b/tests/unit/test_adapters/test_oracledb/test_uuid_handlers.py new file mode 100644 index 00000000..2afce0d5 --- /dev/null +++ b/tests/unit/test_adapters/test_oracledb/test_uuid_handlers.py @@ -0,0 +1,240 @@ +"""Unit tests for Oracle UUID type handlers.""" + +import uuid +from unittest.mock import Mock, patch + +from sqlspec.adapters.oracledb._uuid_handlers import ( + _input_type_handler, # pyright: ignore + _output_type_handler, # pyright: ignore + register_uuid_handlers, + uuid_converter_in, + uuid_converter_out, +) + + +def test_uuid_converter_in() -> None: + """UUID instances should convert to 16-byte RAW payloads.""" + + test_uuid = uuid.UUID("12345678-1234-5678-1234-567812345678") + result = uuid_converter_in(test_uuid) + + assert isinstance(result, bytes) + assert len(result) == 16 + assert result == test_uuid.bytes + + +def test_uuid_converter_out_valid() -> None: + """16-byte RAW payloads should convert back to UUID.""" + + test_uuid = uuid.UUID("87654321-4321-8765-4321-876543218765") + result = uuid_converter_out(test_uuid.bytes) + + assert isinstance(result, uuid.UUID) + assert result == test_uuid + + +def test_uuid_converter_out_none() -> None: + """None should stay None.""" + + assert uuid_converter_out(None) is None + + +def test_uuid_converter_out_invalid_length() -> None: + """Bytes with invalid length should be returned unchanged.""" + + invalid_bytes = b"12345" + result = uuid_converter_out(invalid_bytes) + + assert result is invalid_bytes + + +def test_uuid_converter_out_type_error() -> None: + """TypeError should fall back to original bytes.""" + + payload = b"1234567890123456" + with patch("uuid.UUID", side_effect=TypeError("Invalid type")): + result = uuid_converter_out(payload) + assert result is payload + + +def test_uuid_converter_out_value_error() -> None: + """ValueError should fall back to original bytes.""" + + payload = b"1234567890123456" + with patch("uuid.UUID", side_effect=ValueError("Invalid UUID")): + result = uuid_converter_out(payload) + assert result is payload + + +def test_uuid_variants_roundtrip() -> None: + """Multiple UUID variants should roundtrip via converters.""" + + variants = [uuid.uuid1(), uuid.uuid4(), uuid.uuid5(uuid.NAMESPACE_DNS, "example.com")] + for entry in variants: + binary = uuid_converter_in(entry) + assert uuid_converter_out(binary) == entry + + +def test_input_type_handler_with_uuid() -> None: + """Input handler should bind UUID values as RAW(16).""" + + import oracledb + + cursor = Mock() + cursor_var = Mock() + cursor.var = Mock(return_value=cursor_var) + + result = _input_type_handler(cursor, uuid.uuid4(), 1) + + assert result is cursor_var + cursor.var.assert_called_once_with(oracledb.DB_TYPE_RAW, arraysize=1, inconverter=uuid_converter_in) + + +def test_input_type_handler_non_uuid() -> None: + """Input handler should return None for non-UUID values.""" + + cursor = Mock() + assert _input_type_handler(cursor, "not-a-uuid", 1) is None + + +def test_output_type_handler_raw16() -> None: + """Output handler should wrap RAW(16) metadata.""" + + import oracledb + + cursor = Mock() + cursor.arraysize = 10 + cursor_var = Mock() + cursor.var = Mock(return_value=cursor_var) + + metadata = ("RAW_COL", oracledb.DB_TYPE_RAW, 16, 16, None, None, True) + result = _output_type_handler(cursor, metadata) + + assert result is cursor_var + cursor.var.assert_called_once_with(oracledb.DB_TYPE_RAW, arraysize=10, outconverter=uuid_converter_out) + + +def test_output_type_handler_non_raw16() -> None: + """Output handler should ignore non RAW(16) columns.""" + + import oracledb + + cursor = Mock() + metadata = ("VARCHAR_COL", oracledb.DB_TYPE_VARCHAR, 36, 36, None, None, True) + assert _output_type_handler(cursor, metadata) is None + + +def test_register_uuid_handlers_no_existing() -> None: + """Registering handlers without existing ones should set both hooks.""" + + connection = Mock() + connection.inputtypehandler = None + connection.outputtypehandler = None + + register_uuid_handlers(connection) + + assert connection.inputtypehandler is not None + assert connection.outputtypehandler is not None + + +def test_register_uuid_handlers_with_existing() -> None: + """Registering handlers should chain with existing hooks.""" + + existing_input = Mock(return_value=None) + existing_output = Mock(return_value=None) + + connection = Mock() + connection.inputtypehandler = existing_input + connection.outputtypehandler = existing_output + + register_uuid_handlers(connection) + + assert connection.inputtypehandler is not existing_input + assert connection.outputtypehandler is not existing_output + + +def test_input_handler_chain_uses_existing_for_non_uuid() -> None: + """Combined handler should defer to existing handler for non-UUID values.""" + + fallback = Mock() + connection = Mock() + connection.inputtypehandler = fallback + connection.outputtypehandler = None + + register_uuid_handlers(connection) + + cursor = Mock() + value = "not-a-uuid" + result = connection.inputtypehandler(cursor, value, 1) + + fallback.assert_called_once_with(cursor, value, 1) + assert result is fallback.return_value + + +def test_input_handler_chain_prioritizes_uuid() -> None: + """Combined handler should intercept UUID values before fallback handler.""" + + import oracledb + + fallback = Mock() + connection = Mock() + connection.inputtypehandler = fallback + connection.outputtypehandler = None + + register_uuid_handlers(connection) + + cursor = Mock() + cursor_var = Mock() + cursor.var = Mock(return_value=cursor_var) + test_uuid = uuid.uuid4() + + result = connection.inputtypehandler(cursor, test_uuid, 1) + + fallback.assert_not_called() + assert result is cursor_var + cursor.var.assert_called_once_with(oracledb.DB_TYPE_RAW, arraysize=1, inconverter=uuid_converter_in) + + +def test_output_handler_chain_uses_existing_for_non_uuid() -> None: + """Combined output handler should defer to fallback for non RAW(16).""" + + import oracledb + + fallback = Mock() + connection = Mock() + connection.inputtypehandler = None + connection.outputtypehandler = fallback + + register_uuid_handlers(connection) + + cursor = Mock() + metadata = ("VARCHAR_COL", oracledb.DB_TYPE_VARCHAR, 36, 36, None, None, True) + result = connection.outputtypehandler(cursor, metadata) + + fallback.assert_called_once_with(cursor, metadata) + assert result is fallback.return_value + + +def test_output_handler_chain_prioritizes_raw16() -> None: + """Combined output handler should intercept RAW(16) columns.""" + + import oracledb + + fallback = Mock() + connection = Mock() + connection.inputtypehandler = None + connection.outputtypehandler = fallback + + register_uuid_handlers(connection) + + cursor = Mock() + cursor.arraysize = 32 + cursor_var = Mock() + cursor.var = Mock(return_value=cursor_var) + metadata = ("RAW_COL", oracledb.DB_TYPE_RAW, 16, 16, None, None, True) + + result = connection.outputtypehandler(cursor, metadata) + + fallback.assert_not_called() + assert result is cursor_var + cursor.var.assert_called_once_with(oracledb.DB_TYPE_RAW, arraysize=32, outconverter=uuid_converter_out)