Skip to content
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ alloydb = ["google-cloud-alloydb-connector"]
asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"]
attrs = ["attrs", "cattrs"]
bigquery = ["google-cloud-bigquery"]
bigquery = ["google-cloud-bigquery", "google-cloud-storage"]
cli = ["rich-click"]
cloud-sql = ["cloud-sql-python-connector"]
duckdb = ["duckdb"]
Expand Down
34 changes: 16 additions & 18 deletions sqlspec/adapters/adbc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from typing_extensions import NotRequired

from sqlspec.adapters.adbc._types import AdbcConnection
from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, get_adbc_statement_config
from sqlspec.adapters.adbc.driver import AdbcCursor, AdbcDriver, AdbcExceptionHandler, get_adbc_statement_config
from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, NoPoolSyncConfig, StarletteConfig
from sqlspec.core import StatementConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.utils.module_loader import import_string
from sqlspec.utils.serializers import to_json

if TYPE_CHECKING:
from collections.abc import Generator
Expand Down Expand Up @@ -140,20 +141,12 @@ def __init__(
detected_dialect = str(self._get_dialect() or "sqlite")
statement_config = get_adbc_statement_config(detected_dialect)

from sqlspec.utils.serializers import to_json
processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
json_serializer = processed_driver_features.setdefault("json_serializer", to_json)
processed_driver_features.setdefault("enable_cast_detection", True)
processed_driver_features.setdefault("strict_type_coercion", False)
processed_driver_features.setdefault("arrow_extension_types", True)

if driver_features is None:
driver_features = {}
if "json_serializer" not in driver_features:
driver_features["json_serializer"] = to_json
if "enable_cast_detection" not in driver_features:
driver_features["enable_cast_detection"] = True
if "strict_type_coercion" not in driver_features:
driver_features["strict_type_coercion"] = False
if "arrow_extension_types" not in driver_features:
driver_features["arrow_extension_types"] = True

json_serializer = driver_features.get("json_serializer")
if json_serializer is not None:
parameter_config = statement_config.parameter_config
previous_list_converter = parameter_config.type_coercion_map.get(list)
Expand All @@ -172,7 +165,7 @@ def __init__(
connection_config=self.connection_config,
migration_config=migration_config,
statement_config=statement_config,
driver_features=dict(driver_features),
driver_features=processed_driver_features,
bind_key=bind_key,
extension_config=extension_config,
)
Expand Down Expand Up @@ -420,13 +413,18 @@ def _get_connection_config_dict(self) -> dict[str, Any]:

return config

def get_signature_namespace(self) -> "dict[str, type[Any]]":
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for types.

Returns:
Dictionary mapping type names to types.
"""

namespace = super().get_signature_namespace()
namespace.update({"AdbcConnection": AdbcConnection, "AdbcCursor": AdbcCursor})
namespace.update({
"AdbcConnection": AdbcConnection,
"AdbcConnectionParams": AdbcConnectionParams,
"AdbcCursor": AdbcCursor,
"AdbcDriver": AdbcDriver,
"AdbcExceptionHandler": AdbcExceptionHandler,
})
return namespace
32 changes: 16 additions & 16 deletions sqlspec/adapters/aiosqlite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
from typing_extensions import NotRequired

from sqlspec.adapters.aiosqlite._types import AiosqliteConnection
from sqlspec.adapters.aiosqlite.driver import AiosqliteCursor, AiosqliteDriver, aiosqlite_statement_config
from sqlspec.adapters.aiosqlite.driver import (
AiosqliteCursor,
AiosqliteDriver,
AiosqliteExceptionHandler,
aiosqlite_statement_config,
)
from sqlspec.adapters.aiosqlite.pool import (
AiosqliteConnectionPool,
AiosqliteConnectTimeoutError,
AiosqlitePoolClosedError,
AiosqlitePoolConnection,
)
from sqlspec.adapters.sqlite._type_handlers import register_type_handlers
from sqlspec.config import ADKConfig, AsyncDatabaseConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig
from sqlspec.utils.serializers import from_json, to_json

Expand Down Expand Up @@ -117,20 +123,11 @@ def __init__(
config_dict["uri"] = True

processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}

if "enable_custom_adapters" not in processed_driver_features:
processed_driver_features["enable_custom_adapters"] = True

if "json_serializer" not in processed_driver_features:
processed_driver_features["json_serializer"] = to_json

if "json_deserializer" not in processed_driver_features:
processed_driver_features["json_deserializer"] = from_json
processed_driver_features.setdefault("enable_custom_adapters", True)
json_serializer = processed_driver_features.setdefault("json_serializer", to_json)
json_deserializer = processed_driver_features.setdefault("json_deserializer", from_json)

base_statement_config = statement_config or aiosqlite_statement_config

json_serializer = processed_driver_features.get("json_serializer")
json_deserializer = processed_driver_features.get("json_deserializer")
if json_serializer is not None:
parameter_config = base_statement_config.parameter_config.with_json_serializers(
json_serializer, deserializer=json_deserializer
Expand Down Expand Up @@ -250,8 +247,6 @@ def _register_type_adapters(self) -> None:
sync adapter, so this shares the implementation.
"""
if self.driver_features.get("enable_custom_adapters", False):
from sqlspec.adapters.sqlite._type_handlers import register_type_handlers

register_type_handlers(
json_serializer=self.driver_features.get("json_serializer"),
json_deserializer=self.driver_features.get("json_deserializer"),
Expand Down Expand Up @@ -283,7 +278,7 @@ async def provide_pool(self) -> AiosqliteConnectionPool:
self.pool_instance = await self.create_pool()
return self.pool_instance

def get_signature_namespace(self) -> "dict[str, type[Any]]":
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for aiosqlite types.

Returns:
Expand All @@ -292,11 +287,16 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
namespace = super().get_signature_namespace()
namespace.update({
"AiosqliteConnection": AiosqliteConnection,
"AiosqliteConnectionParams": AiosqliteConnectionParams,
"AiosqliteConnectionPool": AiosqliteConnectionPool,
"AiosqliteConnectTimeoutError": AiosqliteConnectTimeoutError,
"AiosqliteCursor": AiosqliteCursor,
"AiosqliteDriver": AiosqliteDriver,
"AiosqliteDriverFeatures": AiosqliteDriverFeatures,
"AiosqliteExceptionHandler": AiosqliteExceptionHandler,
"AiosqlitePoolClosedError": AiosqlitePoolClosedError,
"AiosqlitePoolConnection": AiosqlitePoolConnection,
"AiosqlitePoolParams": AiosqlitePoolParams,
})
return namespace

Expand Down
16 changes: 10 additions & 6 deletions sqlspec/adapters/asyncmy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from sqlspec.adapters.asyncmy.driver import (
AsyncmyCursor,
AsyncmyDriver,
AsyncmyExceptionHandler,
asyncmy_statement_config,
build_asyncmy_statement_config,
)
Expand Down Expand Up @@ -121,10 +122,8 @@ def __init__(
extras = processed_pool_config.pop("extra")
processed_pool_config.update(extras)

if "host" not in processed_pool_config:
processed_pool_config["host"] = "localhost"
if "port" not in processed_pool_config:
processed_pool_config["port"] = 3306
processed_pool_config.setdefault("host", "localhost")
processed_pool_config.setdefault("port", 3306)

processed_driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}
serializer = processed_driver_features.setdefault("json_serializer", to_json)
Expand Down Expand Up @@ -221,7 +220,7 @@ async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": # pyright: i
self.pool_instance = await self.create_pool()
return self.pool_instance

def get_signature_namespace(self) -> "dict[str, type[Any]]":
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for Asyncmy types.

Returns:
Expand All @@ -231,7 +230,12 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
namespace = super().get_signature_namespace()
namespace.update({
"AsyncmyConnection": AsyncmyConnection,
"AsyncmyPool": AsyncmyPool,
"AsyncmyConnectionParams": AsyncmyConnectionParams,
"AsyncmyCursor": AsyncmyCursor,
"AsyncmyDriver": AsyncmyDriver,
"AsyncmyDriverFeatures": AsyncmyDriverFeatures,
"AsyncmyExceptionHandler": AsyncmyExceptionHandler,
"AsyncmyPool": AsyncmyPool,
"AsyncmyPoolParams": AsyncmyPoolParams,
})
return namespace
20 changes: 11 additions & 9 deletions sqlspec/adapters/asyncpg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from asyncpg.pool import Pool, PoolConnectionProxy, PoolConnectionProxyMeta
from typing_extensions import NotRequired

from sqlspec.adapters.asyncpg._types import AsyncpgConnection
from sqlspec.adapters.asyncpg._type_handlers import register_json_codecs, register_pgvector_support
from sqlspec.adapters.asyncpg._types import AsyncpgConnection, AsyncpgPool
from sqlspec.adapters.asyncpg.driver import (
AsyncpgCursor,
AsyncpgDriver,
AsyncpgExceptionHandler,
asyncpg_statement_config,
build_asyncpg_statement_config,
)
Expand Down Expand Up @@ -329,8 +331,7 @@ async def _create_pool(self) -> "Pool[Record]":
elif self.driver_features.get("enable_alloydb", False):
self._setup_alloydb_connector(config)

if "init" not in config:
config["init"] = self._init_connection
config.setdefault("init", self._init_connection)

return await asyncpg_create_pool(**config)

Expand All @@ -341,17 +342,13 @@ async def _init_connection(self, connection: "AsyncpgConnection") -> None:
connection: AsyncPG connection to initialize.
"""
if self.driver_features.get("enable_json_codecs", True):
from sqlspec.adapters.asyncpg._type_handlers import register_json_codecs

await register_json_codecs(
connection,
encoder=self.driver_features.get("json_serializer", to_json),
decoder=self.driver_features.get("json_deserializer", from_json),
)

if self.driver_features.get("enable_pgvector", False):
from sqlspec.adapters.asyncpg._type_handlers import register_pgvector_support

await register_pgvector_support(connection)

async def _close_pool(self) -> None:
Expand Down Expand Up @@ -432,7 +429,7 @@ async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool[Record]":
self.pool_instance = await self.create_pool()
return self.pool_instance

def get_signature_namespace(self) -> "dict[str, type[Any]]":
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for AsyncPG types.

This provides all AsyncPG-specific types that Litestar needs to recognize
Expand All @@ -450,7 +447,12 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
"PoolConnectionProxyMeta": PoolConnectionProxyMeta,
"ConnectionMeta": ConnectionMeta,
"Record": Record,
"AsyncpgConnection": AsyncpgConnection, # type: ignore[dict-item]
"AsyncpgConnection": AsyncpgConnection,
"AsyncpgConnectionConfig": AsyncpgConnectionConfig,
"AsyncpgCursor": AsyncpgCursor,
"AsyncpgDriver": AsyncpgDriver,
"AsyncpgExceptionHandler": AsyncpgExceptionHandler,
"AsyncpgPool": AsyncpgPool,
"AsyncpgPoolConfig": AsyncpgPoolConfig,
})
return namespace
22 changes: 15 additions & 7 deletions sqlspec/adapters/bigquery/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from typing_extensions import NotRequired

from sqlspec.adapters.bigquery._types import BigQueryConnection
from sqlspec.adapters.bigquery.driver import BigQueryCursor, BigQueryDriver, build_bigquery_statement_config
from sqlspec.adapters.bigquery.driver import (
BigQueryCursor,
BigQueryDriver,
BigQueryExceptionHandler,
build_bigquery_statement_config,
)
from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, NoPoolSyncConfig, StarletteConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.typing import Empty
Expand Down Expand Up @@ -134,10 +139,7 @@ def __init__(
self.connection_config.update(extras)

self.driver_features: dict[str, Any] = dict(driver_features) if driver_features else {}

if "enable_uuid_conversion" not in self.driver_features:
self.driver_features["enable_uuid_conversion"] = True

self.driver_features.setdefault("enable_uuid_conversion", True)
serializer = self.driver_features.setdefault("json_serializer", to_json)

self._connection_instance: BigQueryConnection | None = self.driver_features.get("connection_instance")
Expand Down Expand Up @@ -263,13 +265,19 @@ def provide_session(
)
yield driver

def get_signature_namespace(self) -> "dict[str, type[Any]]":
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for BigQuery types.

Returns:
Dictionary mapping type names to types.
"""

namespace = super().get_signature_namespace()
namespace.update({"BigQueryConnection": BigQueryConnection, "BigQueryCursor": BigQueryCursor})
namespace.update({
"BigQueryConnection": BigQueryConnection,
"BigQueryConnectionParams": BigQueryConnectionParams,
"BigQueryCursor": BigQueryCursor,
"BigQueryDriver": BigQueryDriver,
"BigQueryExceptionHandler": BigQueryExceptionHandler,
})
return namespace
25 changes: 20 additions & 5 deletions sqlspec/adapters/duckdb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from typing_extensions import NotRequired

from sqlspec.adapters.duckdb._types import DuckDBConnection
from sqlspec.adapters.duckdb.driver import DuckDBCursor, DuckDBDriver, build_duckdb_statement_config
from sqlspec.adapters.duckdb.driver import (
DuckDBCursor,
DuckDBDriver,
DuckDBExceptionHandler,
build_duckdb_statement_config,
)
from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool
from sqlspec.config import ADKConfig, FastAPIConfig, FlaskConfig, LitestarConfig, StarletteConfig, SyncDatabaseConfig
from sqlspec.utils.serializers import to_json
Expand Down Expand Up @@ -209,8 +214,7 @@ def __init__(
"""
if pool_config is None:
pool_config = {}
if "database" not in pool_config:
pool_config["database"] = ":memory:shared_db"
pool_config.setdefault("database", ":memory:shared_db")

if pool_config.get("database") in {":memory:", ""}:
pool_config["database"] = ":memory:shared_db"
Expand Down Expand Up @@ -331,7 +335,7 @@ def provide_session(
)
yield driver

def get_signature_namespace(self) -> "dict[str, type[Any]]":
def get_signature_namespace(self) -> "dict[str, Any]":
"""Get the signature namespace for DuckDB types.

This provides all DuckDB-specific types that Litestar needs to recognize
Expand All @@ -342,5 +346,16 @@ def get_signature_namespace(self) -> "dict[str, type[Any]]":
"""

namespace = super().get_signature_namespace()
namespace.update({"DuckDBConnection": DuckDBConnection, "DuckDBCursor": DuckDBCursor})
namespace.update({
"DuckDBConnection": DuckDBConnection,
"DuckDBConnectionParams": DuckDBConnectionParams,
"DuckDBConnectionPool": DuckDBConnectionPool,
"DuckDBCursor": DuckDBCursor,
"DuckDBDriver": DuckDBDriver,
"DuckDBDriverFeatures": DuckDBDriverFeatures,
"DuckDBExceptionHandler": DuckDBExceptionHandler,
"DuckDBExtensionConfig": DuckDBExtensionConfig,
"DuckDBPoolParams": DuckDBPoolParams,
"DuckDBSecretConfig": DuckDBSecretConfig,
})
return namespace
Loading