Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for relationship building models in SQLAlchemyFactory #467

Closed
0x587 opened this issue Dec 23, 2023 · 10 comments · Fixed by #468
Closed

Support for relationship building models in SQLAlchemyFactory #467

0x587 opened this issue Dec 23, 2023 · 10 comments · Fixed by #468
Labels
enhancement New feature or request

Comments

@0x587
Copy link

0x587 commented Dec 23, 2023

Summary

Suppose I have a series of cascading models like A, B, C. When I call the factory for A and enable the __set_relationships__ parameter, the factory will build model A and have B in each A, but no C in B.
just like:
A
|-B1
|- Expected C but did not
|-B2
|- Expected C but did not

I went and read the source code and tried to modify it, and found that it was because the __set_relationships__ parameter was not set when recursively constructing the factory. When I tried to set the parameter manually it triggered a recursive error, and this feature may require a larger change to the structure of the code.

I am trying to use this library for database mock fake data, if implemented this feature will be able to stage my work.

Basic Example

Here's a sample test

class Author(Base):
    __tablename__ = "authors"

    id: Any = Column(Integer(), primary_key=True)
    books: Any = orm.relationship(
        "Book",
        uselist=True,
        back_populates="author",
    )


class Book(Base):
    __tablename__ = "books"

    id: Any = Column(Integer(), primary_key=True)
    author_id: Any = Column(
        Integer(),
        ForeignKey(Author.id),
        nullable=False,
    )
    author: Any = orm.relationship(
        Author,
        uselist=False,
        back_populates="books",
    )
    comments: Any = orm.relationship(
        "Comment",
        uselist=True,
        back_populates="book",
    )


class Comment(Base):
    __tablename__ = "comments"

    id: Any = Column(Integer(), primary_key=True)
    book_id: Any = Column(
        Integer(),
        ForeignKey(Book.id),
        nullable=False,
    )
    book: Any = orm.relationship(
        Book,
        uselist=False,
        back_populates="comments",
    )

def test_recursive_relationship_resolution() -> None:
    class AuthorFactory(SQLAlchemyFactory[Author]):
        __model__ = Author
        __set_relationships__ = True

    result = AuthorFactory.build()
    assert isinstance(result.books, list)
    assert isinstance(result.books[0], Book)
    assert isinstance(result.books[0].comments, list)
    assert isinstance(result.books[0].comments[0], Comment)

Drawbacks and Impact

I try add code in
https://github.com/litestar-org/polyfactory/blob/8dc8e1a4594a75ad9a16e1b6f5041b6044fc4f51/polyfactory/factories/base.py#L650C1-L650C80

like that

if BaseFactory.is_batch_factory_type(annotation=unwrapped_annotation):
    factory = cls._get_or_create_factory(model=field_meta.type_args[0])
    if hasattr(factory, '__set_relationships__'):
        setattr(factory, '__set_relationships__', getattr(cls, '__set_relationships__'))
    ......

This only solves recursive construction in one-to-many or many-to-many relationships and is very inelegant, raising recursion errors when I try to extend it to one-to-one or many-to-one relationships.

like that

if BaseFactory.is_factory_type(annotation=unwrapped_annotation):
    factory = cls._get_or_create_factory(model=unwrapped_annotation)
    if hasattr(factory, '__set_relationships__'):
        setattr(factory, '__set_relationships__', getattr(cls, '__set_relationships__'))

    return factory.build(
        **(field_build_parameters if isinstance(field_build_parameters, Mapping) else {}),
    )

Unresolved questions

No response


Note

While we are open for sponsoring on GitHub Sponsors and
OpenCollective, we also utilize Polar.sh to engage in pledge-based sponsorship.

Check out all issues funded or available for funding on our Polar.sh dashboard

  • If you would like to see an issue prioritized, make a pledge towards it!
  • We receive the pledge once the issue is completed & verified
  • This, along with engagement in the community, helps us know which features are a priority to our users.
Fund with Polar
@0x587 0x587 added the enhancement New feature or request label Dec 23, 2023
@0x587
Copy link
Author

0x587 commented Dec 28, 2023

I'm trying to build a factory for each model, building from the bottom up in the order in which the models are referenced.
Referring to the example above, the AuthorFactory constructs a number of authors, but when the BookFactory starts constructing, the book's author attribute reconstructs new authors instead of using the authors already constructed by the AuthorFactory.
Perhaps we could prioritize the use of already generated factories when constructing relationships instead of constructing a new factory.
Maybe we can correct it from this?

@0x587 0x587 changed the title Support for recursively building models in SQLAlchemyFactory Support for relationship building models in SQLAlchemyFactory Dec 28, 2023
@0x587
Copy link
Author

0x587 commented Dec 28, 2023

I've implemented a usable multilevel generation example with some factory configurations that I hope will be helpful.

Rule

My thinking is to only set __set_relationships__ on the factories of classes that have foreign keys, such as Grade and Course. If a relationship attribute points to a class with __set_relationships__ set to prevent recursion, the attribute returns Noneor [] to prevent recursion, such as Course.grades.


Here's a diagram of the foreign key associations for these models.
image

Model define

class DBStudent(Base):
    __tablename__ = "student"
    id: Mapped[uuid.UUID] = mapped_column(sqltypes.UUID, primary_key=True)
    name: Mapped[str] = mapped_column(sqltypes.String, nullable=False)

    grades: Mapped[List["DBGrade"]] = relationship(back_populates="student")

class DBCourse(Base):
    __tablename__ = "course"
    id: Mapped[uuid.UUID] = mapped_column(sqltypes.UUID, primary_key=True)
    name: Mapped[str] = mapped_column(sqltypes.String, nullable=False)

    grades: Mapped[List["DBGrade"]] = relationship(back_populates="course")
    teacher: Mapped['DBTeacher'] = relationship(back_populates="courses")
    fk_teacher_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("teacher.id"), nullable=False)

class DBTeacher(Base):
    __tablename__ = "teacher"
    id: Mapped[uuid.UUID] = mapped_column(sqltypes.UUID, primary_key=True)
    courses: Mapped[List["DBCourse"]] = relationship(back_populates="teacher")

class DBGrade(Base):
    __tablename__ = "grade"
    id: Mapped[uuid.UUID] = mapped_column(sqltypes.UUID, primary_key=True)
    grade: Mapped[float] = mapped_column(sqltypes.Float, nullable=False)

    student: Mapped['DBStudent'] = relationship(back_populates="grades")
    fk_student_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("student.id"), nullable=False)
    course: Mapped['DBCourse'] = relationship(back_populates="grades")
    fk_course_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("course.id"), nullable=False)

Factory define

class StudentFactory(SQLAlchemyFactory[DBStudent]):
    __sync_persistence__ = True

class CourseFactory(SQLAlchemyFactory[DBCourse]):
    __sync_persistence__ = True
    __set_relationships__ = True

    @classmethod
    def teacher(cls):
        return TeacherFactory.build()

    @classmethod
    def grades(cls):
        return []

class TeacherFactory(SQLAlchemyFactory[DBTeacher]):
    __sync_persistence__ = True

class GradeFactory(SQLAlchemyFactory[DBGrade]):
    __sync_persistence__ = True
    __set_relationships__ = True

    @classmethod
    def student(cls):
        return StudentFactory.build()

    @classmethod
    def course(cls):
        return CourseFactory.build()

Main call

async def main():
    async with sessionmanager.session() as session:
        GradeFactory.__async_session__ = session
        await GradeFactory.create_batch_async(size=10)

@adhtruong
Copy link
Collaborator

I think the settings here may cover some of the desired use cases https://polyfactory.litestar.dev/reference/factories/base.html#polyfactory.factories.base.BaseFactory.__set_as_default_factory_for_type__.

class MySQLAlchemyFactory(SQLAlchemyFactory):
     __is_base_factory__ = True
     __set_relationships__ = True

will allow changing this globally. Inheriting config for dynamically created subfactories is a known issue (see #426).


My thinking is to only set set_relationships on the factories of classes that have foreign keys, such as Grade and Course. If a relationship attribute points to a class with set_relationships set to prevent recursion, the attribute returns Noneor [] to prevent recursion, such as Course.grades.

I am not sure on this rule. This works for Grade but if calling TeacherFactory would generating some grades be expected? Maybe an enum/literal values would be appropriate to allow different strategy on a per factory basis, e.g. has __set_relationships__: bool | Literal['all', 'collections', 'foreign_key'] or similar

@0x587
Copy link
Author

0x587 commented Dec 30, 2023

Hi @adhtruong. Thanks for the reply.

For your suggestion.

I think the settings here may cover some of the desired use cases https://polyfactory.litestar.dev/reference/factories/base.html#polyfactory.factories.base.BaseFactory.__set_as_default_factory_for_type__.

class MySQLAlchemyFactory(SQLAlchemyFactory):
     __is_base_factory__ = True
     __set_relationships__ = True

will allow changing this globally. Inheriting config for dynamically created subfactories is a known issue (see #426).

This would cause the program to go into infinite recursion, which is why I need to artificially disable certain field generation.Like Course.grades.

@0x587
Copy link
Author

0x587 commented Dec 30, 2023

In the example I gave, the only way to get the expected generated result is to call the Grade Factory. I have created a network of model relationships and computed a spanning tree to define the factory with reference to this tree. Grade Factory is used as the root of this tree, so only Grade Factory can generate the expected results.

Giving __set_relationships__ more possible values might be a good solution, but would that be a big change to the whole library?

I wrote a script to apply networkx to compute the spanning tree of model relationships and then generate the factory definition code, which is my current solution.

Perhaps we can find a more elegant way to generate a series of correlated models.

@adhtruong
Copy link
Collaborator

adhtruong commented Dec 30, 2023

Giving set_relationships more possible values might be a good solution, but would that be a big change to the whole library?

This configuration is only used by the SQLAlchemyFactory currently. A configuration here makes sense if useful to have these distinct behaves in the library itself.


Going back to your previous comment

This would cause the program to go into infinite recursion, which is why I need to artificially disable certain field generation.Like Course.grades.

would just keeping track of seen types beforehand handle this case? Note this may be a more naive solution than full graph resolution but may suffice for a lot of use cases. Here's a quick prototype of this based on an example in the docs

from __future__ import annotations

import contextlib
from typing import Any, Iterator, List, TypedDict

from sqlalchemy import ForeignKey, inspect
from sqlalchemy.orm import DeclarativeBase, Mapped, Mapper, mapped_column, relationship

from polyfactory.factories.sqlalchemy_factory import SQLAlchemyFactory
from polyfactory.field_meta import FieldMeta


class Base(DeclarativeBase):
    ...


class Author(Base):
    __tablename__ = "authors"

    id: Mapped[int] = mapped_column(primary_key=True)
    name: Mapped[str]

    books: Mapped[list["Book"]] = relationship("Book", uselist=True)


class Book(Base):
    __tablename__ = "books"

    id: Mapped[int] = mapped_column(primary_key=True)
    author_id: Mapped[int] = mapped_column(ForeignKey(Author.id))

    author: Mapped[Author] = relationship("Author", uselist=False)


class Context(TypedDict):
    seen: set[Any]


_context: Context = {"seen": set()}  # This should not be global


@contextlib.contextmanager
def add_to_context(model: type) -> Iterator[None]:
    _context["seen"].add(model)
    yield
    _context["seen"].remove(model)


class ImprovedSQLAlchemyFactory(SQLAlchemyFactory):
    __is_base_factory__ = True

    @classmethod
    def build(cls, **kwargs: Any) -> Any:
        with add_to_context(cls.__model__):
            return super().build(**kwargs)

    @classmethod
    def get_model_fields(cls) -> list[FieldMeta]:
        fields_meta = super().get_model_fields()

        table: Mapper = inspect(cls.__model__)  # type: ignore[assignment]
        for name, relationship_ in table.relationships.items():
            class_ = relationship_.entity.class_
            if class_ in _context["seen"]:
                continue

            annotation = class_ if not relationship_.uselist else List[class_]  # type: ignore[valid-type]
            fields_meta.append(
                FieldMeta.from_type(
                    name=name,
                    annotation=annotation,
                    random=cls.__random__,
                ),
            )

        return fields_meta


def test_sqla_factory() -> None:
    author: Author = ImprovedSQLAlchemyFactory.create_factory(Author).build()
    assert isinstance(author.books[0], Book)
    assert author.books[0].author is None

    book: Book = ImprovedSQLAlchemyFactory.create_factory(Book).build()
    assert book.author is not None
    assert book.author.books == []


def test_sqla_factory_create() -> None:
    engine = create_engine("sqlite:///:memory:")
    Base.metadata.create_all(engine)
    ImprovedSQLAlchemyFactory.__session__ = Session(engine)

    author: Author = ImprovedSQLAlchemyFactory.create_factory(Author).create_sync()
    assert isinstance(author.books[0], Book)
    assert author.books[0].author is author

    book = ImprovedSQLAlchemyFactory.create_factory(Book).create_sync()
    assert book.author is not None
    assert book.author.books == [book]

Edit: Add test for example with session

@0x587
Copy link
Author

0x587 commented Dec 31, 2023

The results of this prototype don't seem to be what was expected.

This assert

assert author.books[0].author is None

I think should be assert author.books[0].author == author

@adhtruong
Copy link
Collaborator

@0x587 I think that difference comes from using create_sync/ create_async vs build. The former adds the instances to a session so these are resolved by SQLA. Out the box these won't necessarily be set correctly with build.

I've extended the example to use create which does match the expected assertion. Do you think this logic would meet your use case? I think this logic is generic enough to be in the library itself if so.

@0x587
Copy link
Author

0x587 commented Dec 31, 2023

The example of this extension perfectly meets my needs, thank you.

I think this need commonly arises in scenarios where fake data needs to be generated for a database. Creating a subclass of such a SQLAlchemyFactory with a nice name to put into the library is necessary.

Maybe we can close this issue.

@adhtruong
Copy link
Collaborator

Great, thanks for checking and confirming!

I would be in favour of keeping this issue open just so the workaround is documented here as I agree with you this is a probably a common issue. I'll see if this feature is part of the main library or at least document the above as a workaround.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants