diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index 78cd4fab..e30f0e1f 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -1513,40 +1513,45 @@ def get_multi_foreign_keys( tc.table_schema, tc.table_name, tc.constraint_name, - ctu.table_name, - ctu.table_schema, - ARRAY_AGG(DISTINCT ccu.column_name), - ARRAY_AGG( - DISTINCT CONCAT( - CAST(kcu.ordinal_position AS STRING), - '_____', - kcu.column_name - ) + tc_uq.table_name, + tc_uq.table_schema, + -- Find the corresponding pairs of columns for the foreign key constraint + -- and its related unique constraint. + ARRAY( + SELECT (kcu.column_name, kcu_uq.column_name) + FROM information_schema.key_column_usage AS kcu + JOIN information_schema.key_column_usage AS kcu_uq + ON kcu_uq.constraint_catalog = rc.unique_constraint_catalog + AND kcu_uq.constraint_schema = rc.unique_constraint_schema + AND kcu_uq.constraint_name = rc.unique_constraint_name + AND kcu_uq.ordinal_position = kcu.ordinal_position + WHERE + kcu.constraint_catalog = tc.constraint_catalog + AND kcu.constraint_schema = tc.constraint_schema + AND kcu.constraint_name = tc.constraint_name + ORDER BY kcu.ordinal_position ) FROM information_schema.table_constraints AS tc - JOIN information_schema.constraint_column_usage AS ccu - ON ccu.constraint_catalog = tc.table_catalog - and ccu.constraint_schema = tc.table_schema - and ccu.constraint_name = tc.constraint_name - JOIN information_schema.constraint_table_usage AS ctu - ON ctu.constraint_catalog = tc.table_catalog - and ctu.constraint_schema = tc.table_schema - and ctu.constraint_name = tc.constraint_name - JOIN information_schema.key_column_usage AS kcu - ON kcu.table_catalog = tc.table_catalog - and kcu.table_schema = tc.table_schema - and kcu.constraint_name = tc.constraint_name + -- Join the foreign key constraint for the referring table. + JOIN information_schema.referential_constraints AS rc + ON rc.constraint_catalog = tc.constraint_catalog + AND rc.constraint_schema = tc.constraint_schema + AND rc.constraint_name = tc.constraint_name + -- Join the corresponding unique constraint on the referenced table. + JOIN information_schema.table_constraints AS tc_uq + ON tc_uq.constraint_catalog = rc.unique_constraint_catalog + AND tc_uq.constraint_schema = rc.unique_constraint_schema + AND tc_uq.constraint_name = rc.unique_constraint_name + -- Join in the tables view so WHERE filters can reference fields in it. JOIN information_schema.tables AS t ON t.table_catalog = tc.table_catalog - and t.table_schema = tc.table_schema - and t.table_name = tc.table_name + AND t.table_schema = tc.table_schema + AND t.table_name = tc.table_name WHERE {table_filter_query} {table_type_query} {schema_filter_query} tc.constraint_type = "FOREIGN KEY" - GROUP BY tc.table_name, tc.table_schema, tc.constraint_name, - ctu.table_name, ctu.table_schema """.format( table_filter_query=table_filter_query, table_type_query=table_type_query, @@ -1558,29 +1563,16 @@ def get_multi_foreign_keys( result_dict = {} for row in rows: - # Due to Spanner limitations, arrays order is not guaranteed during - # aggregation. Still, for constraints it's vital to keep the order - # of the referred columns, otherwise SQLAlchemy and Alembic may start - # to occasionally drop and recreate constraints. To avoid this, the - # method uses prefixes with the `key_column_usage.ordinal_position` - # values to ensure the columns are aggregated into an array in the - # correct order. Prefixes are only used under the hood. For more details - # see the issue: - # https://github.com/googleapis/python-spanner-sqlalchemy/issues/271 - # - # The solution seem a bit clumsy, and should be improved as soon as a - # better approach found. row[0] = row[0] or None table_info = result_dict.get((row[0], row[1]), []) - for index, value in enumerate(sorted(row[6])): - row[6][index] = value.split("_____")[1] + constrained_columns, referred_columns = zip(*row[5]) fk_info = { "name": row[2], "referred_table": row[3], "referred_schema": row[4] or None, - "referred_columns": row[5], - "constrained_columns": row[6], + "referred_columns": list(referred_columns), + "constrained_columns": list(constrained_columns), } table_info.append(fk_info) diff --git a/test/system/test_basics.py b/test/system/test_basics.py index 75d9682f..bdd7dec3 100644 --- a/test/system/test_basics.py +++ b/test/system/test_basics.py @@ -14,12 +14,15 @@ import datetime import os from typing import Optional + +import pytest from sqlalchemy import ( text, Table, Column, Integer, ForeignKey, + ForeignKeyConstraint, PrimaryKeyConstraint, String, Index, @@ -96,6 +99,25 @@ def define_tables(cls, metadata): Column("color", String(20)), schema="schema", ) + # Add a composite primary key & foreign key example. + Table( + "composite_pk", + metadata, + Column("a", String, primary_key=True), + Column("b", String, primary_key=True), + ) + Table( + "composite_fk", + metadata, + Column("my_a", String, primary_key=True), + Column("my_b", String, primary_key=True), + Column("my_c", String, primary_key=True), + ForeignKeyConstraint( + ["my_a", "my_b"], + ["composite_pk.a", "composite_pk.b"], + name="composite_fk_composite_pk_a_b", + ), + ) def test_hello_world(self, connection): greeting = connection.execute(text("select 'Hello World'")) @@ -115,7 +137,7 @@ def test_reflect(self, connection): engine = connection.engine meta: MetaData = MetaData() meta.reflect(bind=engine) - eq_(3, len(meta.tables)) + eq_(5, len(meta.tables)) table = meta.tables["numbers"] eq_(5, len(table.columns)) eq_("number", table.columns[0].name) @@ -269,6 +291,13 @@ class User(Base): eq_(len(inserted_rows), len(selected_rows)) eq_(set(inserted_rows), set(selected_rows)) + @pytest.mark.skipif( + os.environ.get("SPANNER_EMULATOR_HOST") is not None, + reason=( + "Fails in emulator due to bug: " + "https://github.com/GoogleCloudPlatform/cloud-spanner-emulator/issues/279" + ), + ) def test_cross_schema_fk_lookups(self, connection): """Ensures we introspect FKs within & across schema.""" @@ -306,6 +335,27 @@ def test_cross_schema_fk_lookups(self, connection): ), ) + def test_composite_fk_lookups(self, connection): + """Ensures we introspect composite FKs.""" + + engine = connection.engine + + insp = inspect(engine) + eq_( + { + (None, "composite_fk"): [ + { + "name": "composite_fk_composite_pk_a_b", + "referred_table": "composite_pk", + "referred_schema": None, + "referred_columns": ["a", "b"], + "constrained_columns": ["my_a", "my_b"], + } + ] + }, + insp.get_multi_foreign_keys(filter_names=["composite_fk"]), + ) + def test_commit_timestamp(self, connection): """Ensures commit timestamps are set."""