Skip to content

Commit

Permalink
feat: Add SQLA persistence handles, update class attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
adhtruong committed Sep 17, 2023
1 parent ae1e78b commit 0d89d72
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 113 deletions.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ repos:
mongomock_motor,
msgspec,
odmantic,
sqlalchemy>=2,
pydantic>=2,
pytest,
sphinx,
Expand All @@ -67,7 +66,6 @@ repos:
mongomock_motor,
msgspec,
odmantic,
sqlalchemy>=2,
pydantic>=2,
pytest,
sphinx,
Expand Down
31 changes: 28 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 4 additions & 31 deletions polyfactory/factories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,7 @@
from os.path import realpath
from pathlib import Path
from random import Random
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Collection,
Generic,
Mapping,
Sequence,
Type,
TypeVar,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Generic, Mapping, Sequence, Type, TypeVar, cast
from uuid import NAMESPACE_DNS, UUID, uuid1, uuid3, uuid5

from faker import Faker
Expand All @@ -46,21 +34,10 @@
MIN_COLLECTION_LENGTH,
RANDOMIZE_COLLECTION_LENGTH,
)
from polyfactory.exceptions import (
ConfigurationException,
MissingBuildKwargException,
ParameterException,
)
from polyfactory.exceptions import ConfigurationException, MissingBuildKwargException, ParameterException
from polyfactory.fields import Fixture, Ignore, PostGenerated, Require, Use
from polyfactory.utils.helpers import unwrap_annotation, unwrap_args, unwrap_optional
from polyfactory.utils.predicates import (
get_type_origin,
is_any,
is_literal,
is_optional,
is_safe_subclass,
is_union,
)
from polyfactory.utils.predicates import get_type_origin, is_any, is_literal, is_optional, is_safe_subclass, is_union
from polyfactory.value_generators.complex_types import handle_collection_type
from polyfactory.value_generators.constrained_collections import (
handle_constrained_collection,
Expand All @@ -76,11 +53,7 @@
from polyfactory.value_generators.constrained_strings import handle_constrained_string_or_bytes
from polyfactory.value_generators.constrained_url import handle_constrained_url
from polyfactory.value_generators.constrained_uuid import handle_constrained_uuid
from polyfactory.value_generators.primitives import (
create_random_boolean,
create_random_bytes,
create_random_string,
)
from polyfactory.value_generators.primitives import create_random_boolean, create_random_bytes, create_random_string

if TYPE_CHECKING:
from typing_extensions import TypeGuard
Expand Down
69 changes: 63 additions & 6 deletions polyfactory/factories/sqlalchemy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from polyfactory.exceptions import MissingDependencyException
from polyfactory.factories.base import BaseFactory
from polyfactory.field_meta import FieldMeta
from polyfactory.persistence import AsyncPersistenceProtocol, SyncPersistenceProtocol

try:
from sqlalchemy import Column, inspect, orm, types
Expand All @@ -16,23 +17,59 @@
raise MissingDependencyException("sqlalchemy is not installed") from e

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from typing_extensions import TypeGuard


T = TypeVar("T", bound=orm.DeclarativeBase)


class SQLASyncPersistence(SyncPersistenceProtocol[T]):
def __init__(self, session: Session) -> None:
self.session = session

def save(self, data: T) -> T:
self.session.add(data)
self.session.commit()
return data

def save_many(self, data: list[T]) -> list[T]:
self.session.add_all(data)
self.session.commit()
return data


class SQLAASyncPersistence(AsyncPersistenceProtocol[T]):
def __init__(self, session: AsyncSession) -> None:
self.session = session

async def save(self, data: T) -> T:
self.session.add(data)
await self.session.commit()
return data

async def save_many(self, data: list[T]) -> list[T]:
self.session.add_all(data)
await self.session.commit()
return data


class SQLAlchemyFactory(Generic[T], BaseFactory[T]):
"""Base factory for SQLAlchemy models."""

__is_base_factory__ = True
__resolve_primary_key__: ClassVar[bool] = True

__set_primary_key__: ClassVar[bool] = True
"""Configuration to consider primary key columns as a field or not."""
__resolve_foreign_keys__: ClassVar[bool] = True
__set_foreign_keys__: ClassVar[bool] = True
"""Configuration to consider columns with foreign keys as a field or not."""
__resolve_relationships__: ClassVar[bool] = False
__set_relationships__: ClassVar[bool] = False
"""Configuration to consider relationships property as a model field or not."""

__session__: ClassVar[Session | Callable[[], Session] | None] = None
__async_session__: ClassVar[AsyncSession | Callable[[], AsyncSession] | None] = None

@classmethod
def get_sqlalchemy_types(cls) -> dict[Any, Callable[[], Any]]:
"""Get mapping of types where column type."""
Expand Down Expand Up @@ -66,10 +103,10 @@ def is_supported_type(cls, value: Any) -> TypeGuard[type[T]]:

@classmethod
def should_column_be_set(cls, column: Column) -> bool:
if not cls.__resolve_primary_key__ and column.primary_key:
if not cls.__set_primary_key__ and column.primary_key:
return False

return bool(cls.__resolve_foreign_keys__ or not column.foreign_keys)
return bool(cls.__set_foreign_keys__ or not column.foreign_keys)

@classmethod
def get_type_from_column(cls, column: Column) -> type:
Expand Down Expand Up @@ -103,7 +140,7 @@ def get_model_fields(cls) -> list[FieldMeta]:
for name, column in table.columns.items()
if cls.should_column_be_set(column)
)
if cls.__resolve_relationships__:
if cls.__set_relationships__:
for name, relationship in table.relationships.items():
class_ = relationship.entity.class_
annotation = class_ if not relationship.uselist else List[class_] # type: ignore[valid-type]
Expand All @@ -119,3 +156,23 @@ def get_model_fields(cls) -> list[FieldMeta]:
)

return fields_meta

@classmethod
def _get_sync_persistence(cls) -> SyncPersistenceProtocol[T]:
if cls.__session__ is not None:
return (
SQLASyncPersistence(cls.__session__())
if callable(cls.__session__)
else SQLASyncPersistence(cls.__session__)
)
return super()._get_sync_persistence()

@classmethod
def _get_async_persistence(cls) -> AsyncPersistenceProtocol[T]:
if cls.__async_session__ is not None:
return (
SQLAASyncPersistence(cls.__async_session__())
if callable(cls.__async_session__)
else SQLAASyncPersistence(cls.__async_session__)
)
return super()._get_async_persistence()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ mongomock-motor = "*"
pytest = "*"
pytest-asyncio = "*"
pytest-cov = "*"
aiosqlite = "^0.19.0"

[tool.poetry.group.docs.dependencies]
auto-pytabs = "*"
Expand Down Expand Up @@ -105,7 +106,6 @@ include = ["polyfactory", "tests", "examples"]
omit = ["*/tests/*"]

[tool.pytest.ini_options]
addopts = "tests docs/examples"
asyncio_mode = "auto"
filterwarnings = [
"ignore:.*pkg_resources.declare_namespace\\('sphinxcontrib'\\).*:DeprecationWarning",
Expand Down
Loading

0 comments on commit 0d89d72

Please sign in to comment.