Skip to content

Commit

Permalink
feat: adds a fastapi and standalone example (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Sep 19, 2023
1 parent 43e188a commit ae0cb75
Show file tree
Hide file tree
Showing 6 changed files with 566 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:
rev: v2.2.5
hooks:
- id: codespell
exclude: "pdm.lock"
exclude: "pdm.lock|examples/us_state_lookup.json"
- repo: https://github.com/psf/black
rev: 23.9.1
hooks:
Expand Down
228 changes: 228 additions & 0 deletions examples/fastapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
from __future__ import annotations

from datetime import date # noqa: TCH003
from typing import Annotated
from uuid import UUID # noqa: TCH003

from fastapi import APIRouter, Depends, FastAPI, Request
from pydantic import BaseModel as _BaseModel
from pydantic import TypeAdapter
from sqlalchemy import ForeignKey, select
from sqlalchemy.ext.asyncio import AsyncSession # noqa: TCH002
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload

from advanced_alchemy.base import UUIDAuditBase, UUIDBase
from advanced_alchemy.config import AsyncSessionConfig, SQLAlchemyAsyncConfig
from advanced_alchemy.extensions.starlette import StarletteAdvancedAlchemy
from advanced_alchemy.filters import LimitOffset
from advanced_alchemy.repository import SQLAlchemyAsyncRepository

# #######################
# Models
# #######################


class BaseModel(_BaseModel):
"""Extend Pydantic's BaseModel to enable ORM mode"""

model_config = {"from_attributes": True}


# the SQLAlchemy base includes a declarative model for you to use in your models.
# The `Base` class includes a `UUID` based primary key (`id`)
class AuthorModel(UUIDBase):
# we can optionally provide the table name instead of auto-generating it
__tablename__ = "author" # type: ignore[assignment]
name: Mapped[str]
dob: Mapped[date | None]
books: Mapped[list[BookModel]] = relationship(back_populates="author", lazy="noload")


# The `AuditBase` class includes the same UUID` based primary key (`id`) and 2
# additional columns: `created` and `updated`. `created` is a timestamp of when the
# record created, and `updated` is the last time the record was modified.
class BookModel(UUIDAuditBase):
__tablename__ = "book" # type: ignore[assignment]
title: Mapped[str]
author_id: Mapped[UUID] = mapped_column(ForeignKey("author.id"))
author: Mapped[AuthorModel] = relationship(lazy="joined", innerjoin=True, viewonly=True)


# we will explicitly define the schema instead of using DTO objects for clarity.


class Author(BaseModel):
id: UUID | None
name: str
dob: date | None = None


class AuthorCreate(BaseModel):
name: str
dob: date | None = None


class AuthorUpdate(BaseModel):
name: str | None = None
dob: date | None = None


class AuthorPagination(BaseModel):
"""Container for data returned using limit/offset pagination."""

items: list[Author]
"""List of data being sent as part of the response."""
limit: int
"""Maximal number of items to send."""
offset: int
"""Offset from the beginning of the query.
Identical to an index.
"""
total: int
"""Total number of items."""


class AuthorRepository(SQLAlchemyAsyncRepository[AuthorModel]):
"""Author repository."""

model_type = AuthorModel


# #######################
# Dependencies
# #######################


async def provide_db_session(request: Request) -> AsyncSession:
"""Provide a DB session."""
return alchemy.get_session(request)


async def provide_authors_repo(db_session: Annotated[AsyncSession, Depends(provide_db_session)]) -> AuthorRepository:
"""This provides the default Authors repository."""
return AuthorRepository(session=db_session)


# we can optionally override the default `select` used for the repository to pass in
# specific SQL options such as join details
async def provide_author_details_repo(
db_session: Annotated[AsyncSession, Depends(provide_db_session)],
) -> AuthorRepository:
"""This provides a simple example demonstrating how to override the join options for the repository."""
return AuthorRepository(
statement=select(AuthorModel).options(selectinload(AuthorModel.books)),
session=db_session,
)


def provide_limit_offset_pagination(
current_page: int = 1,
page_size: int = 10,
) -> LimitOffset:
"""Add offset/limit pagination.
Return type consumed by `Repository.apply_limit_offset_pagination()`.
Parameters
----------
current_page : int
LIMIT to apply to select.
page_size : int
OFFSET to apply to select.
"""
return LimitOffset(page_size, page_size * (current_page - 1))


async def on_startup() -> None:
"""Initializes the database."""
async with sqlalchemy_config.get_engine().begin() as conn:
await conn.run_sync(UUIDBase.metadata.create_all)


# #######################
# Application
# #######################

session_config = AsyncSessionConfig(expire_on_commit=False)
sqlalchemy_config = SQLAlchemyAsyncConfig(
connection_string="sqlite+aiosqlite:///test.sqlite",
session_config=session_config,
) # Create 'db_session' dependency.
app = FastAPI(on_startup=[on_startup])
alchemy = StarletteAdvancedAlchemy(config=sqlalchemy_config, app=app)

# #######################
# Routes
# #######################
author_router = APIRouter()


@author_router.get(path="/authors", response_model=AuthorPagination)
async def list_authors(
authors_repo: Annotated[AuthorRepository, Depends(provide_authors_repo)],
limit_offset: Annotated[LimitOffset, Depends(provide_limit_offset_pagination)],
) -> AuthorPagination:
"""List authors."""
results, total = await authors_repo.list_and_count(limit_offset)
type_adapter = TypeAdapter(list[Author])
return AuthorPagination(
items=type_adapter.validate_python(results),
total=total,
limit=limit_offset.limit,
offset=limit_offset.offset,
)


@author_router.post(path="/authors", response_model=Author)
async def create_author(
authors_repo: Annotated[AuthorRepository, Depends(provide_authors_repo)],
data: AuthorCreate,
) -> Author:
"""Create a new author."""
obj = await authors_repo.add(
AuthorModel(**data.model_dump(exclude_unset=True, exclude_none=True)),
)
await authors_repo.session.commit()
return Author.model_validate(obj)


# we override the authors_repo to use the version that joins the Books in
@author_router.get(path="/authors/{author_id}", response_model=Author)
async def get_author(
authors_repo: Annotated[AuthorRepository, Depends(provide_authors_repo)],
author_id: UUID,
) -> Author:
"""Get an existing author."""
obj = await authors_repo.get(author_id)
return Author.model_validate(obj)


@author_router.patch(
path="/authors/{author_id}",
response_model=Author,
)
async def update_author(
authors_repo: Annotated[AuthorRepository, Depends(provide_authors_repo)],
data: AuthorUpdate,
author_id: UUID,
) -> Author:
"""Update an author."""
raw_obj = data.model_dump(exclude_unset=True, exclude_none=True)
raw_obj.update({"id": author_id})
obj = await authors_repo.update(AuthorModel(**raw_obj))
await authors_repo.session.commit()
return Author.model_validate(obj)


@author_router.delete(path="/authors/{author_id}")
async def delete_author(
authors_repo: Annotated[AuthorRepository, Depends(provide_authors_repo)],
author_id: UUID,
) -> None:
"""Delete a author from the system."""
_ = await authors_repo.delete(author_id)
await authors_repo.session.commit()


app.include_router(author_router)
88 changes: 88 additions & 0 deletions examples/standalone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
from pathlib import Path
from typing import Any

from rich import get_console
from sqlalchemy import create_engine
from sqlalchemy.orm import Mapped, Session, sessionmaker

from advanced_alchemy.base import UUIDBase
from advanced_alchemy.filters import LimitOffset
from advanced_alchemy.repository import SQLAlchemySyncRepository

here = Path(__file__).parent
console = get_console()


class USState(UUIDBase):
# you can optionally override the generated table name by manually setting it.
__tablename__ = "us_state_lookup" # type: ignore[assignment]
abbreviation: Mapped[str]
name: Mapped[str]


class USStateRepository(SQLAlchemySyncRepository[USState]):
"""US State repository."""

model_type = USState


engine = create_engine(
"duckdb:///:memory:",
future=True,
)
session_factory: sessionmaker[Session] = sessionmaker(engine, expire_on_commit=False)


def open_fixture(fixtures_path: Path, fixture_name: str) -> Any:
"""Loads JSON file with the specified fixture name
Args:
fixtures_path (Path): The path to look for fixtures
fixture_name (str): The fixture name to load.
Raises:
FileNotFoundError: Fixtures not found.
Returns:
Any: The parsed JSON data
"""
fixture = Path(fixtures_path / f"{fixture_name}.json")
if fixture.exists():
with fixture.open(mode="r", encoding="utf-8") as f:
f_data = f.read()
return json.loads(f_data)
msg = f"Could not find the {fixture_name} fixture"
raise FileNotFoundError(msg)


def run_script() -> None:
"""Load data from a fixture."""

# Initializes the database.
with engine.begin() as conn:
USState.metadata.create_all(conn)

with session_factory() as db_session:
# 1) Load the JSON data into the US States table.
repo = USStateRepository(session=db_session)
fixture = open_fixture(here, USStateRepository.model_type.__tablename__) # type: ignore[has-type]
objs = repo.add_many([USStateRepository.model_type(**raw_obj) for raw_obj in fixture])
db_session.commit()
console.print(f"Created {len(objs)} new objects.")

# 2) Select paginated data and total row count.
created_objs, total_objs = repo.list_and_count(LimitOffset(limit=10, offset=0))
console.print(f"Selected {len(created_objs)} records out of a total of {total_objs}.")

# 3) Let's remove the batch of records selected.
deleted_objs = repo.delete_many([new_obj.id for new_obj in created_objs])
console.print(f"Removed {len(deleted_objs)} records out of a total of {total_objs}.")

# 4) Let's count the remaining rows
remaining_count = repo.count()
console.print(f"Found {remaining_count} remaining records after delete.")


if __name__ == "__main__":
run_script()

0 comments on commit ae0cb75

Please sign in to comment.