Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqlfactory): support nested type in pg.array types and others #530

Merged
merged 12 commits into from
May 9, 2024
Merged
8 changes: 7 additions & 1 deletion polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.HSTORE: lambda: cls.__faker__.pydict(),
wangxin688 marked this conversation as resolved.
Show resolved Hide resolved
}

@classmethod
Expand Down Expand Up @@ -124,8 +125,13 @@ def should_column_be_set(cls, column: Any) -> bool:
@classmethod
def get_type_from_column(cls, column: Column) -> type:
wangxin688 marked this conversation as resolved.
Show resolved Hide resolved
column_type = type(column.type)
if column_type in cls.get_sqlalchemy_types():
if column_type in (sqla_types := cls.get_sqlalchemy_types()):
wangxin688 marked this conversation as resolved.
Show resolved Hide resolved
annotation = column_type
elif issubclass(column_type, postgresql.ARRAY):
if type(column.type.item_type) in sqla_types: # type: ignore[attr-defined]
annotation = List[type(column.type.item_type)] # type: ignore[attr-defined,misc]
else:
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
else:
Expand Down
35 changes: 34 additions & 1 deletion tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from enum import Enum
from typing import Any, List
from ipaddress import ip_network
from typing import Any, Dict, List
from uuid import UUID

import pytest
from sqlalchemy import ForeignKey, __version__, orm, types
from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET
from sqlalchemy.ext.mutable import MutableList

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

Expand Down Expand Up @@ -64,6 +68,35 @@ class ModelFactory(SQLAlchemyFactory[Model]):
assert isinstance(instance.str_array_type[0], str)


def test_pg_dialect_types() -> None:
class Base(orm.DeclarativeBase): ...

class PgModel(Base):
__tablename__ = "pgmodel"
id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
uuid_type: orm.Mapped[UUID] = orm.mapped_column(type_=types.UUID)
nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1))
nested_array_cidr: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1))
hstore_type: orm.Mapped[Dict] = orm.mapped_column(type_=HSTORE)
mut_nested_arry_inet: orm.Mapped[List[str]] = orm.mapped_column(
type_=MutableList.as_mutable(ARRAY(INET, dimensions=1))
)

class ModelFactory(SQLAlchemyFactory[PgModel]):
__model__ = PgModel

instance = ModelFactory.build()

assert isinstance(instance.nested_array_inet[0], str)
assert ip_network(instance.nested_array_inet[0])
assert isinstance(instance.nested_array_cidr[0], str)
assert ip_network(instance.nested_array_cidr[0])
assert isinstance(instance.hstore_type, dict)
assert isinstance(instance.uuid_type, UUID)
assert isinstance(instance.mut_nested_arry_inet[0], str)
assert ip_network(instance.mut_nested_arry_inet[0])


@pytest.mark.parametrize(
"type_",
tuple(SQLAlchemyFactory.get_sqlalchemy_types().keys()),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_recursive_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_recursive_pydantic_models(factory_use_construct: bool) -> None:
factory = ModelFactory.create_factory(PydanticNode)

result = factory.build(factory_use_construct)
assert result.child is _Sentinel, "Default is not used"
assert result.child is _Sentinel, "Default is not used" # type: ignore[comparison-overlap]
assert isinstance(result.union_child, int)
assert result.optional_child is None
assert result.list_child == []
Expand Down
Loading