Skip to content

Commit

Permalink
feat: adds a get_session context manager to the config (#83)
Browse files Browse the repository at this point in the history
Generally most apps end up creating some sort of context manager to help with sessions. While the Litestar plugin handles this for us, I still find the need to get to a context manager outside of requests. See here for an (example)[https://github.com/litestar-org/litestar-fullstack/blob/main/src/app/lib/db/base.py#L103].

This PR enables quick access to a `Session`, `AsyncSession` context manager from the SQL Alchemy configuration.
  • Loading branch information
cofin committed Oct 31, 2023
1 parent 82251f9 commit 1800fe7
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 7 deletions.
11 changes: 10 additions & 1 deletion advanced_alchemy/config/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, AsyncGenerator

from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine

Expand Down Expand Up @@ -68,3 +69,11 @@ def __post_init__(self) -> None:
if self.metadata:
self.alembic_config.target_metadata = self.metadata
super().__post_init__()

@asynccontextmanager
async def get_session(
self,
) -> AsyncGenerator[AsyncSession, None]:
session_maker = self.create_session_maker()
async with session_maker() as session:
yield session
11 changes: 10 additions & 1 deletion advanced_alchemy/config/sync.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Generator

from sqlalchemy import Connection, Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker
Expand Down Expand Up @@ -53,3 +54,11 @@ def __post_init__(self) -> None:
if self.metadata:
self.alembic_config.target_metadata = self.metadata
super().__post_init__()

@contextmanager
def get_session(
self,
) -> Generator[Session, None, None]:
session_maker = self.create_session_maker()
with session_maker() as session:
yield session
46 changes: 41 additions & 5 deletions tests/unit/test_extensions/test_litestar/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dataclasses import replace
from pathlib import Path
from types import ModuleType
from typing import Any, Callable, TypeVar, cast
from typing import Any, AsyncGenerator, Callable, TypeVar, cast
from unittest.mock import ANY

import pytest
Expand All @@ -27,8 +27,8 @@
from litestar.types.empty import Empty
from litestar.typing import FieldDefinition
from pytest import FixtureRequest, MonkeyPatch
from sqlalchemy import Engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
from sqlalchemy import Engine, NullPool, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, sessionmaker

from advanced_alchemy.extensions.litestar.alembic import AlembicCommands
Expand Down Expand Up @@ -224,20 +224,56 @@ def scope(create_scope: Callable[..., Scope]) -> Scope:


@pytest.fixture()
async def sync_sqlalchemy_plugin(engine: Engine, session_maker: sessionmaker[Session]) -> SQLAlchemyPlugin:
def engine() -> Generator[Engine, None, None]:
"""SQLite engine for end-to-end testing.
Returns:
Async SQLAlchemy engine instance.
"""
engine = create_engine("sqlite:///:memory:", poolclass=NullPool)
try:
yield engine
finally:
engine.dispose()


@pytest.fixture()
async def sync_sqlalchemy_plugin(
engine: Engine,
session_maker: sessionmaker[Session] | None = None,
) -> SQLAlchemyPlugin:
return SQLAlchemyPlugin(config=SQLAlchemySyncConfig(engine_instance=engine, session_maker=session_maker))


@pytest.fixture()
async def async_engine() -> AsyncGenerator[AsyncEngine, None]:
"""SQLite engine for end-to-end testing.
Returns:
Async SQLAlchemy engine instance.
"""
engine = create_async_engine("sqlite+aiosqlite:///:memory:", poolclass=NullPool)
try:
yield engine
finally:
await engine.dispose()


@pytest.fixture()
async def async_sqlalchemy_plugin(
async_engine: AsyncEngine,
async_session_maker: async_sessionmaker[AsyncSession],
async_session_maker: async_sessionmaker[AsyncSession] | None = None,
) -> SQLAlchemyPlugin:
return SQLAlchemyPlugin(
config=SQLAlchemyAsyncConfig(engine_instance=async_engine, session_maker=async_session_maker),
)


@pytest.fixture(params=[pytest.param("sync_sqlalchemy_plugin"), pytest.param("async_sqlalchemy_plugin")])
async def plugin(request: FixtureRequest) -> SQLAlchemyPlugin:
return cast(SQLAlchemyPlugin, request.getfixturevalue(request.param))


@pytest.fixture()
async def sync_app(sync_sqlalchemy_plugin: SQLAlchemyPlugin) -> Litestar:
return Litestar(plugins=[sync_sqlalchemy_plugin])
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_extensions/test_litestar/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import cast

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyPlugin
from advanced_alchemy.extensions.litestar.plugins.init.config.asyncio import SQLAlchemyAsyncConfig
from advanced_alchemy.extensions.litestar.plugins.init.config.sync import SQLAlchemySyncConfig


async def test_sync_db_session(sync_sqlalchemy_plugin: SQLAlchemyPlugin) -> None:
config = cast("SQLAlchemySyncConfig", sync_sqlalchemy_plugin._config)

with config.get_session() as session:
assert isinstance(session, Session)


async def test_async_db_session(async_sqlalchemy_plugin: SQLAlchemyPlugin) -> None:
config = cast("SQLAlchemyAsyncConfig", async_sqlalchemy_plugin._config)

async with config.get_session() as session:
assert isinstance(session, AsyncSession)

0 comments on commit 1800fe7

Please sign in to comment.