Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions sqlspec/extensions/litestar/handlers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import inspect
from collections.abc import AsyncGenerator, Callable
from contextlib import AbstractAsyncContextManager
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from typing import TYPE_CHECKING, Any, cast

from litestar.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT
Expand All @@ -13,7 +13,7 @@
get_sqlspec_scope_state,
set_sqlspec_scope_state,
)
from sqlspec.utils.sync_tools import ensure_async_
from sqlspec.utils.sync_tools import ensure_async_, with_ensure_async_

if TYPE_CHECKING:
from collections.abc import Awaitable, Coroutine
Expand Down Expand Up @@ -239,8 +239,14 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[
raise ImproperConfigurationError(msg)

connection_cm: Any = config.provide_connection(db_pool)
context_manager: AbstractAsyncContextManager[ConnectionT] | None = None

if not isinstance(connection_cm, AbstractAsyncContextManager):
if isinstance(connection_cm, AbstractAsyncContextManager):
context_manager = connection_cm
elif isinstance(connection_cm, AbstractContextManager):
context_manager = with_ensure_async_(connection_cm)

if context_manager is None:
conn_instance: ConnectionT
if inspect.isawaitable(connection_cm):
conn_instance = await cast("Awaitable[ConnectionT]", connection_cm)
Expand All @@ -250,12 +256,12 @@ async def provide_connection(state: "State", scope: "Scope") -> "AsyncGenerator[
yield conn_instance
return

entered_connection = await connection_cm.__aenter__()
entered_connection = await context_manager.__aenter__()
try:
set_sqlspec_scope_state(scope, connection_key, entered_connection)
yield entered_connection
finally:
await connection_cm.__aexit__(None, None, None)
await context_manager.__aexit__(None, None, None)
delete_sqlspec_scope_state(scope, connection_key)

return provide_connection
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_extensions/test_litestar/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from litestar.constants import HTTP_RESPONSE_START

from sqlspec.adapters.aiosqlite.config import AiosqliteConfig
from sqlspec.adapters.sqlite.config import SqliteConfig
from sqlspec.exceptions import ImproperConfigurationError
from sqlspec.extensions.litestar._utils import get_sqlspec_scope_state, set_sqlspec_scope_state
from sqlspec.extensions.litestar.handlers import (
Expand Down Expand Up @@ -245,6 +246,29 @@ async def test_async_connection_provider_raises_when_pool_missing() -> None:
assert pool_key in str(exc_info.value)


async def test_sync_connection_provider_supports_context_manager() -> None:
"""Test sync connection provider wraps sync context managers."""
config = SqliteConfig(pool_config={"database": ":memory:"})
pool_key = "test_pool"
connection_key = "test_connection"

provider = connection_provider_maker(config, pool_key, connection_key)

pool = config.create_pool()
state = MagicMock()
state.get.return_value = pool
scope = cast("Scope", {})

try:
async for connection in provider(state, scope):
assert connection is not None
assert get_sqlspec_scope_state(scope, connection_key) is connection
finally:
pool.close()

assert get_sqlspec_scope_state(scope, connection_key) is None


async def test_async_session_provider_creates_session() -> None:
"""Test async session provider creates driver session."""
config = AiosqliteConfig(pool_config={"database": ":memory:"})
Expand Down
Loading