Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jerch committed Apr 29, 2022
1 parent 190adfb commit 6d169d9
Showing 1 changed file with 18 additions and 20 deletions.
38 changes: 18 additions & 20 deletions fast_update/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from django.db.models.functions import Cast
from django.db.models.expressions import Col
from operator import attrgetter
import logging

# typing imports
from django.db.models import Field
Expand All @@ -13,12 +12,21 @@
from django.db.backends.base.base import BaseDatabaseWrapper


logger = logging.getLogger(__name__)


"""
DB vendor low level interfaces
To register a fast update implementations, call:
register_implementation('alias', check_function)
where `alias` is the vendor name as returned by `connection.vendor`.
The check function gets called once (lazy) with `connection` and is meant
to find a suitable implementation (you can provide multiple for different
server versions), either actively by probing against the db server,
or directly if it can be determined upfront.
The check function should return a tuple of (create_sql, prepare_data | None)
for supported backends, or an empty tuple, if the backend is unsupported.
create_sql function:
def as_sql_xy(
Expand Down Expand Up @@ -59,18 +67,6 @@ def prepare_data_xy(
the flat data table preparation is needed. The data is row based,
and contains field values in [pk] + fields order.
Return the altered data listing according to the SQL needs.
To register a fast update implementations, call:
register_implementation('alias', check_function)
where `alias` is the vendor name as returned by `connection.vendor`.
The check function gets called once (lazy) with `connection` and is meant
to find a suitable implementation (you can provide multiple for different
server versions, if needed), either actively by probing against
the db server, or directly if it can be determined upfront.
The check function must return a tuple of (create_sql, prepare_data | None)
for supported backends, or an empty tuple, if the backend is unsupported.
"""


Expand Down Expand Up @@ -109,7 +105,7 @@ def get_impl(conn: BaseDatabaseWrapper) -> str:
if not check: # pragma: no cover
SEEN_CONNECTIONS[conn] = tuple()
return tuple()
impl = check(conn)
impl = check(conn) or tuple() # NOTE: in case check returns something nullish
SEEN_CONNECTIONS[conn] = impl
return impl

Expand Down Expand Up @@ -207,6 +203,7 @@ def as_sqlite_cte(
f'WHERE "{tname}"."{pkname}" in ({pks})'
)


def prepare_data_sqlite_cte(data, width, height):
return data + [data[i] for i in range(0, len(data), width)]

Expand Down Expand Up @@ -246,24 +243,25 @@ def as_mysql(
)


# Register our default db implementations.
register_implementation(
'postgresql',
lambda conn: (as_postgresql, None)
lambda _: (as_postgresql, None)
)
register_implementation(
'sqlite',
# NOTE: check function does not handle versions <2.15 anymore
lambda conn: (as_sqlite, None) if conn.Database.sqlite_version_info >= (3, 33)
else (as_sqlite_cte, prepare_data_sqlite_cte)
)
register_implementation(
'mysql',
lambda conn: (as_mysql, None)
lambda _: (as_mysql, None)
)


def update_from_values(
c: CursorWrapper,
#vendor: str,
tname: str,
pk_field: Field,
fields: List[Field],
Expand Down

0 comments on commit 6d169d9

Please sign in to comment.