Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 95 additions & 91 deletions sqlspec/adapters/oracledb/adk/store.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
"""Oracle ADK store for Google Agent Development Kit session/event storage."""

from decimal import Decimal
from enum import Enum
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, cast

import oracledb

from sqlspec import SQL
from sqlspec.adapters.oracledb.data_dictionary import (
OracleAsyncDataDictionary,
OracleSyncDataDictionary,
OracleVersionInfo,
)
from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord
from sqlspec.utils.logging import get_logger
from sqlspec.utils.serializers import from_json, to_json
Expand Down Expand Up @@ -33,6 +39,41 @@ class JSONStorageType(str, Enum):
BLOB_PLAIN = "blob_plain"


def _coerce_decimal_values(value: Any) -> Any:
if isinstance(value, Decimal):
return float(value)
if isinstance(value, dict):
return {key: _coerce_decimal_values(val) for key, val in value.items()}
if isinstance(value, list):
return [_coerce_decimal_values(item) for item in value]
if isinstance(value, tuple):
return tuple(_coerce_decimal_values(item) for item in value)
if isinstance(value, set):
return {_coerce_decimal_values(item) for item in value}
if isinstance(value, frozenset):
return frozenset(_coerce_decimal_values(item) for item in value)
return value


def _storage_type_from_version(version_info: "OracleVersionInfo | None") -> JSONStorageType:
"""Determine JSON storage type based on Oracle version metadata."""

if version_info and version_info.supports_native_json():
logger.debug("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_info)
return JSONStorageType.JSON_NATIVE

if version_info and version_info.supports_json_blob():
logger.debug("Detected Oracle %s, using BLOB_JSON (recommended)", version_info)
return JSONStorageType.BLOB_JSON

if version_info:
logger.debug("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_info)
return JSONStorageType.BLOB_PLAIN

logger.warning("Oracle version could not be detected; defaulting to BLOB_JSON storage")
return JSONStorageType.BLOB_JSON


def _to_oracle_bool(value: "bool | None") -> "int | None":
"""Convert Python boolean to Oracle NUMBER(1).

Expand Down Expand Up @@ -103,7 +144,7 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]):
- Configuration is read from config.extension_config["adk"]
"""

__slots__ = ("_in_memory", "_json_storage_type")
__slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info")

def __init__(self, config: "OracleAsyncConfig") -> None:
"""Initialize Oracle ADK store.
Expand All @@ -120,6 +161,7 @@ def __init__(self, config: "OracleAsyncConfig") -> None:
"""
super().__init__(config)
self._json_storage_type: JSONStorageType | None = None
self._oracle_version_info: OracleVersionInfo | None = None

adk_config = config.extension_config.get("adk", {})
self._in_memory: bool = bool(adk_config.get("in_memory", False))
Expand Down Expand Up @@ -160,44 +202,24 @@ async def _detect_json_storage_type(self) -> JSONStorageType:
if self._json_storage_type is not None:
return self._json_storage_type

async with self._config.provide_connection() as conn:
cursor = conn.cursor()
await cursor.execute(
"""
SELECT version FROM product_component_version
WHERE product LIKE 'Oracle%Database%'
"""
)
row = await cursor.fetchone()

if row is None:
logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON")
self._json_storage_type = JSONStorageType.BLOB_JSON
return self._json_storage_type

version_str = str(row[0])
version_parts = version_str.split(".")
major_version = int(version_parts[0]) if version_parts else 0

if major_version >= ORACLE_MIN_JSON_NATIVE_VERSION:
await cursor.execute("SELECT value FROM v$parameter WHERE name = 'compatible'")
compatible_row = await cursor.fetchone()
if compatible_row:
compatible_parts = str(compatible_row[0]).split(".")
compatible_major = int(compatible_parts[0]) if compatible_parts else 0
if compatible_major >= ORACLE_MIN_JSON_NATIVE_COMPATIBLE:
logger.info("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_str)
self._json_storage_type = JSONStorageType.JSON_NATIVE
return self._json_storage_type

if major_version >= ORACLE_MIN_JSON_BLOB_VERSION:
logger.info("Detected Oracle %s, using BLOB_JSON (recommended)", version_str)
self._json_storage_type = JSONStorageType.BLOB_JSON
return self._json_storage_type

logger.info("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_str)
self._json_storage_type = JSONStorageType.BLOB_PLAIN
return self._json_storage_type
version_info = await self._get_version_info()
self._json_storage_type = _storage_type_from_version(version_info)
return self._json_storage_type

async def _get_version_info(self) -> "OracleVersionInfo | None":
"""Return cached Oracle version info using Oracle data dictionary."""

if self._oracle_version_info is not None:
return self._oracle_version_info

async with self._config.provide_session() as driver:
dictionary = OracleAsyncDataDictionary()
self._oracle_version_info = await dictionary.get_version(driver)

if self._oracle_version_info is None:
logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage")

return self._oracle_version_info

async def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes":
"""Serialize state dictionary to appropriate format based on storage type.
Expand Down Expand Up @@ -232,7 +254,7 @@ async def _deserialize_state(self, data: Any) -> "dict[str, Any]":
data = await data.read()

if isinstance(data, dict):
return data
return cast("dict[str, Any]", _coerce_decimal_values(data))

if isinstance(data, bytes):
return from_json(data) # type: ignore[no-any-return]
Expand Down Expand Up @@ -280,7 +302,7 @@ async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
data = await data.read()

if isinstance(data, dict):
return data
return cast("dict[str, Any]", _coerce_decimal_values(data))

if isinstance(data, bytes):
return from_json(data) # type: ignore[no-any-return]
Expand Down Expand Up @@ -490,7 +512,7 @@ async def create_tables(self) -> None:
Uses version-appropriate table schema.
"""
storage_type = await self._detect_json_storage_type()
logger.info("Creating ADK tables with storage type: %s", storage_type)
logger.debug("Creating ADK tables with storage type: %s", storage_type)

async with self._config.provide_session() as driver:
await driver.execute_script(self._get_create_sessions_table_sql_for_type(storage_type))
Expand Down Expand Up @@ -561,16 +583,17 @@ async def get_session(self, session_id: str) -> "SessionRecord | None":
State is deserialized using version-appropriate format.
"""

sql = f"""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {self._session_table}
WHERE id = :id
"""

try:
async with self._config.provide_connection() as conn:
cursor = conn.cursor()
await cursor.execute(sql, {"id": session_id})
await cursor.execute(
f"""
SELECT id, app_name, user_id, state, create_time, update_time
FROM {self._session_table}
WHERE id = :id
""",
{"id": session_id},
)
row = await cursor.fetchone()

if row is None:
Expand Down Expand Up @@ -881,7 +904,7 @@ class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]):
- Configuration is read from config.extension_config["adk"]
"""

__slots__ = ("_in_memory", "_json_storage_type")
__slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info")

def __init__(self, config: "OracleSyncConfig") -> None:
"""Initialize Oracle synchronous ADK store.
Expand All @@ -898,6 +921,7 @@ def __init__(self, config: "OracleSyncConfig") -> None:
"""
super().__init__(config)
self._json_storage_type: JSONStorageType | None = None
self._oracle_version_info: OracleVersionInfo | None = None

adk_config = config.extension_config.get("adk", {})
self._in_memory: bool = bool(adk_config.get("in_memory", False))
Expand Down Expand Up @@ -938,44 +962,24 @@ def _detect_json_storage_type(self) -> JSONStorageType:
if self._json_storage_type is not None:
return self._json_storage_type

with self._config.provide_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"""
SELECT version FROM product_component_version
WHERE product LIKE 'Oracle%Database%'
"""
)
row = cursor.fetchone()

if row is None:
logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON")
self._json_storage_type = JSONStorageType.BLOB_JSON
return self._json_storage_type

version_str = str(row[0])
version_parts = version_str.split(".")
major_version = int(version_parts[0]) if version_parts else 0

if major_version >= ORACLE_MIN_JSON_NATIVE_VERSION:
cursor.execute("SELECT value FROM v$parameter WHERE name = 'compatible'")
compatible_row = cursor.fetchone()
if compatible_row:
compatible_parts = str(compatible_row[0]).split(".")
compatible_major = int(compatible_parts[0]) if compatible_parts else 0
if compatible_major >= ORACLE_MIN_JSON_NATIVE_COMPATIBLE:
logger.info("Detected Oracle %s with compatible >= 20, using JSON_NATIVE", version_str)
self._json_storage_type = JSONStorageType.JSON_NATIVE
return self._json_storage_type

if major_version >= ORACLE_MIN_JSON_BLOB_VERSION:
logger.info("Detected Oracle %s, using BLOB_JSON (recommended)", version_str)
self._json_storage_type = JSONStorageType.BLOB_JSON
return self._json_storage_type

logger.info("Detected Oracle %s (pre-12c), using BLOB_PLAIN", version_str)
self._json_storage_type = JSONStorageType.BLOB_PLAIN
return self._json_storage_type
version_info = self._get_version_info()
self._json_storage_type = _storage_type_from_version(version_info)
return self._json_storage_type

def _get_version_info(self) -> "OracleVersionInfo | None":
"""Return cached Oracle version info using Oracle data dictionary."""

if self._oracle_version_info is not None:
return self._oracle_version_info

with self._config.provide_session() as driver:
dictionary = OracleSyncDataDictionary()
self._oracle_version_info = dictionary.get_version(driver)

if self._oracle_version_info is None:
logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage")

return self._oracle_version_info

def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes":
"""Serialize state dictionary to appropriate format based on storage type.
Expand Down Expand Up @@ -1010,7 +1014,7 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]":
data = data.read()

if isinstance(data, dict):
return data
return cast("dict[str, Any]", _coerce_decimal_values(data))

if isinstance(data, bytes):
return from_json(data) # type: ignore[no-any-return]
Expand Down Expand Up @@ -1058,7 +1062,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None":
data = data.read()

if isinstance(data, dict):
return data
return cast("dict[str, Any]", _coerce_decimal_values(data))

if isinstance(data, bytes):
return from_json(data) # type: ignore[no-any-return]
Expand Down
12 changes: 7 additions & 5 deletions sqlspec/core/parameters/_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
(?P<block_comment>/\*(?:[^*]|\*(?!/))*\*/) |
(?P<pg_q_operator>\?\?|\?\||\?&) |
(?P<pg_cast>::(?P<cast_type>\w+)) |
(?P<sql_server_global>@@(?P<global_var_name>\w+)) |
(?P<pyformat_named>%\((?P<pyformat_name>\w+)\)s) |
(?P<pyformat_pos>%s) |
(?P<positional_colon>:(?P<colon_num>\d+)) |
(?P<named_colon>:(?P<colon_name>\w+)) |
(?P<named_at>@(?P<at_name>\w+)) |
(?P<numeric>\$(?P<numeric_num>\d+)) |
(?P<named_dollar_param>\$(?P<dollar_param_name>\w+)) |
(?P<positional_colon>(?<![A-Za-z0-9_]):(?P<colon_num>\d+)) |
(?P<named_colon>(?<![A-Za-z0-9_]):(?P<colon_name>\w+)) |
(?P<named_at>(?<![A-Za-z0-9_])@(?P<at_name>\w+)) |
(?P<numeric>(?<![A-Za-z0-9_])\$(?P<numeric_num>\d+)) |
(?P<named_dollar_param>(?<![A-Za-z0-9_])\$(?P<dollar_param_name>\w+)) |
(?P<qmark>\?)
""",
re.VERBOSE | re.IGNORECASE | re.MULTILINE | re.DOTALL,
Expand Down Expand Up @@ -85,6 +86,7 @@ def extract_parameters(self, sql: str) -> "list[ParameterInfo]":
"block_comment",
"pg_q_operator",
"pg_cast",
"sql_server_global",
)

for match in PARAMETER_REGEX.finditer(sql):
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/test_adapters/test_oracledb/test_adk_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Tests for Oracle ADK store Decimal coercion."""

from decimal import Decimal

import pytest

from sqlspec.adapters.oracledb.adk.store import OracleAsyncADKStore, OracleSyncADKStore


@pytest.mark.asyncio
async def test_oracle_async_adk_store_deserialize_dict_coerces_decimal() -> None:
store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg]

payload = {"value": Decimal("1.25"), "nested": {"score": Decimal("0.5")}}

result = await store._deserialize_json_field(payload) # type: ignore[attr-defined]

assert result == {"value": 1.25, "nested": {"score": 0.5}}


@pytest.mark.asyncio
async def test_oracle_async_adk_store_deserialize_state_dict_coerces_decimal() -> None:
store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg]

payload = {"state": Decimal("2.0")}

result = await store._deserialize_state(payload) # type: ignore[attr-defined]

assert result == {"state": 2.0}


def test_oracle_sync_adk_store_deserialize_dict_coerces_decimal() -> None:
store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg]

payload = {"value": Decimal("3.14"), "items": [Decimal("1.0"), Decimal("2.0")]}

result = store._deserialize_json_field(payload) # type: ignore[attr-defined]

assert result == {"value": 3.14, "items": [1.0, 2.0]}


def test_oracle_sync_adk_store_deserialize_state_dict_coerces_decimal() -> None:
store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg]

payload = {"state": Decimal("5.0")}

result = store._deserialize_state(payload) # type: ignore[attr-defined]

assert result == {"state": 5.0}
Loading