From ffd1218eaca3e3072b6ad3a892b496464124052b Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 12:46:36 +1000 Subject: [PATCH 01/16] SQLAlchemy 2.0 `InitPluginProtocol` implementation. --- docs/conf.py | 5 +- .../__init__.py | 0 .../sqlalchemy_async.py | 0 .../sqlalchemy_relationships.py | 0 .../sqlalchemy_relationships_to_many.py | 0 .../sqlalchemy_sync.py | 0 .../sqlalchemy_init_plugin/__init__.py | 0 .../sqlalchemy_async.py | 31 +++ .../sqlalchemy_init_plugin/sqlalchemy_sync.py | 31 +++ ..._plugin.py => test_sqlalchemy_1_plugin.py} | 8 +- .../plugins/test_sqlalchemy_init_plugin.py | 20 ++ docs/reference/contrib/index.rst | 1 + .../contrib/sqlalchemy/config/asyncio.rst | 5 + .../contrib/sqlalchemy/config/common.rst | 5 + .../contrib/sqlalchemy/config/index.rst | 9 + .../contrib/sqlalchemy/config/sync.rst | 5 + docs/reference/contrib/sqlalchemy/index.rst | 8 + docs/reference/contrib/sqlalchemy/plugin.rst | 5 + docs/usage/plugins/sqlalchemy.rst | 138 ++--------- docs/usage/responses.rst | 2 +- starlite/constants.py | 5 +- .../sqlalchemy/init_plugin/__init__.py | 19 ++ .../sqlalchemy/init_plugin/config/__init__.py | 13 ++ .../sqlalchemy/init_plugin/config/asyncio.py | 92 ++++++++ .../sqlalchemy/init_plugin/config/common.py | 215 ++++++++++++++++++ .../sqlalchemy/init_plugin/config/engine.py | 76 +++++++ .../sqlalchemy/init_plugin/config/sync.py | 87 +++++++ .../contrib/sqlalchemy/init_plugin/plugin.py | 45 ++++ starlite/utils/__init__.py | 2 + starlite/utils/dataclass.py | 56 ++++- starlite/utils/scope.py | 18 +- .../sqlalchemy/init_plugin/__init__.py | 0 .../sqlalchemy/init_plugin/config/__init__.py | 0 .../init_plugin/config/test_common.py | 166 ++++++++++++++ 34 files changed, 930 insertions(+), 137 deletions(-) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/__init__.py (100%) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/sqlalchemy_async.py (100%) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/sqlalchemy_relationships.py (100%) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/sqlalchemy_relationships_to_many.py (100%) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/sqlalchemy_sync.py (100%) create mode 100644 docs/examples/plugins/sqlalchemy_init_plugin/__init__.py create mode 100644 docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py create mode 100644 docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py rename docs/examples/tests/plugins/{test_sqlalchemy_plugin.py => test_sqlalchemy_1_plugin.py} (83%) create mode 100644 docs/examples/tests/plugins/test_sqlalchemy_init_plugin.py create mode 100644 docs/reference/contrib/sqlalchemy/config/asyncio.rst create mode 100644 docs/reference/contrib/sqlalchemy/config/common.rst create mode 100644 docs/reference/contrib/sqlalchemy/config/index.rst create mode 100644 docs/reference/contrib/sqlalchemy/config/sync.rst create mode 100644 docs/reference/contrib/sqlalchemy/index.rst create mode 100644 docs/reference/contrib/sqlalchemy/plugin.rst create mode 100644 starlite/contrib/sqlalchemy/init_plugin/__init__.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/__init__.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/common.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/engine.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/sync.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/plugin.py create mode 100644 tests/contrib/sqlalchemy/init_plugin/__init__.py create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/__init__.py create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/test_common.py diff --git a/docs/conf.py b/docs/conf.py index b33c43c698..99cab96418 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -35,7 +35,7 @@ "msgspec": ("https://jcristharif.com/msgspec/", None), "anyio": ("https://anyio.readthedocs.io/en/stable/", None), "multidict": ("https://multidict.aio-libs.org/en/stable/", None), - "sqlalchemy": ("https://docs.sqlalchemy.org/en/14/", None), + "sqlalchemy": ("https://docs.sqlalchemy.org/en/20/", None), "click": ("https://click.palletsprojects.com/en/8.1.x/", None), "redis": ("https://redis-py.readthedocs.io/en/stable/", None), "picologging": ("https://microsoft.github.io/picologging", None), @@ -113,6 +113,9 @@ "starlite.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin.handle_string_type": {"BINARY", "VARBINARY", "LargeBinary"}, "starlite.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin.is_plugin_supported_type": {"DeclarativeMeta"}, re.compile(r"starlite\.plugins.*"): re.compile(".*(ModelT|DataContainerT)"), + re.compile(r"starlite\.contrib\.sqlalchemy\.init_plugin\.config\.common.*"): re.compile( + ".*(ConnectionT|EngineT|SessionT|SessionMakerT)" + ), } diff --git a/docs/examples/plugins/sqlalchemy_plugin/__init__.py b/docs/examples/plugins/sqlalchemy_1_plugin/__init__.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/__init__.py rename to docs/examples/plugins/sqlalchemy_1_plugin/__init__.py diff --git a/docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_async.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_async.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_async.py rename to docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_async.py diff --git a/docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships.py rename to docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships.py diff --git a/docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships_to_many.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships_to_many.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships_to_many.py rename to docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships_to_many.py diff --git a/docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_sync.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_sync.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_sync.py rename to docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_sync.py diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/__init__.py b/docs/examples/plugins/sqlalchemy_init_plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py new file mode 100644 index 0000000000..a9f03a7fb0 --- /dev/null +++ b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlalchemy import text + +from starlite import Starlite, get +from starlite.contrib.sqlalchemy.init_plugin import SQLAlchemyAsyncConfig, SQLAlchemyInitPlugin + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + + +@get(path="/sqlalchemy-app") +async def async_sqlalchemy_init(db_session: AsyncSession, db_engine: AsyncEngine) -> str: + """Create a new company and return it.""" + + one = (await db_session.execute(text("SELECT 1"))).scalar_one() + + async with db_engine.begin() as conn: + two = (await conn.execute(text("SELECT 2"))).scalar_one() + + return f"{one} {two}" + + +sqlalchemy_config = SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite") + +app = Starlite( + route_handlers=[async_sqlalchemy_init], + plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)], +) diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py new file mode 100644 index 0000000000..e0572b1440 --- /dev/null +++ b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlalchemy import text + +from starlite import Starlite, get +from starlite.contrib.sqlalchemy.init_plugin import SQLAlchemyInitPlugin, SQLAlchemySyncConfig + +if TYPE_CHECKING: + from sqlalchemy import Engine + from sqlalchemy.orm import Session + + +@get(path="/sqlalchemy-app") +def async_sqlalchemy_init(db_session: Session, db_engine: Engine) -> str: + """Create a new company and return it.""" + one = db_session.execute(text("SELECT 1")).scalar_one() + + with db_engine.connect() as conn: + two = conn.execute(text("SELECT 2")).scalar_one() + + return f"{one} {two}" + + +sqlalchemy_config = SQLAlchemySyncConfig(connection_string="sqlite:///test.sqlite") + +app = Starlite( + route_handlers=[async_sqlalchemy_init], + plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)], +) diff --git a/docs/examples/tests/plugins/test_sqlalchemy_plugin.py b/docs/examples/tests/plugins/test_sqlalchemy_1_plugin.py similarity index 83% rename from docs/examples/tests/plugins/test_sqlalchemy_plugin.py rename to docs/examples/tests/plugins/test_sqlalchemy_1_plugin.py index a327046bb6..1cc7727d6c 100644 --- a/docs/examples/tests/plugins/test_sqlalchemy_plugin.py +++ b/docs/examples/tests/plugins/test_sqlalchemy_1_plugin.py @@ -3,14 +3,14 @@ import pytest -from examples.plugins.sqlalchemy_plugin.sqlalchemy_async import app as async_sqla_app -from examples.plugins.sqlalchemy_plugin.sqlalchemy_relationships import ( +from examples.plugins.sqlalchemy_1_plugin.sqlalchemy_async import app as async_sqla_app +from examples.plugins.sqlalchemy_1_plugin.sqlalchemy_relationships import ( app as relationship_app, ) -from examples.plugins.sqlalchemy_plugin.sqlalchemy_relationships_to_many import ( +from examples.plugins.sqlalchemy_1_plugin.sqlalchemy_relationships_to_many import ( app as relationship_app_to_many, ) -from examples.plugins.sqlalchemy_plugin.sqlalchemy_sync import app as sync_sqla_app +from examples.plugins.sqlalchemy_1_plugin.sqlalchemy_sync import app as sync_sqla_app from starlite import Starlite from starlite.testing import TestClient diff --git a/docs/examples/tests/plugins/test_sqlalchemy_init_plugin.py b/docs/examples/tests/plugins/test_sqlalchemy_init_plugin.py new file mode 100644 index 0000000000..7e3d93a7df --- /dev/null +++ b/docs/examples/tests/plugins/test_sqlalchemy_init_plugin.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from examples.plugins.sqlalchemy_init_plugin.sqlalchemy_async import app as async_sqla_app +from examples.plugins.sqlalchemy_init_plugin.sqlalchemy_sync import app as sync_sqla_app +from starlite.testing import TestClient + +if TYPE_CHECKING: + from starlite import Starlite + + +@pytest.mark.parametrize("app", [async_sqla_app, sync_sqla_app]) +def test_app(app: Starlite) -> None: + with TestClient(app=app) as client: + res = client.get("/sqlalchemy-app") + assert res.status_code == 200 + assert res.json() == "1 2" diff --git a/docs/reference/contrib/index.rst b/docs/reference/contrib/index.rst index e8b35dab0f..52970c1da4 100644 --- a/docs/reference/contrib/index.rst +++ b/docs/reference/contrib/index.rst @@ -10,5 +10,6 @@ contrib mako opentelemetry piccolo_orm + sqlalchemy/index sqlalchemy_1/index tortoise_orm diff --git a/docs/reference/contrib/sqlalchemy/config/asyncio.rst b/docs/reference/contrib/sqlalchemy/config/asyncio.rst new file mode 100644 index 0000000000..e138610870 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config/asyncio.rst @@ -0,0 +1,5 @@ +asyncio +======= + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.asyncio + :members: diff --git a/docs/reference/contrib/sqlalchemy/config/common.rst b/docs/reference/contrib/sqlalchemy/config/common.rst new file mode 100644 index 0000000000..e8e2467cf7 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config/common.rst @@ -0,0 +1,5 @@ +asyncio +======= + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.common + :members: diff --git a/docs/reference/contrib/sqlalchemy/config/index.rst b/docs/reference/contrib/sqlalchemy/config/index.rst new file mode 100644 index 0000000000..1919a60794 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config/index.rst @@ -0,0 +1,9 @@ +config +====== + +.. toctree:: + :titlesonly: + + asyncio + common + sync diff --git a/docs/reference/contrib/sqlalchemy/config/sync.rst b/docs/reference/contrib/sqlalchemy/config/sync.rst new file mode 100644 index 0000000000..ddf769ccf1 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config/sync.rst @@ -0,0 +1,5 @@ +sync +==== + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.sync + :members: diff --git a/docs/reference/contrib/sqlalchemy/index.rst b/docs/reference/contrib/sqlalchemy/index.rst new file mode 100644 index 0000000000..3ee4e8d69b --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/index.rst @@ -0,0 +1,8 @@ +sqlalchemy +========== + +.. toctree:: + :titlesonly: + + config/index + plugin diff --git a/docs/reference/contrib/sqlalchemy/plugin.rst b/docs/reference/contrib/sqlalchemy/plugin.rst new file mode 100644 index 0000000000..1a2b06b0c3 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/plugin.rst @@ -0,0 +1,5 @@ +plugin +====== + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.plugin + :members: diff --git a/docs/usage/plugins/sqlalchemy.rst b/docs/usage/plugins/sqlalchemy.rst index 8f4d4af4da..bc98606da0 100644 --- a/docs/usage/plugins/sqlalchemy.rst +++ b/docs/usage/plugins/sqlalchemy.rst @@ -1,46 +1,29 @@ -SQLAlchemy Plugin -================= +SQLAlchemy Plugins +================== Starlite comes with built-in support for `SQLAlchemy `_ via -the :class:`SQLAlchemyPlugin <.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin>`. +the :class:`SQLAlchemyInitPlugin <.contrib.sqlalchemy.init_plugin.plugin.SQLAlchemyInitPlugin>`. Features -------- - * Managed `sessions `_ (sync and async) including dependency injection -* Automatic serialization of SQLAlchemy models powered pydantic -* Data validation based on SQLAlchemy models powered pydantic - -.. seealso:: - - The following examples use SQLAlchemy's "2.0 Style" introduced in SQLAlchemy 1.4. - - If you are unfamiliar with it, you can find a comprehensive migration guide in SQLAlchemy's - documentation `here `_, - and `a handy table `_ - comparing the ORM usage - -.. attention:: - - The :class:`SQLAlchemyPlugin <.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin>` supports only - `mapped classes `_. - `Tables `_ are - currently not supported since they are not easy to convert to pydantic models. +* Managed `engine `_ (sync and async) including dependency injection +* Typed configuration objects Basic Use --------- -You can simply pass an instance of :class:`SQLAlchemyPlugin <.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin>` without -passing config to the Starlite constructor. This will extend support for serialization, deserialization and DTO creation -for SQLAlchemy declarative models: +You can simply pass an instance of :class:`SQLAlchemyInitPlugin <.contrib.sqlalchemy.init_plugin.plugin.SQLAlchemyInitPlugin>` +to the Starlite constructor. This will automatically create a SQLAlchemy engine and session for you, and make them +available to your handlers and dependencies via dependency injection. .. tab-set:: .. tab-item:: Async :sync: async - .. literalinclude:: /examples/plugins/sqlalchemy_plugin/sqlalchemy_async.py + .. literalinclude:: /examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py :caption: sqlalchemy_plugin.py :language: python @@ -48,109 +31,14 @@ for SQLAlchemy declarative models: .. tab-item:: Sync :sync: sync - .. literalinclude:: /examples/plugins/sqlalchemy_plugin/sqlalchemy_sync.py + .. literalinclude:: /examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py :caption: sqlalchemy_plugin.py :language: python -.. admonition:: Using imperative mappings - :class: info - - `Imperative mappings `_ - are supported as well, just make sure to use a mapped class instead of the table itself - - .. code-block:: python - - company_table = Table( - "company", - Base.registry.metadata, - Column("id", Integer, primary_key=True), - Column("name", String), - Column("worth", Float), - ) - - - class Company: - pass - - - Base.registry.map_imperatively(Company, company_table) - - -Relationships -------------- - -.. attention:: - - Currently only to-one relationships are supported because of the way the SQLAlchemy plugin handles relationships. - Since it recursively traverses relationships, a cyclic reference will result in an endless loop. To prevent this, - these relationships will be type as :class:`typing.Any` in the pydantic model - Relationships are typed as :class:`typing.Optional` in the pydantic model by default so sending incomplete models - won't cause any issues. - - -Simple relationships -^^^^^^^^^^^^^^^^^^^^ - -Simple relationships can be handled by the plugin automatically: - -.. literalinclude:: /examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships.py - :caption: sqlalchemy_relationships.py - :language: python - - -.. admonition:: Example - :class: tip - - Run the above with ``uvicorn sqlalchemy_relationships:app``, navigate your browser to - `http://127.0.0.0:8000/user/1 `_ - and you will see: - - .. code-block:: json - - { - "id": 1, - "name": "Peter", - "company_id": 1, - "company": { - "id": 1, - "name": "Peter Co.", - "worth": 0 - } - } - - -To-Many relationships and circular references -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For to-many relationships or those that contain circular references you need to define the pydantic models yourself: - -.. literalinclude:: /examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships_to_many.py - :caption: sqlalchemy_relationships_to_many - :language: python - - -.. admonition:: Example - :class: tip - - Run the above with ``uvicorn sqlalchemy_relationships_to_many:app``, navigate your browser to `http://127.0.0.0:8000/user/1`_ - and you will see: - - .. code-block:: json - - { - "id": 1, - "name": "Peter", - "pets": [ - { - "id": 1, - "name": "Paul" - } - ] - } - - Configuration ------------- -You can configure the Plugin using the :class:`SQLAlchemyConfig <.contrib.sqlalchemy_1.config.SQLAlchemyConfig>` object. +You configure the Plugin using either +:class:`SQLAlchemyAsyncConfig <.contrib.sqlalchemy.init_plugin.config.asyncio.SQLAlchemyAsyncConfig>` or +:class:`SQLAlchemySyncConfig <.contrib.sqlalchemy.init_plugin.config.sync.SQLAlchemySyncConfig>`. diff --git a/docs/usage/responses.rst b/docs/usage/responses.rst index 6158a5c1fd..a31ff33ba8 100644 --- a/docs/usage/responses.rst +++ b/docs/usage/responses.rst @@ -825,7 +825,7 @@ kwargs>` :language: python -See :ref:`SQLAlchemy plugin ` for sqlalchemy integration. +See :ref:`SQLAlchemy plugin ` for sqlalchemy integration. Cursor Pagination +++++++++++++++++ diff --git a/starlite/constants.py b/starlite/constants.py index 14a20ea72c..ae012e1afa 100644 --- a/starlite/constants.py +++ b/starlite/constants.py @@ -6,6 +6,7 @@ DEFAULT_ALLOWED_CORS_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} DEFAULT_CHUNK_SIZE = 1024 * 128 # 128KB +HTTP_DISCONNECT = "http.disconnect" HTTP_RESPONSE_BODY = "http.response.body" HTTP_RESPONSE_START = "http.response.start" ONE_MEGABYTE = 1024 * 1024 @@ -14,5 +15,7 @@ SCOPE_STATE_DEPENDENCY_CACHE = "dependency_cache" SCOPE_STATE_NAMESPACE = "__starlite__" SCOPE_STATE_RESPONSE_COMPRESSED = "response_compressed" -UNDEFINED_SENTINELS = {Undefined, Signature.empty, Empty, Ellipsis} SKIP_VALIDATION_NAMES = {"request", "socket", "scope", "receive", "send"} +UNDEFINED_SENTINELS = {Undefined, Signature.empty, Empty, Ellipsis} +WEBSOCKET_CLOSE = "websocket.close" +WEBSOCKET_DISCONNECT = "websocket.disconnect" diff --git a/starlite/contrib/sqlalchemy/init_plugin/__init__.py b/starlite/contrib/sqlalchemy/init_plugin/__init__.py new file mode 100644 index 0000000000..782770ea09 --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/__init__.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from .config import ( + AsyncSessionConfig, + EngineConfig, + SQLAlchemyAsyncConfig, + SQLAlchemySyncConfig, + SyncSessionConfig, +) +from .plugin import SQLAlchemyInitPlugin + +__all__ = ( + "AsyncSessionConfig", + "EngineConfig", + "SQLAlchemyAsyncConfig", + "SQLAlchemyInitPlugin", + "SQLAlchemySyncConfig", + "SyncSessionConfig", +) diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py b/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py new file mode 100644 index 0000000000..deff28a94c --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from .asyncio import AsyncSessionConfig, SQLAlchemyAsyncConfig +from .engine import EngineConfig +from .sync import SQLAlchemySyncConfig, SyncSessionConfig + +__all__ = ( + "AsyncSessionConfig", + "EngineConfig", + "SQLAlchemyAsyncConfig", + "SQLAlchemySyncConfig", + "SyncSessionConfig", +) diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py new file mode 100644 index 0000000000..89ab40b3f5 --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, cast + +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine + +from starlite.types import Empty +from starlite.utils import ( + delete_starlite_scope_state, + get_starlite_scope_state, +) + +from .common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig + +if TYPE_CHECKING: + from typing import Any, Callable + + from sqlalchemy.orm import Session + + from starlite.datastructures.state import State + from starlite.types import BeforeMessageSendHookHandler, EmptyType, Message, Scope + +__all__ = ("SQLAlchemyAsyncConfig", "AsyncSessionConfig") + + +async def default_before_send_handler(message: Message, _: State, scope: Scope) -> None: + """Handle closing and cleaning up sessions before sending. + + Args: + message: ASGI-``Message`` + _: A ``State`` (not used) + scope: An ASGI-``Scope`` + + Returns: + None + """ + session = cast("AsyncSession | None", get_starlite_scope_state(scope, SESSION_SCOPE_KEY)) + if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: + await session.close() + delete_starlite_scope_state(scope, SESSION_SCOPE_KEY) + + +@dataclass +class AsyncSessionConfig(GenericSessionConfig[AsyncConnection, AsyncEngine, AsyncSession]): + """SQLAlchemy async session config.""" + + sync_session_class: type[Session] | None | EmptyType = Empty + + +@dataclass +class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, async_sessionmaker]): + """Async SQLAlchemy Configuration.""" + + create_engine_callable: Callable[[str], AsyncEngine] = create_async_engine + """Callable that creates an :class:`AsyncEngine ` instance or instance of its + subclass. + """ + session_config: AsyncSessionConfig = field(default_factory=AsyncSessionConfig) + """Configuration options for the ``sessionmaker``. + + The configuration options are documented in the SQLAlchemy documentation. + """ + session_maker_class: type[async_sessionmaker] = async_sessionmaker + """Sessionmaker class to use.""" + before_send_handler: BeforeMessageSendHookHandler = default_before_send_handler + """Handler to call before the ASGI message is sent. + + The handler should handle closing the session stored in the ASGI scope, if its still open, and committing and + uncommitted data. + """ + + @property + def signature_namespace(self) -> dict[str, Any]: + """Return the plugin's signature namespace. + + Returns: + A string keyed dict of names to be added to the namespace for signature forward reference resolution. + """ + return {"AsyncEngine": AsyncEngine, "AsyncSession": AsyncSession} + + async def on_shutdown(self, state: State) -> None: + """Disposes of the SQLAlchemy engine. + + Args: + state: The ``Starlite.state`` instance. + + Returns: + None + """ + engine = cast("AsyncEngine", state.pop(self.engine_app_state_key)) + await engine.dispose() diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/common.py b/starlite/contrib/sqlalchemy/init_plugin/config/common.py new file mode 100644 index 0000000000..a0a25d7eeb --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/common.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable, Generic, TypeVar, cast + +from starlite.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT +from starlite.exceptions import ImproperlyConfiguredException +from starlite.types import Empty +from starlite.utils import get_starlite_scope_state, set_starlite_scope_state +from starlite.utils.dataclass import simple_asdict_filter_empty + +from .engine import EngineConfig + +if TYPE_CHECKING: + from typing import Any + + from sqlalchemy import Connection, Engine + from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker + from sqlalchemy.orm import Mapper, Query, Session, sessionmaker + from sqlalchemy.orm.session import JoinTransactionMode + from sqlalchemy.sql import TableClause + + from starlite.datastructures.state import State + from starlite.types import BeforeMessageSendHookHandler, EmptyType, Scope + +__all__ = ( + "SESSION_SCOPE_KEY", + "SESSION_TERMINUS_ASGI_EVENTS", + "GenericSQLAlchemyConfig", + "GenericSessionConfig", +) + +SESSION_SCOPE_KEY = "_sqlalchemy_db_session" +SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE} + +ConnectionT = TypeVar("ConnectionT", bound="Connection | AsyncConnection") +EngineT = TypeVar("EngineT", bound="Engine | AsyncEngine") +SessionT = TypeVar("SessionT", bound="Session | AsyncSession") +SessionMakerT = TypeVar("SessionMakerT", bound="sessionmaker | async_sessionmaker") + + +@dataclass +class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]): + """SQLAlchemy async session config.""" + + autobegin: bool | EmptyType = Empty + autoflush: bool | EmptyType = Empty + bind: EngineT | ConnectionT | None | EmptyType = Empty + binds: dict[type[Any] | Mapper[Any] | TableClause | str, EngineT | ConnectionT] | None | EmptyType = Empty + class_: type[SessionT] | EmptyType = Empty + enable_baked_queries: bool | EmptyType = Empty + expire_on_commit: bool | EmptyType = Empty + info: dict[str, Any] | None | EmptyType = Empty + join_transaction_mode: JoinTransactionMode | EmptyType = Empty + query_cls: type[Query] | None | EmptyType = Empty + twophase: bool | EmptyType = Empty + + +@dataclass +class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]): + """Common SQLAlchemy Configuration.""" + + create_engine_callable: Callable[[str], EngineT] + """Callable that creates an :class:`AsyncEngine ` instance or instance of its + subclass. + """ + session_config: GenericSessionConfig + """Configuration options for the ``sessionmaker``. + + The configuration options are documented in the SQLAlchemy documentation. + """ + session_maker_class: type[sessionmaker] | type[async_sessionmaker] + """Sessionmaker class to use.""" + before_send_handler: BeforeMessageSendHookHandler + """Handler to call before the ASGI message is sent. + + The handler should handle closing the session stored in the ASGI scope, if its still open, and committing and + uncommitted data. + """ + connection_string: str | None = field(default=None) + """Database connection string in one of the formats supported by SQLAlchemy. + + Notes: + - For async connections, the connection string must include the correct async prefix. + e.g. ``'postgresql+asyncpg://...'`` instead of ``'postgresql://'``, and for sync connections its the opposite. + + """ + engine_dependency_key: str = "db_engine" + """Key to use for the dependency injection of database engines.""" + session_dependency_key: str = "db_session" + """Key to use for the dependency injection of database sessions.""" + engine_app_state_key: str = "db_engine" + """Key under which to store the SQLAlchemy engine in the application :class:`State <.datastructures.State>` + instance. + """ + engine_config: EngineConfig = field(default_factory=EngineConfig) + """Configuration for the SQLAlchemy engine. + + The configuration options are documented in the SQLAlchemy documentation. + """ + session_maker_app_state_key: str = "session_maker_class" + """Key under which to store the SQLAlchemy ``sessionmaker`` in the application + :class:`State <.datastructures.State>` instance. + """ + session_maker: Callable[[], SessionT] | None = None + """Callable that returns a session. + + If provided, the plugin will use this rather than instantiate a sessionmaker. + """ + engine_instance: EngineT | None = None + """Optional engine to use. + + If set, the plugin will use the provided instance rather than instantiate an engine. + """ + + def __post_init__(self) -> None: + if self.connection_string is not None and self.engine_instance is not None: + raise ImproperlyConfiguredException("Only one of 'connection_string' or 'engine_instance' can be provided.") + + @property + def engine_config_dict(self) -> dict[str, Any]: + """Return the engine configuration as a dict. + + Returns: + A string keyed dict of config kwargs for the SQLAlchemy ``create_engine`` function. + """ + return simple_asdict_filter_empty(self.engine_config) + + @property + def session_config_dict(self) -> dict[str, Any]: + """Return the session configuration as a dict. + + Returns: + A string keyed dict of config kwargs for the SQLAlchemy ``sessionmaker`` class. + """ + return simple_asdict_filter_empty(self.session_config) + + @property + def signature_namespace(self) -> dict[str, Any]: + """Return the plugin's signature namespace. + + Returns: + A string keyed dict of names to be added to the namespace for signature forward reference resolution. + """ + return {} + + def create_engine(self) -> EngineT: + """Return an engine. If none exists yet, create one. + + Returns: + Getter that returns the engine instance used by the plugin. + """ + if self.engine_instance: + return self.engine_instance + + if self.connection_string is None: + raise ImproperlyConfiguredException("One of 'connection_string' or 'engine_instance' must be provided.") + + engine_config = self.engine_config_dict + try: + return self.create_engine_callable(self.connection_string, **engine_config) + except ValueError: + # likely due to a dialect that doesn't support json type + del engine_config["json_deserializer"] + del engine_config["json_serializer"] + return self.create_engine_callable(self.connection_string, **engine_config) + + def create_session_maker(self) -> Callable[[], SessionT]: + """Get a session maker. If none exists yet, create one. + + Returns: + Session factory used by the plugin. + """ + if self.session_maker: + return self.session_maker + + session_kws = self.session_config_dict + if session_kws.get("bind") is None: + session_kws["bind"] = self.create_engine() + return self.session_maker_class(**session_kws) + + def provide_engine(self, state: State) -> EngineT: + """Create an engine instance. + + Args: + state: The ``Starlite.state`` instance. + + Returns: + An engine instance. + """ + return cast("EngineT", state.get(self.engine_app_state_key)) + + def provide_session(self, state: State, scope: Scope) -> SessionT: + """Create a session instance. + + Args: + state: The ``Starlite.state`` instance. + scope: The current connection's scope. + + Returns: + A session instance. + """ + session = cast("SessionT | None", get_starlite_scope_state(scope, SESSION_SCOPE_KEY)) + if session is None: + session_maker = cast("Callable[[], SessionT]", state[self.session_maker_app_state_key]) + session = session_maker() + set_starlite_scope_state(scope, SESSION_SCOPE_KEY, session) + return session + + def app_state(self) -> dict[str, Any]: + """Key/value pairs to be stored in application state.""" + return { + self.engine_app_state_key: self.create_engine(), + self.session_maker_app_state_key: self.create_session_maker(), + } diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/engine.py b/starlite/contrib/sqlalchemy/init_plugin/config/engine.py new file mode 100644 index 0000000000..71659bc3ef --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/engine.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Literal + +from starlite.exceptions import MissingDependencyException +from starlite.serialization import decode_json, encode_json +from starlite.types import Empty + +try: + import sqlalchemy # noqa: F401 +except ImportError as e: + raise MissingDependencyException("sqlalchemy is not installed") from e + +if TYPE_CHECKING: + from typing import Any, Mapping + + from sqlalchemy.engine.interfaces import IsolationLevel + from sqlalchemy.pool import Pool + from typing_extensions import TypeAlias + + from starlite.types import EmptyType + +__all__ = ("EngineConfig",) + +_EchoFlagType: TypeAlias = "None | bool | Literal['debug']" +_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat", "numeric_dollar"] + + +def serializer(value: Any) -> str: + """Serialize JSON field values. + + Args: + value: Any json serializable value. + + Returns: + JSON string. + """ + return encode_json(value).decode("utf-8") + + +@dataclass +class EngineConfig: + """Configuration for SQLAlchemy's :class`Engine `. + + For details see: https://docs.sqlalchemy.org/en/20/core/engines.html + """ + + connect_args: dict[Any, Any] | EmptyType = Empty + echo: _EchoFlagType | EmptyType = Empty + echo_pool: _EchoFlagType | EmptyType = Empty + enable_from_linting: bool | EmptyType = Empty + execution_options: Mapping[str, Any] | EmptyType = Empty + hide_parameters: bool | EmptyType = Empty + insertmanyvalues_page_size: int | EmptyType = Empty + isolation_level: IsolationLevel | EmptyType = Empty + json_deserializer: Callable[[str], Any] = decode_json + json_serializer: Callable[[Any], str] = serializer + label_length: int | None | EmptyType = Empty + logging_name: str | EmptyType = Empty + max_identifier_length: int | None | EmptyType = Empty + max_overflow: int | EmptyType = Empty + module: Any | None | EmptyType = Empty + paramstyle: _ParamStyle | None | EmptyType = Empty + pool: Pool | None | EmptyType = Empty + poolclass: type[Pool] | None | EmptyType = Empty + pool_logging_name: str | EmptyType = Empty + pool_pre_ping: bool | EmptyType = Empty + pool_size: int | EmptyType = Empty + pool_recycle: int | EmptyType = Empty + pool_reset_on_return: Literal["rollback", "commit"] | EmptyType = Empty + pool_timeout: int | EmptyType = Empty + pool_use_lifo: bool | EmptyType = Empty + plugins: list[str] | EmptyType = Empty + query_cache_size: int | EmptyType = Empty + use_insertmanyvalues: bool | EmptyType = Empty diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/sync.py b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py new file mode 100644 index 0000000000..1652d2b0b0 --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, cast + +from sqlalchemy import Connection, Engine, create_engine +from sqlalchemy.orm import Session, sessionmaker + +from starlite.utils import ( + delete_starlite_scope_state, + get_starlite_scope_state, +) + +from .common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig + +if TYPE_CHECKING: + from typing import Any, Callable + + from starlite.datastructures.state import State + from starlite.types import BeforeMessageSendHookHandler, Message, Scope + +__all__ = ("SQLAlchemySyncConfig", "SyncSessionConfig") + + +async def default_before_send_handler(message: Message, _: State, scope: Scope) -> None: + """Handle closing and cleaning up sessions before sending. + + Args: + message: ASGI-``Message`` + _: A ``State`` (not used) + scope: An ASGI-``Scope`` + + Returns: + None + """ + session = cast("Session | None", get_starlite_scope_state(scope, SESSION_SCOPE_KEY)) + if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: + session.close() + delete_starlite_scope_state(scope, SESSION_SCOPE_KEY) + + +class SyncSessionConfig(GenericSessionConfig[Connection, Engine, Session]): + pass + + +@dataclass +class SQLAlchemySyncConfig(GenericSQLAlchemyConfig[Engine, Session, sessionmaker]): + """Sync SQLAlchemy Configuration.""" + + create_engine_callable: Callable[[str], Engine] = create_engine + """Callable that creates an :class:`AsyncEngine ` instance or instance of its + subclass. + """ + session_config: SyncSessionConfig = field(default_factory=SyncSessionConfig) # pyright:ignore + """Configuration options for the ``sessionmaker``. + + The configuration options are documented in the SQLAlchemy documentation. + """ + session_maker_class: type[sessionmaker] = sessionmaker + """Sessionmaker class to use.""" + before_send_handler: BeforeMessageSendHookHandler = default_before_send_handler + """Handler to call before the ASGI message is sent. + + The handler should handle closing the session stored in the ASGI scope, if its still open, and committing and + uncommitted data. + """ + + @property + def signature_namespace(self) -> dict[str, Any]: + """Return the plugin's signature namespace. + + Returns: + A string keyed dict of names to be added to the namespace for signature forward reference resolution. + """ + return {"Engine": Engine, "Session": Session} + + def on_shutdown(self, state: State) -> None: + """Disposes of the SQLAlchemy engine. + + Args: + state: The ``Starlite.state`` instance. + + Returns: + None + """ + engine = cast("Engine", state.pop(self.engine_app_state_key)) + engine.dispose() diff --git a/starlite/contrib/sqlalchemy/init_plugin/plugin.py b/starlite/contrib/sqlalchemy/init_plugin/plugin.py new file mode 100644 index 0000000000..5841f71bd7 --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/plugin.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from starlite.di import Provide +from starlite.plugins import InitPluginProtocol + +if TYPE_CHECKING: + from starlite.config.app import AppConfig + + from .config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig + +__all__ = ("SQLAlchemyInitPlugin",) + + +class SQLAlchemyInitPlugin(InitPluginProtocol): + """SQLAlchemy application lifecycle configuration.""" + + __slots__ = ("_config",) + + def __init__(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None: + """Initialize ``SQLAlchemyPlugin``. + + Args: + config: configure DB connection and hook handlers and dependencies. + """ + self._config = config + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Configure application for use with SQLAlchemy. + + Args: + app_config: The :class:`AppConfig <.config.app.AppConfig>` instance. + """ + app_config.dependencies.update( + { + self._config.engine_dependency_key: Provide(self._config.provide_engine), + self._config.session_dependency_key: Provide(self._config.provide_session), + } + ) + app_config.before_send.append(self._config.before_send_handler) + app_config.on_shutdown.append(self._config.on_shutdown) + app_config.state.update(self._config.app_state()) + app_config.signature_namespace.update(self._config.signature_namespace) + return app_config diff --git a/starlite/utils/__init__.py b/starlite/utils/__init__.py index d3fb57f0a3..55e4056b63 100644 --- a/starlite/utils/__init__.py +++ b/starlite/utils/__init__.py @@ -19,6 +19,7 @@ create_parsed_model_field, ) from .scope import ( + delete_starlite_scope_state, get_serializer_from_scope, get_starlite_scope_state, set_starlite_scope_state, @@ -43,6 +44,7 @@ "convert_dataclass_to_model", "convert_typeddict_to_model", "create_parsed_model_field", + "delete_starlite_scope_state", "deprecated", "find_index", "get_enum_string_value", diff --git a/starlite/utils/dataclass.py b/starlite/utils/dataclass.py index 1ae226bf70..5f9c77ca15 100644 --- a/starlite/utils/dataclass.py +++ b/starlite/utils/dataclass.py @@ -1,12 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, cast - -__all__ = ("extract_dataclass_fields",) +from dataclasses import asdict, fields +from typing import TYPE_CHECKING, cast +from starlite.types import DataclassProtocol, Empty if TYPE_CHECKING: - from starlite.types import DataclassProtocol + from typing import Any, Iterable + +__all__ = ( + "asdict_filter_empty", + "extract_dataclass_fields", +) def extract_dataclass_fields( @@ -28,3 +33,46 @@ def extract_dataclass_fields( if (not exclude_none or getattr(dt, field_name) is not None) and ((include is not None and field_name in include) or include is None) ) + + +def asdict_filter_empty(obj: DataclassProtocol) -> dict[str, Any]: + """Same as stdlib's ``dataclasses.asdict`` with additional filtering for :class:`Empty<.types.Empty>`. + + Args: + obj: A dataclass instance. + + Returns: + ``obj`` converted into a ``dict`` of its fields, with any :class:`Empty<.types.Empty>` values excluded. + """ + return {k: v for k, v in asdict(obj).items() if v is not Empty} + + +def simple_asdict(obj: DataclassProtocol) -> dict[str, Any]: + """Recursively convert a dataclass instance into a ``dict`` of its fields, without using ``copy.deepcopy()``. + + The standard library ``dataclasses.asdict()`` function uses ``copy.deepcopy()`` on any value that is not a + dataclass, dict, list or tuple, which presents a problem when the dataclass holds items that cannot be pickled. + + This function provides an alternative that does not use ``copy.deepcopy()``, and is a much simpler implementation, + only recursing into other dataclasses. + + Args: + obj: A dataclass instance. + + Returns: + ``obj`` converted into a ``dict`` of its fields. + """ + field_values = ((field.name, getattr(obj, field.name)) for field in fields(obj)) + return {k: simple_asdict(v) if isinstance(v, DataclassProtocol) else v for k, v in field_values} + + +def simple_asdict_filter_empty(obj: DataclassProtocol) -> dict[str, Any]: + """Same as asdict_filter_empty but uses ``simple_asdict``. + + Args: + obj: A dataclass instance. + + Returns: + ``obj`` converted into a ``dict`` of its fields, with any :class:`Empty<.types.Empty>` values excluded. + """ + return {k: v for k, v in simple_asdict(obj).items() if v is not Empty} diff --git a/starlite/utils/scope.py b/starlite/utils/scope.py index c6609c42cd..db197bf74d 100644 --- a/starlite/utils/scope.py +++ b/starlite/utils/scope.py @@ -4,7 +4,12 @@ from starlite.constants import SCOPE_STATE_NAMESPACE -__all__ = ("get_serializer_from_scope", "get_starlite_scope_state", "set_starlite_scope_state") +__all__ = ( + "delete_starlite_scope_state", + "get_serializer_from_scope", + "get_starlite_scope_state", + "set_starlite_scope_state", +) if TYPE_CHECKING: @@ -69,3 +74,14 @@ def set_starlite_scope_state(scope: Scope, key: str, value: Any) -> None: value: Value for key. """ scope["state"].setdefault(SCOPE_STATE_NAMESPACE, {})[key] = value + + +def delete_starlite_scope_state(scope: Scope, key: str) -> None: + """Delete an internal value from connection scope state. + + Args: + scope: The connection scope. + key: Key to set under internal namespace in scope state. + value: Value for key. + """ + del scope["state"][SCOPE_STATE_NAMESPACE][key] diff --git a/tests/contrib/sqlalchemy/init_plugin/__init__.py b/tests/contrib/sqlalchemy/init_plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/contrib/sqlalchemy/init_plugin/config/__init__.py b/tests/contrib/sqlalchemy/init_plugin/config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_common.py b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py new file mode 100644 index 0000000000..8c6c0cbe49 --- /dev/null +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine + +from starlite.constants import SCOPE_STATE_NAMESPACE +from starlite.contrib.sqlalchemy.init_plugin.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig +from starlite.contrib.sqlalchemy.init_plugin.config.common import SESSION_SCOPE_KEY +from starlite.datastructures import State +from starlite.exceptions import ImproperlyConfiguredException + +if TYPE_CHECKING: + from typing import Any + + from pytest import MonkeyPatch + + from starlite.types import Scope + + +@pytest.fixture(name="config_cls", params=[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig]) +def _config_cls(request: Any) -> type[SQLAlchemySyncConfig | SQLAlchemyAsyncConfig]: + """Return SQLAlchemy config class.""" + return request.param # type:ignore[no-any-return] + + +def test_raise_improperly_configured_exception(config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig]) -> None: + """Test raise ImproperlyConfiguredException if both engine and connection string are provided.""" + with pytest.raises(ImproperlyConfiguredException): + config_cls(connection_string="sqlite://", engine_instance=create_engine("sqlite://")) + + +def test_engine_config_dict_with_no_provided_config( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test engine_config_dict with no provided config.""" + config = config_cls() + assert config.engine_config_dict.keys() == {"json_deserializer", "json_serializer"} + + +def test_session_config_dict_with_no_provided_config( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test session_config_dict with no provided config.""" + config = config_cls() + assert config.session_config_dict == {} + + +def test_config_create_engine_if_engine_instance_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test create_engine if engine instance provided.""" + engine = create_engine("sqlite://") + config = config_cls(engine_instance=engine) + assert config.create_engine() == engine + + +def test_create_engine_if_no_engine_instance_or_connection_string_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test create_engine if no engine instance or connection string provided.""" + config = config_cls() + with pytest.raises(ImproperlyConfiguredException): + config.create_engine() + + +def test_call_create_engine_callable_value_error_handling( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], monkeypatch: MonkeyPatch +) -> None: + """If the dialect doesn't support JSON types, we get a ValueError. + This should be handled by removing the JSON serializer/deserializer kwargs. + """ + call_count = 0 + + def side_effect(*args: Any, **kwargs: Any) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError() + + config = config_cls(connection_string="sqlite://") + create_engine_callable_mock = MagicMock(side_effect=side_effect) + monkeypatch.setattr(config, "create_engine_callable", create_engine_callable_mock) + + config.create_engine() + + assert create_engine_callable_mock.call_count == 2 + first_call, second_call = create_engine_callable_mock.mock_calls + assert first_call.kwargs.keys() == {"json_deserializer", "json_serializer"} + assert second_call.kwargs.keys() == set() + + +def test_create_session_maker_if_session_maker_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test create_session_maker if session maker provided to config.""" + session_maker = MagicMock() + config = config_cls(session_maker=session_maker) + assert config.create_session_maker() == session_maker + + +def test_create_session_maker_if_no_session_maker_provided_and_bind_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], monkeypatch: MonkeyPatch +) -> None: + """Test create_session_maker if no session maker provided to config.""" + config = config_cls() + config.session_config.bind = create_engine("sqlite://") + create_engine_mock = MagicMock() + monkeypatch.setattr(config, "create_engine", create_engine_mock) + assert config.session_maker is None + assert isinstance(config.create_session_maker(), config.session_maker_class) + create_engine_mock.assert_not_called() + + +def test_create_session_maker_if_no_session_maker_or_bind_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], monkeypatch: MonkeyPatch +) -> None: + """Test create_session_maker if no session maker or bind provided to config.""" + config = config_cls() + create_engine_mock = MagicMock(return_value=create_engine("sqlite://")) + monkeypatch.setattr(config, "create_engine", create_engine_mock) + assert config.session_maker is None + assert isinstance(config.create_session_maker(), config.session_maker_class) + create_engine_mock.assert_called_once() + + +def test_create_session_instance_if_session_already_in_scope_state( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test provide_session if session already in scope state.""" + with patch( + "starlite.contrib.sqlalchemy.init_plugin.config.common.get_starlite_scope_state" + ) as get_starlite_scope_state_mock: + session_mock = MagicMock() + get_starlite_scope_state_mock.return_value = session_mock + config = config_cls() + assert config.provide_session(State(), {}) is session_mock # type:ignore[arg-type] + + +def test_create_session_instance_if_session_not_in_scope_state( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test provide_session if session not in scope state.""" + with patch( + "starlite.contrib.sqlalchemy.init_plugin.config.common.get_starlite_scope_state" + ) as get_starlite_scope_state_mock: + get_starlite_scope_state_mock.return_value = None + config = config_cls() + state = State() + state[config.session_maker_app_state_key] = MagicMock() + scope: Scope = {"state": {}} # type:ignore[assignment] + assert isinstance(config.provide_session(state, scope), MagicMock) + assert SESSION_SCOPE_KEY in scope["state"][SCOPE_STATE_NAMESPACE] + + +def test_app_state(config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], monkeypatch: MonkeyPatch) -> None: + """Test app_state.""" + config = config_cls(connection_string="sqlite://") + with patch.object(config, "create_session_maker") as create_session_maker_mock, patch.object( + config, "create_engine" + ) as create_engine_mock: + assert config.app_state().keys() == {config.engine_app_state_key, config.session_maker_app_state_key} + create_session_maker_mock.assert_called_once() + create_engine_mock.assert_called_once() From 4b546a0749f0f3e023f8d2228be0a739042aa0ea Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 12:51:48 +1000 Subject: [PATCH 02/16] Fix example docstrings. --- .../plugins/sqlalchemy_init_plugin/sqlalchemy_async.py | 3 +-- .../examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py index a9f03a7fb0..615c2cbdf0 100644 --- a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py +++ b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py @@ -13,8 +13,7 @@ @get(path="/sqlalchemy-app") async def async_sqlalchemy_init(db_session: AsyncSession, db_engine: AsyncEngine) -> str: - """Create a new company and return it.""" - + """Interact with SQLAlchemy engine and session.""" one = (await db_session.execute(text("SELECT 1"))).scalar_one() async with db_engine.begin() as conn: diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py index e0572b1440..983ba1b92b 100644 --- a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py +++ b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py @@ -14,7 +14,7 @@ @get(path="/sqlalchemy-app") def async_sqlalchemy_init(db_session: Session, db_engine: Engine) -> str: - """Create a new company and return it.""" + """Interact with SQLAlchemy engine and session.""" one = db_session.execute(text("SELECT 1")).scalar_one() with db_engine.connect() as conn: From 1fed563874700ad38a4a441c1a0e328c1361a238 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 13:20:44 +1000 Subject: [PATCH 03/16] Tests for dataclass utils. --- starlite/utils/dataclass.py | 9 ++++++++- tests/utils/test_dataclass.py | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_dataclass.py diff --git a/starlite/utils/dataclass.py b/starlite/utils/dataclass.py index 5f9c77ca15..05a15ccfeb 100644 --- a/starlite/utils/dataclass.py +++ b/starlite/utils/dataclass.py @@ -11,6 +11,8 @@ __all__ = ( "asdict_filter_empty", "extract_dataclass_fields", + "simple_asdict", + "simple_asdict_filter_empty", ) @@ -75,4 +77,9 @@ def simple_asdict_filter_empty(obj: DataclassProtocol) -> dict[str, Any]: Returns: ``obj`` converted into a ``dict`` of its fields, with any :class:`Empty<.types.Empty>` values excluded. """ - return {k: v for k, v in simple_asdict(obj).items() if v is not Empty} + field_values = ((field.name, getattr(obj, field.name)) for field in fields(obj)) + return { + k: simple_asdict_filter_empty(v) if isinstance(v, DataclassProtocol) else v + for k, v in field_values + if v is not Empty + } diff --git a/tests/utils/test_dataclass.py b/tests/utils/test_dataclass.py new file mode 100644 index 0000000000..c4b5af588e --- /dev/null +++ b/tests/utils/test_dataclass.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from starlite.types import Empty +from starlite.utils.dataclass import asdict_filter_empty, simple_asdict, simple_asdict_filter_empty + +if TYPE_CHECKING: + from starlite.types import EmptyType + + +@dataclass +class Foo: + bar: str = "baz" + baz: int | EmptyType = Empty + qux: list[str] = field(default_factory=lambda: ["quux", "quuz"]) + + +@dataclass +class Bar: + foo: Foo = field(default_factory=Foo) + quux: list[Foo] = field(default_factory=lambda: [Foo(), Foo()]) + + +def test_asdict_filter_empty() -> None: + foo = Foo() + assert asdict_filter_empty(foo) == {"bar": "baz", "qux": ["quux", "quuz"]} + + +def test_simple_asdict() -> None: + bar = Bar() + assert simple_asdict(bar) == {"foo": {"bar": "baz", "baz": Empty, "qux": ["quux", "quuz"]}, "quux": [Foo(), Foo()]} + + +def test_simple_asdict_filter_empty() -> None: + bar = Bar() + assert simple_asdict_filter_empty(bar) == {"foo": {"bar": "baz", "qux": ["quux", "quuz"]}, "quux": [Foo(), Foo()]} From 1a57ddce56557f6be2d4a937b52467b2d43e3bb3 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 13:43:11 +1000 Subject: [PATCH 04/16] Tests for default before send handlers. --- .../init_plugin/config/test_asyncio.py | 33 +++++++++++++++++++ .../init_plugin/config/test_sync.py | 33 +++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/test_asyncio.py create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/test_sync.py diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_asyncio.py b/tests/contrib/sqlalchemy/init_plugin/config/test_asyncio.py new file mode 100644 index 0000000000..71ae5b6620 --- /dev/null +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_asyncio.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from starlite import get +from starlite.contrib.sqlalchemy.init_plugin import SQLAlchemyAsyncConfig, SQLAlchemyInitPlugin +from starlite.testing import create_test_client + +if TYPE_CHECKING: + from typing import Any + + from sqlalchemy.ext.asyncio import AsyncSession + + from starlite.types import Scope + + +def test_default_before_send_handler() -> None: + """Test default_before_send_handler.""" + + captured_scope_state: dict[str, Any] | None = None + config = SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite://") + plugin = SQLAlchemyInitPlugin(config=config) + + @get() + def test_handler(db_session: AsyncSession, scope: Scope) -> None: + nonlocal captured_scope_state + captured_scope_state = scope["state"] + assert db_session is captured_scope_state[config.session_dependency_key] + + with create_test_client(route_handlers=[test_handler], plugins=[plugin]) as client: + client.get("/") + assert captured_scope_state is not None + assert config.session_dependency_key not in captured_scope_state # pyright: ignore diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_sync.py b/tests/contrib/sqlalchemy/init_plugin/config/test_sync.py new file mode 100644 index 0000000000..8c8835a438 --- /dev/null +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_sync.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from starlite import get +from starlite.contrib.sqlalchemy.init_plugin import SQLAlchemyInitPlugin, SQLAlchemySyncConfig +from starlite.testing import create_test_client + +if TYPE_CHECKING: + from typing import Any + + from sqlalchemy.orm import Session + + from starlite.types import Scope + + +def test_default_before_send_handler() -> None: + """Test default_before_send_handler.""" + + captured_scope_state: dict[str, Any] | None = None + config = SQLAlchemySyncConfig(connection_string="sqlite+aiosqlite://") + plugin = SQLAlchemyInitPlugin(config=config) + + @get() + def test_handler(db_session: Session, scope: Scope) -> None: + nonlocal captured_scope_state + captured_scope_state = scope["state"] + assert db_session is captured_scope_state[config.session_dependency_key] + + with create_test_client(route_handlers=[test_handler], plugins=[plugin]) as client: + client.get("/") + assert captured_scope_state is not None + assert config.session_dependency_key not in captured_scope_state From 28322cb36b8ea6b890ba228facac98b7ccc35777 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 13:47:03 +1000 Subject: [PATCH 05/16] Test for engine json serializer. --- .../contrib/sqlalchemy/init_plugin/config/test_engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/test_engine.py diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_engine.py b/tests/contrib/sqlalchemy/init_plugin/config/test_engine.py new file mode 100644 index 0000000000..dc72c40beb --- /dev/null +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_engine.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from starlite.contrib.sqlalchemy.init_plugin.config.engine import serializer + + +def test_serializer_returns_string() -> None: + """Test that serializer returns a string.""" + assert isinstance(serializer({"a": "b"}), str) From 8f0ed462bb2aaedcb936ed9419685c24a7d8b022 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 13:51:48 +1000 Subject: [PATCH 06/16] Update starlite/contrib/sqlalchemy/init_plugin/config/common.py --- starlite/contrib/sqlalchemy/init_plugin/config/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/common.py b/starlite/contrib/sqlalchemy/init_plugin/config/common.py index a0a25d7eeb..694f6406eb 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/common.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/common.py @@ -142,7 +142,7 @@ def signature_namespace(self) -> dict[str, Any]: Returns: A string keyed dict of names to be added to the namespace for signature forward reference resolution. """ - return {} + return {} # pragma: no cover def create_engine(self) -> EngineT: """Return an engine. If none exists yet, create one. From 8f02d0be4fc692e45ddafad41e52874ed283d1c3 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 12:46:36 +1000 Subject: [PATCH 07/16] SQLAlchemy 2.0 `InitPluginProtocol` implementation. --- docs/conf.py | 5 +- .../__init__.py | 0 .../sqlalchemy_async.py | 0 .../sqlalchemy_relationships.py | 0 .../sqlalchemy_relationships_to_many.py | 0 .../sqlalchemy_sync.py | 0 .../sqlalchemy_init_plugin/__init__.py | 0 .../sqlalchemy_async.py | 31 +++ .../sqlalchemy_init_plugin/sqlalchemy_sync.py | 31 +++ ..._plugin.py => test_sqlalchemy_1_plugin.py} | 8 +- .../plugins/test_sqlalchemy_init_plugin.py | 20 ++ docs/reference/contrib/index.rst | 1 + .../contrib/sqlalchemy/config/asyncio.rst | 5 + .../contrib/sqlalchemy/config/common.rst | 5 + .../contrib/sqlalchemy/config/index.rst | 9 + .../contrib/sqlalchemy/config/sync.rst | 5 + docs/reference/contrib/sqlalchemy/index.rst | 8 + docs/reference/contrib/sqlalchemy/plugin.rst | 5 + docs/usage/plugins/sqlalchemy.rst | 138 ++--------- docs/usage/responses.rst | 2 +- starlite/constants.py | 7 +- .../sqlalchemy/init_plugin/__init__.py | 19 ++ .../sqlalchemy/init_plugin/config/__init__.py | 13 ++ .../sqlalchemy/init_plugin/config/asyncio.py | 92 ++++++++ .../sqlalchemy/init_plugin/config/common.py | 215 ++++++++++++++++++ .../sqlalchemy/init_plugin/config/engine.py | 76 +++++++ .../sqlalchemy/init_plugin/config/sync.py | 87 +++++++ .../contrib/sqlalchemy/init_plugin/plugin.py | 45 ++++ starlite/utils/__init__.py | 2 + starlite/utils/dataclass.py | 56 ++++- starlite/utils/scope.py | 18 +- .../sqlalchemy/init_plugin/__init__.py | 0 .../sqlalchemy/init_plugin/config/__init__.py | 0 .../init_plugin/config/test_common.py | 166 ++++++++++++++ 34 files changed, 931 insertions(+), 138 deletions(-) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/__init__.py (100%) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/sqlalchemy_async.py (100%) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/sqlalchemy_relationships.py (100%) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/sqlalchemy_relationships_to_many.py (100%) rename docs/examples/plugins/{sqlalchemy_plugin => sqlalchemy_1_plugin}/sqlalchemy_sync.py (100%) create mode 100644 docs/examples/plugins/sqlalchemy_init_plugin/__init__.py create mode 100644 docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py create mode 100644 docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py rename docs/examples/tests/plugins/{test_sqlalchemy_plugin.py => test_sqlalchemy_1_plugin.py} (83%) create mode 100644 docs/examples/tests/plugins/test_sqlalchemy_init_plugin.py create mode 100644 docs/reference/contrib/sqlalchemy/config/asyncio.rst create mode 100644 docs/reference/contrib/sqlalchemy/config/common.rst create mode 100644 docs/reference/contrib/sqlalchemy/config/index.rst create mode 100644 docs/reference/contrib/sqlalchemy/config/sync.rst create mode 100644 docs/reference/contrib/sqlalchemy/index.rst create mode 100644 docs/reference/contrib/sqlalchemy/plugin.rst create mode 100644 starlite/contrib/sqlalchemy/init_plugin/__init__.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/__init__.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/common.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/engine.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/config/sync.py create mode 100644 starlite/contrib/sqlalchemy/init_plugin/plugin.py create mode 100644 tests/contrib/sqlalchemy/init_plugin/__init__.py create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/__init__.py create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/test_common.py diff --git a/docs/conf.py b/docs/conf.py index b33c43c698..99cab96418 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -35,7 +35,7 @@ "msgspec": ("https://jcristharif.com/msgspec/", None), "anyio": ("https://anyio.readthedocs.io/en/stable/", None), "multidict": ("https://multidict.aio-libs.org/en/stable/", None), - "sqlalchemy": ("https://docs.sqlalchemy.org/en/14/", None), + "sqlalchemy": ("https://docs.sqlalchemy.org/en/20/", None), "click": ("https://click.palletsprojects.com/en/8.1.x/", None), "redis": ("https://redis-py.readthedocs.io/en/stable/", None), "picologging": ("https://microsoft.github.io/picologging", None), @@ -113,6 +113,9 @@ "starlite.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin.handle_string_type": {"BINARY", "VARBINARY", "LargeBinary"}, "starlite.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin.is_plugin_supported_type": {"DeclarativeMeta"}, re.compile(r"starlite\.plugins.*"): re.compile(".*(ModelT|DataContainerT)"), + re.compile(r"starlite\.contrib\.sqlalchemy\.init_plugin\.config\.common.*"): re.compile( + ".*(ConnectionT|EngineT|SessionT|SessionMakerT)" + ), } diff --git a/docs/examples/plugins/sqlalchemy_plugin/__init__.py b/docs/examples/plugins/sqlalchemy_1_plugin/__init__.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/__init__.py rename to docs/examples/plugins/sqlalchemy_1_plugin/__init__.py diff --git a/docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_async.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_async.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_async.py rename to docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_async.py diff --git a/docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships.py rename to docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships.py diff --git a/docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships_to_many.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships_to_many.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships_to_many.py rename to docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_relationships_to_many.py diff --git a/docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_sync.py b/docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_sync.py similarity index 100% rename from docs/examples/plugins/sqlalchemy_plugin/sqlalchemy_sync.py rename to docs/examples/plugins/sqlalchemy_1_plugin/sqlalchemy_sync.py diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/__init__.py b/docs/examples/plugins/sqlalchemy_init_plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py new file mode 100644 index 0000000000..a9f03a7fb0 --- /dev/null +++ b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlalchemy import text + +from starlite import Starlite, get +from starlite.contrib.sqlalchemy.init_plugin import SQLAlchemyAsyncConfig, SQLAlchemyInitPlugin + +if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession + + +@get(path="/sqlalchemy-app") +async def async_sqlalchemy_init(db_session: AsyncSession, db_engine: AsyncEngine) -> str: + """Create a new company and return it.""" + + one = (await db_session.execute(text("SELECT 1"))).scalar_one() + + async with db_engine.begin() as conn: + two = (await conn.execute(text("SELECT 2"))).scalar_one() + + return f"{one} {two}" + + +sqlalchemy_config = SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite:///test.sqlite") + +app = Starlite( + route_handlers=[async_sqlalchemy_init], + plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)], +) diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py new file mode 100644 index 0000000000..e0572b1440 --- /dev/null +++ b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from sqlalchemy import text + +from starlite import Starlite, get +from starlite.contrib.sqlalchemy.init_plugin import SQLAlchemyInitPlugin, SQLAlchemySyncConfig + +if TYPE_CHECKING: + from sqlalchemy import Engine + from sqlalchemy.orm import Session + + +@get(path="/sqlalchemy-app") +def async_sqlalchemy_init(db_session: Session, db_engine: Engine) -> str: + """Create a new company and return it.""" + one = db_session.execute(text("SELECT 1")).scalar_one() + + with db_engine.connect() as conn: + two = conn.execute(text("SELECT 2")).scalar_one() + + return f"{one} {two}" + + +sqlalchemy_config = SQLAlchemySyncConfig(connection_string="sqlite:///test.sqlite") + +app = Starlite( + route_handlers=[async_sqlalchemy_init], + plugins=[SQLAlchemyInitPlugin(config=sqlalchemy_config)], +) diff --git a/docs/examples/tests/plugins/test_sqlalchemy_plugin.py b/docs/examples/tests/plugins/test_sqlalchemy_1_plugin.py similarity index 83% rename from docs/examples/tests/plugins/test_sqlalchemy_plugin.py rename to docs/examples/tests/plugins/test_sqlalchemy_1_plugin.py index a327046bb6..1cc7727d6c 100644 --- a/docs/examples/tests/plugins/test_sqlalchemy_plugin.py +++ b/docs/examples/tests/plugins/test_sqlalchemy_1_plugin.py @@ -3,14 +3,14 @@ import pytest -from examples.plugins.sqlalchemy_plugin.sqlalchemy_async import app as async_sqla_app -from examples.plugins.sqlalchemy_plugin.sqlalchemy_relationships import ( +from examples.plugins.sqlalchemy_1_plugin.sqlalchemy_async import app as async_sqla_app +from examples.plugins.sqlalchemy_1_plugin.sqlalchemy_relationships import ( app as relationship_app, ) -from examples.plugins.sqlalchemy_plugin.sqlalchemy_relationships_to_many import ( +from examples.plugins.sqlalchemy_1_plugin.sqlalchemy_relationships_to_many import ( app as relationship_app_to_many, ) -from examples.plugins.sqlalchemy_plugin.sqlalchemy_sync import app as sync_sqla_app +from examples.plugins.sqlalchemy_1_plugin.sqlalchemy_sync import app as sync_sqla_app from starlite import Starlite from starlite.testing import TestClient diff --git a/docs/examples/tests/plugins/test_sqlalchemy_init_plugin.py b/docs/examples/tests/plugins/test_sqlalchemy_init_plugin.py new file mode 100644 index 0000000000..7e3d93a7df --- /dev/null +++ b/docs/examples/tests/plugins/test_sqlalchemy_init_plugin.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from examples.plugins.sqlalchemy_init_plugin.sqlalchemy_async import app as async_sqla_app +from examples.plugins.sqlalchemy_init_plugin.sqlalchemy_sync import app as sync_sqla_app +from starlite.testing import TestClient + +if TYPE_CHECKING: + from starlite import Starlite + + +@pytest.mark.parametrize("app", [async_sqla_app, sync_sqla_app]) +def test_app(app: Starlite) -> None: + with TestClient(app=app) as client: + res = client.get("/sqlalchemy-app") + assert res.status_code == 200 + assert res.json() == "1 2" diff --git a/docs/reference/contrib/index.rst b/docs/reference/contrib/index.rst index e8b35dab0f..52970c1da4 100644 --- a/docs/reference/contrib/index.rst +++ b/docs/reference/contrib/index.rst @@ -10,5 +10,6 @@ contrib mako opentelemetry piccolo_orm + sqlalchemy/index sqlalchemy_1/index tortoise_orm diff --git a/docs/reference/contrib/sqlalchemy/config/asyncio.rst b/docs/reference/contrib/sqlalchemy/config/asyncio.rst new file mode 100644 index 0000000000..e138610870 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config/asyncio.rst @@ -0,0 +1,5 @@ +asyncio +======= + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.asyncio + :members: diff --git a/docs/reference/contrib/sqlalchemy/config/common.rst b/docs/reference/contrib/sqlalchemy/config/common.rst new file mode 100644 index 0000000000..e8e2467cf7 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config/common.rst @@ -0,0 +1,5 @@ +asyncio +======= + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.common + :members: diff --git a/docs/reference/contrib/sqlalchemy/config/index.rst b/docs/reference/contrib/sqlalchemy/config/index.rst new file mode 100644 index 0000000000..1919a60794 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config/index.rst @@ -0,0 +1,9 @@ +config +====== + +.. toctree:: + :titlesonly: + + asyncio + common + sync diff --git a/docs/reference/contrib/sqlalchemy/config/sync.rst b/docs/reference/contrib/sqlalchemy/config/sync.rst new file mode 100644 index 0000000000..ddf769ccf1 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config/sync.rst @@ -0,0 +1,5 @@ +sync +==== + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.sync + :members: diff --git a/docs/reference/contrib/sqlalchemy/index.rst b/docs/reference/contrib/sqlalchemy/index.rst new file mode 100644 index 0000000000..3ee4e8d69b --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/index.rst @@ -0,0 +1,8 @@ +sqlalchemy +========== + +.. toctree:: + :titlesonly: + + config/index + plugin diff --git a/docs/reference/contrib/sqlalchemy/plugin.rst b/docs/reference/contrib/sqlalchemy/plugin.rst new file mode 100644 index 0000000000..1a2b06b0c3 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/plugin.rst @@ -0,0 +1,5 @@ +plugin +====== + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.plugin + :members: diff --git a/docs/usage/plugins/sqlalchemy.rst b/docs/usage/plugins/sqlalchemy.rst index 8f4d4af4da..bc98606da0 100644 --- a/docs/usage/plugins/sqlalchemy.rst +++ b/docs/usage/plugins/sqlalchemy.rst @@ -1,46 +1,29 @@ -SQLAlchemy Plugin -================= +SQLAlchemy Plugins +================== Starlite comes with built-in support for `SQLAlchemy `_ via -the :class:`SQLAlchemyPlugin <.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin>`. +the :class:`SQLAlchemyInitPlugin <.contrib.sqlalchemy.init_plugin.plugin.SQLAlchemyInitPlugin>`. Features -------- - * Managed `sessions `_ (sync and async) including dependency injection -* Automatic serialization of SQLAlchemy models powered pydantic -* Data validation based on SQLAlchemy models powered pydantic - -.. seealso:: - - The following examples use SQLAlchemy's "2.0 Style" introduced in SQLAlchemy 1.4. - - If you are unfamiliar with it, you can find a comprehensive migration guide in SQLAlchemy's - documentation `here `_, - and `a handy table `_ - comparing the ORM usage - -.. attention:: - - The :class:`SQLAlchemyPlugin <.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin>` supports only - `mapped classes `_. - `Tables `_ are - currently not supported since they are not easy to convert to pydantic models. +* Managed `engine `_ (sync and async) including dependency injection +* Typed configuration objects Basic Use --------- -You can simply pass an instance of :class:`SQLAlchemyPlugin <.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin>` without -passing config to the Starlite constructor. This will extend support for serialization, deserialization and DTO creation -for SQLAlchemy declarative models: +You can simply pass an instance of :class:`SQLAlchemyInitPlugin <.contrib.sqlalchemy.init_plugin.plugin.SQLAlchemyInitPlugin>` +to the Starlite constructor. This will automatically create a SQLAlchemy engine and session for you, and make them +available to your handlers and dependencies via dependency injection. .. tab-set:: .. tab-item:: Async :sync: async - .. literalinclude:: /examples/plugins/sqlalchemy_plugin/sqlalchemy_async.py + .. literalinclude:: /examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py :caption: sqlalchemy_plugin.py :language: python @@ -48,109 +31,14 @@ for SQLAlchemy declarative models: .. tab-item:: Sync :sync: sync - .. literalinclude:: /examples/plugins/sqlalchemy_plugin/sqlalchemy_sync.py + .. literalinclude:: /examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py :caption: sqlalchemy_plugin.py :language: python -.. admonition:: Using imperative mappings - :class: info - - `Imperative mappings `_ - are supported as well, just make sure to use a mapped class instead of the table itself - - .. code-block:: python - - company_table = Table( - "company", - Base.registry.metadata, - Column("id", Integer, primary_key=True), - Column("name", String), - Column("worth", Float), - ) - - - class Company: - pass - - - Base.registry.map_imperatively(Company, company_table) - - -Relationships -------------- - -.. attention:: - - Currently only to-one relationships are supported because of the way the SQLAlchemy plugin handles relationships. - Since it recursively traverses relationships, a cyclic reference will result in an endless loop. To prevent this, - these relationships will be type as :class:`typing.Any` in the pydantic model - Relationships are typed as :class:`typing.Optional` in the pydantic model by default so sending incomplete models - won't cause any issues. - - -Simple relationships -^^^^^^^^^^^^^^^^^^^^ - -Simple relationships can be handled by the plugin automatically: - -.. literalinclude:: /examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships.py - :caption: sqlalchemy_relationships.py - :language: python - - -.. admonition:: Example - :class: tip - - Run the above with ``uvicorn sqlalchemy_relationships:app``, navigate your browser to - `http://127.0.0.0:8000/user/1 `_ - and you will see: - - .. code-block:: json - - { - "id": 1, - "name": "Peter", - "company_id": 1, - "company": { - "id": 1, - "name": "Peter Co.", - "worth": 0 - } - } - - -To-Many relationships and circular references -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For to-many relationships or those that contain circular references you need to define the pydantic models yourself: - -.. literalinclude:: /examples/plugins/sqlalchemy_plugin/sqlalchemy_relationships_to_many.py - :caption: sqlalchemy_relationships_to_many - :language: python - - -.. admonition:: Example - :class: tip - - Run the above with ``uvicorn sqlalchemy_relationships_to_many:app``, navigate your browser to `http://127.0.0.0:8000/user/1`_ - and you will see: - - .. code-block:: json - - { - "id": 1, - "name": "Peter", - "pets": [ - { - "id": 1, - "name": "Paul" - } - ] - } - - Configuration ------------- -You can configure the Plugin using the :class:`SQLAlchemyConfig <.contrib.sqlalchemy_1.config.SQLAlchemyConfig>` object. +You configure the Plugin using either +:class:`SQLAlchemyAsyncConfig <.contrib.sqlalchemy.init_plugin.config.asyncio.SQLAlchemyAsyncConfig>` or +:class:`SQLAlchemySyncConfig <.contrib.sqlalchemy.init_plugin.config.sync.SQLAlchemySyncConfig>`. diff --git a/docs/usage/responses.rst b/docs/usage/responses.rst index 6158a5c1fd..a31ff33ba8 100644 --- a/docs/usage/responses.rst +++ b/docs/usage/responses.rst @@ -825,7 +825,7 @@ kwargs>` :language: python -See :ref:`SQLAlchemy plugin ` for sqlalchemy integration. +See :ref:`SQLAlchemy plugin ` for sqlalchemy integration. Cursor Pagination +++++++++++++++++ diff --git a/starlite/constants.py b/starlite/constants.py index cc0065bec3..722608653b 100644 --- a/starlite/constants.py +++ b/starlite/constants.py @@ -6,14 +6,17 @@ DEFAULT_ALLOWED_CORS_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} DEFAULT_CHUNK_SIZE = 1024 * 128 # 128KB +HTTP_DISCONNECT = "http.disconnect" HTTP_RESPONSE_BODY = "http.response.body" HTTP_RESPONSE_START = "http.response.start" ONE_MEGABYTE = 1024 * 1024 +OPENAPI_NOT_INITIALIZED = "Starlite has not been instantiated with OpenAPIConfig" REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308} RESERVED_KWARGS = {"state", "headers", "cookies", "request", "socket", "data", "query", "scope", "body"} SCOPE_STATE_DEPENDENCY_CACHE = "dependency_cache" SCOPE_STATE_NAMESPACE = "__starlite__" SCOPE_STATE_RESPONSE_COMPRESSED = "response_compressed" -UNDEFINED_SENTINELS = {Undefined, Signature.empty, Empty, Ellipsis} SKIP_VALIDATION_NAMES = {"request", "socket", "scope", "receive", "send"} -OPENAPI_NOT_INITIALIZED = "Starlite has not been instantiated with OpenAPIConfig" +UNDEFINED_SENTINELS = {Undefined, Signature.empty, Empty, Ellipsis} +WEBSOCKET_CLOSE = "websocket.close" +WEBSOCKET_DISCONNECT = "websocket.disconnect" diff --git a/starlite/contrib/sqlalchemy/init_plugin/__init__.py b/starlite/contrib/sqlalchemy/init_plugin/__init__.py new file mode 100644 index 0000000000..782770ea09 --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/__init__.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from .config import ( + AsyncSessionConfig, + EngineConfig, + SQLAlchemyAsyncConfig, + SQLAlchemySyncConfig, + SyncSessionConfig, +) +from .plugin import SQLAlchemyInitPlugin + +__all__ = ( + "AsyncSessionConfig", + "EngineConfig", + "SQLAlchemyAsyncConfig", + "SQLAlchemyInitPlugin", + "SQLAlchemySyncConfig", + "SyncSessionConfig", +) diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py b/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py new file mode 100644 index 0000000000..deff28a94c --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from .asyncio import AsyncSessionConfig, SQLAlchemyAsyncConfig +from .engine import EngineConfig +from .sync import SQLAlchemySyncConfig, SyncSessionConfig + +__all__ = ( + "AsyncSessionConfig", + "EngineConfig", + "SQLAlchemyAsyncConfig", + "SQLAlchemySyncConfig", + "SyncSessionConfig", +) diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py new file mode 100644 index 0000000000..89ab40b3f5 --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, cast + +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine + +from starlite.types import Empty +from starlite.utils import ( + delete_starlite_scope_state, + get_starlite_scope_state, +) + +from .common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig + +if TYPE_CHECKING: + from typing import Any, Callable + + from sqlalchemy.orm import Session + + from starlite.datastructures.state import State + from starlite.types import BeforeMessageSendHookHandler, EmptyType, Message, Scope + +__all__ = ("SQLAlchemyAsyncConfig", "AsyncSessionConfig") + + +async def default_before_send_handler(message: Message, _: State, scope: Scope) -> None: + """Handle closing and cleaning up sessions before sending. + + Args: + message: ASGI-``Message`` + _: A ``State`` (not used) + scope: An ASGI-``Scope`` + + Returns: + None + """ + session = cast("AsyncSession | None", get_starlite_scope_state(scope, SESSION_SCOPE_KEY)) + if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: + await session.close() + delete_starlite_scope_state(scope, SESSION_SCOPE_KEY) + + +@dataclass +class AsyncSessionConfig(GenericSessionConfig[AsyncConnection, AsyncEngine, AsyncSession]): + """SQLAlchemy async session config.""" + + sync_session_class: type[Session] | None | EmptyType = Empty + + +@dataclass +class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, async_sessionmaker]): + """Async SQLAlchemy Configuration.""" + + create_engine_callable: Callable[[str], AsyncEngine] = create_async_engine + """Callable that creates an :class:`AsyncEngine ` instance or instance of its + subclass. + """ + session_config: AsyncSessionConfig = field(default_factory=AsyncSessionConfig) + """Configuration options for the ``sessionmaker``. + + The configuration options are documented in the SQLAlchemy documentation. + """ + session_maker_class: type[async_sessionmaker] = async_sessionmaker + """Sessionmaker class to use.""" + before_send_handler: BeforeMessageSendHookHandler = default_before_send_handler + """Handler to call before the ASGI message is sent. + + The handler should handle closing the session stored in the ASGI scope, if its still open, and committing and + uncommitted data. + """ + + @property + def signature_namespace(self) -> dict[str, Any]: + """Return the plugin's signature namespace. + + Returns: + A string keyed dict of names to be added to the namespace for signature forward reference resolution. + """ + return {"AsyncEngine": AsyncEngine, "AsyncSession": AsyncSession} + + async def on_shutdown(self, state: State) -> None: + """Disposes of the SQLAlchemy engine. + + Args: + state: The ``Starlite.state`` instance. + + Returns: + None + """ + engine = cast("AsyncEngine", state.pop(self.engine_app_state_key)) + await engine.dispose() diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/common.py b/starlite/contrib/sqlalchemy/init_plugin/config/common.py new file mode 100644 index 0000000000..a0a25d7eeb --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/common.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Callable, Generic, TypeVar, cast + +from starlite.constants import HTTP_DISCONNECT, HTTP_RESPONSE_START, WEBSOCKET_CLOSE, WEBSOCKET_DISCONNECT +from starlite.exceptions import ImproperlyConfiguredException +from starlite.types import Empty +from starlite.utils import get_starlite_scope_state, set_starlite_scope_state +from starlite.utils.dataclass import simple_asdict_filter_empty + +from .engine import EngineConfig + +if TYPE_CHECKING: + from typing import Any + + from sqlalchemy import Connection, Engine + from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker + from sqlalchemy.orm import Mapper, Query, Session, sessionmaker + from sqlalchemy.orm.session import JoinTransactionMode + from sqlalchemy.sql import TableClause + + from starlite.datastructures.state import State + from starlite.types import BeforeMessageSendHookHandler, EmptyType, Scope + +__all__ = ( + "SESSION_SCOPE_KEY", + "SESSION_TERMINUS_ASGI_EVENTS", + "GenericSQLAlchemyConfig", + "GenericSessionConfig", +) + +SESSION_SCOPE_KEY = "_sqlalchemy_db_session" +SESSION_TERMINUS_ASGI_EVENTS = {HTTP_RESPONSE_START, HTTP_DISCONNECT, WEBSOCKET_DISCONNECT, WEBSOCKET_CLOSE} + +ConnectionT = TypeVar("ConnectionT", bound="Connection | AsyncConnection") +EngineT = TypeVar("EngineT", bound="Engine | AsyncEngine") +SessionT = TypeVar("SessionT", bound="Session | AsyncSession") +SessionMakerT = TypeVar("SessionMakerT", bound="sessionmaker | async_sessionmaker") + + +@dataclass +class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]): + """SQLAlchemy async session config.""" + + autobegin: bool | EmptyType = Empty + autoflush: bool | EmptyType = Empty + bind: EngineT | ConnectionT | None | EmptyType = Empty + binds: dict[type[Any] | Mapper[Any] | TableClause | str, EngineT | ConnectionT] | None | EmptyType = Empty + class_: type[SessionT] | EmptyType = Empty + enable_baked_queries: bool | EmptyType = Empty + expire_on_commit: bool | EmptyType = Empty + info: dict[str, Any] | None | EmptyType = Empty + join_transaction_mode: JoinTransactionMode | EmptyType = Empty + query_cls: type[Query] | None | EmptyType = Empty + twophase: bool | EmptyType = Empty + + +@dataclass +class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]): + """Common SQLAlchemy Configuration.""" + + create_engine_callable: Callable[[str], EngineT] + """Callable that creates an :class:`AsyncEngine ` instance or instance of its + subclass. + """ + session_config: GenericSessionConfig + """Configuration options for the ``sessionmaker``. + + The configuration options are documented in the SQLAlchemy documentation. + """ + session_maker_class: type[sessionmaker] | type[async_sessionmaker] + """Sessionmaker class to use.""" + before_send_handler: BeforeMessageSendHookHandler + """Handler to call before the ASGI message is sent. + + The handler should handle closing the session stored in the ASGI scope, if its still open, and committing and + uncommitted data. + """ + connection_string: str | None = field(default=None) + """Database connection string in one of the formats supported by SQLAlchemy. + + Notes: + - For async connections, the connection string must include the correct async prefix. + e.g. ``'postgresql+asyncpg://...'`` instead of ``'postgresql://'``, and for sync connections its the opposite. + + """ + engine_dependency_key: str = "db_engine" + """Key to use for the dependency injection of database engines.""" + session_dependency_key: str = "db_session" + """Key to use for the dependency injection of database sessions.""" + engine_app_state_key: str = "db_engine" + """Key under which to store the SQLAlchemy engine in the application :class:`State <.datastructures.State>` + instance. + """ + engine_config: EngineConfig = field(default_factory=EngineConfig) + """Configuration for the SQLAlchemy engine. + + The configuration options are documented in the SQLAlchemy documentation. + """ + session_maker_app_state_key: str = "session_maker_class" + """Key under which to store the SQLAlchemy ``sessionmaker`` in the application + :class:`State <.datastructures.State>` instance. + """ + session_maker: Callable[[], SessionT] | None = None + """Callable that returns a session. + + If provided, the plugin will use this rather than instantiate a sessionmaker. + """ + engine_instance: EngineT | None = None + """Optional engine to use. + + If set, the plugin will use the provided instance rather than instantiate an engine. + """ + + def __post_init__(self) -> None: + if self.connection_string is not None and self.engine_instance is not None: + raise ImproperlyConfiguredException("Only one of 'connection_string' or 'engine_instance' can be provided.") + + @property + def engine_config_dict(self) -> dict[str, Any]: + """Return the engine configuration as a dict. + + Returns: + A string keyed dict of config kwargs for the SQLAlchemy ``create_engine`` function. + """ + return simple_asdict_filter_empty(self.engine_config) + + @property + def session_config_dict(self) -> dict[str, Any]: + """Return the session configuration as a dict. + + Returns: + A string keyed dict of config kwargs for the SQLAlchemy ``sessionmaker`` class. + """ + return simple_asdict_filter_empty(self.session_config) + + @property + def signature_namespace(self) -> dict[str, Any]: + """Return the plugin's signature namespace. + + Returns: + A string keyed dict of names to be added to the namespace for signature forward reference resolution. + """ + return {} + + def create_engine(self) -> EngineT: + """Return an engine. If none exists yet, create one. + + Returns: + Getter that returns the engine instance used by the plugin. + """ + if self.engine_instance: + return self.engine_instance + + if self.connection_string is None: + raise ImproperlyConfiguredException("One of 'connection_string' or 'engine_instance' must be provided.") + + engine_config = self.engine_config_dict + try: + return self.create_engine_callable(self.connection_string, **engine_config) + except ValueError: + # likely due to a dialect that doesn't support json type + del engine_config["json_deserializer"] + del engine_config["json_serializer"] + return self.create_engine_callable(self.connection_string, **engine_config) + + def create_session_maker(self) -> Callable[[], SessionT]: + """Get a session maker. If none exists yet, create one. + + Returns: + Session factory used by the plugin. + """ + if self.session_maker: + return self.session_maker + + session_kws = self.session_config_dict + if session_kws.get("bind") is None: + session_kws["bind"] = self.create_engine() + return self.session_maker_class(**session_kws) + + def provide_engine(self, state: State) -> EngineT: + """Create an engine instance. + + Args: + state: The ``Starlite.state`` instance. + + Returns: + An engine instance. + """ + return cast("EngineT", state.get(self.engine_app_state_key)) + + def provide_session(self, state: State, scope: Scope) -> SessionT: + """Create a session instance. + + Args: + state: The ``Starlite.state`` instance. + scope: The current connection's scope. + + Returns: + A session instance. + """ + session = cast("SessionT | None", get_starlite_scope_state(scope, SESSION_SCOPE_KEY)) + if session is None: + session_maker = cast("Callable[[], SessionT]", state[self.session_maker_app_state_key]) + session = session_maker() + set_starlite_scope_state(scope, SESSION_SCOPE_KEY, session) + return session + + def app_state(self) -> dict[str, Any]: + """Key/value pairs to be stored in application state.""" + return { + self.engine_app_state_key: self.create_engine(), + self.session_maker_app_state_key: self.create_session_maker(), + } diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/engine.py b/starlite/contrib/sqlalchemy/init_plugin/config/engine.py new file mode 100644 index 0000000000..71659bc3ef --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/engine.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Literal + +from starlite.exceptions import MissingDependencyException +from starlite.serialization import decode_json, encode_json +from starlite.types import Empty + +try: + import sqlalchemy # noqa: F401 +except ImportError as e: + raise MissingDependencyException("sqlalchemy is not installed") from e + +if TYPE_CHECKING: + from typing import Any, Mapping + + from sqlalchemy.engine.interfaces import IsolationLevel + from sqlalchemy.pool import Pool + from typing_extensions import TypeAlias + + from starlite.types import EmptyType + +__all__ = ("EngineConfig",) + +_EchoFlagType: TypeAlias = "None | bool | Literal['debug']" +_ParamStyle = Literal["qmark", "numeric", "named", "format", "pyformat", "numeric_dollar"] + + +def serializer(value: Any) -> str: + """Serialize JSON field values. + + Args: + value: Any json serializable value. + + Returns: + JSON string. + """ + return encode_json(value).decode("utf-8") + + +@dataclass +class EngineConfig: + """Configuration for SQLAlchemy's :class`Engine `. + + For details see: https://docs.sqlalchemy.org/en/20/core/engines.html + """ + + connect_args: dict[Any, Any] | EmptyType = Empty + echo: _EchoFlagType | EmptyType = Empty + echo_pool: _EchoFlagType | EmptyType = Empty + enable_from_linting: bool | EmptyType = Empty + execution_options: Mapping[str, Any] | EmptyType = Empty + hide_parameters: bool | EmptyType = Empty + insertmanyvalues_page_size: int | EmptyType = Empty + isolation_level: IsolationLevel | EmptyType = Empty + json_deserializer: Callable[[str], Any] = decode_json + json_serializer: Callable[[Any], str] = serializer + label_length: int | None | EmptyType = Empty + logging_name: str | EmptyType = Empty + max_identifier_length: int | None | EmptyType = Empty + max_overflow: int | EmptyType = Empty + module: Any | None | EmptyType = Empty + paramstyle: _ParamStyle | None | EmptyType = Empty + pool: Pool | None | EmptyType = Empty + poolclass: type[Pool] | None | EmptyType = Empty + pool_logging_name: str | EmptyType = Empty + pool_pre_ping: bool | EmptyType = Empty + pool_size: int | EmptyType = Empty + pool_recycle: int | EmptyType = Empty + pool_reset_on_return: Literal["rollback", "commit"] | EmptyType = Empty + pool_timeout: int | EmptyType = Empty + pool_use_lifo: bool | EmptyType = Empty + plugins: list[str] | EmptyType = Empty + query_cache_size: int | EmptyType = Empty + use_insertmanyvalues: bool | EmptyType = Empty diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/sync.py b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py new file mode 100644 index 0000000000..1652d2b0b0 --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, cast + +from sqlalchemy import Connection, Engine, create_engine +from sqlalchemy.orm import Session, sessionmaker + +from starlite.utils import ( + delete_starlite_scope_state, + get_starlite_scope_state, +) + +from .common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig + +if TYPE_CHECKING: + from typing import Any, Callable + + from starlite.datastructures.state import State + from starlite.types import BeforeMessageSendHookHandler, Message, Scope + +__all__ = ("SQLAlchemySyncConfig", "SyncSessionConfig") + + +async def default_before_send_handler(message: Message, _: State, scope: Scope) -> None: + """Handle closing and cleaning up sessions before sending. + + Args: + message: ASGI-``Message`` + _: A ``State`` (not used) + scope: An ASGI-``Scope`` + + Returns: + None + """ + session = cast("Session | None", get_starlite_scope_state(scope, SESSION_SCOPE_KEY)) + if session and message["type"] in SESSION_TERMINUS_ASGI_EVENTS: + session.close() + delete_starlite_scope_state(scope, SESSION_SCOPE_KEY) + + +class SyncSessionConfig(GenericSessionConfig[Connection, Engine, Session]): + pass + + +@dataclass +class SQLAlchemySyncConfig(GenericSQLAlchemyConfig[Engine, Session, sessionmaker]): + """Sync SQLAlchemy Configuration.""" + + create_engine_callable: Callable[[str], Engine] = create_engine + """Callable that creates an :class:`AsyncEngine ` instance or instance of its + subclass. + """ + session_config: SyncSessionConfig = field(default_factory=SyncSessionConfig) # pyright:ignore + """Configuration options for the ``sessionmaker``. + + The configuration options are documented in the SQLAlchemy documentation. + """ + session_maker_class: type[sessionmaker] = sessionmaker + """Sessionmaker class to use.""" + before_send_handler: BeforeMessageSendHookHandler = default_before_send_handler + """Handler to call before the ASGI message is sent. + + The handler should handle closing the session stored in the ASGI scope, if its still open, and committing and + uncommitted data. + """ + + @property + def signature_namespace(self) -> dict[str, Any]: + """Return the plugin's signature namespace. + + Returns: + A string keyed dict of names to be added to the namespace for signature forward reference resolution. + """ + return {"Engine": Engine, "Session": Session} + + def on_shutdown(self, state: State) -> None: + """Disposes of the SQLAlchemy engine. + + Args: + state: The ``Starlite.state`` instance. + + Returns: + None + """ + engine = cast("Engine", state.pop(self.engine_app_state_key)) + engine.dispose() diff --git a/starlite/contrib/sqlalchemy/init_plugin/plugin.py b/starlite/contrib/sqlalchemy/init_plugin/plugin.py new file mode 100644 index 0000000000..5841f71bd7 --- /dev/null +++ b/starlite/contrib/sqlalchemy/init_plugin/plugin.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from starlite.di import Provide +from starlite.plugins import InitPluginProtocol + +if TYPE_CHECKING: + from starlite.config.app import AppConfig + + from .config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig + +__all__ = ("SQLAlchemyInitPlugin",) + + +class SQLAlchemyInitPlugin(InitPluginProtocol): + """SQLAlchemy application lifecycle configuration.""" + + __slots__ = ("_config",) + + def __init__(self, config: SQLAlchemyAsyncConfig | SQLAlchemySyncConfig) -> None: + """Initialize ``SQLAlchemyPlugin``. + + Args: + config: configure DB connection and hook handlers and dependencies. + """ + self._config = config + + def on_app_init(self, app_config: AppConfig) -> AppConfig: + """Configure application for use with SQLAlchemy. + + Args: + app_config: The :class:`AppConfig <.config.app.AppConfig>` instance. + """ + app_config.dependencies.update( + { + self._config.engine_dependency_key: Provide(self._config.provide_engine), + self._config.session_dependency_key: Provide(self._config.provide_session), + } + ) + app_config.before_send.append(self._config.before_send_handler) + app_config.on_shutdown.append(self._config.on_shutdown) + app_config.state.update(self._config.app_state()) + app_config.signature_namespace.update(self._config.signature_namespace) + return app_config diff --git a/starlite/utils/__init__.py b/starlite/utils/__init__.py index d3fb57f0a3..55e4056b63 100644 --- a/starlite/utils/__init__.py +++ b/starlite/utils/__init__.py @@ -19,6 +19,7 @@ create_parsed_model_field, ) from .scope import ( + delete_starlite_scope_state, get_serializer_from_scope, get_starlite_scope_state, set_starlite_scope_state, @@ -43,6 +44,7 @@ "convert_dataclass_to_model", "convert_typeddict_to_model", "create_parsed_model_field", + "delete_starlite_scope_state", "deprecated", "find_index", "get_enum_string_value", diff --git a/starlite/utils/dataclass.py b/starlite/utils/dataclass.py index 1ae226bf70..5f9c77ca15 100644 --- a/starlite/utils/dataclass.py +++ b/starlite/utils/dataclass.py @@ -1,12 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterable, cast - -__all__ = ("extract_dataclass_fields",) +from dataclasses import asdict, fields +from typing import TYPE_CHECKING, cast +from starlite.types import DataclassProtocol, Empty if TYPE_CHECKING: - from starlite.types import DataclassProtocol + from typing import Any, Iterable + +__all__ = ( + "asdict_filter_empty", + "extract_dataclass_fields", +) def extract_dataclass_fields( @@ -28,3 +33,46 @@ def extract_dataclass_fields( if (not exclude_none or getattr(dt, field_name) is not None) and ((include is not None and field_name in include) or include is None) ) + + +def asdict_filter_empty(obj: DataclassProtocol) -> dict[str, Any]: + """Same as stdlib's ``dataclasses.asdict`` with additional filtering for :class:`Empty<.types.Empty>`. + + Args: + obj: A dataclass instance. + + Returns: + ``obj`` converted into a ``dict`` of its fields, with any :class:`Empty<.types.Empty>` values excluded. + """ + return {k: v for k, v in asdict(obj).items() if v is not Empty} + + +def simple_asdict(obj: DataclassProtocol) -> dict[str, Any]: + """Recursively convert a dataclass instance into a ``dict`` of its fields, without using ``copy.deepcopy()``. + + The standard library ``dataclasses.asdict()`` function uses ``copy.deepcopy()`` on any value that is not a + dataclass, dict, list or tuple, which presents a problem when the dataclass holds items that cannot be pickled. + + This function provides an alternative that does not use ``copy.deepcopy()``, and is a much simpler implementation, + only recursing into other dataclasses. + + Args: + obj: A dataclass instance. + + Returns: + ``obj`` converted into a ``dict`` of its fields. + """ + field_values = ((field.name, getattr(obj, field.name)) for field in fields(obj)) + return {k: simple_asdict(v) if isinstance(v, DataclassProtocol) else v for k, v in field_values} + + +def simple_asdict_filter_empty(obj: DataclassProtocol) -> dict[str, Any]: + """Same as asdict_filter_empty but uses ``simple_asdict``. + + Args: + obj: A dataclass instance. + + Returns: + ``obj`` converted into a ``dict`` of its fields, with any :class:`Empty<.types.Empty>` values excluded. + """ + return {k: v for k, v in simple_asdict(obj).items() if v is not Empty} diff --git a/starlite/utils/scope.py b/starlite/utils/scope.py index c6609c42cd..db197bf74d 100644 --- a/starlite/utils/scope.py +++ b/starlite/utils/scope.py @@ -4,7 +4,12 @@ from starlite.constants import SCOPE_STATE_NAMESPACE -__all__ = ("get_serializer_from_scope", "get_starlite_scope_state", "set_starlite_scope_state") +__all__ = ( + "delete_starlite_scope_state", + "get_serializer_from_scope", + "get_starlite_scope_state", + "set_starlite_scope_state", +) if TYPE_CHECKING: @@ -69,3 +74,14 @@ def set_starlite_scope_state(scope: Scope, key: str, value: Any) -> None: value: Value for key. """ scope["state"].setdefault(SCOPE_STATE_NAMESPACE, {})[key] = value + + +def delete_starlite_scope_state(scope: Scope, key: str) -> None: + """Delete an internal value from connection scope state. + + Args: + scope: The connection scope. + key: Key to set under internal namespace in scope state. + value: Value for key. + """ + del scope["state"][SCOPE_STATE_NAMESPACE][key] diff --git a/tests/contrib/sqlalchemy/init_plugin/__init__.py b/tests/contrib/sqlalchemy/init_plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/contrib/sqlalchemy/init_plugin/config/__init__.py b/tests/contrib/sqlalchemy/init_plugin/config/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_common.py b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py new file mode 100644 index 0000000000..8c6c0cbe49 --- /dev/null +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import create_engine + +from starlite.constants import SCOPE_STATE_NAMESPACE +from starlite.contrib.sqlalchemy.init_plugin.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig +from starlite.contrib.sqlalchemy.init_plugin.config.common import SESSION_SCOPE_KEY +from starlite.datastructures import State +from starlite.exceptions import ImproperlyConfiguredException + +if TYPE_CHECKING: + from typing import Any + + from pytest import MonkeyPatch + + from starlite.types import Scope + + +@pytest.fixture(name="config_cls", params=[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig]) +def _config_cls(request: Any) -> type[SQLAlchemySyncConfig | SQLAlchemyAsyncConfig]: + """Return SQLAlchemy config class.""" + return request.param # type:ignore[no-any-return] + + +def test_raise_improperly_configured_exception(config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig]) -> None: + """Test raise ImproperlyConfiguredException if both engine and connection string are provided.""" + with pytest.raises(ImproperlyConfiguredException): + config_cls(connection_string="sqlite://", engine_instance=create_engine("sqlite://")) + + +def test_engine_config_dict_with_no_provided_config( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test engine_config_dict with no provided config.""" + config = config_cls() + assert config.engine_config_dict.keys() == {"json_deserializer", "json_serializer"} + + +def test_session_config_dict_with_no_provided_config( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test session_config_dict with no provided config.""" + config = config_cls() + assert config.session_config_dict == {} + + +def test_config_create_engine_if_engine_instance_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test create_engine if engine instance provided.""" + engine = create_engine("sqlite://") + config = config_cls(engine_instance=engine) + assert config.create_engine() == engine + + +def test_create_engine_if_no_engine_instance_or_connection_string_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test create_engine if no engine instance or connection string provided.""" + config = config_cls() + with pytest.raises(ImproperlyConfiguredException): + config.create_engine() + + +def test_call_create_engine_callable_value_error_handling( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], monkeypatch: MonkeyPatch +) -> None: + """If the dialect doesn't support JSON types, we get a ValueError. + This should be handled by removing the JSON serializer/deserializer kwargs. + """ + call_count = 0 + + def side_effect(*args: Any, **kwargs: Any) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError() + + config = config_cls(connection_string="sqlite://") + create_engine_callable_mock = MagicMock(side_effect=side_effect) + monkeypatch.setattr(config, "create_engine_callable", create_engine_callable_mock) + + config.create_engine() + + assert create_engine_callable_mock.call_count == 2 + first_call, second_call = create_engine_callable_mock.mock_calls + assert first_call.kwargs.keys() == {"json_deserializer", "json_serializer"} + assert second_call.kwargs.keys() == set() + + +def test_create_session_maker_if_session_maker_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test create_session_maker if session maker provided to config.""" + session_maker = MagicMock() + config = config_cls(session_maker=session_maker) + assert config.create_session_maker() == session_maker + + +def test_create_session_maker_if_no_session_maker_provided_and_bind_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], monkeypatch: MonkeyPatch +) -> None: + """Test create_session_maker if no session maker provided to config.""" + config = config_cls() + config.session_config.bind = create_engine("sqlite://") + create_engine_mock = MagicMock() + monkeypatch.setattr(config, "create_engine", create_engine_mock) + assert config.session_maker is None + assert isinstance(config.create_session_maker(), config.session_maker_class) + create_engine_mock.assert_not_called() + + +def test_create_session_maker_if_no_session_maker_or_bind_provided( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], monkeypatch: MonkeyPatch +) -> None: + """Test create_session_maker if no session maker or bind provided to config.""" + config = config_cls() + create_engine_mock = MagicMock(return_value=create_engine("sqlite://")) + monkeypatch.setattr(config, "create_engine", create_engine_mock) + assert config.session_maker is None + assert isinstance(config.create_session_maker(), config.session_maker_class) + create_engine_mock.assert_called_once() + + +def test_create_session_instance_if_session_already_in_scope_state( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test provide_session if session already in scope state.""" + with patch( + "starlite.contrib.sqlalchemy.init_plugin.config.common.get_starlite_scope_state" + ) as get_starlite_scope_state_mock: + session_mock = MagicMock() + get_starlite_scope_state_mock.return_value = session_mock + config = config_cls() + assert config.provide_session(State(), {}) is session_mock # type:ignore[arg-type] + + +def test_create_session_instance_if_session_not_in_scope_state( + config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], +) -> None: + """Test provide_session if session not in scope state.""" + with patch( + "starlite.contrib.sqlalchemy.init_plugin.config.common.get_starlite_scope_state" + ) as get_starlite_scope_state_mock: + get_starlite_scope_state_mock.return_value = None + config = config_cls() + state = State() + state[config.session_maker_app_state_key] = MagicMock() + scope: Scope = {"state": {}} # type:ignore[assignment] + assert isinstance(config.provide_session(state, scope), MagicMock) + assert SESSION_SCOPE_KEY in scope["state"][SCOPE_STATE_NAMESPACE] + + +def test_app_state(config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig], monkeypatch: MonkeyPatch) -> None: + """Test app_state.""" + config = config_cls(connection_string="sqlite://") + with patch.object(config, "create_session_maker") as create_session_maker_mock, patch.object( + config, "create_engine" + ) as create_engine_mock: + assert config.app_state().keys() == {config.engine_app_state_key, config.session_maker_app_state_key} + create_session_maker_mock.assert_called_once() + create_engine_mock.assert_called_once() From 243019f90abd98592e308b241abbd567ab48c275 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 12:51:48 +1000 Subject: [PATCH 08/16] Fix example docstrings. --- .../plugins/sqlalchemy_init_plugin/sqlalchemy_async.py | 3 +-- .../examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py index a9f03a7fb0..615c2cbdf0 100644 --- a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py +++ b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_async.py @@ -13,8 +13,7 @@ @get(path="/sqlalchemy-app") async def async_sqlalchemy_init(db_session: AsyncSession, db_engine: AsyncEngine) -> str: - """Create a new company and return it.""" - + """Interact with SQLAlchemy engine and session.""" one = (await db_session.execute(text("SELECT 1"))).scalar_one() async with db_engine.begin() as conn: diff --git a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py index e0572b1440..983ba1b92b 100644 --- a/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py +++ b/docs/examples/plugins/sqlalchemy_init_plugin/sqlalchemy_sync.py @@ -14,7 +14,7 @@ @get(path="/sqlalchemy-app") def async_sqlalchemy_init(db_session: Session, db_engine: Engine) -> str: - """Create a new company and return it.""" + """Interact with SQLAlchemy engine and session.""" one = db_session.execute(text("SELECT 1")).scalar_one() with db_engine.connect() as conn: From 9605b911116c8f87e3aec060a1bb4c4662984ef6 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 13:20:44 +1000 Subject: [PATCH 09/16] Tests for dataclass utils. --- starlite/utils/dataclass.py | 9 ++++++++- tests/utils/test_dataclass.py | 38 +++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 tests/utils/test_dataclass.py diff --git a/starlite/utils/dataclass.py b/starlite/utils/dataclass.py index 5f9c77ca15..05a15ccfeb 100644 --- a/starlite/utils/dataclass.py +++ b/starlite/utils/dataclass.py @@ -11,6 +11,8 @@ __all__ = ( "asdict_filter_empty", "extract_dataclass_fields", + "simple_asdict", + "simple_asdict_filter_empty", ) @@ -75,4 +77,9 @@ def simple_asdict_filter_empty(obj: DataclassProtocol) -> dict[str, Any]: Returns: ``obj`` converted into a ``dict`` of its fields, with any :class:`Empty<.types.Empty>` values excluded. """ - return {k: v for k, v in simple_asdict(obj).items() if v is not Empty} + field_values = ((field.name, getattr(obj, field.name)) for field in fields(obj)) + return { + k: simple_asdict_filter_empty(v) if isinstance(v, DataclassProtocol) else v + for k, v in field_values + if v is not Empty + } diff --git a/tests/utils/test_dataclass.py b/tests/utils/test_dataclass.py new file mode 100644 index 0000000000..c4b5af588e --- /dev/null +++ b/tests/utils/test_dataclass.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from starlite.types import Empty +from starlite.utils.dataclass import asdict_filter_empty, simple_asdict, simple_asdict_filter_empty + +if TYPE_CHECKING: + from starlite.types import EmptyType + + +@dataclass +class Foo: + bar: str = "baz" + baz: int | EmptyType = Empty + qux: list[str] = field(default_factory=lambda: ["quux", "quuz"]) + + +@dataclass +class Bar: + foo: Foo = field(default_factory=Foo) + quux: list[Foo] = field(default_factory=lambda: [Foo(), Foo()]) + + +def test_asdict_filter_empty() -> None: + foo = Foo() + assert asdict_filter_empty(foo) == {"bar": "baz", "qux": ["quux", "quuz"]} + + +def test_simple_asdict() -> None: + bar = Bar() + assert simple_asdict(bar) == {"foo": {"bar": "baz", "baz": Empty, "qux": ["quux", "quuz"]}, "quux": [Foo(), Foo()]} + + +def test_simple_asdict_filter_empty() -> None: + bar = Bar() + assert simple_asdict_filter_empty(bar) == {"foo": {"bar": "baz", "qux": ["quux", "quuz"]}, "quux": [Foo(), Foo()]} From 2392bfc150d24c4163daa79a2958e1615a86c7a9 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 13:43:11 +1000 Subject: [PATCH 10/16] Tests for default before send handlers. --- .../init_plugin/config/test_asyncio.py | 33 +++++++++++++++++++ .../init_plugin/config/test_sync.py | 33 +++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/test_asyncio.py create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/test_sync.py diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_asyncio.py b/tests/contrib/sqlalchemy/init_plugin/config/test_asyncio.py new file mode 100644 index 0000000000..71ae5b6620 --- /dev/null +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_asyncio.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from starlite import get +from starlite.contrib.sqlalchemy.init_plugin import SQLAlchemyAsyncConfig, SQLAlchemyInitPlugin +from starlite.testing import create_test_client + +if TYPE_CHECKING: + from typing import Any + + from sqlalchemy.ext.asyncio import AsyncSession + + from starlite.types import Scope + + +def test_default_before_send_handler() -> None: + """Test default_before_send_handler.""" + + captured_scope_state: dict[str, Any] | None = None + config = SQLAlchemyAsyncConfig(connection_string="sqlite+aiosqlite://") + plugin = SQLAlchemyInitPlugin(config=config) + + @get() + def test_handler(db_session: AsyncSession, scope: Scope) -> None: + nonlocal captured_scope_state + captured_scope_state = scope["state"] + assert db_session is captured_scope_state[config.session_dependency_key] + + with create_test_client(route_handlers=[test_handler], plugins=[plugin]) as client: + client.get("/") + assert captured_scope_state is not None + assert config.session_dependency_key not in captured_scope_state # pyright: ignore diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_sync.py b/tests/contrib/sqlalchemy/init_plugin/config/test_sync.py new file mode 100644 index 0000000000..8c8835a438 --- /dev/null +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_sync.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from starlite import get +from starlite.contrib.sqlalchemy.init_plugin import SQLAlchemyInitPlugin, SQLAlchemySyncConfig +from starlite.testing import create_test_client + +if TYPE_CHECKING: + from typing import Any + + from sqlalchemy.orm import Session + + from starlite.types import Scope + + +def test_default_before_send_handler() -> None: + """Test default_before_send_handler.""" + + captured_scope_state: dict[str, Any] | None = None + config = SQLAlchemySyncConfig(connection_string="sqlite+aiosqlite://") + plugin = SQLAlchemyInitPlugin(config=config) + + @get() + def test_handler(db_session: Session, scope: Scope) -> None: + nonlocal captured_scope_state + captured_scope_state = scope["state"] + assert db_session is captured_scope_state[config.session_dependency_key] + + with create_test_client(route_handlers=[test_handler], plugins=[plugin]) as client: + client.get("/") + assert captured_scope_state is not None + assert config.session_dependency_key not in captured_scope_state From 3f8bc315203de6cdb59627e52a1191c0b23d4e4b Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 13:47:03 +1000 Subject: [PATCH 11/16] Test for engine json serializer. --- .../contrib/sqlalchemy/init_plugin/config/test_engine.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 tests/contrib/sqlalchemy/init_plugin/config/test_engine.py diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_engine.py b/tests/contrib/sqlalchemy/init_plugin/config/test_engine.py new file mode 100644 index 0000000000..dc72c40beb --- /dev/null +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_engine.py @@ -0,0 +1,8 @@ +from __future__ import annotations + +from starlite.contrib.sqlalchemy.init_plugin.config.engine import serializer + + +def test_serializer_returns_string() -> None: + """Test that serializer returns a string.""" + assert isinstance(serializer({"a": "b"}), str) From 7fe22e70e5a3b169c511eb3bcad0738d07b090e4 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Mon, 27 Mar 2023 13:51:48 +1000 Subject: [PATCH 12/16] Update starlite/contrib/sqlalchemy/init_plugin/config/common.py --- starlite/contrib/sqlalchemy/init_plugin/config/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/common.py b/starlite/contrib/sqlalchemy/init_plugin/config/common.py index a0a25d7eeb..694f6406eb 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/common.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/common.py @@ -142,7 +142,7 @@ def signature_namespace(self) -> dict[str, Any]: Returns: A string keyed dict of names to be added to the namespace for signature forward reference resolution. """ - return {} + return {} # pragma: no cover def create_engine(self) -> EngineT: """Return an engine. If none exists yet, create one. From d2e4270a2bf69a04d0a3aa56717d5d99a706b2ab Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Tue, 28 Mar 2023 18:56:13 +1000 Subject: [PATCH 13/16] Use refactored dataclass utils. --- starlite/contrib/sqlalchemy/init_plugin/config/common.py | 6 +++--- starlite/utils/dataclass.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/common.py b/starlite/contrib/sqlalchemy/init_plugin/config/common.py index 694f6406eb..f8c657d2b3 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/common.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/common.py @@ -7,7 +7,7 @@ from starlite.exceptions import ImproperlyConfiguredException from starlite.types import Empty from starlite.utils import get_starlite_scope_state, set_starlite_scope_state -from starlite.utils.dataclass import simple_asdict_filter_empty +from starlite.utils.dataclass import simple_asdict from .engine import EngineConfig @@ -124,7 +124,7 @@ def engine_config_dict(self) -> dict[str, Any]: Returns: A string keyed dict of config kwargs for the SQLAlchemy ``create_engine`` function. """ - return simple_asdict_filter_empty(self.engine_config) + return simple_asdict(self.engine_config, exclude_empty=True) @property def session_config_dict(self) -> dict[str, Any]: @@ -133,7 +133,7 @@ def session_config_dict(self) -> dict[str, Any]: Returns: A string keyed dict of config kwargs for the SQLAlchemy ``sessionmaker`` class. """ - return simple_asdict_filter_empty(self.session_config) + return simple_asdict(self.session_config, exclude_empty=True) @property def signature_namespace(self) -> dict[str, Any]: diff --git a/starlite/utils/dataclass.py b/starlite/utils/dataclass.py index 61dbb50794..e8eb599c4f 100644 --- a/starlite/utils/dataclass.py +++ b/starlite/utils/dataclass.py @@ -12,6 +12,7 @@ __all__ = ( "extract_dataclass_fields", "extract_dataclass_items", + "simple_asdict", ) From 3b69070209d5440ece72aab3b309251364404659 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 29 Mar 2023 09:58:07 +1000 Subject: [PATCH 14/16] Address review items. --- docs/conf.py | 2 ++ docs/reference/contrib/sqlalchemy/config.rst | 5 +++++ .../contrib/sqlalchemy/config/asyncio.rst | 5 ----- .../contrib/sqlalchemy/config/common.rst | 5 ----- .../reference/contrib/sqlalchemy/config/index.rst | 9 --------- docs/reference/contrib/sqlalchemy/config/sync.rst | 5 ----- docs/reference/contrib/sqlalchemy/index.rst | 2 +- .../sqlalchemy/init_plugin/config/asyncio.py | 7 ++----- .../sqlalchemy/init_plugin/config/common.py | 15 ++++++++------- .../contrib/sqlalchemy/init_plugin/config/sync.py | 5 +---- starlite/contrib/sqlalchemy/init_plugin/plugin.py | 2 +- .../sqlalchemy/init_plugin/config/test_common.py | 5 ++++- 12 files changed, 24 insertions(+), 43 deletions(-) create mode 100644 docs/reference/contrib/sqlalchemy/config.rst delete mode 100644 docs/reference/contrib/sqlalchemy/config/asyncio.rst delete mode 100644 docs/reference/contrib/sqlalchemy/config/common.rst delete mode 100644 docs/reference/contrib/sqlalchemy/config/index.rst delete mode 100644 docs/reference/contrib/sqlalchemy/config/sync.rst diff --git a/docs/conf.py b/docs/conf.py index 99cab96418..869e117e9e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,6 +82,8 @@ # intentionally undocumented ("py:class", "NoneType"), ("py:class", "starlite._signature.models.SignatureField"), + ("py:class", "starlite.contrib.sqlalchemy.init_plugin.config.common.GenericSessionConfig"), + ("py:class", "starlite.contrib.sqlalchemy.init_plugin.config.common.GenericSQLAlchemyConfig"), ] nitpick_ignore_regex = [ (r"py:.*", r"starlite\.types.*"), diff --git a/docs/reference/contrib/sqlalchemy/config.rst b/docs/reference/contrib/sqlalchemy/config.rst new file mode 100644 index 0000000000..64a3568eb2 --- /dev/null +++ b/docs/reference/contrib/sqlalchemy/config.rst @@ -0,0 +1,5 @@ +config +====== + +.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config + :members: diff --git a/docs/reference/contrib/sqlalchemy/config/asyncio.rst b/docs/reference/contrib/sqlalchemy/config/asyncio.rst deleted file mode 100644 index e138610870..0000000000 --- a/docs/reference/contrib/sqlalchemy/config/asyncio.rst +++ /dev/null @@ -1,5 +0,0 @@ -asyncio -======= - -.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.asyncio - :members: diff --git a/docs/reference/contrib/sqlalchemy/config/common.rst b/docs/reference/contrib/sqlalchemy/config/common.rst deleted file mode 100644 index e8e2467cf7..0000000000 --- a/docs/reference/contrib/sqlalchemy/config/common.rst +++ /dev/null @@ -1,5 +0,0 @@ -asyncio -======= - -.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.common - :members: diff --git a/docs/reference/contrib/sqlalchemy/config/index.rst b/docs/reference/contrib/sqlalchemy/config/index.rst deleted file mode 100644 index 1919a60794..0000000000 --- a/docs/reference/contrib/sqlalchemy/config/index.rst +++ /dev/null @@ -1,9 +0,0 @@ -config -====== - -.. toctree:: - :titlesonly: - - asyncio - common - sync diff --git a/docs/reference/contrib/sqlalchemy/config/sync.rst b/docs/reference/contrib/sqlalchemy/config/sync.rst deleted file mode 100644 index ddf769ccf1..0000000000 --- a/docs/reference/contrib/sqlalchemy/config/sync.rst +++ /dev/null @@ -1,5 +0,0 @@ -sync -==== - -.. automodule:: starlite.contrib.sqlalchemy.init_plugin.config.sync - :members: diff --git a/docs/reference/contrib/sqlalchemy/index.rst b/docs/reference/contrib/sqlalchemy/index.rst index 3ee4e8d69b..698f6bdda2 100644 --- a/docs/reference/contrib/sqlalchemy/index.rst +++ b/docs/reference/contrib/sqlalchemy/index.rst @@ -4,5 +4,5 @@ sqlalchemy .. toctree:: :titlesonly: - config/index + config plugin diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py index 89ab40b3f5..ff09b6a0dc 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py @@ -57,16 +57,13 @@ class SQLAlchemyAsyncConfig(GenericSQLAlchemyConfig[AsyncEngine, AsyncSession, a subclass. """ session_config: AsyncSessionConfig = field(default_factory=AsyncSessionConfig) - """Configuration options for the ``sessionmaker``. - - The configuration options are documented in the SQLAlchemy documentation. - """ + """Configuration options for the :class:`async_sessionmaker`.""" session_maker_class: type[async_sessionmaker] = async_sessionmaker """Sessionmaker class to use.""" before_send_handler: BeforeMessageSendHookHandler = default_before_send_handler """Handler to call before the ASGI message is sent. - The handler should handle closing the session stored in the ASGI scope, if its still open, and committing and + The handler should handle closing the session stored in the ASGI scope, if it's still open, and committing and uncommitted data. """ diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/common.py b/starlite/contrib/sqlalchemy/init_plugin/config/common.py index f8c657d2b3..b15f990c25 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/common.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/common.py @@ -65,9 +65,8 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]): subclass. """ session_config: GenericSessionConfig - """Configuration options for the ``sessionmaker``. - - The configuration options are documented in the SQLAlchemy documentation. + """Configuration options for either the :class:`async_sessionmaker` + or :class:`sessionmaker`. """ session_maker_class: type[sessionmaker] | type[async_sessionmaker] """Sessionmaker class to use.""" @@ -99,7 +98,7 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]): The configuration options are documented in the SQLAlchemy documentation. """ session_maker_app_state_key: str = "session_maker_class" - """Key under which to store the SQLAlchemy ``sessionmaker`` in the application + """Key under which to store the SQLAlchemy :class:`sessionmaker` in the application :class:`State <.datastructures.State>` instance. """ session_maker: Callable[[], SessionT] | None = None @@ -122,7 +121,8 @@ def engine_config_dict(self) -> dict[str, Any]: """Return the engine configuration as a dict. Returns: - A string keyed dict of config kwargs for the SQLAlchemy ``create_engine`` function. + A string keyed dict of config kwargs for the SQLAlchemy :class:`create_engine` + function. """ return simple_asdict(self.engine_config, exclude_empty=True) @@ -131,7 +131,8 @@ def session_config_dict(self) -> dict[str, Any]: """Return the session configuration as a dict. Returns: - A string keyed dict of config kwargs for the SQLAlchemy ``sessionmaker`` class. + A string keyed dict of config kwargs for the SQLAlchemy :class:`sessionmaker` + class. """ return simple_asdict(self.session_config, exclude_empty=True) @@ -207,7 +208,7 @@ def provide_session(self, state: State, scope: Scope) -> SessionT: set_starlite_scope_state(scope, SESSION_SCOPE_KEY, session) return session - def app_state(self) -> dict[str, Any]: + def create_app_state_items(self) -> dict[str, Any]: """Key/value pairs to be stored in application state.""" return { self.engine_app_state_key: self.create_engine(), diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/sync.py b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py index 1652d2b0b0..a0dbe93dcc 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/sync.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py @@ -52,10 +52,7 @@ class SQLAlchemySyncConfig(GenericSQLAlchemyConfig[Engine, Session, sessionmaker subclass. """ session_config: SyncSessionConfig = field(default_factory=SyncSessionConfig) # pyright:ignore - """Configuration options for the ``sessionmaker``. - - The configuration options are documented in the SQLAlchemy documentation. - """ + """Configuration options for the :class:`sessionmaker`.""" session_maker_class: type[sessionmaker] = sessionmaker """Sessionmaker class to use.""" before_send_handler: BeforeMessageSendHookHandler = default_before_send_handler diff --git a/starlite/contrib/sqlalchemy/init_plugin/plugin.py b/starlite/contrib/sqlalchemy/init_plugin/plugin.py index 5841f71bd7..fb620486b2 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/plugin.py +++ b/starlite/contrib/sqlalchemy/init_plugin/plugin.py @@ -40,6 +40,6 @@ def on_app_init(self, app_config: AppConfig) -> AppConfig: ) app_config.before_send.append(self._config.before_send_handler) app_config.on_shutdown.append(self._config.on_shutdown) - app_config.state.update(self._config.app_state()) + app_config.state.update(self._config.create_app_state_items()) app_config.signature_namespace.update(self._config.signature_namespace) return app_config diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_common.py b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py index 8c6c0cbe49..efe97dd0c2 100644 --- a/tests/contrib/sqlalchemy/init_plugin/config/test_common.py +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py @@ -161,6 +161,9 @@ def test_app_state(config_cls: type[SQLAlchemySyncConfig | SQLAlchemySyncConfig] with patch.object(config, "create_session_maker") as create_session_maker_mock, patch.object( config, "create_engine" ) as create_engine_mock: - assert config.app_state().keys() == {config.engine_app_state_key, config.session_maker_app_state_key} + assert config.create_app_state_items().keys() == { + config.engine_app_state_key, + config.session_maker_app_state_key, + } create_session_maker_mock.assert_called_once() create_engine_mock.assert_called_once() From af45e50007c4185f19bfde6c11e123ee623923e5 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Wed, 29 Mar 2023 19:21:04 +1000 Subject: [PATCH 15/16] Makes `config.common` module private, expands docs. --- docs/conf.py | 4 +-- .../config/{common.py => _common.py} | 27 ++++++++++++++++++- .../sqlalchemy/init_plugin/config/asyncio.py | 2 +- .../sqlalchemy/init_plugin/config/sync.py | 2 +- .../init_plugin/config/test_common.py | 6 ++--- 5 files changed, 33 insertions(+), 8 deletions(-) rename starlite/contrib/sqlalchemy/init_plugin/config/{common.py => _common.py} (77%) diff --git a/docs/conf.py b/docs/conf.py index 869e117e9e..46e3315aac 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,8 +82,8 @@ # intentionally undocumented ("py:class", "NoneType"), ("py:class", "starlite._signature.models.SignatureField"), - ("py:class", "starlite.contrib.sqlalchemy.init_plugin.config.common.GenericSessionConfig"), - ("py:class", "starlite.contrib.sqlalchemy.init_plugin.config.common.GenericSQLAlchemyConfig"), + ("py:class", "starlite.contrib.sqlalchemy.init_plugin.config._common.GenericSessionConfig"), + ("py:class", "starlite.contrib.sqlalchemy.init_plugin.config._common.GenericSQLAlchemyConfig"), ] nitpick_ignore_regex = [ (r"py:.*", r"starlite\.types.*"), diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/common.py b/starlite/contrib/sqlalchemy/init_plugin/config/_common.py similarity index 77% rename from starlite/contrib/sqlalchemy/init_plugin/config/common.py rename to starlite/contrib/sqlalchemy/init_plugin/config/_common.py index b15f990c25..46ec079ccb 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/common.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/_common.py @@ -44,16 +44,41 @@ class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]): """SQLAlchemy async session config.""" autobegin: bool | EmptyType = Empty + """Automatically start transactions when database access is requested by an operation.""" autoflush: bool | EmptyType = Empty + """When ``True``, all query operations will issue a flush call to this :class:`Session ` + before proceeding""" bind: EngineT | ConnectionT | None | EmptyType = Empty + """The :class:`Engine ` or :class:`Connection ` that new + :class:`Session ` objects will be bound to.""" binds: dict[type[Any] | Mapper[Any] | TableClause | str, EngineT | ConnectionT] | None | EmptyType = Empty + """A dictionary which may specify any number of :class:`Engine ` or :class:`Connection + ` objects as the source of connectivity for SQL operations on a per-entity basis. The + keys of the dictionary consist of any series of mapped classes, arbitrary Python classes that are bases for mapped + classes, :class:`Table ` objects and :class:`Mapper ` objects. The + values of the dictionary are then instances of :class:`Engine ` or less commonly + :class:`Connection` objects.""" class_: type[SessionT] | EmptyType = Empty - enable_baked_queries: bool | EmptyType = Empty + """Class to use in order to create new :class:`Session ` objects.""" expire_on_commit: bool | EmptyType = Empty + """If ``True``, all instances will be expired after each commit.""" info: dict[str, Any] | None | EmptyType = Empty + """Optional dictionary of information that will be available via the + :attr:`Session.info `""" join_transaction_mode: JoinTransactionMode | EmptyType = Empty + """Describes the transactional behavior to take when a given bind is a Connection that has already begun a + transaction outside the scope of this Session; in other words the + :attr:`Connection.in_transaction()` method returns True.""" query_cls: type[Query] | None | EmptyType = Empty + """Class which should be used to create new Query objects, as returned by the + :attr:`Session.query()` method.""" twophase: bool | EmptyType = Empty + """When ``True``, all transactions will be started as a “two phase” transaction, i.e. using the “two phase” + semantics of the database in use along with an XID. During a :attr:`commit()`, after + :attr:`flush()` has been issued for all attached databases, the + :attr:`TwoPhaseTransaction.prepare()` method on each database`s + :class:`TwoPhaseTransaction` will be called. This allows each database to + roll back the entire transaction, before each transaction is committed.""" @dataclass diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py index ff09b6a0dc..ab86f67282 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py @@ -11,7 +11,7 @@ get_starlite_scope_state, ) -from .common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig +from ._common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig if TYPE_CHECKING: from typing import Any, Callable diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/sync.py b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py index a0dbe93dcc..1450c298d0 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/sync.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py @@ -11,7 +11,7 @@ get_starlite_scope_state, ) -from .common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig +from ._common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig if TYPE_CHECKING: from typing import Any, Callable diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_common.py b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py index efe97dd0c2..f2934ad7f0 100644 --- a/tests/contrib/sqlalchemy/init_plugin/config/test_common.py +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py @@ -8,7 +8,7 @@ from starlite.constants import SCOPE_STATE_NAMESPACE from starlite.contrib.sqlalchemy.init_plugin.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig -from starlite.contrib.sqlalchemy.init_plugin.config.common import SESSION_SCOPE_KEY +from starlite.contrib.sqlalchemy.init_plugin.config._common import SESSION_SCOPE_KEY from starlite.datastructures import State from starlite.exceptions import ImproperlyConfiguredException @@ -131,7 +131,7 @@ def test_create_session_instance_if_session_already_in_scope_state( ) -> None: """Test provide_session if session already in scope state.""" with patch( - "starlite.contrib.sqlalchemy.init_plugin.config.common.get_starlite_scope_state" + "starlite.contrib.sqlalchemy.init_plugin.config._common.get_starlite_scope_state" ) as get_starlite_scope_state_mock: session_mock = MagicMock() get_starlite_scope_state_mock.return_value = session_mock @@ -144,7 +144,7 @@ def test_create_session_instance_if_session_not_in_scope_state( ) -> None: """Test provide_session if session not in scope state.""" with patch( - "starlite.contrib.sqlalchemy.init_plugin.config.common.get_starlite_scope_state" + "starlite.contrib.sqlalchemy.init_plugin.config._common.get_starlite_scope_state" ) as get_starlite_scope_state_mock: get_starlite_scope_state_mock.return_value = None config = config_cls() From 29b087ed1f8304be0ec24f96a76326e4a4e66ef8 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Thu, 30 Mar 2023 15:32:41 +1000 Subject: [PATCH 16/16] Improve docs. --- docs/conf.py | 4 +- .../sqlalchemy/init_plugin/config/__init__.py | 3 + .../sqlalchemy/init_plugin/config/asyncio.py | 7 +- .../config/{_common.py => common.py} | 24 ++--- .../sqlalchemy/init_plugin/config/engine.py | 95 ++++++++++++++++++- .../sqlalchemy/init_plugin/config/sync.py | 2 +- .../init_plugin/config/test_common.py | 6 +- 7 files changed, 120 insertions(+), 21 deletions(-) rename starlite/contrib/sqlalchemy/init_plugin/config/{_common.py => common.py} (91%) diff --git a/docs/conf.py b/docs/conf.py index 46e3315aac..b92bbeb688 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,8 +82,6 @@ # intentionally undocumented ("py:class", "NoneType"), ("py:class", "starlite._signature.models.SignatureField"), - ("py:class", "starlite.contrib.sqlalchemy.init_plugin.config._common.GenericSessionConfig"), - ("py:class", "starlite.contrib.sqlalchemy.init_plugin.config._common.GenericSQLAlchemyConfig"), ] nitpick_ignore_regex = [ (r"py:.*", r"starlite\.types.*"), @@ -115,7 +113,7 @@ "starlite.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin.handle_string_type": {"BINARY", "VARBINARY", "LargeBinary"}, "starlite.contrib.sqlalchemy_1.plugin.SQLAlchemyPlugin.is_plugin_supported_type": {"DeclarativeMeta"}, re.compile(r"starlite\.plugins.*"): re.compile(".*(ModelT|DataContainerT)"), - re.compile(r"starlite\.contrib\.sqlalchemy\.init_plugin\.config\.common.*"): re.compile( + re.compile(r"starlite\.contrib\.sqlalchemy\.init_plugin\.config.*"): re.compile( ".*(ConnectionT|EngineT|SessionT|SessionMakerT)" ), } diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py b/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py index deff28a94c..f2e39da99e 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/__init__.py @@ -1,12 +1,15 @@ from __future__ import annotations from .asyncio import AsyncSessionConfig, SQLAlchemyAsyncConfig +from .common import GenericSessionConfig, GenericSQLAlchemyConfig from .engine import EngineConfig from .sync import SQLAlchemySyncConfig, SyncSessionConfig __all__ = ( "AsyncSessionConfig", "EngineConfig", + "GenericSQLAlchemyConfig", + "GenericSessionConfig", "SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig", "SyncSessionConfig", diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py index ab86f67282..640e706f19 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/asyncio.py @@ -11,7 +11,7 @@ get_starlite_scope_state, ) -from ._common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig +from .common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig if TYPE_CHECKING: from typing import Any, Callable @@ -46,6 +46,11 @@ class AsyncSessionConfig(GenericSessionConfig[AsyncConnection, AsyncEngine, Asyn """SQLAlchemy async session config.""" sync_session_class: type[Session] | None | EmptyType = Empty + """A :class:`Session ` subclass or other callable which will be used to construct the + :class:`Session ` which will be proxied. This parameter may be used to provide custom + :class:`Session ` subclasses. Defaults to the + :attr:`AsyncSession.sync_session_class ` class-level + attribute.""" @dataclass diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/_common.py b/starlite/contrib/sqlalchemy/init_plugin/config/common.py similarity index 91% rename from starlite/contrib/sqlalchemy/init_plugin/config/_common.py rename to starlite/contrib/sqlalchemy/init_plugin/config/common.py index 46ec079ccb..6413153529 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/_common.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/common.py @@ -57,7 +57,7 @@ class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]): keys of the dictionary consist of any series of mapped classes, arbitrary Python classes that are bases for mapped classes, :class:`Table ` objects and :class:`Mapper ` objects. The values of the dictionary are then instances of :class:`Engine ` or less commonly - :class:`Connection` objects.""" + :class:`Connection ` objects.""" class_: type[SessionT] | EmptyType = Empty """Class to use in order to create new :class:`Session ` objects.""" expire_on_commit: bool | EmptyType = Empty @@ -68,16 +68,16 @@ class GenericSessionConfig(Generic[ConnectionT, EngineT, SessionT]): join_transaction_mode: JoinTransactionMode | EmptyType = Empty """Describes the transactional behavior to take when a given bind is a Connection that has already begun a transaction outside the scope of this Session; in other words the - :attr:`Connection.in_transaction()` method returns True.""" + :attr:`Connection.in_transaction() ` method returns True.""" query_cls: type[Query] | None | EmptyType = Empty """Class which should be used to create new Query objects, as returned by the - :attr:`Session.query()` method.""" + :attr:`Session.query() ` method.""" twophase: bool | EmptyType = Empty """When ``True``, all transactions will be started as a “two phase” transaction, i.e. using the “two phase” - semantics of the database in use along with an XID. During a :attr:`commit()`, after - :attr:`flush()` has been issued for all attached databases, the - :attr:`TwoPhaseTransaction.prepare()` method on each database`s - :class:`TwoPhaseTransaction` will be called. This allows each database to + semantics of the database in use along with an XID. During a :attr:`commit() `, after + :attr:`flush() ` has been issued for all attached databases, the + :attr:`TwoPhaseTransaction.prepare() ` method on each database`s + :class:`TwoPhaseTransaction ` will be called. This allows each database to roll back the entire transaction, before each transaction is committed.""" @@ -90,8 +90,8 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]): subclass. """ session_config: GenericSessionConfig - """Configuration options for either the :class:`async_sessionmaker` - or :class:`sessionmaker`. + """Configuration options for either the :class:`async_sessionmaker ` + or :class:`sessionmaker `. """ session_maker_class: type[sessionmaker] | type[async_sessionmaker] """Sessionmaker class to use.""" @@ -123,7 +123,7 @@ class GenericSQLAlchemyConfig(Generic[EngineT, SessionT, SessionMakerT]): The configuration options are documented in the SQLAlchemy documentation. """ session_maker_app_state_key: str = "session_maker_class" - """Key under which to store the SQLAlchemy :class:`sessionmaker` in the application + """Key under which to store the SQLAlchemy :class:`sessionmaker ` in the application :class:`State <.datastructures.State>` instance. """ session_maker: Callable[[], SessionT] | None = None @@ -146,7 +146,7 @@ def engine_config_dict(self) -> dict[str, Any]: """Return the engine configuration as a dict. Returns: - A string keyed dict of config kwargs for the SQLAlchemy :class:`create_engine` + A string keyed dict of config kwargs for the SQLAlchemy :func:`create_engine ` function. """ return simple_asdict(self.engine_config, exclude_empty=True) @@ -156,7 +156,7 @@ def session_config_dict(self) -> dict[str, Any]: """Return the session configuration as a dict. Returns: - A string keyed dict of config kwargs for the SQLAlchemy :class:`sessionmaker` + A string keyed dict of config kwargs for the SQLAlchemy :class:`sessionmaker ` class. """ return simple_asdict(self.session_config, exclude_empty=True) diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/engine.py b/starlite/contrib/sqlalchemy/init_plugin/config/engine.py index 71659bc3ef..aab9e9c5a8 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/engine.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/engine.py @@ -41,36 +41,129 @@ def serializer(value: Any) -> str: @dataclass class EngineConfig: - """Configuration for SQLAlchemy's :class`Engine `. + """Configuration for SQLAlchemy's :class:`Engine `. For details see: https://docs.sqlalchemy.org/en/20/core/engines.html """ connect_args: dict[Any, Any] | EmptyType = Empty + """A dictionary of arguments which will be passed directly to the DBAPI's ``connect()`` method as keyword arguments. + """ echo: _EchoFlagType | EmptyType = Empty + """If ``True``, the Engine will log all statements as well as a ``repr()`` of their parameter lists to the default + log handler, which defaults to ``sys.stdout`` for output. If set to the string "debug", result rows will be printed + to the standard output as well. The echo attribute of Engine can be modified at any time to turn logging on and off; + direct control of logging is also available using the standard Python logging module. + """ echo_pool: _EchoFlagType | EmptyType = Empty + """If ``True``, the connection pool will log informational output such as when connections are invalidated as well + as when connections are recycled to the default log handler, which defaults to sys.stdout for output. If set to the + string "debug", the logging will include pool checkouts and checkins. Direct control of logging is also available + using the standard Python logging module.""" enable_from_linting: bool | EmptyType = Empty + """Defaults to True. Will emit a warning if a given SELECT statement is found to have un-linked FROM elements which + would cause a cartesian product.""" execution_options: Mapping[str, Any] | EmptyType = Empty + """Dictionary execution options which will be applied to all connections. See + :attr:`Connection.execution_options() ` for details.""" hide_parameters: bool | EmptyType = Empty + """Boolean, when set to ``True``, SQL statement parameters will not be displayed in INFO logging nor will they be + formatted into the string representation of :class:`StatementError ` objects.""" insertmanyvalues_page_size: int | EmptyType = Empty + """Number of rows to format into an INSERT statement when the statement uses “insertmanyvalues” mode, which is a + paged form of bulk insert that is used for many backends when using executemany execution typically in conjunction + with RETURNING. Defaults to 1000, but may also be subject to dialect-specific limiting factors which may override + this value on a per-statement basis.""" isolation_level: IsolationLevel | EmptyType = Empty + """Optional string name of an isolation level which will be set on all new connections unconditionally. Isolation + levels are typically some subset of the string names "SERIALIZABLE", "REPEATABLE READ", "READ COMMITTED", + "READ UNCOMMITTED" and "AUTOCOMMIT" based on backend.""" json_deserializer: Callable[[str], Any] = decode_json + """For dialects that support the :class:`JSON ` datatype, this is a Python callable that will + convert a JSON string to a Python object. By default, this is set to Starlite's + :attr:`decode_json() <.serialization.decode_json>` function.""" json_serializer: Callable[[Any], str] = serializer + """For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON. + By default, Starlite's :attr:`encode_json() <.serialization.encode_json>` is used.""" label_length: int | None | EmptyType = Empty + """Optional integer value which limits the size of dynamically generated column labels to that many characters. If + less than 6, labels are generated as “_(counter)”. If ``None``, the value of ``dialect.max_identifier_length``, + which may be affected via the + :attr:`create_engine.max_identifier_length parameter `, is + used instead. The value of + :attr:`create_engine.label_length ` may not be larger than that of + :attr:`create_engine.max_identfier_length `.""" logging_name: str | EmptyType = Empty + """String identifier which will be used within the “name” field of logging records generated within the + “sqlalchemy.engine” logger. Defaults to a hexstring of the object`s id.""" max_identifier_length: int | None | EmptyType = Empty + """Override the max_identifier_length determined by the dialect. if ``None`` or ``0``, has no effect. This is the + database`s configured maximum number of characters that may be used in a SQL identifier such as a table name, column + name, or label name. All dialects determine this value automatically, however in the case of a new database version + for which this value has changed but SQLAlchemy`s dialect has not been adjusted, the value may be passed here.""" max_overflow: int | EmptyType = Empty + """The number of connections to allow in connection pool “overflow”, that is connections that can be opened above + and beyond the pool_size setting, which defaults to five. This is only used with + :class:`QueuePool `.""" module: Any | None | EmptyType = Empty + """Reference to a Python module object (the module itself, not its string name). Specifies an alternate DBAPI module + to be used by the engine`s dialect. Each sub-dialect references a specific DBAPI which will be imported before first + connect. This parameter causes the import to be bypassed, and the given module to be used instead. Can be used for + testing of DBAPIs as well as to inject “mock” DBAPI implementations into the + :class:`Engine `.""" paramstyle: _ParamStyle | None | EmptyType = Empty + """The paramstyle to use when rendering bound parameters. This style defaults to the one recommended by the DBAPI + itself, which is retrieved from the ``.paramstyle`` attribute of the DBAPI. However, most DBAPIs accept more than + one paramstyle, and in particular it may be desirable to change a “named” paramstyle into a “positional” one, or + vice versa. When this attribute is passed, it should be one of the values "qmark", "numeric", "named", "format" or + "pyformat", and should correspond to a parameter style known to be supported by the DBAPI in use.""" pool: Pool | None | EmptyType = Empty + """An already-constructed instance of :class:`Pool `, such as a + :class:`QueuePool ` instance. If non-None, this pool will be used directly as the + underlying connection pool for the engine, bypassing whatever connection parameters are present in the URL argument. + For information on constructing connection pools manually, see + `Connection Pooling `_.""" poolclass: type[Pool] | None | EmptyType = Empty + """A :class:`Pool ` subclass, which will be used to create a connection pool instance using + the connection parameters given in the URL. Note this differs from pool in that you don`t actually instantiate the + pool in this case, you just indicate what type of pool to be used.""" pool_logging_name: str | EmptyType = Empty + """String identifier which will be used within the “name” field of logging records generated within the + “sqlalchemy.pool” logger. Defaults to a hexstring of the object`s id.""" pool_pre_ping: bool | EmptyType = Empty + """If True will enable the connection pool “pre-ping” feature that tests connections for liveness upon each + checkout.""" pool_size: int | EmptyType = Empty + """The number of connections to keep open inside the connection pool. This used with + :class:`QueuePool ` as well as + :class:`SingletonThreadPool `. With + :class:`QueuePool `, a pool_size setting of ``0`` indicates no limit; to disable pooling, + set ``poolclass`` to :class:`NullPool ` instead.""" pool_recycle: int | EmptyType = Empty + """This setting causes the pool to recycle connections after the given number of seconds has passed. It defaults to + ``-1``, or no timeout. For example, setting to ``3600`` means connections will be recycled after one hour. Note that + MySQL in particular will disconnect automatically if no activity is detected on a connection for eight hours + (although this is configurable with the MySQLDB connection itself and the server configuration as well).""" pool_reset_on_return: Literal["rollback", "commit"] | EmptyType = Empty + """Set the :attr:`Pool.reset_on_return ` object, which can be set to the values ``"rollback"``, ``"commit"``, or + ``None``.""" pool_timeout: int | EmptyType = Empty + """Number of seconds to wait before giving up on getting a connection from the pool. This is only used with + :class:`QueuePool `. This can be a float but is subject to the limitations of Python time + functions which may not be reliable in the tens of milliseconds.""" pool_use_lifo: bool | EmptyType = Empty + """Use LIFO (last-in-first-out) when retrieving connections from :class:`QueuePool ` + instead of FIFO (first-in-first-out). Using LIFO, a server-side timeout scheme can reduce the number of connections + used during non-peak periods of use. When planning for server-side timeouts, ensure that a recycle or pre-ping + strategy is in use to gracefully handle stale connections.""" plugins: list[str] | EmptyType = Empty + """String list of plugin names to load. See :class:`CreateEnginePlugin ` for + background.""" query_cache_size: int | EmptyType = Empty + """Size of the cache used to cache the SQL string form of queries. Set to zero to disable caching. + + See :attr:`query_cache_size ` for more info. + """ use_insertmanyvalues: bool | EmptyType = Empty + """``True`` by default, use the “insertmanyvalues” execution style for INSERT..RETURNING statements by default.""" diff --git a/starlite/contrib/sqlalchemy/init_plugin/config/sync.py b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py index 1450c298d0..a0dbe93dcc 100644 --- a/starlite/contrib/sqlalchemy/init_plugin/config/sync.py +++ b/starlite/contrib/sqlalchemy/init_plugin/config/sync.py @@ -11,7 +11,7 @@ get_starlite_scope_state, ) -from ._common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig +from .common import SESSION_SCOPE_KEY, SESSION_TERMINUS_ASGI_EVENTS, GenericSessionConfig, GenericSQLAlchemyConfig if TYPE_CHECKING: from typing import Any, Callable diff --git a/tests/contrib/sqlalchemy/init_plugin/config/test_common.py b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py index f2934ad7f0..efe97dd0c2 100644 --- a/tests/contrib/sqlalchemy/init_plugin/config/test_common.py +++ b/tests/contrib/sqlalchemy/init_plugin/config/test_common.py @@ -8,7 +8,7 @@ from starlite.constants import SCOPE_STATE_NAMESPACE from starlite.contrib.sqlalchemy.init_plugin.config import SQLAlchemyAsyncConfig, SQLAlchemySyncConfig -from starlite.contrib.sqlalchemy.init_plugin.config._common import SESSION_SCOPE_KEY +from starlite.contrib.sqlalchemy.init_plugin.config.common import SESSION_SCOPE_KEY from starlite.datastructures import State from starlite.exceptions import ImproperlyConfiguredException @@ -131,7 +131,7 @@ def test_create_session_instance_if_session_already_in_scope_state( ) -> None: """Test provide_session if session already in scope state.""" with patch( - "starlite.contrib.sqlalchemy.init_plugin.config._common.get_starlite_scope_state" + "starlite.contrib.sqlalchemy.init_plugin.config.common.get_starlite_scope_state" ) as get_starlite_scope_state_mock: session_mock = MagicMock() get_starlite_scope_state_mock.return_value = session_mock @@ -144,7 +144,7 @@ def test_create_session_instance_if_session_not_in_scope_state( ) -> None: """Test provide_session if session not in scope state.""" with patch( - "starlite.contrib.sqlalchemy.init_plugin.config._common.get_starlite_scope_state" + "starlite.contrib.sqlalchemy.init_plugin.config.common.get_starlite_scope_state" ) as get_starlite_scope_state_mock: get_starlite_scope_state_mock.return_value = None config = config_cls()