From 91e35cf487af8e11e987d52670101c502603acee Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 03:13:04 +0000 Subject: [PATCH 1/6] feat: topological sorting and foreign key retrieval across multiple database adapters - Added `get_tables_in_topological_order` method to DuckDB, OracleDB, Psycopg, and SQLite data dictionaries to retrieve tables in dependency order. - Introduced `get_foreign_keys` method to DuckDB, OracleDB, Psycopg, and SQLite data dictionaries for fetching foreign key metadata. - Enhanced index retrieval methods in OracleDB, Psycopg, and SQLite data dictionaries. - Created integration tests for topological sorting and foreign key metadata in AioSQLite and AsyncPG adapters. - Added `ForeignKeyMetadata`, `ColumnMetadata`, and `IndexMetadata` classes to the common driver module for structured metadata representation. --- docs/extensions/aiosql/api.rst | 2 + docs/extensions/litestar/api.rst | 1 + docs/guides/architecture/data-dictionary.md | 98 ++++++ docs/reference/driver.rst | 37 +++ docs/usage/drivers_and_querying.rst | 2 +- sqlspec/adapters/adbc/data_dictionary.py | 180 ++++++++++- sqlspec/adapters/aiosqlite/data_dictionary.py | 121 ++++++++ sqlspec/adapters/asyncmy/data_dictionary.py | 140 +++++++-- sqlspec/adapters/asyncpg/data_dictionary.py | 146 +++++++++ sqlspec/adapters/bigquery/data_dictionary.py | 71 +++++ sqlspec/adapters/duckdb/data_dictionary.py | 111 +++++++ sqlspec/adapters/oracledb/data_dictionary.py | 167 ++++++++++ sqlspec/adapters/psycopg/data_dictionary.py | 293 ++++++++++++++++++ sqlspec/adapters/sqlite/data_dictionary.py | 117 +++++++ sqlspec/driver/__init__.py | 8 + sqlspec/driver/_async.py | 36 +++ sqlspec/driver/_common.py | 187 +++++++++++ sqlspec/driver/_sync.py | 36 +++ .../test_aiosqlite/test_data_dictionary.py | 73 +++++ .../test_asyncpg/test_data_dictionary.py | 60 ++++ 20 files changed, 1856 insertions(+), 30 deletions(-) create mode 100644 docs/guides/architecture/data-dictionary.md create mode 100644 tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py diff --git a/docs/extensions/aiosql/api.rst b/docs/extensions/aiosql/api.rst index 6fe53f376..28f367484 100644 --- a/docs/extensions/aiosql/api.rst +++ b/docs/extensions/aiosql/api.rst @@ -24,6 +24,7 @@ AiosqlAsyncAdapter :members: :undoc-members: :show-inheritance: + :no-index: AiosqlSyncAdapter ----------------- @@ -32,6 +33,7 @@ AiosqlSyncAdapter :members: :undoc-members: :show-inheritance: + :no-index: Query Operators =============== diff --git a/docs/extensions/litestar/api.rst b/docs/extensions/litestar/api.rst index d2c157ede..eb799a01d 100644 --- a/docs/extensions/litestar/api.rst +++ b/docs/extensions/litestar/api.rst @@ -11,6 +11,7 @@ SQLSpecPlugin :members: :undoc-members: :show-inheritance: + :no-index: Configuration ============= diff --git a/docs/guides/architecture/data-dictionary.md b/docs/guides/architecture/data-dictionary.md new file mode 100644 index 000000000..728b9238f --- /dev/null +++ b/docs/guides/architecture/data-dictionary.md @@ -0,0 +1,98 @@ +# Data Dictionary & Introspection + +SQLSpec provides a unified Data Dictionary API to introspect database schemas across all supported adapters. This allows you to retrieve table metadata, columns, indexes, and foreign keys in a consistent format, regardless of the underlying database engine. + +## Core Concepts + +The `DataDictionary` is accessed via the `driver.data_dictionary` property. It provides methods to query the database catalog. + +### Introspection Capabilities + +- **Tables**: List tables in a schema. +- **Columns**: Get column details (name, type, nullable, default). +- **Indexes**: Get index definitions (columns, uniqueness). +- **Foreign Keys**: Get foreign key constraints and relationships. +- **Topological Sorting**: Get tables sorted by dependency order (useful for cleanups or migrations). + +## Usage + +### Basic Introspection + +```python +async with config.provide_session() as session: + # Get all tables in the default schema + tables = await session.data_dictionary.get_tables(session) + print(f"Tables: {tables}") + + # Get columns for a specific table + columns = await session.data_dictionary.get_columns(session, "users") + for col in columns: + print(f"{col['column_name']}: {col['data_type']}") +``` + +### Topological Sort (Dependency Ordering) + +One of the most powerful features is `get_tables_in_topological_order`. This returns table names sorted such that parent tables appear before child tables (tables with foreign keys to parents). + +This is essential for: + +- **Data Loading**: Insert into parents first. +- **Cleanup**: Delete in reverse order to avoid foreign key violations. + +```python +async with config.provide_session() as session: + # Get tables sorted parent -> child + sorted_tables = await session.data_dictionary.get_tables_in_topological_order(session) + + print("Insertion Order:", sorted_tables) + print("Deletion Order:", list(reversed(sorted_tables))) +``` + +**Implementation Details**: + +- **Postgres / SQLite / MySQL 8+**: Uses efficient Recursive CTEs in SQL. +- **Oracle**: Uses `CONNECT BY` queries. +- **Others (BigQuery, MySQL 5.7)**: Falls back to a Python-based topological sort using `graphlib`. + +### Metadata Types + +SQLSpec uses typed dataclasses for metadata results where possible. + +```python +from sqlspec.driver import ForeignKeyMetadata + +async with config.provide_session() as session: + fks: list[ForeignKeyMetadata] = await session.data_dictionary.get_foreign_keys(session, "orders") + + for fk in fks: + print(f"FK: {fk.column_name} -> {fk.referenced_table}.{fk.referenced_column}") +``` + +## Adapter Support Matrix + +| Feature | Postgres | SQLite | Oracle | MySQL | DuckDB | BigQuery | +|---------|----------|--------|--------|-------|--------|----------| +| Tables | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Columns | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Indexes | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | +| Foreign Keys | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Topological Sort | ✅ (CTE) | ✅ (CTE) | ✅ (Connect By) | ✅ (CTE/Python) | ✅ (CTE) | ✅ (Python) | + +## API Reference + +### Data Dictionary Protocol + +The base interface shared by all adapters. + +```python +class DataDictionaryBase: + async def get_tables(self, driver, schema=None) -> list[str]: ... + + async def get_columns(self, driver, table, schema=None) -> list[dict]: ... + + async def get_indexes(self, driver, table, schema=None) -> list[dict]: ... + + async def get_foreign_keys(self, driver, table=None, schema=None) -> list[ForeignKeyMetadata]: ... + + async def get_tables_in_topological_order(self, driver, schema=None) -> list[str]: ... +``` diff --git a/docs/reference/driver.rst b/docs/reference/driver.rst index c404a762d..7a0e89a46 100644 --- a/docs/reference/driver.rst +++ b/docs/reference/driver.rst @@ -103,6 +103,43 @@ Connection Pooling :undoc-members: :show-inheritance: +Data Dictionary +=============== + +The Data Dictionary API provides standardized introspection capabilities across all supported databases. + +.. currentmodule:: sqlspec.driver + +.. autoclass:: DataDictionaryMixin + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: AsyncDataDictionaryBase + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: SyncDataDictionaryBase + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ForeignKeyMetadata + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ColumnMetadata + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: IndexMetadata + :members: + :undoc-members: + :show-inheritance: + Driver Protocols ================ diff --git a/docs/usage/drivers_and_querying.rst b/docs/usage/drivers_and_querying.rst index d73bd66ee..3f57f3940 100644 --- a/docs/usage/drivers_and_querying.rst +++ b/docs/usage/drivers_and_querying.rst @@ -471,7 +471,7 @@ Performance Tips :start-after: # start-example :end-before: # end-example :caption: ``asyncpg connection pooling`` - :dedent: 4 + :dedent: 2 **2. Batch Operations** diff --git a/sqlspec/adapters/adbc/data_dictionary.py b/sqlspec/adapters/adbc/data_dictionary.py index 77c4f14cc..e608b90e6 100644 --- a/sqlspec/adapters/adbc/data_dictionary.py +++ b/sqlspec/adapters/adbc/data_dictionary.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo +from sqlspec.driver import ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -27,6 +27,184 @@ class AdbcDataDictionary(SyncDataDictionaryBase): Delegates to appropriate dialect-specific logic based on the driver's dialect. """ + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata based on detected dialect.""" + + dialect = self._get_dialect(driver) + adbc_driver = cast("AdbcDriver", driver) + + if dialect == "sqlite": + if table: + # Single table + result = adbc_driver.execute(f"PRAGMA foreign_key_list('{table}')") + # SQLite PRAGMA returns: id, seq, table, from, to, on_update, on_delete, match + # We need 'from' (col) and 'table' (ref_table) and 'to' (ref_col) + # Note: PRAGMA results from ADBC might be keyed by name or index depending on driver + return [ + ForeignKeyMetadata( + table_name=table, + column_name=row["from"] if isinstance(row, dict) else row[3], + referenced_table=row["table"] if isinstance(row, dict) else row[2], + referenced_column=row["to"] if isinstance(row, dict) else row[4], + ) + for row in result.data + ] + # For all tables in SQLite we'd have to iterate, which base class doesn't do efficiently. + # We'll just return empty if no table specified for now, or iterate if crucial. + # Base implementation will call this per-table if needed? No, base implementation expects all if table is None. + # For SQLite ADBC, iterating tables is expensive. Let's support single table primarily. + return [] + + # SQL-standard compliant databases (Postgres, MySQL, DuckDB, BigQuery) + # They all support information_schema.key_column_usage roughly the same way + + # Postgres/DuckDB/MySQL query + where_clauses = [] + params = [] + + if dialect == "bigquery": + dataset = schema + if not dataset: + return [] # BigQuery requires dataset for info schema + kcu = f"`{dataset}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE`" + rc = f"`{dataset}.INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS`" + # BQ uses named params usually or positional? ADBC usually positional '?' + # But BQ driver might want named. ADBC standardizes on '?' usually. + sql = f""" + SELECT + kcu.table_name, + kcu.column_name, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name, + kcu.constraint_name, + kcu.table_schema, + pk_kcu.table_schema AS referenced_table_schema + FROM {kcu} kcu + JOIN {rc} rc ON kcu.constraint_name = rc.constraint_name + JOIN {kcu} pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + """ + if table: + sql += f" WHERE kcu.table_name = '{table}'" # Simple string sub for BQ ADBC safety check needed? + + try: + result = adbc_driver.execute(sql) + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + except Exception: + return [] + + # Standard ANSI SQL (Postgres, MySQL, DuckDB) + kcu = "information_schema.key_column_usage" + + if dialect == "postgres": + # Postgres joins with constraint_column_usage or referential_constraints + # Let's use the query we verified for asyncpg + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + ccu.table_name AS referenced_table_name, + ccu.column_name AS referenced_column_name, + tc.constraint_name, + tc.table_schema, + ccu.table_schema AS referenced_table_schema + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + """ + if schema: + sql += " AND tc.table_schema = ?" + params.append(schema) + if table: + sql += " AND tc.table_name = ?" + params.append(table) + + elif dialect == "mysql": + # MySQL information_schema + sql = """ + SELECT + table_name, + column_name, + referenced_table_name, + referenced_column_name, + constraint_name, + table_schema, + referenced_table_schema + FROM information_schema.key_column_usage + WHERE referenced_table_name IS NOT NULL + """ + if schema: + sql += " AND table_schema = ?" + params.append(schema) + if table: + sql += " AND table_name = ?" + params.append(table) + + elif dialect == "duckdb": + # DuckDB similar to Postgres but sometimes requires referential_constraints join + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name, + kcu.constraint_name, + kcu.table_schema, + pk_kcu.table_schema AS referenced_table_schema + FROM information_schema.key_column_usage kcu + JOIN information_schema.referential_constraints rc + ON kcu.constraint_name = rc.constraint_name + JOIN information_schema.key_column_usage pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + WHERE 1=1 + """ + if schema: + sql += " AND kcu.table_schema = ?" + params.append(schema) + if table: + sql += " AND kcu.table_name = ?" + params.append(table) + else: + return [] + + try: + result = adbc_driver.execute(sql, tuple(params)) + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row.get("table_schema"), + referenced_schema=row.get("referenced_table_schema"), + ) + for row in result.data + ] + except Exception: + return [] + def _get_dialect(self, driver: SyncDriverAdapterBase) -> str: """Get dialect from ADBC driver. diff --git a/sqlspec/adapters/aiosqlite/data_dictionary.py b/sqlspec/adapters/aiosqlite/data_dictionary.py index d841cce2f..98318a7e6 100644 --- a/sqlspec/adapters/aiosqlite/data_dictionary.py +++ b/sqlspec/adapters/aiosqlite/data_dictionary.py @@ -10,6 +10,7 @@ from collections.abc import Callable from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver + from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.aiosqlite.data_dictionary") @@ -129,6 +130,126 @@ async def get_columns( for row in result.data or [] ] + async def get_tables_in_topological_order( + self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order.""" + aiosqlite_driver = cast("AiosqliteDriver", driver) + + # Assuming modern SQLite with pragma table-valued functions + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + m.name as table_name, + 0 as level, + '/' || m.name || '/' as path + FROM sqlite_schema m + WHERE m.type = 'table' + AND m.name NOT LIKE 'sqlite_%' + AND NOT EXISTS ( + SELECT 1 FROM pragma_foreign_key_list(m.name) + ) + + UNION ALL + + SELECT + m.name as table_name, + dt.level + 1, + dt.path || m.name || '/' + FROM sqlite_schema m + JOIN pragma_foreign_key_list(m.name) fk + JOIN dependency_tree dt ON fk."table" = dt.table_name + WHERE m.type = 'table' + AND m.name NOT LIKE 'sqlite_%' + AND instr(dt.path, '/' || m.name || '/') = 0 + ) + SELECT DISTINCT table_name, level FROM dependency_tree ORDER BY level, table_name; + """ + try: + result = await aiosqlite_driver.execute(sql) + return [row["table_name"] if isinstance(row, dict) else row[0] for row in result.data] + except Exception: + # Fallback to Python sort if TVF not supported or other error + return await super().get_tables_in_topological_order(driver, schema) + + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + aiosqlite_driver = cast("AiosqliteDriver", driver) + + if table: + # Single table optimization + sql = f"SELECT '{table}' as table_name, fk.* FROM pragma_foreign_key_list('{table}') fk" + result = await aiosqlite_driver.execute(sql) + else: + # All tables + sql = """ + SELECT m.name as table_name, fk.* + FROM sqlite_schema m, pragma_foreign_key_list(m.name) fk + WHERE m.type = 'table' AND m.name NOT LIKE 'sqlite_%' + """ + result = await aiosqlite_driver.execute(sql) + + fks = [] + for row in result.data: + if isinstance(row, (list, tuple)): + table_name = row[0] + ref_table = row[3] + col = row[4] + ref_col = row[5] + else: + table_name = row["table_name"] + ref_table = row["table"] + col = row["from"] + ref_col = row["to"] + + fks.append( + ForeignKeyMetadata( + table_name=table_name, + column_name=col, + referenced_table=ref_table, + referenced_column=ref_col, + constraint_name=None, + schema=None, + referenced_schema=None, + ) + ) + return fks + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + aiosqlite_driver = cast("AiosqliteDriver", driver) + + # 1. Get indexes for table + index_list_res = await aiosqlite_driver.execute(f"PRAGMA index_list('{table}')") + indexes = [] + + for idx_row in index_list_res.data: + if isinstance(idx_row, (list, tuple)): + idx_name = idx_row[1] + unique = bool(idx_row[2]) + else: + idx_name = idx_row["name"] + unique = bool(idx_row["unique"]) + + # 2. Get columns for index + info_res = await aiosqlite_driver.execute(f"PRAGMA index_info('{idx_name}')") + cols = [] + for col_row in info_res.data: + if isinstance(col_row, (list, tuple)): + cols.append(col_row[2]) + else: + cols.append(col_row["name"]) + + indexes.append({"name": idx_name, "columns": cols, "unique": unique, "primary": False, "table_name": table}) + + return indexes + def list_available_features(self) -> "list[str]": """List available SQLite feature flags. diff --git a/sqlspec/adapters/asyncmy/data_dictionary.py b/sqlspec/adapters/asyncmy/data_dictionary.py index c8bd81421..75bffcb8b 100644 --- a/sqlspec/adapters/asyncmy/data_dictionary.py +++ b/sqlspec/adapters/asyncmy/data_dictionary.py @@ -10,6 +10,7 @@ from collections.abc import Callable from sqlspec.adapters.asyncmy.driver import AsyncmyDriver + from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.asyncmy.data_dictionary") @@ -104,42 +105,125 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "VARCHAR(255)") - async def get_columns( - self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None - ) -> "list[dict[str, Any]]": - """Get column information for a table using information_schema. + async def get_tables_in_topological_order( + self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order.""" + version = await self.get_version(driver) + if version and version >= VersionInfo(8, 0, 1): + # Use Recursive CTE + asyncmy_driver = cast("AsyncmyDriver", driver) + schema_clause = f"'{schema}'" if schema else "DATABASE()" - Args: - driver: AsyncMy driver instance - table: Table name to query columns for - schema: Schema name (database name in MySQL) + sql = f""" + WITH RECURSIVE dependency_tree AS ( + SELECT + table_name, + 0 AS level, + CAST(table_name AS CHAR(4000)) AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = {schema_clause} + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.key_column_usage kcu + WHERE kcu.table_name = t.table_name + AND kcu.table_schema = t.table_schema + AND kcu.referenced_table_name IS NOT NULL + ) + + UNION ALL + + SELECT + kcu.table_name, + dt.level + 1, + CONCAT(dt.path, ',', kcu.table_name) + FROM information_schema.key_column_usage kcu + JOIN dependency_tree dt ON kcu.referenced_table_name = dt.table_name + WHERE kcu.table_schema = {schema_clause} + AND kcu.referenced_table_name IS NOT NULL + AND NOT FIND_IN_SET(kcu.table_name, dt.path) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name + """ + try: + result = await asyncmy_driver.execute(sql) + return [row["table_name"] for row in result.data] + except Exception as exc: + logger.warning("Failed to get tables in topological order via SQL: %s", exc) + + return await super().get_tables_in_topological_order(driver, schema) + + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata - Returns: - List of column metadata dictionaries with keys: - - column_name: Name of the column - - data_type: MySQL data type - - is_nullable: Whether column allows NULL (YES/NO) - - column_default: Default value if any - """ asyncmy_driver = cast("AsyncmyDriver", driver) + where_clauses = ["referenced_table_name IS NOT NULL"] + if table: + where_clauses.append(f"table_name = '{table}'") if schema: - sql = f""" - SELECT column_name, data_type, is_nullable, column_default - FROM information_schema.columns - WHERE table_name = '{table}' AND table_schema = '{schema}' - ORDER BY ordinal_position - """ + where_clauses.append(f"table_schema = '{schema}'") else: - sql = f""" - SELECT column_name, data_type, is_nullable, column_default - FROM information_schema.columns - WHERE table_name = '{table}' - ORDER BY ordinal_position - """ + where_clauses.append("table_schema = DATABASE()") + + where_str = " AND ".join(where_clauses) + + sql = f""" + SELECT + table_name, + column_name, + referenced_table_name, + referenced_column_name, + constraint_name, + table_schema, + referenced_table_schema + FROM information_schema.key_column_usage + WHERE {where_str} + """ + result = await asyncmy_driver.execute(sql) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row.get("referenced_table_schema"), + ) + for row in result.data + ] + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + asyncmy_driver = cast("AsyncmyDriver", driver) + sql = f"SHOW INDEX FROM {table}" if schema is None else f"SHOW INDEX FROM {table} FROM {schema}" result = await asyncmy_driver.execute(sql) - return result.data or [] + # Parse SHOW INDEX output + indexes: dict[str, dict[str, Any]] = {} + for row in result.data: + idx_name = row["Key_name"] + if idx_name not in indexes: + indexes[idx_name] = { + "name": idx_name, + "columns": [], + "unique": row["Non_unique"] == 0, + "primary": idx_name == "PRIMARY", + "table_name": table, + } + indexes[idx_name]["columns"].append(row["Column_name"]) + + return list(indexes.values()) def list_available_features(self) -> "list[str]": """List available MySQL feature flags. diff --git a/sqlspec/adapters/asyncpg/data_dictionary.py b/sqlspec/adapters/asyncpg/data_dictionary.py index 606dd2fe7..3fbdebfe6 100644 --- a/sqlspec/adapters/asyncpg/data_dictionary.py +++ b/sqlspec/adapters/asyncpg/data_dictionary.py @@ -10,6 +10,7 @@ from collections.abc import Callable from sqlspec.adapters.asyncpg.driver import AsyncpgDriver + from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.asyncpg.data_dictionary") @@ -158,6 +159,151 @@ async def get_columns( result = await asyncpg_driver.execute(sql, (table, schema_name)) return result.data or [] + async def get_tables_in_topological_order( + self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order using Recursive CTE.""" + asyncpg_driver = cast("AsyncpgDriver", driver) + schema_name = schema or "public" + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + t.table_name::text, + 0 AS level, + ARRAY[t.table_name::text] AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = $1 + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = t.table_name + AND tc.table_schema = t.table_schema + ) + + UNION ALL + + SELECT + tc.table_name::text, + dt.level + 1, + dt.path || tc.table_name::text + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + JOIN dependency_tree dt + ON ccu.table_name = dt.table_name + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = $1 + AND ccu.table_schema = $1 + AND NOT (tc.table_name = ANY(dt.path)) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name; + """ + result = await asyncpg_driver.execute(sql, (schema_name,)) + return [row["table_name"] for row in result.data] + + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + asyncpg_driver = cast("AsyncpgDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + ccu.table_name AS referenced_table_name, + ccu.column_name AS referenced_column_name, + tc.constraint_name, + tc.table_schema, + ccu.table_schema AS referenced_table_schema + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND ($1::text IS NULL OR tc.table_schema = $1) + AND ($2::text IS NULL OR tc.table_name = $2) + """ + result = await asyncpg_driver.execute(sql, (schema_name, table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + asyncpg_driver = cast("AsyncpgDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + i.relname as index_name, + ix.indisunique as is_unique, + ix.indisprimary as is_primary, + array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + AND i.oid = ix.indexrelid + AND a.attrelid = t.oid + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + AND t.relnamespace = n.oid + AND n.nspname = $1 + AND t.relname = $2 + GROUP BY + i.relname, + ix.indisunique, + ix.indisprimary + """ + result = await asyncpg_driver.execute(sql, (schema_name, table)) + + return [ + { + "name": row["index_name"], + "columns": row["columns"], + "unique": row["is_unique"], + "primary": row["is_primary"], + "table_name": table, + } + for row in result.data + ] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. diff --git a/sqlspec/adapters/bigquery/data_dictionary.py b/sqlspec/adapters/bigquery/data_dictionary.py index dc4eb5d3f..1f66d77a0 100644 --- a/sqlspec/adapters/bigquery/data_dictionary.py +++ b/sqlspec/adapters/bigquery/data_dictionary.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from sqlspec.adapters.bigquery.driver import BigQueryDriver + from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.bigquery.data_dictionary") @@ -125,6 +126,76 @@ def get_columns( result = bigquery_driver.execute(sql) return result.data or [] + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get list of tables in schema.""" + bigquery_driver = cast("BigQueryDriver", driver) + if schema: + sql = f"SELECT table_name FROM `{schema}.INFORMATION_SCHEMA.TABLES` WHERE table_type = 'BASE TABLE'" + else: + sql = "SELECT table_name FROM INFORMATION_SCHEMA.TABLES WHERE table_type = 'BASE TABLE'" + + result = bigquery_driver.execute(sql) + return [row["table_name"] for row in result.data] + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + bigquery_driver = cast("BigQueryDriver", driver) + + dataset = schema + if dataset: + kcu_table = f"`{dataset}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE`" + rc_table = f"`{dataset}.INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS`" + else: + kcu_table = "INFORMATION_SCHEMA.KEY_COLUMN_USAGE" + rc_table = "INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS" + + where_clauses = [] + if table: + where_clauses.append(f"kcu.table_name = '{table}'") + + where_str = " AND ".join(where_clauses) + if where_str: + where_str = "WHERE " + where_str + + sql = f""" + SELECT + kcu.table_name, + kcu.column_name, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name, + kcu.constraint_name, + kcu.table_schema, + pk_kcu.table_schema AS referenced_table_schema + FROM {kcu_table} kcu + JOIN {rc_table} rc ON kcu.constraint_name = rc.constraint_name + JOIN {kcu_table} pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + {where_str} + """ + + try: + result = bigquery_driver.execute(sql) + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + except Exception: + logger.warning("Failed to fetch foreign keys from BigQuery") + return [] + def list_available_features(self) -> "list[str]": """List available BigQuery feature flags. diff --git a/sqlspec/adapters/duckdb/data_dictionary.py b/sqlspec/adapters/duckdb/data_dictionary.py index c4c43957a..efc6b22a1 100644 --- a/sqlspec/adapters/duckdb/data_dictionary.py +++ b/sqlspec/adapters/duckdb/data_dictionary.py @@ -10,6 +10,7 @@ from collections.abc import Callable from sqlspec.adapters.duckdb.driver import DuckDBDriver + from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.duckdb.data_dictionary") @@ -141,6 +142,116 @@ def get_columns( result = duckdb_driver.execute(sql) return result.data or [] + def get_tables_in_topological_order( + self, driver: "SyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order.""" + duckdb_driver = cast("DuckDBDriver", driver) + schema_clause = f"'{schema}'" if schema else "current_schema()" + + sql = f""" + WITH RECURSIVE dependency_tree AS ( + SELECT + table_name, + 0 AS level, + [table_name] AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = {schema_clause} + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.key_column_usage kcu + WHERE kcu.table_name = t.table_name + AND kcu.table_schema = t.table_schema + AND kcu.constraint_name IN (SELECT constraint_name FROM information_schema.referential_constraints) + ) + + UNION ALL + + SELECT + kcu.table_name, + dt.level + 1, + list_append(dt.path, kcu.table_name) + FROM information_schema.key_column_usage kcu + JOIN information_schema.referential_constraints rc ON kcu.constraint_name = rc.constraint_name + JOIN information_schema.key_column_usage pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND rc.unique_constraint_schema = pk_kcu.constraint_schema + JOIN dependency_tree dt ON dt.table_name = pk_kcu.table_name + WHERE kcu.table_schema = {schema_clause} + AND NOT list_contains(dt.path, kcu.table_name) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name + """ + try: + result = duckdb_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] + except Exception: + return self.sort_tables_topologically( + self.get_tables(driver, schema), self.get_foreign_keys(driver, schema=schema) + ) + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + duckdb_driver = cast("DuckDBDriver", driver) + + where_clauses = [] + if schema: + where_clauses.append(f"kcu.table_schema = '{schema}'") + if table: + where_clauses.append(f"kcu.table_name = '{table}'") + + where_str = " AND ".join(where_clauses) if where_clauses else "1=1" + + sql = f""" + SELECT + kcu.table_name, + kcu.column_name, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name, + kcu.constraint_name, + kcu.table_schema, + pk_kcu.table_schema AS referenced_table_schema + FROM information_schema.key_column_usage kcu + JOIN information_schema.referential_constraints rc + ON kcu.constraint_name = rc.constraint_name + JOIN information_schema.key_column_usage pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + WHERE {where_str} + """ + + result = duckdb_driver.execute(sql) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.get_data() + ] + + def get_indexes( + self, driver: "SyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + # DuckDB doesn't expose indexes easily in IS yet, usually just constraints? + # Fallback to empty for now or implementation specific. + # PRD mentions it but no specific instruction on implementation detail if missing. + # Returning empty list. + return [] + def list_available_features(self) -> "list[str]": """List available DuckDB feature flags. diff --git a/sqlspec/adapters/oracledb/data_dictionary.py b/sqlspec/adapters/oracledb/data_dictionary.py index 272f43e74..721b30973 100644 --- a/sqlspec/adapters/oracledb/data_dictionary.py +++ b/sqlspec/adapters/oracledb/data_dictionary.py @@ -17,6 +17,7 @@ from collections.abc import Callable from sqlspec.adapters.oracledb.driver import OracleAsyncDriver, OracleSyncDriver + from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.oracledb.data_dictionary") @@ -139,6 +140,52 @@ def _get_columns_sql(self, table: str, schema: "str | None" = None) -> str: ORDER BY column_id """ + def _get_foreign_keys_sql(self, table: "str | None" = None) -> str: + where_clause = f"AND c.table_name = '{table.upper()}'" if table else "" + return f""" + SELECT + c.table_name, + cc.column_name, + r.table_name AS referenced_table_name, + rcc.column_name AS referenced_column_name, + c.constraint_name, + c.owner AS schema, + r.owner AS referenced_schema + FROM user_constraints c + JOIN user_cons_columns cc ON c.constraint_name = cc.constraint_name + JOIN user_constraints r ON c.r_constraint_name = r.constraint_name + JOIN user_cons_columns rcc ON r.constraint_name = rcc.constraint_name + WHERE c.constraint_type = 'R' + AND cc.position = rcc.position + {where_clause} + """ + + def _get_indexes_sql(self, table: "str | None" = None) -> str: + where_clause = f"AND i.table_name = '{table.upper()}'" if table else "" + return f""" + SELECT + i.index_name AS name, + i.table_name, + i.uniqueness, + LISTAGG(ic.column_name, ',') WITHIN GROUP (ORDER BY ic.column_position) AS columns + FROM user_indexes i + JOIN user_ind_columns ic ON i.index_name = ic.index_name + WHERE 1=1 {where_clause} + GROUP BY i.index_name, i.table_name, i.uniqueness + """ + + def _get_topological_sort_sql(self) -> str: + return """ + SELECT table_name, MAX(LEVEL) as lvl + FROM user_constraints + START WITH table_name NOT IN ( + SELECT table_name FROM user_constraints WHERE constraint_type = 'R' + ) + CONNECT BY NOCYCLE PRIOR constraint_name = r_constraint_name + GROUP BY table_name + ORDER BY lvl, table_name + """ + def _select_component_version_row(self, driver: "OracleSyncDriver") -> "dict[str, Any] | None": """Fetch the latest Oracle component version row. @@ -382,6 +429,66 @@ def get_columns( result = oracle_driver.execute(self._get_columns_sql(table, schema)) return result.get_data() + def get_tables_in_topological_order( + self, driver: "SyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order.""" + oracle_driver = cast("OracleSyncDriver", driver) + # Fetch dependency sorted tables + result = oracle_driver.execute(self._get_topological_sort_sql()) + sorted_tables = [row["table_name"] for row in result.get_data()] + + # Fetch all tables + all_result = oracle_driver.execute("SELECT table_name FROM user_tables") + all_tables = {row["table_name"] for row in all_result.get_data()} + + # Add disconnected tables (level 0 implied) at the beginning + sorted_set = set(sorted_tables) + disconnected = list(all_tables - sorted_set) + disconnected.sort() + + return disconnected + sorted_tables + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + oracle_driver = cast("OracleSyncDriver", driver) + result = oracle_driver.execute(self._get_foreign_keys_sql(table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["schema"], + referenced_schema=row["referenced_schema"], + ) + for row in result.get_data() + ] + + def get_indexes( + self, driver: "SyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + oracle_driver = cast("OracleSyncDriver", driver) + result = oracle_driver.execute(self._get_indexes_sql(table)) + + return [ + { + "name": row["name"], + "columns": row["columns"].split(",") if row["columns"] else [], + "unique": row["uniqueness"] == "UNIQUE", + "primary": False, + "table_name": row["table_name"], + } + for row in result.get_data() + ] + def list_available_features(self) -> "list[str]": """List available Oracle feature flags. @@ -544,6 +651,66 @@ async def get_columns( result = await oracle_driver.execute(self._get_columns_sql(table, schema)) return result.get_data() + async def get_tables_in_topological_order( + self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order.""" + oracle_driver = cast("OracleAsyncDriver", driver) + # Fetch dependency sorted tables + result = await oracle_driver.execute(self._get_topological_sort_sql()) + sorted_tables = [row["table_name"] for row in result.get_data()] + + # Fetch all tables + all_result = await oracle_driver.execute("SELECT table_name FROM user_tables") + all_tables = {row["table_name"] for row in all_result.get_data()} + + # Add disconnected tables (level 0 implied) at the beginning + sorted_set = set(sorted_tables) + disconnected = list(all_tables - sorted_set) + disconnected.sort() + + return disconnected + sorted_tables + + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + oracle_driver = cast("OracleAsyncDriver", driver) + result = await oracle_driver.execute(self._get_foreign_keys_sql(table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["schema"], + referenced_schema=row["referenced_schema"], + ) + for row in result.get_data() + ] + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + oracle_driver = cast("OracleAsyncDriver", driver) + result = await oracle_driver.execute(self._get_indexes_sql(table)) + + return [ + { + "name": row["name"], + "columns": row["columns"].split(",") if row["columns"] else [], + "unique": row["uniqueness"] == "UNIQUE", + "primary": False, + "table_name": row["table_name"], + } + for row in result.get_data() + ] + def list_available_features(self) -> "list[str]": """List available Oracle feature flags. diff --git a/sqlspec/adapters/psycopg/data_dictionary.py b/sqlspec/adapters/psycopg/data_dictionary.py index a52c0eafb..f9c4fc81c 100644 --- a/sqlspec/adapters/psycopg/data_dictionary.py +++ b/sqlspec/adapters/psycopg/data_dictionary.py @@ -16,6 +16,7 @@ from collections.abc import Callable from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver + from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.psycopg.data_dictionary") @@ -163,6 +164,152 @@ def get_columns( result = psycopg_driver.execute(sql, (table, schema_name)) return result.data or [] + def get_tables_in_topological_order( + self, driver: "SyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order using Recursive CTE.""" + psycopg_driver = cast("PsycopgSyncDriver", driver) + schema_name = schema or "public" + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + t.table_name::text, + 0 AS level, + ARRAY[t.table_name::text] AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = %s + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = t.table_name + AND tc.table_schema = t.table_schema + ) + + UNION ALL + + SELECT + tc.table_name::text, + dt.level + 1, + dt.path || tc.table_name::text + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + JOIN dependency_tree dt + ON ccu.table_name = dt.table_name + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = %s + AND ccu.table_schema = %s + AND NOT (tc.table_name = ANY(dt.path)) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name; + """ + result = psycopg_driver.execute(sql, (schema_name, schema_name, schema_name)) + return [row["table_name"] for row in result.data] + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + psycopg_driver = cast("PsycopgSyncDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + ccu.table_name AS referenced_table_name, + ccu.column_name AS referenced_column_name, + tc.constraint_name, + tc.table_schema, + ccu.table_schema AS referenced_table_schema + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND (%s::text IS NULL OR tc.table_schema = %s) + AND (%s::text IS NULL OR tc.table_name = %s) + """ + + result = psycopg_driver.execute(sql, (schema_name, schema_name, table, table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + + def get_indexes( + self, driver: "SyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + psycopg_driver = cast("PsycopgSyncDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + i.relname as index_name, + ix.indisunique as is_unique, + ix.indisprimary as is_primary, + array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + AND i.oid = ix.indexrelid + AND a.attrelid = t.oid + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + AND t.relnamespace = n.oid + AND n.nspname = %s + AND t.relname = %s + GROUP BY + i.relname, + ix.indisunique, + ix.indisprimary + """ + result = psycopg_driver.execute(sql, (schema_name, table)) + + return [ + { + "name": row["index_name"], + "columns": row["columns"], + "unique": row["is_unique"], + "primary": row["is_primary"], + "table_name": table, + } + for row in result.data + ] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. @@ -323,6 +470,152 @@ async def get_columns( result = await psycopg_driver.execute(sql, (table, schema_name)) return result.data or [] + async def get_tables_in_topological_order( + self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order using Recursive CTE.""" + psycopg_driver = cast("PsycopgAsyncDriver", driver) + schema_name = schema or "public" + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + t.table_name::text, + 0 AS level, + ARRAY[t.table_name::text] AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = %s + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = t.table_name + AND tc.table_schema = t.table_schema + ) + + UNION ALL + + SELECT + tc.table_name::text, + dt.level + 1, + dt.path || tc.table_name::text + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + JOIN dependency_tree dt + ON ccu.table_name = dt.table_name + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = %s + AND ccu.table_schema = %s + AND NOT (tc.table_name = ANY(dt.path)) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name; + """ + result = await psycopg_driver.execute(sql, (schema_name, schema_name, schema_name)) + return [row["table_name"] for row in result.data] + + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + psycopg_driver = cast("PsycopgAsyncDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + ccu.table_name AS referenced_table_name, + ccu.column_name AS referenced_column_name, + tc.constraint_name, + tc.table_schema, + ccu.table_schema AS referenced_table_schema + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND (%s::text IS NULL OR tc.table_schema = %s) + AND (%s::text IS NULL OR tc.table_name = %s) + """ + + result = await psycopg_driver.execute(sql, (schema_name, schema_name, table, table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + psycopg_driver = cast("PsycopgAsyncDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + i.relname as index_name, + ix.indisunique as is_unique, + ix.indisprimary as is_primary, + array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + AND i.oid = ix.indexrelid + AND a.attrelid = t.oid + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + AND t.relnamespace = n.oid + AND n.nspname = %s + AND t.relname = %s + GROUP BY + i.relname, + ix.indisunique, + ix.indisprimary + """ + result = await psycopg_driver.execute(sql, (schema_name, table)) + + return [ + { + "name": row["index_name"], + "columns": row["columns"], + "unique": row["is_unique"], + "primary": row["is_primary"], + "table_name": table, + } + for row in result.data + ] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. diff --git a/sqlspec/adapters/sqlite/data_dictionary.py b/sqlspec/adapters/sqlite/data_dictionary.py index 4377a8801..c543c9649 100644 --- a/sqlspec/adapters/sqlite/data_dictionary.py +++ b/sqlspec/adapters/sqlite/data_dictionary.py @@ -10,6 +10,7 @@ from collections.abc import Callable from sqlspec.adapters.sqlite.driver import SqliteDriver + from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.sqlite.data_dictionary") @@ -129,6 +130,122 @@ def get_columns( for row in result.data or [] ] + def get_tables_in_topological_order( + self, driver: "SyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order.""" + sqlite_driver = cast("SqliteDriver", driver) + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + m.name as table_name, + 0 as level, + '/' || m.name || '/' as path + FROM sqlite_schema m + WHERE m.type = 'table' + AND m.name NOT LIKE 'sqlite_%' + AND NOT EXISTS ( + SELECT 1 FROM pragma_foreign_key_list(m.name) + ) + + UNION ALL + + SELECT + m.name as table_name, + dt.level + 1, + dt.path || m.name || '/' + FROM sqlite_schema m + JOIN pragma_foreign_key_list(m.name) fk + JOIN dependency_tree dt ON fk."table" = dt.table_name + WHERE m.type = 'table' + AND m.name NOT LIKE 'sqlite_%' + AND instr(dt.path, '/' || m.name || '/') = 0 + ) + SELECT DISTINCT table_name, level FROM dependency_tree ORDER BY level, table_name; + """ + try: + result = sqlite_driver.execute(sql) + return [row["table_name"] if isinstance(row, dict) else row[0] for row in result.data] + except Exception: + return self.sort_tables_topologically( + self.get_tables(driver, schema), self.get_foreign_keys(driver, schema=schema) + ) + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + from sqlspec.driver import ForeignKeyMetadata + + sqlite_driver = cast("SqliteDriver", driver) + + if table: + sql = f"SELECT '{table}' as table_name, fk.* FROM pragma_foreign_key_list('{table}') fk" + result = sqlite_driver.execute(sql) + else: + sql = """ + SELECT m.name as table_name, fk.* + FROM sqlite_schema m, pragma_foreign_key_list(m.name) fk + WHERE m.type = 'table' AND m.name NOT LIKE 'sqlite_%' + """ + result = sqlite_driver.execute(sql) + + fks = [] + for row in result.data: + if isinstance(row, (list, tuple)): + table_name = row[0] + ref_table = row[3] + col = row[4] + ref_col = row[5] + else: + table_name = row["table_name"] + ref_table = row["table"] + col = row["from"] + ref_col = row["to"] + + fks.append( + ForeignKeyMetadata( + table_name=table_name, + column_name=col, + referenced_table=ref_table, + referenced_column=ref_col, + constraint_name=None, + schema=None, + referenced_schema=None, + ) + ) + return fks + + def get_indexes( + self, driver: "SyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + sqlite_driver = cast("SqliteDriver", driver) + + index_list_res = sqlite_driver.execute(f"PRAGMA index_list('{table}')") + indexes = [] + + for idx_row in index_list_res.data: + if isinstance(idx_row, (list, tuple)): + idx_name = idx_row[1] + unique = bool(idx_row[2]) + else: + idx_name = idx_row["name"] + unique = bool(idx_row["unique"]) + + info_res = sqlite_driver.execute(f"PRAGMA index_info('{idx_name}')") + cols = [] + for col_row in info_res.data: + if isinstance(col_row, (list, tuple)): + cols.append(col_row[2]) + else: + cols.append(col_row["name"]) + + indexes.append({"name": idx_name, "columns": cols, "unique": unique, "primary": False, "table_name": table}) + + return indexes + def list_available_features(self) -> "list[str]": """List available SQLite feature flags. diff --git a/sqlspec/driver/__init__.py b/sqlspec/driver/__init__.py index 4bc2e18e0..fae69124f 100644 --- a/sqlspec/driver/__init__.py +++ b/sqlspec/driver/__init__.py @@ -3,8 +3,12 @@ from sqlspec.driver import mixins from sqlspec.driver._async import AsyncDataDictionaryBase, AsyncDriverAdapterBase from sqlspec.driver._common import ( + ColumnMetadata, CommonDriverAttributesMixin, + DataDictionaryMixin, ExecutionResult, + ForeignKeyMetadata, + IndexMetadata, StackExecutionObserver, VersionInfo, describe_stack_statement, @@ -14,9 +18,13 @@ __all__ = ( "AsyncDataDictionaryBase", "AsyncDriverAdapterBase", + "ColumnMetadata", "CommonDriverAttributesMixin", + "DataDictionaryMixin", "DriverAdapterProtocol", "ExecutionResult", + "ForeignKeyMetadata", + "IndexMetadata", "StackExecutionObserver", "SyncDataDictionaryBase", "SyncDriverAdapterBase", diff --git a/sqlspec/driver/_async.py b/sqlspec/driver/_async.py index 0d7497a91..4bd957de6 100644 --- a/sqlspec/driver/_async.py +++ b/sqlspec/driver/_async.py @@ -27,6 +27,7 @@ from sqlspec.builder import QueryBuilder from sqlspec.core import ArrowResult, SQLResult, StatementConfig, StatementFilter + from sqlspec.driver._common import ForeignKeyMetadata from sqlspec.typing import ArrowReturnFormat, SchemaT, StatementParameters @@ -747,6 +748,41 @@ async def get_indexes( _ = driver, table, schema return [] + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata. + + Args: + driver: Async database driver instance + table: Optional table name filter + schema: Optional schema name filter + + Returns: + List of foreign key metadata + """ + _ = driver, table, schema + return [] + + async def get_tables_in_topological_order( + self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order. + + Default implementation fetches all tables and foreign keys, + then uses Python's topological sort. + + Args: + driver: Async database driver instance + schema: Optional schema name + + Returns: + List of table names sorted by dependency + """ + tables = await self.get_tables(driver, schema) + foreign_keys = await self.get_foreign_keys(driver, schema=schema) + return self.sort_tables_topologically(tables, foreign_keys) + def list_available_features(self) -> "list[str]": """List all features that can be checked via get_feature_flag. diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index 7d686d0a3..661102dea 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -1,5 +1,6 @@ """Common driver attributes and utilities.""" +import graphlib import hashlib import logging import re @@ -42,9 +43,12 @@ "EXEC_CURSOR_RESULT", "EXEC_ROWCOUNT_OVERRIDE", "EXEC_SPECIAL_DATA", + "ColumnMetadata", "CommonDriverAttributesMixin", "DataDictionaryMixin", "ExecutionResult", + "ForeignKeyMetadata", + "IndexMetadata", "ScriptExecutionResult", "StackExecutionObserver", "VersionInfo", @@ -62,6 +66,163 @@ VERSION_GROUPS_MIN_FOR_PATCH = 2 +class ForeignKeyMetadata: + """Metadata for a foreign key constraint.""" + + __slots__ = ( + "column_name", + "constraint_name", + "referenced_column", + "referenced_schema", + "referenced_table", + "schema", + "table_name", + ) + + def __init__( + self, + table_name: str, + column_name: str, + referenced_table: str, + referenced_column: str, + constraint_name: str | None = None, + schema: str | None = None, + referenced_schema: str | None = None, + ) -> None: + self.table_name = table_name + self.column_name = column_name + self.referenced_table = referenced_table + self.referenced_column = referenced_column + self.constraint_name = constraint_name + self.schema = schema + self.referenced_schema = referenced_schema + + def __repr__(self) -> str: + return ( + f"ForeignKeyMetadata(table_name={self.table_name!r}, column_name={self.column_name!r}, " + f"referenced_table={self.referenced_table!r}, referenced_column={self.referenced_column!r}, " + f"constraint_name={self.constraint_name!r}, schema={self.schema!r}, referenced_schema={self.referenced_schema!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ForeignKeyMetadata): + return NotImplemented + return ( + self.table_name == other.table_name + and self.column_name == other.column_name + and self.referenced_table == other.referenced_table + and self.referenced_column == other.referenced_column + and self.constraint_name == other.constraint_name + and self.schema == other.schema + and self.referenced_schema == other.referenced_schema + ) + + def __hash__(self) -> int: + return hash(( + self.table_name, + self.column_name, + self.referenced_table, + self.referenced_column, + self.constraint_name, + self.schema, + self.referenced_schema, + )) + + +class ColumnMetadata: + """Metadata for a database column.""" + + __slots__ = ("data_type", "default_value", "max_length", "name", "nullable", "precision", "primary_key", "scale") + + def __init__( + self, + name: str, + data_type: str, + nullable: bool, + default_value: str | None = None, + primary_key: bool = False, + max_length: int | None = None, + precision: int | None = None, + scale: int | None = None, + ) -> None: + self.name = name + self.data_type = data_type + self.nullable = nullable + self.default_value = default_value + self.primary_key = primary_key + self.max_length = max_length + self.precision = precision + self.scale = scale + + def __repr__(self) -> str: + return ( + f"ColumnMetadata(name={self.name!r}, data_type={self.data_type!r}, nullable={self.nullable!r}, " + f"default_value={self.default_value!r}, primary_key={self.primary_key!r}, max_length={self.max_length!r}, " + f"precision={self.precision!r}, scale={self.scale!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ColumnMetadata): + return NotImplemented + return ( + self.name == other.name + and self.data_type == other.data_type + and self.nullable == other.nullable + and self.default_value == other.default_value + and self.primary_key == other.primary_key + and self.max_length == other.max_length + and self.precision == other.precision + and self.scale == other.scale + ) + + def __hash__(self) -> int: + return hash(( + self.name, + self.data_type, + self.nullable, + self.default_value, + self.primary_key, + self.max_length, + self.precision, + self.scale, + )) + + +class IndexMetadata: + """Metadata for a database index.""" + + __slots__ = ("columns", "name", "primary", "table_name", "unique") + + def __init__( + self, name: str, table_name: str, columns: list[str], unique: bool = False, primary: bool = False + ) -> None: + self.name = name + self.table_name = table_name + self.columns = columns + self.unique = unique + self.primary = primary + + def __repr__(self) -> str: + return ( + f"IndexMetadata(name={self.name!r}, table_name={self.table_name!r}, columns={self.columns!r}, " + f"unique={self.unique!r}, primary={self.primary!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, IndexMetadata): + return NotImplemented + return ( + self.name == other.name + and self.table_name == other.table_name + and self.columns == other.columns + and self.unique == other.unique + and self.primary == other.primary + ) + + def __hash__(self) -> int: + return hash((self.name, self.table_name, tuple(self.columns), self.unique, self.primary)) + + def make_cache_key_hashable(obj: Any) -> Any: """Recursively convert unhashable types to hashable ones for cache keys. @@ -374,6 +535,32 @@ def get_default_features(self) -> "list[str]": """ return ["supports_transactions", "supports_prepared_statements"] + def sort_tables_topologically(self, tables: "list[str]", foreign_keys: "list[ForeignKeyMetadata]") -> "list[str]": + """Sort tables topologically based on foreign key dependencies using Python. + + Args: + tables: List of table names. + foreign_keys: List of foreign key metadata. + + Returns: + List of table names in topological order (dependencies first). + + Raises: + CycleError: If a dependency cycle is detected. + """ + sorter: graphlib.TopologicalSorter[str] = graphlib.TopologicalSorter() + for table in tables: + sorter.add(table) + + for fk in foreign_keys: + # If self-referencing, ignore for sorting purposes to avoid simple cycles + if fk.table_name == fk.referenced_table: + continue + # table_name depends on referenced_table + sorter.add(fk.table_name, fk.referenced_table) + + return list(sorter.static_order()) + class ScriptExecutionResult(NamedTuple): """Result from script execution with statement count information.""" diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index d748cbcbb..f55110e63 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -27,6 +27,7 @@ from sqlspec.builder import QueryBuilder from sqlspec.core import ArrowResult, SQLResult, Statement, StatementConfig, StatementFilter + from sqlspec.driver._common import ForeignKeyMetadata from sqlspec.typing import ArrowReturnFormat, SchemaT, StatementParameters _LOGGER_NAME: Final[str] = "sqlspec" @@ -747,6 +748,41 @@ def get_indexes( _ = driver, table, schema return [] + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata. + + Args: + driver: Sync database driver instance + table: Optional table name filter + schema: Optional schema name filter + + Returns: + List of foreign key metadata + """ + _ = driver, table, schema + return [] + + def get_tables_in_topological_order( + self, driver: "SyncDriverAdapterBase", schema: "str | None" = None + ) -> "list[str]": + """Get tables sorted by topological dependency order. + + Default implementation fetches all tables and foreign keys, + then uses Python's topological sort. + + Args: + driver: Sync database driver instance + schema: Optional schema name + + Returns: + List of table names sorted by dependency + """ + tables = self.get_tables(driver, schema) + foreign_keys = self.get_foreign_keys(driver, schema=schema) + return self.sort_tables_topologically(tables, foreign_keys) + def list_available_features(self) -> "list[str]": """List all features that can be checked via get_feature_flag. diff --git a/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py b/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py new file mode 100644 index 000000000..9dd808546 --- /dev/null +++ b/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py @@ -0,0 +1,73 @@ +"""Integration tests for AioSQLite data dictionary.""" + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver + +pytestmark = pytest.mark.xdist_group("aiosqlite") + + +@pytest.mark.aiosqlite +async def test_aiosqlite_data_dictionary_topology_and_fks(aiosqlite_session: "AiosqliteDriver") -> None: + """Test topological sort and FK metadata.""" + aiosqlite_driver = aiosqlite_session + import uuid + + unique_suffix = uuid.uuid4().hex[:8] + users_table = f"dd_users_{unique_suffix}" + orders_table = f"dd_orders_{unique_suffix}" + items_table = f"dd_items_{unique_suffix}" + + await aiosqlite_driver.execute_script(f""" + CREATE TABLE {users_table} ( + id INTEGER PRIMARY KEY, + name VARCHAR(50) + ); + CREATE TABLE {orders_table} ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES {users_table}(id), + amount INTEGER + ); + CREATE TABLE {items_table} ( + id INTEGER PRIMARY KEY, + order_id INTEGER REFERENCES {orders_table}(id), + name VARCHAR(50) + ); + """) + + try: + # Test 1: Topological Sort + sorted_tables = await aiosqlite_driver.data_dictionary.get_tables_in_topological_order(aiosqlite_driver) + + test_tables = [t for t in sorted_tables if t in (users_table, orders_table, items_table)] + assert len(test_tables) == 3 + + idx_users = test_tables.index(users_table) + idx_orders = test_tables.index(orders_table) + idx_items = test_tables.index(items_table) + + assert idx_users < idx_orders + assert idx_orders < idx_items + + # Test 2: Foreign Keys + fks = await aiosqlite_driver.data_dictionary.get_foreign_keys(aiosqlite_driver, table=orders_table) + assert len(fks) >= 1 + my_fk = next((fk for fk in fks if fk.referenced_table == users_table), None) + assert my_fk is not None + assert my_fk.column_name == "user_id" + + # Test 3: Indexes + await aiosqlite_driver.execute(f"CREATE INDEX idx_{unique_suffix} ON {users_table}(name)") + indexes = await aiosqlite_driver.data_dictionary.get_indexes(aiosqlite_driver, table=users_table) + assert len(indexes) >= 1 + assert any(idx["name"] == f"idx_{unique_suffix}" for idx in indexes) + + finally: + await aiosqlite_driver.execute_script(f""" + DROP TABLE IF EXISTS {items_table}; + DROP TABLE IF EXISTS {orders_table}; + DROP TABLE IF EXISTS {users_table}; + """) diff --git a/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py b/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py index 68126324f..c2bc5578f 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py +++ b/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py @@ -105,3 +105,63 @@ async def test_asyncpg_data_dictionary_available_features(asyncpg_async_driver: for feature in expected_features: assert feature in features + + +@pytest.mark.asyncpg +async def test_asyncpg_data_dictionary_topology_and_fks(asyncpg_async_driver: "AsyncpgDriver") -> None: + """Test topological sort and FK metadata.""" + import uuid + + unique_suffix = uuid.uuid4().hex[:8] + users_table = f"dd_users_{unique_suffix}" + orders_table = f"dd_orders_{unique_suffix}" + items_table = f"dd_items_{unique_suffix}" + + await asyncpg_async_driver.execute_script(f""" + CREATE TABLE {users_table} ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ); + CREATE TABLE {orders_table} ( + id SERIAL PRIMARY KEY, + user_id INTEGER REFERENCES {users_table}(id), + amount INTEGER + ); + CREATE TABLE {items_table} ( + id SERIAL PRIMARY KEY, + order_id INTEGER REFERENCES {orders_table}(id), + name VARCHAR(50) + ); + """) + + try: + # Test 1: Topological Sort + sorted_tables = await asyncpg_async_driver.data_dictionary.get_tables_in_topological_order(asyncpg_async_driver) + + test_tables = [t for t in sorted_tables if t in (users_table, orders_table, items_table)] + assert len(test_tables) == 3 + + idx_users = test_tables.index(users_table) + idx_orders = test_tables.index(orders_table) + idx_items = test_tables.index(items_table) + + assert idx_users < idx_orders + assert idx_orders < idx_items + + # Test 2: Foreign Keys + fks = await asyncpg_async_driver.data_dictionary.get_foreign_keys(asyncpg_async_driver, table=orders_table) + assert len(fks) >= 1 + my_fk = next((fk for fk in fks if fk.referenced_table == users_table), None) + assert my_fk is not None + assert my_fk.column_name == "user_id" + + # Test 3: Indexes + indexes = await asyncpg_async_driver.data_dictionary.get_indexes(asyncpg_async_driver, table=users_table) + assert len(indexes) >= 1 # PK index + + finally: + await asyncpg_async_driver.execute_script(f""" + DROP TABLE IF EXISTS {items_table} CASCADE; + DROP TABLE IF EXISTS {orders_table} CASCADE; + DROP TABLE IF EXISTS {users_table} CASCADE; + """) From 1b3ec4d314240b37fee9ec8f8b322a32cf34dd26 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 04:11:13 +0000 Subject: [PATCH 2/6] feat: updated metadata class pattern and improve test isolation in examples --- AGENTS.md | 52 +++++++++++++ .../usage/usage_drivers_and_querying_10.py | 73 +++++++++++++------ .../usage/usage_drivers_and_querying_6.py | 51 +++++++------ docs/guides/architecture/data-dictionary.md | 23 +----- sqlspec/adapters/adbc/data_dictionary.py | 5 +- 5 files changed, 138 insertions(+), 66 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index d9f031851..7e35fcdae 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -451,6 +451,58 @@ if supports_where(obj): result = obj.where("condition") ``` +### Mypyc-Compatible Metadata Class Pattern + +When defining data-holding classes intended for core modules (`sqlspec/core/`, `sqlspec/driver/`) that will be compiled with MyPyC, use regular classes with `__slots__` and explicitly implement `__init__`, `__repr__`, `__eq__`, and `__hash__`. This approach ensures optimal performance and MyPyC compatibility, as `dataclasses` are not directly supported by MyPyC for compilation. + +**Key Principles:** + +- **`__slots__`**: Reduces memory footprint and speeds up attribute access. +- **Explicit `__init__`**: Defines the constructor for the class. +- **Explicit `__repr__`**: Provides a clear string representation for debugging. +- **Explicit `__eq__`**: Enables correct equality comparisons. +- **Explicit `__hash__`**: Makes instances hashable, allowing them to be used in sets or as dictionary keys. The hash implementation should be based on all fields that define the object's identity. + +**Example Implementation:** + +```python +class MyMetadata: + __slots__ = ("field1", "field2", "optional_field") + + def __init__(self, field1: str, field2: int, optional_field: str | None = None) -> None: + self.field1 = field1 + self.field2 = field2 + self.optional_field = optional_field + + def __repr__(self) -> str: + return f"MyMetadata(field1={self.field1!r}, field2={self.field2!r}, optional_field={self.optional_field!r})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MyMetadata): + return NotImplemented + return ( + self.field1 == other.field1 + and self.field2 == other.field2 + and self.optional_field == other.optional_field + ) + + def __hash__(self) -> int: + return hash((self.field1, self.field2, self.optional_field)) +``` + +**When to Use:** + +- For all new data-holding classes in performance-critical paths (e.g., `sqlspec/driver/_common.py`). +- When MyPyC compilation is enabled for the module containing the class. + +**Anti-Patterns to Avoid:** + +- Using `@dataclass` decorators for classes intended for MyPyC compilation. +- Omitting `__slots__` when defining performance-critical data structures. +- Relying on default `__eq__` or `__hash__` behavior for complex objects, especially for equality comparisons in collections. + +--- + ### Performance Patterns (MANDATORY) **PERF401 - List Operations**: diff --git a/docs/examples/usage/usage_drivers_and_querying_10.py b/docs/examples/usage/usage_drivers_and_querying_10.py index 84e06aef5..e05b29f6b 100644 --- a/docs/examples/usage/usage_drivers_and_querying_10.py +++ b/docs/examples/usage/usage_drivers_and_querying_10.py @@ -1,6 +1,7 @@ # Test module converted from docs example - code-block 10 """Minimal smoke test for drivers_and_querying example 10.""" +import tempfile from pathlib import Path import pytest @@ -15,27 +16,53 @@ def test_example_10_duckdb_config(tmp_path: Path) -> None: from sqlspec import SQLSpec from sqlspec.adapters.duckdb import DuckDBConfig - spec = SQLSpec() - # In-memory - config = DuckDBConfig() - - # Persistent - database_file = tmp_path / "analytics.duckdb" - config = DuckDBConfig(pool_config={"database": database_file.name, "read_only": False}) - - with spec.provide_session(config) as session: - # Create table from Parquet - session.execute(f""" - CREATE TABLE if not exists users AS - SELECT * FROM read_parquet('{Path(__file__).parent.parent / "queries/users.parquet"}') - """) - - # Analytical query - session.execute(""" - SELECT date_trunc('day', created_at) as day, - count(*) as user_count - FROM users - GROUP BY day - ORDER BY day - """) + # Use a temporary directory for the DuckDB database for test isolation + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "analytics.duckdb" + + spec = SQLSpec() + # In-memory + in_memory_config = DuckDBConfig() + + # Persistent (using the temporary file in the temporary directory) + persistent_config = DuckDBConfig(pool_config={"database": str(db_path)}) + + try: + # Test with in-memory config + with spec.provide_session(in_memory_config) as session: + # Create table from Parquet + session.execute(f""" + CREATE TABLE if not exists users AS + SELECT * FROM read_parquet('{Path(__file__).parent.parent / "queries/users.parquet"}') + """) + + # Analytical query + session.execute(""" + SELECT date_trunc('day', created_at) as day, + count(*) as user_count + FROM users + GROUP BY day + ORDER BY day + """) + + # Test with persistent config + with spec.provide_session(persistent_config) as session: + # Create table from Parquet + session.execute(f""" + CREATE TABLE if not exists users AS + SELECT * FROM read_parquet('{Path(__file__).parent.parent / "queries/users.parquet"}') + """) + + # Analytical query + session.execute(""" + SELECT date_trunc('day', created_at) as day, + count(*) as user_count + FROM users + GROUP BY day + ORDER BY day + """) + finally: + # Close the pool for the persistent config + persistent_config.close_pool() + # The TemporaryDirectory context manager handles directory cleanup automatically # end-example diff --git a/docs/examples/usage/usage_drivers_and_querying_6.py b/docs/examples/usage/usage_drivers_and_querying_6.py index 18ebaa641..94251eb21 100644 --- a/docs/examples/usage/usage_drivers_and_querying_6.py +++ b/docs/examples/usage/usage_drivers_and_querying_6.py @@ -1,6 +1,7 @@ # Test module converted from docs example - code-block 6 """Minimal smoke test for drivers_and_querying example 6.""" +import tempfile from pathlib import Path from sqlspec import SQLSpec @@ -12,24 +13,32 @@ def test_example_6_sqlite_config(tmp_path: Path) -> None: # start-example from sqlspec.adapters.sqlite import SqliteConfig - spec = SQLSpec() - - database_file = tmp_path / "myapp.db" - config = SqliteConfig(pool_config={"database": database_file.name, "timeout": 5.0, "check_same_thread": False}) - - with spec.provide_session(config) as session: - # Create table - session.execute(""" - CREATE TABLE IF NOT EXISTS usage6_users ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL - ) - """) - - # Insert with parameters - session.execute("INSERT INTO usage6_users (name) VALUES (?)", "Alice") - - # Query - result = session.execute("SELECT * FROM usage6_users") - result.all() - # end-example + # Use a temporary file for the SQLite database for test isolation + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db_file: + db_path = tmp_db_file.name + + spec = SQLSpec() + + config = SqliteConfig(pool_config={"database": db_path, "timeout": 5.0, "check_same_thread": False}) + + try: + with spec.provide_session(config) as session: + # Create table + session.execute(""" + CREATE TABLE IF NOT EXISTS usage6_users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ) + """) + + # Insert with parameters + session.execute("INSERT INTO usage6_users (name) VALUES (?)", "Alice") + + # Query + result = session.execute("SELECT * FROM usage6_users") + result.all() + finally: + # Clean up the temporary database file + config.close_pool() + Path(db_path).unlink() + # end-example diff --git a/docs/guides/architecture/data-dictionary.md b/docs/guides/architecture/data-dictionary.md index 728b9238f..706777424 100644 --- a/docs/guides/architecture/data-dictionary.md +++ b/docs/guides/architecture/data-dictionary.md @@ -43,7 +43,7 @@ This is essential for: async with config.provide_session() as session: # Get tables sorted parent -> child sorted_tables = await session.data_dictionary.get_tables_in_topological_order(session) - + print("Insertion Order:", sorted_tables) print("Deletion Order:", list(reversed(sorted_tables))) ``` @@ -56,14 +56,14 @@ async with config.provide_session() as session: ### Metadata Types -SQLSpec uses typed dataclasses for metadata results where possible. +SQLSpec uses regular classes with __slots__ for metadata results to ensure mypyc compatibility and memory efficiency. ```python from sqlspec.driver import ForeignKeyMetadata async with config.provide_session() as session: fks: list[ForeignKeyMetadata] = await session.data_dictionary.get_foreign_keys(session, "orders") - + for fk in fks: print(f"FK: {fk.column_name} -> {fk.referenced_table}.{fk.referenced_column}") ``` @@ -80,19 +80,4 @@ async with config.provide_session() as session: ## API Reference -### Data Dictionary Protocol - -The base interface shared by all adapters. - -```python -class DataDictionaryBase: - async def get_tables(self, driver, schema=None) -> list[str]: ... - - async def get_columns(self, driver, table, schema=None) -> list[dict]: ... - - async def get_indexes(self, driver, table, schema=None) -> list[dict]: ... - - async def get_foreign_keys(self, driver, table=None, schema=None) -> list[ForeignKeyMetadata]: ... - - async def get_tables_in_topological_order(self, driver, schema=None) -> list[str]: ... -``` +For a complete API reference of the Data Dictionary components, including `DataDictionaryMixin`, `AsyncDataDictionaryBase`, `SyncDataDictionaryBase`, and the metadata classes (`ForeignKeyMetadata`, `ColumnMetadata`, `IndexMetadata`), please refer to the :doc:`/reference/driver`. diff --git a/sqlspec/adapters/adbc/data_dictionary.py b/sqlspec/adapters/adbc/data_dictionary.py index e608b90e6..78b43ceaa 100644 --- a/sqlspec/adapters/adbc/data_dictionary.py +++ b/sqlspec/adapters/adbc/data_dictionary.py @@ -61,7 +61,6 @@ def get_foreign_keys( # They all support information_schema.key_column_usage roughly the same way # Postgres/DuckDB/MySQL query - where_clauses = [] params = [] if dialect == "bigquery": @@ -83,8 +82,8 @@ def get_foreign_keys( pk_kcu.table_schema AS referenced_table_schema FROM {kcu} kcu JOIN {rc} rc ON kcu.constraint_name = rc.constraint_name - JOIN {kcu} pk_kcu - ON rc.unique_constraint_name = pk_kcu.constraint_name + JOIN {kcu} pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name AND kcu.ordinal_position = pk_kcu.ordinal_position """ if table: From 22ab966dadcfcfc20b16c34c04dfe4c9d907a980 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 18:01:02 +0000 Subject: [PATCH 3/6] refactor: unify database configuration handling in usage examples --- .../usage/usage_drivers_and_querying_10.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/examples/usage/usage_drivers_and_querying_10.py b/docs/examples/usage/usage_drivers_and_querying_10.py index e05b29f6b..138ef5f0d 100644 --- a/docs/examples/usage/usage_drivers_and_querying_10.py +++ b/docs/examples/usage/usage_drivers_and_querying_10.py @@ -1,7 +1,6 @@ # Test module converted from docs example - code-block 10 """Minimal smoke test for drivers_and_querying example 10.""" -import tempfile from pathlib import Path import pytest @@ -13,6 +12,8 @@ def test_example_10_duckdb_config(tmp_path: Path) -> None: # start-example + import tempfile + from sqlspec import SQLSpec from sqlspec.adapters.duckdb import DuckDBConfig @@ -22,14 +23,12 @@ def test_example_10_duckdb_config(tmp_path: Path) -> None: spec = SQLSpec() # In-memory - in_memory_config = DuckDBConfig() - - # Persistent (using the temporary file in the temporary directory) - persistent_config = DuckDBConfig(pool_config={"database": str(db_path)}) + in_memory_db = spec.add_config(DuckDBConfig()) + persistent_db = spec.add_config(DuckDBConfig(pool_config={"database": str(db_path)})) try: # Test with in-memory config - with spec.provide_session(in_memory_config) as session: + with spec.provide_session(in_memory_db) as session: # Create table from Parquet session.execute(f""" CREATE TABLE if not exists users AS @@ -46,7 +45,7 @@ def test_example_10_duckdb_config(tmp_path: Path) -> None: """) # Test with persistent config - with spec.provide_session(persistent_config) as session: + with spec.provide_session(persistent_db) as session: # Create table from Parquet session.execute(f""" CREATE TABLE if not exists users AS @@ -63,6 +62,7 @@ def test_example_10_duckdb_config(tmp_path: Path) -> None: """) finally: # Close the pool for the persistent config - persistent_config.close_pool() + spec.get_config(in_memory_db).close_pool() + spec.get_config(persistent_db).close_pool() # The TemporaryDirectory context manager handles directory cleanup automatically # end-example From 7fd6d1d5283ccedd29cfc9d3641f2d000ad84164 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 18:10:02 +0000 Subject: [PATCH 4/6] fix: ensure database file path is correctly converted to string in usage examples --- docs/examples/usage/usage_drivers_and_querying_6.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/examples/usage/usage_drivers_and_querying_6.py b/docs/examples/usage/usage_drivers_and_querying_6.py index 94251eb21..419555b33 100644 --- a/docs/examples/usage/usage_drivers_and_querying_6.py +++ b/docs/examples/usage/usage_drivers_and_querying_6.py @@ -19,10 +19,12 @@ def test_example_6_sqlite_config(tmp_path: Path) -> None: spec = SQLSpec() - config = SqliteConfig(pool_config={"database": db_path, "timeout": 5.0, "check_same_thread": False}) + db = spec.add_config( + SqliteConfig(pool_config={"database": db_path, "timeout": 5.0, "check_same_thread": False}) + ) try: - with spec.provide_session(config) as session: + with spec.provide_session(db) as session: # Create table session.execute(""" CREATE TABLE IF NOT EXISTS usage6_users ( @@ -39,6 +41,6 @@ def test_example_6_sqlite_config(tmp_path: Path) -> None: result.all() finally: # Clean up the temporary database file - config.close_pool() + spec.get_config(db).close_pool() Path(db_path).unlink() # end-example From 5aa02f33c10937442b4a8f31a72badeb9b29da1d Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 18:36:05 +0000 Subject: [PATCH 5/6] feat: refactor data dictionary imports to streamline foreign key metadata handling --- sqlspec/adapters/adbc/data_dictionary.py | 2 -- sqlspec/adapters/aiosqlite/data_dictionary.py | 4 +--- sqlspec/adapters/asyncmy/data_dictionary.py | 4 +--- sqlspec/adapters/asyncpg/data_dictionary.py | 4 +--- sqlspec/adapters/bigquery/data_dictionary.py | 4 +--- sqlspec/adapters/duckdb/data_dictionary.py | 4 +--- sqlspec/adapters/oracledb/data_dictionary.py | 4 +--- sqlspec/adapters/psycopg/data_dictionary.py | 4 +--- sqlspec/adapters/sqlite/data_dictionary.py | 4 +--- 9 files changed, 8 insertions(+), 26 deletions(-) diff --git a/sqlspec/adapters/adbc/data_dictionary.py b/sqlspec/adapters/adbc/data_dictionary.py index 78b43ceaa..dd7fb2b12 100644 --- a/sqlspec/adapters/adbc/data_dictionary.py +++ b/sqlspec/adapters/adbc/data_dictionary.py @@ -110,8 +110,6 @@ def get_foreign_keys( kcu = "information_schema.key_column_usage" if dialect == "postgres": - # Postgres joins with constraint_column_usage or referential_constraints - # Let's use the query we verified for asyncpg sql = """ SELECT kcu.table_name, diff --git a/sqlspec/adapters/aiosqlite/data_dictionary.py b/sqlspec/adapters/aiosqlite/data_dictionary.py index 98318a7e6..5ac5b5ee9 100644 --- a/sqlspec/adapters/aiosqlite/data_dictionary.py +++ b/sqlspec/adapters/aiosqlite/data_dictionary.py @@ -3,14 +3,13 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo +from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, ForeignKeyMetadata, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Callable from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver - from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.aiosqlite.data_dictionary") @@ -176,7 +175,6 @@ async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata aiosqlite_driver = cast("AiosqliteDriver", driver) diff --git a/sqlspec/adapters/asyncmy/data_dictionary.py b/sqlspec/adapters/asyncmy/data_dictionary.py index 75bffcb8b..623681869 100644 --- a/sqlspec/adapters/asyncmy/data_dictionary.py +++ b/sqlspec/adapters/asyncmy/data_dictionary.py @@ -3,14 +3,13 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo +from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, ForeignKeyMetadata, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Callable from sqlspec.adapters.asyncmy.driver import AsyncmyDriver - from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.asyncmy.data_dictionary") @@ -160,7 +159,6 @@ async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata asyncmy_driver = cast("AsyncmyDriver", driver) diff --git a/sqlspec/adapters/asyncpg/data_dictionary.py b/sqlspec/adapters/asyncpg/data_dictionary.py index 3fbdebfe6..f3decbbdb 100644 --- a/sqlspec/adapters/asyncpg/data_dictionary.py +++ b/sqlspec/adapters/asyncpg/data_dictionary.py @@ -3,14 +3,13 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo +from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, ForeignKeyMetadata, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Callable from sqlspec.adapters.asyncpg.driver import AsyncpgDriver - from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.asyncpg.data_dictionary") @@ -217,7 +216,6 @@ async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata asyncpg_driver = cast("AsyncpgDriver", driver) schema_name = schema or "public" diff --git a/sqlspec/adapters/bigquery/data_dictionary.py b/sqlspec/adapters/bigquery/data_dictionary.py index 1f66d77a0..554bc939d 100644 --- a/sqlspec/adapters/bigquery/data_dictionary.py +++ b/sqlspec/adapters/bigquery/data_dictionary.py @@ -2,12 +2,11 @@ from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo +from sqlspec.driver import ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from sqlspec.adapters.bigquery.driver import BigQueryDriver - from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.bigquery.data_dictionary") @@ -141,7 +140,6 @@ def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata bigquery_driver = cast("BigQueryDriver", driver) diff --git a/sqlspec/adapters/duckdb/data_dictionary.py b/sqlspec/adapters/duckdb/data_dictionary.py index efc6b22a1..75c68f20c 100644 --- a/sqlspec/adapters/duckdb/data_dictionary.py +++ b/sqlspec/adapters/duckdb/data_dictionary.py @@ -3,14 +3,13 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo +from sqlspec.driver import ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Callable from sqlspec.adapters.duckdb.driver import DuckDBDriver - from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.duckdb.data_dictionary") @@ -197,7 +196,6 @@ def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata duckdb_driver = cast("DuckDBDriver", driver) diff --git a/sqlspec/adapters/oracledb/data_dictionary.py b/sqlspec/adapters/oracledb/data_dictionary.py index 721b30973..64c0b8f9c 100644 --- a/sqlspec/adapters/oracledb/data_dictionary.py +++ b/sqlspec/adapters/oracledb/data_dictionary.py @@ -7,6 +7,7 @@ from sqlspec.driver import ( AsyncDataDictionaryBase, AsyncDriverAdapterBase, + ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo, @@ -17,7 +18,6 @@ from collections.abc import Callable from sqlspec.adapters.oracledb.driver import OracleAsyncDriver, OracleSyncDriver - from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.oracledb.data_dictionary") @@ -453,7 +453,6 @@ def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata oracle_driver = cast("OracleSyncDriver", driver) result = oracle_driver.execute(self._get_foreign_keys_sql(table)) @@ -675,7 +674,6 @@ async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata oracle_driver = cast("OracleAsyncDriver", driver) result = await oracle_driver.execute(self._get_foreign_keys_sql(table)) diff --git a/sqlspec/adapters/psycopg/data_dictionary.py b/sqlspec/adapters/psycopg/data_dictionary.py index f9c4fc81c..522569336 100644 --- a/sqlspec/adapters/psycopg/data_dictionary.py +++ b/sqlspec/adapters/psycopg/data_dictionary.py @@ -6,6 +6,7 @@ from sqlspec.driver import ( AsyncDataDictionaryBase, AsyncDriverAdapterBase, + ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo, @@ -16,7 +17,6 @@ from collections.abc import Callable from sqlspec.adapters.psycopg.driver import PsycopgAsyncDriver, PsycopgSyncDriver - from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.psycopg.data_dictionary") @@ -222,7 +222,6 @@ def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata psycopg_driver = cast("PsycopgSyncDriver", driver) schema_name = schema or "public" @@ -528,7 +527,6 @@ async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata psycopg_driver = cast("PsycopgAsyncDriver", driver) schema_name = schema or "public" diff --git a/sqlspec/adapters/sqlite/data_dictionary.py b/sqlspec/adapters/sqlite/data_dictionary.py index c543c9649..ca6062d37 100644 --- a/sqlspec/adapters/sqlite/data_dictionary.py +++ b/sqlspec/adapters/sqlite/data_dictionary.py @@ -3,14 +3,13 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo +from sqlspec.driver import ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: from collections.abc import Callable from sqlspec.adapters.sqlite.driver import SqliteDriver - from sqlspec.driver import ForeignKeyMetadata logger = get_logger("adapters.sqlite.data_dictionary") @@ -176,7 +175,6 @@ def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - from sqlspec.driver import ForeignKeyMetadata sqlite_driver = cast("SqliteDriver", driver) From 017291745112040479604dfc97c050d2d973ec24 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 21 Nov 2025 19:08:37 +0000 Subject: [PATCH 6/6] feat: rename `get_tables_in_topological_order` to `get_tables` for consistency across database adapters --- docs/guides/architecture/data-dictionary.md | 4 +- sqlspec/adapters/aiosqlite/data_dictionary.py | 18 +- sqlspec/adapters/asyncmy/data_dictionary.py | 98 +++++---- sqlspec/adapters/asyncpg/data_dictionary.py | 95 +++++---- sqlspec/adapters/bigquery/data_dictionary.py | 48 ++++- sqlspec/adapters/duckdb/data_dictionary.py | 16 +- sqlspec/adapters/oracledb/data_dictionary.py | 30 +-- sqlspec/adapters/psycopg/data_dictionary.py | 186 +++++++++--------- sqlspec/adapters/sqlite/data_dictionary.py | 18 +- sqlspec/driver/_async.py | 19 -- sqlspec/driver/_sync.py | 19 -- .../test_aiosqlite/test_data_dictionary.py | 2 +- .../test_asyncpg/test_data_dictionary.py | 2 +- 13 files changed, 255 insertions(+), 300 deletions(-) diff --git a/docs/guides/architecture/data-dictionary.md b/docs/guides/architecture/data-dictionary.md index 706777424..7e0cddb1a 100644 --- a/docs/guides/architecture/data-dictionary.md +++ b/docs/guides/architecture/data-dictionary.md @@ -32,7 +32,7 @@ async with config.provide_session() as session: ### Topological Sort (Dependency Ordering) -One of the most powerful features is `get_tables_in_topological_order`. This returns table names sorted such that parent tables appear before child tables (tables with foreign keys to parents). +`get_tables` now returns table names sorted such that parent tables appear before child tables (tables with foreign keys to parents). This is essential for: @@ -42,7 +42,7 @@ This is essential for: ```python async with config.provide_session() as session: # Get tables sorted parent -> child - sorted_tables = await session.data_dictionary.get_tables_in_topological_order(session) + sorted_tables = await session.data_dictionary.get_tables(session) print("Insertion Order:", sorted_tables) print("Deletion Order:", list(reversed(sorted_tables))) diff --git a/sqlspec/adapters/aiosqlite/data_dictionary.py b/sqlspec/adapters/aiosqlite/data_dictionary.py index 5ac5b5ee9..f9f0cf732 100644 --- a/sqlspec/adapters/aiosqlite/data_dictionary.py +++ b/sqlspec/adapters/aiosqlite/data_dictionary.py @@ -129,13 +129,10 @@ async def get_columns( for row in result.data or [] ] - async def get_tables_in_topological_order( - self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": - """Get tables sorted by topological dependency order.""" + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using SQLite catalog.""" aiosqlite_driver = cast("AiosqliteDriver", driver) - # Assuming modern SQLite with pragma table-valued functions sql = """ WITH RECURSIVE dependency_tree AS ( SELECT @@ -162,20 +159,15 @@ async def get_tables_in_topological_order( AND m.name NOT LIKE 'sqlite_%' AND instr(dt.path, '/' || m.name || '/') = 0 ) - SELECT DISTINCT table_name, level FROM dependency_tree ORDER BY level, table_name; + SELECT DISTINCT table_name FROM dependency_tree ORDER BY level, table_name; """ - try: - result = await aiosqlite_driver.execute(sql) - return [row["table_name"] if isinstance(row, dict) else row[0] for row in result.data] - except Exception: - # Fallback to Python sort if TVF not supported or other error - return await super().get_tables_in_topological_order(driver, schema) + result = await aiosqlite_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - aiosqlite_driver = cast("AiosqliteDriver", driver) if table: diff --git a/sqlspec/adapters/asyncmy/data_dictionary.py b/sqlspec/adapters/asyncmy/data_dictionary.py index 623681869..86727ee1e 100644 --- a/sqlspec/adapters/asyncmy/data_dictionary.py +++ b/sqlspec/adapters/asyncmy/data_dictionary.py @@ -104,62 +104,60 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "VARCHAR(255)") - async def get_tables_in_topological_order( - self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": - """Get tables sorted by topological dependency order.""" + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using MySQL catalog. + + Requires MySQL 8.0.1+ for recursive CTE support. + """ version = await self.get_version(driver) - if version and version >= VersionInfo(8, 0, 1): - # Use Recursive CTE - asyncmy_driver = cast("AsyncmyDriver", driver) - schema_clause = f"'{schema}'" if schema else "DATABASE()" - - sql = f""" - WITH RECURSIVE dependency_tree AS ( - SELECT - table_name, - 0 AS level, - CAST(table_name AS CHAR(4000)) AS path - FROM information_schema.tables t - WHERE t.table_type = 'BASE TABLE' - AND t.table_schema = {schema_clause} - AND NOT EXISTS ( - SELECT 1 - FROM information_schema.key_column_usage kcu - WHERE kcu.table_name = t.table_name - AND kcu.table_schema = t.table_schema - AND kcu.referenced_table_name IS NOT NULL - ) - - UNION ALL - - SELECT - kcu.table_name, - dt.level + 1, - CONCAT(dt.path, ',', kcu.table_name) - FROM information_schema.key_column_usage kcu - JOIN dependency_tree dt ON kcu.referenced_table_name = dt.table_name - WHERE kcu.table_schema = {schema_clause} - AND kcu.referenced_table_name IS NOT NULL - AND NOT FIND_IN_SET(kcu.table_name, dt.path) - ) - SELECT DISTINCT table_name, level - FROM dependency_tree - ORDER BY level, table_name - """ - try: - result = await asyncmy_driver.execute(sql) - return [row["table_name"] for row in result.data] - except Exception as exc: - logger.warning("Failed to get tables in topological order via SQL: %s", exc) - - return await super().get_tables_in_topological_order(driver, schema) + asyncmy_driver = cast("AsyncmyDriver", driver) + + if not version or version < VersionInfo(8, 0, 1): + msg = "get_tables requires MySQL 8.0.1+ for dependency ordering" + raise RuntimeError(msg) + + schema_clause = f"'{schema}'" if schema else "DATABASE()" + + sql = f""" + WITH RECURSIVE dependency_tree AS ( + SELECT + table_name, + 0 AS level, + CAST(table_name AS CHAR(4000)) AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = {schema_clause} + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.key_column_usage kcu + WHERE kcu.table_name = t.table_name + AND kcu.table_schema = t.table_schema + AND kcu.referenced_table_name IS NOT NULL + ) + + UNION ALL + + SELECT + kcu.table_name, + dt.level + 1, + CONCAT(dt.path, ',', kcu.table_name) + FROM information_schema.key_column_usage kcu + JOIN dependency_tree dt ON kcu.referenced_table_name = dt.table_name + WHERE kcu.table_schema = {schema_clause} + AND kcu.referenced_table_name IS NOT NULL + AND NOT FIND_IN_SET(kcu.table_name, dt.path) + ) + SELECT DISTINCT table_name + FROM dependency_tree + ORDER BY level, table_name + """ + result = await asyncmy_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - asyncmy_driver = cast("AsyncmyDriver", driver) where_clauses = ["referenced_table_name IS NOT NULL"] diff --git a/sqlspec/adapters/asyncpg/data_dictionary.py b/sqlspec/adapters/asyncpg/data_dictionary.py index f3decbbdb..aad7fefa9 100644 --- a/sqlspec/adapters/asyncpg/data_dictionary.py +++ b/sqlspec/adapters/asyncpg/data_dictionary.py @@ -114,53 +114,7 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "TEXT") - async def get_columns( - self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None - ) -> "list[dict[str, Any]]": - """Get column information for a table using pg_catalog. - - Args: - driver: AsyncPG driver instance - table: Table name to query columns for - schema: Schema name (None for default 'public') - - Returns: - List of column metadata dictionaries with keys: - - column_name: Name of the column - - data_type: PostgreSQL data type - - is_nullable: Whether column allows NULL (YES/NO) - - column_default: Default value if any - - Notes: - Uses pg_catalog instead of information_schema to avoid potential - issues with PostgreSQL 'name' type in some drivers. - """ - asyncpg_driver = cast("AsyncpgDriver", driver) - - schema_name = schema or "public" - sql = """ - SELECT - a.attname::text AS column_name, - pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, - CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, - pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default - FROM pg_catalog.pg_attribute a - JOIN pg_catalog.pg_class c ON a.attrelid = c.oid - JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid - LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum - WHERE c.relname = $1 - AND n.nspname = $2 - AND a.attnum > 0 - AND NOT a.attisdropped - ORDER BY a.attnum - """ - - result = await asyncpg_driver.execute(sql, (table, schema_name)) - return result.data or [] - - async def get_tables_in_topological_order( - self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": """Get tables sorted by topological dependency order using Recursive CTE.""" asyncpg_driver = cast("AsyncpgDriver", driver) schema_name = schema or "public" @@ -210,13 +164,56 @@ async def get_tables_in_topological_order( ORDER BY level, table_name; """ result = await asyncpg_driver.execute(sql, (schema_name,)) - return [row["table_name"] for row in result.data] + return [row["table_name"] for row in result.get_data()] + + async def get_columns( + self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using pg_catalog. + + Args: + driver: AsyncPG driver instance + table: Table name to query columns for + schema: Schema name (None for default 'public') + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: PostgreSQL data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + + Notes: + Uses pg_catalog instead of information_schema to avoid potential + issues with PostgreSQL 'name' type in some drivers. + """ + asyncpg_driver = cast("AsyncpgDriver", driver) + + schema_name = schema or "public" + sql = """ + SELECT + a.attname::text AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, + pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default + FROM pg_catalog.pg_attribute a + JOIN pg_catalog.pg_class c ON a.attrelid = c.oid + JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid + LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum + WHERE c.relname = $1 + AND n.nspname = $2 + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum + """ + + result = await asyncpg_driver.execute(sql, (table, schema_name)) + return result.data or [] async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - asyncpg_driver = cast("AsyncpgDriver", driver) schema_name = schema or "public" diff --git a/sqlspec/adapters/bigquery/data_dictionary.py b/sqlspec/adapters/bigquery/data_dictionary.py index 554bc939d..7d9f41af2 100644 --- a/sqlspec/adapters/bigquery/data_dictionary.py +++ b/sqlspec/adapters/bigquery/data_dictionary.py @@ -126,21 +126,59 @@ def get_columns( return result.data or [] def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": - """Get list of tables in schema.""" + """Get tables sorted by topological dependency order using BigQuery catalog.""" bigquery_driver = cast("BigQueryDriver", driver) + if schema: - sql = f"SELECT table_name FROM `{schema}.INFORMATION_SCHEMA.TABLES` WHERE table_type = 'BASE TABLE'" + tables_table = f"`{schema}.INFORMATION_SCHEMA.TABLES`" + kcu_table = f"`{schema}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE`" + rc_table = f"`{schema}.INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS`" else: - sql = "SELECT table_name FROM INFORMATION_SCHEMA.TABLES WHERE table_type = 'BASE TABLE'" + tables_table = "INFORMATION_SCHEMA.TABLES" + kcu_table = "INFORMATION_SCHEMA.KEY_COLUMN_USAGE" + rc_table = "INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS" + + sql = f""" + WITH RECURSIVE dependency_tree AS ( + SELECT + t.table_name, + 0 AS level, + [t.table_name] AS path + FROM {tables_table} t + WHERE t.table_type = 'BASE TABLE' + AND NOT EXISTS ( + SELECT 1 + FROM {kcu_table} kcu + JOIN {rc_table} rc ON kcu.constraint_name = rc.constraint_name + WHERE kcu.table_name = t.table_name + ) + + UNION ALL + + SELECT + kcu.table_name, + dt.level + 1, + ARRAY_CONCAT(dt.path, [kcu.table_name]) + FROM {kcu_table} kcu + JOIN {rc_table} rc ON kcu.constraint_name = rc.constraint_name + JOIN {kcu_table} pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + JOIN dependency_tree dt ON pk_kcu.table_name = dt.table_name + WHERE kcu.table_name NOT IN UNNEST(dt.path) + ) + SELECT DISTINCT table_name + FROM dependency_tree + ORDER BY level, table_name + """ result = bigquery_driver.execute(sql) - return [row["table_name"] for row in result.data] + return [row["table_name"] for row in result.get_data()] def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - bigquery_driver = cast("BigQueryDriver", driver) dataset = schema diff --git a/sqlspec/adapters/duckdb/data_dictionary.py b/sqlspec/adapters/duckdb/data_dictionary.py index 75c68f20c..a16da71e4 100644 --- a/sqlspec/adapters/duckdb/data_dictionary.py +++ b/sqlspec/adapters/duckdb/data_dictionary.py @@ -141,10 +141,8 @@ def get_columns( result = duckdb_driver.execute(sql) return result.data or [] - def get_tables_in_topological_order( - self, driver: "SyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": - """Get tables sorted by topological dependency order.""" + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using DuckDB catalog.""" duckdb_driver = cast("DuckDBDriver", driver) schema_clause = f"'{schema}'" if schema else "current_schema()" @@ -184,19 +182,13 @@ def get_tables_in_topological_order( FROM dependency_tree ORDER BY level, table_name """ - try: - result = duckdb_driver.execute(sql) - return [row["table_name"] for row in result.get_data()] - except Exception: - return self.sort_tables_topologically( - self.get_tables(driver, schema), self.get_foreign_keys(driver, schema=schema) - ) + result = duckdb_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - duckdb_driver = cast("DuckDBDriver", driver) where_clauses = [] diff --git a/sqlspec/adapters/oracledb/data_dictionary.py b/sqlspec/adapters/oracledb/data_dictionary.py index 64c0b8f9c..6d5b55b4e 100644 --- a/sqlspec/adapters/oracledb/data_dictionary.py +++ b/sqlspec/adapters/oracledb/data_dictionary.py @@ -174,7 +174,7 @@ def _get_indexes_sql(self, table: "str | None" = None) -> str: GROUP BY i.index_name, i.table_name, i.uniqueness """ - def _get_topological_sort_sql(self) -> str: + def _get_tables_sql(self) -> str: return """ SELECT table_name, MAX(LEVEL) as lvl FROM user_constraints @@ -429,20 +429,16 @@ def get_columns( result = oracle_driver.execute(self._get_columns_sql(table, schema)) return result.get_data() - def get_tables_in_topological_order( - self, driver: "SyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": - """Get tables sorted by topological dependency order.""" + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using Oracle CONNECT BY.""" oracle_driver = cast("OracleSyncDriver", driver) - # Fetch dependency sorted tables - result = oracle_driver.execute(self._get_topological_sort_sql()) - sorted_tables = [row["table_name"] for row in result.get_data()] - # Fetch all tables + result = oracle_driver.execute(self._get_tables_sql()) + sorted_tables = [row["table_name"] for row in result] + all_result = oracle_driver.execute("SELECT table_name FROM user_tables") all_tables = {row["table_name"] for row in all_result.get_data()} - # Add disconnected tables (level 0 implied) at the beginning sorted_set = set(sorted_tables) disconnected = list(all_tables - sorted_set) disconnected.sort() @@ -453,7 +449,6 @@ def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - oracle_driver = cast("OracleSyncDriver", driver) result = oracle_driver.execute(self._get_foreign_keys_sql(table)) @@ -650,20 +645,16 @@ async def get_columns( result = await oracle_driver.execute(self._get_columns_sql(table, schema)) return result.get_data() - async def get_tables_in_topological_order( - self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": - """Get tables sorted by topological dependency order.""" + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using Oracle CONNECT BY.""" oracle_driver = cast("OracleAsyncDriver", driver) - # Fetch dependency sorted tables - result = await oracle_driver.execute(self._get_topological_sort_sql()) + + result = await oracle_driver.execute(self._get_tables_sql()) sorted_tables = [row["table_name"] for row in result.get_data()] - # Fetch all tables all_result = await oracle_driver.execute("SELECT table_name FROM user_tables") all_tables = {row["table_name"] for row in all_result.get_data()} - # Add disconnected tables (level 0 implied) at the beginning sorted_set = set(sorted_tables) disconnected = list(all_tables - sorted_set) disconnected.sort() @@ -674,7 +665,6 @@ async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - oracle_driver = cast("OracleAsyncDriver", driver) result = await oracle_driver.execute(self._get_foreign_keys_sql(table)) diff --git a/sqlspec/adapters/psycopg/data_dictionary.py b/sqlspec/adapters/psycopg/data_dictionary.py index 522569336..2a1768b9b 100644 --- a/sqlspec/adapters/psycopg/data_dictionary.py +++ b/sqlspec/adapters/psycopg/data_dictionary.py @@ -120,53 +120,7 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> } return type_map.get(type_category, "TEXT") - def get_columns( - self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None - ) -> "list[dict[str, Any]]": - """Get column information for a table using pg_catalog. - - Args: - driver: Psycopg sync driver instance - table: Table name to query columns for - schema: Schema name (None for default 'public') - - Returns: - List of column metadata dictionaries with keys: - - column_name: Name of the column - - data_type: PostgreSQL data type - - is_nullable: Whether column allows NULL (YES/NO) - - column_default: Default value if any - - Notes: - Uses pg_catalog instead of information_schema to avoid potential - issues with PostgreSQL 'name' type in some drivers. - """ - psycopg_driver = cast("PsycopgSyncDriver", driver) - - schema_name = schema or "public" - sql = """ - SELECT - a.attname::text AS column_name, - pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, - CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, - pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default - FROM pg_catalog.pg_attribute a - JOIN pg_catalog.pg_class c ON a.attrelid = c.oid - JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid - LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum - WHERE c.relname = %s - AND n.nspname = %s - AND a.attnum > 0 - AND NOT a.attisdropped - ORDER BY a.attnum - """ - - result = psycopg_driver.execute(sql, (table, schema_name)) - return result.data or [] - - def get_tables_in_topological_order( - self, driver: "SyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": """Get tables sorted by topological dependency order using Recursive CTE.""" psycopg_driver = cast("PsycopgSyncDriver", driver) schema_name = schema or "public" @@ -218,11 +172,54 @@ def get_tables_in_topological_order( result = psycopg_driver.execute(sql, (schema_name, schema_name, schema_name)) return [row["table_name"] for row in result.data] + def get_columns( + self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using pg_catalog. + + Args: + driver: Psycopg sync driver instance + table: Table name to query columns for + schema: Schema name (None for default 'public') + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: PostgreSQL data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + + Notes: + Uses pg_catalog instead of information_schema to avoid potential + issues with PostgreSQL 'name' type in some drivers. + """ + psycopg_driver = cast("PsycopgSyncDriver", driver) + + schema_name = schema or "public" + sql = """ + SELECT + a.attname::text AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, + pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default + FROM pg_catalog.pg_attribute a + JOIN pg_catalog.pg_class c ON a.attrelid = c.oid + JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid + LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum + WHERE c.relname = %s + AND n.nspname = %s + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum + """ + + result = psycopg_driver.execute(sql, (table, schema_name)) + return result.data or [] + def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - psycopg_driver = cast("PsycopgSyncDriver", driver) schema_name = schema or "public" @@ -425,53 +422,7 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "TEXT") - async def get_columns( - self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None - ) -> "list[dict[str, Any]]": - """Get column information for a table using pg_catalog. - - Args: - driver: Psycopg async driver instance - table: Table name to query columns for - schema: Schema name (None for default 'public') - - Returns: - List of column metadata dictionaries with keys: - - column_name: Name of the column - - data_type: PostgreSQL data type - - is_nullable: Whether column allows NULL (YES/NO) - - column_default: Default value if any - - Notes: - Uses pg_catalog instead of information_schema to avoid potential - issues with PostgreSQL 'name' type in some drivers. - """ - psycopg_driver = cast("PsycopgAsyncDriver", driver) - - schema_name = schema or "public" - sql = """ - SELECT - a.attname::text AS column_name, - pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, - CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, - pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default - FROM pg_catalog.pg_attribute a - JOIN pg_catalog.pg_class c ON a.attrelid = c.oid - JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid - LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum - WHERE c.relname = %s - AND n.nspname = %s - AND a.attnum > 0 - AND NOT a.attisdropped - ORDER BY a.attnum - """ - - result = await psycopg_driver.execute(sql, (table, schema_name)) - return result.data or [] - - async def get_tables_in_topological_order( - self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": """Get tables sorted by topological dependency order using Recursive CTE.""" psycopg_driver = cast("PsycopgAsyncDriver", driver) schema_name = schema or "public" @@ -523,11 +474,54 @@ async def get_tables_in_topological_order( result = await psycopg_driver.execute(sql, (schema_name, schema_name, schema_name)) return [row["table_name"] for row in result.data] + async def get_columns( + self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using pg_catalog. + + Args: + driver: Psycopg async driver instance + table: Table name to query columns for + schema: Schema name (None for default 'public') + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: PostgreSQL data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + + Notes: + Uses pg_catalog instead of information_schema to avoid potential + issues with PostgreSQL 'name' type in some drivers. + """ + psycopg_driver = cast("PsycopgAsyncDriver", driver) + + schema_name = schema or "public" + sql = """ + SELECT + a.attname::text AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable, + pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default + FROM pg_catalog.pg_attribute a + JOIN pg_catalog.pg_class c ON a.attrelid = c.oid + JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid + LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum + WHERE c.relname = %s + AND n.nspname = %s + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum + """ + + result = await psycopg_driver.execute(sql, (table, schema_name)) + return result.data or [] + async def get_foreign_keys( self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - psycopg_driver = cast("PsycopgAsyncDriver", driver) schema_name = schema or "public" diff --git a/sqlspec/adapters/sqlite/data_dictionary.py b/sqlspec/adapters/sqlite/data_dictionary.py index ca6062d37..ff0ab28b5 100644 --- a/sqlspec/adapters/sqlite/data_dictionary.py +++ b/sqlspec/adapters/sqlite/data_dictionary.py @@ -129,10 +129,8 @@ def get_columns( for row in result.data or [] ] - def get_tables_in_topological_order( - self, driver: "SyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": - """Get tables sorted by topological dependency order.""" + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using SQLite catalog.""" sqlite_driver = cast("SqliteDriver", driver) sql = """ @@ -161,21 +159,15 @@ def get_tables_in_topological_order( AND m.name NOT LIKE 'sqlite_%' AND instr(dt.path, '/' || m.name || '/') = 0 ) - SELECT DISTINCT table_name, level FROM dependency_tree ORDER BY level, table_name; + SELECT DISTINCT table_name FROM dependency_tree ORDER BY level, table_name; """ - try: - result = sqlite_driver.execute(sql) - return [row["table_name"] if isinstance(row, dict) else row[0] for row in result.data] - except Exception: - return self.sort_tables_topologically( - self.get_tables(driver, schema), self.get_foreign_keys(driver, schema=schema) - ) + result = sqlite_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] def get_foreign_keys( self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None ) -> "list[ForeignKeyMetadata]": """Get foreign key metadata.""" - sqlite_driver = cast("SqliteDriver", driver) if table: diff --git a/sqlspec/driver/_async.py b/sqlspec/driver/_async.py index 4bd957de6..399b2e488 100644 --- a/sqlspec/driver/_async.py +++ b/sqlspec/driver/_async.py @@ -764,25 +764,6 @@ async def get_foreign_keys( _ = driver, table, schema return [] - async def get_tables_in_topological_order( - self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": - """Get tables sorted by topological dependency order. - - Default implementation fetches all tables and foreign keys, - then uses Python's topological sort. - - Args: - driver: Async database driver instance - schema: Optional schema name - - Returns: - List of table names sorted by dependency - """ - tables = await self.get_tables(driver, schema) - foreign_keys = await self.get_foreign_keys(driver, schema=schema) - return self.sort_tables_topologically(tables, foreign_keys) - def list_available_features(self) -> "list[str]": """List all features that can be checked via get_feature_flag. diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index f55110e63..146b7b864 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -764,25 +764,6 @@ def get_foreign_keys( _ = driver, table, schema return [] - def get_tables_in_topological_order( - self, driver: "SyncDriverAdapterBase", schema: "str | None" = None - ) -> "list[str]": - """Get tables sorted by topological dependency order. - - Default implementation fetches all tables and foreign keys, - then uses Python's topological sort. - - Args: - driver: Sync database driver instance - schema: Optional schema name - - Returns: - List of table names sorted by dependency - """ - tables = self.get_tables(driver, schema) - foreign_keys = self.get_foreign_keys(driver, schema=schema) - return self.sort_tables_topologically(tables, foreign_keys) - def list_available_features(self) -> "list[str]": """List all features that can be checked via get_feature_flag. diff --git a/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py b/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py index 9dd808546..ed0557344 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py @@ -40,7 +40,7 @@ async def test_aiosqlite_data_dictionary_topology_and_fks(aiosqlite_session: "Ai try: # Test 1: Topological Sort - sorted_tables = await aiosqlite_driver.data_dictionary.get_tables_in_topological_order(aiosqlite_driver) + sorted_tables = await aiosqlite_driver.data_dictionary.get_tables(aiosqlite_driver) test_tables = [t for t in sorted_tables if t in (users_table, orders_table, items_table)] assert len(test_tables) == 3 diff --git a/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py b/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py index c2bc5578f..547866702 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py +++ b/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py @@ -136,7 +136,7 @@ async def test_asyncpg_data_dictionary_topology_and_fks(asyncpg_async_driver: "A try: # Test 1: Topological Sort - sorted_tables = await asyncpg_async_driver.data_dictionary.get_tables_in_topological_order(asyncpg_async_driver) + sorted_tables = await asyncpg_async_driver.data_dictionary.get_tables(asyncpg_async_driver) test_tables = [t for t in sorted_tables if t in (users_table, orders_table, items_table)] assert len(test_tables) == 3