Skip to content

Commit

Permalink
fix: favour SA mapped type over impl type
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Mar 23, 2024
1 parent 719495e commit 82d64d4
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
9 changes: 4 additions & 5 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,10 @@ def get_type_from_column(cls, column: Column) -> type:
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
else:
annotation = (
column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues]
if hasattr(column.type, "impl")
else column.type.python_type
)
try:
annotation = column.type.python_type
except NotImplementedError:
annotation = column.type.impl.python_type # type: ignore[attr-defined]

if column.nullable:
annotation = Union[annotation, None] # type: ignore[assignment]
Expand Down
43 changes: 43 additions & 0 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Type
from uuid import UUID

import pytest
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine, inspect, orm, types
Expand Down Expand Up @@ -347,3 +348,45 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]):

result = ModelFactory.build()
assert isinstance(result.name, str)


@pytest.mark.parametrize("python_type_", (UUID, None))
@pytest.mark.parametrize(
"impl_",
(
types.Uuid(),
types.Uuid(native_uuid=False),
types.CHAR(32),
),
)
def test_sqlalchemy_custom_type_from_type_decorator(impl_: types.TypeEngine, python_type_: type | None) -> None:
class CustomType(types.TypeDecorator):
impl = impl_
cache_ok = True

if python_type_ is not None:

@property
def python_type(self) -> type:
return python_type_

class Base(orm.DeclarativeBase):
type_annotation_map = {
UUID: CustomType,
}

class Model(Base):
__tablename__ = "model_with_custom_types"

id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
custom_type: orm.Mapped[UUID] = orm.mapped_column(type_=CustomType(), nullable=False)
custom_type_from_annotation_map: orm.Mapped[UUID]

class ModelFactory(SQLAlchemyFactory[Model]):
__model__ = Model

instance = ModelFactory.build()

expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type
assert isinstance(instance.custom_type, expected_type)
assert isinstance(instance.custom_type_from_annotation_map, expected_type)

0 comments on commit 82d64d4

Please sign in to comment.