Skip to content

Commit

Permalink
fix: Make litestar example work again, and implement tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sherbang committed Mar 27, 2024
1 parent c3dba02 commit cba4f5f
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 82 deletions.
159 changes: 77 additions & 82 deletions examples/litestar.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,66 @@
from __future__ import annotations

from datetime import date, datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Union
from uuid import UUID

from litestar import Litestar
from litestar.controller import Controller
from litestar.di import Provide
from litestar.exceptions import NotFoundException as LiteStarNotFoundException
from litestar.handlers.http_handlers.decorators import delete, get, patch, post
from litestar.pagination import OffsetPagination
from litestar.params import Parameter
from pydantic import BaseModel as _BaseModel
from pydantic import TypeAdapter
from sqlalchemy import ForeignKey, select
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
from typing_extensions import Annotated

from advanced_alchemy.base import UUIDAuditBase, UUIDBase
from advanced_alchemy.config import AsyncSessionConfig
from advanced_alchemy.exceptions import NotFoundError as AdvancedAlchemyNotFoundError
from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyAsyncConfig, SQLAlchemyPlugin
from advanced_alchemy.filters import LimitOffset
from advanced_alchemy.repository import SQLAlchemyAsyncRepository

if TYPE_CHECKING:
from litestar.dto import DTOData
from sqlalchemy.ext.asyncio import AsyncSession


class BaseModel(_BaseModel):
"""Extend Pydantic's BaseModel to enable ORM mode"""

model_config = {"from_attributes": True}


# the SQLAlchemy base includes a declarative model for you to use in your models.
# The `Base` class includes a `UUID` based primary key (`id`)
class AuthorModel(UUIDBase):
class Author(UUIDBase):
# we can optionally provide the table name instead of auto-generating it
__tablename__ = "author" # type: ignore[assignment]
name: Mapped[str]
dob: Mapped[date | None]
books: Mapped[list[BookModel]] = relationship(back_populates="author", lazy="noload")
dob: Mapped[Union[date, None]] # noqa: UP007 - needed for SQLAlchemy on older python versions
books: Mapped[List[Book]] = relationship(back_populates="author", lazy="noload") # noqa: UP006


# The `AuditBase` class includes the same UUID` based primary key (`id`) and 2
# additional columns: `created` and `updated`. `created` is a timestamp of when the
# record created, and `updated` is the last time the record was modified.
class BookModel(UUIDAuditBase):
class Book(UUIDAuditBase):
__tablename__ = "book" # type: ignore[assignment]
title: Mapped[str]
author_id: Mapped[UUID] = mapped_column(ForeignKey("author.id"))
author: Mapped[AuthorModel] = relationship(lazy="joined", innerjoin=True, viewonly=True)


# we will explicitly define the schema instead of using DTO objects for clarity.

author: Mapped[Author] = relationship(lazy="joined", innerjoin=True, viewonly=True)

class Author(BaseModel):
id: UUID | None
name: str
dob: date | None = None

# DTO objects let us filter certain fields out of our request/response data
# without defining separate models
class AuthorDTO(SQLAlchemyDTO[Author]):
config = SQLAlchemyDTOConfig(exclude={"books"})

class AuthorCreate(BaseModel):
name: str
dob: date | None = None

class AuthorCreateUpdateDTO(SQLAlchemyDTO[Author]):
config = SQLAlchemyDTOConfig(exclude={"id", "books"})

class AuthorUpdate(BaseModel):
name: str | None = None
dob: date | None = None


class AuthorRepository(SQLAlchemyAsyncRepository[AuthorModel]):
class AuthorRepository(SQLAlchemyAsyncRepository[Author]):
"""Author repository."""

model_type = AuthorModel
model_type = Author


async def provide_authors_repo(db_session: AsyncSession) -> AuthorRepository:
Expand All @@ -86,7 +73,7 @@ async def provide_authors_repo(db_session: AsyncSession) -> AuthorRepository:
async def provide_author_details_repo(db_session: AsyncSession) -> AuthorRepository:
"""This provides a simple example demonstrating how to override the join options for the repository."""
return AuthorRepository(
statement=select(AuthorModel).options(selectinload(AuthorModel.books)),
statement=select(Author).options(selectinload(Author.books)),
session=db_session,
)

Expand Down Expand Up @@ -119,68 +106,74 @@ class AuthorController(Controller):

dependencies = {"authors_repo": Provide(provide_authors_repo)}

@get(path="/authors")
@get(path="/authors", return_dto=AuthorDTO)
async def list_authors(
self,
authors_repo: AuthorRepository,
limit_offset: LimitOffset,
) -> OffsetPagination[Author]:
"""List authors."""
results, total = await authors_repo.list_and_count(limit_offset)
type_adapter = TypeAdapter(list[Author])
return OffsetPagination[Author](
items=type_adapter.validate_python(results),
items=results,
total=total,
limit=limit_offset.limit,
offset=limit_offset.offset,
)

@post(path="/authors")
async def create_author(
self,
authors_repo: AuthorRepository,
data: AuthorCreate,
) -> Author:
@post(path="/authors", dto=AuthorCreateUpdateDTO)
async def create_author(self, authors_repo: AuthorRepository, data: DTOData[Author]) -> Author:
"""Create a new author."""
obj = await authors_repo.add(
AuthorModel(**data.model_dump(exclude_unset=True, exclude_none=True)),
)

# Turn the DTO object into an Author instance.
author = data.create_instance()

obj = await authors_repo.add(author)
await authors_repo.session.commit()
return Author.model_validate(obj)
return obj

# we override the authors_repo to use the version that joins the Books in
@get(path="/authors/{author_id:uuid}", dependencies={"authors_repo": Provide(provide_author_details_repo)})
async def get_author(
self,
authors_repo: AuthorRepository,
author_id: UUID = Parameter( # noqa: B008
title="Author ID",
description="The author to retrieve.",
),
author_id: Annotated[
UUID,
Parameter(
title="Author ID",
description="The author to retrieve.",
),
],
) -> Author:
"""Get an existing author."""
obj = await authors_repo.get(author_id)
return Author.model_validate(obj)
try:
return await authors_repo.get(author_id)
except AdvancedAlchemyNotFoundError as e:
msg = f"Author with id {author_id} not found."
raise LiteStarNotFoundException(msg) from e

@patch(
path="/authors/{author_id:uuid}",
dependencies={"authors_repo": Provide(provide_author_details_repo)},
dto=AuthorCreateUpdateDTO,
)
async def update_author(
self,
authors_repo: AuthorRepository,
data: AuthorUpdate,
author_id: UUID = Parameter( # noqa: B008
title="Author ID",
description="The author to update.",
),
data: DTOData[Author],
author_id: Annotated[
UUID,
Parameter(
title="Author ID",
description="The author to update.",
),
],
) -> Author:
"""Update an author."""
raw_obj = data.model_dump(exclude_unset=True, exclude_none=True)
raw_obj.update({"id": author_id})
obj = await authors_repo.update(AuthorModel(**raw_obj))
author = data.create_instance(id=author_id)
obj = await authors_repo.update(author)
await authors_repo.session.commit()
return Author.model_validate(obj)
return obj

@delete(path="/authors/{author_id:uuid}")
async def delete_author(
Expand All @@ -196,24 +189,26 @@ async def delete_author(
await authors_repo.session.commit()


session_config = AsyncSessionConfig(expire_on_commit=False)
sqlalchemy_config = SQLAlchemyAsyncConfig(
connection_string="sqlite+aiosqlite:///test.sqlite",
session_config=session_config,
) # Create 'db_session' dependency.
sqlalchemy_plugin = SQLAlchemyPlugin(config=sqlalchemy_config)


async def on_startup() -> None:
"""Initializes the database."""
async with sqlalchemy_config.get_engine().begin() as conn:
await conn.run_sync(UUIDBase.metadata.create_all)


app = Litestar(
route_handlers=[AuthorController],
on_startup=[on_startup],
plugins=[sqlalchemy_plugin],
dependencies={"limit_offset": Provide(provide_limit_offset_pagination, sync_to_thread=False)},
signature_namespace={"date": date, "datetime": datetime, "UUID": UUID},
)
def init_app(*, sqlalchemy_config: SQLAlchemyAsyncConfig | None = None) -> Litestar:
if not sqlalchemy_config:
# expire_on_commit=False prevents the sqlalchemy models from being invalidated on commit.
session_config = AsyncSessionConfig(expire_on_commit=False)
sqlalchemy_config = SQLAlchemyAsyncConfig(
connection_string="sqlite+aiosqlite:///test.sqlite",
session_config=session_config,
) # Create 'db_session' dependency.

sqlalchemy_plugin = SQLAlchemyPlugin(config=sqlalchemy_config)

async def on_startup() -> None:
"""Initializes the database."""
async with sqlalchemy_config.get_engine().begin() as conn:
await conn.run_sync(UUIDBase.metadata.create_all)

return Litestar(
route_handlers=[AuthorController],
on_startup=[on_startup],
plugins=[sqlalchemy_plugin],
dependencies={"limit_offset": Provide(provide_limit_offset_pagination, sync_to_thread=False)},
signature_namespace={"date": date, "datetime": datetime, "UUID": UUID},
)
Empty file added tests/examples/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions tests/examples/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
from pytest import MonkeyPatch


@pytest.fixture(autouse=True)
def _patch_bases(monkeypatch: MonkeyPatch) -> None:
"""
Ensure metadata isn't shared with other tests.
Within tests, imports that include SQLAlchemy models must be put into the
test functions so that this monkeypatch effects them. The joys of testing
with global variables.
"""
from sqlalchemy import orm
from sqlalchemy.schema import MetaData

class NewDeclarativeBase(orm.DeclarativeBase):
metadata = MetaData()

monkeypatch.setattr(orm, "DeclarativeBase", NewDeclarativeBase)
89 changes: 89 additions & 0 deletions tests/examples/test_litestar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

import pytest
from litestar.testing import AsyncTestClient

from advanced_alchemy.base import UUIDBase
from advanced_alchemy.config import AsyncSessionConfig
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyAsyncConfig

if TYPE_CHECKING:
from litestar import Litestar


@pytest.fixture()
async def test_client() -> AsyncIterator[AsyncTestClient[Litestar]]:
# see _patch_bases in conftest.py
from examples.litestar import init_app

# Use an in-memory database for testing and create the tables.
engine = SQLAlchemyAsyncConfig.create_engine_callable("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn:
await conn.run_sync(UUIDBase.metadata.create_all)

sqlalchemy_config = SQLAlchemyAsyncConfig(
# Use the same session instance for all requests so the database doesn't disappear
engine_instance=engine,
session_config=AsyncSessionConfig(expire_on_commit=False),
)

app = init_app(sqlalchemy_config=sqlalchemy_config)
app.debug = True

async with AsyncTestClient(app=app) as client:
yield client


async def test_create_list(test_client: AsyncTestClient[Litestar]) -> None:
# see _patch_bases in conftest.py
from examples.litestar import Author

author = Author(name="foo")

response = await test_client.post(
"/authors",
json=author.to_dict(),
)
assert response.status_code == 201, response.text
assert response.json()["name"] == author.name

response = await test_client.get("/authors")
assert response.status_code == 200, response.text
assert response.json()["items"][0]["name"] == author.name


async def test_create_get_update_delete(test_client: AsyncTestClient[Litestar]) -> None:
# see _patch_bases in conftest.py
from examples.litestar import Author

author = Author(name="foo")

response = await test_client.post(
"/authors",
json=author.to_dict(),
)
assert response.status_code == 201, response.text
assert response.json()["name"] == author.name
author_id = response.json()["id"]

response = await test_client.get(f"/authors/{author_id}")
assert response.status_code == 200, response.text
assert response.json()["name"] == author.name
assert response.json()["id"] == author_id

response = await test_client.patch(
f"/authors/{author_id}",
json={"name": "bar"},
)
assert response.status_code == 200, response.text
assert response.json()["name"] == "bar"
assert response.json()["id"] == author_id

response = await test_client.delete(f"/authors/{author_id}")
assert response.status_code == 204, response.text

response = await test_client.get(f"/authors/{author_id}")
assert response.status_code == 404, response.text

0 comments on commit cba4f5f

Please sign in to comment.