From 7d727ada396395d80e340d164921063e337536fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Breitbart?= Date: Mon, 11 Apr 2022 16:37:47 +0200 Subject: [PATCH] prepare for first beta --- README.md | 39 ++-- fast_update/__init__.py | 2 +- fast_update/copy.py | 431 +++++++++++++++++++++++++++------------- fast_update/fast.py | 17 +- fast_update/query.py | 90 ++++++--- setup.py | 4 +- 6 files changed, 398 insertions(+), 185 deletions(-) diff --git a/README.md b/README.md index b9e79a6..879a2bb 100644 --- a/README.md +++ b/README.md @@ -6,12 +6,9 @@ Faster db updates using `UPDATE FROM VALUES` sql variants. -### fast_update ### +### Installation & Usage ### -`fast_update` is meant to be used as `bulk_update` replacement. - - -#### Example Usage #### +Run `pip install django-fast-update` and place `fast_update` in INSTALLED_APPS. With attaching `FastUpdateManager` as a manager to your model, `fast_update` can be used instead of `bulk_update`, e.g.: @@ -31,15 +28,23 @@ class MyModel(models.Model): MyModel.objects.fast_update(bunch_of_instances, ['field_a', 'field_b', 'field_c']) ``` +Alternatively `fast.fast_update` can be used directly with a queryset as first argument +(Warning - this skips most sanity checks with up to 30% speed gain, +but make sure not to feed something totally off). + + +### Compatibility ### -#### Compatibility #### +The way `fast_update` is implemented by UPDATE FROM VALUES variants, makes it +dependent on recent database backends. It is currently implemented for: -`fast_update` is implemented for these database backends: - SQLite 3.33+ - PostgreSQL - MariaDB 10.3.3+ - MySQL 8.0.19+ +For unsupported database backends or outdated versions `fast_update` will fall back to `bulk_update`. + Note that with `fast_update` f-expressions cannot be used anymore. This is a design decision to not penalize update performance by some swiss-army-knife functionality. If you have f-expressions in your update data, consider re-grouping the update steps and update those @@ -49,24 +54,28 @@ fields with `update` or `bulk_update` instead. ### copy_update ### This is a PostgreSQL only update implementation based on `COPY FROM`. This runs even faster -than `fast_update` for medium to big changesets. +than `fast_update` for medium to big changesets (but tends to be slower than `fast_update` for <100 objects). -Note that this will probably never leave the alpha/PoC-state, as psycopg3 brings great COPY support, -which does a more secure value conversion and even runs faster in the C-version. +`copy_update` follows the same interface idea as `bulk_update` and `fast_update`, minus a `batch_size` +argument (data is always transferred in one big batch). It can be used likewise from the `FastUpdateManager`. +`copy_update` also has no support for f-expressions. -TODO - describe usage and limitations... +**Note** `copy_update` will probably never leave the alpha/PoC-state, as psycopg3 brings great COPY support, +which does a more secure value conversion and has a very fast C-version. ### Status ### -Currently alpha, left to do: -- finish `copy_update` (array null cascading, some tests) -- some better docs +Currently beta, still some TODOs left (including better docs). + +The whole package is tested with Django 3.2 & 4.0 on Python 3.8 & 3.10. ### Performance ### -There is a management command in the example app testing performance of updates on the `FieldUpdate` model. +There is a management command in the example app testing performance of updates on the `FieldUpdate` +model (`./manange.py perf`). + Here are some numbers from my laptop (tested with `settings.DEBUG=False`, db engines freshly bootstrapped from docker as mentioned in `settings.py`): diff --git a/fast_update/__init__.py b/fast_update/__init__.py index 156d6f9..b794fd4 100644 --- a/fast_update/__init__.py +++ b/fast_update/__init__.py @@ -1 +1 @@ -__version__ = '0.0.4' +__version__ = '0.1.0' diff --git a/fast_update/copy.py b/fast_update/copy.py index 3a2041b..5f9a471 100644 --- a/fast_update/copy.py +++ b/fast_update/copy.py @@ -1,21 +1,37 @@ import os +import sys from threading import Thread from io import BytesIO from binascii import b2a_hex from operator import attrgetter -from django.db import connections, transaction, models -from typing import Any, Dict, Optional, Sequence from decimal import Decimal as Decimal from datetime import date, datetime, timedelta, time as dt_time from json import dumps from uuid import UUID +from django.db import connections, transaction, models +from django.db.models.fields.related import RelatedField from django.contrib.postgres.fields import (HStoreField, ArrayField, IntegerRangeField, BigIntegerRangeField, DecimalRangeField, DateTimeRangeField, DateRangeField) from psycopg2.extras import Range +# typings imports +from django.db.backends.utils import CursorWrapper +from typing import Any, BinaryIO, Callable, Dict, Generic, Iterable, List, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import cast + +EncoderProto = TypeVar("EncoderProto", bound=Callable[[Any, str, List[Any]], Any]) -# TODO: typings, docs cleanup -# TODO: document encoder interface +if sys.version_info > (3, 7): + from typing import Protocol + class FieldEncoder(Protocol[EncoderProto]): + array_escape: bool + __call__: EncoderProto +else: + class FieldEncoder(Generic[EncoderProto]): + array_escape: bool + __call__: EncoderProto + +# TODO: proper typing for lazy encoders and lazy container # postgres connection encodings mapped to python @@ -69,21 +85,142 @@ # NULL placeholder for COPY FROM NULL = '\\N' +# SQL NULL +SQL_NULL = 'NULL' # lazy placeholder: NUL - not allowed in postgres data (cannot slip through) LAZY_PLACEHOLDER = '\x00' LAZY_PLACEHOLDER_BYTE = b'\x00' -def text_escape(v): +def text_escape(v: str) -> str: """ Escape str-like data for postgres' TEXT format. + + This does all basic TEXT escaping. For nested types as of hstore + or in arrays, array_escape must be applied on top. """ return (v.replace('\\', '\\\\') .replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n') .replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v')) -def Int(v, fname, lazy): + +# Rules for nested types: +# - a backslash in nested strings needs 2x2 escaping, e.g. \ --> \\ --> \\\\ +# - nested str values may need explicit quoting to avoid collisions +# --> simply always quote +# - due to quoting, " needs \\" escape in nested strings +# - null value is again sql NULL (handled in nesting encoders) + + +def array_escape(v: Any) -> str: + """ + Escapes vulnerable str data for nested types as in arrays or hstore. + + To meet the nested rules above, the data must already be escaped from + text_escape, e.g. \ --> text_escape --> \\ --> array_escape --> \\\\. + """ + if not isinstance(v, str): + v = str(v) + return '"' + v.replace('\\\\', '\\\\\\\\').replace('"', '\\\\"') + '"' + + +# TODO: move the interface description into proper place... +""" +Encoder Interface + +Field value encoding for postgres' TEXT format happens in 2 main steps: + +- string stage: + A string representation of a data row (e.g. for one model instance) + is constructed by f-string formatting from return values of the field encoders. + A byte intensive encoder may place a ``LAZY_PLACEHOLDER`` in the row string + to indicate later data injection during the byte stage. + Once the row is finished, it gets encoded to bytes by the selected encoding. + +- byte stage: + The bytes are written in 64k chunks to a file object digested by + psycopg2's copy_from function, either directly or threaded for bigger payloads. + If an encoder marked its data as lazy, the encoders lazy part gets called + and writes its data directly into the file object. + +Within the default encoders the lazy encoding is only used for binary fields to +avoid costly forth and back unicode to bytes encoding of potentially big data. + + +Field Encoder + +A field encoder is responsible for conversion of a python type into a string format, +that Postgres understands and does not collide with the TEXT format. +The basic encoder interface is + +def some_encoder(value: Any, fname: str, lazy: List[Any]) -> ReturnType: ... +some_encoder.array_escape: bool = ... + +``value`` denotes the incoming field value as currently set on the model instance for +the field ``fname``. ``lazy`` is a context helper object for lazy encoders +(see Lazy Encoder below). + +The return type should be a string in postgres' TEXT format. For proper escaping +the helper ``text_escape`` can be used. For None/nullish values ``NULL`` is predefined. +It is also possible to return a python type directly, if it is known to translate +correctly into the TEXT format from ``__str__`` output (some default encoders +use it to speed up processing). + +The ``array_escape`` flag is needed to indicate, whether additional escaping for +array handling is needed (see ``array_escape`` function). Typically this is +always the case for any `text_escape`'d values. Beside those, the array syntax has +further collisions for characters like comma or curly braces. + +For better typings support there is a decorator ``encoder``, which can be applied +to the encoding function like this: + +@encoder +def some_encoder(value: Any, fname: str, lazy: List[Any]) -> ReturnType: ... + +The decorator sets the ``array_escape`` flag to ``True`` be default. + + +Lazy Encoder + +Lazy encoding can be achieved by placing returning a ``LAZY_PLACEHOLDER`` and +appending a tuple like ``(callback, value)`` to ``lazy``. The callback is meant +to write bytes directly to the file object and has the following interface: + +def lazy_callback(file: BinaryIO, value: Any) -> None: + + +Encoder Registration + +On startup django's standard field types are globally mapped to default encoders. +With ``register_fieldclass`` the encoders can be changed, or encoders be registered +for a custom field type. + +Furthermore encoders can be overridden for individual `copy_update` calls with +``field_encoders``. + + +Array Support + +The default encoder implementations do not hardcode array support, but use +the factory function ``array_factory`` instead. The factory should be generic enough +to be used with custom encoders as well. Currently the factory has to re-run for every +arrayfield. Furthermore the empty and balance checks are not further optimised yet. +Thus arrayfields have some performance penalty currently. +""" + +def encoder(func: EncoderProto) -> FieldEncoder[EncoderProto]: + """ + Decoration helper for encoders to place the needed + array_escape attribute (default is True). + """ + f = cast(FieldEncoder[EncoderProto], func) + f.array_escape = True + return f + + +@encoder +def Int(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``int``, raise for any other.""" if isinstance(v, int): return v @@ -91,7 +228,8 @@ def Int(v, fname, lazy): Int.array_escape = False -def IntOrNone(v, fname, lazy): +@encoder +def IntOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Int``, additionally handling ``None`` as NULL.""" if v is None: return NULL @@ -101,7 +239,7 @@ def IntOrNone(v, fname, lazy): IntOrNone.array_escape = False -def _lazy_binary(f, v): +def _lazy_binary(f: BinaryIO, v: Union[memoryview, bytes]) -> None: length = len(v) if length <= 65536: f.write(b2a_hex(v)) @@ -112,7 +250,8 @@ def _lazy_binary(f, v): byte_pos += 65536 -def Binary(v, fname, lazy): +@encoder +def Binary(v: Any, fname: str, lazy: List[Any]): """ Test and pass along ``(memoryview, bytes)`` types, raise for any other. @@ -128,10 +267,10 @@ def Binary(v, fname, lazy): return '\\\\x' + LAZY_PLACEHOLDER return '\\\\x' + v.hex() raise TypeError(f'expected types {memoryview} or {bytes} for field "{fname}", got {type(v)}') -Binary.array_escape = True -def BinaryOrNone(v, fname, lazy): +@encoder +def BinaryOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Binary``, additionally handling ``None`` as NULL.""" if v is None: return NULL @@ -141,10 +280,10 @@ def BinaryOrNone(v, fname, lazy): return '\\\\x' + LAZY_PLACEHOLDER return '\\\\x' + v.hex() raise TypeError(f'expected types {memoryview}, {bytes} or None for field "{fname}", got {type(v)}') -BinaryOrNone.array_escape = True -def Boolean(v, fname, lazy): +@encoder +def Boolean(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``bool``, raise for any other.""" if isinstance(v, bool): return v @@ -152,7 +291,8 @@ def Boolean(v, fname, lazy): Boolean.array_escape = False -def BooleanOrNone(v, fname, lazy): +@encoder +def BooleanOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Boolean``, additionally handling ``None`` as NULL.""" if v is None: return NULL @@ -162,7 +302,8 @@ def BooleanOrNone(v, fname, lazy): BooleanOrNone.array_escape = False -def Date(v, fname, lazy): +@encoder +def Date(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``datetime.date``, raise for any other.""" if isinstance(v, date): return v @@ -170,7 +311,8 @@ def Date(v, fname, lazy): Date.array_escape = False -def DateOrNone(v, fname, lazy): +@encoder +def DateOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Date``, additionally handling ``None`` as NULL.""" if v is None: return NULL @@ -180,25 +322,26 @@ def DateOrNone(v, fname, lazy): DateOrNone.array_escape = False -def Datetime(v, fname, lazy): +@encoder +def Datetime(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``datetime``, raise for any other.""" if isinstance(v, datetime): return v raise TypeError(f'expected type {datetime} for field "{fname}", got {type(v)}') -Datetime.array_escape = True -def DatetimeOrNone(v, fname, lazy): +@encoder +def DatetimeOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Datetime``, additionally handling ``None`` as NULL.""" if v is None: return NULL if isinstance(v, datetime): return v raise TypeError(f'expected type {datetime} or None for field "{fname}", got {type(v)}') -DatetimeOrNone.array_escape = True -def Numeric(v, fname, lazy): +@encoder +def Numeric(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``Decimal``, raise for any other.""" if isinstance(v, Decimal): return v @@ -206,7 +349,8 @@ def Numeric(v, fname, lazy): Numeric.array_escape = False -def NumericOrNone(v, fname, lazy): +@encoder +def NumericOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Numeric``, additionally handling ``None`` as NULL.""" if v is None: return NULL @@ -216,25 +360,26 @@ def NumericOrNone(v, fname, lazy): NumericOrNone.array_escape = False -def Duration(v, fname, lazy): +@encoder +def Duration(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``timedelta``, raise for any other.""" if isinstance(v, timedelta): return v raise TypeError(f'expected type {timedelta} for field "{fname}", got {type(v)}') -Duration.array_escape = True -def DurationOrNone(v, fname, lazy): +@encoder +def DurationOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Duration``, additionally handling ``None`` as NULL.""" if v is None: return NULL if isinstance(v, timedelta): return v raise TypeError(f'expected type {timedelta} or None for field "{fname}", got {type(v)}') -DurationOrNone.array_escape = True -def Float(v, fname, lazy): +@encoder +def Float(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``float`` or ``int``, raise for any other.""" if isinstance(v, (float, int)): return v @@ -242,7 +387,8 @@ def Float(v, fname, lazy): Float.array_escape = False -def FloatOrNone(v, fname, lazy): +@encoder +def FloatOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Float``, additionally handling ``None`` as NULL.""" if v is None: return NULL @@ -253,17 +399,19 @@ def FloatOrNone(v, fname, lazy): # TODO: test and document Json vs. JsonOrNone behavior (sql null vs. json null) -def Json(v, fname, lazy): +# TODO: document alternative json field encoder with orjson? +@encoder +def Json(v: Any, fname: str, lazy: List[Any]): """ Default JSON encoder using ``json.dumps``. This version encodes ``None`` as json null value. """ return text_escape(dumps(v)) -Json.array_escape = True -def JsonOrNone(v, fname, lazy): +@encoder +def JsonOrNone(v: Any, fname: str, lazy: List[Any]): """ Default JSON encoder using ``json.dumps``. @@ -272,10 +420,10 @@ def JsonOrNone(v, fname, lazy): if v is None: return NULL return text_escape(dumps(v)) -JsonOrNone.array_escape = True -def Text(v, fname, lazy): +@encoder +def Text(v: Any, fname: str, lazy: List[Any]): """ Test and encode ``str``, raise for any other. @@ -285,20 +433,20 @@ def Text(v, fname, lazy): if isinstance(v, str): return text_escape(v) raise TypeError(f'expected type {str} for field "{fname}", got {type(v)}') -Text.array_escape = True -def TextOrNone(v, fname, lazy): +@encoder +def TextOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Text``, additionally handling ``None`` as NULL.""" if v is None: return NULL if isinstance(v, str): return text_escape(v) raise TypeError(f'expected type {str} or None for field "{fname}", got {type(v)}') -TextOrNone.array_escape = True -def Time(v, fname, lazy): +@encoder +def Time(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``datetime.time``, raise for any other.""" if isinstance(v, dt_time): return v @@ -306,7 +454,8 @@ def Time(v, fname, lazy): Time.array_escape = False -def TimeOrNone(v, fname, lazy): +@encoder +def TimeOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Time``, additionally handling ``None`` as NULL.""" if v is None: return NULL @@ -316,7 +465,8 @@ def TimeOrNone(v, fname, lazy): TimeOrNone.array_escape = False -def Uuid(v, fname, lazy): +@encoder +def Uuid(v: Any, fname: str, lazy: List[Any]): """Test and pass along ``UUID``, raise for any other.""" if isinstance(v, UUID): return v @@ -324,7 +474,8 @@ def Uuid(v, fname, lazy): Uuid.array_escape = False -def UuidOrNone(v, fname, lazy): +@encoder +def UuidOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Uuid``, additionally handling ``None`` as NULL.""" if v is None: return NULL @@ -334,40 +485,8 @@ def UuidOrNone(v, fname, lazy): UuidOrNone.array_escape = False -""" -Special handling of nested types - -Nested types behave way different in COPY FROM TEXT format, -than on top level (kinda falling back on SQL syntax format). -From the tests this applies to values in arrays and hstore -(prolly applies to all custom composite types, not tested). - -Rules for nested types: -- a backslash in nested strings needs 4x escaping, e.g. \ --> \\\\ -- nested str values may need explicit quoting --> always quoted -- due to quoting, " needs \\" escape in nested strings -- null value is again sql NULL -""" -def quote(v): - return '"' + v.replace('"', '\\\\"') + '"' - -def text_escape_nested(v): - """ - Escape nested str-like data for postgres' TEXT format. - The nested variant is needed for array and hstore data - (prolly for any custom composite types, untested). - """ - return (v.replace('\\', '\\\\\\\\') - .replace('\b', '\\b').replace('\f', '\\f').replace('\n', '\\n') - .replace('\r', '\\r').replace('\t', '\\t').replace('\v', '\\v')) - -SQL_NULL = 'NULL' - -def array_escape(v): - return '"' + v.replace('\\\\', '\\\\\\\\').replace('"', '\\\\"') + '"' - - -def HStore(v, fname, lazy): +@encoder +def HStore(v: Any, fname: str, lazy: List[Any]): """ HStore field encoder. Expects a ``dict`` as input type with str keys and str|None values. Any other types will raise. @@ -383,25 +502,34 @@ def HStore(v, fname, lazy): if v is not None and not isinstance(v, str): raise TypeError(f'expected type {str} or None for values of field "{fname}"') parts.append( - f'{quote(text_escape_nested(k))}=>' - f'{SQL_NULL if v is None else quote(text_escape_nested(v))}' + f'{array_escape(text_escape(k))}=>' + f'{SQL_NULL if v is None else array_escape(text_escape(v))}' ) return ','.join(parts) raise TypeError(f'expected type {dict} for field "{fname}", got {type(v)}') -HStore.array_escape = True -def HStoreOrNone(v, fname, lazy): +@encoder +def HStoreOrNone(v: Any, fname: str, lazy: List[Any]): """Same as ``Hstore``, additionally handling ``None`` as NULL.""" if v is None: return NULL if isinstance(v, dict): - return HStore(v, fname, lazy) + parts = [] + for k, v in v.items(): + if not isinstance(k, str): + raise TypeError(f'expected type {str} for keys of field "{fname}"') + if v is not None and not isinstance(v, str): + raise TypeError(f'expected type {str} or None for values of field "{fname}"') + parts.append( + f'{array_escape(text_escape(k))}=>' + f'{SQL_NULL if v is None else array_escape(text_escape(v))}' + ) + return ','.join(parts) raise TypeError(f'expected type {dict} or None for field "{fname}", got {type(v)}') -HStoreOrNone.array_escape = True -def range_factory(basetype, text_safe): +def range_factory(basetype, text_safe) -> Tuple[FieldEncoder, FieldEncoder]: """ Factory for range type encoders. @@ -418,49 +546,51 @@ def range_factory(basetype, text_safe): Returns a tuple of (Range, RangeOrNone) encoders. """ - def encode_range(v, fname, lazy): + @encoder + def encode_range(v: Any, fname: str, lazy: List[Any]): if isinstance(v, Range) and isinstance(v.lower, basetype) and isinstance(v.upper, basetype): return v if text_safe else text_escape(str(v)) raise TypeError(f'expected type {basetype} for field "{fname}", got {type(v)}') - encode_range.array_escape = True - def encode_range_none(v, fname, lazy): + + @encoder + def encode_range_none(v: Any, fname: str, lazy: List[Any]): if v is None: return NULL if isinstance(v, Range) and isinstance(v.lower, basetype) and isinstance(v.upper, basetype): return v if text_safe else text_escape(str(v)) raise TypeError(f'expected type {basetype} or None for field "{fname}", got {type(v)}') - encode_range_none.array_escape = True + return encode_range, encode_range_none -""" -Edge cases with arrays: -- empty top level array works for both - select ARRAY[]::integer[]; --> {} - select '{}'::integer[]; --> {} - -- nested empty sub arrays - select ARRAY[ARRAY[ARRAY[]]]::integer[]; --> {} - select '{{{}}}'::integer[]; --> malformed array literal: "{{{}}}" - --> no direct aquivalent in text notation - -- complicated mixture with null - select ARRAY[null,ARRAY[ARRAY[],null]]::integer[]; --> {} - no direct aquivalent in text notation - -Is the ARRAY notation broken in postgres? - -Observations: -- if all values are nullish + at least one empty sub array, enclosing array is set to empty (losing all inner info?) - --> cascades up to top level leaving an empty array {} with no dimension or inner info at all -- if all values are nullish and there is no empty sub array, the values manifest as null with dimension set -- ARRAY[] raises for unbalanced multidimension arrays, TEXT format returns nonsense syntax error - -The following 2 functions is_empty_array and is_balanced try to restore some of the ARRAY[] behavior. -This happens to a rather high price of 2 additional deep scan of array values. -A better implementation prolly could do that in one pass combined. -""" -def is_empty_array(v): +# Edge cases with arrays: +# - empty top level array works for both +# select ARRAY[]::integer[]; --> {} +# select '{}'::integer[]; --> {} +# +# - nested empty sub arrays +# select ARRAY[ARRAY[ARRAY[]]]::integer[]; --> {} +# select '{{{}}}'::integer[]; --> malformed array literal: "{{{}}}" +# --> no direct aquivalent in text notation +# +# - complicated mixture with null +# select ARRAY[null,ARRAY[ARRAY[],null]]::integer[]; --> {} +# no direct aquivalent in text notation +# +# Is the ARRAY notation broken in postgres? +# +# Observations: +# - if all values are nullish + at least one empty sub array, enclosing array is set to empty (losing all inner info?) +# --> cascades up to top level leaving an empty array {} with no dimension or inner info at all +# - if all values are nullish and there is no empty sub array, the values manifest as null with dimension set +# - ARRAY[] raises for unbalanced multidimension arrays, TEXT format returns nonsense syntax error +# +# The following 2 functions is_empty_array and is_balanced try to restore some of the ARRAY[] behavior. +# This happens to a rather high price of 2 additional deep scan of array values. +# A better implementation prolly could do that in one pass combined. + + +def is_empty_array(v: Any) -> bool: """ Special handling of nullish array values, that reduce to single {}. @@ -476,19 +606,19 @@ def is_empty_array(v): return False -def _balanced(v, depth, dim=0): +def _balanced(v: Any, depth: int, dim: int = 0) -> Union[Tuple, None]: if not isinstance(v, (list, tuple)) or dim>=depth: - return + return None return tuple(_balanced(e, depth, dim+1) for e in v) -def is_balanced(v, depth): +def is_balanced(v: Any, depth: int) -> bool: """ Check if array value is balanced over multiple dimensions. """ - return len(set(_balanced(v, depth))) < 2 + return len(set(_balanced(v, depth) or [])) < 2 -def array_factory(encoder, depth=1, null=False): +def array_factory(encoder, depth=1, null=False) -> FieldEncoder: """ Factory for array value encoder. @@ -505,7 +635,7 @@ def array_factory(encoder, depth=1, null=False): ``depth`` is the max array dimensions, the encoder will try to descend into subarrays. ``null`` denotes whether the encoder should allow None at top level. """ - def encode_array(v, fname, lazy, dim=0): + def _encode_array(v: Any, fname: str, lazy: List[Any], dim: int = 0) -> str: if not dim: # handle top level separately, as it differs for certain aspects: # - null respected (postgres always allows null in nested arrays) @@ -535,10 +665,13 @@ def encode_array(v, fname, lazy, dim=0): # - always at bottom (dim==depth) final = str(encoder(v, fname, lazy)) return array_escape(final) if encoder.array_escape else final + encode_array = cast(FieldEncoder[Any], _encode_array) + encode_array.array_escape = False return encode_array -ENCODERS = { +# global field type -> encoder mapping +ENCODERS: Dict[Type[models.Field], Tuple[FieldEncoder, FieldEncoder]] = { models.AutoField: (Int, IntOrNone), models.BigAutoField: (Int, IntOrNone), models.BigIntegerField: (Int, IntOrNone), @@ -578,7 +711,11 @@ def encode_array(v, fname, lazy, dim=0): } -def register_fieldclass(field_cls, encoder, encoder_none=None): +def register_fieldclass( + field_cls: Type[models.Field], + encoder: FieldEncoder, + encoder_none: Optional[FieldEncoder] = None +) -> None: """ Register a fieldclass globally with value encoders. @@ -586,16 +723,16 @@ def register_fieldclass(field_cls, encoder, encoder_none=None): ``encoder_none`` for fields with ``null=True``. If only one encoder is provided, it will be used for both field settings. - In that case make sure, that the encoder correctly translates ``None``. + In that case make sure, that the encoder correctly translates None values. """ ENCODERS[field_cls] = (encoder, encoder_none or encoder) -def get_encoder(field, null=None): +def get_encoder(field: models.Field, null: Optional[bool] = None) -> FieldEncoder: """Get registered encoder for field.""" if null is None: null = field.null - if field.is_relation: + if isinstance(field, RelatedField): return get_encoder(field.target_field, null) if isinstance(field, ArrayField): # TODO: cache array encoder for later calls? @@ -612,8 +749,8 @@ def get_encoder(field, null=None): raise NotImplementedError(f'no suitable encoder found for field {field}') -def write_lazy(f, data, stack): - """Execute lazy value encoders.""" +def write_lazy(f: BinaryIO, data: bytearray, stack: List[Any]) -> None: + """Execute lazy field encoders.""" m = memoryview(data) idx = 0 for writer, byte_object in stack: @@ -625,14 +762,31 @@ def write_lazy(f, data, stack): f.write(m[idx:]) -def threaded_copy(cur, fr, tname, columns): - cur.copy_from(fr, tname, size=65536, columns=columns) - - -def copy_from(c, tname, data, fnames, columns, get, encs, encoding): +def threaded_copy( + c: CursorWrapper, + fr: BinaryIO, + tname: str, + columns: Tuple[str] +) -> None: + c.copy_from(fr, tname, size=65536, columns=columns) + + +def copy_from( + c: CursorWrapper, + tname: str, + data: Sequence[models.Model], + fnames: Tuple[str], + columns: Tuple[str], + get: attrgetter, + encs: List[Any], + encoding: str +) -> None: + """ + Optimized call of cursor.copy_from with threading for bigger change data. + """ use_thread = False payload = bytearray() - lazy = [] + lazy: List[Any] = [] for o in data: payload += '\t'.join([ f'{enc(el, fname, lazy)}' @@ -689,20 +843,25 @@ def copy_from(c, tname, data, fnames, columns, get, encs, encoding): f.close() -def create_columns(column_def): +def create_columns(column_def: Tuple[str, str]) -> str: """ Prepare columns for table create as follows: - types copied from target table - no indexes or constraints (no serial, no unique, no primary key etc.) """ - return (",".join(f'{k} {v}' for k, v in column_def) + return (','.join(f'{k} {v}' for k, v in column_def) .replace('bigserial', 'bigint') .replace('smallserial', 'smallint') .replace('serial', 'integer') ) -def update_sql(tname, temp_table, pkname, copy_fields): +def update_sql( + tname: str, + temp_table: str, + pkname: str, + copy_fields: List[models.Field] +) -> str: cols = ','.join(f'"{f.column}"="{temp_table}"."{f.column}"' for f in copy_fields) where = f'"{tname}"."{pkname}"="{temp_table}"."{pkname}"' return f'UPDATE "{tname}" SET {cols} FROM "{temp_table}" WHERE {where}' @@ -711,7 +870,7 @@ def update_sql(tname, temp_table, pkname, copy_fields): def copy_update( qs: models.QuerySet, objs: Sequence[models.Model], - fieldnames: Sequence[str], + fieldnames: Iterable[str], field_encoders: Optional[Dict[str, Any]] = None, encoding: Optional[str] = None ) -> int: diff --git a/fast_update/fast.py b/fast_update/fast.py index 9cf6db7..450ca39 100644 --- a/fast_update/fast.py +++ b/fast_update/fast.py @@ -1,4 +1,4 @@ -from weakref import WeakKeyDictionary +from weakref import WeakKeyDictionary, ReferenceType from django.db import transaction, models, connections from django.db.utils import ProgrammingError from django.db.models.functions import Cast @@ -8,7 +8,7 @@ # typing imports from django.db.models import Field -from typing import List, Optional, Sequence, Any, Union +from typing import Dict, Iterable, List, Optional, Sequence, Any, Union, cast from django.db.models.sql.compiler import SQLCompiler from django.db.backends.utils import CursorWrapper from django.db.backends.base.base import BaseDatabaseWrapper @@ -18,7 +18,7 @@ # memorize fast_update vendor on connection object -SEEN_CONNECTIONS = WeakKeyDictionary() +SEEN_CONNECTIONS = cast(Dict[BaseDatabaseWrapper, str], WeakKeyDictionary()) def get_vendor(conn: BaseDatabaseWrapper) -> str: @@ -40,7 +40,8 @@ def get_vendor(conn: BaseDatabaseWrapper) -> str: return 'postgresql' if conn.vendor == 'sqlite': - major, minor, _ = conn.Database.sqlite_version_info + _conn = cast(Any, conn) + major, minor, _ = _conn.Database.sqlite_version_info if (major == 3 and minor > 32) or major > 3: SEEN_CONNECTIONS[conn] = 'sqlite' return 'sqlite' @@ -74,7 +75,7 @@ def get_vendor(conn: BaseDatabaseWrapper) -> str: def pq_cast(tname: str, field: Field, compiler: SQLCompiler, connection: Any) -> str: """Column type cast for postgres.""" - # FIXME: compare to as_postgresql in v4 + # TODO: compare to as_postgresql in v4 return Cast(Col(tname, field), output_field=field).as_sql(compiler, connection)[0] @@ -232,7 +233,7 @@ def update_from_values( def fast_update( qs: models.QuerySet, objs: Sequence[models.Model], - fieldnames: Sequence[str], + fieldnames: Iterable[str], batch_size: Union[int, None] ) -> int: qs._for_write = True @@ -258,7 +259,7 @@ def fast_update( # prepare all needed arguments for update max_batch_size = conn.ops.bulk_batch_size(['pk'] + local_fieldnames, objs) - batch_size = min(batch_size or 2 ** 31, max_batch_size) + batch_size_adjusted = min(batch_size or 2 ** 31, max_batch_size) fields = [model._meta.get_field(f) for f in local_fieldnames] pk_field = model._meta.pk get = attrgetter(pk_field.attname, *(f.attname for f in fields)) @@ -274,7 +275,7 @@ def fast_update( for o in objs: counter += 1 data += [p(v, conn) for p, v in zip(prep_save, get(o))] - if counter >= batch_size: + if counter >= batch_size_adjusted: rows_updated += update_from_values( c, vendor, model._meta.db_table, pk_field, fields, counter, data, compiler, conn diff --git a/fast_update/query.py b/fast_update/query.py index 750f9b9..10b00fd 100644 --- a/fast_update/query.py +++ b/fast_update/query.py @@ -1,74 +1,118 @@ from django.db import connections from django.db.models import QuerySet, Model, Manager from django.db.utils import NotSupportedError -from typing import Any, Dict, Optional, Sequence +from typing import Any, Dict, Iterable, Optional, Sequence, Type from .fast import fast_update -def sanity_check(model, objs, fields, batch_size): +def sanity_check( + model: Type[Model], + objs: Iterable[Model], + fields: Iterable[str], + batch_size: Optional[int] = None +) -> None: # basic sanity checks (most taken from bulk_update) if batch_size is not None and batch_size < 0: raise ValueError('Batch size must be a positive integer.') if not fields: raise ValueError('Field names must be given to fast_update().') - if not objs: - return 0 if any(obj.pk is None for obj in objs): raise ValueError('All fast_update() objects must have a primary key set.') - fields = [model._meta.get_field(name) for name in fields] - if any(not f.concrete or f.many_to_many for f in fields): + fields_ = [model._meta.get_field(name) for name in fields] + if any(not f.concrete or f.many_to_many for f in fields_): raise ValueError('fast_update() can only be used with concrete fields.') - if any(f.primary_key for f in fields): + if any(f.primary_key for f in fields_): raise ValueError('fast_update() cannot be used with primary key fields.') for obj in objs: - # FIXME: django main has an additional argument 'fields' + # TODO: django main has an additional argument 'fields' (saves some runtime?) obj._prepare_related_fields_for_save(operation_name='fast_update') # additionally raise on f-expression - for field in fields: + for field in fields_: attr = getattr(obj, field.attname) if hasattr(attr, 'resolve_expression'): raise ValueError('fast_update() cannot be used with f-expressions.') - return fields class FastUpdateQuerySet(QuerySet): def fast_update( self, - objs: Sequence[Model], - fields: Sequence[str], + objs: Iterable[Model], + fields: Iterable[str], batch_size: Optional[int] = None ) -> int: """ - TODO... + Faster alternative for ``bulk_update`` with the same method signature. + + Due to the way the update works internally with constant VALUES tables, + f-expressions cannot be used anymore. Beside that it has similar + restrictions as ``bulk_update`` (e.g. primary keys cannot be updated). + + The internal implementation relies on recent versions of database + backends and will fall back to ``bulk_update`` if the backend is not + supported. It will also invoke ``bulk_update`` for non-local fields + (e.g. for multi-table inheritance). + + ``batch_size`` can be set to much higher values than typically + for ``bulk_update`` (if needed at all). + + Returns the number of affected rows. """ + if not objs: + return 0 objs = tuple(objs) - fields = set(fields or []) - sanity_check(self.model, objs, fields, batch_size) - return fast_update(self, objs, fields, batch_size) + fields_ = set(fields or []) + sanity_check(self.model, objs, fields_, batch_size) + return fast_update(self, objs, fields_, batch_size) fast_update.alters_data = True def copy_update( self, - objs: Sequence[Model], - fields: Sequence[str], + objs: Iterable[Model], + fields: Iterable[str], field_encoders: Optional[Dict[str, Any]] = None, encoding: Optional[str] = None ) -> int: """ - TODO... + PostgreSQL only method (raises an exception on any other backend) + to update a large amount of model instances via COPY FROM. + The method follows the same interface idea of ``bulk_update`` or ``fast_update``, + but will perform much better for bigger updates, even than ``fast_update``. + + Other than for ``fast_update``, there is no ``batch_size`` argument anymore, + as the update is always done in one single big batch by copying the data into + a temporary table and run the update from there. + + For the data transport postgres' TEXT format is used. For this the field values + get encoded by special encoders. The encoders are globally registered for + django's standard field types (works similar to `get_db_prep_value`). + With ``field_encoders`` custom encoders can be attached to update fields + for a single call. This might come handy for additional conversion work or + further speedup by omitting the base type checks of the default encoders + (do this only if the data was checked by other means, otherwise malformed + updates may happen). + + ``encoding`` overwrites the text encoding used in the COPY FROM transmission + (default is psycopg's connection encoding). + + Returns the number of affected rows. + + NOTE: The underlying implementation is only a PoC and probably will be replaced + soon by the much safer and superior COPY support of psycopg3. """ self._for_write = True connection = connections[self.db] if connection.vendor != 'postgresql': raise NotSupportedError( f'copy_update() is not supported on "{connection.vendor}" backend') - from .copy import copy_update + from .copy import copy_update # TODO: better in conditional import? + if not objs: + return 0 objs = tuple(objs) - fields = set(fields or []) - sanity_check(self.model, objs, fields, 123) - return copy_update(self, objs, fields, field_encoders, encoding) + fields_ = set(fields or []) + sanity_check(self.model, objs, fields_) + return copy_update(self, objs, fields_, field_encoders, encoding) copy_update.alters_data = True diff --git a/setup.py b/setup.py index 9eb17f2..b73f9db 100644 --- a/setup.py +++ b/setup.py @@ -24,10 +24,10 @@ def get_version(path): author='netzkolchose', author_email='j.breitbart@netzkolchose.de', url='https://github.com/netzkolchose/django-fast-update', - download_url='https://github.com/netzkolchose/django-fast-update/archive/v0.0.3.tar.gz', + download_url='https://github.com/netzkolchose/django-fast-update/archive/v0.1.0.tar.gz', keywords=['django', 'bulk_update', 'fast', 'update', 'fast_update'], classifiers=[ - 'Development Status :: 3 - Alpha', + 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Topic :: Database', 'Topic :: Database :: Front-Ends',