diff --git a/sqlspec/extensions/litestar/handlers.py b/sqlspec/extensions/litestar/handlers.py index 59f30e18a..93114d8de 100644 --- a/sqlspec/extensions/litestar/handlers.py +++ b/sqlspec/extensions/litestar/handlers.py @@ -1,7 +1,7 @@ import contextlib import inspect from collections.abc import AsyncGenerator, Callable -from contextlib import AbstractAsyncContextManager +from contextlib import AbstractAsyncContextManager, AbstractContextManager from typing import TYPE_CHECKING, Any, cast from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT @@ -13,7 +13,7 @@ get_sqlspec_scope_state, set_sqlspec_scope_state, ) -from sqlspec.utils.sync_tools import ensure_async_ +from sqlspec.utils.sync_tools import ensure_async_, with_ensure_async_ if TYPE_CHECKING: from collections.abc import Awaitable, Coroutine @@ -239,8 +239,14 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ raise ImproperConfigurationError(msg) connection_cm: Any = config.provide_connection(db_pool) + context_manager: AbstractAsyncContextManager[ConnectionT] | None = None - if not isinstance(connection_cm, AbstractAsyncContextManager): + if isinstance(connection_cm, AbstractAsyncContextManager): + context_manager = connection_cm + elif isinstance(connection_cm, AbstractContextManager): + context_manager = with_ensure_async_(connection_cm) + + if context_manager is None: conn_instance: ConnectionT if inspect.isawaitable(connection_cm): conn_instance = await cast("Awaitable[ConnectionT]", connection_cm) @@ -250,12 +256,12 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[ yield conn_instance return - entered_connection = await connection_cm.__aenter__() + entered_connection = await context_manager.__aenter__() try: set_sqlspec_scope_state(scope, connection_key, entered_connection) yield entered_connection finally: - await connection_cm.__aexit__(None, None, None) + await context_manager.__aexit__(None, None, None) delete_sqlspec_scope_state(scope, connection_key) return provide_connection diff --git a/tests/unit/test_extensions/test_litestar/test_handlers.py b/tests/unit/test_extensions/test_litestar/test_handlers.py index 8af30136b..b6d7f1d67 100644 --- a/tests/unit/test_extensions/test_litestar/test_handlers.py +++ b/tests/unit/test_extensions/test_litestar/test_handlers.py @@ -7,6 +7,7 @@ from litestar.constants import HTTP_RESPONSE_START from sqlspec.adapters.aiosqlite.config import AiosqliteConfig +from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.exceptions import ImproperConfigurationError from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state from sqlspec.extensions.litestar.handlers import ( @@ -245,6 +246,29 @@ async def test_async_connection_provider_raises_when_pool_missing() -> None: assert pool_key in str(exc_info.value) +async def test_sync_connection_provider_supports_context_manager() -> None: + """Test sync connection provider wraps sync context managers.""" + config = SqliteConfig(pool_config={"database": ":memory:"}) + pool_key = "test_pool" + connection_key = "test_connection" + + provider = connection_provider_maker(config, pool_key, connection_key) + + pool = config.create_pool() + state = MagicMock() + state.get.return_value = pool + scope = cast("Scope", {}) + + try: + async for connection in provider(state, scope): + assert connection is not None + assert get_sqlspec_scope_state(scope, connection_key) is connection + finally: + pool.close() + + assert get_sqlspec_scope_state(scope, connection_key) is None + + async def test_async_session_provider_creates_session() -> None: """Test async session provider creates driver session.""" config = AiosqliteConfig(pool_config={"database": ":memory:"})