Skip to content

Commit

Permalink
fix: to_schema inocrrectly checks for ModelMetaclass instead of `…
Browse files Browse the repository at this point in the history
…BaseModel` (#198)

* feat: revert BaseModel / ModelMetaclass check for Pydantic

* feat: export `LoadSpec`

* fix: function signature help

Co-authored-by: Janek Nouvertné <provinzkraut@posteo.de>

---------

Co-authored-by: Janek Nouvertné <provinzkraut@posteo.de>
  • Loading branch information
cofin and provinzkraut committed May 23, 2024
1 parent f302cce commit bbda30b
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 94 deletions.
8 changes: 4 additions & 4 deletions advanced_alchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
orm_insert_sentinel,
registry,
)
from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType # pyright: ignore[reportPrivateUsage]

from advanced_alchemy.types import GUID, UUID_UTILS_INSTALLED, BigIntIdentity, DateTimeUTC, JsonB

Expand All @@ -32,7 +33,6 @@
uuid7 = uuid4 # type: ignore[assignment]

if TYPE_CHECKING:
from sqlalchemy.orm.decl_base import _TableArgsType as TableArgsType # pyright: ignore[reportPrivateUsage]
from sqlalchemy.sql import FromClause
from sqlalchemy.sql.schema import (
_NamingSchemaParameter as NamingSchemaParameter, # pyright: ignore[reportPrivateUsage]
Expand Down Expand Up @@ -61,6 +61,7 @@
"SQLQuery",
"orm_registry",
"merge_table_arguments",
"TableArgsType",
)


Expand All @@ -81,7 +82,6 @@
"""Regular expression for table name"""



def merge_table_arguments(cls: type[DeclarativeBase], table_args: TableArgsType | None = None) -> TableArgsType:
"""Merge Table Arguments.
Expand All @@ -91,15 +91,15 @@ def merge_table_arguments(cls: type[DeclarativeBase], table_args: TableArgsType
Args:
cls (DeclarativeBase): This is the model that will get the table args
table_args: additional information to add to tableargs
table_args: additional information to add to table_args
Returns:
tuple | dict: The merged __table_args__ property
"""
args: list[Any] = []
kwargs: dict[str, Any] = {}

mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in cls.__bases__)
mixin_table_args = (getattr(super(base_cls, cls), "__table_args__", None) for base_cls in cls.__bases__) # pyright: ignore[reportUnknownParameter,reportUnknownArgumentType,reportArgumentType]

for arg_to_merge in (*mixin_table_args, table_args):
if arg_to_merge:
Expand Down
3 changes: 2 additions & 1 deletion advanced_alchemy/repository/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
SQLAlchemySyncRepository,
SQLAlchemySyncSlugRepository,
)
from advanced_alchemy.repository._util import get_instrumented_attr, model_from_dict
from advanced_alchemy.repository._util import LoadSpec, get_instrumented_attr, model_from_dict

__all__ = (
"SQLAlchemyAsyncRepository",
Expand All @@ -19,4 +19,5 @@
"SQLAlchemySyncQueryRepository",
"get_instrumented_attr",
"model_from_dict",
"LoadSpec",
)
43 changes: 23 additions & 20 deletions advanced_alchemy/service/_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from uuid import UUID

from advanced_alchemy.exceptions import AdvancedAlchemyError
from advanced_alchemy.filters import FilterTypes, LimitOffset
from advanced_alchemy.repository.typing import ModelOrRowMappingT
from advanced_alchemy.service.pagination import OffsetPagination
Expand All @@ -34,20 +35,17 @@ def convert(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef] # noqa:


try:
from pydantic.main import ModelMetaclass # pyright: ignore[reportAssignmentType]
from pydantic import BaseModel # pyright: ignore[reportAssignmentType]
from pydantic.type_adapter import TypeAdapter # pyright: ignore[reportAssignmentType]
except ImportError: # pragma: nocover

class ModelMetaclass: # type: ignore[no-redef]
class BaseModel: # type: ignore[no-redef]
"""Placeholder Implementation"""

class TypeAdapter: # type: ignore[no-redef]
"""Placeholder Implementation"""


EMPTY_FILTER: list[FilterTypes] = []


def _default_deserializer(
target_type: Any,
value: Any,
Expand Down Expand Up @@ -101,10 +99,24 @@ def _find_filter(
def to_schema(
data: ModelOrRowMappingT | Sequence[ModelOrRowMappingT],
total: int | None = None,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] = EMPTY_FILTER,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] | None = None,
schema_type: type[ModelDTOT] | None = None,
) -> ModelOrRowMappingT | OffsetPagination[ModelOrRowMappingT] | ModelDTOT | OffsetPagination[ModelDTOT]:
if schema_type is not None and issubclass(schema_type, Struct):
if filters is None:
filters = []
if schema_type is None:
if not issubclass(type(data), Sequence):
return cast("ModelOrRowMappingT", data)
limit_offset = _find_filter(LimitOffset, filters=filters)
total = total or len(data) # type: ignore[arg-type]
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0) # type: ignore[arg-type]
return OffsetPagination[ModelOrRowMappingT](
items=cast("List[ModelOrRowMappingT]", data),
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
)
if issubclass(schema_type, Struct):
if not isinstance(data, Sequence):
return convert( # type: ignore # noqa: PGH003
obj=data,
Expand Down Expand Up @@ -137,9 +149,9 @@ def to_schema(
total=total,
)

if schema_type is not None and issubclass(schema_type, ModelMetaclass):
if issubclass(schema_type, BaseModel):
if not isinstance(data, Sequence):
return TypeAdapter(schema_type).validate_python(data, from_attributes=True) # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType,reportAttributeAccessIssue,reportCallIssue]
return TypeAdapter(schema_type).validate_python(data, from_attributes=True) # type: ignore[return-value] # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType,reportAttributeAccessIssue,reportCallIssue]
limit_offset = _find_filter(LimitOffset, filters=filters)
total = total if total else len(data)
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0)
Expand All @@ -149,14 +161,5 @@ def to_schema(
offset=limit_offset.offset,
total=total,
)
if not issubclass(type(data), Sequence):
return cast("ModelOrRowMappingT", data)
limit_offset = _find_filter(LimitOffset, filters=filters)
total = total or len(data) # type: ignore[arg-type]
limit_offset = limit_offset if limit_offset is not None else LimitOffset(limit=len(data), offset=0) # type: ignore[arg-type]
return OffsetPagination[ModelOrRowMappingT](
items=cast("List[ModelOrRowMappingT]", data),
limit=limit_offset.limit,
offset=limit_offset.offset,
total=total,
)
msg = "`schema_type` should be a valid Pydantic or Msgspec schema"
raise AdvancedAlchemyError(msg)
19 changes: 11 additions & 8 deletions advanced_alchemy/service/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import TYPE_CHECKING, overload

from advanced_alchemy.service._converters import EMPTY_FILTER, to_schema
from advanced_alchemy.service._converters import to_schema

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -31,40 +31,43 @@ def to_schema(
self,
data: ModelOrRowMappingT,
total: int | None = None,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] = EMPTY_FILTER,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] = ...,
) -> ModelOrRowMappingT: ...

@overload
def to_schema(
self,
data: Sequence[ModelOrRowMappingT],
total: int | None = None,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] = EMPTY_FILTER,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] | None = None,
) -> OffsetPagination[ModelOrRowMappingT]: ...

@overload
def to_schema(
self,
data: ModelProtocol | RowMapping,
total: int | None = None,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] = EMPTY_FILTER,
schema_type: type[ModelDTOT] = ...,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] | None = None,
*,
schema_type: type[ModelDTOT],
) -> ModelDTOT: ...

@overload
def to_schema(
self,
data: Sequence[ModelOrRowMappingT],
total: int | None = None,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] = EMPTY_FILTER,
schema_type: type[ModelDTOT] = ...,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] | None = None,
*,
schema_type: type[ModelDTOT],
) -> OffsetPagination[ModelDTOT]: ...

def to_schema(
self,
data: ModelOrRowMappingT | Sequence[ModelOrRowMappingT],
total: int | None = None,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] = EMPTY_FILTER,
filters: Sequence[FilterTypes | ColumnElement[bool]] | Sequence[FilterTypes] | None = None,
*,
schema_type: type[ModelDTOT] | None = None,
) -> ModelOrRowMappingT | OffsetPagination[ModelOrRowMappingT] | ModelDTOT | OffsetPagination[ModelDTOT]:
"""Convert the object to a response schema. When `schema_type` is None, the model is returned with no conversion.
Expand Down

0 comments on commit bbda30b

Please sign in to comment.