Skip to content
Permalink
Browse files
fix: unnest failed in some cases (with table references failed when t…
…here were no other references to refrenced tables in a query) (#290)
  • Loading branch information
jimfulton committed Aug 25, 2021
1 parent 5e9f4c2 commit 9b5b0025ec0b65177c0df02013ac387b3d3de472
@@ -45,7 +45,11 @@ def readme():
return f.read()


extras = dict(geography=["GeoAlchemy2", "shapely"], alembic=["alembic"], tests=["pytz"])
extras = dict(
geography=["GeoAlchemy2", "shapely"],
alembic=["alembic"],
tests=["packaging", "pytz"],
)
extras["all"] = set(itertools.chain.from_iterable(extras.values()))

setup(
@@ -85,7 +89,7 @@ def readme():
],
extras_require=extras,
python_requires=">=3.6, <3.10",
tests_require=["pytz"],
tests_require=["packaging", "pytz"],
entry_points={
"sqlalchemy.dialects": ["bigquery = sqlalchemy_bigquery:BigQueryDialect"]
},
@@ -24,40 +24,42 @@

from .base import BigQueryDialect, dialect # noqa
from .base import (
STRING,
ARRAY,
BIGNUMERIC,
BOOL,
BOOLEAN,
BYTES,
DATE,
DATETIME,
FLOAT,
FLOAT64,
INT64,
INTEGER,
FLOAT64,
FLOAT,
TIMESTAMP,
DATETIME,
DATE,
BYTES,
TIME,
RECORD,
NUMERIC,
BIGNUMERIC,
RECORD,
STRING,
TIME,
TIMESTAMP,
)

__all__ = [
"ARRAY",
"BIGNUMERIC",
"BigQueryDialect",
"STRING",
"BOOL",
"BOOLEAN",
"BYTES",
"DATE",
"DATETIME",
"FLOAT",
"FLOAT64",
"INT64",
"INTEGER",
"FLOAT64",
"FLOAT",
"TIMESTAMP",
"DATETIME",
"DATE",
"BYTES",
"TIME",
"RECORD",
"NUMERIC",
"BIGNUMERIC",
"RECORD",
"STRING",
"TIME",
"TIMESTAMP",
]

try:
@@ -62,6 +62,8 @@

FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+")

TABLE_VALUED_ALIAS_ALIASES = "bigquery_table_valued_alias_aliases"


def assert_(cond, message="Assertion failed"): # pragma: NO COVER
if not cond:
@@ -114,39 +116,41 @@ def format_label(self, label, name=None):


_type_map = {
"STRING": types.String,
"BOOL": types.Boolean,
"ARRAY": types.ARRAY,
"BIGNUMERIC": types.Numeric,
"BOOLEAN": types.Boolean,
"INT64": types.Integer,
"INTEGER": types.Integer,
"BOOL": types.Boolean,
"BYTES": types.BINARY,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"FLOAT64": types.Float,
"FLOAT": types.Float,
"INT64": types.Integer,
"INTEGER": types.Integer,
"NUMERIC": types.Numeric,
"RECORD": types.JSON,
"STRING": types.String,
"TIMESTAMP": types.TIMESTAMP,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"BYTES": types.BINARY,
"TIME": types.TIME,
"RECORD": types.JSON,
"NUMERIC": types.Numeric,
"BIGNUMERIC": types.Numeric,
}

# By convention, dialect-provided types are spelled with all upper case.
STRING = _type_map["STRING"]
BOOL = _type_map["BOOL"]
ARRAY = _type_map["ARRAY"]
BIGNUMERIC = _type_map["NUMERIC"]
BOOLEAN = _type_map["BOOLEAN"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
BOOL = _type_map["BOOL"]
BYTES = _type_map["BYTES"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
FLOAT64 = _type_map["FLOAT64"]
FLOAT = _type_map["FLOAT"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
NUMERIC = _type_map["NUMERIC"]
RECORD = _type_map["RECORD"]
STRING = _type_map["STRING"]
TIMESTAMP = _type_map["TIMESTAMP"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
BYTES = _type_map["BYTES"]
TIME = _type_map["TIME"]
RECORD = _type_map["RECORD"]
NUMERIC = _type_map["NUMERIC"]
BIGNUMERIC = _type_map["NUMERIC"]

try:
_type_map["GEOGRAPHY"] = GEOGRAPHY
@@ -246,6 +250,56 @@ def visit_insert(self, insert_stmt, asfrom=False, **kw):
insert_stmt, asfrom=False, **kw
)

def visit_table_valued_alias(self, element, **kw):
# When using table-valued functions, like UNNEST, BigQuery requires a
# FROM for any table referenced in the function, including expressions
# in function arguments.
#
# For example, given SQLAlchemy code:
#
# print(
# select([func.unnest(foo.c.objects).alias('foo_objects').column])
# .compile(engine))
#
# Left to it's own devices, SQLAlchemy would outout:
#
# SELECT `foo_objects`
# FROM unnest(`foo`.`objects`) AS `foo_objects`
#
# But BigQuery diesn't understand the `foo` reference unless
# we add as reference to `foo` in the FROM:
#
# SELECT foo_objects
# FROM `foo`, UNNEST(`foo`.`objects`) as foo_objects
#
# This is tricky because:
# 1. We have to find the table references.
# 2. We can't know practically if there's already a FROM for a table.
#
# We leverage visit_column to find a table reference. Whenever we find
# one, we create an alias for it, so as not to conflict with an existing
# reference if one is present.
#
# This requires communicating between this function and visit_column.
# We do this by sticking a dictionary in the keyword arguments.
# This dictionary:
# a. Tells visit_column that it's an a table-valued alias expresssion, and
# b. Gives it a place to record the aliases it creates.
#
# This function creates aliases in the FROM list for any aliases recorded
# by visit_column.

kw[TABLE_VALUED_ALIAS_ALIASES] = {}
ret = super().visit_table_valued_alias(element, **kw)
aliases = kw.pop(TABLE_VALUED_ALIAS_ALIASES)
if aliases:
aliases = ", ".join(
f"{self.preparer.quote(tablename)} {self.preparer.quote(alias)}"
for tablename, alias in aliases.items()
)
ret = f"{aliases}, {ret}"
return ret

def visit_column(
self,
column,
@@ -281,6 +335,13 @@ def visit_column(
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
elif TABLE_VALUED_ALIAS_ALIASES in kwargs:
aliases = kwargs[TABLE_VALUED_ALIAS_ALIASES]
if tablename not in aliases:
aliases[tablename] = self.anon_map[
f"{TABLE_VALUED_ALIAS_ALIASES} {tablename}"
]
tablename = aliases[tablename]

return self.preparer.quote(tablename) + "." + name

@@ -19,6 +19,7 @@

import datetime
import mock
import packaging.version
import pytest
import pytz
import sqlalchemy
@@ -41,7 +42,7 @@
)


if sqlalchemy.__version__ < "1.4":
if packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"):
from sqlalchemy.testing.suite import LimitOffsetTest as _LimitOffsetTest

class LimitOffsetTest(_LimitOffsetTest):
@@ -28,13 +28,13 @@
from sqlalchemy.sql import expression, select, literal_column
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import sessionmaker
import packaging.version
from pytz import timezone
import pytest
import sqlalchemy
import datetime
import decimal


ONE_ROW_CONTENTS_EXPANDED = [
588,
datetime.datetime(2013, 10, 10, 11, 27, 16, tzinfo=timezone("UTC")),
@@ -725,3 +725,31 @@ class MyTable(Base):
)

assert sorted(db.query(sqlalchemy.distinct(MyTable.my_column)).all()) == expected


@pytest.mark.skipif(
packaging.version.parse(sqlalchemy.__version__) < packaging.version.parse("1.4"),
reason="unnest (and other table-valued-function) support required version 1.4",
)
def test_unnest(engine, bigquery_dataset):
from sqlalchemy import select, func, String
from sqlalchemy_bigquery import ARRAY

conn = engine.connect()
metadata = MetaData()
table = Table(
f"{bigquery_dataset}.test_unnest", metadata, Column("objects", ARRAY(String)),
)
metadata.create_all(engine)
conn.execute(
table.insert(), [dict(objects=["a", "b", "c"]), dict(objects=["x", "y"])]
)
query = select([func.unnest(table.c.objects).alias("foo_objects").column])
compiled = str(query.compile(engine))
assert " ".join(compiled.strip().split()) == (
f"SELECT `foo_objects`"
f" FROM"
f" `{bigquery_dataset}.test_unnest` `{bigquery_dataset}.test_unnest_1`,"
f" unnest(`{bigquery_dataset}.test_unnest_1`.`objects`) AS `foo_objects`"
)
assert sorted(r[0] for r in conn.execute(query)) == ["a", "b", "c", "x", "y"]
@@ -21,20 +21,24 @@
import mock
import sqlite3

import packaging.version
import pytest
import sqlalchemy

import fauxdbi

sqlalchemy_version_info = tuple(map(int, sqlalchemy.__version__.split(".")))
sqlalchemy_version = packaging.version.parse(sqlalchemy.__version__)
sqlalchemy_1_3_or_higher = pytest.mark.skipif(
sqlalchemy_version_info < (1, 3), reason="requires sqlalchemy 1.3 or higher"
sqlalchemy_version < packaging.version.parse("1.3"),
reason="requires sqlalchemy 1.3 or higher",
)
sqlalchemy_1_4_or_higher = pytest.mark.skipif(
sqlalchemy_version_info < (1, 4), reason="requires sqlalchemy 1.4 or higher"
sqlalchemy_version < packaging.version.parse("1.4"),
reason="requires sqlalchemy 1.4 or higher",
)
sqlalchemy_before_1_4 = pytest.mark.skipif(
sqlalchemy_version_info >= (1, 4), reason="requires sqlalchemy 1.3 or lower"
sqlalchemy_version >= packaging.version.parse("1.4"),
reason="requires sqlalchemy 1.3 or lower",
)


0 comments on commit 9b5b002

Please sign in to comment.