Skip to content

Commit

Permalink
fix: change upsert_many behavior (#90)
Browse files Browse the repository at this point in the history
Co-authored-by: Janek Nouvertné <provinzkraut@posteo.de>
Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 4, 2023
1 parent c53b2ea commit 7a7d755
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 50 deletions.
81 changes: 56 additions & 25 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,7 @@ async def get_or_upsert(
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
if match_fields:
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name, None)
for field_name in match_fields
Expand Down Expand Up @@ -601,10 +598,7 @@ async def get_and_update(
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
if match_fields:
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name, None)
for field_name in match_fields
Expand Down Expand Up @@ -932,10 +926,7 @@ async def upsert(
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
if match_fields:
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: getattr(data, field_name, None)
for field_name in match_fields
Expand Down Expand Up @@ -1013,29 +1004,37 @@ async def upsert_many(
instances: list[ModelT] = []
data_to_update: list[ModelT] = []
data_to_insert: list[ModelT] = []
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
match_filter: list[FilterTypes | ColumnElement[bool]] = [
CollectionFilter(
field_name=self.id_attribute,
values=[getattr(datum, self.id_attribute) for datum in data if datum is not None] if data else None,
),
]
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
match_filter: list[FilterTypes | ColumnElement[bool]] = []
if match_fields:
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [getattr(datum, field_name) for datum in data if datum is not None]
matched_values = [
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
]
if self._prefer_any:
match_filter.append(any_(matched_values) == field) # type: ignore[arg-type]
else:
match_filter.append(field.in_(matched_values))

with wrap_sqlalchemy_exception():
existing_objs = await self.list(*match_filter, auto_expunge=False)
existing_ids = [getattr(datum, self.id_attribute) for datum in existing_objs if datum is not None]
existing_objs = await self.list(
*match_filter,
auto_expunge=False,
)
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [getattr(datum, field_name) for datum in existing_objs if datum is not None]
if self._prefer_any:
match_filter.append(any_(matched_values) == field) # type: ignore[arg-type]
else:
match_filter.append(field.in_(matched_values))
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
if getattr(datum, self.id_attribute) is not None in existing_ids:
if getattr(datum, self.id_attribute, None) in existing_ids:
data_to_update.append(datum)
else:
data_to_insert.append(datum)
Expand All @@ -1052,6 +1051,38 @@ async def upsert_many(
self._expunge(instance, auto_expunge=auto_expunge)
return instances

def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]:
return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None]

def _get_match_fields(
self,
match_fields: list[str] | str | None = None,
id_attribute: str | None = None,
) -> list[str] | None:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields

def _merge_on_match_fields(
self,
data: list[ModelT],
existing_data: list[ModelT],
match_fields: list[str] | str | None = None,
) -> list[ModelT]:
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for row_id, datum in enumerate(data):
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(data[row_id], self.id_attribute, getattr(existing_datum, self.id_attribute))
return data

async def list(
self,
*filters: FilterTypes | ColumnElement[bool],
Expand Down
81 changes: 56 additions & 25 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,10 +527,7 @@ def get_or_upsert(
When using match_fields and actual model values differ from ``kwargs``, the
model value will be updated.
"""
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
if match_fields:
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name, None)
for field_name in match_fields
Expand Down Expand Up @@ -602,10 +599,7 @@ def get_and_update(
Raises:
NotFoundError: If no instance found identified by `item_id`.
"""
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
if match_fields:
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: kwargs.get(field_name, None)
for field_name in match_fields
Expand Down Expand Up @@ -933,10 +927,7 @@ def upsert(
Raises:
NotFoundError: If no instance found with same identifier as `data`.
"""
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
if match_fields:
if match_fields := self._get_match_fields(match_fields=match_fields):
match_filter = {
field_name: getattr(data, field_name, None)
for field_name in match_fields
Expand Down Expand Up @@ -1014,29 +1005,37 @@ def upsert_many(
instances: list[ModelT] = []
data_to_update: list[ModelT] = []
data_to_insert: list[ModelT] = []
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
match_filter: list[FilterTypes | ColumnElement[bool]] = [
CollectionFilter(
field_name=self.id_attribute,
values=[getattr(datum, self.id_attribute) for datum in data if datum is not None] if data else None,
),
]
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
match_filter: list[FilterTypes | ColumnElement[bool]] = []
if match_fields:
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [getattr(datum, field_name) for datum in data if datum is not None]
matched_values = [
field_data for datum in data if (field_data := getattr(datum, field_name)) is not None
]
if self._prefer_any:
match_filter.append(any_(matched_values) == field) # type: ignore[arg-type]
else:
match_filter.append(field.in_(matched_values))

with wrap_sqlalchemy_exception():
existing_objs = self.list(*match_filter, auto_expunge=False)
existing_ids = [getattr(datum, self.id_attribute) for datum in existing_objs if datum is not None]
existing_objs = self.list(
*match_filter,
auto_expunge=False,
)
for field_name in match_fields:
field = get_instrumented_attr(self.model_type, field_name)
matched_values = [getattr(datum, field_name) for datum in existing_objs if datum is not None]
if self._prefer_any:
match_filter.append(any_(matched_values) == field) # type: ignore[arg-type]
else:
match_filter.append(field.in_(matched_values))
existing_ids = self._get_object_ids(existing_objs=existing_objs)
data = self._merge_on_match_fields(data, existing_objs, match_fields)
for datum in data:
if getattr(datum, self.id_attribute) is not None in existing_ids:
if getattr(datum, self.id_attribute, None) in existing_ids:
data_to_update.append(datum)
else:
data_to_insert.append(datum)
Expand All @@ -1053,6 +1052,38 @@ def upsert_many(
self._expunge(instance, auto_expunge=auto_expunge)
return instances

def _get_object_ids(self, existing_objs: list[ModelT]) -> list[Any]:
return [obj_id for datum in existing_objs if (obj_id := getattr(datum, self.id_attribute)) is not None]

def _get_match_fields(
self,
match_fields: list[str] | str | None = None,
id_attribute: str | None = None,
) -> list[str] | None:
id_attribute = id_attribute or self.id_attribute
match_fields = match_fields or self.match_fields
if isinstance(match_fields, str):
match_fields = [match_fields]
return match_fields

def _merge_on_match_fields(
self,
data: list[ModelT],
existing_data: list[ModelT],
match_fields: list[str] | str | None = None,
) -> list[ModelT]:
match_fields = self._get_match_fields(match_fields=match_fields)
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for row_id, datum in enumerate(data):
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(data[row_id], self.id_attribute, getattr(existing_datum, self.id_attribute))
return data

def list(
self,
*filters: FilterTypes | ColumnElement[bool],
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,32 @@ async def test_repo_upsert_many_method_match_non_id(
assert existing_count_now > existing_count


async def test_repo_upsert_many_method_match_not_on_input(
author_repo: AuthorRepository,
author_model: AuthorModel,
) -> None:
if author_repo._dialect.name.startswith("spanner") and os.environ.get("SPANNER_EMULATOR_HOST"):
pytest.skip(
"Skipped on emulator. See the following: https://github.com/GoogleCloudPlatform/cloud-spanner-emulator/issues/73",
)
existing_count = await maybe_async(author_repo.count())
existing_obj = await maybe_async(author_repo.get_one(name="Agatha Christie"))
existing_obj.name = "Agatha C."
_ = await maybe_async(
author_repo.upsert_many(
data=[
existing_obj,
author_model(name="Inserted Author"),
author_model(name="Custom Author"),
],
match_fields=["id"],
),
)
existing_count_now = await maybe_async(author_repo.count())

assert existing_count_now > existing_count


async def test_repo_filter_before_after(author_repo: AuthorRepository) -> None:
before_filter = BeforeAfter(
field_name="created_at",
Expand Down

0 comments on commit 7a7d755

Please sign in to comment.