Skip to content

Commit

Permalink
better placeholder handling
Browse files Browse the repository at this point in the history
  • Loading branch information
jerch committed Apr 6, 2022
1 parent cb76f02 commit 0d6342d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 84 deletions.
124 changes: 41 additions & 83 deletions fast_update/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_vendor(conn: BaseDatabaseWrapper) -> str:
if (major == 3 and minor > 32) or major > 3:
SEEN_CONNECTIONS[conn] = 'sqlite'
return 'sqlite'
else:
else: # pragma: no cover
logger.warning('unsupported sqlite backend, fast_update will fall back to bulk_update')
SEEN_CONNECTIONS[conn] = ''
return ''
Expand All @@ -62,11 +62,10 @@ def get_vendor(conn: BaseDatabaseWrapper) -> str:
c.execute("SELECT column_1 FROM (VALUES ROW(1, 'zzz'), ROW(2, 'yyy')) as foo")
SEEN_CONNECTIONS[conn] = 'mysql8'
return 'mysql8'
except ProgrammingError:
pass
logger.warning('unsupported mysql backend, fast_update will fall back to bulk_update')
SEEN_CONNECTIONS[conn] = ''
return ''
except ProgrammingError: # pragma: no cover
logger.warning('unsupported mysql backend, fast_update will fall back to bulk_update')
SEEN_CONNECTIONS[conn] = ''
return ''

logger.warning('unsupported db backend, fast_update will fall back to bulk_update')
SEEN_CONNECTIONS[conn] = ''
Expand All @@ -83,10 +82,10 @@ def as_postgresql(
tname: str,
pkname: str,
fields: Sequence[Field],
rows: List[str],
count: int,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
row_placeholder: Optional[List[str]] = None
connection: BaseDatabaseWrapper
) -> str:
"""
Uses UPDATE FROM VALUES with column aliasing.
Expand All @@ -96,11 +95,7 @@ def as_postgresql(
"""
dname = 'd' if tname != 'd' else 'c'
cols = ','.join(f'"{f.column}"={pq_cast(dname, f, compiler, connection)}' for f in fields)
if row_placeholder:
values = ','.join(row_placeholder)
else:
value = f'({",".join(["%s"] * (len(fields) + 1))})'
values = ','.join([value] * count)
values = ','.join(rows)
dcols = f'"{pkname}",' + ','.join(f'"{f.column}"' for f in fields)
where = f'"{tname}"."{pkname}"="{dname}"."{pkname}"'
return (
Expand All @@ -113,10 +108,10 @@ def as_sqlite(
tname: str,
pkname: str,
fields: Sequence[Field],
rows: List[str],
count: int,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
row_placeholder: Optional[List[str]] = None
connection: BaseDatabaseWrapper
) -> str:
"""
sqlite >= 3.32 implements basic UPDATE FROM VALUES support following postgres' syntax.
Expand All @@ -126,11 +121,7 @@ def as_sqlite(
dname = 'd' if tname != 'd' else 'c'
cols = ','.join(f'"{f.column}"="{dname}"."column{i + 2}"'
for i, f in enumerate(fields))
if row_placeholder:
values = ','.join(row_placeholder)
else:
value = f'({",".join(["%s"] * (len(fields) + 1))})'
values = ','.join([value] * count)
values = ','.join(rows)
where = f'"{tname}"."{pkname}"="{dname}"."column1"'
return f'UPDATE "{tname}" SET {cols} FROM (VALUES {values}) AS "{dname}" WHERE {where}'

Expand All @@ -139,10 +130,10 @@ def as_mysql(
tname: str,
pkname: str,
fields: Sequence[Field],
rows: List[str],
count: int,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
row_placeholder: Optional[List[str]] = None
connection: BaseDatabaseWrapper
) -> str:
"""
For mariadb we use TVC, introduced in 10.3.3.
Expand All @@ -156,17 +147,14 @@ def as_mysql(
with an offset by 1 select.
"""
dname = 'd' if tname != 'd' else 'c'
temp = 'temp1' if tname != 'temp1' else 'temp2'
cols = ','.join(f'`{tname}`.`{f.column}`={dname}.{i+1}' for i, f in enumerate(fields))
if row_placeholder:
values = ','.join(row_placeholder)
else:
value = f'({",".join(["%s"] * (len(fields) + 1))})'
values = ",".join([value] * (count + 1))
# mysql only: prepend placeholders for additional (0,1,2,...) row
values = ','.join([f'({",".join(["%s"] * (len(fields) + 1))})'] + rows)
where = f'`{tname}`.`{pkname}` = {dname}.0'
# FIXME: need collision check of 'temp' against tname?
return (
f'UPDATE `{tname}`, '
f'(SELECT * FROM (VALUES {values}) AS temp LIMIT {count} OFFSET 1) AS {dname} '
f'(SELECT * FROM (VALUES {values}) AS {temp} LIMIT {count} OFFSET 1) AS {dname} '
f'SET {cols} WHERE {where}'
)

Expand All @@ -175,10 +163,10 @@ def as_mysql8(
tname: str,
pkname: str,
fields: Sequence[Field],
rows: List[str],
count: int,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
row_placeholder: Optional[List[str]] = None
connection: BaseDatabaseWrapper
) -> str:
"""
For MySQL we use the extended VALUES statement, introduced in MySQL 8.0.19.
Expand All @@ -188,11 +176,7 @@ def as_mysql8(
"""
dname = 'd' if tname != 'd' else 'c'
cols = ','.join(f'`{f.column}`={dname}.column_{i+1}' for i, f in enumerate(fields))
if row_placeholder:
values = ','.join('ROW' + r for r in row_placeholder)
else:
value = f'ROW({",".join(["%s"] * (len(fields) + 1))})'
values = ",".join([value] * count)
values = ','.join('ROW' + r for r in rows)
on = f'`{tname}`.`{pkname}` = {dname}.column_0'
return f'UPDATE `{tname}` INNER JOIN (VALUES {values}) AS {dname} ON {on} SET {cols}'

Expand All @@ -206,22 +190,6 @@ def as_mysql8(
}


def row_placeholder(
fields: List[Field],
data: List[Any],
comp: SQLCompiler,
conn: BaseDatabaseWrapper
) -> str:
"""
Generate value placeholders from custom field placeholder functions for given data.
"""
# TODO: prelayout get_placeholder functions to avoid looped checks
placeholders = ','.join(
f.get_placeholder(v, comp, conn) if hasattr(f, 'get_placeholder') else '%s'
for f, v in zip(fields, data))
return f'({placeholders})'


def update_from_values(
c: CursorWrapper,
vendor: str,
Expand All @@ -230,37 +198,30 @@ def update_from_values(
fields: List[Field],
counter: int,
data: List[Any],
has_placeholders: bool,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper
) -> int:
"""
Generate vendor specific sql statement and execute it for given data.
"""
if has_placeholders:
# A custom field placeholder is currently only used by django's BinaryField
# (and in fact only needed for mysql). Since the placeholder interface works
# on data value level and does not allow backend introspection,
# we switch to placeholder mode on the first field with a custom placeholder
# for all backends.
# While this penalizes processing speed alot, it should be safer in general.
row_fields = [pk_field] + fields
row_length = len(row_fields)
values_ph = []
if vendor == 'mysql':
# mysql only: prepend (%s,%s...) for (0,1,2,...) data patch as first row
values_ph.append(f'({",".join(["%s"] * row_length)})')
for i in range(0, len(data), row_length):
values_ph.append(
row_placeholder(row_fields, data[i : i + row_length], compiler, connection)
)
sql = QUERY[vendor](
tname, pk_field.column, fields,
counter, compiler, connection, values_ph
)
else:
# non custom placeholder based faster construction path
sql = QUERY[vendor](tname, pk_field.column, fields, counter, compiler, connection)
# The following placeholder calc is quite cumbersome:
# For fast processing we approach the data col-based (90° turned)
# to save runtime for funcion pointer juggling for every single row
# which is ~90% faster than a more direct row-based evaluation.
# This is still alot slower than direct flat layouting, but not significant
# anymore for the total runtime (<<1%, a row-based approach takes 3-7%).
row_fields = [pk_field] + fields
row_length = len(row_fields)
default_placeholder = ['%s'] * counter
col_placeholders = [
([
f.get_placeholder(data[i], compiler, connection)
for i in range(pos, len(data), row_length)
] if hasattr(f, 'get_placeholder') else default_placeholder)
for pos, f in enumerate(row_fields)
]
rows = [f'({",".join(row)})' for row in zip(*col_placeholders)]
sql = QUERY[vendor](tname, pk_field.column, fields, rows, counter, compiler, connection)
if vendor == 'mysql':
# mysql only: prepend (0,1,2,...) as first row
data = list(range(len(fields) + 1)) + data
Expand Down Expand Up @@ -302,10 +263,7 @@ def fast_update(
pk_field = model._meta.pk
get = attrgetter(pk_field.attname, *(f.attname for f in fields))
prep_save = [pk_field.get_db_prep_save] + [f.get_db_prep_save for f in fields]
has_placeholders = any(hasattr(f, 'get_placeholder') for f in fields)
compiler = None
if vendor == 'postgresql' or has_placeholders:
compiler = models.sql.UpdateQuery(model).get_compiler(conn.alias)
compiler = models.sql.UpdateQuery(model).get_compiler(conn.alias)

rows_updated = 0
with transaction.atomic(using=conn.alias, savepoint=False):
Expand All @@ -319,14 +277,14 @@ def fast_update(
if counter >= batch_size:
rows_updated += update_from_values(
c, vendor, model._meta.db_table, pk_field, fields,
counter, data, has_placeholders, compiler, conn
counter, data, compiler, conn
)
data = []
counter = 0
if data:
rows_updated += update_from_values(
c, vendor, model._meta.db_table, pk_field, fields,
counter, data, has_placeholders, compiler, conn
counter, data, compiler, conn
)

# handle remaining non local fields (done by bulk_update for now)
Expand Down
2 changes: 1 addition & 1 deletion fast_update/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def sanity_check(model, objs, fields, batch_size):
raise ValueError('Batch size must be a positive integer.')
if not fields:
raise ValueError('Field names must be given to fast_update().')
objs = tuple(objs)
if not objs:
return 0
if any(obj.pk is None for obj in objs):
Expand Down Expand Up @@ -43,6 +42,7 @@ def fast_update(
"""
TODO...
"""
objs = tuple(objs)
sanity_check(self.model, objs, fields, batch_size)
return fast_update(self, objs, fields, batch_size)

Expand Down

0 comments on commit 0d6342d

Please sign in to comment.