Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support computed columns #139

Merged
merged 19 commits into from
Nov 19, 2021
7 changes: 7 additions & 0 deletions google/cloud/sqlalchemy_spanner/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@

class Requirements(SuiteRequirements):
@property
def computed_columns(self):
return exclusions.open()

@property
def computed_columns_stored(self):
return exclusions.open()

def sane_rowcount(self):
return exclusions.closed()

Expand Down
29 changes: 21 additions & 8 deletions google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ def limit_clause(self, select, **kw):
class SpannerDDLCompiler(DDLCompiler):
"""Spanner DDL statements compiler."""

def visit_computed_column(self, generated, **kw):
"""Computed column operator."""
text = "AS (%s) STORED" % self.sql_compiler.process(
generated.sqltext, include_table=False, literal_binds=True
)
return text

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overriding the method to drop GENERATED ALWAYS part of the statement

def visit_drop_table(self, drop_table):
"""
Cloud Spanner doesn't drop tables which have indexes
Expand Down Expand Up @@ -479,7 +486,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
list: The table every column dict-like description.
"""
sql = """
SELECT column_name, spanner_type, is_nullable
SELECT column_name, spanner_type, is_nullable, generation_expression
FROM information_schema.columns
WHERE
table_catalog = ''
Expand All @@ -499,14 +506,20 @@ def get_columns(self, connection, table_name, schema=None, **kw):
columns = snap.execute_sql(sql)

for col in columns:
cols_desc.append(
{
"name": col[0],
"type": self._designate_type(col[1]),
"nullable": col[2] == "YES",
"default": None,
col_desc = {
"name": col[0],
"type": self._designate_type(col[1]),
"nullable": col[2] == "YES",
"default": None,
}

if col[3] is not None:
col_desc["computed"] = {
"persisted": True,
"sqltext": col[3],
}
)
cols_desc.append(col_desc)

IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
return cols_desc

def _designate_type(self, str_repr):
Expand Down
98 changes: 98 additions & 0 deletions test/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
from sqlalchemy import ForeignKey
from sqlalchemy import MetaData
from sqlalchemy.schema import DDL
from sqlalchemy.schema import Computed
from sqlalchemy.testing import config
from sqlalchemy.testing import engines
from sqlalchemy.testing import eq_
from sqlalchemy.testing import provide_metadata, emits_warning
from sqlalchemy.testing import fixtures
from sqlalchemy.testing import is_true
from sqlalchemy.testing.provision import temp_table_keyword_args
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
Expand All @@ -54,6 +56,9 @@
from sqlalchemy.types import Numeric
from sqlalchemy.types import Text
from sqlalchemy.testing import requires
from sqlalchemy.testing.fixtures import (
ComputedReflectionFixtureTest as _ComputedReflectionFixtureTest,
)

from google.api_core.datetime_helpers import DatetimeWithNanoseconds

Expand Down Expand Up @@ -89,6 +94,7 @@
QuotedNameArgumentTest as _QuotedNameArgumentTest,
ComponentReflectionTest as _ComponentReflectionTest,
CompositeKeyReflectionTest as _CompositeKeyReflectionTest,
ComputedReflectionTest as _ComputedReflectionTest,
)
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
from sqlalchemy.testing.suite.test_types import ( # noqa: F401, F403
Expand Down Expand Up @@ -1594,3 +1600,95 @@ def test_read_only(self):
with self._engine.connect().execution_options(read_only=True) as connection:
connection.execute(select(["*"], from_obj=self._table)).fetchall()
assert connection.connection.read_only is True


class ComputedReflectionFixtureTest(_ComputedReflectionFixtureTest):
@classmethod
def define_tables(cls, metadata):
"""SPANNER OVERRIDE:

Avoid using default values for computed columns.
"""
Table(
"computed_default_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_col", Integer, Computed("normal + 42")),
Column("with_default", Integer),
)

t = Table(
"computed_column_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_no_flag", Integer, Computed("normal + 42")),
)

if testing.requires.schemas.enabled:
t2 = Table(
"computed_column_table",
metadata,
Column("id", Integer, primary_key=True),
Column("normal", Integer),
Column("computed_no_flag", Integer, Computed("normal / 42")),
schema=config.test_schema,
)

if testing.requires.computed_columns_virtual.enabled:
t.append_column(
Column(
"computed_virtual",
Integer,
Computed("normal + 2", persisted=False),
)
)
if testing.requires.schemas.enabled:
t2.append_column(
Column(
"computed_virtual",
Integer,
Computed("normal / 2", persisted=False),
)
)
if testing.requires.computed_columns_stored.enabled:
t.append_column(
Column(
"computed_stored", Integer, Computed("normal - 42", persisted=True),
)
)
if testing.requires.schemas.enabled:
t2.append_column(
Column(
"computed_stored",
Integer,
Computed("normal * 42", persisted=True),
)
)


class ComputedReflectionTest(_ComputedReflectionTest, ComputedReflectionFixtureTest):
@pytest.mark.skip("Default values are not supported.")
def test_computed_col_default_not_set(self):
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
pass

def test_get_column_returns_computed(self):
IlyaFaer marked this conversation as resolved.
Show resolved Hide resolved
"""
SPANNER OVERRIDE:

In Spanner all the generated columns are STORED,
meaning there are no persisted and not persisted
(in the terms of the SQLAlchemy) columns. The
method override omits the persistence reflection checks.
"""
insp = inspect(config.db)

cols = insp.get_columns("computed_default_table")
data = {c["name"]: c for c in cols}
for key in ("id", "normal", "with_default"):
is_true("computed" not in data[key])
compData = data["computed_col"]
is_true("computed" in compData)
is_true("sqltext" in compData["computed"])
eq_(self.normalize(compData["computed"]["sqltext"]), "normal+42")