From 3fb0c8322c52b1d80be73b4ecdbc09769157e177 Mon Sep 17 00:00:00 2001 From: Arthur Loiselle Date: Wed, 1 Nov 2023 13:06:23 -0400 Subject: [PATCH 1/3] feat: split the main modules into multiple modules - DIA-61984 --- README.md | 8 +- fastapi_sqla/__init__.py | 50 ++-------- fastapi_sqla/async_pagination.py | 87 ++++++++++++++++ fastapi_sqla/asyncio_support.py | 91 +---------------- fastapi_sqla/base.py | 33 +++++++ fastapi_sqla/models.py | 25 +++-- fastapi_sqla/pagination.py | 146 +++++++++++++++++++++++++++ fastapi_sqla/sqla.py | 165 +------------------------------ tests/test_base.py | 6 +- tests/test_open_session.py | 3 +- 10 files changed, 313 insertions(+), 301 deletions(-) create mode 100644 fastapi_sqla/async_pagination.py create mode 100644 fastapi_sqla/base.py create mode 100644 fastapi_sqla/pagination.py diff --git a/README.md b/README.md index eadf00d..769101a 100644 --- a/README.md +++ b/README.md @@ -158,8 +158,7 @@ exception occurred: ```python from fastapi import APIRouter, Depends -from fastapi_sqla import Session -from fastapi_sqla.asyncio_support import AsyncSession +from fastapi_sqla import Session, AsyncSession router = APIRouter() @@ -183,8 +182,7 @@ occurred: ```python from fastapi import APIRouter, BackgroundTasks -from fastapi_sqla import open_session -from fastapi_sqla import asyncio_support +from fastapi_sqla import open_session, open_async_session router = APIRouter() @@ -201,7 +199,7 @@ def run_bg(): async def run_async_bg(): - async with asyncio_support.open_session() as session: + async with open_async_session() as session: await session.scalar("SELECT now()") ``` diff --git a/fastapi_sqla/__init__.py b/fastapi_sqla/__init__.py index 2f59589..b7ff68c 100644 --- a/fastapi_sqla/__init__.py +++ b/fastapi_sqla/__init__.py @@ -1,17 +1,7 @@ -import os - -from fastapi import FastAPI - -from fastapi_sqla import sqla -from fastapi_sqla.models import Collection, Item, Page -from fastapi_sqla.sqla import ( - Base, - Paginate, - PaginateSignature, - Pagination, - Session, - open_session, -) +from fastapi_sqla.base import setup +from fastapi_sqla.models import Base, Collection, Item, Page +from fastapi_sqla.pagination import Paginate, PaginateSignature, Pagination +from fastapi_sqla.sqla import Session, open_session __all__ = [ "Base", @@ -23,17 +13,14 @@ "Pagination", "Session", "open_session", + "setup", ] try: - from fastapi_sqla import asyncio_support - from fastapi_sqla.asyncio_support import ( # noqa - AsyncPaginate, - AsyncPagination, - AsyncSession, - ) - from fastapi_sqla.asyncio_support import open_session as open_async_session # noqa + from fastapi_sqla.async_pagination import AsyncPaginate, AsyncPagination + from fastapi_sqla.asyncio_support import AsyncSession + from fastapi_sqla.asyncio_support import open_session as open_async_session __all__ += [ "AsyncPaginate", @@ -43,22 +30,5 @@ ] has_asyncio_support = True -except ImportError as err: # pragma: no cover - has_asyncio_support = False - asyncio_support_err = str(err) - - -def setup(app: FastAPI): - engine = sqla.new_engine() - - if not sqla.is_async_dialect(engine): - app.add_event_handler("startup", sqla.startup) - app.middleware("http")(sqla.add_session_to_request) - - has_async_config = "async_sqlalchemy_url" in os.environ or sqla.is_async_dialect( - engine - ) - if has_async_config: - assert has_asyncio_support, asyncio_support_err - app.add_event_handler("startup", asyncio_support.startup) - app.middleware("http")(asyncio_support.add_session_to_request) +except ImportError: # pragma: no cover + pass diff --git a/fastapi_sqla/async_pagination.py b/fastapi_sqla/async_pagination.py new file mode 100644 index 0000000..272ac0d --- /dev/null +++ b/fastapi_sqla/async_pagination.py @@ -0,0 +1,87 @@ +import math +from collections.abc import Awaitable, Callable +from typing import Iterator, Optional, Union, cast + +from fastapi import Depends, Query +from sqlalchemy.sql import Select, func, select + +from fastapi_sqla.asyncio_support import AsyncSession +from fastapi_sqla.models import Page + +QueryCountDependency = Callable[..., Awaitable[int]] +PaginateSignature = Callable[[Select, Optional[bool]], Awaitable[Page]] +DefaultDependency = Callable[[AsyncSession, int, int], PaginateSignature] +WithQueryCountDependency = Callable[[AsyncSession, int, int, int], PaginateSignature] +PaginateDependency = Union[DefaultDependency, WithQueryCountDependency] + + +async def default_query_count(session: AsyncSession, query: Select) -> int: + result = await session.execute(select(func.count()).select_from(query.subquery())) + return cast(int, result.scalar()) + + +async def paginate_query( + query: Select, + session: AsyncSession, + total_items: int, + offset: int, + limit: int, + *, + scalars: bool = True, +) -> Page: + total_pages = math.ceil(total_items / limit) + page_number = offset / limit + 1 + query = query.offset(offset).limit(limit) + result = await session.execute(query) + data = iter( + cast(Iterator, result.unique().scalars() if scalars else result.mappings()) + ) + return Page( + data=data, + meta={ + "offset": offset, + "total_items": total_items, + "total_pages": total_pages, + "page_number": page_number, + }, + ) + + +def AsyncPagination( + min_page_size: int = 10, + max_page_size: int = 100, + query_count: Union[QueryCountDependency, None] = None, +) -> PaginateDependency: + def default_dependency( + session: AsyncSession = Depends(), + offset: int = Query(0, ge=0), + limit: int = Query(min_page_size, ge=1, le=max_page_size), + ) -> PaginateSignature: + async def paginate(query: Select, scalars=True) -> Page: + total_items = await default_query_count(session, query) + return await paginate_query( + query, session, total_items, offset, limit, scalars=scalars + ) + + return paginate + + def with_query_count_dependency( + session: AsyncSession = Depends(), + offset: int = Query(0, ge=0), + limit: int = Query(min_page_size, ge=1, le=max_page_size), + total_items: int = Depends(query_count), + ): + async def paginate(query: Select, scalars=True) -> Page: + return await paginate_query( + query, session, total_items, offset, limit, scalars=scalars + ) + + return paginate + + if query_count: + return with_query_count_dependency + else: + return default_dependency + + +AsyncPaginate: PaginateDependency = AsyncPagination() diff --git a/fastapi_sqla/asyncio_support.py b/fastapi_sqla/asyncio_support.py index 3674fd4..074f295 100644 --- a/fastapi_sqla/asyncio_support.py +++ b/fastapi_sqla/asyncio_support.py @@ -1,20 +1,18 @@ -import math import os -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager -from typing import Iterator, Optional, Union, cast +from typing import cast import structlog -from fastapi import Depends, Query, Request +from fastapi import Request from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession as SqlaAsyncSession from sqlalchemy.orm.session import sessionmaker -from sqlalchemy.sql import Select, func, select from fastapi_sqla import aws_aurora_support, aws_rds_iam_support -from fastapi_sqla.models import Page -from fastapi_sqla.sqla import Base, new_engine +from fastapi_sqla.models import Base +from fastapi_sqla.sqla import new_engine logger = structlog.get_logger(__name__) _ASYNC_SESSION_KEY = "fastapi_sqla_async_session" @@ -127,82 +125,3 @@ async def get_users(session: fastapi_sqla.AsyncSession = Depends()): await session.rollback() return response - - -QueryCountDependency = Callable[..., Awaitable[int]] -PaginateSignature = Callable[[Select, Optional[bool]], Awaitable[Page]] -DefaultDependency = Callable[[AsyncSession, int, int], PaginateSignature] -WithQueryCountDependency = Callable[[AsyncSession, int, int, int], PaginateSignature] -PaginateDependency = Union[DefaultDependency, WithQueryCountDependency] - - -async def default_query_count(session: AsyncSession, query: Select) -> int: - result = await session.execute(select(func.count()).select_from(query.subquery())) - return cast(int, result.scalar()) - - -async def paginate_query( - query: Select, - session: AsyncSession, - total_items: int, - offset: int, - limit: int, - *, - scalars: bool = True, -) -> Page: - total_pages = math.ceil(total_items / limit) - page_number = offset / limit + 1 - query = query.offset(offset).limit(limit) - result = await session.execute(query) - data = iter( - cast(Iterator, result.unique().scalars() if scalars else result.mappings()) - ) - return Page( - data=data, - meta={ - "offset": offset, - "total_items": total_items, - "total_pages": total_pages, - "page_number": page_number, - }, - ) - - -def AsyncPagination( - min_page_size: int = 10, - max_page_size: int = 100, - query_count: Union[QueryCountDependency, None] = None, -) -> PaginateDependency: - def default_dependency( - session: AsyncSession = Depends(), - offset: int = Query(0, ge=0), - limit: int = Query(min_page_size, ge=1, le=max_page_size), - ) -> PaginateSignature: - async def paginate(query: Select, scalars=True) -> Page: - total_items = await default_query_count(session, query) - return await paginate_query( - query, session, total_items, offset, limit, scalars=scalars - ) - - return paginate - - def with_query_count_dependency( - session: AsyncSession = Depends(), - offset: int = Query(0, ge=0), - limit: int = Query(min_page_size, ge=1, le=max_page_size), - total_items: int = Depends(query_count), - ): - async def paginate(query: Select, scalars=True) -> Page: - return await paginate_query( - query, session, total_items, offset, limit, scalars=scalars - ) - - return paginate - - if query_count: - return with_query_count_dependency - else: - return default_dependency - - -AsyncPaginate: PaginateDependency = AsyncPagination() diff --git a/fastapi_sqla/base.py b/fastapi_sqla/base.py new file mode 100644 index 0000000..c5b0ee2 --- /dev/null +++ b/fastapi_sqla/base.py @@ -0,0 +1,33 @@ +import os + +from fastapi import FastAPI +from sqlalchemy.engine import Engine + +from fastapi_sqla import sqla + +try: + from fastapi_sqla import asyncio_support + + has_asyncio_support = True + +except ImportError as err: # pragma: no cover + has_asyncio_support = False + asyncio_support_err = str(err) + + +def setup(app: FastAPI): + engine = sqla.new_engine() + + if not is_async_dialect(engine): + app.add_event_handler("startup", sqla.startup) + app.middleware("http")(sqla.add_session_to_request) + + has_async_config = "async_sqlalchemy_url" in os.environ or is_async_dialect(engine) + if has_async_config: + assert has_asyncio_support, asyncio_support_err + app.add_event_handler("startup", asyncio_support.startup) + app.middleware("http")(asyncio_support.add_session_to_request) + + +def is_async_dialect(engine: Engine): + return engine.dialect.is_async if hasattr(engine.dialect, "is_async") else False diff --git a/fastapi_sqla/models.py b/fastapi_sqla/models.py index b2874fd..ae68870 100644 --- a/fastapi_sqla/models.py +++ b/fastapi_sqla/models.py @@ -2,6 +2,14 @@ from pydantic import BaseModel, Field from pydantic import __version__ as pydantic_version +from sqlalchemy.ext.declarative import DeferredReflection + +try: + from sqlalchemy.orm import DeclarativeBase +except ImportError: + from sqlalchemy.ext.declarative import declarative_base + + DeclarativeBase = declarative_base() # type: ignore major, _, _ = [int(v) for v in pydantic_version.split(".")] is_pydantic2 = major == 2 @@ -10,19 +18,24 @@ else: from pydantic.generics import GenericModel # type:ignore -T = TypeVar("T") + +class Base(DeclarativeBase, DeferredReflection): + __abstract__ = True + + +ItemT = TypeVar("ItemT") -class Item(GenericModel, Generic[T]): +class Item(GenericModel, Generic[ItemT]): """Item container.""" - data: T + data: ItemT -class Collection(GenericModel, Generic[T]): +class Collection(GenericModel, Generic[ItemT]): """Collection container.""" - data: list[T] + data: list[ItemT] class Meta(BaseModel): @@ -34,7 +47,7 @@ class Meta(BaseModel): page_number: int = Field(..., description="Current page number. Starts at 1.") -class Page(Collection[T], Generic[T]): +class Page(Collection[ItemT], Generic[ItemT]): """A page of the collection with info on current page and total items in meta.""" meta: Meta diff --git a/fastapi_sqla/pagination.py b/fastapi_sqla/pagination.py new file mode 100644 index 0000000..a879a47 --- /dev/null +++ b/fastapi_sqla/pagination.py @@ -0,0 +1,146 @@ +import math +from collections.abc import Callable +from functools import singledispatch +from typing import Iterator, Optional, Union, cast + +from fastapi import Depends, Query +from sqlalchemy.orm import Query as LegacyQuery +from sqlalchemy.sql import Select, func, select + +from fastapi_sqla.models import Page +from fastapi_sqla.sqla import Session + +DbQuery = Union[LegacyQuery, Select] +QueryCountDependency = Callable[..., int] +PaginateSignature = Callable[[DbQuery, Optional[bool]], Page] +DefaultDependency = Callable[[Session, int, int], PaginateSignature] +WithQueryCountDependency = Callable[[Session, int, int, int], PaginateSignature] +PaginateDependency = Union[DefaultDependency, WithQueryCountDependency] + + +def default_query_count(session: Session, query: DbQuery) -> int: + """Default function used to count items returned by a query. + + It is slower than a manually written query could be: It runs the query in a + subquery, and count the number of elements returned. + + See https://gist.github.com/hest/8798884 + """ + if isinstance(query, LegacyQuery): + result = query.count() + + elif isinstance(query, Select): + result = cast( + int, + session.execute( + select(func.count()).select_from(query.subquery()) + ).scalar(), + ) + + else: # pragma: no cover + raise NotImplementedError(f"Query type {type(query)!r} is not supported") + + return result + + +@singledispatch +def paginate_query( + query: DbQuery, + session: Session, + total_items: int, + offset: int, + limit: int, + scalars: bool = True, +) -> Page: # pragma: no cover + "Dispatch on registered functions based on `query` type" + raise NotImplementedError(f"no paginate_query registered for type {type(query)!r}") + + +@paginate_query.register +def _paginate_legacy( + query: LegacyQuery, + session: Session, + total_items: int, + offset: int, + limit: int, + scalars: bool = True, +) -> Page: + total_pages = math.ceil(total_items / limit) + page_number = offset / limit + 1 + return Page( + data=query.offset(offset).limit(limit).all(), + meta={ + "offset": offset, + "total_items": total_items, + "total_pages": total_pages, + "page_number": page_number, + }, + ) + + +@paginate_query.register +def _paginate( + query: Select, + session: Session, + total_items: int, + offset: int, + limit: int, + *, + scalars: bool = True, +) -> Page: + total_pages = math.ceil(total_items / limit) + page_number = offset / limit + 1 + query = query.offset(offset).limit(limit) + result = session.execute(query) + data = iter( + cast(Iterator, result.unique().scalars() if scalars else result.mappings()) + ) + return Page( + data=data, + meta={ + "offset": offset, + "total_items": total_items, + "total_pages": total_pages, + "page_number": page_number, + }, + ) + + +def Pagination( + min_page_size: int = 10, + max_page_size: int = 100, + query_count: Union[QueryCountDependency, None] = None, +) -> PaginateDependency: + def default_dependency( + session: Session = Depends(), + offset: int = Query(0, ge=0), + limit: int = Query(min_page_size, ge=1, le=max_page_size), + ) -> PaginateSignature: + def paginate(query: DbQuery, scalars=True) -> Page: + total_items = default_query_count(session, query) + return paginate_query( + query, session, total_items, offset, limit, scalars=scalars + ) + + return paginate + + def with_query_count_dependency( + session: Session = Depends(), + offset: int = Query(0, ge=0), + limit: int = Query(min_page_size, ge=1, le=max_page_size), + total_items: int = Depends(query_count), + ) -> PaginateSignature: + def paginate(query: DbQuery, scalars=True) -> Page: + return paginate_query( + query, session, total_items, offset, limit, scalars=scalars + ) + + return paginate + + if query_count: + return with_query_count_dependency + else: + return default_dependency + + +Paginate: PaginateDependency = Pagination() diff --git a/fastapi_sqla/sqla.py b/fastapi_sqla/sqla.py index 42387f5..0c594ef 100644 --- a/fastapi_sqla/sqla.py +++ b/fastapi_sqla/sqla.py @@ -1,33 +1,20 @@ import asyncio -import math import os -from collections.abc import Callable, Generator +from collections.abc import Generator from contextlib import contextmanager -from functools import singledispatch -from typing import Iterator, Optional, Union, cast +from typing import Union import structlog -from fastapi import Depends, Query, Request +from fastapi import Request from fastapi.concurrency import contextmanager_in_threadpool from fastapi.responses import PlainTextResponse from sqlalchemy import engine_from_config, text from sqlalchemy.engine import Engine -from sqlalchemy.ext.declarative import DeferredReflection -from sqlalchemy.orm import Query as LegacyQuery from sqlalchemy.orm.session import Session as SqlaSession from sqlalchemy.orm.session import sessionmaker -from sqlalchemy.sql import Select, func, select from fastapi_sqla import aws_aurora_support, aws_rds_iam_support -from fastapi_sqla.models import Page - -try: - from sqlalchemy.orm import DeclarativeBase -except ImportError: - from sqlalchemy.ext.declarative import declarative_base - - DeclarativeBase = declarative_base() # type: ignore - +from fastapi_sqla.models import Base logger = structlog.get_logger(__name__) @@ -44,10 +31,6 @@ def new_engine(*, envvar_prefix: Union[str, None] = None) -> Engine: return engine_from_config(lowercase_environ, prefix=envvar_prefix) -def is_async_dialect(engine): - return engine.dialect.is_async if hasattr(engine.dialect, "is_async") else False - - def startup(): engine = new_engine() aws_rds_iam_support.setup(engine.engine) @@ -68,10 +51,6 @@ def startup(): logger.info("startup", engine=engine) -class Base(DeclarativeBase, DeferredReflection): - __abstract__ = True - - class Session(SqlaSession): def __new__(cls, request: Request): """Yield the sqlalchmey session for that request. @@ -176,139 +155,3 @@ def get_users(session: fastapi_sqla.Session = Depends()): await loop.run_in_executor(None, session.rollback) return response - - -DbQuery = Union[LegacyQuery, Select] -QueryCountDependency = Callable[..., int] -PaginateSignature = Callable[[DbQuery, Optional[bool]], Page] -DefaultDependency = Callable[[Session, int, int], PaginateSignature] -WithQueryCountDependency = Callable[[Session, int, int, int], PaginateSignature] -PaginateDependency = Union[DefaultDependency, WithQueryCountDependency] - - -def default_query_count(session: Session, query: DbQuery) -> int: - """Default function used to count items returned by a query. - - It is slower than a manually written query could be: It runs the query in a - subquery, and count the number of elements returned. - - See https://gist.github.com/hest/8798884 - """ - if isinstance(query, LegacyQuery): - result = query.count() - - elif isinstance(query, Select): - result = cast( - int, - session.execute( - select(func.count()).select_from(query.subquery()) - ).scalar(), - ) - - else: # pragma: no cover - raise NotImplementedError(f"Query type {type(query)!r} is not supported") - - return result - - -@singledispatch -def paginate_query( - query: DbQuery, - session: Session, - total_items: int, - offset: int, - limit: int, - scalars: bool = True, -) -> Page: # pragma: no cover - "Dispatch on registered functions based on `query` type" - raise NotImplementedError(f"no paginate_query registered for type {type(query)!r}") - - -@paginate_query.register -def _paginate_legacy( - query: LegacyQuery, - session: Session, - total_items: int, - offset: int, - limit: int, - scalars: bool = True, -) -> Page: - total_pages = math.ceil(total_items / limit) - page_number = offset / limit + 1 - return Page( - data=query.offset(offset).limit(limit).all(), - meta={ - "offset": offset, - "total_items": total_items, - "total_pages": total_pages, - "page_number": page_number, - }, - ) - - -@paginate_query.register -def _paginate( - query: Select, - session: Session, - total_items: int, - offset: int, - limit: int, - *, - scalars: bool = True, -) -> Page: - total_pages = math.ceil(total_items / limit) - page_number = offset / limit + 1 - query = query.offset(offset).limit(limit) - result = session.execute(query) - data = iter( - cast(Iterator, result.unique().scalars() if scalars else result.mappings()) - ) - return Page( - data=data, - meta={ - "offset": offset, - "total_items": total_items, - "total_pages": total_pages, - "page_number": page_number, - }, - ) - - -def Pagination( - min_page_size: int = 10, - max_page_size: int = 100, - query_count: Union[QueryCountDependency, None] = None, -) -> PaginateDependency: - def default_dependency( - session: Session = Depends(), - offset: int = Query(0, ge=0), - limit: int = Query(min_page_size, ge=1, le=max_page_size), - ) -> PaginateSignature: - def paginate(query: DbQuery, scalars=True) -> Page: - total_items = default_query_count(session, query) - return paginate_query( - query, session, total_items, offset, limit, scalars=scalars - ) - - return paginate - - def with_query_count_dependency( - session: Session = Depends(), - offset: int = Query(0, ge=0), - limit: int = Query(min_page_size, ge=1, le=max_page_size), - total_items: int = Depends(query_count), - ) -> PaginateSignature: - def paginate(query: DbQuery, scalars=True) -> Page: - return paginate_query( - query, session, total_items, offset, limit, scalars=scalars - ) - - return paginate - - if query_count: - return with_query_count_dependency - else: - return default_dependency - - -Paginate: PaginateDependency = Pagination() diff --git a/tests/test_base.py b/tests/test_base.py index 3a2285a..891e55e 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -15,7 +15,8 @@ def setup_tear_down(engine): def test_startup_reflect_test_table(): - from fastapi_sqla.sqla import Base, _Session, startup + from fastapi_sqla.models import Base + from fastapi_sqla.sqla import _Session, startup class TestTable(Base): __tablename__ = "test_table" @@ -43,7 +44,8 @@ def expected_error(sqla_version_tuple): def test_startup_fails_when_table_doesnt_exist(expected_error): - from fastapi_sqla.sqla import Base, startup + from fastapi_sqla.models import Base + from fastapi_sqla.sqla import startup class TestTable(Base): __tablename__ = "does_not_exist" diff --git a/tests/test_open_session.py b/tests/test_open_session.py index c5a7d57..c30745c 100644 --- a/tests/test_open_session.py +++ b/tests/test_open_session.py @@ -26,7 +26,8 @@ def setup(sqla_connection): @fixture(scope="module") def TestTable(module_setup_tear_down): - from fastapi_sqla.sqla import Base, startup + from fastapi_sqla import Base + from fastapi_sqla.sqla import startup class TestTable(Base): __tablename__ = "test_table" From 492a5a98725ccbdd60f4e183f3b8265b117b63ed Mon Sep 17 00:00:00 2001 From: Arthur Loiselle Date: Wed, 1 Nov 2023 13:12:55 -0400 Subject: [PATCH 2/3] asyncio_support to async_session --- fastapi_sqla/__init__.py | 4 ++-- fastapi_sqla/_pytest_plugin.py | 6 +++--- fastapi_sqla/async_pagination.py | 2 +- .../{asyncio_support.py => async_session.py} | 0 fastapi_sqla/base.py | 6 +++--- tests/conftest.py | 1 + ...test_asyncio_support.py => test_async_session.py} | 12 ++++++------ tests/test_aws_aurora_support.py | 2 +- tests/test_pytest_plugin.py | 2 +- tests/test_setup.py | 7 ++++--- tests/test_startup.py | 6 +++--- 11 files changed, 25 insertions(+), 23 deletions(-) rename fastapi_sqla/{asyncio_support.py => async_session.py} (100%) rename tests/{test_asyncio_support.py => test_async_session.py} (77%) diff --git a/fastapi_sqla/__init__.py b/fastapi_sqla/__init__.py index b7ff68c..860aa34 100644 --- a/fastapi_sqla/__init__.py +++ b/fastapi_sqla/__init__.py @@ -19,8 +19,8 @@ try: from fastapi_sqla.async_pagination import AsyncPaginate, AsyncPagination - from fastapi_sqla.asyncio_support import AsyncSession - from fastapi_sqla.asyncio_support import open_session as open_async_session + from fastapi_sqla.async_session import AsyncSession + from fastapi_sqla.async_session import open_session as open_async_session __all__ += [ "AsyncPaginate", diff --git a/fastapi_sqla/_pytest_plugin.py b/fastapi_sqla/_pytest_plugin.py index 9eeb687..df0c595 100644 --- a/fastapi_sqla/_pytest_plugin.py +++ b/fastapi_sqla/_pytest_plugin.py @@ -164,13 +164,13 @@ async def async_sqla_connection(async_engine, event_loop): @fixture async def patch_new_engine(async_sqlalchemy_url, async_sqla_connection, request): """So that all async DB operations are never written to db for real.""" - from fastapi_sqla.asyncio_support import _AsyncSession + from fastapi_sqla.async_session import _AsyncSession if "dont_patch_engines" in request.keywords: yield else: - with patch("fastapi_sqla.asyncio_support.new_engine") as new_engine: + with patch("fastapi_sqla.async_session.new_engine") as new_engine: new_engine.return_value = async_sqla_connection _AsyncSession.configure( bind=async_sqla_connection, expire_on_commit=False @@ -187,7 +187,7 @@ async def async_sqla_reflection(sqla_modules, async_sqla_connection): async def async_session( async_sqla_connection, async_sqla_reflection, patch_new_engine ): - from fastapi_sqla.asyncio_support import _AsyncSession + from fastapi_sqla.async_session import _AsyncSession session = _AsyncSession(bind=async_sqla_connection) yield session diff --git a/fastapi_sqla/async_pagination.py b/fastapi_sqla/async_pagination.py index 272ac0d..be15947 100644 --- a/fastapi_sqla/async_pagination.py +++ b/fastapi_sqla/async_pagination.py @@ -5,7 +5,7 @@ from fastapi import Depends, Query from sqlalchemy.sql import Select, func, select -from fastapi_sqla.asyncio_support import AsyncSession +from fastapi_sqla.async_session import AsyncSession from fastapi_sqla.models import Page QueryCountDependency = Callable[..., Awaitable[int]] diff --git a/fastapi_sqla/asyncio_support.py b/fastapi_sqla/async_session.py similarity index 100% rename from fastapi_sqla/asyncio_support.py rename to fastapi_sqla/async_session.py diff --git a/fastapi_sqla/base.py b/fastapi_sqla/base.py index c5b0ee2..41358c7 100644 --- a/fastapi_sqla/base.py +++ b/fastapi_sqla/base.py @@ -6,7 +6,7 @@ from fastapi_sqla import sqla try: - from fastapi_sqla import asyncio_support + from fastapi_sqla import async_session has_asyncio_support = True @@ -25,8 +25,8 @@ def setup(app: FastAPI): has_async_config = "async_sqlalchemy_url" in os.environ or is_async_dialect(engine) if has_async_config: assert has_asyncio_support, asyncio_support_err - app.add_event_handler("startup", asyncio_support.startup) - app.middleware("http")(asyncio_support.add_session_to_request) + app.add_event_handler("startup", async_session.startup) + app.middleware("http")(async_session.add_session_to_request) def is_async_dialect(engine: Engine): diff --git a/tests/conftest.py b/tests/conftest.py index b3ebd59..4795c4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -86,6 +86,7 @@ def tear_down(environ): close_all_sessions() # reload fastapi_sqla to clear sqla deferred reflection mapping stored in Base + importlib.reload(fastapi_sqla.models) importlib.reload(fastapi_sqla.sqla) importlib.reload(fastapi_sqla) diff --git a/tests/test_asyncio_support.py b/tests/test_async_session.py similarity index 77% rename from tests/test_asyncio_support.py rename to tests/test_async_session.py index a54de24..494a4bc 100644 --- a/tests/test_asyncio_support.py +++ b/tests/test_async_session.py @@ -8,14 +8,14 @@ @fixture async def startup(environ): - from fastapi_sqla.asyncio_support import startup + from fastapi_sqla.async_session import startup await startup() yield async def test_startup_configure_async_session(startup): - from fastapi_sqla.asyncio_support import _AsyncSession + from fastapi_sqla.async_session import _AsyncSession async with _AsyncSession() as session: res = await session.execute(text("SELECT 123")) @@ -24,7 +24,7 @@ async def test_startup_configure_async_session(startup): async def test_open_async_session(startup): - from fastapi_sqla.asyncio_support import open_session + from fastapi_sqla.async_session import open_session async with open_session() as session: res = await session.execute(text("select 123")) @@ -35,7 +35,7 @@ async def test_open_async_session(startup): async def test_new_async_engine_without_async_alchemy_url( monkeypatch, async_sqlalchemy_url ): - from fastapi_sqla.asyncio_support import new_async_engine + from fastapi_sqla.async_session import new_async_engine monkeypatch.delenv("async_sqlalchemy_url") monkeypatch.setenv("sqlalchemy_url", async_sqlalchemy_url) @@ -45,13 +45,13 @@ async def test_new_async_engine_without_async_alchemy_url( @fixture def AsyncSessionMock(): - with patch("fastapi_sqla.asyncio_support._AsyncSession") as AsyncSessionMock: + with patch("fastapi_sqla.async_session._AsyncSession") as AsyncSessionMock: AsyncSessionMock.return_value = AsyncMock() yield AsyncSessionMock async def test_context_manager_rollbacks_on_error(AsyncSessionMock): - from fastapi_sqla.asyncio_support import open_session + from fastapi_sqla.async_session import open_session session = AsyncSessionMock.return_value with raises(Exception) as raise_info: diff --git a/tests/test_aws_aurora_support.py b/tests/test_aws_aurora_support.py index ab34e8a..3be05eb 100644 --- a/tests/test_aws_aurora_support.py +++ b/tests/test_aws_aurora_support.py @@ -24,7 +24,7 @@ def test_sync_disconnects_on_readonly_error(monkeypatch): @mark.require_asyncpg @mark.dont_patch_engines async def test_async_disconnects_on_readonly_error(monkeypatch, async_sqlalchemy_url): - from fastapi_sqla.asyncio_support import _AsyncSession, startup + from fastapi_sqla.async_session import _AsyncSession, startup monkeypatch.setenv("fastapi_sqla_aws_aurora_enabled", "true") monkeypatch.setenv("async_sqlalchemy_url", async_sqlalchemy_url) diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 35034ff..2cdc28e 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -88,7 +88,7 @@ def test_sqla_13_all_opened_sessions_are_within_the_same_transaction( async def test_all_opened_async_sessions_are_within_the_same_transaction( async_sqla_connection, async_session, singer_cls ): - from fastapi_sqla.asyncio_support import _AsyncSession + from fastapi_sqla.async_session import _AsyncSession async_session.add(singer_cls(id=1, name="Bob Marley", country="Jamaica")) await async_session.commit() diff --git a/tests/test_setup.py b/tests/test_setup.py index 84a6c85..260bfc0 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -8,7 +8,8 @@ def test_setup_with_async_sqlalchemy_url_adds_asyncio_support_startup( monkeypatch, async_sqlalchemy_url ): - from fastapi_sqla import asyncio_support, setup + from fastapi_sqla import async_session + from fastapi_sqla.base import setup monkeypatch.delenv("async_sqlalchemy_url") monkeypatch.setenv("sqlalchemy_url", async_sqlalchemy_url) @@ -16,8 +17,8 @@ def test_setup_with_async_sqlalchemy_url_adds_asyncio_support_startup( app = Mock() setup(app) - app.add_event_handler.assert_called_once_with("startup", asyncio_support.startup) + app.add_event_handler.assert_called_once_with("startup", async_session.startup) app.middleware.assert_called_once_with("http") app.middleware.return_value.assert_called_once_with( - asyncio_support.add_session_to_request + async_session.add_session_to_request ) diff --git a/tests/test_startup.py b/tests/test_startup.py index b7aeb59..b18a043 100644 --- a/tests/test_startup.py +++ b/tests/test_startup.py @@ -91,9 +91,9 @@ async def test_async_startup_fail_on_bad_async_sqlalchemy_url(monkeypatch): ) with raises(Exception): - from fastapi_sqla import asyncio_support + from fastapi_sqla import async_session - await asyncio_support.startup() + await async_session.startup() @mark.require_boto3 @@ -119,7 +119,7 @@ def test_sync_startup_with_aws_rds_iam_enabled( async def test_async_startup_with_aws_rds_iam_enabled( monkeypatch, async_sqlalchemy_url, boto_session, boto_client_mock, db_host, db_user ): - from fastapi_sqla.asyncio_support import startup + from fastapi_sqla.async_session import startup monkeypatch.setenv("fastapi_sqla_aws_rds_iam_enabled", "true") monkeypatch.setenv("async_sqlalchemy_url", async_sqlalchemy_url) From 954cf21692b357ff534c9ba2e925a1f83b8120e7 Mon Sep 17 00:00:00 2001 From: Arthur Loiselle Date: Wed, 1 Nov 2023 13:25:17 -0400 Subject: [PATCH 3/3] async_session -> async_sqla --- fastapi_sqla/__init__.py | 4 ++-- fastapi_sqla/_pytest_plugin.py | 6 +++--- fastapi_sqla/async_pagination.py | 2 +- fastapi_sqla/{async_session.py => async_sqla.py} | 0 fastapi_sqla/base.py | 6 +++--- tests/{test_async_session.py => test_async_sqla.py} | 12 ++++++------ tests/test_aws_aurora_support.py | 2 +- tests/test_pytest_plugin.py | 2 +- tests/test_setup.py | 6 +++--- tests/test_startup.py | 6 +++--- 10 files changed, 23 insertions(+), 23 deletions(-) rename fastapi_sqla/{async_session.py => async_sqla.py} (100%) rename tests/{test_async_session.py => test_async_sqla.py} (78%) diff --git a/fastapi_sqla/__init__.py b/fastapi_sqla/__init__.py index 860aa34..3a1f87a 100644 --- a/fastapi_sqla/__init__.py +++ b/fastapi_sqla/__init__.py @@ -19,8 +19,8 @@ try: from fastapi_sqla.async_pagination import AsyncPaginate, AsyncPagination - from fastapi_sqla.async_session import AsyncSession - from fastapi_sqla.async_session import open_session as open_async_session + from fastapi_sqla.async_sqla import AsyncSession + from fastapi_sqla.async_sqla import open_session as open_async_session __all__ += [ "AsyncPaginate", diff --git a/fastapi_sqla/_pytest_plugin.py b/fastapi_sqla/_pytest_plugin.py index df0c595..0d18edb 100644 --- a/fastapi_sqla/_pytest_plugin.py +++ b/fastapi_sqla/_pytest_plugin.py @@ -164,13 +164,13 @@ async def async_sqla_connection(async_engine, event_loop): @fixture async def patch_new_engine(async_sqlalchemy_url, async_sqla_connection, request): """So that all async DB operations are never written to db for real.""" - from fastapi_sqla.async_session import _AsyncSession + from fastapi_sqla.async_sqla import _AsyncSession if "dont_patch_engines" in request.keywords: yield else: - with patch("fastapi_sqla.async_session.new_engine") as new_engine: + with patch("fastapi_sqla.async_sqla.new_engine") as new_engine: new_engine.return_value = async_sqla_connection _AsyncSession.configure( bind=async_sqla_connection, expire_on_commit=False @@ -187,7 +187,7 @@ async def async_sqla_reflection(sqla_modules, async_sqla_connection): async def async_session( async_sqla_connection, async_sqla_reflection, patch_new_engine ): - from fastapi_sqla.async_session import _AsyncSession + from fastapi_sqla.async_sqla import _AsyncSession session = _AsyncSession(bind=async_sqla_connection) yield session diff --git a/fastapi_sqla/async_pagination.py b/fastapi_sqla/async_pagination.py index be15947..0cac58c 100644 --- a/fastapi_sqla/async_pagination.py +++ b/fastapi_sqla/async_pagination.py @@ -5,7 +5,7 @@ from fastapi import Depends, Query from sqlalchemy.sql import Select, func, select -from fastapi_sqla.async_session import AsyncSession +from fastapi_sqla.async_sqla import AsyncSession from fastapi_sqla.models import Page QueryCountDependency = Callable[..., Awaitable[int]] diff --git a/fastapi_sqla/async_session.py b/fastapi_sqla/async_sqla.py similarity index 100% rename from fastapi_sqla/async_session.py rename to fastapi_sqla/async_sqla.py diff --git a/fastapi_sqla/base.py b/fastapi_sqla/base.py index 41358c7..cbd6642 100644 --- a/fastapi_sqla/base.py +++ b/fastapi_sqla/base.py @@ -6,7 +6,7 @@ from fastapi_sqla import sqla try: - from fastapi_sqla import async_session + from fastapi_sqla import async_sqla has_asyncio_support = True @@ -25,8 +25,8 @@ def setup(app: FastAPI): has_async_config = "async_sqlalchemy_url" in os.environ or is_async_dialect(engine) if has_async_config: assert has_asyncio_support, asyncio_support_err - app.add_event_handler("startup", async_session.startup) - app.middleware("http")(async_session.add_session_to_request) + app.add_event_handler("startup", async_sqla.startup) + app.middleware("http")(async_sqla.add_session_to_request) def is_async_dialect(engine: Engine): diff --git a/tests/test_async_session.py b/tests/test_async_sqla.py similarity index 78% rename from tests/test_async_session.py rename to tests/test_async_sqla.py index 494a4bc..1a3bac2 100644 --- a/tests/test_async_session.py +++ b/tests/test_async_sqla.py @@ -8,14 +8,14 @@ @fixture async def startup(environ): - from fastapi_sqla.async_session import startup + from fastapi_sqla.async_sqla import startup await startup() yield async def test_startup_configure_async_session(startup): - from fastapi_sqla.async_session import _AsyncSession + from fastapi_sqla.async_sqla import _AsyncSession async with _AsyncSession() as session: res = await session.execute(text("SELECT 123")) @@ -24,7 +24,7 @@ async def test_startup_configure_async_session(startup): async def test_open_async_session(startup): - from fastapi_sqla.async_session import open_session + from fastapi_sqla.async_sqla import open_session async with open_session() as session: res = await session.execute(text("select 123")) @@ -35,7 +35,7 @@ async def test_open_async_session(startup): async def test_new_async_engine_without_async_alchemy_url( monkeypatch, async_sqlalchemy_url ): - from fastapi_sqla.async_session import new_async_engine + from fastapi_sqla.async_sqla import new_async_engine monkeypatch.delenv("async_sqlalchemy_url") monkeypatch.setenv("sqlalchemy_url", async_sqlalchemy_url) @@ -45,13 +45,13 @@ async def test_new_async_engine_without_async_alchemy_url( @fixture def AsyncSessionMock(): - with patch("fastapi_sqla.async_session._AsyncSession") as AsyncSessionMock: + with patch("fastapi_sqla.async_sqla._AsyncSession") as AsyncSessionMock: AsyncSessionMock.return_value = AsyncMock() yield AsyncSessionMock async def test_context_manager_rollbacks_on_error(AsyncSessionMock): - from fastapi_sqla.async_session import open_session + from fastapi_sqla.async_sqla import open_session session = AsyncSessionMock.return_value with raises(Exception) as raise_info: diff --git a/tests/test_aws_aurora_support.py b/tests/test_aws_aurora_support.py index 3be05eb..a1821e6 100644 --- a/tests/test_aws_aurora_support.py +++ b/tests/test_aws_aurora_support.py @@ -24,7 +24,7 @@ def test_sync_disconnects_on_readonly_error(monkeypatch): @mark.require_asyncpg @mark.dont_patch_engines async def test_async_disconnects_on_readonly_error(monkeypatch, async_sqlalchemy_url): - from fastapi_sqla.async_session import _AsyncSession, startup + from fastapi_sqla.async_sqla import _AsyncSession, startup monkeypatch.setenv("fastapi_sqla_aws_aurora_enabled", "true") monkeypatch.setenv("async_sqlalchemy_url", async_sqlalchemy_url) diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 2cdc28e..371e030 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -88,7 +88,7 @@ def test_sqla_13_all_opened_sessions_are_within_the_same_transaction( async def test_all_opened_async_sessions_are_within_the_same_transaction( async_sqla_connection, async_session, singer_cls ): - from fastapi_sqla.async_session import _AsyncSession + from fastapi_sqla.async_sqla import _AsyncSession async_session.add(singer_cls(id=1, name="Bob Marley", country="Jamaica")) await async_session.commit() diff --git a/tests/test_setup.py b/tests/test_setup.py index 260bfc0..7292b1b 100644 --- a/tests/test_setup.py +++ b/tests/test_setup.py @@ -8,7 +8,7 @@ def test_setup_with_async_sqlalchemy_url_adds_asyncio_support_startup( monkeypatch, async_sqlalchemy_url ): - from fastapi_sqla import async_session + from fastapi_sqla import async_sqla from fastapi_sqla.base import setup monkeypatch.delenv("async_sqlalchemy_url") @@ -17,8 +17,8 @@ def test_setup_with_async_sqlalchemy_url_adds_asyncio_support_startup( app = Mock() setup(app) - app.add_event_handler.assert_called_once_with("startup", async_session.startup) + app.add_event_handler.assert_called_once_with("startup", async_sqla.startup) app.middleware.assert_called_once_with("http") app.middleware.return_value.assert_called_once_with( - async_session.add_session_to_request + async_sqla.add_session_to_request ) diff --git a/tests/test_startup.py b/tests/test_startup.py index b18a043..4334a78 100644 --- a/tests/test_startup.py +++ b/tests/test_startup.py @@ -91,9 +91,9 @@ async def test_async_startup_fail_on_bad_async_sqlalchemy_url(monkeypatch): ) with raises(Exception): - from fastapi_sqla import async_session + from fastapi_sqla import async_sqla - await async_session.startup() + await async_sqla.startup() @mark.require_boto3 @@ -119,7 +119,7 @@ def test_sync_startup_with_aws_rds_iam_enabled( async def test_async_startup_with_aws_rds_iam_enabled( monkeypatch, async_sqlalchemy_url, boto_session, boto_client_mock, db_host, db_user ): - from fastapi_sqla.async_session import startup + from fastapi_sqla.async_sqla import startup monkeypatch.setenv("fastapi_sqla_aws_rds_iam_enabled", "true") monkeypatch.setenv("async_sqlalchemy_url", async_sqlalchemy_url)