Skip to content

Commit

Permalink
fix: favour SA mapped type over impl type (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Mar 29, 2024
1 parent 719495e commit bb04b4e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
9 changes: 4 additions & 5 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,10 @@ def get_type_from_column(cls, column: Column) -> type:
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
else:
annotation = (
column.type.impl.python_type # pyright: ignore[reportGeneralTypeIssues]
if hasattr(column.type, "impl")
else column.type.python_type
)
try:
annotation = column.type.python_type
except NotImplementedError:
annotation = column.type.impl.python_type # type: ignore[attr-defined]

if column.nullable:
annotation = Union[annotation, None] # type: ignore[assignment]
Expand Down
37 changes: 36 additions & 1 deletion tests/sqlalchemy_factory/test_sqlalchemy_factory_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Type
from typing import Any, Callable, Type, Union
from uuid import UUID

import pytest
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine, inspect, orm, types
Expand Down Expand Up @@ -347,3 +348,37 @@ class ModelFactory(SQLAlchemyFactory[ModelWithAlias]):

result = ModelFactory.build()
assert isinstance(result.name, str)


@pytest.mark.parametrize("python_type_", (UUID, None))
def test_sqlalchemy_custom_type_from_type_decorator(python_type_: Union[type, None]) -> None:
class CustomType(types.TypeDecorator):
impl = types.CHAR(32)
cache_ok = True

if python_type_ is not None:

@property
def python_type(self) -> type:
return python_type_

class Base(metaclass=DeclarativeMeta):
__abstract__ = True
__allow_unmapped__ = True

registry = _registry
metadata = _registry.metadata

class Model(Base):
__tablename__ = f"model_with_custom_types_{python_type_}"

id: Any = Column(Integer(), primary_key=True)
custom_type: Any = Column(type_=CustomType(), nullable=False)

class ModelFactory(SQLAlchemyFactory[Model]):
__model__ = Model

instance = ModelFactory.build()

expected_type = python_type_ if python_type_ is not None else CustomType.impl.python_type
assert isinstance(instance.custom_type, expected_type)

0 comments on commit bb04b4e

Please sign in to comment.