Skip to content

Commit

Permalink
fix: sqlalchemy dto for models non Column fields (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdulhaq-e committed Oct 25, 2023
1 parent 090725a commit c17c83e
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
12 changes: 7 additions & 5 deletions advanced_alchemy/extensions/litestar/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
RelationshipDirection,
RelationshipProperty,
)
from sqlalchemy.sql.expression import ColumnClause, Label

from advanced_alchemy.exceptions import ImproperConfigurationError

Expand All @@ -38,7 +39,7 @@

T = TypeVar("T", bound="DeclarativeBase | Collection[DeclarativeBase]")

ElementType: TypeAlias = "Column | RelationshipProperty | CompositeProperty"
ElementType: TypeAlias = "Column | RelationshipProperty | CompositeProperty | ColumnClause | Label"
SQLA_NS = {**vars(orm), **vars(sql)}


Expand Down Expand Up @@ -102,8 +103,8 @@ def _(

elem: ElementType
if isinstance(orm_descriptor.property, ColumnProperty):
if not isinstance(orm_descriptor.property.expression, Column):
msg = f"Expected 'Column', got: '{orm_descriptor.property.expression}'"
if not isinstance(orm_descriptor.property.expression, (Column, ColumnClause, Label)):
msg = f"Expected 'Column', got: '{orm_descriptor.property.expression}, {type(orm_descriptor.property.expression)}'"
raise NotImplementedError(msg)
elem = orm_descriptor.property.expression
elif isinstance(orm_descriptor.property, (RelationshipProperty, CompositeProperty)):
Expand All @@ -123,6 +124,7 @@ def _(
except KeyError:
field_definition = parse_type_from_element(elem)

dto_field = elem.info.get(DTO_FIELD_META_KEY, DTOField()) if hasattr(elem, "info") else DTOField()
return [
DTOFieldDefinition.from_field_definition(
field_definition=replace(
Expand All @@ -131,7 +133,7 @@ def _(
default=default,
),
default_factory=default_factory,
dto_field=elem.info.get(DTO_FIELD_META_KEY, DTOField()),
dto_field=dto_field,
model_name=model_name,
),
]
Expand Down Expand Up @@ -348,7 +350,7 @@ def parse_type_from_element(elem: ElementType) -> FieldDefinition:
if isinstance(elem, CompositeProperty):
return FieldDefinition.from_annotation(elem.composite_class)

msg = f"Unable to parse type from element '{elem}'. Consider adding a type hint." # type: ignore[unreachable]
msg = f"Unable to parse type from element '{elem}'. Consider adding a type hint."
raise ImproperConfigurationError(
msg,
)
Expand Down
38 changes: 36 additions & 2 deletions tests/unit/test_extensions/test_litestar/test_dto_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from types import ModuleType
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple

import pytest
from litestar import get, post
Expand All @@ -12,7 +12,7 @@
from litestar.dto.field import DTO_FIELD_META_KEY
from litestar.dto.types import RenameStrategy
from litestar.testing import create_test_client
from sqlalchemy import Column, ForeignKey, Integer, String, Table
from sqlalchemy import Column, ForeignKey, Integer, String, Table, func, select
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import (
DeclarativeBase,
Expand Down Expand Up @@ -56,6 +56,9 @@ class Book(Base):
bar: Mapped[str] = mapped_column(default="Hello") # pyright: ignore
SPAM: Mapped[str] = mapped_column(default="Bye") # pyright: ignore
spam_bar: Mapped[str] = mapped_column(default="Goodbye") # pyright: ignore
number_of_reviews: Mapped[Optional[int]] = column_property( # noqa: UP007
select(func.count(BookReview.id)).where(BookReview.book_id == id).scalar_subquery(),
)


@dataclass
Expand All @@ -70,6 +73,7 @@ class BookAuthorTestData:
book_spam_bar: str = "GoodBye"
book_review_id: str = "23432"
book_review: str = "Excellent!"
number_of_reviews: int | None = None


@pytest.fixture
Expand All @@ -94,6 +98,7 @@ def _generate(rename_strategy: RenameStrategy, test_data: BookAuthorTestData) ->
_rename_field(name="review", strategy=rename_strategy): test_data.book_review,
},
],
_rename_field(name="number_of_reviews", strategy=rename_strategy): test_data.number_of_reviews,
}
book = Book(
id=test_data.book_id,
Expand Down Expand Up @@ -147,6 +152,35 @@ def get_handler() -> Book:
assert response_callback.json() == json_data


class ConcreteBase(Base):
pass


func_result_query = select(func.count(1)).scalar_subquery()
model_with_func_query = select(ConcreteBase, func_result_query.label("func_result")).subquery()


class ModelWithFunc(Base):
__table__ = model_with_func_query
func_result: Mapped[Optional[int]] = column_property(model_with_func_query.c.func_result) # noqa: UP007


def test_model_using_func() -> None:
instance = ModelWithFunc(id="hi")
config = SQLAlchemyDTOConfig()
dto = SQLAlchemyDTO[Annotated[ModelWithFunc, config]]

@get(dto=dto, signature_namespace={"ModelWithFunc": ModelWithFunc})
def get_handler() -> ModelWithFunc:
return instance

with create_test_client(
route_handlers=[get_handler],
) as client:
response_callback = client.get("/")
assert response_callback


def test_dto_with_association_proxy(create_module: Callable[[str], ModuleType]) -> None:
module = create_module(
"""
Expand Down

0 comments on commit c17c83e

Please sign in to comment.