Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sqlfactory): support nested type in pg.array types and others #530

Merged
merged 12 commits into from
May 9, 2024
Merged
11 changes: 10 additions & 1 deletion polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
postgresql.NUMRANGE: lambda: tuple(sorted([cls.__faker__.pyint(), cls.__faker__.pyint()])),
postgresql.TSRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.TSTZRANGE: lambda: (cls.__faker__.past_datetime(), datetime.now()), # noqa: DTZ005
postgresql.HSTORE: lambda: cls.__faker__.pydict(),
wangxin688 marked this conversation as resolved.
Show resolved Hide resolved
# `types.JSON` is compatible for sqlachemy extend dialects. Such as `pg.JSON` and `JSONB`
types.JSON: lambda: cls.__faker__.pydict(),
}

@classmethod
Expand Down Expand Up @@ -124,8 +127,14 @@ def should_column_be_set(cls, column: Any) -> bool:
@classmethod
def get_type_from_column(cls, column: Column) -> type:
wangxin688 marked this conversation as resolved.
Show resolved Hide resolved
column_type = type(column.type)
if column_type in cls.get_sqlalchemy_types():
sqla_types = cls.get_sqlalchemy_types()
if column_type in sqla_types:
annotation = column_type
elif issubclass(column_type, postgresql.ARRAY):
if type(column.type.item_type) in sqla_types: # type: ignore[attr-defined]
annotation = List[type(column.type.item_type)] # type: ignore[attr-defined,misc]
else:
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
elif issubclass(column_type, types.ARRAY):
annotation = List[column.type.item_type.python_type] # type: ignore[assignment,name-defined]
else:
Expand Down
68 changes: 66 additions & 2 deletions tests/sqlalchemy_factory/test_sqlalchemy_factory_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from enum import Enum
from typing import Any, List
from ipaddress import ip_network
from typing import Any, Dict, List
from uuid import UUID

import pytest
from sqlalchemy import ForeignKey, __version__, orm, types
from sqlalchemy import ForeignKey, Text, __version__, orm, types
from sqlalchemy.dialects.mssql import JSON as MSSQL_JSON
from sqlalchemy.dialects.mysql import JSON as MYSQL_JSON
from sqlalchemy.dialects.postgresql import ARRAY, CIDR, HSTORE, INET, JSON, JSONB
from sqlalchemy.dialects.sqlite import JSON as SQLITE_JSON
from sqlalchemy.ext.mutable import MutableDict, MutableList

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory

Expand Down Expand Up @@ -64,6 +71,63 @@ class ModelFactory(SQLAlchemyFactory[Model]):
assert isinstance(instance.str_array_type[0], str)


def test_pg_dialect_types() -> None:
class Base(orm.DeclarativeBase): ...

class SqlaModel(Base):
__tablename__ = "sql_models"
id: orm.Mapped[int] = orm.mapped_column(primary_key=True)
uuid_type: orm.Mapped[UUID] = orm.mapped_column(type_=types.UUID)
nested_array_inet: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(INET, dimensions=1))
nested_array_cidr: orm.Mapped[List[str]] = orm.mapped_column(type_=ARRAY(CIDR, dimensions=1))
hstore_type: orm.Mapped[Dict] = orm.mapped_column(type_=HSTORE)
mut_nested_arry_inet: orm.Mapped[List[str]] = orm.mapped_column(
type_=MutableList.as_mutable(ARRAY(INET, dimensions=1))
)
pg_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSON)
pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(type_=JSONB)
common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=types.JSON)
mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MYSQL_JSON)
sqlite_json: orm.Mapped[Dict] = orm.mapped_column(type_=SQLITE_JSON)
mssql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MSSQL_JSON)

multible_pg_json_type: orm.Mapped[Dict] = orm.mapped_column(
type_=MutableDict.as_mutable(JSON(astext_type=Text())) # type: ignore[no-untyped-call]
)
multible_pg_jsonb_type: orm.Mapped[Dict] = orm.mapped_column(
type_=MutableDict.as_mutable(JSONB(astext_type=Text())) # type: ignore[no-untyped-call]
)
multible_common_json_type: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(types.JSON()))
multible_mysql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(MYSQL_JSON()))
multible_sqlite_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(SQLITE_JSON()))
multible_mssql_json: orm.Mapped[Dict] = orm.mapped_column(type_=MutableDict.as_mutable(MSSQL_JSON()))

class ModelFactory(SQLAlchemyFactory[SqlaModel]):
__model__ = SqlaModel

instance = ModelFactory.build()
assert isinstance(instance.nested_array_inet[0], str)
assert ip_network(instance.nested_array_inet[0])
assert isinstance(instance.nested_array_cidr[0], str)
assert ip_network(instance.nested_array_cidr[0])
assert isinstance(instance.hstore_type, dict)
assert isinstance(instance.uuid_type, UUID)
assert isinstance(instance.mut_nested_arry_inet[0], str)
assert ip_network(instance.mut_nested_arry_inet[0])
assert isinstance(instance.pg_json_type, dict)
assert isinstance(instance.pg_jsonb_type, dict)
assert isinstance(instance.common_json_type, dict)
assert isinstance(instance.mysql_json, dict)
assert isinstance(instance.sqlite_json, dict)
assert isinstance(instance.mssql_json, dict)
assert isinstance(instance.multible_pg_json_type, dict)
assert isinstance(instance.multible_pg_jsonb_type, dict)
assert isinstance(instance.multible_common_json_type, dict)
assert isinstance(instance.multible_mysql_json, dict)
assert isinstance(instance.multible_sqlite_json, dict)
assert isinstance(instance.multible_mssql_json, dict)


@pytest.mark.parametrize(
"type_",
tuple(SQLAlchemyFactory.get_sqlalchemy_types().keys()),
Expand Down
Loading