Skip to content
Permalink
Browse files
feat: Add support for SQLAlchemy 1.4 (#177)
  • Loading branch information
jimfulton committed May 21, 2021
1 parent 9dd3cf4 commit b7b60007c966cd548448d1d6fd5a14d1f89480cd
@@ -28,7 +28,9 @@
BLACK_PATHS = ["docs", "pybigquery", "tests", "noxfile.py", "setup.py"]

DEFAULT_PYTHON_VERSION = "3.8"
SYSTEM_TEST_PYTHON_VERSIONS = ["3.9"]

# We're using two Python versions to test with sqlalchemy 1.3 and 1.4.
SYSTEM_TEST_PYTHON_VERSIONS = ["3.8", "3.9"]
UNIT_TEST_PYTHON_VERSIONS = ["3.6", "3.7", "3.8", "3.9"]

CURRENT_DIRECTORY = pathlib.Path(__file__).parent.absolute()
@@ -47,6 +49,7 @@

# Error if a python version is missing
nox.options.error_on_missing_interpreters = True
nox.options.stop_on_first_error = True


@nox.session(python=DEFAULT_PYTHON_VERSION)
@@ -4,6 +4,9 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import functools
import re

from google.api_core import client_info
import google.auth
from google.cloud import bigquery
@@ -58,3 +61,22 @@ def create_bigquery_client(
location=location,
default_query_job_config=default_query_job_config,
)


def substitute_re_method(r, flags=0, repl=None):
if repl is None:
return lambda f: substitute_re_method(r, flags, f)

r = re.compile(r, flags)

if isinstance(repl, str):
return lambda self, s: r.sub(repl, s)

@functools.wraps(repl)
def sub(self, s, *args, **kw):
def repl_(m):
return repl(self, m, *args, **kw)

return r.sub(repl_, s)

return sub
@@ -44,7 +44,7 @@ def parse_boolean(bool_string):


def parse_url(url): # noqa: C901
query = url.query
query = dict(url.query) # need mutable query.

# use_legacy_sql (legacy)
if "use_legacy_sql" in query:
@@ -154,8 +154,14 @@ def comment_reflection(self):
def unicode_ddl(self):
"""Target driver must support some degree of non-ascii symbol
names.
However:
Must contain only letters (a-z, A-Z), numbers (0-9), or underscores (_)
https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#column_name_and_column_schema
"""
return supported()
return unsupported()

@property
def datetime_literals(self):
@@ -219,6 +225,14 @@ def order_by_label_with_expression(self):
"""
return supported()

@property
def sql_expression_limit_offset(self):
"""target database can render LIMIT and/or OFFSET with a complete
SQL expression, such as one that uses the addition operator.
parameter
"""
return unsupported()


class WithSchemas(Requirements):
"""
@@ -34,6 +34,7 @@
from google.cloud.bigquery.table import TableReference
from google.api_core.exceptions import NotFound

import sqlalchemy
import sqlalchemy.sql.sqltypes
import sqlalchemy.sql.type_api
from sqlalchemy.exc import NoSuchTableError
@@ -57,6 +58,11 @@
FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+")


def assert_(cond, message="Assertion failed"): # pragma: NO COVER
if not cond:
raise AssertionError(message)


class BigQueryIdentifierPreparer(IdentifierPreparer):
"""
Set containing everything
@@ -152,15 +158,25 @@ def get_insert_default(self, column): # pragma: NO COVER
elif isinstance(column.type, String):
return str(uuid.uuid4())

def pre_exec(
self,
in_sub=re.compile(
r" IN UNNEST\(\[ "
r"(%\([^)]+_\d+\)s(?:, %\([^)]+_\d+\)s)*)?" # Placeholders. See below.
r":([A-Z0-9]+)" # Type
r" \]\)"
).sub,
):
__remove_type_from_empty_in = _helpers.substitute_re_method(
r" IN UNNEST\(\[ ("
r"(?:NULL|\(NULL(?:, NULL)+\))\)"
r" (?:AND|OR) \(1 !?= 1"
r")"
r"(?:[:][A-Z0-9]+)?"
r" \]\)",
re.IGNORECASE,
r" IN(\1)",
)

@_helpers.substitute_re_method(
r" IN UNNEST\(\[ "
r"(%\([^)]+_\d+\)s(?:, %\([^)]+_\d+\)s)*)?" # Placeholders. See below.
r":([A-Z0-9]+)" # Type
r" \]\)",
re.IGNORECASE,
)
def __distribute_types_to_expanded_placeholders(self, m):
# If we have an in parameter, it sometimes gets expaned to 0 or more
# parameters and we need to move the type marker to each
# parameter.
@@ -171,29 +187,29 @@ def pre_exec(
# suffixes refect that when an array parameter is expanded,
# numeric suffixes are added. For example, a placeholder like
# `%(foo)s` gets expaneded to `%(foo_0)s, `%(foo_1)s, ...`.
placeholders, type_ = m.groups()
if placeholders:
placeholders = placeholders.replace(")", f":{type_})")
else:
placeholders = ""
return f" IN UNNEST([ {placeholders} ])"

def repl(m):
placeholders, type_ = m.groups()
if placeholders:
placeholders = placeholders.replace(")", f":{type_})")
else:
placeholders = ""
return f" IN UNNEST([ {placeholders} ])"

self.statement = in_sub(repl, self.statement)
def pre_exec(self):
self.statement = self.__distribute_types_to_expanded_placeholders(
self.__remove_type_from_empty_in(self.statement)
)


class BigQueryCompiler(SQLCompiler):

compound_keywords = SQLCompiler.compound_keywords.copy()
compound_keywords[selectable.CompoundSelect.UNION] = "UNION ALL"
compound_keywords[selectable.CompoundSelect.UNION] = "UNION DISTINCT"
compound_keywords[selectable.CompoundSelect.UNION_ALL] = "UNION ALL"

def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
def __init__(self, dialect, statement, *args, **kwargs):
if isinstance(statement, Column):
kwargs["compile_kwargs"] = util.immutabledict({"include_table": False})
super(BigQueryCompiler, self).__init__(
dialect, statement, column_keys, inline, **kwargs
)
super(BigQueryCompiler, self).__init__(dialect, statement, *args, **kwargs)

def visit_insert(self, insert_stmt, asfrom=False, **kw):
# The (internal) documentation for `inline` is confusing, but
@@ -260,24 +276,37 @@ def group_by_clause(self, select, **kw):
# no way to tell sqlalchemy that, so it works harder than
# necessary and makes us do the same.

_in_expanding_bind = re.compile(r" IN \((\[EXPANDING_\w+\](:[A-Z0-9]+)?)\)$")
__sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split(".")))

def _unnestify_in_expanding_bind(self, in_text):
return self._in_expanding_bind.sub(r" IN UNNEST([ \1 ])", in_text)
__expandng_text = (
"EXPANDING" if __sqlalchemy_version_info < (1, 4) else "POSTCOMPILE"
)

__in_expanding_bind = _helpers.substitute_re_method(
fr" IN \((\[" fr"{__expandng_text}" fr"_[^\]]+\](:[A-Z0-9]+)?)\)$",
re.IGNORECASE,
r" IN UNNEST([ \1 ])",
)

def visit_in_op_binary(self, binary, operator_, **kw):
return self._unnestify_in_expanding_bind(
return self.__in_expanding_bind(
self._generate_generic_binary(binary, " IN ", **kw)
)

def visit_empty_set_expr(self, element_types):
return ""

def visit_notin_op_binary(self, binary, operator, **kw):
return self._unnestify_in_expanding_bind(
self._generate_generic_binary(binary, " NOT IN ", **kw)
def visit_not_in_op_binary(self, binary, operator, **kw):
return (
"("
+ self.__in_expanding_bind(
self._generate_generic_binary(binary, " NOT IN ", **kw)
)
+ ")"
)

visit_notin_op_binary = visit_not_in_op_binary # before 1.4

############################################################################

############################################################################
@@ -327,6 +356,10 @@ def visit_notendswith_op_binary(self, binary, operator, **kw):

############################################################################

__placeholder = re.compile(r"%\(([^\]:]+)(:[^\]:]+)?\)s$").match

__expanded_param = re.compile(fr"\(\[" fr"{__expandng_text}" fr"_[^\]]+\]\)$").match

def visit_bindparam(
self,
bindparam,
@@ -365,8 +398,20 @@ def visit_bindparam(
# Values get arrayified at a lower level.
bq_type = bq_type[6:-1]

assert param != "%s"
return param.replace(")", f":{bq_type})")
assert_(param != "%s", f"Unexpected param: {param}")

if bindparam.expanding:
assert_(self.__expanded_param(param), f"Unexpected param: {param}")
param = param.replace(")", f":{bq_type})")

else:
m = self.__placeholder(param)
if m:
name, type_ = m.groups()
assert_(type_ is None)
param = f"%({name}:{bq_type})s"

return param


class BigQueryTypeCompiler(GenericTypeCompiler):
@@ -541,7 +586,6 @@ class BigQueryDialect(DefaultDialect):
supports_unicode_statements = True
supports_unicode_binds = True
supports_native_decimal = True
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
supports_simple_order_by_label = True
@@ -65,10 +65,10 @@ def readme():
],
platforms="Posix; MacOS X; Windows",
install_requires=[
"sqlalchemy>=1.2.0,<1.4.0dev",
"google-auth>=1.24.0,<2.0dev", # Work around pip wack.
"google-cloud-bigquery>=2.15.0",
"google-api-core>=1.23.0", # Work-around bug in cloud core deps.
"google-auth>=1.24.0,<2.0dev", # Work around pip wack.
"google-cloud-bigquery>=2.16.1",
"sqlalchemy>=1.2.0,<1.5.0dev",
"future",
],
python_requires=">=3.6, <3.10",
@@ -6,5 +6,5 @@
# e.g., if setup.py has "foo >= 1.14.0, < 2.0.0dev",
sqlalchemy==1.2.0
google-auth==1.24.0
google-cloud-bigquery==2.15.0
google-cloud-bigquery==2.16.1
google-api-core==1.23.0
@@ -0,0 +1 @@
sqlalchemy==1.3.24
@@ -0,0 +1 @@
sqlalchemy>=1.4.13
@@ -20,3 +20,8 @@
from sqlalchemy.dialects import registry

registry.register("bigquery", "pybigquery.sqlalchemy_bigquery", "BigQueryDialect")

# sqlalchemy's dialect-testing machinery wants an entry like this. It is wack. :(
registry.register(
"bigquery.bigquery", "pybigquery.sqlalchemy_bigquery", "BigQueryDialect"
)
@@ -19,9 +19,9 @@

import contextlib
import random
import re
import traceback

import sqlalchemy
from sqlalchemy.testing import config
from sqlalchemy.testing.plugin.pytestplugin import * # noqa
from sqlalchemy.testing.plugin.pytestplugin import (
@@ -35,23 +35,28 @@
pybigquery.sqlalchemy_bigquery.BigQueryDialect.preexecute_autoincrement_sequences = True
google.cloud.bigquery.dbapi.connection.Connection.rollback = lambda self: None

_where = re.compile(r"\s+WHERE\s+", re.IGNORECASE).search

# BigQuery requires delete statements to have where clauses. Other
# databases don't and sqlalchemy doesn't include where clauses when
# cleaning up test data. So we add one when we see a delete without a
# where clause when tearing down tests. We only do this during tear
# down, by inspecting the stack, because we don't want to hide bugs
# outside of test house-keeping.
def visit_delete(self, delete_stmt, *args, **kw):
if delete_stmt._whereclause is None and "teardown" in set(
f.name for f in traceback.extract_stack()
):
delete_stmt._whereclause = sqlalchemy.true()

return super(pybigquery.sqlalchemy_bigquery.BigQueryCompiler, self).visit_delete(

def visit_delete(self, delete_stmt, *args, **kw):
text = super(pybigquery.sqlalchemy_bigquery.BigQueryCompiler, self).visit_delete(
delete_stmt, *args, **kw
)

if not _where(text) and any(
"teardown" in f.name.lower() for f in traceback.extract_stack()
):
text += " WHERE true"

return text


pybigquery.sqlalchemy_bigquery.BigQueryCompiler.visit_delete = visit_delete

0 comments on commit b7b6000

Please sign in to comment.