Skip to content
Permalink
Browse files
feat: Handle passing of arrays to in statements more efficiently in S…
…QLAlchemy 1.4 and higher (#253)
  • Loading branch information
jimfulton committed Aug 25, 2021
1 parent 9b5b002 commit 76927044aa4d2be9d0f2ec47e917b28b97c18425
Showing with 141 additions and 79 deletions.
  1. +36 −2 sqlalchemy_bigquery/base.py
  2. +21 −0 tests/system/test_sqlalchemy_bigquery.py
  3. +13 −4 tests/unit/fauxdbi.py
  4. +71 −73 tests/unit/test_select.py
@@ -483,6 +483,39 @@ def visit_bindparam(
skip_bind_expression=False,
**kwargs,
):
type_ = bindparam.type
unnest = False
if (
bindparam.expanding
and not isinstance(type_, NullType)
and not literal_binds
):
# Normally, when performing an IN operation, like:
#
# foo IN (some_sequence)
#
# SQAlchemy passes `foo` as a parameter and unpacks
# `some_sequence` and passes each element as a parameter.
# This mechanism is refered to as "expanding". It's
# inefficient and can't handle large arrays. (It's also
# very complicated, but that's not the issue we care about
# here. :) ) BigQuery lets us use arrays directly in this
# context, we just need to call UNNEST on an array when
# it's used in IN.
#
# So, if we get an `expanding` flag, and if we have a known type
# (and don't have literal binds, which are implemented in-line in
# in the SQL), we turn off expanding and we set an unnest flag
# so that we add an UNNEST() call (below).
#
# The NullType/known-type check has to do with some extreme
# edge cases having to do with empty in-lists that get special
# hijinks from SQLAlchemy that we don't want to disturb. :)
if getattr(bindparam, "expand_op", None) is not None:
assert bindparam.expand_op.__name__.endswith("in_op") # in in
bindparam.expanding = False
unnest = True

param = super(BigQueryCompiler, self).visit_bindparam(
bindparam,
within_columns_clause,
@@ -491,7 +524,6 @@ def visit_bindparam(
**kwargs,
)

type_ = bindparam.type
if literal_binds or isinstance(type_, NullType):
return param

@@ -512,7 +544,6 @@ def visit_bindparam(
if bq_type[-1] == ">" and bq_type.startswith("ARRAY<"):
# Values get arrayified at a lower level.
bq_type = bq_type[6:-1]

bq_type = self.__remove_type_parameter(bq_type)

assert_(param != "%s", f"Unexpected param: {param}")
@@ -528,6 +559,9 @@ def visit_bindparam(
assert_(type_ is None)
param = f"%({name}:{bq_type})s"

if unnest:
param = f"UNNEST({param})"

return param


@@ -727,6 +727,27 @@ class MyTable(Base):
assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected


@pytest.mark.skipif(
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
reason="requires sqlalchemy 1.4 or higher",
)
def test_huge_in():
engine = sqlalchemy.create_engine("bigquery://")
conn = engine.connect()
try:
assert list(
conn.execute(
sqlalchemy.select([sqlalchemy.literal(-1).in_(list(range(99999)))])
)
) == [(False,)]
except Exception:
error = True
else:
error = False

assert not error, "execution failed"


@pytest.mark.skipif(
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
reason="unnest (and other table-valued-function) support required version 1.4",
@@ -261,11 +261,20 @@ def __handle_problematic_literal_inserts(
else:
return operation

__handle_unnest = substitute_string_re_method(
r"UNNEST\(\[ ([^\]]+)? \]\)", # UNNEST([ ... ])
flags=re.IGNORECASE,
repl=r"(\1)",
@substitute_re_method(
r"""
UNNEST\(
(
\[ (?P<exp>[^\]]+)? \] # UNNEST([ ... ])
|
([?]) # UNNEST(?)
)
\)
""",
flags=re.IGNORECASE | re.VERBOSE,
)
def __handle_unnest(self, m):
return "(" + (m.group("exp") or "?") + ")"

def __handle_true_false(self, operation):
# Older sqlite versions, like those used on the CI servers
@@ -28,6 +28,7 @@

from conftest import (
setup_table,
sqlalchemy_version,
sqlalchemy_1_3_or_higher,
sqlalchemy_1_4_or_higher,
sqlalchemy_before_1_4,
@@ -214,18 +215,6 @@ def test_disable_quote(faux_conn):
assert faux_conn.test_data["execute"][-1][0] == ("SELECT `t`.foo \nFROM `t`")


def _normalize_in_params(query, params):
# We have to normalize parameter names, because they
# change with sqlalchemy versions.
newnames = sorted(
((p, f"p_{i}") for i, p in enumerate(sorted(params))), key=lambda i: -len(i[0])
)
for old, new in newnames:
query = query.replace(old, new)

return query, {new: params[old] for old, new in newnames}


@sqlalchemy_before_1_4
def test_select_in_lit_13(faux_conn):
[[isin]] = faux_conn.execute(
@@ -240,66 +229,74 @@ def test_select_in_lit_13(faux_conn):


@sqlalchemy_1_4_or_higher
def test_select_in_lit(faux_conn):
[[isin]] = faux_conn.execute(
sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])])
)
assert isin
assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == (
"SELECT %(p_0:INT64)s IN "
"UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ]) AS `anon_1`",
{"p_1": 1, "p_2": 2, "p_3": 3, "p_0": 1},
def test_select_in_lit(faux_conn, last_query):
faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(1).in_([1, 2, 3])]))
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(%(param_2:INT64)s) AS `anon_1`",
{"param_1": 1, "param_2": [1, 2, 3]},
)


def test_select_in_param(faux_conn):
def test_select_in_param(faux_conn, last_query):
[[isin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[1, 2, 3]),
)
assert isin
assert faux_conn.test_data["execute"][-1] == (
"SELECT %(param_1:INT64)s IN UNNEST("
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
") AS `anon_1`",
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
)
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
{"param_1": 1, "q": [1, 2, 3]},
)
else:
assert isin
last_query(
"SELECT %(param_1:INT64)s IN UNNEST("
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
") AS `anon_1`",
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
)


def test_select_in_param1(faux_conn):
def test_select_in_param1(faux_conn, last_query):
[[isin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[1]),
)
assert isin
assert faux_conn.test_data["execute"][-1] == (
"SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`",
{"param_1": 1, "q_1": 1},
)
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
{"param_1": 1, "q": [1]},
)
else:
assert isin
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(" "[ %(q_1:INT64)s ]" ") AS `anon_1`",
{"param_1": 1, "q_1": 1},
)


@sqlalchemy_1_3_or_higher
def test_select_in_param_empty(faux_conn):
def test_select_in_param_empty(faux_conn, last_query):
[[isin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).in_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[]),
)
assert not isin
assert faux_conn.test_data["execute"][-1] == (
"SELECT %(param_1:INT64)s IN(NULL) AND (1 != 1) AS `anon_1`"
if (
packaging.version.parse(sqlalchemy.__version__)
>= packaging.version.parse("1.4")
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT %(param_1:INT64)s IN UNNEST(%(q:INT64)s) AS `anon_1`",
{"param_1": 1, "q": []},
)
else:
assert not isin
last_query(
"SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`", {"param_1": 1}
)
else "SELECT %(param_1:INT64)s IN UNNEST([ ]) AS `anon_1`",
{"param_1": 1},
)


@sqlalchemy_before_1_4
@@ -316,53 +313,54 @@ def test_select_notin_lit13(faux_conn):


@sqlalchemy_1_4_or_higher
def test_select_notin_lit(faux_conn):
[[isnotin]] = faux_conn.execute(
sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])])
def test_select_notin_lit(faux_conn, last_query):
faux_conn.execute(sqlalchemy.select([sqlalchemy.literal(0).notin_([1, 2, 3])]))
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(param_2:INT64)s)) AS `anon_1`",
{"param_1": 0, "param_2": [1, 2, 3]},
)
assert isnotin

assert _normalize_in_params(*faux_conn.test_data["execute"][-1]) == (
"SELECT (%(p_0:INT64)s NOT IN "
"UNNEST([ %(p_1:INT64)s, %(p_2:INT64)s, %(p_3:INT64)s ])) AS `anon_1`",
{"p_0": 0, "p_1": 1, "p_2": 2, "p_3": 3},
)


def test_select_notin_param(faux_conn):
def test_select_notin_param(faux_conn, last_query):
[[isnotin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[1, 2, 3]),
)
assert not isnotin
assert faux_conn.test_data["execute"][-1] == (
"SELECT (%(param_1:INT64)s NOT IN UNNEST("
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
")) AS `anon_1`",
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
)
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`",
{"param_1": 1, "q": [1, 2, 3]},
)
else:
assert not isnotin
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST("
"[ %(q_1:INT64)s, %(q_2:INT64)s, %(q_3:INT64)s ]"
")) AS `anon_1`",
{"param_1": 1, "q_1": 1, "q_2": 2, "q_3": 3},
)


@sqlalchemy_1_3_or_higher
def test_select_notin_param_empty(faux_conn):
def test_select_notin_param_empty(faux_conn, last_query):
[[isnotin]] = faux_conn.execute(
sqlalchemy.select(
[sqlalchemy.literal(1).notin_(sqlalchemy.bindparam("q", expanding=True))]
),
dict(q=[]),
)
assert isnotin
assert faux_conn.test_data["execute"][-1] == (
"SELECT (%(param_1:INT64)s NOT IN(NULL) OR (1 = 1)) AS `anon_1`"
if (
packaging.version.parse(sqlalchemy.__version__)
>= packaging.version.parse("1.4")
if sqlalchemy_version >= packaging.version.parse("1.4"):
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST(%(q:INT64)s)) AS `anon_1`",
{"param_1": 1, "q": []},
)
else:
assert isnotin
last_query(
"SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`", {"param_1": 1}
)
else "SELECT (%(param_1:INT64)s NOT IN UNNEST([ ])) AS `anon_1`",
{"param_1": 1},
)


def test_literal_binds_kwarg_with_an_IN_operator_252(faux_conn):

0 comments on commit 7692704

Please sign in to comment.