diff --git a/Makefile b/Makefile index a1ad2f94..ce57ffdc 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ install-uv: ## Install latest version of .PHONY: install install: destroy clean ## Install the project, dependencies, and pre-commit @echo "${INFO} Starting fresh installation..." - @uv python pin 3.12 >/dev/null 2>&1 + @uv python pin 3.10 >/dev/null 2>&1 @uv venv >/dev/null 2>&1 @uv sync --all-extras --dev @echo "${OK} Installation complete! 🎉" @@ -51,7 +51,7 @@ install: destroy clean ## Install the project, depe .PHONY: install-compiled install-compiled: destroy clean ## Install with mypyc compilation for performance @echo "${INFO} Starting fresh installation with mypyc compilation..." - @uv python pin 3.12 >/dev/null 2>&1 + @uv python pin 3.10 >/dev/null 2>&1 @uv venv >/dev/null 2>&1 @echo "${INFO} Installing in editable mode with mypyc compilation..." @HATCH_BUILD_HOOKS_ENABLE=1 uv pip install -e . diff --git a/docs/examples/litestar_asyncpg.py b/docs/examples/litestar_asyncpg.py index 6a96500d..1da343b5 100755 --- a/docs/examples/litestar_asyncpg.py +++ b/docs/examples/litestar_asyncpg.py @@ -24,9 +24,10 @@ from litestar import Litestar, get -from sqlspec import SQL +from sqlspec import SQLSpec from sqlspec.adapters.asyncpg import AsyncpgConfig, AsyncpgDriver, AsyncpgPoolConfig -from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec +from sqlspec.core.statement import SQL +from sqlspec.extensions.litestar import SQLSpecPlugin @get("/") @@ -70,19 +71,17 @@ async def get_status() -> dict[str, str]: # Configure SQLSpec with AsyncPG # Note: Modify this DSN to match your database configuration -sqlspec = SQLSpec( - config=[ - DatabaseConfig( - config=AsyncpgConfig( - pool_config=AsyncpgPoolConfig( - dsn="postgresql://postgres:postgres@localhost:5433/postgres", min_size=5, max_size=5 - ) - ), - commit_mode="autocommit", - ) - ] +sql = SQLSpec() +sql.add_config( + AsyncpgConfig( + pool_config=AsyncpgPoolConfig( + dsn="postgresql://postgres:postgres@localhost:5433/postgres", min_size=5, max_size=5 + ), + extension_config={"litestar": {"commit_mode": "autocommit"}}, + ) ) -app = Litestar(route_handlers=[hello_world, get_version, list_tables, get_status], plugins=[sqlspec], debug=True) +plugin = SQLSpecPlugin(sqlspec=sql) +app = Litestar(route_handlers=[hello_world, get_version, list_tables, get_status], plugins=[plugin], debug=True) if __name__ == "__main__": import os diff --git a/docs/examples/litestar_duckllm.py b/docs/examples/litestar_duckllm.py index 36c526f3..1e2ee5c9 100644 --- a/docs/examples/litestar_duckllm.py +++ b/docs/examples/litestar_duckllm.py @@ -16,8 +16,9 @@ from litestar import Litestar, post from msgspec import Struct +from sqlspec import SQLSpec from sqlspec.adapters.duckdb import DuckDBConfig, DuckDBDriver -from sqlspec.extensions.litestar import SQLSpec +from sqlspec.extensions.litestar import SQLSpecPlugin class ChatMessage(Struct): @@ -30,8 +31,9 @@ def duckllm_chat(db_session: DuckDBDriver, data: ChatMessage) -> ChatMessage: return db_session.to_schema(results or {"message": "No response from DuckLLM"}, schema_type=ChatMessage) -sqlspec = SQLSpec( - config=DuckDBConfig( +sql = SQLSpec() +sql.add_config( + DuckDBConfig( driver_features={ "extensions": [{"name": "open_prompt"}], "secrets": [ @@ -48,7 +50,8 @@ def duckllm_chat(db_session: DuckDBDriver, data: ChatMessage) -> ChatMessage: } ) ) -app = Litestar(route_handlers=[duckllm_chat], plugins=[sqlspec], debug=True) +plugin = SQLSpecPlugin(sqlspec=sql) +app = Litestar(route_handlers=[duckllm_chat], plugins=[plugin], debug=True) if __name__ == "__main__": import uvicorn diff --git a/docs/examples/litestar_multi_db.py b/docs/examples/litestar_multi_db.py index 5f81e13d..d9ee7b91 100644 --- a/docs/examples/litestar_multi_db.py +++ b/docs/examples/litestar_multi_db.py @@ -15,10 +15,11 @@ from litestar import Litestar, get +from sqlspec import SQLSpec from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver from sqlspec.adapters.duckdb import DuckDBConfig, DuckDBDriver from sqlspec.core.statement import SQL -from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec +from sqlspec.extensions.litestar import SQLSpecPlugin @get("/test", sync_to_thread=True) @@ -35,22 +36,19 @@ async def simple_sqlite(db_session: AiosqliteDriver) -> dict[str, str]: return {"greeting": greeting["greeting"] if greeting is not None else "hi"} -sqlspec = SQLSpec( - config=[ - DatabaseConfig(config=AiosqliteConfig(), commit_mode="autocommit"), - DatabaseConfig( - config=DuckDBConfig( - driver_features={ - "extensions": [{"name": "vss", "force_install": True}], - "secrets": [{"secret_type": "s3", "name": "s3_secret", "value": {"key_id": "abcd"}}], - } - ), - connection_key="etl_connection", - session_key="etl_session", - ), - ] +sql = SQLSpec() +sql.add_config(AiosqliteConfig(extension_config={"litestar": {"commit_mode": "autocommit"}})) +sql.add_config( + DuckDBConfig( + driver_features={ + "extensions": [{"name": "vss", "force_install": True}], + "secrets": [{"secret_type": "s3", "name": "s3_secret", "value": {"key_id": "abcd"}}], + }, + extension_config={"litestar": {"connection_key": "etl_connection", "session_key": "etl_session"}}, + ) ) -app = Litestar(route_handlers=[simple_sqlite, simple_select], plugins=[sqlspec]) +plugin = SQLSpecPlugin(sqlspec=sql) +app = Litestar(route_handlers=[simple_sqlite, simple_select], plugins=[plugin]) if __name__ == "__main__": import os diff --git a/docs/examples/litestar_psycopg.py b/docs/examples/litestar_psycopg.py index 48341484..f5492aa4 100644 --- a/docs/examples/litestar_psycopg.py +++ b/docs/examples/litestar_psycopg.py @@ -15,29 +15,27 @@ from litestar import Litestar, get +from sqlspec import SQLSpec from sqlspec.adapters.psycopg import PsycopgAsyncConfig, PsycopgAsyncDriver -from sqlspec.extensions.litestar import DatabaseConfig, SQLSpec +from sqlspec.core.statement import SQL +from sqlspec.extensions.litestar import SQLSpecPlugin @get("/") async def simple_psycopg(db_session: PsycopgAsyncDriver) -> dict[str, str]: - from sqlspec.core.statement import SQL - result = await db_session.execute(SQL("SELECT 'Hello, world!' AS greeting")) return result.get_first() or {"greeting": "No result found"} -sqlspec = SQLSpec( - config=[ - DatabaseConfig( - config=PsycopgAsyncConfig( - pool_config={"conninfo": "postgres://app:app@localhost:15432/app", "min_size": 1, "max_size": 3} - ), - commit_mode="autocommit", - ) - ] +sql = SQLSpec() +sql.add_config( + PsycopgAsyncConfig( + pool_config={"conninfo": "postgres://app:app@localhost:15432/app", "min_size": 1, "max_size": 3}, + extension_config={"litestar": {"commit_mode": "autocommit"}}, + ) ) -app = Litestar(route_handlers=[simple_psycopg], plugins=[sqlspec]) +plugin = SQLSpecPlugin(sqlspec=sql) +app = Litestar(route_handlers=[simple_psycopg], plugins=[plugin]) if __name__ == "__main__": import os diff --git a/docs/examples/litestar_single_db.py b/docs/examples/litestar_single_db.py index 96c97ed0..5b7ca6fc 100644 --- a/docs/examples/litestar_single_db.py +++ b/docs/examples/litestar_single_db.py @@ -8,8 +8,9 @@ from aiosqlite import Connection from litestar import Litestar, get +from sqlspec import SQLSpec from sqlspec.adapters.aiosqlite import AiosqliteConfig -from sqlspec.extensions.litestar import SQLSpec +from sqlspec.extensions.litestar import SQLSpecPlugin @get("/") @@ -23,5 +24,7 @@ async def simple_sqlite(db_connection: Connection) -> dict[str, str]: return {"greeting": next(iter(result))[0]} -sqlspec = SQLSpec(config=AiosqliteConfig()) -app = Litestar(route_handlers=[simple_sqlite], plugins=[sqlspec]) +sql = SQLSpec() +sql.add_config(AiosqliteConfig()) +plugin = SQLSpecPlugin(sqlspec=sql) +app = Litestar(route_handlers=[simple_sqlite], plugins=[plugin]) diff --git a/sqlspec/adapters/aiosqlite/config.py b/sqlspec/adapters/aiosqlite/config.py index 2559c1d0..9fc7c613 100644 --- a/sqlspec/adapters/aiosqlite/config.py +++ b/sqlspec/adapters/aiosqlite/config.py @@ -79,6 +79,15 @@ def __init__( if "database" not in config_dict or config_dict["database"] == ":memory:": config_dict["database"] = "file::memory:?cache=shared" config_dict["uri"] = True + elif "database" in config_dict: + database_path = str(config_dict["database"]) + if database_path.startswith("file:") and not config_dict.get("uri"): + logger.debug( + "Database URI detected (%s) but uri=True not set. " + "Auto-enabling URI mode to prevent physical file creation.", + database_path, + ) + config_dict["uri"] = True super().__init__( pool_config=config_dict, diff --git a/sqlspec/adapters/sqlite/config.py b/sqlspec/adapters/sqlite/config.py index bb5a4a60..524df727 100644 --- a/sqlspec/adapters/sqlite/config.py +++ b/sqlspec/adapters/sqlite/config.py @@ -1,5 +1,6 @@ """SQLite database configuration with thread-local connections.""" +import logging import uuid from contextlib import contextmanager from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast @@ -11,6 +12,8 @@ from sqlspec.adapters.sqlite.pool import SqliteConnectionPool from sqlspec.config import SyncDatabaseConfig +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from collections.abc import Generator @@ -64,6 +67,15 @@ def __init__( if "database" not in pool_config or pool_config["database"] == ":memory:": pool_config["database"] = f"file:memory_{uuid.uuid4().hex}?mode=memory&cache=private" pool_config["uri"] = True + elif "database" in pool_config: + database_path = str(pool_config["database"]) + if database_path.startswith("file:") and not pool_config.get("uri"): + logger.debug( + "Database URI detected (%s) but uri=True not set. " + "Auto-enabling URI mode to prevent physical file creation.", + database_path, + ) + pool_config["uri"] = True super().__init__( bind_key=bind_key, diff --git a/sqlspec/base.py b/sqlspec/base.py index 41ef5baf..b1aa326a 100644 --- a/sqlspec/base.py +++ b/sqlspec/base.py @@ -161,6 +161,15 @@ def get_config( logger.debug("Retrieved configuration: %s", self._get_config_name(name)) return config + @property + def configs(self) -> "dict[type, DatabaseConfigProtocol[Any, Any, Any]]": + """Access the registry of database configurations. + + Returns: + Dictionary mapping config types to config instances. + """ + return self._configs + @overload def get_connection( self, diff --git a/sqlspec/cli.py b/sqlspec/cli.py index fdadb62d..9d3ea5c9 100644 --- a/sqlspec/cli.py +++ b/sqlspec/cli.py @@ -174,12 +174,6 @@ def get_config_by_bind_key( console.print(f"[red]No config found for bind key: {bind_key}[/]") sys.exit(1) - # Extract the actual config from DatabaseConfig wrapper if needed - from sqlspec.extensions.litestar.config import DatabaseConfig - - if isinstance(config, DatabaseConfig): - config = config.config - return cast("AsyncDatabaseConfig[Any, Any, Any] | SyncDatabaseConfig[Any, Any, Any]", config) def get_configs_with_migrations(ctx: "click.Context", enabled_only: bool = False) -> "list[tuple[str, Any]]": @@ -195,18 +189,13 @@ def get_configs_with_migrations(ctx: "click.Context", enabled_only: bool = False configs = ctx.obj["configs"] migration_configs = [] - from sqlspec.extensions.litestar.config import DatabaseConfig - for config in configs: - # Extract the actual config from DatabaseConfig wrapper if needed - actual_config = config.config if isinstance(config, DatabaseConfig) else config - - migration_config = actual_config.migration_config + migration_config = config.migration_config if migration_config: enabled = migration_config.get("enabled", True) if not enabled_only or enabled: - config_name = actual_config.bind_key or str(type(actual_config).__name__) - migration_configs.append((config_name, actual_config)) + config_name = config.bind_key or str(type(config).__name__) + migration_configs.append((config_name, config)) return migration_configs @@ -506,7 +495,6 @@ def init_sqlspec( # pyright: ignore[reportUnusedFunction] """Initialize the database migrations.""" from rich.prompt import Confirm - from sqlspec.extensions.litestar.config import DatabaseConfig from sqlspec.migrations.commands import create_migration_commands from sqlspec.utils.sync_tools import run_ @@ -527,13 +515,11 @@ async def _init_sqlspec() -> None: ) for config in configs: - # Extract the actual config from DatabaseConfig wrapper if needed - actual_config = config.config if isinstance(config, DatabaseConfig) else config - migration_config = getattr(actual_config, "migration_config", {}) + migration_config = getattr(config, "migration_config", {}) target_directory = ( migration_config.get("script_location", "migrations") if directory is None else directory ) - migration_commands = create_migration_commands(config=actual_config) + migration_commands = create_migration_commands(config=config) await maybe_await(migration_commands.init(directory=cast("str", target_directory), package=package)) run_(_init_sqlspec)() diff --git a/sqlspec/extensions/litestar/__init__.py b/sqlspec/extensions/litestar/__init__.py index 6eab1a6f..2d60c576 100644 --- a/sqlspec/extensions/litestar/__init__.py +++ b/sqlspec/extensions/litestar/__init__.py @@ -1,6 +1,19 @@ -from sqlspec.extensions.litestar import handlers, providers from sqlspec.extensions.litestar.cli import database_group -from sqlspec.extensions.litestar.config import DatabaseConfig -from sqlspec.extensions.litestar.plugin import SQLSpec +from sqlspec.extensions.litestar.plugin import ( + DEFAULT_COMMIT_MODE, + DEFAULT_CONNECTION_KEY, + DEFAULT_POOL_KEY, + DEFAULT_SESSION_KEY, + CommitMode, + SQLSpecPlugin, +) -__all__ = ("DatabaseConfig", "SQLSpec", "database_group", "handlers", "providers") +__all__ = ( + "DEFAULT_COMMIT_MODE", + "DEFAULT_CONNECTION_KEY", + "DEFAULT_POOL_KEY", + "DEFAULT_SESSION_KEY", + "CommitMode", + "SQLSpecPlugin", + "database_group", +) diff --git a/sqlspec/extensions/litestar/cli.py b/sqlspec/extensions/litestar/cli.py index f27f0673..a57295f7 100644 --- a/sqlspec/extensions/litestar/cli.py +++ b/sqlspec/extensions/litestar/cli.py @@ -15,10 +15,10 @@ if TYPE_CHECKING: from litestar import Litestar - from sqlspec.extensions.litestar.plugin import SQLSpec + from sqlspec.extensions.litestar.plugin import SQLSpecPlugin -def get_database_migration_plugin(app: "Litestar") -> "SQLSpec": +def get_database_migration_plugin(app: "Litestar") -> "SQLSpecPlugin": """Retrieve the SQLSpec plugin from the Litestar application's plugins. Args: @@ -31,10 +31,10 @@ def get_database_migration_plugin(app: "Litestar") -> "SQLSpec": ImproperConfigurationError: If the SQLSpec plugin is not found """ from sqlspec.exceptions import ImproperConfigurationError - from sqlspec.extensions.litestar.plugin import SQLSpec + from sqlspec.extensions.litestar.plugin import SQLSpecPlugin with suppress(KeyError): - return app.plugins.get(SQLSpec) + return app.plugins.get(SQLSpecPlugin) msg = "Failed to initialize database migrations. The required SQLSpec plugin is missing." raise ImproperConfigurationError(msg) diff --git a/sqlspec/extensions/litestar/config.py b/sqlspec/extensions/litestar/config.py deleted file mode 100644 index 50b05776..00000000 --- a/sqlspec/extensions/litestar/config.py +++ /dev/null @@ -1,291 +0,0 @@ -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal, cast - -from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state -from sqlspec.extensions.litestar.handlers import ( - autocommit_handler_maker, - connection_provider_maker, - lifespan_handler_maker, - manual_handler_maker, - pool_provider_maker, - session_provider_maker, -) - -if TYPE_CHECKING: - from collections.abc import AsyncGenerator, Awaitable, Callable - from contextlib import AbstractAsyncContextManager, AbstractContextManager - - from litestar import Litestar - from litestar.datastructures.state import State - from litestar.types import BeforeMessageSendHookHandler, Scope - - from sqlspec.config import AsyncConfigT, DriverT, SyncConfigT - from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase - from sqlspec.typing import ConnectionT, PoolT - - -CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"] -DEFAULT_COMMIT_MODE: CommitMode = "manual" -DEFAULT_CONNECTION_KEY = "db_connection" -DEFAULT_POOL_KEY = "db_pool" -DEFAULT_SESSION_KEY = "db_session" - -__all__ = ( - "DEFAULT_COMMIT_MODE", - "DEFAULT_CONNECTION_KEY", - "DEFAULT_POOL_KEY", - "DEFAULT_SESSION_KEY", - "AsyncDatabaseConfig", - "CommitMode", - "DatabaseConfig", - "SyncDatabaseConfig", -) - - -@dataclass -class DatabaseConfig: - config: "SyncConfigT | AsyncConfigT" = field() # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues] - connection_key: str = field(default=DEFAULT_CONNECTION_KEY) - pool_key: str = field(default=DEFAULT_POOL_KEY) - session_key: str = field(default=DEFAULT_SESSION_KEY) - commit_mode: "CommitMode" = field(default=DEFAULT_COMMIT_MODE) - extra_commit_statuses: "set[int] | None" = field(default=None) - extra_rollback_statuses: "set[int] | None" = field(default=None) - enable_correlation_middleware: bool = field(default=True) - connection_provider: "Callable[[State, Scope], AsyncGenerator[ConnectionT, None]]" = field( # pyright: ignore[reportGeneralTypeIssues] - init=False, repr=False, hash=False - ) - pool_provider: "Callable[[State,Scope], Awaitable[PoolT]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues] - session_provider: "Callable[[Any], AsyncGenerator[DriverT, None]]" = field(init=False, repr=False, hash=False) # pyright: ignore[reportGeneralTypeIssues] - before_send_handler: "BeforeMessageSendHookHandler" = field(init=False, repr=False, hash=False) - lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field( - init=False, repr=False, hash=False - ) - annotation: "type[SyncConfigT | AsyncConfigT]" = field(init=False, repr=False, hash=False) # type: ignore[valid-type] # pyright: ignore[reportGeneralTypeIssues] - - def __post_init__(self) -> None: - litestar_config = self.config.extension_config.get("litestar", {}) # type: ignore[union-attr] - - if self.connection_key == DEFAULT_CONNECTION_KEY and "connection_key" in litestar_config: - self.connection_key = litestar_config["connection_key"] - if self.pool_key == DEFAULT_POOL_KEY and "pool_key" in litestar_config: - self.pool_key = litestar_config["pool_key"] - if self.session_key == DEFAULT_SESSION_KEY and "session_key" in litestar_config: - self.session_key = litestar_config["session_key"] - if self.commit_mode == DEFAULT_COMMIT_MODE and "commit_mode" in litestar_config: - self.commit_mode = litestar_config["commit_mode"] - if self.extra_commit_statuses is None and "extra_commit_statuses" in litestar_config: - self.extra_commit_statuses = litestar_config["extra_commit_statuses"] - if self.extra_rollback_statuses is None and "extra_rollback_statuses" in litestar_config: - self.extra_rollback_statuses = litestar_config["extra_rollback_statuses"] - if "enable_correlation_middleware" in litestar_config: - self.enable_correlation_middleware = litestar_config["enable_correlation_middleware"] - - if not self.config.supports_connection_pooling and self.pool_key == DEFAULT_POOL_KEY: # type: ignore[union-attr,unused-ignore] - self.pool_key = f"_{self.pool_key}_{id(self.config)}" - if self.commit_mode == "manual": - self.before_send_handler = manual_handler_maker(connection_scope_key=self.connection_key) - elif self.commit_mode == "autocommit": - self.before_send_handler = autocommit_handler_maker( - commit_on_redirect=False, - extra_commit_statuses=self.extra_commit_statuses, - extra_rollback_statuses=self.extra_rollback_statuses, - connection_scope_key=self.connection_key, - ) - elif self.commit_mode == "autocommit_include_redirect": - self.before_send_handler = autocommit_handler_maker( - commit_on_redirect=True, - extra_commit_statuses=self.extra_commit_statuses, - extra_rollback_statuses=self.extra_rollback_statuses, - connection_scope_key=self.connection_key, - ) - else: - msg = f"Invalid commit mode: {self.commit_mode}" - raise ImproperConfigurationError(detail=msg) - self.lifespan_handler = lifespan_handler_maker(config=self.config, pool_key=self.pool_key) - self.connection_provider = connection_provider_maker( - connection_key=self.connection_key, pool_key=self.pool_key, config=self.config - ) - self.pool_provider = pool_provider_maker(config=self.config, pool_key=self.pool_key) - self.session_provider = session_provider_maker( - config=self.config, connection_dependency_key=self.connection_key - ) - - def get_request_session(self, state: "State", scope: "Scope") -> "SyncDriverAdapterBase | AsyncDriverAdapterBase": - """Get a session instance from the current request. - - This method provides access to the database session that has been added to the request - scope, similar to Advanced Alchemy's provide_session method. It first looks for an - existing session in the request scope state, and if not found, creates a new one using - the connection from the scope. - - Args: - state: The Litestar application State object. - scope: The ASGI scope containing the request context. - - Returns: - A driver session instance. - - Raises: - ImproperConfigurationError: If no connection is available in the scope. - """ - # Create a unique scope key for sessions to avoid conflicts - session_scope_key = f"{self.session_key}_instance" - - # Try to get existing session from scope - session = get_sqlspec_scope_state(scope, session_scope_key) - if session is not None: - return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session) - - # Get connection from scope state - connection = get_sqlspec_scope_state(scope, self.connection_key) - if connection is None: - msg = f"No database connection found in scope for key '{self.connection_key}'. " - msg += "Ensure the connection dependency is properly configured and available." - raise ImproperConfigurationError(detail=msg) - - # Create new session using the connection - # Access driver_type which is available on all config types - session = self.config.driver_type(connection=connection) # type: ignore[union-attr] - - # Store session in scope for future use - set_sqlspec_scope_state(scope, session_scope_key, session) - - return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session) - - def get_request_connection(self, state: "State", scope: "Scope") -> "Any": - """Get a connection instance from the current request. - - This method provides access to the database connection that has been added to the request - scope. This is useful in guards, middleware, or other contexts where you need direct - access to the connection that's been established for the current request. - - Args: - state: The Litestar application State object. - scope: The ASGI scope containing the request context. - - Returns: - A database connection instance. - - Raises: - ImproperConfigurationError: If no connection is available in the scope. - """ - connection = get_sqlspec_scope_state(scope, self.connection_key) - if connection is None: - msg = f"No database connection found in scope for key '{self.connection_key}'. " - msg += "Ensure the connection dependency is properly configured and available." - raise ImproperConfigurationError(detail=msg) - - return cast("Any", connection) - - -# Add passthrough methods to both specialized classes for convenience -class SyncDatabaseConfig(DatabaseConfig): - """Sync-specific DatabaseConfig with better typing for get_request_session.""" - - def get_request_session(self, state: "State", scope: "Scope") -> "SyncDriverAdapterBase": - """Get a sync session instance from the current request. - - This method provides access to the database session that has been added to the request - scope, similar to Advanced Alchemy's provide_session method. It first looks for an - existing session in the request scope state, and if not found, creates a new one using - the connection from the scope. - - Args: - state: The Litestar application State object. - scope: The ASGI scope containing the request context. - - Returns: - A sync driver session instance. - """ - session = super().get_request_session(state, scope) - return cast("SyncDriverAdapterBase", session) - - def provide_session(self) -> "AbstractContextManager[SyncDriverAdapterBase]": - """Provide a database session context manager. - - This is a passthrough to the underlying config's provide_session method - for convenient access to database sessions. - - Returns: - Context manager that yields a sync driver session. - """ - return self.config.provide_session() # type: ignore[union-attr,no-any-return] - - def provide_connection(self) -> "AbstractContextManager[Any]": - """Provide a database connection context manager. - - This is a passthrough to the underlying config's provide_connection method - for convenient access to database connections. - - Returns: - Context manager that yields a sync database connection. - """ - return self.config.provide_connection() # type: ignore[union-attr,no-any-return] - - def create_connection(self) -> "Any": - """Create and return a new database connection. - - This is a passthrough to the underlying config's create_connection method - for direct connection creation without context management. - - Returns: - A new sync database connection. - """ - return self.config.create_connection() # type: ignore[union-attr] - - -class AsyncDatabaseConfig(DatabaseConfig): - """Async-specific DatabaseConfig with better typing for get_request_session.""" - - def get_request_session(self, state: "State", scope: "Scope") -> "AsyncDriverAdapterBase": - """Get an async session instance from the current request. - - This method provides access to the database session that has been added to the request - scope, similar to Advanced Alchemy's provide_session method. It first looks for an - existing session in the request scope state, and if not found, creates a new one using - the connection from the scope. - - Args: - state: The Litestar application State object. - scope: The ASGI scope containing the request context. - - Returns: - An async driver session instance. - """ - session = super().get_request_session(state, scope) - return cast("AsyncDriverAdapterBase", session) - - def provide_session(self) -> "AbstractAsyncContextManager[AsyncDriverAdapterBase]": - """Provide a database session context manager. - - This is a passthrough to the underlying config's provide_session method - for convenient access to database sessions. - - Returns: - Context manager that yields an async driver session. - """ - return self.config.provide_session() # type: ignore[union-attr,no-any-return] - - def provide_connection(self) -> "AbstractAsyncContextManager[Any]": - """Provide a database connection context manager. - - This is a passthrough to the underlying config's provide_connection method - for convenient access to database connections. - - Returns: - Context manager that yields an async database connection. - """ - return self.config.provide_connection() # type: ignore[union-attr,no-any-return] - - async def create_connection(self) -> "Any": - """Create and return a new database connection. - - This is a passthrough to the underlying config's create_connection method - for direct connection creation without context management. - - Returns: - A new async database connection. - """ - return await self.config.create_connection() # type: ignore[union-attr] diff --git a/sqlspec/extensions/litestar/handlers.py b/sqlspec/extensions/litestar/handlers.py index f609867d..16b90bec 100644 --- a/sqlspec/extensions/litestar/handlers.py +++ b/sqlspec/extensions/litestar/handlers.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, cast from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT +from litestar.params import Dependency from sqlspec.exceptions import ImproperConfigurationError from sqlspec.extensions.litestar._utils import ( @@ -37,11 +38,14 @@ ) -def manual_handler_maker(connection_scope_key: str) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]": +def manual_handler_maker( + connection_scope_key: str, is_async: bool = False +) -> "Callable[[Message, Scope], Coroutine[Any, Any, None]]": """Create handler for manual connection management. Args: connection_scope_key: The key used to store the connection in the ASGI scope. + is_async: Whether the database driver is async (uses direct await) or sync (uses ensure_async_). Returns: The handler callable. @@ -56,7 +60,9 @@ async def handler(message: "Message", scope: "Scope") -> None: """ connection = get_sqlspec_scope_state(scope, connection_scope_key) if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: - if hasattr(connection, "close") and callable(connection.close): + if is_async: + await connection.close() + else: await ensure_async_(connection.close)() delete_sqlspec_scope_state(scope, connection_scope_key) @@ -65,6 +71,7 @@ async def handler(message: "Message", scope: "Scope") -> None: def autocommit_handler_maker( connection_scope_key: str, + is_async: bool = False, commit_on_redirect: bool = False, extra_commit_statuses: "set[int] | None" = None, extra_rollback_statuses: "set[int] | None" = None, @@ -73,6 +80,7 @@ def autocommit_handler_maker( Args: connection_scope_key: The key used to store the connection in the ASGI scope. + is_async: Whether the database driver is async (uses direct await) or sync (uses ensure_async_). commit_on_redirect: Issue a commit when the response status is a redirect (3XX). extra_commit_statuses: A set of additional status codes that trigger a commit. extra_rollback_statuses: A set of additional status codes that trigger a rollback. @@ -108,13 +116,19 @@ async def handler(message: "Message", scope: "Scope") -> None: if (message["status"] in commit_range or message["status"] in extra_commit_statuses) and message[ "status" ] not in extra_rollback_statuses: - if hasattr(connection, "commit") and callable(connection.commit): + if is_async: + await connection.commit() + else: await ensure_async_(connection.commit)() - elif hasattr(connection, "rollback") and callable(connection.rollback): + elif is_async: + await connection.rollback() + else: await ensure_async_(connection.rollback)() finally: if connection and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: - if hasattr(connection, "close") and callable(connection.close): + if is_async: + await connection.close() + else: await ensure_async_(connection.close)() delete_sqlspec_scope_state(scope, connection_scope_key) @@ -144,14 +158,23 @@ async def lifespan_handler(app: "Litestar") -> "AsyncGenerator[None, None]": Yields: Control to application during pool lifetime. """ - db_pool = await ensure_async_(config.create_pool)() + db_pool: Any + if config.is_async: + db_pool = await config.create_pool() + else: + db_pool = await ensure_async_(config.create_pool)() app.state.update({pool_key: db_pool}) try: yield finally: app.state.pop(pool_key, None) try: - await ensure_async_(config.close_pool)() + if config.is_async: + close_result = config.close_pool() + if close_result is not None: + await close_result + else: + await ensure_async_(config.close_pool)() except Exception as e: if app.logger: app.logger.warning("Error closing database pool for %s. Error: %s", pool_key, e) @@ -215,11 +238,11 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ msg = f"Database pool with key '{pool_key}' not found. Cannot create a connection." raise ImproperConfigurationError(msg) - connection_cm = config.provide_connection(db_pool) + connection_cm: Any = config.provide_connection(db_pool) if not isinstance(connection_cm, AbstractAsyncContextManager): conn_instance: ConnectionT - if hasattr(connection_cm, "__await__"): + if inspect.isawaitable(connection_cm): conn_instance = await cast("Awaitable[ConnectionT]", connection_cm) else: conn_instance = cast("ConnectionT", connection_cm) @@ -256,8 +279,6 @@ async def provide_session(*args: Any, **kwargs: Any) -> "AsyncGenerator[DriverT, conn_type_annotation = config.connection_type - from litestar.params import Dependency - db_conn_param = inspect.Parameter( name=connection_dependency_key, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, @@ -272,7 +293,7 @@ async def provide_session(*args: Any, **kwargs: Any) -> "AsyncGenerator[DriverT, provide_session.__signature__ = provider_signature # type: ignore[attr-defined] - if not hasattr(provide_session, "__annotations__") or provide_session.__annotations__ is None: + if provide_session.__annotations__ is None: provide_session.__annotations__ = {} provide_session.__annotations__[connection_dependency_key] = conn_type_annotation diff --git a/sqlspec/extensions/litestar/plugin.py b/sqlspec/extensions/litestar/plugin.py index 1a61e2e9..fec875ec 100644 --- a/sqlspec/extensions/litestar/plugin.py +++ b/sqlspec/extensions/litestar/plugin.py @@ -1,60 +1,190 @@ -from typing import TYPE_CHECKING, Any, Union, cast, overload +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, NoReturn, cast, overload from litestar.di import Provide from litestar.plugins import CLIPlugin, InitPluginProtocol -from sqlspec.base import SQLSpec as SQLSpecBase -from sqlspec.config import AsyncConfigT, DatabaseConfigProtocol, DriverT, SyncConfigT +from sqlspec.base import SQLSpec +from sqlspec.config import ( + AsyncConfigT, + AsyncDatabaseConfig, + DatabaseConfigProtocol, + DriverT, + NoPoolAsyncConfig, + NoPoolSyncConfig, + SyncConfigT, + SyncDatabaseConfig, +) from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.extensions.litestar.config import AsyncDatabaseConfig, DatabaseConfig, SyncDatabaseConfig +from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state +from sqlspec.extensions.litestar.handlers import ( + autocommit_handler_maker, + connection_provider_maker, + lifespan_handler_maker, + manual_handler_maker, + pool_provider_maker, + session_provider_maker, +) from sqlspec.typing import ConnectionT, PoolT from sqlspec.utils.logging import get_logger if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Callable + from contextlib import AbstractAsyncContextManager + from click import Group + from litestar import Litestar from litestar.config.app import AppConfig from litestar.datastructures.state import State - from litestar.types import Scope + from litestar.types import BeforeMessageSendHookHandler, Scope from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase from sqlspec.loader import SQLFileLoader logger = get_logger("extensions.litestar") - -class SQLSpec(SQLSpecBase, InitPluginProtocol, CLIPlugin): +CommitMode = Literal["manual", "autocommit", "autocommit_include_redirect"] +DEFAULT_COMMIT_MODE: CommitMode = "manual" +DEFAULT_CONNECTION_KEY = "db_connection" +DEFAULT_POOL_KEY = "db_pool" +DEFAULT_SESSION_KEY = "db_session" + +__all__ = ( + "DEFAULT_COMMIT_MODE", + "DEFAULT_CONNECTION_KEY", + "DEFAULT_POOL_KEY", + "DEFAULT_SESSION_KEY", + "CommitMode", + "SQLSpecPlugin", +) + + +@dataclass +class _PluginConfigState: + """Internal state for each database configuration.""" + + config: "DatabaseConfigProtocol[Any, Any, Any]" + connection_key: str + pool_key: str + session_key: str + commit_mode: CommitMode + extra_commit_statuses: "set[int] | None" + extra_rollback_statuses: "set[int] | None" + enable_correlation_middleware: bool + connection_provider: "Callable[[State, Scope], AsyncGenerator[Any, None]]" = field(init=False) + pool_provider: "Callable[[State, Scope], Any]" = field(init=False) + session_provider: "Callable[..., AsyncGenerator[Any, None]]" = field(init=False) + before_send_handler: "BeforeMessageSendHookHandler" = field(init=False) + lifespan_handler: "Callable[[Litestar], AbstractAsyncContextManager[None]]" = field(init=False) + annotation: "type[DatabaseConfigProtocol[Any, Any, Any]]" = field(init=False) + + +class SQLSpecPlugin(InitPluginProtocol, CLIPlugin): """Litestar plugin for SQLSpec database integration.""" - __slots__ = ("_plugin_configs",) + __slots__ = ("_plugin_configs", "_sqlspec") - def __init__( - self, - config: Union["SyncConfigT", "AsyncConfigT", "DatabaseConfig", list["DatabaseConfig"]], - *, - loader: "SQLFileLoader | None" = None, - ) -> None: + def __init__(self, sqlspec: SQLSpec, *, loader: "SQLFileLoader | None" = None) -> None: """Initialize SQLSpec plugin. Args: - config: Database configuration for SQLSpec plugin. - loader: Optional SQL file loader instance. + sqlspec: Pre-configured SQLSpec instance with registered database configs. + loader: Optional SQL file loader instance (SQLSpec may already have one). """ - super().__init__(loader=loader) - if isinstance(config, DatabaseConfigProtocol): - self._plugin_configs: list[DatabaseConfig] = [DatabaseConfig(config=config)] # pyright: ignore - elif isinstance(config, DatabaseConfig): - self._plugin_configs = [config] + self._sqlspec = sqlspec + + self._plugin_configs: list[_PluginConfigState] = [] + for cfg in self._sqlspec.configs.values(): + config_union = cast( + "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]", + cfg, + ) + settings = self._extract_litestar_settings(config_union) + state = self._create_config_state(config_union, settings) + self._plugin_configs.append(state) + + def _extract_litestar_settings( + self, + config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]", + ) -> "dict[str, Any]": + """Extract Litestar settings from config.extension_config.""" + litestar_config = config.extension_config.get("litestar", {}) + + connection_key = litestar_config.get("connection_key", DEFAULT_CONNECTION_KEY) + pool_key = litestar_config.get("pool_key", DEFAULT_POOL_KEY) + session_key = litestar_config.get("session_key", DEFAULT_SESSION_KEY) + commit_mode = litestar_config.get("commit_mode", DEFAULT_COMMIT_MODE) + + if not config.supports_connection_pooling and pool_key == DEFAULT_POOL_KEY: + pool_key = f"_{DEFAULT_POOL_KEY}_{id(config)}" + + return { + "connection_key": connection_key, + "pool_key": pool_key, + "session_key": session_key, + "commit_mode": commit_mode, + "extra_commit_statuses": litestar_config.get("extra_commit_statuses"), + "extra_rollback_statuses": litestar_config.get("extra_rollback_statuses"), + "enable_correlation_middleware": litestar_config.get("enable_correlation_middleware", True), + } + + def _create_config_state( + self, + config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]", + settings: "dict[str, Any]", + ) -> _PluginConfigState: + """Create plugin state with handlers for the given configuration.""" + state = _PluginConfigState( + config=config, + connection_key=settings["connection_key"], + pool_key=settings["pool_key"], + session_key=settings["session_key"], + commit_mode=settings["commit_mode"], + extra_commit_statuses=settings.get("extra_commit_statuses"), + extra_rollback_statuses=settings.get("extra_rollback_statuses"), + enable_correlation_middleware=settings["enable_correlation_middleware"], + ) + + self._setup_handlers(state) + return state + + def _setup_handlers(self, state: _PluginConfigState) -> None: + """Setup handlers for the plugin state.""" + connection_key = state.connection_key + pool_key = state.pool_key + commit_mode = state.commit_mode + config = state.config + is_async = config.is_async + + state.connection_provider = connection_provider_maker(config, pool_key, connection_key) + state.pool_provider = pool_provider_maker(config, pool_key) + state.session_provider = session_provider_maker(config, connection_key) + state.lifespan_handler = lifespan_handler_maker(config, pool_key) + + if commit_mode == "manual": + state.before_send_handler = manual_handler_maker(connection_key, is_async) else: - self._plugin_configs = config + commit_on_redirect = commit_mode == "autocommit_include_redirect" + state.before_send_handler = autocommit_handler_maker( + connection_key, is_async, commit_on_redirect, state.extra_commit_statuses, state.extra_rollback_statuses + ) @property - def config(self) -> "list[DatabaseConfig]": # pyright: ignore[reportInvalidTypeVarUse] - """Return the plugin configuration. + def config( + self, + ) -> "list[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]": + """Return the plugin configurations. Returns: List of database configurations. """ - return self._plugin_configs + return [ + cast( + "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]", + state.config, + ) + for state in self._plugin_configs + ] def on_cli_init(self, cli: "Group") -> None: """Configure CLI commands for SQLSpec database operations. @@ -75,34 +205,31 @@ def on_app_init(self, app_config: "AppConfig") -> "AppConfig": Returns: The updated application configuration instance. """ - self._validate_dependency_keys() def store_sqlspec_in_state() -> None: app_config.state.sqlspec = self app_config.on_startup.append(store_sqlspec_in_state) - app_config.signature_types.extend( - [SQLSpec, ConnectionT, PoolT, DriverT, DatabaseConfig, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT] - ) + app_config.signature_types.extend([SQLSpec, DatabaseConfigProtocol, SyncConfigT, AsyncConfigT]) - signature_namespace = {} + signature_namespace = {"ConnectionT": ConnectionT, "PoolT": PoolT, "DriverT": DriverT} - for c in self._plugin_configs: - c.annotation = self.add_config(c.config) - app_config.signature_types.append(c.annotation) - app_config.signature_types.append(c.config.connection_type) # type: ignore[union-attr] - app_config.signature_types.append(c.config.driver_type) # type: ignore[union-attr] + for state in self._plugin_configs: + state.annotation = type(state.config) + app_config.signature_types.append(state.annotation) + app_config.signature_types.append(state.config.connection_type) + app_config.signature_types.append(state.config.driver_type) - signature_namespace.update(c.config.get_signature_namespace()) # type: ignore[union-attr] + signature_namespace.update(state.config.get_signature_namespace()) # type: ignore[arg-type] - app_config.before_send.append(c.before_send_handler) - app_config.lifespan.append(c.lifespan_handler) # pyright: ignore[reportUnknownMemberType] + app_config.before_send.append(state.before_send_handler) + app_config.lifespan.append(state.lifespan_handler) app_config.dependencies.update( { - c.connection_key: Provide(c.connection_provider), - c.pool_key: Provide(c.pool_provider), - c.session_key: Provide(c.session_provider), + state.connection_key: Provide(state.connection_provider), + state.pool_key: Provide(state.pool_provider), + state.session_key: Provide(state.session_provider), } ) @@ -111,21 +238,30 @@ def store_sqlspec_in_state() -> None: return app_config - def get_annotations(self) -> "list[type[SyncConfigT | AsyncConfigT]]": # pyright: ignore[reportInvalidTypeVarUse] + def get_annotations( + self, + ) -> "list[type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]]": """Return the list of annotations. Returns: List of annotations. """ - return [c.annotation for c in self.config] + return [ + cast( + "type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]", + state.annotation, + ) + for state in self._plugin_configs + ] def get_annotation( - self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]" - ) -> "type[SyncConfigT | AsyncConfigT]": + self, + key: "str | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any] | type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]", + ) -> "type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]": """Return the annotation for the given configuration. Args: - key: The configuration instance or key to lookup + key: The configuration instance or key to lookup. Raises: KeyError: If no configuration is found for the given key. @@ -133,51 +269,38 @@ def get_annotation( Returns: The annotation for the configuration. """ - for c in self.config: - # Check annotation only if it's been set (during on_app_init) - annotation_match = hasattr(c, "annotation") and key == c.annotation - if key == c.config or annotation_match or key in {c.connection_key, c.pool_key}: - if not hasattr(c, "annotation"): - msg = ( - "Annotation not set for configuration. Ensure the plugin has been initialized with on_app_init." - ) - raise AttributeError(msg) - return c.annotation + for state in self._plugin_configs: + if key in (state.config, state.annotation) or key in {state.connection_key, state.pool_key}: + return cast( + "type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]", + state.annotation, + ) + msg = f"No configuration found for {key}" raise KeyError(msg) @overload - def get_config(self, name: "type[SyncConfigT]") -> "SyncConfigT": ... - - @overload - def get_config(self, name: "type[AsyncConfigT]") -> "AsyncConfigT": ... - - @overload - def get_config(self, name: str) -> "DatabaseConfig": ... + def get_config( + self, name: "type[SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]]" + ) -> "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]": ... @overload - def get_config(self, name: "type[SyncDatabaseConfig]") -> "SyncDatabaseConfig": ... + def get_config( + self, name: "type[AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]]" + ) -> "AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]": ... @overload - def get_config(self, name: "type[AsyncDatabaseConfig]") -> "AsyncDatabaseConfig": ... - def get_config( - self, name: "type[DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]] | str | Any" - ) -> "DatabaseConfigProtocol[ConnectionT, PoolT, DriverT] | DatabaseConfig | SyncDatabaseConfig | AsyncDatabaseConfig": - """Get a configuration instance by name, supporting both base behavior and Litestar extensions. + self, name: str + ) -> "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]": ... - This method extends the base get_config to support Litestar-specific lookup patterns - while maintaining compatibility with the base class signature. It supports lookup by - connection key, pool key, session key, config instance, or annotation type. + def get_config( + self, name: "type[DatabaseConfigProtocol[Any, Any, Any]] | str | Any" + ) -> "DatabaseConfigProtocol[Any, Any, Any]": + """Get a configuration instance by name. Args: - name: The configuration identifier - can be: - - Type annotation (base class behavior) - - connection_key (e.g., "auth_db_connection") - - pool_key (e.g., "analytics_db_pool") - - session_key (e.g., "reporting_db_session") - - config instance - - annotation type + name: The configuration identifier. Raises: KeyError: If no configuration is found for the given name. @@ -185,26 +308,14 @@ def get_config( Returns: The configuration instance for the specified name. """ - # First try base class behavior for type-based lookup - # Only call super() if name matches the expected base class types - if not isinstance(name, str): - try: - return super().get_config(name) # type: ignore[no-any-return] - except (KeyError, AttributeError): - # Fall back to Litestar-specific lookup patterns - pass - - # Litestar-specific lookups by string keys if isinstance(name, str): - for c in self.config: - if name in {c.connection_key, c.pool_key, c.session_key}: - return c # Return the DatabaseConfig wrapper for string lookups + for state in self._plugin_configs: + if name in {state.connection_key, state.pool_key, state.session_key}: + return cast("DatabaseConfigProtocol[Any, Any, Any]", state.config) # type: ignore[redundant-cast] - # Lookup by config instance or annotation - for c in self.config: - annotation_match = hasattr(c, "annotation") and name == c.annotation - if name == c.config or annotation_match: - return c.config # Return the underlying config for type-based lookups + for state in self._plugin_configs: + if name in (state.config, state.annotation): + return cast("DatabaseConfigProtocol[Any, Any, Any]", state.config) # type: ignore[redundant-cast] msg = f"No database configuration found for name '{name}'. Available keys: {self._get_available_keys()}" raise KeyError(msg) @@ -214,57 +325,44 @@ def provide_request_session( ) -> "SyncDriverAdapterBase | AsyncDriverAdapterBase": """Provide a database session for the specified configuration key from request scope. - This is a convenience method that combines get_config and get_request_session - into a single call, similar to Advanced Alchemy's provide_session pattern. - Args: - key: The configuration identifier (same as get_config) - state: The Litestar application State object - scope: The ASGI scope containing the request context + key: The configuration identifier (same as get_config). + state: The Litestar application State object. + scope: The ASGI scope containing the request context. Returns: - A driver session instance for the specified database configuration - - Example: - >>> sqlspec_plugin = connection.app.state.sqlspec - >>> # Direct session access by key - >>> auth_session = sqlspec_plugin.provide_request_session( - ... "auth_db", state, scope - ... ) - >>> analytics_session = sqlspec_plugin.provide_request_session( - ... "analytics_db", state, scope - ... ) + A driver session instance for the specified database configuration. """ - # Get DatabaseConfig wrapper for Litestar methods - db_config = self._get_database_config(key) - return db_config.get_request_session(state, scope) + plugin_state = self._get_plugin_state(key) + session_scope_key = f"{plugin_state.session_key}_instance" + + session = get_sqlspec_scope_state(scope, session_scope_key) + if session is not None: + return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session) + + connection = get_sqlspec_scope_state(scope, plugin_state.connection_key) + if connection is None: + self._raise_missing_connection(plugin_state.connection_key) + + session = plugin_state.config.driver_type(connection=connection) + set_sqlspec_scope_state(scope, session_scope_key, session) + + return cast("SyncDriverAdapterBase | AsyncDriverAdapterBase", session) def provide_sync_request_session( self, key: "str | SyncConfigT | type[SyncConfigT]", state: "State", scope: "Scope" ) -> "SyncDriverAdapterBase": """Provide a sync database session for the specified configuration key from request scope. - This method provides better type hints for sync database sessions, ensuring the returned - session is properly typed as SyncDriverAdapterBase for better IDE support and type safety. - Args: - key: The sync configuration identifier - state: The Litestar application State object - scope: The ASGI scope containing the request context + key: The sync configuration identifier. + state: The Litestar application State object. + scope: The ASGI scope containing the request context. Returns: - A sync driver session instance for the specified database configuration - - Example: - >>> sqlspec_plugin = connection.app.state.sqlspec - >>> auth_session = sqlspec_plugin.provide_sync_request_session( - ... "auth_db", state, scope - ... ) - >>> # auth_session is now correctly typed as SyncDriverAdapterBase + A sync driver session instance for the specified database configuration. """ - # Get DatabaseConfig wrapper for Litestar methods - db_config = self._get_database_config(key) - session = db_config.get_request_session(state, scope) + session = self.provide_request_session(key, state, scope) return cast("SyncDriverAdapterBase", session) def provide_async_request_session( @@ -272,109 +370,88 @@ def provide_async_request_session( ) -> "AsyncDriverAdapterBase": """Provide an async database session for the specified configuration key from request scope. - This method provides better type hints for async database sessions, ensuring the returned - session is properly typed as AsyncDriverAdapterBase for better IDE support and type safety. - Args: - key: The async configuration identifier - state: The Litestar application State object - scope: The ASGI scope containing the request context + key: The async configuration identifier. + state: The Litestar application State object. + scope: The ASGI scope containing the request context. Returns: - An async driver session instance for the specified database configuration - - Example: - >>> sqlspec_plugin = connection.app.state.sqlspec - >>> auth_session = sqlspec_plugin.provide_async_request_session( - ... "auth_db", state, scope - ... ) - >>> # auth_session is now correctly typed as AsyncDriverAdapterBase + An async driver session instance for the specified database configuration. """ - # Get DatabaseConfig wrapper for Litestar methods - db_config = self._get_database_config(key) - session = db_config.get_request_session(state, scope) + session = self.provide_request_session(key, state, scope) return cast("AsyncDriverAdapterBase", session) def provide_request_connection( self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]", state: "State", scope: "Scope" - ) -> Any: + ) -> "Any": """Provide a database connection for the specified configuration key from request scope. - This is a convenience method that combines get_config and get_request_connection - into a single call. - Args: - key: The configuration identifier (same as get_config) - state: The Litestar application State object - scope: The ASGI scope containing the request context + key: The configuration identifier (same as get_config). + state: The Litestar application State object. + scope: The ASGI scope containing the request context. Returns: - A database connection instance for the specified database configuration - - Example: - >>> sqlspec_plugin = connection.app.state.sqlspec - >>> # Direct connection access by key - >>> auth_conn = sqlspec_plugin.provide_request_connection( - ... "auth_db", state, scope - ... ) - >>> analytics_conn = sqlspec_plugin.provide_request_connection( - ... "analytics_db", state, scope - ... ) + A database connection instance for the specified database configuration. """ - # Get DatabaseConfig wrapper for Litestar methods - db_config = self._get_database_config(key) - return db_config.get_request_connection(state, scope) - - def _get_database_config( - self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]" - ) -> DatabaseConfig: - """Get a DatabaseConfig wrapper instance by name. - - This is used internally by provide_request_session and provide_request_connection - to get the DatabaseConfig wrapper that has the request session methods. + plugin_state = self._get_plugin_state(key) + connection = get_sqlspec_scope_state(scope, plugin_state.connection_key) + if connection is None: + self._raise_missing_connection(plugin_state.connection_key) - Args: - key: The configuration identifier - - Returns: - The DatabaseConfig wrapper instance + return connection - Raises: - KeyError: If no configuration is found for the given key - """ - # For string keys, lookup by connection/pool/session keys + def _get_plugin_state( + self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]" + ) -> _PluginConfigState: + """Get plugin state for a configuration by key.""" if isinstance(key, str): - for c in self.config: - if key in {c.connection_key, c.pool_key, c.session_key}: - return c + for state in self._plugin_configs: + if key in {state.connection_key, state.pool_key, state.session_key}: + return state - # For other keys, lookup by config instance or annotation - for c in self.config: - annotation_match = hasattr(c, "annotation") and key == c.annotation - if key == c.config or annotation_match: - return c + for state in self._plugin_configs: + if key in (state.config, state.annotation): + return state - msg = f"No database configuration found for name '{key}'. Available keys: {self._get_available_keys()}" - raise KeyError(msg) + self._raise_config_not_found(key) + return None def _get_available_keys(self) -> "list[str]": """Get a list of all available configuration keys for error messages.""" keys = [] - for c in self.config: - keys.extend([c.connection_key, c.pool_key, c.session_key]) + for state in self._plugin_configs: + keys.extend([state.connection_key, state.pool_key, state.session_key]) return keys def _validate_dependency_keys(self) -> None: - """Validate that connection and pool keys are unique across configurations. + """Validate that connection and pool keys are unique across configurations.""" + connection_keys = [state.connection_key for state in self._plugin_configs] + pool_keys = [state.pool_key for state in self._plugin_configs] - Raises: - ImproperConfigurationError: If connection keys or pool keys are not unique. - """ - connection_keys = [c.connection_key for c in self.config] - pool_keys = [c.pool_key for c in self.config] if len(set(connection_keys)) != len(connection_keys): - msg = "When using multiple database configuration, each configuration must have a unique `connection_key`." - raise ImproperConfigurationError(detail=msg) + self._raise_duplicate_connection_keys() + if len(set(pool_keys)) != len(pool_keys): - msg = "When using multiple database configuration, each configuration must have a unique `pool_key`." - raise ImproperConfigurationError(detail=msg) + self._raise_duplicate_pool_keys() + + def _raise_missing_connection(self, connection_key: str) -> None: + """Raise error when connection is not found in scope.""" + msg = f"No database connection found in scope for key '{connection_key}'. " + msg += "Ensure the connection dependency is properly configured and available." + raise ImproperConfigurationError(detail=msg) + + def _raise_config_not_found(self, key: Any) -> NoReturn: + """Raise error when configuration is not found.""" + msg = f"No database configuration found for name '{key}'. Available keys: {self._get_available_keys()}" + raise KeyError(msg) + + def _raise_duplicate_connection_keys(self) -> None: + """Raise error when connection keys are not unique.""" + msg = "When using multiple database configuration, each configuration must have a unique `connection_key`." + raise ImproperConfigurationError(detail=msg) + + def _raise_duplicate_pool_keys(self) -> None: + """Raise error when pool keys are not unique.""" + msg = "When using multiple database configuration, each configuration must have a unique `pool_key`." + raise ImproperConfigurationError(detail=msg) diff --git a/tests/integration/test_adapters/test_aiosqlite/conftest.py b/tests/integration/test_adapters/test_aiosqlite/conftest.py index 1e595d85..fcd2be58 100644 --- a/tests/integration/test_adapters/test_aiosqlite/conftest.py +++ b/tests/integration/test_adapters/test_aiosqlite/conftest.py @@ -2,8 +2,9 @@ from __future__ import annotations +import os +import tempfile from collections.abc import AsyncGenerator -from uuid import uuid4 import pytest @@ -16,8 +17,7 @@ async def aiosqlite_session() -> AsyncGenerator[AiosqliteDriver, None]: """Create an aiosqlite session with test table.""" - unique_db = f"file:memdb{uuid4().hex}?mode=memory&cache=shared" - config = AiosqliteConfig(pool_config={"database": unique_db}) + config = AiosqliteConfig() try: async with config.provide_session() as session: @@ -49,8 +49,7 @@ async def aiosqlite_session() -> AsyncGenerator[AiosqliteDriver, None]: @pytest.fixture async def aiosqlite_config() -> AsyncGenerator[AiosqliteConfig, None]: """Provide AiosqliteConfig for connection tests.""" - unique_db = f"file:memdb{uuid4().hex}?mode=memory&cache=shared" - config = AiosqliteConfig(pool_config={"database": unique_db}) + config = AiosqliteConfig() try: yield config @@ -61,9 +60,6 @@ async def aiosqlite_config() -> AsyncGenerator[AiosqliteConfig, None]: @pytest.fixture async def aiosqlite_config_file() -> AsyncGenerator[AiosqliteConfig, None]: """Provide AiosqliteConfig with temporary file database for concurrent access tests.""" - import os - import tempfile - with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name diff --git a/tests/integration/test_adapters/test_aiosqlite/test_connection.py b/tests/integration/test_adapters/test_aiosqlite/test_connection.py index 1a21644c..3594f239 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_connection.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_connection.py @@ -3,6 +3,9 @@ from __future__ import annotations +from pathlib import Path +from uuid import uuid4 + import pytest from sqlspec.adapters.aiosqlite import AiosqliteConfig, AiosqliteDriver @@ -89,10 +92,7 @@ async def test_connection_with_transactions(aiosqlite_config: AiosqliteConfig) - async def test_connection_context_manager_cleanup() -> None: """Test proper cleanup of connection context manager.""" - from uuid import uuid4 - - unique_db = f"file:memdb{uuid4().hex}?mode=memory&cache=shared" - config = AiosqliteConfig(pool_config={"database": unique_db}) + config = AiosqliteConfig() driver_ref = None try: @@ -113,15 +113,11 @@ async def test_connection_context_manager_cleanup() -> None: async def test_provide_connection_direct() -> None: """Test direct connection provision without session wrapper.""" - from uuid import uuid4 - - unique_db = f"file:memdb{uuid4().hex}?mode=memory&cache=shared" - config = AiosqliteConfig(pool_config={"database": unique_db}) + config = AiosqliteConfig() try: - if hasattr(config, "provide_connection"): - async with config.provide_connection() as conn: - assert conn is not None + async with config.provide_connection() as conn: + assert conn is not None async with config.provide_session() as driver: assert driver.connection is not None @@ -134,22 +130,17 @@ async def test_provide_connection_direct() -> None: await config.close_pool() -async def test_config_with_pool_config() -> None: +async def test_config_with_pool_config(tmp_path: Path) -> None: """Test that AiosqliteConfig correctly accepts pool_config parameter.""" - from uuid import uuid4 - pool_config = { - "database": f"file:test_{uuid4().hex}.db?mode=memory&cache=shared", - "timeout": 10.0, - "isolation_level": None, - "check_same_thread": False, - } + db_path = tmp_path / f"test_{uuid4().hex}.db" + pool_config = {"database": str(db_path), "timeout": 10.0, "isolation_level": None, "check_same_thread": False} config = AiosqliteConfig(pool_config=pool_config) try: connection_config = config._get_connection_config_dict() - assert "test_" in connection_config["database"] + assert connection_config["database"] == str(db_path) assert connection_config["timeout"] == 10.0 assert connection_config["isolation_level"] is None @@ -165,20 +156,19 @@ async def test_config_with_pool_config() -> None: await config.close_pool() -async def test_config_with_kwargs_override() -> None: +async def test_config_with_kwargs_override(tmp_path: Path) -> None: """Test that kwargs properly override pool_config values.""" - from uuid import uuid4 pool_config = {"database": "base.db", "timeout": 5.0} - unique_db = f"file:override_{uuid4().hex}.db?mode=memory&cache=shared" + db_path = tmp_path / f"override_{uuid4().hex}.db" # Override pool_config with specific test values - test_pool_config = {**pool_config, "database": unique_db, "timeout": 15.0} + test_pool_config = {**pool_config, "database": str(db_path), "timeout": 15.0} config = AiosqliteConfig(pool_config=test_pool_config) try: connection_config = config._get_connection_config_dict() - assert connection_config["database"] == unique_db + assert connection_config["database"] == str(db_path) assert connection_config["timeout"] == 15.0 async with config.provide_session() as driver: @@ -228,16 +218,17 @@ async def test_config_default_database() -> None: await config.close_pool() -async def test_config_parameter_preservation() -> None: +async def test_config_parameter_preservation(tmp_path: Path) -> None: """Test that aiosqlite config properly preserves parameters.""" - pool_config = {"database": "parameter_test.db", "isolation_level": None, "cached_statements": 100} + db_path = tmp_path / "parameter_test.db" + pool_config = {"database": str(db_path), "isolation_level": None, "cached_statements": 100} config = AiosqliteConfig(pool_config=pool_config) try: connection_config = config._get_connection_config_dict() - assert connection_config["database"] == "parameter_test.db" + assert connection_config["database"] == str(db_path) assert connection_config["isolation_level"] is None assert connection_config["cached_statements"] == 100 diff --git a/tests/integration/test_adapters/test_aiosqlite/test_exceptions.py b/tests/integration/test_adapters/test_aiosqlite/test_exceptions.py index 127d0ea9..240b184f 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_exceptions.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_exceptions.py @@ -1,7 +1,6 @@ """Exception handling integration tests for aiosqlite adapter.""" from collections.abc import AsyncGenerator -from uuid import uuid4 import pytest @@ -20,8 +19,7 @@ @pytest.fixture async def aiosqlite_exception_session() -> AsyncGenerator[AiosqliteDriver, None]: """Create an aiosqlite session for exception testing.""" - unique_db = f"file:memdb{uuid4().hex}?mode=memory&cache=shared" - config = AiosqliteConfig(pool_config={"database": unique_db}) + config = AiosqliteConfig() try: async with config.provide_session() as session: diff --git a/tests/integration/test_adapters/test_duckdb/test_connection.py b/tests/integration/test_adapters/test_duckdb/test_connection.py index 906e968e..e9e879e3 100644 --- a/tests/integration/test_adapters/test_duckdb/test_connection.py +++ b/tests/integration/test_adapters/test_duckdb/test_connection.py @@ -1,7 +1,12 @@ # pyright: reportPrivateImportUsage = false, reportPrivateUsage = false """Test DuckDB connection configuration.""" +import os +import tempfile +import time +from pathlib import Path from typing import Any +from uuid import uuid4 import pytest @@ -13,7 +18,6 @@ def create_permissive_config(**kwargs: Any) -> DuckDBConfig: """Create a DuckDB config with permissive SQL settings.""" - import uuid connection_config = kwargs.pop("connection_config", {}) @@ -35,7 +39,7 @@ def create_permissive_config(**kwargs: Any) -> DuckDBConfig: if "database" not in connection_config: # Use a unique memory database identifier to avoid configuration conflicts - connection_config["database"] = f":memory:{uuid.uuid4().hex}" + connection_config["database"] = f":memory:{uuid4().hex}" kwargs["pool_config"] = connection_config return DuckDBConfig(**kwargs) @@ -145,9 +149,6 @@ def connection_hook(conn: DuckDBConnection) -> None: def test_connection_read_only_mode() -> None: """Test DuckDB connection in read-only mode.""" - import os - import tempfile - import time temp_fd, temp_db_path = tempfile.mkstemp(suffix=".duckdb") os.close(temp_fd) @@ -238,16 +239,17 @@ def test_multiple_concurrent_connections() -> None: pass -def test_config_with_pool_config_parameter() -> None: +def test_config_with_pool_config_parameter(tmp_path: Path) -> None: """Test that DuckDBConfig correctly accepts pool_config parameter.""" - pool_config = {"database": "test.duckdb", "memory_limit": "256MB", "threads": 4} + db_path = tmp_path / "test.duckdb" + pool_config = {"database": str(db_path), "memory_limit": "256MB", "threads": 4} config = DuckDBConfig(pool_config=pool_config) try: connection_config = config._get_connection_config_dict() - assert connection_config["database"] == "test.duckdb" + assert connection_config["database"] == str(db_path) assert connection_config["memory_limit"] == "256MB" assert connection_config["threads"] == 4 @@ -314,11 +316,12 @@ def test_config_default_database_shared() -> None: config._close_pool() -def test_config_consistency_with_other_adapters() -> None: +def test_config_consistency_with_other_adapters(tmp_path: Path) -> None: """Test that DuckDB config behaves consistently with SQLite/aiosqlite.""" + db_path = tmp_path / "consistency_test.duckdb" pool_config = { - "database": "consistency_test.duckdb", + "database": str(db_path), "memory_limit": "512MB", "threads": 2, "pool_min_size": 1, @@ -329,7 +332,7 @@ def test_config_consistency_with_other_adapters() -> None: try: connection_config = config._get_connection_config_dict() - assert connection_config["database"] == "consistency_test.duckdb" + assert connection_config["database"] == str(db_path) assert connection_config["memory_limit"] == "512MB" assert connection_config["threads"] == 2 diff --git a/tests/integration/test_adapters/test_sqlite/test_pooling.py b/tests/integration/test_adapters/test_sqlite/test_pooling.py index 1cad61ed..c637e4cb 100644 --- a/tests/integration/test_adapters/test_sqlite/test_pooling.py +++ b/tests/integration/test_adapters/test_sqlite/test_pooling.py @@ -1,6 +1,8 @@ # pyright: reportPrivateImportUsage = false, reportPrivateUsage = false """Integration tests for SQLite connection pooling.""" +from pathlib import Path + import pytest from sqlspec.adapters.sqlite.config import SqliteConfig @@ -211,16 +213,17 @@ def test_pool_transaction_rollback(sqlite_config_shared_memory: SqliteConfig) -> config.close_pool() -def test_config_with_pool_config_parameter() -> None: +def test_config_with_pool_config_parameter(tmp_path: Path) -> None: """Test that SqliteConfig correctly accepts pool_config parameter.""" - pool_config = {"database": "test.sqlite", "timeout": 10.0, "check_same_thread": False} + db_path = tmp_path / "test.sqlite" + pool_config = {"database": str(db_path), "timeout": 10.0, "check_same_thread": False} config = SqliteConfig(pool_config=pool_config) try: connection_config = config._get_connection_config_dict() - assert connection_config["database"] == "test.sqlite" + assert connection_config["database"] == str(db_path) assert connection_config["timeout"] == 10.0 assert connection_config["check_same_thread"] is False diff --git a/tests/unit/test_extensions/test_litestar/test_config.py b/tests/unit/test_extensions/test_litestar/test_config.py deleted file mode 100644 index bf24b8f4..00000000 --- a/tests/unit/test_extensions/test_litestar/test_config.py +++ /dev/null @@ -1,482 +0,0 @@ -"""Test SQLSpec Litestar configuration extensions.""" - -from typing import TYPE_CHECKING, Any, cast -from unittest.mock import MagicMock - -import pytest - -from sqlspec.adapters.sqlite.config import SqliteConfig - -if TYPE_CHECKING: - from litestar.types import Scope -from sqlspec.exceptions import ImproperConfigurationError -from sqlspec.extensions.litestar._utils import set_sqlspec_scope_state -from sqlspec.extensions.litestar.config import AsyncDatabaseConfig, DatabaseConfig, SyncDatabaseConfig -from sqlspec.extensions.litestar.plugin import SQLSpec - - -def test_get_request_session_with_existing_session() -> None: - """Test get_request_session returns existing session from scope.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock existing session in scope - session_scope_key = f"{db_config.session_key}_instance" - expected_session = MagicMock() - set_sqlspec_scope_state(scope, session_scope_key, expected_session) - - # Act - result = db_config.get_request_session(state, scope) - - # Assert - assert result is expected_session - - -def test_get_request_session_creates_new_session() -> None: - """Test get_request_session creates new session when none exists.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - mock_connection = MagicMock() - set_sqlspec_scope_state(scope, db_config.connection_key, mock_connection) - - # Act - result = db_config.get_request_session(state, scope) - - # Assert - assert result is not None - # Verify the session was created with the connection - assert hasattr(result, "connection") - - -def test_get_request_session_raises_when_no_connection() -> None: - """Test get_request_session raises error when no connection in scope.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - state = MagicMock() - scope = cast("Scope", {}) # No connection in scope - - # Act & Assert - with pytest.raises(ImproperConfigurationError) as exc_info: - db_config.get_request_session(state, scope) - - assert "No database connection found in scope" in str(exc_info.value) - assert db_config.connection_key in str(exc_info.value) - - -def test_get_request_connection_returns_connection() -> None: - """Test get_request_connection returns connection from scope.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - expected_connection = MagicMock() - set_sqlspec_scope_state(scope, db_config.connection_key, expected_connection) - - # Act - result = db_config.get_request_connection(state, scope) - - # Assert - assert result is expected_connection - - -def test_get_request_connection_raises_when_no_connection() -> None: - """Test get_request_connection raises error when no connection in scope.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - state = MagicMock() - scope = cast("Scope", {}) # No connection in scope - - # Act & Assert - with pytest.raises(ImproperConfigurationError) as exc_info: - db_config.get_request_connection(state, scope) - - assert "No database connection found in scope" in str(exc_info.value) - assert db_config.connection_key in str(exc_info.value) - - -def test_get_request_session_caches_session_in_scope() -> None: - """Test get_request_session stores created session in scope for reuse.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - mock_connection = MagicMock() - set_sqlspec_scope_state(scope, db_config.connection_key, mock_connection) - - # Act - call twice - result1 = db_config.get_request_session(state, scope) - result2 = db_config.get_request_session(state, scope) - - # Assert - should return the same cached session - assert result1 is result2 - - -def test_database_config_provides_both_methods() -> None: - """Test DatabaseConfig exposes both get_request_session and get_request_connection methods.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - # Assert - assert hasattr(db_config, "get_request_session") - assert callable(db_config.get_request_session) - assert hasattr(db_config, "get_request_connection") - assert callable(db_config.get_request_connection) - - -def test_sqlspec_plugin_get_config_by_connection_key() -> None: - """Test SQLSpec plugin get_config method with connection key.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - auth_db_config = DatabaseConfig(config=sqlite_config, connection_key="auth_db_connection") - - plugin = SQLSpec(config=auth_db_config) - - # Act - result = plugin.get_config("auth_db_connection") - - # Assert - assert result is auth_db_config - - -def test_sqlspec_plugin_get_config_by_pool_key() -> None: - """Test SQLSpec plugin get_config method with pool key.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - analytics_db_config = DatabaseConfig(config=sqlite_config, pool_key="analytics_db_pool") - - plugin = SQLSpec(config=analytics_db_config) - - # Act - result = plugin.get_config("analytics_db_pool") - - # Assert - assert result is analytics_db_config - - -def test_sqlspec_plugin_get_config_by_session_key() -> None: - """Test SQLSpec plugin get_config method with session key.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - reporting_db_config = DatabaseConfig(config=sqlite_config, session_key="reporting_db_session") - - plugin = SQLSpec(config=reporting_db_config) - - # Act - result = plugin.get_config("reporting_db_session") - - # Assert - assert result is reporting_db_config - - -def test_sqlspec_plugin_get_config_with_multiple_configs() -> None: - """Test SQLSpec plugin get_config method with multiple database configurations.""" - # Arrange - auth_sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - auth_db_config = DatabaseConfig(config=auth_sqlite_config, connection_key="auth_db_connection") - - analytics_sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - analytics_db_config = DatabaseConfig(config=analytics_sqlite_config, connection_key="analytics_db_connection") - - plugin = SQLSpec(config=[auth_db_config, analytics_db_config]) - - # Act & Assert - auth_result = plugin.get_config("auth_db_connection") - analytics_result = plugin.get_config("analytics_db_connection") - - assert auth_result is auth_db_config - assert analytics_result is analytics_db_config - - -def test_sqlspec_plugin_get_config_raises_keyerror_for_unknown_key() -> None: - """Test SQLSpec plugin get_config raises KeyError for unknown key.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - plugin = SQLSpec(config=db_config) - - # Act & Assert - with pytest.raises(KeyError) as exc_info: - plugin.get_config("nonexistent_key") - - assert "No database configuration found for name 'nonexistent_key'" in str(exc_info.value) - assert "Available keys:" in str(exc_info.value) - - -def test_sqlspec_plugin_provide_request_session() -> None: - """Test SQLSpec plugin provide_request_session method.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config, connection_key="test_db_connection") - - plugin = SQLSpec(config=db_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - mock_connection = MagicMock() - set_sqlspec_scope_state(scope, db_config.connection_key, mock_connection) - - # Act - result: Any = plugin.provide_request_session("test_db_connection", state, scope) - - # Assert - assert result is not None - assert hasattr(result, "connection") - - -def test_sqlspec_plugin_provide_request_connection() -> None: - """Test SQLSpec plugin provide_request_connection method.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config, connection_key="test_db_connection") - - plugin = SQLSpec(config=db_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - expected_connection = MagicMock() - set_sqlspec_scope_state(scope, db_config.connection_key, expected_connection) - - # Act - result: Any = plugin.provide_request_connection("test_db_connection", state, scope) - - # Assert - assert result is expected_connection - - -def test_sqlspec_plugin_provide_request_session_raises_keyerror_for_unknown_key() -> None: - """Test SQLSpec plugin provide_request_session raises KeyError for unknown key.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - plugin = SQLSpec(config=db_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Act & Assert - with pytest.raises(KeyError): - plugin.provide_request_session("nonexistent_key", state, scope) - - -def test_sqlspec_plugin_provide_request_connection_raises_keyerror_for_unknown_key() -> None: - """Test SQLSpec plugin provide_request_connection raises KeyError for unknown key.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config) - - plugin = SQLSpec(config=db_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Act & Assert - with pytest.raises(KeyError): - plugin.provide_request_connection("nonexistent_key", state, scope) - - -def test_sync_database_config_returns_sync_driver_type() -> None: - """Test SyncDatabaseConfig.get_request_session returns SyncDriverAdapterBase type.""" - # Arrange - - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - sync_db_config = SyncDatabaseConfig(config=sqlite_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - mock_connection = MagicMock() - set_sqlspec_scope_state(scope, sync_db_config.connection_key, mock_connection) - - # Act - result = sync_db_config.get_request_session(state, scope) - - # Assert - assert result is not None - # The type checker should now know this is SyncDriverAdapterBase - assert hasattr(result, "execute") # Basic driver interface check - - -def test_async_database_config_returns_async_driver_type() -> None: - """Test AsyncDatabaseConfig.get_request_session returns AsyncDriverAdapterBase type.""" - # Arrange - using aiosqlite for async example - from sqlspec.adapters.aiosqlite.config import AiosqliteConfig - - aiosqlite_config = AiosqliteConfig(pool_config={"database": ":memory:"}) - async_db_config = AsyncDatabaseConfig(config=aiosqlite_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - mock_connection = MagicMock() - set_sqlspec_scope_state(scope, async_db_config.connection_key, mock_connection) - - # Act - result = async_db_config.get_request_session(state, scope) - - # Assert - assert result is not None - # The type checker should now know this is AsyncDriverAdapterBase - assert hasattr(result, "execute") # Basic driver interface check - - -def test_specialized_configs_inherit_from_base_config() -> None: - """Test that specialized configs inherit all functionality from DatabaseConfig.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - sync_config = SyncDatabaseConfig(config=sqlite_config) - - from sqlspec.adapters.aiosqlite.config import AiosqliteConfig - - aiosqlite_config = AiosqliteConfig(pool_config={"database": ":memory:"}) - async_config = AsyncDatabaseConfig(config=aiosqlite_config) - - # Assert - Should have all the same attributes as base DatabaseConfig - base_attrs = [ - "connection_key", - "pool_key", - "session_key", - "commit_mode", - "get_request_session", - "get_request_connection", - ] - - for attr in base_attrs: - assert hasattr(sync_config, attr), f"SyncDatabaseConfig missing {attr}" - assert hasattr(async_config, attr), f"AsyncDatabaseConfig missing {attr}" - - -def test_specialized_configs_work_with_plugin() -> None: - """Test that SQLSpec plugin works with specialized database configs.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - sync_db_config = SyncDatabaseConfig(config=sqlite_config, connection_key="sync_db") - - from sqlspec.adapters.aiosqlite.config import AiosqliteConfig - - aiosqlite_config = AiosqliteConfig(pool_config={"database": ":memory:"}) - async_db_config = AsyncDatabaseConfig(config=aiosqlite_config, connection_key="async_db") - - plugin = SQLSpec(config=[sync_db_config, async_db_config]) - - # Act & Assert - sync_config = plugin.get_config("sync_db") - async_config = plugin.get_config("async_db") - - assert isinstance(sync_config, SyncDatabaseConfig) - assert isinstance(async_config, AsyncDatabaseConfig) - assert sync_config is sync_db_config - assert async_config is async_db_config - - -def test_sqlspec_plugin_provide_sync_request_session() -> None: - """Test SQLSpec plugin provide_sync_request_session method returns properly typed session.""" - # Arrange - sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=sqlite_config, connection_key="sync_db_connection") - - plugin = SQLSpec(config=db_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - mock_connection = MagicMock() - set_sqlspec_scope_state(scope, db_config.connection_key, mock_connection) - - # Act - result: Any = plugin.provide_sync_request_session("sync_db_connection", state, scope) - - # Assert - assert result is not None - assert hasattr(result, "connection") - # The returned type should be SyncDriverAdapterBase according to type hints - - -def test_sqlspec_plugin_provide_async_request_session() -> None: - """Test SQLSpec plugin provide_async_request_session method returns properly typed session.""" - # Arrange - from sqlspec.adapters.aiosqlite.config import AiosqliteConfig - - aiosqlite_config = AiosqliteConfig(pool_config={"database": ":memory:"}) - db_config = DatabaseConfig(config=aiosqlite_config, connection_key="async_db_connection") - - plugin = SQLSpec(config=db_config) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connection in scope - mock_connection = MagicMock() - set_sqlspec_scope_state(scope, db_config.connection_key, mock_connection) - - # Act - result: Any = plugin.provide_async_request_session("async_db_connection", state, scope) - - # Assert - assert result is not None - assert hasattr(result, "connection") - # The returned type should be AsyncDriverAdapterBase according to type hints - - -def test_sqlspec_plugin_provide_typed_session_methods_with_multiple_configs() -> None: - """Test typed session methods work with multiple database configurations.""" - # Arrange - sync_sqlite_config = SqliteConfig(pool_config={"database": ":memory:"}) - sync_db_config = DatabaseConfig(config=sync_sqlite_config, connection_key="sync_db") - - from sqlspec.adapters.aiosqlite.config import AiosqliteConfig - - async_sqlite_config = AiosqliteConfig(pool_config={"database": ":memory:"}) - async_db_config = DatabaseConfig(config=async_sqlite_config, connection_key="async_db") - - plugin = SQLSpec(config=[sync_db_config, async_db_config]) - - state = MagicMock() - scope = cast("Scope", {}) - - # Mock connections in scope - mock_sync_connection = MagicMock() - mock_async_connection = MagicMock() - set_sqlspec_scope_state(scope, sync_db_config.connection_key, mock_sync_connection) - set_sqlspec_scope_state(scope, async_db_config.connection_key, mock_async_connection) - - # Act & Assert - sync session - sync_result: Any = plugin.provide_sync_request_session("sync_db", state, scope) - assert sync_result is not None - assert hasattr(sync_result, "connection") - - # Act & Assert - async session - async_result: Any = plugin.provide_async_request_session("async_db", state, scope) - assert async_result is not None - assert hasattr(async_result, "connection") diff --git a/tests/unit/test_extensions/test_litestar/test_handlers.py b/tests/unit/test_extensions/test_litestar/test_handlers.py new file mode 100644 index 00000000..518b7961 --- /dev/null +++ b/tests/unit/test_extensions/test_litestar/test_handlers.py @@ -0,0 +1,290 @@ +"""Test handlers for SQLSpec Litestar extension.""" + +from typing import TYPE_CHECKING, Any, cast +from unittest.mock import AsyncMock, MagicMock + +import pytest +from litestar.constants import HTTP_RESPONSE_START + +from sqlspec.adapters.aiosqlite.config import AiosqliteConfig +from sqlspec.exceptions import ImproperConfigurationError +from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state +from sqlspec.extensions.litestar.handlers import ( + autocommit_handler_maker, + connection_provider_maker, + lifespan_handler_maker, + manual_handler_maker, + pool_provider_maker, + session_provider_maker, +) + +if TYPE_CHECKING: + from litestar.types import Message, Scope + + +@pytest.mark.asyncio +async def test_async_manual_handler_closes_connection() -> None: + """Test async manual handler closes connection on terminus event.""" + connection_key = "test_connection" + handler = manual_handler_maker(connection_key, is_async=True) + + mock_connection = AsyncMock() + mock_connection.close = AsyncMock() + + scope = cast("Scope", {}) + set_sqlspec_scope_state(scope, connection_key, mock_connection) + + message = cast("Message", {"type": HTTP_RESPONSE_START, "status": 200}) + + await handler(message, scope) + + mock_connection.close.assert_awaited_once() + assert get_sqlspec_scope_state(scope, connection_key) is None + + +@pytest.mark.asyncio +async def test_async_manual_handler_ignores_non_terminus_events() -> None: + """Test async manual handler ignores non-terminus events.""" + connection_key = "test_connection" + handler = manual_handler_maker(connection_key, is_async=True) + + mock_connection = AsyncMock() + mock_connection.close = AsyncMock() + + scope = cast("Scope", {}) + set_sqlspec_scope_state(scope, connection_key, mock_connection) + + message = cast("Message", {"type": "http.request"}) + + await handler(message, scope) + + mock_connection.close.assert_not_awaited() + assert get_sqlspec_scope_state(scope, connection_key) is mock_connection + + +@pytest.mark.asyncio +async def test_async_autocommit_handler_commits_on_success() -> None: + """Test async autocommit handler commits on 2xx status.""" + connection_key = "test_connection" + handler = autocommit_handler_maker(connection_key, is_async=True) + + mock_connection = AsyncMock() + mock_connection.commit = AsyncMock() + mock_connection.rollback = AsyncMock() + mock_connection.close = AsyncMock() + + scope = cast("Scope", {}) + set_sqlspec_scope_state(scope, connection_key, mock_connection) + + message = cast("Message", {"type": HTTP_RESPONSE_START, "status": 200}) + + await handler(message, scope) + + mock_connection.commit.assert_awaited_once() + mock_connection.rollback.assert_not_awaited() + mock_connection.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_async_autocommit_handler_rolls_back_on_error() -> None: + """Test async autocommit handler rolls back on 4xx/5xx status.""" + connection_key = "test_connection" + handler = autocommit_handler_maker(connection_key, is_async=True) + + mock_connection = AsyncMock() + mock_connection.commit = AsyncMock() + mock_connection.rollback = AsyncMock() + mock_connection.close = AsyncMock() + + scope = cast("Scope", {}) + set_sqlspec_scope_state(scope, connection_key, mock_connection) + + message = cast("Message", {"type": HTTP_RESPONSE_START, "status": 500}) + + await handler(message, scope) + + mock_connection.commit.assert_not_awaited() + mock_connection.rollback.assert_awaited_once() + mock_connection.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_async_autocommit_handler_with_redirect_commit() -> None: + """Test async autocommit handler commits on 3xx when enabled.""" + connection_key = "test_connection" + handler = autocommit_handler_maker(connection_key, is_async=True, commit_on_redirect=True) + + mock_connection = AsyncMock() + mock_connection.commit = AsyncMock() + mock_connection.rollback = AsyncMock() + + scope = cast("Scope", {}) + set_sqlspec_scope_state(scope, connection_key, mock_connection) + + message = cast("Message", {"type": HTTP_RESPONSE_START, "status": 301}) + + await handler(message, scope) + + mock_connection.commit.assert_awaited_once() + mock_connection.rollback.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_async_autocommit_handler_extra_commit_statuses() -> None: + """Test async autocommit handler uses extra commit statuses.""" + connection_key = "test_connection" + handler = autocommit_handler_maker(connection_key, is_async=True, extra_commit_statuses={418}) + + mock_connection = AsyncMock() + mock_connection.commit = AsyncMock() + mock_connection.rollback = AsyncMock() + + scope = cast("Scope", {}) + set_sqlspec_scope_state(scope, connection_key, mock_connection) + + message = cast("Message", {"type": HTTP_RESPONSE_START, "status": 418}) + + await handler(message, scope) + + mock_connection.commit.assert_awaited_once() + mock_connection.rollback.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_async_autocommit_handler_raises_on_conflicting_statuses() -> None: + """Test async autocommit handler raises error when status sets overlap.""" + with pytest.raises(ImproperConfigurationError) as exc_info: + autocommit_handler_maker("test", is_async=True, extra_commit_statuses={418}, extra_rollback_statuses={418}) + + assert "must not share" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_async_lifespan_handler_creates_and_closes_pool() -> None: + """Test async lifespan handler manages pool lifecycle.""" + config = AiosqliteConfig(pool_config={"database": ":memory:"}) + pool_key = "test_pool" + + handler = lifespan_handler_maker(config, pool_key) + + mock_app = MagicMock() + mock_app.state = {} + mock_app.logger = None + + async with handler(mock_app): + assert pool_key in mock_app.state + pool = mock_app.state[pool_key] + assert pool is not None + + assert pool_key not in mock_app.state + + +@pytest.mark.asyncio +async def test_async_pool_provider_returns_pool() -> None: + """Test async pool provider returns pool from state.""" + config = AiosqliteConfig(pool_config={"database": ":memory:"}) + pool_key = "test_pool" + + provider = pool_provider_maker(config, pool_key) + + mock_pool = MagicMock() + state = MagicMock() + state.get.return_value = mock_pool + scope = cast("Scope", {}) + + result: Any = await provider(state, scope) + + assert result is mock_pool + state.get.assert_called_once_with(pool_key) + + +@pytest.mark.asyncio +async def test_async_pool_provider_raises_when_pool_missing() -> None: + """Test async pool provider raises error when pool not in state.""" + config = AiosqliteConfig(pool_config={"database": ":memory:"}) + pool_key = "test_pool" + + provider = pool_provider_maker(config, pool_key) + + state = MagicMock() + state.get.return_value = None + scope = cast("Scope", {}) + + with pytest.raises(ImproperConfigurationError) as exc_info: + await provider(state, scope) + + assert pool_key in str(exc_info.value) + assert "not found in application state" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_async_connection_provider_creates_connection() -> None: + """Test async connection provider creates connection from pool.""" + config = AiosqliteConfig(pool_config={"database": ":memory:"}) + pool_key = "test_pool" + connection_key = "test_connection" + + provider = connection_provider_maker(config, pool_key, connection_key) + + mock_pool = await config.create_pool() + state = MagicMock() + state.get.return_value = mock_pool + scope = cast("Scope", {}) + + connection: Any + async for connection in provider(state, scope): + assert connection is not None + assert get_sqlspec_scope_state(scope, connection_key) is connection + + +@pytest.mark.asyncio +async def test_async_connection_provider_raises_when_pool_missing() -> None: + """Test async connection provider raises error when pool missing.""" + config = AiosqliteConfig(pool_config={"database": ":memory:"}) + pool_key = "test_pool" + connection_key = "test_connection" + + provider = connection_provider_maker(config, pool_key, connection_key) + + state = MagicMock() + state.get.return_value = None + scope = cast("Scope", {}) + + with pytest.raises(ImproperConfigurationError) as exc_info: + async for _ in provider(state, scope): + pass + + assert pool_key in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_async_session_provider_creates_session() -> None: + """Test async session provider creates driver session.""" + config = AiosqliteConfig(pool_config={"database": ":memory:"}) + connection_key = "test_connection" + + provider = session_provider_maker(config, connection_key) + + mock_connection = AsyncMock() + + session: Any + async for session in provider(mock_connection): + assert session is not None + assert session.connection is mock_connection + + +def test_handlers_conditionally_use_ensure_async() -> None: + """Test that unified handlers module imports ensure_async_ and uses it conditionally.""" + from pathlib import Path + + from sqlspec.extensions.litestar import handlers + + source = handlers.__file__ + assert source is not None + + content = Path(source).read_text() + + assert "from sqlspec.utils.sync_tools import ensure_async_" in content + assert "if is_async:" in content, "handlers should check is_async flag" + assert "await connection.close()" in content, "async path should use direct await" + assert "await ensure_async_(connection.close)()" in content, "sync path should use ensure_async_"