diff --git a/example/exampleapp/tests.py b/example/exampleapp/tests.py index b66a8c9..0fd22c2 100644 --- a/example/exampleapp/tests.py +++ b/example/exampleapp/tests.py @@ -288,7 +288,7 @@ def test_sanity_checks(self): self.assertRaisesMessage( ValueError, 'All fast_update() objects must have a primary key set.', - lambda : FieldUpdate.objects.fast_update([FieldUpdate(**EXAMPLE), FieldUpdate(**EXAMPLE)], ['f_char']) + lambda : FieldUpdate.objects.fast_update([FieldUpdate(**EXAMPLE)], ['f_char']) ) # non concrete field mbase = MultiBase.objects.create() @@ -305,37 +305,18 @@ def test_sanity_checks(self): ) -class TestFilterDuplicates(TestCase): - def test_apply_first_duplicate_only(self): +class TestDuplicates(TestCase): + def test_raise_on_duplicates(self): a = FieldUpdate.objects.create() - updated = FieldUpdate.objects.fast_update([ - FieldUpdate(pk=a.pk, **EXAMPLE), # all values are trueish - FieldUpdate(pk=a.pk) # all values None - ], FIELDS) - # only 1 row updated - self.assertEqual(updated, 1) - v = FieldUpdate.objects.all().values().first() - # all values should be trueish - self.assertEqual(all(e for e in v.values()), True) + with self.assertRaisesMessage(ValueError, 'cannot update duplicates'): + FieldUpdate.objects.fast_update([ + FieldUpdate(pk=a.pk), + FieldUpdate(pk=a.pk) + ], FIELDS) - def test_multiple_duplicates(self): - a = FieldUpdate.objects.create() - b = FieldUpdate.objects.create() - c = FieldUpdate.objects.create() - updated = FieldUpdate.objects.fast_update([ - FieldUpdate(pk=a.pk, **EXAMPLE), # all values are trueish - FieldUpdate(pk=a.pk), # all values None - FieldUpdate(pk=a.pk), - FieldUpdate(pk=b.pk, **EXAMPLE), # all values are trueish - FieldUpdate(pk=a.pk), - FieldUpdate(pk=a.pk), - FieldUpdate(pk=b.pk), - FieldUpdate(pk=c.pk, **EXAMPLE) # all values are trueish - ], FIELDS) - # 3 row updated - self.assertEqual(updated, 3) - v = list(FieldUpdate.objects.all().values()) - # all values should be trueish - self.assertEqual(all(e for e in v[0].values()), True) - self.assertEqual(all(e for e in v[1].values()), True) - self.assertEqual(all(e for e in v[2].values()), True) + def test_no_pk_duplicates(self): + with self.assertRaisesMessage(ValueError, 'cannot update duplicates'): + FieldUpdate.objects.fast_update([ + FieldUpdate(), + FieldUpdate() + ], FIELDS) diff --git a/example/postgres_tests/tests.py b/example/postgres_tests/tests.py index 63a2c4d..fa563a2 100644 --- a/example/postgres_tests/tests.py +++ b/example/postgres_tests/tests.py @@ -436,11 +436,21 @@ def test_updatefull_multiple(self): for r in results[1:]: for f in CU_FIELDS: self.assertEqual(r[f], first[f]) - # force threaded write - update_c = update_b * 100 - FieldUpdateNotNull.objects.copy_update(update_c, CU_FIELDS) + + def test_updatefull_multiple_threaded(self): + objs = [] + for _ in range(10000): + objs.append(FieldUpdateNotNull()) + FieldUpdateNotNull.objects.bulk_create(objs) + changeset = [] + for o in objs: + changeset.append(FieldUpdateNotNull(pk=o.pk, **CU_EXAMPLE)) + + # force copy_update to use threaded logic due to payload >>64kB + FieldUpdateNotNull.objects.copy_update(changeset, CU_FIELDS) results = list(FieldUpdateNotNull.objects.all().values(*CU_FIELDS)) - for r in results[201:]: + first = results[0] + for r in results[1:]: for f in CU_FIELDS: self.assertEqual(r[f], first[f]) diff --git a/fast_update/fast.py b/fast_update/fast.py index 7a1337d..ffd44a9 100644 --- a/fast_update/fast.py +++ b/fast_update/fast.py @@ -15,7 +15,7 @@ """ DB vendor low level interfaces -To register a fast update implementations, call: +To register fast update implementations, call: register_implementation('vendor', check_function) @@ -70,44 +70,7 @@ def prepare_data_xy( """ -# memorize fast_update vendor on connection object -SEEN_CONNECTIONS = cast(Dict[BaseDatabaseWrapper, str], WeakKeyDictionary()) -CHECKER = {} - -def register_implementation( - vendor: str, - func: Callable[[BaseDatabaseWrapper], Tuple[Any]] -) -> None: - """ - Register fast update implementation for db vendor. - - `vendor` is the vendor name as returned by `connection.vendor`. - `func` is a lazy called function to check support for a certain - implementation at runtime for `connection`. The function should return - a tuple of (create_sql, prepare_data | None) for supported backends, - otherwise an empty tuple (needed to avoid re-eval). - """ - CHECKER[vendor] = func - - -def get_impl(conn: BaseDatabaseWrapper) -> str: - """ - Try to get a fast update implementation for `conn`. - Calls once the the check function of `register_implementation` and - memorizes its result for `conn`. - Returns a tuple (create_sql, prepare_data | None) for supported backends, - otherwise an empty tuple. - """ - impl = SEEN_CONNECTIONS.get(conn) - if impl is not None: - return impl - check = CHECKER.get(conn.vendor) - if not check: # pragma: no cover - SEEN_CONNECTIONS[conn] = tuple() - return tuple() - impl = check(conn) or tuple() # NOTE: in case check returns something nullish - SEEN_CONNECTIONS[conn] = impl - return impl +# Fast update implementations for postgres, sqlite and mysql. def pq_cast(tname: str, field: Field, compiler: SQLCompiler, connection: Any) -> str: @@ -243,14 +206,57 @@ def as_mysql( ) -# Register our default db implementations. +# Implementation registry. + + +# memorize fast_update vendor on connection object +SEEN_CONNECTIONS = cast(Dict[BaseDatabaseWrapper, str], WeakKeyDictionary()) +CHECKER = {} + +def register_implementation( + vendor: str, + func: Callable[[BaseDatabaseWrapper], Tuple[Any]] +) -> None: + """ + Register fast update implementation for db vendor. + + `vendor` is the vendor name as returned by `connection.vendor`. + `func` is a lazy called function to check support for a certain + implementation at runtime for `connection`. The function should return + a tuple of (create_sql, prepare_data | None) for supported backends, + otherwise an empty tuple (needed to avoid re-eval). + """ + CHECKER[vendor] = func + + +def get_impl(conn: BaseDatabaseWrapper) -> str: + """ + Try to get a fast update implementation for `conn`. + Calls once the check function of `register_implementation` and + memorizes its result for `conn`. + Returns a tuple (create_sql, prepare_data | None) for supported backends, + otherwise an empty tuple. + """ + impl = SEEN_CONNECTIONS.get(conn) + if impl is not None: + return impl + check = CHECKER.get(conn.vendor) + if not check: # pragma: no cover + SEEN_CONNECTIONS[conn] = tuple() + return tuple() + impl = check(conn) or tuple() # NOTE: in case check returns something nullish + SEEN_CONNECTIONS[conn] = impl + return impl + + +# Register default db implementations from above. register_implementation( 'postgresql', lambda _: (as_postgresql, None) ) register_implementation( 'sqlite', - # NOTE: check function does not handle versions <2.15 anymore + # NOTE: check function does not handle versions <3.15 anymore lambda conn: (as_sqlite, None) if conn.Database.sqlite_version_info >= (3, 33) else (as_sqlite_cte, prepare_data_sqlite_cte) ) @@ -260,6 +266,9 @@ def as_mysql( ) +# Update implementation. + + def update_from_values( c: CursorWrapper, tname: str, @@ -344,14 +353,9 @@ def fast_update( with conn.cursor() as c: data = [] counter = 0 - seen = set() for o in objs: - row = [p(v, conn) for p, v in zip(prep_save, get(o))] - # filter for first batch occurence - if row[0] not in seen: - counter += 1 - data += row - seen.add(row[0]) + counter += 1 + data += [p(v, conn) for p, v in zip(prep_save, get(o))] if counter >= batch_size_adjusted: rows_updated += update_from_values( c, model._meta.db_table, pk_field, fields, @@ -359,7 +363,6 @@ def fast_update( ) data = [] counter = 0 - seen = set() if data: rows_updated += update_from_values( c, model._meta.db_table, pk_field, fields, diff --git a/fast_update/query.py b/fast_update/query.py index 10b00fd..f315789 100644 --- a/fast_update/query.py +++ b/fast_update/query.py @@ -10,28 +10,33 @@ def sanity_check( model: Type[Model], objs: Iterable[Model], fields: Iterable[str], + op: str, batch_size: Optional[int] = None ) -> None: # basic sanity checks (most taken from bulk_update) if batch_size is not None and batch_size < 0: raise ValueError('Batch size must be a positive integer.') if not fields: - raise ValueError('Field names must be given to fast_update().') - if any(obj.pk is None for obj in objs): - raise ValueError('All fast_update() objects must have a primary key set.') + raise ValueError(f'Field names must be given to {op}.') + pks = set(obj.pk for obj in objs) + if len(pks) < len(objs): + raise ValueError(f'{op} cannot update duplicates.') + if None in pks: + raise ValueError(f'All {op} objects must have a primary key set.') fields_ = [model._meta.get_field(name) for name in fields] if any(not f.concrete or f.many_to_many for f in fields_): - raise ValueError('fast_update() can only be used with concrete fields.') + raise ValueError(f'{op} can only be used with concrete fields.') if any(f.primary_key for f in fields_): - raise ValueError('fast_update() cannot be used with primary key fields.') + raise ValueError(f'{op} cannot be used with primary key fields.') for obj in objs: + # TODO: This is really heavy in the runtime books, any elegant way to speedup? # TODO: django main has an additional argument 'fields' (saves some runtime?) obj._prepare_related_fields_for_save(operation_name='fast_update') # additionally raise on f-expression for field in fields_: - attr = getattr(obj, field.attname) - if hasattr(attr, 'resolve_expression'): - raise ValueError('fast_update() cannot be used with f-expressions.') + # TODO: use faster attrgetter + if hasattr(getattr(obj, field.attname), 'resolve_expression'): + raise ValueError(f'{op} cannot be used with f-expressions.') class FastUpdateQuerySet(QuerySet): @@ -62,7 +67,7 @@ def fast_update( return 0 objs = tuple(objs) fields_ = set(fields or []) - sanity_check(self.model, objs, fields_, batch_size) + sanity_check(self.model, objs, fields_, 'fast_update()', batch_size) return fast_update(self, objs, fields_, batch_size) fast_update.alters_data = True @@ -111,7 +116,7 @@ def copy_update( return 0 objs = tuple(objs) fields_ = set(fields or []) - sanity_check(self.model, objs, fields_) + sanity_check(self.model, objs, fields_, 'copy_update()') return copy_update(self, objs, fields_, field_encoders, encoding) copy_update.alters_data = True