Skip to content

Commit

Permalink
module cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jerch committed Apr 30, 2022
1 parent 353baef commit 32dc155
Showing 1 changed file with 50 additions and 41 deletions.
91 changes: 50 additions & 41 deletions fast_update/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
DB vendor low level interfaces
To register a fast update implementations, call:
To register fast update implementations, call:
register_implementation('vendor', check_function)
Expand Down Expand Up @@ -70,44 +70,7 @@ def prepare_data_xy(
"""


# memorize fast_update vendor on connection object
SEEN_CONNECTIONS = cast(Dict[BaseDatabaseWrapper, str], WeakKeyDictionary())
CHECKER = {}

def register_implementation(
vendor: str,
func: Callable[[BaseDatabaseWrapper], Tuple[Any]]
) -> None:
"""
Register fast update implementation for db vendor.
`vendor` is the vendor name as returned by `connection.vendor`.
`func` is a lazy called function to check support for a certain
implementation at runtime for `connection`. The function should return
a tuple of (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple (needed to avoid re-eval).
"""
CHECKER[vendor] = func


def get_impl(conn: BaseDatabaseWrapper) -> str:
"""
Try to get a fast update implementation for `conn`.
Calls once the the check function of `register_implementation` and
memorizes its result for `conn`.
Returns a tuple (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple.
"""
impl = SEEN_CONNECTIONS.get(conn)
if impl is not None:
return impl
check = CHECKER.get(conn.vendor)
if not check: # pragma: no cover
SEEN_CONNECTIONS[conn] = tuple()
return tuple()
impl = check(conn) or tuple() # NOTE: in case check returns something nullish
SEEN_CONNECTIONS[conn] = impl
return impl
# Fast update implementations for postgres, sqlite and mysql.


def pq_cast(tname: str, field: Field, compiler: SQLCompiler, connection: Any) -> str:
Expand Down Expand Up @@ -243,14 +206,57 @@ def as_mysql(
)


# Register our default db implementations.
# Implementation registry.


# memorize fast_update vendor on connection object
SEEN_CONNECTIONS = cast(Dict[BaseDatabaseWrapper, str], WeakKeyDictionary())
CHECKER = {}

def register_implementation(
vendor: str,
func: Callable[[BaseDatabaseWrapper], Tuple[Any]]
) -> None:
"""
Register fast update implementation for db vendor.
`vendor` is the vendor name as returned by `connection.vendor`.
`func` is a lazy called function to check support for a certain
implementation at runtime for `connection`. The function should return
a tuple of (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple (needed to avoid re-eval).
"""
CHECKER[vendor] = func


def get_impl(conn: BaseDatabaseWrapper) -> str:
"""
Try to get a fast update implementation for `conn`.
Calls once the check function of `register_implementation` and
memorizes its result for `conn`.
Returns a tuple (create_sql, prepare_data | None) for supported backends,
otherwise an empty tuple.
"""
impl = SEEN_CONNECTIONS.get(conn)
if impl is not None:
return impl
check = CHECKER.get(conn.vendor)
if not check: # pragma: no cover
SEEN_CONNECTIONS[conn] = tuple()
return tuple()
impl = check(conn) or tuple() # NOTE: in case check returns something nullish
SEEN_CONNECTIONS[conn] = impl
return impl


# Register default db implementations from above.
register_implementation(
'postgresql',
lambda _: (as_postgresql, None)
)
register_implementation(
'sqlite',
# NOTE: check function does not handle versions <2.15 anymore
# NOTE: check function does not handle versions <3.15 anymore
lambda conn: (as_sqlite, None) if conn.Database.sqlite_version_info >= (3, 33)
else (as_sqlite_cte, prepare_data_sqlite_cte)
)
Expand All @@ -260,6 +266,9 @@ def as_mysql(
)


# Update implementation.


def update_from_values(
c: CursorWrapper,
tname: str,
Expand Down

0 comments on commit 32dc155

Please sign in to comment.