Skip to content

Commit

Permalink
feat: add get_and_update to the repository and services (#77)
Browse files Browse the repository at this point in the history
Co-authored-by: Peter Schutt <peter.github@proton.me>
  • Loading branch information
cofin and peterschutt committed Oct 27, 2023
1 parent d5ca3f8 commit 1e50d41
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 7 deletions.
78 changes: 76 additions & 2 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ async def get_or_upsert(
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, perform an update operation on the model.
`kwargs`, automatically perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
Expand Down Expand Up @@ -543,7 +543,15 @@ async def get_or_upsert(
match_filter = kwargs
existing = await self.get_one_or_none(**match_filter)
if not existing:
return await self.add(self.model_type(**kwargs)), True # pyright: ignore[reportGeneralTypeIssues]
return (
await self.add(
self.model_type(**kwargs),
auto_commit=auto_commit,
auto_refresh=auto_refresh,
auto_expunge=auto_expunge,
),
True,
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, None)
Expand All @@ -560,6 +568,72 @@ async def get_or_upsert(
self._expunge(existing, auto_expunge=auto_expunge)
return existing, False

async def get_and_update(
self,
match_fields: list[str] | str | None = None,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
auto_commit: bool | None = None,
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` and update the model if the arguments are different.
Args:
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_expunge <SQLAlchemyAsyncRepository>`.
auto_refresh: Refresh object from session before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_refresh <SQLAlchemyAsyncRepository>`
auto_commit: Commit objects before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_commit <SQLAlchemyAsyncRepository>`
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be updated.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
if match_fields:
match_filter = {
field_name: kwargs.get(field_name, None)
for field_name in match_fields
if kwargs.get(field_name, None) is not None
}
else:
match_filter = kwargs
existing = await self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = await self._attach_to_session(existing, strategy="merge")
await self._flush_or_commit(auto_commit=auto_commit)
await self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, updated

async def count(
self,
*filters: FilterTypes | ColumnElement[bool],
Expand Down
78 changes: 76 additions & 2 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def get_or_upsert(
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
upsert: When using match_fields and actual model values differ from
`kwargs`, perform an update operation on the model.
`kwargs`, automatically perform an update operation on the model.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
Expand Down Expand Up @@ -544,7 +544,15 @@ def get_or_upsert(
match_filter = kwargs
existing = self.get_one_or_none(**match_filter)
if not existing:
return self.add(self.model_type(**kwargs)), True # pyright: ignore[reportGeneralTypeIssues]
return (
self.add(
self.model_type(**kwargs),
auto_commit=auto_commit,
auto_refresh=auto_refresh,
auto_expunge=auto_expunge,
),
True,
)
if upsert:
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, None)
Expand All @@ -561,6 +569,72 @@ def get_or_upsert(
self._expunge(existing, auto_expunge=auto_expunge)
return existing, False

def get_and_update(
self,
match_fields: list[str] | str | None = None,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
auto_commit: bool | None = None,
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Get instance identified by ``kwargs`` and update the model if the arguments are different.
Args:
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_expunge <SQLAlchemyAsyncRepository>`.
auto_refresh: Refresh object from session before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_refresh <SQLAlchemyAsyncRepository>`
auto_commit: Commit objects before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_commit <SQLAlchemyAsyncRepository>`
**kwargs: Identifier of the instance to be retrieved.
Returns:
a tuple that includes the instance and whether it needed to be updated.
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
if match_fields:
match_filter = {
field_name: kwargs.get(field_name, None)
for field_name in match_fields
if kwargs.get(field_name, None) is not None
}
else:
match_filter = kwargs
existing = self.get_one(**match_filter)
updated = False
for field_name, new_field_value in kwargs.items():
field = getattr(existing, field_name, None)
if field and field != new_field_value:
updated = True
setattr(existing, field_name, new_field_value)
existing = self._attach_to_session(existing, strategy="merge")
self._flush_or_commit(auto_commit=auto_commit)
self._refresh(
existing,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_refresh=auto_refresh,
)
self._expunge(existing, auto_expunge=auto_expunge)
return existing, updated

def count(
self,
*filters: FilterTypes | ColumnElement[bool],
Expand Down
45 changes: 44 additions & 1 deletion advanced_alchemy/service/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ async def upsert_many(

async def get_or_upsert(
self,
match_fields: list[str] | None = None,
match_fields: list[str] | str | None = None,
upsert: bool = True,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
Expand Down Expand Up @@ -478,6 +478,49 @@ async def get_or_upsert(
**validated_model.to_dict(),
)

async def get_and_update(
self,
match_fields: list[str] | str | None = None,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
auto_commit: bool | None = None,
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_expunge <SQLAlchemyAsyncRepository>`.
auto_refresh: Refresh object from session before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_refresh <SQLAlchemyAsyncRepository>`
auto_commit: Commit objects before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_commit <SQLAlchemyAsyncRepository>`
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of updated instance.
"""
match_fields = match_fields or self.match_fields
validated_model = await self.to_model(kwargs, "update")
return await self.repository.get_and_update(
match_fields=match_fields,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
**validated_model.to_dict(),
)

async def delete(
self,
item_id: Any,
Expand Down
45 changes: 44 additions & 1 deletion advanced_alchemy/service/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def upsert_many(

def get_or_upsert(
self,
match_fields: list[str] | None = None,
match_fields: list[str] | str | None = None,
upsert: bool = True,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
Expand Down Expand Up @@ -479,6 +479,49 @@ def get_or_upsert(
**validated_model.to_dict(),
)

def get_and_update(
self,
match_fields: list[str] | str | None = None,
attribute_names: Iterable[str] | None = None,
with_for_update: bool | None = None,
auto_commit: bool | None = None,
auto_expunge: bool | None = None,
auto_refresh: bool | None = None,
**kwargs: Any,
) -> tuple[ModelT, bool]:
"""Wrap repository instance creation.
Args:
match_fields: a list of keys to use to match the existing model. When
empty, all fields are matched.
attribute_names: an iterable of attribute names to pass into the ``update``
method.
with_for_update: indicating FOR UPDATE should be used, or may be a
dictionary containing flags to indicate a more specific set of
FOR UPDATE flags for the SELECT
auto_expunge: Remove object from session before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_expunge <SQLAlchemyAsyncRepository>`.
auto_refresh: Refresh object from session before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_refresh <SQLAlchemyAsyncRepository>`
auto_commit: Commit objects before returning. Defaults to
:class:`SQLAlchemyAsyncRepository.auto_commit <SQLAlchemyAsyncRepository>`
**kwargs: Identifier of the instance to be retrieved.
Returns:
Representation of updated instance.
"""
match_fields = match_fields or self.match_fields
validated_model = self.to_model(kwargs, "update")
return self.repository.get_and_update(
match_fields=match_fields,
attribute_names=attribute_names,
with_for_update=with_for_update,
auto_commit=auto_commit,
auto_expunge=auto_expunge,
auto_refresh=auto_refresh,
**validated_model.to_dict(),
)

def delete(
self,
item_id: Any,
Expand Down
42 changes: 41 additions & 1 deletion tests/integration/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sqlalchemy.orm import Session, sessionmaker

from advanced_alchemy import SQLAlchemyAsyncRepository, SQLAlchemyAsyncRepositoryService, base
from advanced_alchemy.exceptions import RepositoryError
from advanced_alchemy.exceptions import NotFoundError, RepositoryError
from advanced_alchemy.filters import (
BeforeAfter,
CollectionFilter,
Expand Down Expand Up @@ -959,6 +959,36 @@ async def test_repo_get_or_upsert_match_filter(author_repo: AuthorRepository, fi
assert existing_created is False


async def test_repo_get_or_upsert_match_filter_no_upsert(author_repo: AuthorRepository, first_author_id: Any) -> None:
now = datetime.now()
existing_obj, existing_created = await maybe_async(
author_repo.get_or_upsert(match_fields="name", upsert=False, name="Agatha Christie", dob=now.date()),
)
assert existing_obj.id == first_author_id
assert existing_obj.dob != now.date()
assert existing_created is False


async def test_repo_get_and_update(author_repo: AuthorRepository, first_author_id: Any) -> None:
existing_obj, existing_updated = await maybe_async(
author_repo.get_and_update(name="Agatha Christie"),
)
assert existing_obj.id == first_author_id
assert existing_updated is False


async def test_repo_get_and_upsert_match_filter(author_repo: AuthorRepository, first_author_id: Any) -> None:
now = datetime.now()
with pytest.raises(NotFoundError):
_ = await maybe_async(
author_repo.get_and_update(match_fields="name", name="Agatha Christie123", dob=now.date()),
)
with pytest.raises(NotFoundError):
_ = await maybe_async(
author_repo.get_and_update(name="Agatha Christie123"),
)


async def test_repo_upsert_method(
author_repo: AuthorRepository,
first_author_id: Any,
Expand Down Expand Up @@ -1493,6 +1523,16 @@ async def test_service_get_or_upsert_method(author_service: AuthorService, first
assert new_created


async def test_service_get_and_update_method(author_service: AuthorService, first_author_id: Any) -> None:
existing_obj, existing_created = await maybe_async(
author_service.get_and_update(name="Agatha Christie", match_fields="name"),
)
assert existing_obj.id == first_author_id
assert existing_created is False
with pytest.raises(NotFoundError):
_ = await maybe_async(author_service.get_and_update(name="New Author"))


async def test_service_upsert_method(
author_service: AuthorService,
first_author_id: Any,
Expand Down

0 comments on commit 1e50d41

Please sign in to comment.