Skip to content

Commit

Permalink
Merge pull request #9 from netzkolchose/filter_duplicates
Browse files Browse the repository at this point in the history
update only first duplicate
  • Loading branch information
jerch committed Apr 29, 2022
2 parents 6d169d9 + 3f16d5f commit 66719bf
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 7 deletions.
36 changes: 36 additions & 0 deletions example/exampleapp/tests.py
Expand Up @@ -303,3 +303,39 @@ def test_sanity_checks(self):
'fast_update() cannot be used with primary key fields.',
lambda : FieldUpdate.objects.fast_update(self.instances, ['f_char', 'id'])
)


class TestFilterDuplicates(TestCase):
def test_apply_first_duplicate_only(self):
a = FieldUpdate.objects.create()
updated = FieldUpdate.objects.fast_update([
FieldUpdate(pk=a.pk, **EXAMPLE), # all values are trueish
FieldUpdate(pk=a.pk) # all values None
], FIELDS)
# only 1 row updated
self.assertEqual(updated, 1)
v = FieldUpdate.objects.all().values().first()
# all values should be trueish
self.assertEqual(all(e for e in v.values()), True)

def test_multiple_duplicates(self):
a = FieldUpdate.objects.create()
b = FieldUpdate.objects.create()
c = FieldUpdate.objects.create()
updated = FieldUpdate.objects.fast_update([
FieldUpdate(pk=a.pk, **EXAMPLE), # all values are trueish
FieldUpdate(pk=a.pk), # all values None
FieldUpdate(pk=a.pk),
FieldUpdate(pk=b.pk, **EXAMPLE), # all values are trueish
FieldUpdate(pk=a.pk),
FieldUpdate(pk=a.pk),
FieldUpdate(pk=b.pk),
FieldUpdate(pk=c.pk, **EXAMPLE) # all values are trueish
], FIELDS)
# 3 row updated
self.assertEqual(updated, 3)
v = list(FieldUpdate.objects.all().values())
# all values should be trueish
self.assertEqual(all(e for e in v[0].values()), True)
self.assertEqual(all(e for e in v[1].values()), True)
self.assertEqual(all(e for e in v[2].values()), True)
20 changes: 13 additions & 7 deletions fast_update/fast.py
Expand Up @@ -17,9 +17,9 @@
To register a fast update implementations, call:
register_implementation('alias', check_function)
register_implementation('vendor', check_function)
where `alias` is the vendor name as returned by `connection.vendor`.
where `vendor` is the vendor name as returned by `connection.vendor`.
The check function gets called once (lazy) with `connection` and is meant
to find a suitable implementation (you can provide multiple for different
server versions), either actively by probing against the db server,
Expand Down Expand Up @@ -75,19 +75,19 @@ def prepare_data_xy(
CHECKER = {}

def register_implementation(
alias: str,
vendor: str,
func: Callable[[BaseDatabaseWrapper], Tuple[Any]]
) -> None:
"""
Register fast update implementation for db vendor.
`alias` is the vendor name as returned by `connection.vendor`.
`vendor` is the vendor name as returned by `connection.vendor`.
`func` is a lazy called function to check support for a certain
implementation at runtime for `connection`. The function should return
a tuple of (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple (needed to avoid re-eval).
"""
CHECKER[alias] = func
CHECKER[vendor] = func


def get_impl(conn: BaseDatabaseWrapper) -> str:
Expand Down Expand Up @@ -344,16 +344,22 @@ def fast_update(
with conn.cursor() as c:
data = []
counter = 0
seen = set()
for o in objs:
counter += 1
data += [p(v, conn) for p, v in zip(prep_save, get(o))]
row = [p(v, conn) for p, v in zip(prep_save, get(o))]
# filter for first batch occurence
if row[0] not in seen:
counter += 1
data += row
seen.add(row[0])
if counter >= batch_size_adjusted:
rows_updated += update_from_values(
c, model._meta.db_table, pk_field, fields,
counter, data, compiler, conn
)
data = []
counter = 0
seen = set()
if data:
rows_updated += update_from_values(
c, model._meta.db_table, pk_field, fields,
Expand Down

0 comments on commit 66719bf

Please sign in to comment.