Skip to content

Commit

Permalink
Merge 4ddad12 into 66719bf
Browse files Browse the repository at this point in the history
  • Loading branch information
jerch committed Apr 30, 2022
2 parents 66719bf + 4ddad12 commit e4b2cc5
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 96 deletions.
47 changes: 14 additions & 33 deletions example/exampleapp/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def test_sanity_checks(self):
self.assertRaisesMessage(
ValueError,
'All fast_update() objects must have a primary key set.',
lambda : FieldUpdate.objects.fast_update([FieldUpdate(**EXAMPLE), FieldUpdate(**EXAMPLE)], ['f_char'])
lambda : FieldUpdate.objects.fast_update([FieldUpdate(**EXAMPLE)], ['f_char'])
)
# non concrete field
mbase = MultiBase.objects.create()
Expand All @@ -305,37 +305,18 @@ def test_sanity_checks(self):
)


class TestFilterDuplicates(TestCase):
def test_apply_first_duplicate_only(self):
class TestDuplicates(TestCase):
def test_raise_on_duplicates(self):
a = FieldUpdate.objects.create()
updated = FieldUpdate.objects.fast_update([
FieldUpdate(pk=a.pk, **EXAMPLE), # all values are trueish
FieldUpdate(pk=a.pk) # all values None
], FIELDS)
# only 1 row updated
self.assertEqual(updated, 1)
v = FieldUpdate.objects.all().values().first()
# all values should be trueish
self.assertEqual(all(e for e in v.values()), True)
with self.assertRaisesMessage(ValueError, 'cannot update duplicates'):
FieldUpdate.objects.fast_update([
FieldUpdate(pk=a.pk),
FieldUpdate(pk=a.pk)
], FIELDS)

def test_multiple_duplicates(self):
a = FieldUpdate.objects.create()
b = FieldUpdate.objects.create()
c = FieldUpdate.objects.create()
updated = FieldUpdate.objects.fast_update([
FieldUpdate(pk=a.pk, **EXAMPLE), # all values are trueish
FieldUpdate(pk=a.pk), # all values None
FieldUpdate(pk=a.pk),
FieldUpdate(pk=b.pk, **EXAMPLE), # all values are trueish
FieldUpdate(pk=a.pk),
FieldUpdate(pk=a.pk),
FieldUpdate(pk=b.pk),
FieldUpdate(pk=c.pk, **EXAMPLE) # all values are trueish
], FIELDS)
# 3 row updated
self.assertEqual(updated, 3)
v = list(FieldUpdate.objects.all().values())
# all values should be trueish
self.assertEqual(all(e for e in v[0].values()), True)
self.assertEqual(all(e for e in v[1].values()), True)
self.assertEqual(all(e for e in v[2].values()), True)
def test_no_pk_duplicates(self):
with self.assertRaisesMessage(ValueError, 'cannot update duplicates'):
FieldUpdate.objects.fast_update([
FieldUpdate(),
FieldUpdate()
], FIELDS)
18 changes: 14 additions & 4 deletions example/postgres_tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,21 @@ def test_updatefull_multiple(self):
for r in results[1:]:
for f in CU_FIELDS:
self.assertEqual(r[f], first[f])
# force threaded write
update_c = update_b * 100
FieldUpdateNotNull.objects.copy_update(update_c, CU_FIELDS)

def test_updatefull_multiple_threaded(self):
objs = []
for _ in range(10000):
objs.append(FieldUpdateNotNull())
FieldUpdateNotNull.objects.bulk_create(objs)
changeset = []
for o in objs:
changeset.append(FieldUpdateNotNull(pk=o.pk, **CU_EXAMPLE))

# force copy_update to use threaded logic due to payload >>64kB
FieldUpdateNotNull.objects.copy_update(changeset, CU_FIELDS)
results = list(FieldUpdateNotNull.objects.all().values(*CU_FIELDS))
for r in results[201:]:
first = results[0]
for r in results[1:]:
for f in CU_FIELDS:
self.assertEqual(r[f], first[f])

Expand Down
101 changes: 52 additions & 49 deletions fast_update/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
DB vendor low level interfaces
To register a fast update implementations, call:
To register fast update implementations, call:
register_implementation('vendor', check_function)
Expand Down Expand Up @@ -70,44 +70,7 @@ def prepare_data_xy(
"""


# memorize fast_update vendor on connection object
SEEN_CONNECTIONS = cast(Dict[BaseDatabaseWrapper, str], WeakKeyDictionary())
CHECKER = {}

def register_implementation(
vendor: str,
func: Callable[[BaseDatabaseWrapper], Tuple[Any]]
) -> None:
"""
Register fast update implementation for db vendor.
`vendor` is the vendor name as returned by `connection.vendor`.
`func` is a lazy called function to check support for a certain
implementation at runtime for `connection`. The function should return
a tuple of (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple (needed to avoid re-eval).
"""
CHECKER[vendor] = func


def get_impl(conn: BaseDatabaseWrapper) -> str:
"""
Try to get a fast update implementation for `conn`.
Calls once the the check function of `register_implementation` and
memorizes its result for `conn`.
Returns a tuple (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple.
"""
impl = SEEN_CONNECTIONS.get(conn)
if impl is not None:
return impl
check = CHECKER.get(conn.vendor)
if not check: # pragma: no cover
SEEN_CONNECTIONS[conn] = tuple()
return tuple()
impl = check(conn) or tuple() # NOTE: in case check returns something nullish
SEEN_CONNECTIONS[conn] = impl
return impl
# Fast update implementations for postgres, sqlite and mysql.


def pq_cast(tname: str, field: Field, compiler: SQLCompiler, connection: Any) -> str:
Expand Down Expand Up @@ -243,14 +206,57 @@ def as_mysql(
)


# Register our default db implementations.
# Implementation registry.


# memorize fast_update vendor on connection object
SEEN_CONNECTIONS = cast(Dict[BaseDatabaseWrapper, str], WeakKeyDictionary())
CHECKER = {}

def register_implementation(
vendor: str,
func: Callable[[BaseDatabaseWrapper], Tuple[Any]]
) -> None:
"""
Register fast update implementation for db vendor.
`vendor` is the vendor name as returned by `connection.vendor`.
`func` is a lazy called function to check support for a certain
implementation at runtime for `connection`. The function should return
a tuple of (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple (needed to avoid re-eval).
"""
CHECKER[vendor] = func


def get_impl(conn: BaseDatabaseWrapper) -> str:
"""
Try to get a fast update implementation for `conn`.
Calls once the check function of `register_implementation` and
memorizes its result for `conn`.
Returns a tuple (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple.
"""
impl = SEEN_CONNECTIONS.get(conn)
if impl is not None:
return impl
check = CHECKER.get(conn.vendor)
if not check: # pragma: no cover
SEEN_CONNECTIONS[conn] = tuple()
return tuple()
impl = check(conn) or tuple() # NOTE: in case check returns something nullish
SEEN_CONNECTIONS[conn] = impl
return impl


# Register default db implementations from above.
register_implementation(
'postgresql',
lambda _: (as_postgresql, None)
)
register_implementation(
'sqlite',
# NOTE: check function does not handle versions <2.15 anymore
# NOTE: check function does not handle versions <3.15 anymore
lambda conn: (as_sqlite, None) if conn.Database.sqlite_version_info >= (3, 33)
else (as_sqlite_cte, prepare_data_sqlite_cte)
)
Expand All @@ -260,6 +266,9 @@ def as_mysql(
)


# Update implementation.


def update_from_values(
c: CursorWrapper,
tname: str,
Expand Down Expand Up @@ -344,22 +353,16 @@ def fast_update(
with conn.cursor() as c:
data = []
counter = 0
seen = set()
for o in objs:
row = [p(v, conn) for p, v in zip(prep_save, get(o))]
# filter for first batch occurence
if row[0] not in seen:
counter += 1
data += row
seen.add(row[0])
counter += 1
data += [p(v, conn) for p, v in zip(prep_save, get(o))]
if counter >= batch_size_adjusted:
rows_updated += update_from_values(
c, model._meta.db_table, pk_field, fields,
counter, data, compiler, conn
)
data = []
counter = 0
seen = set()
if data:
rows_updated += update_from_values(
c, model._meta.db_table, pk_field, fields,
Expand Down
25 changes: 15 additions & 10 deletions fast_update/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,33 @@ def sanity_check(
model: Type[Model],
objs: Iterable[Model],
fields: Iterable[str],
op: str,
batch_size: Optional[int] = None
) -> None:
# basic sanity checks (most taken from bulk_update)
if batch_size is not None and batch_size < 0:
raise ValueError('Batch size must be a positive integer.')
if not fields:
raise ValueError('Field names must be given to fast_update().')
if any(obj.pk is None for obj in objs):
raise ValueError('All fast_update() objects must have a primary key set.')
raise ValueError(f'Field names must be given to {op}.')
pks = set(obj.pk for obj in objs)
if len(pks) < len(objs):
raise ValueError(f'{op} cannot update duplicates.')
if None in pks:
raise ValueError(f'All {op} objects must have a primary key set.')
fields_ = [model._meta.get_field(name) for name in fields]
if any(not f.concrete or f.many_to_many for f in fields_):
raise ValueError('fast_update() can only be used with concrete fields.')
raise ValueError(f'{op} can only be used with concrete fields.')
if any(f.primary_key for f in fields_):
raise ValueError('fast_update() cannot be used with primary key fields.')
raise ValueError(f'{op} cannot be used with primary key fields.')
for obj in objs:
# TODO: This is really heavy in the runtime books, any elegant way to speedup?
# TODO: django main has an additional argument 'fields' (saves some runtime?)
obj._prepare_related_fields_for_save(operation_name='fast_update')
# additionally raise on f-expression
for field in fields_:
attr = getattr(obj, field.attname)
if hasattr(attr, 'resolve_expression'):
raise ValueError('fast_update() cannot be used with f-expressions.')
# TODO: use faster attrgetter
if hasattr(getattr(obj, field.attname), 'resolve_expression'):
raise ValueError(f'{op} cannot be used with f-expressions.')


class FastUpdateQuerySet(QuerySet):
Expand Down Expand Up @@ -62,7 +67,7 @@ def fast_update(
return 0
objs = tuple(objs)
fields_ = set(fields or [])
sanity_check(self.model, objs, fields_, batch_size)
sanity_check(self.model, objs, fields_, 'fast_update()', batch_size)
return fast_update(self, objs, fields_, batch_size)

fast_update.alters_data = True
Expand Down Expand Up @@ -111,7 +116,7 @@ def copy_update(
return 0
objs = tuple(objs)
fields_ = set(fields or [])
sanity_check(self.model, objs, fields_)
sanity_check(self.model, objs, fields_, 'copy_update()')
return copy_update(self, objs, fields_, field_encoders, encoding)

copy_update.alters_data = True
Expand Down

0 comments on commit e4b2cc5

Please sign in to comment.