Skip to content

Commit

Permalink
feat: add create_all configuration parameter for Litestar (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Oct 24, 2023
1 parent 3a2ced4 commit 54d6a63
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 2 deletions.
5 changes: 5 additions & 0 deletions advanced_alchemy/config/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, a
The configuration options are documented in the Alembic documentation.
"""

def __post_init__(self) -> None:
if self.metadata:
self.alembic_config.target_metadata = self.metadata
super().__post_init__()
7 changes: 7 additions & 0 deletions advanced_alchemy/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]):
If set, the plugin will use the provided instance rather than instantiate an engine.
"""
create_all: bool = False
"""If true, all models are automatically created on engine creation."""

metadata: MetaData | None = None
"""Optional metadata to use.
If set, the plugin will use the provided instance rather than the default metadata."""

def __post_init__(self) -> None:
if self.connection_string is not None and self.engine_instance is not None:
Expand Down
5 changes: 5 additions & 0 deletions advanced_alchemy/config/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ class SQLAlchemySyncConfig(GenericSQLAlchemyConfig[Engine, Session, sessionmaker
The configuration options are documented in the Alembic documentation.
"""

def __post_init__(self) -> None:
if self.metadata:
self.alembic_config.target_metadata = self.metadata
super().__post_init__()
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def autocommit_handler_maker(
extra_commit_statuses: set[int] | None = None,
extra_rollback_statuses: set[int] | None = None,
) -> Callable[[Message, Scope], Coroutine[Any, Any, None]]:
"""Set up the handler to issue a transactin commit or rollback based on specified status codes
"""Set up the handler to issue a transaction commit or rollback based on specified status codes
Args:
commit_on_redirect: Issue a commit when the response status is a redirect (``3XX``)
extra_commit_statuses: A set of additional status codes that trigger a commit
Expand Down Expand Up @@ -196,6 +196,15 @@ async def on_shutdown(self, app: Litestar) -> None:
engine = cast("AsyncEngine", app.state.pop(self.engine_app_state_key))
await engine.dispose()

async def create_all_metadata(self, app: Litestar) -> None:
"""Create all metadata
Args:
app (Litestar): The ``Litestar`` instance
"""
async with self.get_engine().begin() as conn:
await conn.run_sync(self.alembic_config.target_metadata.create_all)

def create_app_state_items(self) -> dict[str, Any]:
"""Key/value pairs to be stored in application state."""
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,15 @@ def on_shutdown(self, app: Litestar) -> None:
engine = cast("Engine", app.state.pop(self.engine_app_state_key))
engine.dispose()

def create_all_metadata(self, app: Litestar) -> None:
"""Create all metadata
Args:
app (Litestar): The ``Litestar`` instance
"""
with self.get_engine().begin() as conn:
self.alembic_config.target_metadata.create_all(bind=conn)

def create_app_state_items(self) -> dict[str, Any]:
"""Key/value pairs to be stored in application state."""
return {
Expand Down
2 changes: 2 additions & 0 deletions advanced_alchemy/extensions/litestar/plugins/init/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig:
)
app_config.before_send.append(self._config.before_send_handler)
app_config.on_startup.insert(0, self._config.update_app_state)
if self._config.create_all:
app_config.on_startup.append(self._config.create_all_metadata)
app_config.on_shutdown.append(self._config.on_shutdown)
app_config.signature_namespace.update(self._config.signature_namespace)
app_config.signature_namespace.update(signature_namespace_values)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from litestar.testing import create_test_client
from litestar.types.asgi_types import HTTPResponseStartEvent
from litestar.utils import set_litestar_scope_state
from pytest import MonkeyPatch
from sqlalchemy.ext.asyncio import AsyncSession

from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyAsyncConfig, SQLAlchemyInitPlugin
Expand Down Expand Up @@ -42,6 +43,28 @@ def test_handler(db_session: AsyncSession, scope: Scope) -> None:
assert config.session_dependency_key not in captured_scope_state # pyright: ignore


async def test_create_all_default(monkeypatch: MonkeyPatch) -> None:
"""Test default_before_send_handler."""

config = SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite://")
plugin = SQLAlchemyInitPlugin(config=config)
mock_fx = MagicMock()
monkeypatch.setattr(config, "create_all_metadata", mock_fx)
with create_test_client(route_handlers=[], plugins=[plugin]) as _client:
mock_fx.assert_not_called()


async def test_create_all(monkeypatch: MonkeyPatch) -> None:
"""Test default_before_send_handler."""

config = SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite://", create_all=True)
plugin = SQLAlchemyInitPlugin(config=config)
mock_fx = MagicMock()
monkeypatch.setattr(config, "create_all_metadata", mock_fx)
with create_test_client(route_handlers=[], plugins=[plugin]) as _client:
mock_fx.assert_called_once()


async def test_before_send_handler_success_response(create_scope: Callable[..., Scope]) -> None:
"""Test that the session is committed given a success response."""
config = SQLAlchemyAsyncConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import random
from typing import TYPE_CHECKING
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

from litestar import Litestar, get
from litestar.testing import create_test_client
from litestar.types.asgi_types import HTTPResponseStartEvent
from litestar.utils import set_litestar_scope_state
from pytest import MonkeyPatch
from sqlalchemy.orm import Session

from advanced_alchemy.extensions.litestar.plugins import (
Expand Down Expand Up @@ -45,6 +46,30 @@ def test_handler(db_session: Session, scope: Scope) -> None:
assert config.session_dependency_key not in captured_scope_state


def test_create_all_default(monkeypatch: MonkeyPatch) -> None:
"""Test default_before_send_handler."""

config = SQLAlchemySyncConfig(connection_string="sqlite+aiosqlite://")
plugin = SQLAlchemyInitPlugin(config=config)
with patch.object(
config,
"create_all_metadata",
) as create_all_metadata_mock, create_test_client(route_handlers=[], plugins=[plugin]) as _client:
create_all_metadata_mock.assert_not_called()


def test_create_all(monkeypatch: MonkeyPatch) -> None:
"""Test default_before_send_handler."""

config = SQLAlchemySyncConfig(connection_string="sqlite+aiosqlite://", create_all=True)
plugin = SQLAlchemyInitPlugin(config=config)
with patch.object(
config,
"create_all_metadata",
) as create_all_metadata_mock, create_test_client(route_handlers=[], plugins=[plugin]) as _client:
create_all_metadata_mock.assert_called_once()


def test_before_send_handler_success_response(create_scope: Callable[..., Scope]) -> None:
"""Test that the session is committed given a success response."""
config = SQLAlchemySyncConfig(connection_string="sqlite://", before_send_handler=autocommit_before_send_handler)
Expand Down

0 comments on commit 54d6a63

Please sign in to comment.