Skip to content

Commit

Permalink
feat: split the main modules into multiple modules - DIA-61984 (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
arththebird committed Nov 1, 2023
1 parent 512c9d8 commit e9510ac
Show file tree
Hide file tree
Showing 17 changed files with 332 additions and 318 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ exception occurred:

```python
from fastapi import APIRouter, Depends
from fastapi_sqla import Session
from fastapi_sqla.asyncio_support import AsyncSession
from fastapi_sqla import Session, AsyncSession

router = APIRouter()

Expand All @@ -183,8 +182,7 @@ occurred:

```python
from fastapi import APIRouter, BackgroundTasks
from fastapi_sqla import open_session
from fastapi_sqla import asyncio_support
from fastapi_sqla import open_session, open_async_session

router = APIRouter()

Expand All @@ -201,7 +199,7 @@ def run_bg():


async def run_async_bg():
async with asyncio_support.open_session() as session:
async with open_async_session() as session:
await session.scalar("SELECT now()")
```

Expand Down
50 changes: 10 additions & 40 deletions fastapi_sqla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
import os

from fastapi import FastAPI

from fastapi_sqla import sqla
from fastapi_sqla.models import Collection, Item, Page
from fastapi_sqla.sqla import (
Base,
Paginate,
PaginateSignature,
Pagination,
Session,
open_session,
)
from fastapi_sqla.base import setup
from fastapi_sqla.models import Base, Collection, Item, Page
from fastapi_sqla.pagination import Paginate, PaginateSignature, Pagination
from fastapi_sqla.sqla import Session, open_session

__all__ = [
"Base",
Expand All @@ -23,17 +13,14 @@
"Pagination",
"Session",
"open_session",
"setup",
]


try:
from fastapi_sqla import asyncio_support
from fastapi_sqla.asyncio_support import ( # noqa
AsyncPaginate,
AsyncPagination,
AsyncSession,
)
from fastapi_sqla.asyncio_support import open_session as open_async_session # noqa
from fastapi_sqla.async_pagination import AsyncPaginate, AsyncPagination
from fastapi_sqla.async_sqla import AsyncSession
from fastapi_sqla.async_sqla import open_session as open_async_session

__all__ += [
"AsyncPaginate",
Expand All @@ -43,22 +30,5 @@
]
has_asyncio_support = True

except ImportError as err: # pragma: no cover
has_asyncio_support = False
asyncio_support_err = str(err)


def setup(app: FastAPI):
engine = sqla.new_engine()

if not sqla.is_async_dialect(engine):
app.add_event_handler("startup", sqla.startup)
app.middleware("http")(sqla.add_session_to_request)

has_async_config = "async_sqlalchemy_url" in os.environ or sqla.is_async_dialect(
engine
)
if has_async_config:
assert has_asyncio_support, asyncio_support_err
app.add_event_handler("startup", asyncio_support.startup)
app.middleware("http")(asyncio_support.add_session_to_request)
except ImportError: # pragma: no cover
pass
6 changes: 3 additions & 3 deletions fastapi_sqla/_pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,13 @@ async def async_sqla_connection(async_engine, event_loop):
@fixture
async def patch_new_engine(async_sqlalchemy_url, async_sqla_connection, request):
"""So that all async DB operations are never written to db for real."""
from fastapi_sqla.asyncio_support import _AsyncSession
from fastapi_sqla.async_sqla import _AsyncSession

if "dont_patch_engines" in request.keywords:
yield

else:
with patch("fastapi_sqla.asyncio_support.new_engine") as new_engine:
with patch("fastapi_sqla.async_sqla.new_engine") as new_engine:
new_engine.return_value = async_sqla_connection
_AsyncSession.configure(
bind=async_sqla_connection, expire_on_commit=False
Expand All @@ -187,7 +187,7 @@ async def async_sqla_reflection(sqla_modules, async_sqla_connection):
async def async_session(
async_sqla_connection, async_sqla_reflection, patch_new_engine
):
from fastapi_sqla.asyncio_support import _AsyncSession
from fastapi_sqla.async_sqla import _AsyncSession

session = _AsyncSession(bind=async_sqla_connection)
yield session
Expand Down
87 changes: 87 additions & 0 deletions fastapi_sqla/async_pagination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import math
from collections.abc import Awaitable, Callable
from typing import Iterator, Optional, Union, cast

from fastapi import Depends, Query
from sqlalchemy.sql import Select, func, select

from fastapi_sqla.async_sqla import AsyncSession
from fastapi_sqla.models import Page

QueryCountDependency = Callable[..., Awaitable[int]]
PaginateSignature = Callable[[Select, Optional[bool]], Awaitable[Page]]
DefaultDependency = Callable[[AsyncSession, int, int], PaginateSignature]
WithQueryCountDependency = Callable[[AsyncSession, int, int, int], PaginateSignature]
PaginateDependency = Union[DefaultDependency, WithQueryCountDependency]


async def default_query_count(session: AsyncSession, query: Select) -> int:
result = await session.execute(select(func.count()).select_from(query.subquery()))
return cast(int, result.scalar())


async def paginate_query(
query: Select,
session: AsyncSession,
total_items: int,
offset: int,
limit: int,
*,
scalars: bool = True,
) -> Page:
total_pages = math.ceil(total_items / limit)
page_number = offset / limit + 1
query = query.offset(offset).limit(limit)
result = await session.execute(query)
data = iter(
cast(Iterator, result.unique().scalars() if scalars else result.mappings())
)
return Page(
data=data,
meta={
"offset": offset,
"total_items": total_items,
"total_pages": total_pages,
"page_number": page_number,
},
)


def AsyncPagination(
min_page_size: int = 10,
max_page_size: int = 100,
query_count: Union[QueryCountDependency, None] = None,
) -> PaginateDependency:
def default_dependency(
session: AsyncSession = Depends(),
offset: int = Query(0, ge=0),
limit: int = Query(min_page_size, ge=1, le=max_page_size),
) -> PaginateSignature:
async def paginate(query: Select, scalars=True) -> Page:
total_items = await default_query_count(session, query)
return await paginate_query(
query, session, total_items, offset, limit, scalars=scalars
)

return paginate

def with_query_count_dependency(
session: AsyncSession = Depends(),
offset: int = Query(0, ge=0),
limit: int = Query(min_page_size, ge=1, le=max_page_size),
total_items: int = Depends(query_count),
):
async def paginate(query: Select, scalars=True) -> Page:
return await paginate_query(
query, session, total_items, offset, limit, scalars=scalars
)

return paginate

if query_count:
return with_query_count_dependency
else:
return default_dependency


AsyncPaginate: PaginateDependency = AsyncPagination()
91 changes: 5 additions & 86 deletions fastapi_sqla/asyncio_support.py → fastapi_sqla/async_sqla.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import math
import os
from collections.abc import AsyncGenerator, Awaitable, Callable
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Iterator, Optional, Union, cast
from typing import cast

import structlog
from fastapi import Depends, Query, Request
from fastapi import Request
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import AsyncSession as SqlaAsyncSession
from sqlalchemy.orm.session import sessionmaker
from sqlalchemy.sql import Select, func, select

from fastapi_sqla import aws_aurora_support, aws_rds_iam_support
from fastapi_sqla.models import Page
from fastapi_sqla.sqla import Base, new_engine
from fastapi_sqla.models import Base
from fastapi_sqla.sqla import new_engine

logger = structlog.get_logger(__name__)
_ASYNC_SESSION_KEY = "fastapi_sqla_async_session"
Expand Down Expand Up @@ -127,82 +125,3 @@ async def get_users(session: fastapi_sqla.AsyncSession = Depends()):
await session.rollback()

return response


QueryCountDependency = Callable[..., Awaitable[int]]
PaginateSignature = Callable[[Select, Optional[bool]], Awaitable[Page]]
DefaultDependency = Callable[[AsyncSession, int, int], PaginateSignature]
WithQueryCountDependency = Callable[[AsyncSession, int, int, int], PaginateSignature]
PaginateDependency = Union[DefaultDependency, WithQueryCountDependency]


async def default_query_count(session: AsyncSession, query: Select) -> int:
result = await session.execute(select(func.count()).select_from(query.subquery()))
return cast(int, result.scalar())


async def paginate_query(
query: Select,
session: AsyncSession,
total_items: int,
offset: int,
limit: int,
*,
scalars: bool = True,
) -> Page:
total_pages = math.ceil(total_items / limit)
page_number = offset / limit + 1
query = query.offset(offset).limit(limit)
result = await session.execute(query)
data = iter(
cast(Iterator, result.unique().scalars() if scalars else result.mappings())
)
return Page(
data=data,
meta={
"offset": offset,
"total_items": total_items,
"total_pages": total_pages,
"page_number": page_number,
},
)


def AsyncPagination(
min_page_size: int = 10,
max_page_size: int = 100,
query_count: Union[QueryCountDependency, None] = None,
) -> PaginateDependency:
def default_dependency(
session: AsyncSession = Depends(),
offset: int = Query(0, ge=0),
limit: int = Query(min_page_size, ge=1, le=max_page_size),
) -> PaginateSignature:
async def paginate(query: Select, scalars=True) -> Page:
total_items = await default_query_count(session, query)
return await paginate_query(
query, session, total_items, offset, limit, scalars=scalars
)

return paginate

def with_query_count_dependency(
session: AsyncSession = Depends(),
offset: int = Query(0, ge=0),
limit: int = Query(min_page_size, ge=1, le=max_page_size),
total_items: int = Depends(query_count),
):
async def paginate(query: Select, scalars=True) -> Page:
return await paginate_query(
query, session, total_items, offset, limit, scalars=scalars
)

return paginate

if query_count:
return with_query_count_dependency
else:
return default_dependency


AsyncPaginate: PaginateDependency = AsyncPagination()
33 changes: 33 additions & 0 deletions fastapi_sqla/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os

from fastapi import FastAPI
from sqlalchemy.engine import Engine

from fastapi_sqla import sqla

try:
from fastapi_sqla import async_sqla

has_asyncio_support = True

except ImportError as err: # pragma: no cover
has_asyncio_support = False
asyncio_support_err = str(err)


def setup(app: FastAPI):
engine = sqla.new_engine()

if not is_async_dialect(engine):
app.add_event_handler("startup", sqla.startup)
app.middleware("http")(sqla.add_session_to_request)

has_async_config = "async_sqlalchemy_url" in os.environ or is_async_dialect(engine)
if has_async_config:
assert has_asyncio_support, asyncio_support_err
app.add_event_handler("startup", async_sqla.startup)
app.middleware("http")(async_sqla.add_session_to_request)


def is_async_dialect(engine: Engine):
return engine.dialect.is_async if hasattr(engine.dialect, "is_async") else False
Loading

0 comments on commit e9510ac

Please sign in to comment.