diff --git a/AGENTS.md b/AGENTS.md index d9f031851..7e35fcdae 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -451,6 +451,58 @@ if supports_where(obj): result = obj.where("condition") ``` +### Mypyc-Compatible Metadata Class Pattern + +When defining data-holding classes intended for core modules (`sqlspec/core/`, `sqlspec/driver/`) that will be compiled with MyPyC, use regular classes with `__slots__` and explicitly implement `__init__`, `__repr__`, `__eq__`, and `__hash__`. This approach ensures optimal performance and MyPyC compatibility, as `dataclasses` are not directly supported by MyPyC for compilation. + +**Key Principles:** + +- **`__slots__`**: Reduces memory footprint and speeds up attribute access. +- **Explicit `__init__`**: Defines the constructor for the class. +- **Explicit `__repr__`**: Provides a clear string representation for debugging. +- **Explicit `__eq__`**: Enables correct equality comparisons. +- **Explicit `__hash__`**: Makes instances hashable, allowing them to be used in sets or as dictionary keys. The hash implementation should be based on all fields that define the object's identity. + +**Example Implementation:** + +```python +class MyMetadata: + __slots__ = ("field1", "field2", "optional_field") + + def __init__(self, field1: str, field2: int, optional_field: str | None = None) -> None: + self.field1 = field1 + self.field2 = field2 + self.optional_field = optional_field + + def __repr__(self) -> str: + return f"MyMetadata(field1={self.field1!r}, field2={self.field2!r}, optional_field={self.optional_field!r})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, MyMetadata): + return NotImplemented + return ( + self.field1 == other.field1 + and self.field2 == other.field2 + and self.optional_field == other.optional_field + ) + + def __hash__(self) -> int: + return hash((self.field1, self.field2, self.optional_field)) +``` + +**When to Use:** + +- For all new data-holding classes in performance-critical paths (e.g., `sqlspec/driver/_common.py`). +- When MyPyC compilation is enabled for the module containing the class. + +**Anti-Patterns to Avoid:** + +- Using `@dataclass` decorators for classes intended for MyPyC compilation. +- Omitting `__slots__` when defining performance-critical data structures. +- Relying on default `__eq__` or `__hash__` behavior for complex objects, especially for equality comparisons in collections. + +--- + ### Performance Patterns (MANDATORY) **PERF401 - List Operations**: diff --git a/docs/examples/usage/usage_drivers_and_querying_10.py b/docs/examples/usage/usage_drivers_and_querying_10.py index 84e06aef5..138ef5f0d 100644 --- a/docs/examples/usage/usage_drivers_and_querying_10.py +++ b/docs/examples/usage/usage_drivers_and_querying_10.py @@ -12,30 +12,57 @@ def test_example_10_duckdb_config(tmp_path: Path) -> None: # start-example + import tempfile + from sqlspec import SQLSpec from sqlspec.adapters.duckdb import DuckDBConfig - spec = SQLSpec() - # In-memory - config = DuckDBConfig() - - # Persistent - database_file = tmp_path / "analytics.duckdb" - config = DuckDBConfig(pool_config={"database": database_file.name, "read_only": False}) - - with spec.provide_session(config) as session: - # Create table from Parquet - session.execute(f""" - CREATE TABLE if not exists users AS - SELECT * FROM read_parquet('{Path(__file__).parent.parent / "queries/users.parquet"}') - """) - - # Analytical query - session.execute(""" - SELECT date_trunc('day', created_at) as day, - count(*) as user_count - FROM users - GROUP BY day - ORDER BY day - """) + # Use a temporary directory for the DuckDB database for test isolation + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "analytics.duckdb" + + spec = SQLSpec() + # In-memory + in_memory_db = spec.add_config(DuckDBConfig()) + persistent_db = spec.add_config(DuckDBConfig(pool_config={"database": str(db_path)})) + + try: + # Test with in-memory config + with spec.provide_session(in_memory_db) as session: + # Create table from Parquet + session.execute(f""" + CREATE TABLE if not exists users AS + SELECT * FROM read_parquet('{Path(__file__).parent.parent / "queries/users.parquet"}') + """) + + # Analytical query + session.execute(""" + SELECT date_trunc('day', created_at) as day, + count(*) as user_count + FROM users + GROUP BY day + ORDER BY day + """) + + # Test with persistent config + with spec.provide_session(persistent_db) as session: + # Create table from Parquet + session.execute(f""" + CREATE TABLE if not exists users AS + SELECT * FROM read_parquet('{Path(__file__).parent.parent / "queries/users.parquet"}') + """) + + # Analytical query + session.execute(""" + SELECT date_trunc('day', created_at) as day, + count(*) as user_count + FROM users + GROUP BY day + ORDER BY day + """) + finally: + # Close the pool for the persistent config + spec.get_config(in_memory_db).close_pool() + spec.get_config(persistent_db).close_pool() + # The TemporaryDirectory context manager handles directory cleanup automatically # end-example diff --git a/docs/examples/usage/usage_drivers_and_querying_6.py b/docs/examples/usage/usage_drivers_and_querying_6.py index 18ebaa641..419555b33 100644 --- a/docs/examples/usage/usage_drivers_and_querying_6.py +++ b/docs/examples/usage/usage_drivers_and_querying_6.py @@ -1,6 +1,7 @@ # Test module converted from docs example - code-block 6 """Minimal smoke test for drivers_and_querying example 6.""" +import tempfile from pathlib import Path from sqlspec import SQLSpec @@ -12,24 +13,34 @@ def test_example_6_sqlite_config(tmp_path: Path) -> None: # start-example from sqlspec.adapters.sqlite import SqliteConfig - spec = SQLSpec() - - database_file = tmp_path / "myapp.db" - config = SqliteConfig(pool_config={"database": database_file.name, "timeout": 5.0, "check_same_thread": False}) - - with spec.provide_session(config) as session: - # Create table - session.execute(""" - CREATE TABLE IF NOT EXISTS usage6_users ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL - ) - """) - - # Insert with parameters - session.execute("INSERT INTO usage6_users (name) VALUES (?)", "Alice") - - # Query - result = session.execute("SELECT * FROM usage6_users") - result.all() - # end-example + # Use a temporary file for the SQLite database for test isolation + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_db_file: + db_path = tmp_db_file.name + + spec = SQLSpec() + + db = spec.add_config( + SqliteConfig(pool_config={"database": db_path, "timeout": 5.0, "check_same_thread": False}) + ) + + try: + with spec.provide_session(db) as session: + # Create table + session.execute(""" + CREATE TABLE IF NOT EXISTS usage6_users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ) + """) + + # Insert with parameters + session.execute("INSERT INTO usage6_users (name) VALUES (?)", "Alice") + + # Query + result = session.execute("SELECT * FROM usage6_users") + result.all() + finally: + # Clean up the temporary database file + spec.get_config(db).close_pool() + Path(db_path).unlink() + # end-example diff --git a/docs/extensions/aiosql/api.rst b/docs/extensions/aiosql/api.rst index 6fe53f376..28f367484 100644 --- a/docs/extensions/aiosql/api.rst +++ b/docs/extensions/aiosql/api.rst @@ -24,6 +24,7 @@ AiosqlAsyncAdapter :members: :undoc-members: :show-inheritance: + :no-index: AiosqlSyncAdapter ----------------- @@ -32,6 +33,7 @@ AiosqlSyncAdapter :members: :undoc-members: :show-inheritance: + :no-index: Query Operators =============== diff --git a/docs/extensions/litestar/api.rst b/docs/extensions/litestar/api.rst index d2c157ede..eb799a01d 100644 --- a/docs/extensions/litestar/api.rst +++ b/docs/extensions/litestar/api.rst @@ -11,6 +11,7 @@ SQLSpecPlugin :members: :undoc-members: :show-inheritance: + :no-index: Configuration ============= diff --git a/docs/guides/architecture/data-dictionary.md b/docs/guides/architecture/data-dictionary.md new file mode 100644 index 000000000..7e0cddb1a --- /dev/null +++ b/docs/guides/architecture/data-dictionary.md @@ -0,0 +1,83 @@ +# Data Dictionary & Introspection + +SQLSpec provides a unified Data Dictionary API to introspect database schemas across all supported adapters. This allows you to retrieve table metadata, columns, indexes, and foreign keys in a consistent format, regardless of the underlying database engine. + +## Core Concepts + +The `DataDictionary` is accessed via the `driver.data_dictionary` property. It provides methods to query the database catalog. + +### Introspection Capabilities + +- **Tables**: List tables in a schema. +- **Columns**: Get column details (name, type, nullable, default). +- **Indexes**: Get index definitions (columns, uniqueness). +- **Foreign Keys**: Get foreign key constraints and relationships. +- **Topological Sorting**: Get tables sorted by dependency order (useful for cleanups or migrations). + +## Usage + +### Basic Introspection + +```python +async with config.provide_session() as session: + # Get all tables in the default schema + tables = await session.data_dictionary.get_tables(session) + print(f"Tables: {tables}") + + # Get columns for a specific table + columns = await session.data_dictionary.get_columns(session, "users") + for col in columns: + print(f"{col['column_name']}: {col['data_type']}") +``` + +### Topological Sort (Dependency Ordering) + +`get_tables` now returns table names sorted such that parent tables appear before child tables (tables with foreign keys to parents). + +This is essential for: + +- **Data Loading**: Insert into parents first. +- **Cleanup**: Delete in reverse order to avoid foreign key violations. + +```python +async with config.provide_session() as session: + # Get tables sorted parent -> child + sorted_tables = await session.data_dictionary.get_tables(session) + + print("Insertion Order:", sorted_tables) + print("Deletion Order:", list(reversed(sorted_tables))) +``` + +**Implementation Details**: + +- **Postgres / SQLite / MySQL 8+**: Uses efficient Recursive CTEs in SQL. +- **Oracle**: Uses `CONNECT BY` queries. +- **Others (BigQuery, MySQL 5.7)**: Falls back to a Python-based topological sort using `graphlib`. + +### Metadata Types + +SQLSpec uses regular classes with __slots__ for metadata results to ensure mypyc compatibility and memory efficiency. + +```python +from sqlspec.driver import ForeignKeyMetadata + +async with config.provide_session() as session: + fks: list[ForeignKeyMetadata] = await session.data_dictionary.get_foreign_keys(session, "orders") + + for fk in fks: + print(f"FK: {fk.column_name} -> {fk.referenced_table}.{fk.referenced_column}") +``` + +## Adapter Support Matrix + +| Feature | Postgres | SQLite | Oracle | MySQL | DuckDB | BigQuery | +|---------|----------|--------|--------|-------|--------|----------| +| Tables | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Columns | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Indexes | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | +| Foreign Keys | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | +| Topological Sort | ✅ (CTE) | ✅ (CTE) | ✅ (Connect By) | ✅ (CTE/Python) | ✅ (CTE) | ✅ (Python) | + +## API Reference + +For a complete API reference of the Data Dictionary components, including `DataDictionaryMixin`, `AsyncDataDictionaryBase`, `SyncDataDictionaryBase`, and the metadata classes (`ForeignKeyMetadata`, `ColumnMetadata`, `IndexMetadata`), please refer to the :doc:`/reference/driver`. diff --git a/docs/reference/driver.rst b/docs/reference/driver.rst index c404a762d..7a0e89a46 100644 --- a/docs/reference/driver.rst +++ b/docs/reference/driver.rst @@ -103,6 +103,43 @@ Connection Pooling :undoc-members: :show-inheritance: +Data Dictionary +=============== + +The Data Dictionary API provides standardized introspection capabilities across all supported databases. + +.. currentmodule:: sqlspec.driver + +.. autoclass:: DataDictionaryMixin + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: AsyncDataDictionaryBase + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: SyncDataDictionaryBase + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ForeignKeyMetadata + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: ColumnMetadata + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: IndexMetadata + :members: + :undoc-members: + :show-inheritance: + Driver Protocols ================ diff --git a/docs/usage/drivers_and_querying.rst b/docs/usage/drivers_and_querying.rst index d73bd66ee..3f57f3940 100644 --- a/docs/usage/drivers_and_querying.rst +++ b/docs/usage/drivers_and_querying.rst @@ -471,7 +471,7 @@ Performance Tips :start-after: # start-example :end-before: # end-example :caption: ``asyncpg connection pooling`` - :dedent: 4 + :dedent: 2 **2. Batch Operations** diff --git a/sqlspec/adapters/adbc/data_dictionary.py b/sqlspec/adapters/adbc/data_dictionary.py index 77c4f14cc..dd7fb2b12 100644 --- a/sqlspec/adapters/adbc/data_dictionary.py +++ b/sqlspec/adapters/adbc/data_dictionary.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo +from sqlspec.driver import ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -27,6 +27,181 @@ class AdbcDataDictionary(SyncDataDictionaryBase): Delegates to appropriate dialect-specific logic based on the driver's dialect. """ + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata based on detected dialect.""" + + dialect = self._get_dialect(driver) + adbc_driver = cast("AdbcDriver", driver) + + if dialect == "sqlite": + if table: + # Single table + result = adbc_driver.execute(f"PRAGMA foreign_key_list('{table}')") + # SQLite PRAGMA returns: id, seq, table, from, to, on_update, on_delete, match + # We need 'from' (col) and 'table' (ref_table) and 'to' (ref_col) + # Note: PRAGMA results from ADBC might be keyed by name or index depending on driver + return [ + ForeignKeyMetadata( + table_name=table, + column_name=row["from"] if isinstance(row, dict) else row[3], + referenced_table=row["table"] if isinstance(row, dict) else row[2], + referenced_column=row["to"] if isinstance(row, dict) else row[4], + ) + for row in result.data + ] + # For all tables in SQLite we'd have to iterate, which base class doesn't do efficiently. + # We'll just return empty if no table specified for now, or iterate if crucial. + # Base implementation will call this per-table if needed? No, base implementation expects all if table is None. + # For SQLite ADBC, iterating tables is expensive. Let's support single table primarily. + return [] + + # SQL-standard compliant databases (Postgres, MySQL, DuckDB, BigQuery) + # They all support information_schema.key_column_usage roughly the same way + + # Postgres/DuckDB/MySQL query + params = [] + + if dialect == "bigquery": + dataset = schema + if not dataset: + return [] # BigQuery requires dataset for info schema + kcu = f"`{dataset}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE`" + rc = f"`{dataset}.INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS`" + # BQ uses named params usually or positional? ADBC usually positional '?' + # But BQ driver might want named. ADBC standardizes on '?' usually. + sql = f""" + SELECT + kcu.table_name, + kcu.column_name, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name, + kcu.constraint_name, + kcu.table_schema, + pk_kcu.table_schema AS referenced_table_schema + FROM {kcu} kcu + JOIN {rc} rc ON kcu.constraint_name = rc.constraint_name + JOIN {kcu} pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + """ + if table: + sql += f" WHERE kcu.table_name = '{table}'" # Simple string sub for BQ ADBC safety check needed? + + try: + result = adbc_driver.execute(sql) + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + except Exception: + return [] + + # Standard ANSI SQL (Postgres, MySQL, DuckDB) + kcu = "information_schema.key_column_usage" + + if dialect == "postgres": + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + ccu.table_name AS referenced_table_name, + ccu.column_name AS referenced_column_name, + tc.constraint_name, + tc.table_schema, + ccu.table_schema AS referenced_table_schema + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + """ + if schema: + sql += " AND tc.table_schema = ?" + params.append(schema) + if table: + sql += " AND tc.table_name = ?" + params.append(table) + + elif dialect == "mysql": + # MySQL information_schema + sql = """ + SELECT + table_name, + column_name, + referenced_table_name, + referenced_column_name, + constraint_name, + table_schema, + referenced_table_schema + FROM information_schema.key_column_usage + WHERE referenced_table_name IS NOT NULL + """ + if schema: + sql += " AND table_schema = ?" + params.append(schema) + if table: + sql += " AND table_name = ?" + params.append(table) + + elif dialect == "duckdb": + # DuckDB similar to Postgres but sometimes requires referential_constraints join + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name, + kcu.constraint_name, + kcu.table_schema, + pk_kcu.table_schema AS referenced_table_schema + FROM information_schema.key_column_usage kcu + JOIN information_schema.referential_constraints rc + ON kcu.constraint_name = rc.constraint_name + JOIN information_schema.key_column_usage pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + WHERE 1=1 + """ + if schema: + sql += " AND kcu.table_schema = ?" + params.append(schema) + if table: + sql += " AND kcu.table_name = ?" + params.append(table) + else: + return [] + + try: + result = adbc_driver.execute(sql, tuple(params)) + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row.get("table_schema"), + referenced_schema=row.get("referenced_table_schema"), + ) + for row in result.data + ] + except Exception: + return [] + def _get_dialect(self, driver: SyncDriverAdapterBase) -> str: """Get dialect from ADBC driver. diff --git a/sqlspec/adapters/aiosqlite/data_dictionary.py b/sqlspec/adapters/aiosqlite/data_dictionary.py index d841cce2f..f9f0cf732 100644 --- a/sqlspec/adapters/aiosqlite/data_dictionary.py +++ b/sqlspec/adapters/aiosqlite/data_dictionary.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo +from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, ForeignKeyMetadata, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -129,6 +129,117 @@ async def get_columns( for row in result.data or [] ] + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using SQLite catalog.""" + aiosqlite_driver = cast("AiosqliteDriver", driver) + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + m.name as table_name, + 0 as level, + '/' || m.name || '/' as path + FROM sqlite_schema m + WHERE m.type = 'table' + AND m.name NOT LIKE 'sqlite_%' + AND NOT EXISTS ( + SELECT 1 FROM pragma_foreign_key_list(m.name) + ) + + UNION ALL + + SELECT + m.name as table_name, + dt.level + 1, + dt.path || m.name || '/' + FROM sqlite_schema m + JOIN pragma_foreign_key_list(m.name) fk + JOIN dependency_tree dt ON fk."table" = dt.table_name + WHERE m.type = 'table' + AND m.name NOT LIKE 'sqlite_%' + AND instr(dt.path, '/' || m.name || '/') = 0 + ) + SELECT DISTINCT table_name FROM dependency_tree ORDER BY level, table_name; + """ + result = await aiosqlite_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] + + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + aiosqlite_driver = cast("AiosqliteDriver", driver) + + if table: + # Single table optimization + sql = f"SELECT '{table}' as table_name, fk.* FROM pragma_foreign_key_list('{table}') fk" + result = await aiosqlite_driver.execute(sql) + else: + # All tables + sql = """ + SELECT m.name as table_name, fk.* + FROM sqlite_schema m, pragma_foreign_key_list(m.name) fk + WHERE m.type = 'table' AND m.name NOT LIKE 'sqlite_%' + """ + result = await aiosqlite_driver.execute(sql) + + fks = [] + for row in result.data: + if isinstance(row, (list, tuple)): + table_name = row[0] + ref_table = row[3] + col = row[4] + ref_col = row[5] + else: + table_name = row["table_name"] + ref_table = row["table"] + col = row["from"] + ref_col = row["to"] + + fks.append( + ForeignKeyMetadata( + table_name=table_name, + column_name=col, + referenced_table=ref_table, + referenced_column=ref_col, + constraint_name=None, + schema=None, + referenced_schema=None, + ) + ) + return fks + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + aiosqlite_driver = cast("AiosqliteDriver", driver) + + # 1. Get indexes for table + index_list_res = await aiosqlite_driver.execute(f"PRAGMA index_list('{table}')") + indexes = [] + + for idx_row in index_list_res.data: + if isinstance(idx_row, (list, tuple)): + idx_name = idx_row[1] + unique = bool(idx_row[2]) + else: + idx_name = idx_row["name"] + unique = bool(idx_row["unique"]) + + # 2. Get columns for index + info_res = await aiosqlite_driver.execute(f"PRAGMA index_info('{idx_name}')") + cols = [] + for col_row in info_res.data: + if isinstance(col_row, (list, tuple)): + cols.append(col_row[2]) + else: + cols.append(col_row["name"]) + + indexes.append({"name": idx_name, "columns": cols, "unique": unique, "primary": False, "table_name": table}) + + return indexes + def list_available_features(self) -> "list[str]": """List available SQLite feature flags. diff --git a/sqlspec/adapters/asyncmy/data_dictionary.py b/sqlspec/adapters/asyncmy/data_dictionary.py index c8bd81421..86727ee1e 100644 --- a/sqlspec/adapters/asyncmy/data_dictionary.py +++ b/sqlspec/adapters/asyncmy/data_dictionary.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo +from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, ForeignKeyMetadata, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -104,42 +104,122 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "VARCHAR(255)") - async def get_columns( - self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None - ) -> "list[dict[str, Any]]": - """Get column information for a table using information_schema. + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using MySQL catalog. - Args: - driver: AsyncMy driver instance - table: Table name to query columns for - schema: Schema name (database name in MySQL) + Requires MySQL 8.0.1+ for recursive CTE support. + """ + version = await self.get_version(driver) + asyncmy_driver = cast("AsyncmyDriver", driver) - Returns: - List of column metadata dictionaries with keys: - - column_name: Name of the column - - data_type: MySQL data type - - is_nullable: Whether column allows NULL (YES/NO) - - column_default: Default value if any + if not version or version < VersionInfo(8, 0, 1): + msg = "get_tables requires MySQL 8.0.1+ for dependency ordering" + raise RuntimeError(msg) + + schema_clause = f"'{schema}'" if schema else "DATABASE()" + + sql = f""" + WITH RECURSIVE dependency_tree AS ( + SELECT + table_name, + 0 AS level, + CAST(table_name AS CHAR(4000)) AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = {schema_clause} + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.key_column_usage kcu + WHERE kcu.table_name = t.table_name + AND kcu.table_schema = t.table_schema + AND kcu.referenced_table_name IS NOT NULL + ) + + UNION ALL + + SELECT + kcu.table_name, + dt.level + 1, + CONCAT(dt.path, ',', kcu.table_name) + FROM information_schema.key_column_usage kcu + JOIN dependency_tree dt ON kcu.referenced_table_name = dt.table_name + WHERE kcu.table_schema = {schema_clause} + AND kcu.referenced_table_name IS NOT NULL + AND NOT FIND_IN_SET(kcu.table_name, dt.path) + ) + SELECT DISTINCT table_name + FROM dependency_tree + ORDER BY level, table_name """ + result = await asyncmy_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] + + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" asyncmy_driver = cast("AsyncmyDriver", driver) + where_clauses = ["referenced_table_name IS NOT NULL"] + if table: + where_clauses.append(f"table_name = '{table}'") if schema: - sql = f""" - SELECT column_name, data_type, is_nullable, column_default - FROM information_schema.columns - WHERE table_name = '{table}' AND table_schema = '{schema}' - ORDER BY ordinal_position - """ + where_clauses.append(f"table_schema = '{schema}'") else: - sql = f""" - SELECT column_name, data_type, is_nullable, column_default - FROM information_schema.columns - WHERE table_name = '{table}' - ORDER BY ordinal_position - """ + where_clauses.append("table_schema = DATABASE()") + + where_str = " AND ".join(where_clauses) + + sql = f""" + SELECT + table_name, + column_name, + referenced_table_name, + referenced_column_name, + constraint_name, + table_schema, + referenced_table_schema + FROM information_schema.key_column_usage + WHERE {where_str} + """ + result = await asyncmy_driver.execute(sql) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row.get("referenced_table_schema"), + ) + for row in result.data + ] + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + asyncmy_driver = cast("AsyncmyDriver", driver) + sql = f"SHOW INDEX FROM {table}" if schema is None else f"SHOW INDEX FROM {table} FROM {schema}" result = await asyncmy_driver.execute(sql) - return result.data or [] + # Parse SHOW INDEX output + indexes: dict[str, dict[str, Any]] = {} + for row in result.data: + idx_name = row["Key_name"] + if idx_name not in indexes: + indexes[idx_name] = { + "name": idx_name, + "columns": [], + "unique": row["Non_unique"] == 0, + "primary": idx_name == "PRIMARY", + "table_name": table, + } + indexes[idx_name]["columns"].append(row["Column_name"]) + + return list(indexes.values()) def list_available_features(self) -> "list[str]": """List available MySQL feature flags. diff --git a/sqlspec/adapters/asyncpg/data_dictionary.py b/sqlspec/adapters/asyncpg/data_dictionary.py index 606dd2fe7..aad7fefa9 100644 --- a/sqlspec/adapters/asyncpg/data_dictionary.py +++ b/sqlspec/adapters/asyncpg/data_dictionary.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo +from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, ForeignKeyMetadata, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -114,6 +114,58 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "TEXT") + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using Recursive CTE.""" + asyncpg_driver = cast("AsyncpgDriver", driver) + schema_name = schema or "public" + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + t.table_name::text, + 0 AS level, + ARRAY[t.table_name::text] AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = $1 + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = t.table_name + AND tc.table_schema = t.table_schema + ) + + UNION ALL + + SELECT + tc.table_name::text, + dt.level + 1, + dt.path || tc.table_name::text + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + JOIN dependency_tree dt + ON ccu.table_name = dt.table_name + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = $1 + AND ccu.table_schema = $1 + AND NOT (tc.table_name = ANY(dt.path)) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name; + """ + result = await asyncpg_driver.execute(sql, (schema_name,)) + return [row["table_name"] for row in result.get_data()] + async def get_columns( self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None ) -> "list[dict[str, Any]]": @@ -158,6 +210,95 @@ async def get_columns( result = await asyncpg_driver.execute(sql, (table, schema_name)) return result.data or [] + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + asyncpg_driver = cast("AsyncpgDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + ccu.table_name AS referenced_table_name, + ccu.column_name AS referenced_column_name, + tc.constraint_name, + tc.table_schema, + ccu.table_schema AS referenced_table_schema + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND ($1::text IS NULL OR tc.table_schema = $1) + AND ($2::text IS NULL OR tc.table_name = $2) + """ + result = await asyncpg_driver.execute(sql, (schema_name, table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + asyncpg_driver = cast("AsyncpgDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + i.relname as index_name, + ix.indisunique as is_unique, + ix.indisprimary as is_primary, + array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + AND i.oid = ix.indexrelid + AND a.attrelid = t.oid + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + AND t.relnamespace = n.oid + AND n.nspname = $1 + AND t.relname = $2 + GROUP BY + i.relname, + ix.indisunique, + ix.indisprimary + """ + result = await asyncpg_driver.execute(sql, (schema_name, table)) + + return [ + { + "name": row["index_name"], + "columns": row["columns"], + "unique": row["is_unique"], + "primary": row["is_primary"], + "table_name": table, + } + for row in result.data + ] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. diff --git a/sqlspec/adapters/bigquery/data_dictionary.py b/sqlspec/adapters/bigquery/data_dictionary.py index dc4eb5d3f..7d9f41af2 100644 --- a/sqlspec/adapters/bigquery/data_dictionary.py +++ b/sqlspec/adapters/bigquery/data_dictionary.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo +from sqlspec.driver import ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -125,6 +125,113 @@ def get_columns( result = bigquery_driver.execute(sql) return result.data or [] + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using BigQuery catalog.""" + bigquery_driver = cast("BigQueryDriver", driver) + + if schema: + tables_table = f"`{schema}.INFORMATION_SCHEMA.TABLES`" + kcu_table = f"`{schema}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE`" + rc_table = f"`{schema}.INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS`" + else: + tables_table = "INFORMATION_SCHEMA.TABLES" + kcu_table = "INFORMATION_SCHEMA.KEY_COLUMN_USAGE" + rc_table = "INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS" + + sql = f""" + WITH RECURSIVE dependency_tree AS ( + SELECT + t.table_name, + 0 AS level, + [t.table_name] AS path + FROM {tables_table} t + WHERE t.table_type = 'BASE TABLE' + AND NOT EXISTS ( + SELECT 1 + FROM {kcu_table} kcu + JOIN {rc_table} rc ON kcu.constraint_name = rc.constraint_name + WHERE kcu.table_name = t.table_name + ) + + UNION ALL + + SELECT + kcu.table_name, + dt.level + 1, + ARRAY_CONCAT(dt.path, [kcu.table_name]) + FROM {kcu_table} kcu + JOIN {rc_table} rc ON kcu.constraint_name = rc.constraint_name + JOIN {kcu_table} pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + JOIN dependency_tree dt ON pk_kcu.table_name = dt.table_name + WHERE kcu.table_name NOT IN UNNEST(dt.path) + ) + SELECT DISTINCT table_name + FROM dependency_tree + ORDER BY level, table_name + """ + + result = bigquery_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + bigquery_driver = cast("BigQueryDriver", driver) + + dataset = schema + if dataset: + kcu_table = f"`{dataset}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE`" + rc_table = f"`{dataset}.INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS`" + else: + kcu_table = "INFORMATION_SCHEMA.KEY_COLUMN_USAGE" + rc_table = "INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS" + + where_clauses = [] + if table: + where_clauses.append(f"kcu.table_name = '{table}'") + + where_str = " AND ".join(where_clauses) + if where_str: + where_str = "WHERE " + where_str + + sql = f""" + SELECT + kcu.table_name, + kcu.column_name, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name, + kcu.constraint_name, + kcu.table_schema, + pk_kcu.table_schema AS referenced_table_schema + FROM {kcu_table} kcu + JOIN {rc_table} rc ON kcu.constraint_name = rc.constraint_name + JOIN {kcu_table} pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + {where_str} + """ + + try: + result = bigquery_driver.execute(sql) + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + except Exception: + logger.warning("Failed to fetch foreign keys from BigQuery") + return [] + def list_available_features(self) -> "list[str]": """List available BigQuery feature flags. diff --git a/sqlspec/adapters/duckdb/data_dictionary.py b/sqlspec/adapters/duckdb/data_dictionary.py index c4c43957a..a16da71e4 100644 --- a/sqlspec/adapters/duckdb/data_dictionary.py +++ b/sqlspec/adapters/duckdb/data_dictionary.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo +from sqlspec.driver import ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -141,6 +141,107 @@ def get_columns( result = duckdb_driver.execute(sql) return result.data or [] + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using DuckDB catalog.""" + duckdb_driver = cast("DuckDBDriver", driver) + schema_clause = f"'{schema}'" if schema else "current_schema()" + + sql = f""" + WITH RECURSIVE dependency_tree AS ( + SELECT + table_name, + 0 AS level, + [table_name] AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = {schema_clause} + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.key_column_usage kcu + WHERE kcu.table_name = t.table_name + AND kcu.table_schema = t.table_schema + AND kcu.constraint_name IN (SELECT constraint_name FROM information_schema.referential_constraints) + ) + + UNION ALL + + SELECT + kcu.table_name, + dt.level + 1, + list_append(dt.path, kcu.table_name) + FROM information_schema.key_column_usage kcu + JOIN information_schema.referential_constraints rc ON kcu.constraint_name = rc.constraint_name + JOIN information_schema.key_column_usage pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND rc.unique_constraint_schema = pk_kcu.constraint_schema + JOIN dependency_tree dt ON dt.table_name = pk_kcu.table_name + WHERE kcu.table_schema = {schema_clause} + AND NOT list_contains(dt.path, kcu.table_name) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name + """ + result = duckdb_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + duckdb_driver = cast("DuckDBDriver", driver) + + where_clauses = [] + if schema: + where_clauses.append(f"kcu.table_schema = '{schema}'") + if table: + where_clauses.append(f"kcu.table_name = '{table}'") + + where_str = " AND ".join(where_clauses) if where_clauses else "1=1" + + sql = f""" + SELECT + kcu.table_name, + kcu.column_name, + pk_kcu.table_name AS referenced_table_name, + pk_kcu.column_name AS referenced_column_name, + kcu.constraint_name, + kcu.table_schema, + pk_kcu.table_schema AS referenced_table_schema + FROM information_schema.key_column_usage kcu + JOIN information_schema.referential_constraints rc + ON kcu.constraint_name = rc.constraint_name + JOIN information_schema.key_column_usage pk_kcu + ON rc.unique_constraint_name = pk_kcu.constraint_name + AND kcu.ordinal_position = pk_kcu.ordinal_position + WHERE {where_str} + """ + + result = duckdb_driver.execute(sql) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.get_data() + ] + + def get_indexes( + self, driver: "SyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + # DuckDB doesn't expose indexes easily in IS yet, usually just constraints? + # Fallback to empty for now or implementation specific. + # PRD mentions it but no specific instruction on implementation detail if missing. + # Returning empty list. + return [] + def list_available_features(self) -> "list[str]": """List available DuckDB feature flags. diff --git a/sqlspec/adapters/oracledb/data_dictionary.py b/sqlspec/adapters/oracledb/data_dictionary.py index 272f43e74..6d5b55b4e 100644 --- a/sqlspec/adapters/oracledb/data_dictionary.py +++ b/sqlspec/adapters/oracledb/data_dictionary.py @@ -7,6 +7,7 @@ from sqlspec.driver import ( AsyncDataDictionaryBase, AsyncDriverAdapterBase, + ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo, @@ -139,6 +140,52 @@ def _get_columns_sql(self, table: str, schema: "str | None" = None) -> str: ORDER BY column_id """ + def _get_foreign_keys_sql(self, table: "str | None" = None) -> str: + where_clause = f"AND c.table_name = '{table.upper()}'" if table else "" + return f""" + SELECT + c.table_name, + cc.column_name, + r.table_name AS referenced_table_name, + rcc.column_name AS referenced_column_name, + c.constraint_name, + c.owner AS schema, + r.owner AS referenced_schema + FROM user_constraints c + JOIN user_cons_columns cc ON c.constraint_name = cc.constraint_name + JOIN user_constraints r ON c.r_constraint_name = r.constraint_name + JOIN user_cons_columns rcc ON r.constraint_name = rcc.constraint_name + WHERE c.constraint_type = 'R' + AND cc.position = rcc.position + {where_clause} + """ + + def _get_indexes_sql(self, table: "str | None" = None) -> str: + where_clause = f"AND i.table_name = '{table.upper()}'" if table else "" + return f""" + SELECT + i.index_name AS name, + i.table_name, + i.uniqueness, + LISTAGG(ic.column_name, ',') WITHIN GROUP (ORDER BY ic.column_position) AS columns + FROM user_indexes i + JOIN user_ind_columns ic ON i.index_name = ic.index_name + WHERE 1=1 {where_clause} + GROUP BY i.index_name, i.table_name, i.uniqueness + """ + + def _get_tables_sql(self) -> str: + return """ + SELECT table_name, MAX(LEVEL) as lvl + FROM user_constraints + START WITH table_name NOT IN ( + SELECT table_name FROM user_constraints WHERE constraint_type = 'R' + ) + CONNECT BY NOCYCLE PRIOR constraint_name = r_constraint_name + GROUP BY table_name + ORDER BY lvl, table_name + """ + def _select_component_version_row(self, driver: "OracleSyncDriver") -> "dict[str, Any] | None": """Fetch the latest Oracle component version row. @@ -382,6 +429,60 @@ def get_columns( result = oracle_driver.execute(self._get_columns_sql(table, schema)) return result.get_data() + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using Oracle CONNECT BY.""" + oracle_driver = cast("OracleSyncDriver", driver) + + result = oracle_driver.execute(self._get_tables_sql()) + sorted_tables = [row["table_name"] for row in result] + + all_result = oracle_driver.execute("SELECT table_name FROM user_tables") + all_tables = {row["table_name"] for row in all_result.get_data()} + + sorted_set = set(sorted_tables) + disconnected = list(all_tables - sorted_set) + disconnected.sort() + + return disconnected + sorted_tables + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + oracle_driver = cast("OracleSyncDriver", driver) + result = oracle_driver.execute(self._get_foreign_keys_sql(table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["schema"], + referenced_schema=row["referenced_schema"], + ) + for row in result.get_data() + ] + + def get_indexes( + self, driver: "SyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + oracle_driver = cast("OracleSyncDriver", driver) + result = oracle_driver.execute(self._get_indexes_sql(table)) + + return [ + { + "name": row["name"], + "columns": row["columns"].split(",") if row["columns"] else [], + "unique": row["uniqueness"] == "UNIQUE", + "primary": False, + "table_name": row["table_name"], + } + for row in result.get_data() + ] + def list_available_features(self) -> "list[str]": """List available Oracle feature flags. @@ -544,6 +645,60 @@ async def get_columns( result = await oracle_driver.execute(self._get_columns_sql(table, schema)) return result.get_data() + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using Oracle CONNECT BY.""" + oracle_driver = cast("OracleAsyncDriver", driver) + + result = await oracle_driver.execute(self._get_tables_sql()) + sorted_tables = [row["table_name"] for row in result.get_data()] + + all_result = await oracle_driver.execute("SELECT table_name FROM user_tables") + all_tables = {row["table_name"] for row in all_result.get_data()} + + sorted_set = set(sorted_tables) + disconnected = list(all_tables - sorted_set) + disconnected.sort() + + return disconnected + sorted_tables + + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + oracle_driver = cast("OracleAsyncDriver", driver) + result = await oracle_driver.execute(self._get_foreign_keys_sql(table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["schema"], + referenced_schema=row["referenced_schema"], + ) + for row in result.get_data() + ] + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + oracle_driver = cast("OracleAsyncDriver", driver) + result = await oracle_driver.execute(self._get_indexes_sql(table)) + + return [ + { + "name": row["name"], + "columns": row["columns"].split(",") if row["columns"] else [], + "unique": row["uniqueness"] == "UNIQUE", + "primary": False, + "table_name": row["table_name"], + } + for row in result.get_data() + ] + def list_available_features(self) -> "list[str]": """List available Oracle feature flags. diff --git a/sqlspec/adapters/psycopg/data_dictionary.py b/sqlspec/adapters/psycopg/data_dictionary.py index a52c0eafb..2a1768b9b 100644 --- a/sqlspec/adapters/psycopg/data_dictionary.py +++ b/sqlspec/adapters/psycopg/data_dictionary.py @@ -6,6 +6,7 @@ from sqlspec.driver import ( AsyncDataDictionaryBase, AsyncDriverAdapterBase, + ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo, @@ -119,6 +120,58 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> } return type_map.get(type_category, "TEXT") + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using Recursive CTE.""" + psycopg_driver = cast("PsycopgSyncDriver", driver) + schema_name = schema or "public" + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + t.table_name::text, + 0 AS level, + ARRAY[t.table_name::text] AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = %s + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = t.table_name + AND tc.table_schema = t.table_schema + ) + + UNION ALL + + SELECT + tc.table_name::text, + dt.level + 1, + dt.path || tc.table_name::text + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + JOIN dependency_tree dt + ON ccu.table_name = dt.table_name + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = %s + AND ccu.table_schema = %s + AND NOT (tc.table_name = ANY(dt.path)) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name; + """ + result = psycopg_driver.execute(sql, (schema_name, schema_name, schema_name)) + return [row["table_name"] for row in result.data] + def get_columns( self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None ) -> "list[dict[str, Any]]": @@ -163,6 +216,96 @@ def get_columns( result = psycopg_driver.execute(sql, (table, schema_name)) return result.data or [] + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + psycopg_driver = cast("PsycopgSyncDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + ccu.table_name AS referenced_table_name, + ccu.column_name AS referenced_column_name, + tc.constraint_name, + tc.table_schema, + ccu.table_schema AS referenced_table_schema + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND (%s::text IS NULL OR tc.table_schema = %s) + AND (%s::text IS NULL OR tc.table_name = %s) + """ + + result = psycopg_driver.execute(sql, (schema_name, schema_name, table, table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + + def get_indexes( + self, driver: "SyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + psycopg_driver = cast("PsycopgSyncDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + i.relname as index_name, + ix.indisunique as is_unique, + ix.indisprimary as is_primary, + array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + AND i.oid = ix.indexrelid + AND a.attrelid = t.oid + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + AND t.relnamespace = n.oid + AND n.nspname = %s + AND t.relname = %s + GROUP BY + i.relname, + ix.indisunique, + ix.indisprimary + """ + result = psycopg_driver.execute(sql, (schema_name, table)) + + return [ + { + "name": row["index_name"], + "columns": row["columns"], + "unique": row["is_unique"], + "primary": row["is_primary"], + "table_name": table, + } + for row in result.data + ] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. @@ -279,6 +422,58 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "TEXT") + async def get_tables(self, driver: "AsyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using Recursive CTE.""" + psycopg_driver = cast("PsycopgAsyncDriver", driver) + schema_name = schema or "public" + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + t.table_name::text, + 0 AS level, + ARRAY[t.table_name::text] AS path + FROM information_schema.tables t + WHERE t.table_type = 'BASE TABLE' + AND t.table_schema = %s + AND NOT EXISTS ( + SELECT 1 + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = t.table_name + AND tc.table_schema = t.table_schema + ) + + UNION ALL + + SELECT + tc.table_name::text, + dt.level + 1, + dt.path || tc.table_name::text + FROM information_schema.table_constraints tc + JOIN information_schema.key_column_usage kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + JOIN dependency_tree dt + ON ccu.table_name = dt.table_name + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = %s + AND ccu.table_schema = %s + AND NOT (tc.table_name = ANY(dt.path)) + ) + SELECT DISTINCT table_name, level + FROM dependency_tree + ORDER BY level, table_name; + """ + result = await psycopg_driver.execute(sql, (schema_name, schema_name, schema_name)) + return [row["table_name"] for row in result.data] + async def get_columns( self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None ) -> "list[dict[str, Any]]": @@ -323,6 +518,96 @@ async def get_columns( result = await psycopg_driver.execute(sql, (table, schema_name)) return result.data or [] + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + psycopg_driver = cast("PsycopgAsyncDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + kcu.table_name, + kcu.column_name, + ccu.table_name AS referenced_table_name, + ccu.column_name AS referenced_column_name, + tc.constraint_name, + tc.table_schema, + ccu.table_schema AS referenced_table_schema + FROM + information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND (%s::text IS NULL OR tc.table_schema = %s) + AND (%s::text IS NULL OR tc.table_name = %s) + """ + + result = await psycopg_driver.execute(sql, (schema_name, schema_name, table, table)) + + return [ + ForeignKeyMetadata( + table_name=row["table_name"], + column_name=row["column_name"], + referenced_table=row["referenced_table_name"], + referenced_column=row["referenced_column_name"], + constraint_name=row["constraint_name"], + schema=row["table_schema"], + referenced_schema=row["referenced_table_schema"], + ) + for row in result.data + ] + + async def get_indexes( + self, driver: "AsyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + psycopg_driver = cast("PsycopgAsyncDriver", driver) + schema_name = schema or "public" + + sql = """ + SELECT + i.relname as index_name, + ix.indisunique as is_unique, + ix.indisprimary as is_primary, + array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) as columns + FROM + pg_class t, + pg_class i, + pg_index ix, + pg_attribute a, + pg_namespace n + WHERE + t.oid = ix.indrelid + AND i.oid = ix.indexrelid + AND a.attrelid = t.oid + AND a.attnum = ANY(ix.indkey) + AND t.relkind = 'r' + AND t.relnamespace = n.oid + AND n.nspname = %s + AND t.relname = %s + GROUP BY + i.relname, + ix.indisunique, + ix.indisprimary + """ + result = await psycopg_driver.execute(sql, (schema_name, table)) + + return [ + { + "name": row["index_name"], + "columns": row["columns"], + "unique": row["is_unique"], + "primary": row["is_primary"], + "table_name": table, + } + for row in result.data + ] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. diff --git a/sqlspec/adapters/sqlite/data_dictionary.py b/sqlspec/adapters/sqlite/data_dictionary.py index 4377a8801..ff0ab28b5 100644 --- a/sqlspec/adapters/sqlite/data_dictionary.py +++ b/sqlspec/adapters/sqlite/data_dictionary.py @@ -3,7 +3,7 @@ import re from typing import TYPE_CHECKING, Any, cast -from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo +from sqlspec.driver import ForeignKeyMetadata, SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -129,6 +129,113 @@ def get_columns( for row in result.data or [] ] + def get_tables(self, driver: "SyncDriverAdapterBase", schema: "str | None" = None) -> "list[str]": + """Get tables sorted by topological dependency order using SQLite catalog.""" + sqlite_driver = cast("SqliteDriver", driver) + + sql = """ + WITH RECURSIVE dependency_tree AS ( + SELECT + m.name as table_name, + 0 as level, + '/' || m.name || '/' as path + FROM sqlite_schema m + WHERE m.type = 'table' + AND m.name NOT LIKE 'sqlite_%' + AND NOT EXISTS ( + SELECT 1 FROM pragma_foreign_key_list(m.name) + ) + + UNION ALL + + SELECT + m.name as table_name, + dt.level + 1, + dt.path || m.name || '/' + FROM sqlite_schema m + JOIN pragma_foreign_key_list(m.name) fk + JOIN dependency_tree dt ON fk."table" = dt.table_name + WHERE m.type = 'table' + AND m.name NOT LIKE 'sqlite_%' + AND instr(dt.path, '/' || m.name || '/') = 0 + ) + SELECT DISTINCT table_name FROM dependency_tree ORDER BY level, table_name; + """ + result = sqlite_driver.execute(sql) + return [row["table_name"] for row in result.get_data()] + + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + sqlite_driver = cast("SqliteDriver", driver) + + if table: + sql = f"SELECT '{table}' as table_name, fk.* FROM pragma_foreign_key_list('{table}') fk" + result = sqlite_driver.execute(sql) + else: + sql = """ + SELECT m.name as table_name, fk.* + FROM sqlite_schema m, pragma_foreign_key_list(m.name) fk + WHERE m.type = 'table' AND m.name NOT LIKE 'sqlite_%' + """ + result = sqlite_driver.execute(sql) + + fks = [] + for row in result.data: + if isinstance(row, (list, tuple)): + table_name = row[0] + ref_table = row[3] + col = row[4] + ref_col = row[5] + else: + table_name = row["table_name"] + ref_table = row["table"] + col = row["from"] + ref_col = row["to"] + + fks.append( + ForeignKeyMetadata( + table_name=table_name, + column_name=col, + referenced_table=ref_table, + referenced_column=ref_col, + constraint_name=None, + schema=None, + referenced_schema=None, + ) + ) + return fks + + def get_indexes( + self, driver: "SyncDriverAdapterBase", table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get index information for a table.""" + sqlite_driver = cast("SqliteDriver", driver) + + index_list_res = sqlite_driver.execute(f"PRAGMA index_list('{table}')") + indexes = [] + + for idx_row in index_list_res.data: + if isinstance(idx_row, (list, tuple)): + idx_name = idx_row[1] + unique = bool(idx_row[2]) + else: + idx_name = idx_row["name"] + unique = bool(idx_row["unique"]) + + info_res = sqlite_driver.execute(f"PRAGMA index_info('{idx_name}')") + cols = [] + for col_row in info_res.data: + if isinstance(col_row, (list, tuple)): + cols.append(col_row[2]) + else: + cols.append(col_row["name"]) + + indexes.append({"name": idx_name, "columns": cols, "unique": unique, "primary": False, "table_name": table}) + + return indexes + def list_available_features(self) -> "list[str]": """List available SQLite feature flags. diff --git a/sqlspec/driver/__init__.py b/sqlspec/driver/__init__.py index 4bc2e18e0..fae69124f 100644 --- a/sqlspec/driver/__init__.py +++ b/sqlspec/driver/__init__.py @@ -3,8 +3,12 @@ from sqlspec.driver import mixins from sqlspec.driver._async import AsyncDataDictionaryBase, AsyncDriverAdapterBase from sqlspec.driver._common import ( + ColumnMetadata, CommonDriverAttributesMixin, + DataDictionaryMixin, ExecutionResult, + ForeignKeyMetadata, + IndexMetadata, StackExecutionObserver, VersionInfo, describe_stack_statement, @@ -14,9 +18,13 @@ __all__ = ( "AsyncDataDictionaryBase", "AsyncDriverAdapterBase", + "ColumnMetadata", "CommonDriverAttributesMixin", + "DataDictionaryMixin", "DriverAdapterProtocol", "ExecutionResult", + "ForeignKeyMetadata", + "IndexMetadata", "StackExecutionObserver", "SyncDataDictionaryBase", "SyncDriverAdapterBase", diff --git a/sqlspec/driver/_async.py b/sqlspec/driver/_async.py index 0d7497a91..399b2e488 100644 --- a/sqlspec/driver/_async.py +++ b/sqlspec/driver/_async.py @@ -27,6 +27,7 @@ from sqlspec.builder import QueryBuilder from sqlspec.core import ArrowResult, SQLResult, StatementConfig, StatementFilter + from sqlspec.driver._common import ForeignKeyMetadata from sqlspec.typing import ArrowReturnFormat, SchemaT, StatementParameters @@ -747,6 +748,22 @@ async def get_indexes( _ = driver, table, schema return [] + async def get_foreign_keys( + self, driver: "AsyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata. + + Args: + driver: Async database driver instance + table: Optional table name filter + schema: Optional schema name filter + + Returns: + List of foreign key metadata + """ + _ = driver, table, schema + return [] + def list_available_features(self) -> "list[str]": """List all features that can be checked via get_feature_flag. diff --git a/sqlspec/driver/_common.py b/sqlspec/driver/_common.py index 7d686d0a3..661102dea 100644 --- a/sqlspec/driver/_common.py +++ b/sqlspec/driver/_common.py @@ -1,5 +1,6 @@ """Common driver attributes and utilities.""" +import graphlib import hashlib import logging import re @@ -42,9 +43,12 @@ "EXEC_CURSOR_RESULT", "EXEC_ROWCOUNT_OVERRIDE", "EXEC_SPECIAL_DATA", + "ColumnMetadata", "CommonDriverAttributesMixin", "DataDictionaryMixin", "ExecutionResult", + "ForeignKeyMetadata", + "IndexMetadata", "ScriptExecutionResult", "StackExecutionObserver", "VersionInfo", @@ -62,6 +66,163 @@ VERSION_GROUPS_MIN_FOR_PATCH = 2 +class ForeignKeyMetadata: + """Metadata for a foreign key constraint.""" + + __slots__ = ( + "column_name", + "constraint_name", + "referenced_column", + "referenced_schema", + "referenced_table", + "schema", + "table_name", + ) + + def __init__( + self, + table_name: str, + column_name: str, + referenced_table: str, + referenced_column: str, + constraint_name: str | None = None, + schema: str | None = None, + referenced_schema: str | None = None, + ) -> None: + self.table_name = table_name + self.column_name = column_name + self.referenced_table = referenced_table + self.referenced_column = referenced_column + self.constraint_name = constraint_name + self.schema = schema + self.referenced_schema = referenced_schema + + def __repr__(self) -> str: + return ( + f"ForeignKeyMetadata(table_name={self.table_name!r}, column_name={self.column_name!r}, " + f"referenced_table={self.referenced_table!r}, referenced_column={self.referenced_column!r}, " + f"constraint_name={self.constraint_name!r}, schema={self.schema!r}, referenced_schema={self.referenced_schema!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ForeignKeyMetadata): + return NotImplemented + return ( + self.table_name == other.table_name + and self.column_name == other.column_name + and self.referenced_table == other.referenced_table + and self.referenced_column == other.referenced_column + and self.constraint_name == other.constraint_name + and self.schema == other.schema + and self.referenced_schema == other.referenced_schema + ) + + def __hash__(self) -> int: + return hash(( + self.table_name, + self.column_name, + self.referenced_table, + self.referenced_column, + self.constraint_name, + self.schema, + self.referenced_schema, + )) + + +class ColumnMetadata: + """Metadata for a database column.""" + + __slots__ = ("data_type", "default_value", "max_length", "name", "nullable", "precision", "primary_key", "scale") + + def __init__( + self, + name: str, + data_type: str, + nullable: bool, + default_value: str | None = None, + primary_key: bool = False, + max_length: int | None = None, + precision: int | None = None, + scale: int | None = None, + ) -> None: + self.name = name + self.data_type = data_type + self.nullable = nullable + self.default_value = default_value + self.primary_key = primary_key + self.max_length = max_length + self.precision = precision + self.scale = scale + + def __repr__(self) -> str: + return ( + f"ColumnMetadata(name={self.name!r}, data_type={self.data_type!r}, nullable={self.nullable!r}, " + f"default_value={self.default_value!r}, primary_key={self.primary_key!r}, max_length={self.max_length!r}, " + f"precision={self.precision!r}, scale={self.scale!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, ColumnMetadata): + return NotImplemented + return ( + self.name == other.name + and self.data_type == other.data_type + and self.nullable == other.nullable + and self.default_value == other.default_value + and self.primary_key == other.primary_key + and self.max_length == other.max_length + and self.precision == other.precision + and self.scale == other.scale + ) + + def __hash__(self) -> int: + return hash(( + self.name, + self.data_type, + self.nullable, + self.default_value, + self.primary_key, + self.max_length, + self.precision, + self.scale, + )) + + +class IndexMetadata: + """Metadata for a database index.""" + + __slots__ = ("columns", "name", "primary", "table_name", "unique") + + def __init__( + self, name: str, table_name: str, columns: list[str], unique: bool = False, primary: bool = False + ) -> None: + self.name = name + self.table_name = table_name + self.columns = columns + self.unique = unique + self.primary = primary + + def __repr__(self) -> str: + return ( + f"IndexMetadata(name={self.name!r}, table_name={self.table_name!r}, columns={self.columns!r}, " + f"unique={self.unique!r}, primary={self.primary!r})" + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, IndexMetadata): + return NotImplemented + return ( + self.name == other.name + and self.table_name == other.table_name + and self.columns == other.columns + and self.unique == other.unique + and self.primary == other.primary + ) + + def __hash__(self) -> int: + return hash((self.name, self.table_name, tuple(self.columns), self.unique, self.primary)) + + def make_cache_key_hashable(obj: Any) -> Any: """Recursively convert unhashable types to hashable ones for cache keys. @@ -374,6 +535,32 @@ def get_default_features(self) -> "list[str]": """ return ["supports_transactions", "supports_prepared_statements"] + def sort_tables_topologically(self, tables: "list[str]", foreign_keys: "list[ForeignKeyMetadata]") -> "list[str]": + """Sort tables topologically based on foreign key dependencies using Python. + + Args: + tables: List of table names. + foreign_keys: List of foreign key metadata. + + Returns: + List of table names in topological order (dependencies first). + + Raises: + CycleError: If a dependency cycle is detected. + """ + sorter: graphlib.TopologicalSorter[str] = graphlib.TopologicalSorter() + for table in tables: + sorter.add(table) + + for fk in foreign_keys: + # If self-referencing, ignore for sorting purposes to avoid simple cycles + if fk.table_name == fk.referenced_table: + continue + # table_name depends on referenced_table + sorter.add(fk.table_name, fk.referenced_table) + + return list(sorter.static_order()) + class ScriptExecutionResult(NamedTuple): """Result from script execution with statement count information.""" diff --git a/sqlspec/driver/_sync.py b/sqlspec/driver/_sync.py index d748cbcbb..146b7b864 100644 --- a/sqlspec/driver/_sync.py +++ b/sqlspec/driver/_sync.py @@ -27,6 +27,7 @@ from sqlspec.builder import QueryBuilder from sqlspec.core import ArrowResult, SQLResult, Statement, StatementConfig, StatementFilter + from sqlspec.driver._common import ForeignKeyMetadata from sqlspec.typing import ArrowReturnFormat, SchemaT, StatementParameters _LOGGER_NAME: Final[str] = "sqlspec" @@ -747,6 +748,22 @@ def get_indexes( _ = driver, table, schema return [] + def get_foreign_keys( + self, driver: "SyncDriverAdapterBase", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata. + + Args: + driver: Sync database driver instance + table: Optional table name filter + schema: Optional schema name filter + + Returns: + List of foreign key metadata + """ + _ = driver, table, schema + return [] + def list_available_features(self) -> "list[str]": """List all features that can be checked via get_feature_flag. diff --git a/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py b/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py new file mode 100644 index 000000000..ed0557344 --- /dev/null +++ b/tests/integration/test_adapters/test_aiosqlite/test_data_dictionary.py @@ -0,0 +1,73 @@ +"""Integration tests for AioSQLite data dictionary.""" + +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from sqlspec.adapters.aiosqlite.driver import AiosqliteDriver + +pytestmark = pytest.mark.xdist_group("aiosqlite") + + +@pytest.mark.aiosqlite +async def test_aiosqlite_data_dictionary_topology_and_fks(aiosqlite_session: "AiosqliteDriver") -> None: + """Test topological sort and FK metadata.""" + aiosqlite_driver = aiosqlite_session + import uuid + + unique_suffix = uuid.uuid4().hex[:8] + users_table = f"dd_users_{unique_suffix}" + orders_table = f"dd_orders_{unique_suffix}" + items_table = f"dd_items_{unique_suffix}" + + await aiosqlite_driver.execute_script(f""" + CREATE TABLE {users_table} ( + id INTEGER PRIMARY KEY, + name VARCHAR(50) + ); + CREATE TABLE {orders_table} ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES {users_table}(id), + amount INTEGER + ); + CREATE TABLE {items_table} ( + id INTEGER PRIMARY KEY, + order_id INTEGER REFERENCES {orders_table}(id), + name VARCHAR(50) + ); + """) + + try: + # Test 1: Topological Sort + sorted_tables = await aiosqlite_driver.data_dictionary.get_tables(aiosqlite_driver) + + test_tables = [t for t in sorted_tables if t in (users_table, orders_table, items_table)] + assert len(test_tables) == 3 + + idx_users = test_tables.index(users_table) + idx_orders = test_tables.index(orders_table) + idx_items = test_tables.index(items_table) + + assert idx_users < idx_orders + assert idx_orders < idx_items + + # Test 2: Foreign Keys + fks = await aiosqlite_driver.data_dictionary.get_foreign_keys(aiosqlite_driver, table=orders_table) + assert len(fks) >= 1 + my_fk = next((fk for fk in fks if fk.referenced_table == users_table), None) + assert my_fk is not None + assert my_fk.column_name == "user_id" + + # Test 3: Indexes + await aiosqlite_driver.execute(f"CREATE INDEX idx_{unique_suffix} ON {users_table}(name)") + indexes = await aiosqlite_driver.data_dictionary.get_indexes(aiosqlite_driver, table=users_table) + assert len(indexes) >= 1 + assert any(idx["name"] == f"idx_{unique_suffix}" for idx in indexes) + + finally: + await aiosqlite_driver.execute_script(f""" + DROP TABLE IF EXISTS {items_table}; + DROP TABLE IF EXISTS {orders_table}; + DROP TABLE IF EXISTS {users_table}; + """) diff --git a/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py b/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py index 68126324f..547866702 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py +++ b/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py @@ -105,3 +105,63 @@ async def test_asyncpg_data_dictionary_available_features(asyncpg_async_driver: for feature in expected_features: assert feature in features + + +@pytest.mark.asyncpg +async def test_asyncpg_data_dictionary_topology_and_fks(asyncpg_async_driver: "AsyncpgDriver") -> None: + """Test topological sort and FK metadata.""" + import uuid + + unique_suffix = uuid.uuid4().hex[:8] + users_table = f"dd_users_{unique_suffix}" + orders_table = f"dd_orders_{unique_suffix}" + items_table = f"dd_items_{unique_suffix}" + + await asyncpg_async_driver.execute_script(f""" + CREATE TABLE {users_table} ( + id SERIAL PRIMARY KEY, + name VARCHAR(50) + ); + CREATE TABLE {orders_table} ( + id SERIAL PRIMARY KEY, + user_id INTEGER REFERENCES {users_table}(id), + amount INTEGER + ); + CREATE TABLE {items_table} ( + id SERIAL PRIMARY KEY, + order_id INTEGER REFERENCES {orders_table}(id), + name VARCHAR(50) + ); + """) + + try: + # Test 1: Topological Sort + sorted_tables = await asyncpg_async_driver.data_dictionary.get_tables(asyncpg_async_driver) + + test_tables = [t for t in sorted_tables if t in (users_table, orders_table, items_table)] + assert len(test_tables) == 3 + + idx_users = test_tables.index(users_table) + idx_orders = test_tables.index(orders_table) + idx_items = test_tables.index(items_table) + + assert idx_users < idx_orders + assert idx_orders < idx_items + + # Test 2: Foreign Keys + fks = await asyncpg_async_driver.data_dictionary.get_foreign_keys(asyncpg_async_driver, table=orders_table) + assert len(fks) >= 1 + my_fk = next((fk for fk in fks if fk.referenced_table == users_table), None) + assert my_fk is not None + assert my_fk.column_name == "user_id" + + # Test 3: Indexes + indexes = await asyncpg_async_driver.data_dictionary.get_indexes(asyncpg_async_driver, table=users_table) + assert len(indexes) >= 1 # PK index + + finally: + await asyncpg_async_driver.execute_script(f""" + DROP TABLE IF EXISTS {items_table} CASCADE; + DROP TABLE IF EXISTS {orders_table} CASCADE; + DROP TABLE IF EXISTS {users_table} CASCADE; + """)