Skip to content

Commit

Permalink
fix: handle empty lists and None collection filters (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Oct 19, 2023
1 parent e6b940f commit 70a4233
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 6 deletions.
12 changes: 8 additions & 4 deletions advanced_alchemy/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ class CollectionFilter(Generic[T]):

field_name: str
"""Name of the model attribute to filter on."""
values: abc.Collection[T]
"""Values for ``IN`` clause."""
values: abc.Collection[T] | None
"""Values for ``IN`` clause.
An empty list will return an empty result set, however, if ``None``, the filter is not applied to the query, and all rows are returned. """


@dataclass
Expand All @@ -68,8 +70,10 @@ class NotInCollectionFilter(Generic[T]):

field_name: str
"""Name of the model attribute to filter on."""
values: abc.Collection[T]
"""Values for ``NOT IN`` clause."""
values: abc.Collection[T] | None
"""Values for ``NOT IN`` clause.
An empty list or ``None`` will return all rows."""


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,7 @@ async def upsert_many(
existing_objs = await self.list(
CollectionFilter(
field_name=self.id_attribute,
values=[getattr(datum, self.id_attribute) for datum in data],
values=[getattr(datum, self.id_attribute) for datum in data] if data else None,
),
)
existing_ids = [getattr(datum, self.id_attribute) for datum in existing_objs]
Expand Down
2 changes: 1 addition & 1 deletion advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def upsert_many(
existing_objs = self.list(
CollectionFilter(
field_name=self.id_attribute,
values=[getattr(datum, self.id_attribute) for datum in data],
values=[getattr(datum, self.id_attribute) for datum in data] if data else None,
),
)
existing_ids = [getattr(datum, self.id_attribute) for datum in existing_objs]
Expand Down
File renamed without changes.
14 changes: 14 additions & 0 deletions tests/integration/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,13 @@ async def test_repo_filter_no_obj_collection(
assert no_obj == []


async def test_repo_filter_null_collection(
author_repo: AuthorRepository,
) -> None:
no_obj = await maybe_async(author_repo.list(CollectionFilter(field_name="id", values=None)))
assert len(no_obj) > 0


async def test_repo_filter_not_in_collection(
author_repo: AuthorRepository,
existing_author_ids: Generator[Any, None, None],
Expand All @@ -1127,6 +1134,13 @@ async def test_repo_filter_not_in_no_obj_collection(
assert len(existing_obj) > 0


async def test_repo_filter_not_in_null_collection(
author_repo: AuthorRepository,
) -> None:
existing_obj = await maybe_async(author_repo.list(NotInCollectionFilter(field_name="id", values=None)))
assert len(existing_obj) > 0


async def test_repo_json_methods(
raw_rules_uuid: RawRecordData,
rule_repo: RuleRepository,
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,21 @@ async def test_sqlalchemy_repo_list_with_collection_filter(
mock_repo._filter_in_collection.assert_called_with(field_name, values, statement=mock_repo.statement)


async def test_sqlalchemy_repo_list_with_null_collection_filter(
mock_repo: SQLAlchemyAsyncRepository,
monkeypatch: MonkeyPatch,
mock_repo_execute: AnyMock,
mocker: MockerFixture,
) -> None:
"""Test behavior of list operation given CollectionFilter."""
field_name = "id"
mock_repo_execute.return_value = MagicMock()
mock_repo.statement.where.return_value = mock_repo.statement
mocker.patch.object(mock_repo, "_filter_in_collection", return_value=mock_repo.statement)
await maybe_async(mock_repo.list(CollectionFilter(field_name, None)))
mock_repo._filter_in_collection.assert_not_called()


async def test_sqlalchemy_repo_empty_list_with_collection_filter(
mock_repo: SQLAlchemyAsyncRepository,
monkeypatch: MonkeyPatch,
Expand Down Expand Up @@ -678,6 +693,21 @@ async def test_sqlalchemy_repo_list_with_not_in_collection_filter(
mock_repo._filter_not_in_collection.assert_called_with(field_name, values, statement=mock_repo.statement)


async def test_sqlalchemy_repo_list_with_null_not_in_collection_filter(
mock_repo: SQLAlchemyAsyncRepository,
monkeypatch: MonkeyPatch,
mock_repo_execute: AnyMock,
mocker: MockerFixture,
) -> None:
"""Test behavior of list operation given CollectionFilter."""
field_name = "id"
mock_repo_execute.return_value = MagicMock()
mock_repo.statement.where.return_value = mock_repo.statement
mocker.patch.object(mock_repo, "_filter_not_in_collection", return_value=mock_repo.statement)
await maybe_async(mock_repo.list(NotInCollectionFilter(field_name, None)))
mock_repo._filter_not_in_collection.assert_not_called()


async def test_sqlalchemy_repo_unknown_filter_type_raises(mock_repo: SQLAlchemyAsyncRepository) -> None:
"""Test that repo raises exception if list receives unknown filter type."""
with pytest.raises(RepositoryError):
Expand Down

0 comments on commit 70a4233

Please sign in to comment.