From 5afbb61a8feda8c9c754290c8415a2f928236c58 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Mon, 18 Aug 2025 16:04:50 +0000 Subject: [PATCH] fix: updated migration adapter discovery --- sqlspec/adapters/oracledb/config.py | 3 + sqlspec/config.py | 8 ++- sqlspec/migrations/adapter_discovery.py | 93 ------------------------- sqlspec/migrations/commands.py | 9 +-- sqlspec/migrations/runner.py | 8 +-- sqlspec/migrations/tracker.py | 10 --- 6 files changed, 14 insertions(+), 117 deletions(-) delete mode 100644 sqlspec/migrations/adapter_discovery.py diff --git a/sqlspec/adapters/oracledb/config.py b/sqlspec/adapters/oracledb/config.py index b27d589b2..248921743 100644 --- a/sqlspec/adapters/oracledb/config.py +++ b/sqlspec/adapters/oracledb/config.py @@ -16,6 +16,7 @@ OracleSyncDriver, oracledb_statement_config, ) +from sqlspec.adapters.oracledb.migrations import OracleAsyncMigrationTracker, OracleSyncMigrationTracker from sqlspec.config import AsyncDatabaseConfig, SyncDatabaseConfig if TYPE_CHECKING: @@ -77,6 +78,7 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool" driver_type: ClassVar[type[OracleSyncDriver]] = OracleSyncDriver connection_type: "ClassVar[type[OracleSyncConnection]]" = OracleSyncConnection + migration_tracker_type: "ClassVar[type[OracleSyncMigrationTracker]]" = OracleSyncMigrationTracker def __init__( self, @@ -199,6 +201,7 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "AsyncConnect connection_type: "ClassVar[type[OracleAsyncConnection]]" = OracleAsyncConnection driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver + migration_tracker_type: "ClassVar[type[OracleAsyncMigrationTracker]]" = OracleAsyncMigrationTracker def __init__( self, diff --git a/sqlspec/config.py b/sqlspec/config.py index 99bd3675f..74554f9a8 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -5,6 +5,7 @@ from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig from sqlspec.core.statement import StatementConfig +from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker from sqlspec.utils.logging import get_logger if TYPE_CHECKING: @@ -160,6 +161,7 @@ class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]): __slots__ = ("connection_config",) is_async: "ClassVar[bool]" = False supports_connection_pooling: "ClassVar[bool]" = False + migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker def __init__( self, @@ -210,9 +212,9 @@ class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]): """Base class for an async database configurations that do not implement a pool.""" __slots__ = ("connection_config",) - is_async: "ClassVar[bool]" = True supports_connection_pooling: "ClassVar[bool]" = False + migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker def __init__( self, @@ -263,9 +265,9 @@ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): """Generic Sync Database Configuration.""" __slots__ = ("pool_config",) - is_async: "ClassVar[bool]" = False supports_connection_pooling: "ClassVar[bool]" = True + migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker def __init__( self, @@ -339,9 +341,9 @@ class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): """Generic Async Database Configuration.""" __slots__ = ("pool_config",) - is_async: "ClassVar[bool]" = True supports_connection_pooling: "ClassVar[bool]" = True + migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker def __init__( self, diff --git a/sqlspec/migrations/adapter_discovery.py b/sqlspec/migrations/adapter_discovery.py deleted file mode 100644 index 42012deac..000000000 --- a/sqlspec/migrations/adapter_discovery.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Adapter-specific migration discovery and loading. - -This module provides functionality to discover and load adapter-specific -migration implementations when available. -""" - -import importlib -from typing import TYPE_CHECKING, Any, cast - -from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker -from sqlspec.utils.logging import get_logger - -if TYPE_CHECKING: - from sqlspec.migrations.base import BaseMigrationTracker - -logger = get_logger("migrations.adapter_discovery") - -__all__ = ("discover_migration_tracker",) - - -def discover_migration_tracker(config: Any, sync: bool = True) -> "type[BaseMigrationTracker[Any]]": - """Discover and return adapter-specific migration tracker if available. - - Args: - config: The SQLSpec configuration object. - sync: Whether to discover sync (True) or async (False) tracker. - - Returns: - Adapter-specific tracker class or default tracker class. - """ - # Extract adapter name from config class - config_class_name = type(config).__name__ - - # Map config class names to adapter module names - adapter_mapping = { - "SqliteConfig": "sqlite", - "DuckDBConfig": "duckdb", - "PsycopgSyncConfig": "psycopg", - "PsycopgAsyncConfig": "psycopg", - "AsyncpgConfig": "asyncpg", - "PsqlpyConfig": "psqlpy", - "AsyncmyConfig": "asyncmy", - "AiosqliteConfig": "aiosqlite", - "OracleSyncConfig": "oracledb", - "OracleAsyncConfig": "oracledb", - "ADBCConfig": "adbc", - "BigQueryConfig": "bigquery", - } - - adapter_name = adapter_mapping.get(config_class_name) - - if not adapter_name: - logger.debug("No adapter mapping found for config %s, using default tracker", config_class_name) - return SyncMigrationTracker if sync else AsyncMigrationTracker - - # Try to import adapter-specific migrations module - try: - module_path = f"sqlspec.adapters.{adapter_name}.migrations" - migrations_module = importlib.import_module(module_path) - - # Look for adapter-specific tracker classes - if sync: - tracker_class_names = [ - "OracleSyncMigrationTracker" - if adapter_name == "oracledb" - else f"{adapter_name.title()}SyncMigrationTracker", - f"{adapter_name.upper()}SyncMigrationTracker", - "SyncMigrationTracker", - ] - else: - tracker_class_names = [ - "OracleAsyncMigrationTracker" - if adapter_name == "oracledb" - else f"{adapter_name.title()}AsyncMigrationTracker", - f"{adapter_name.upper()}AsyncMigrationTracker", - "AsyncMigrationTracker", - ] - - for class_name in tracker_class_names: - if hasattr(migrations_module, class_name): - tracker_class = getattr(migrations_module, class_name) - logger.debug("Using adapter-specific tracker: %s.%s", module_path, class_name) - return cast("type[BaseMigrationTracker[Any]]", tracker_class) - - logger.debug("No suitable tracker class found in %s, using default", module_path) - - except ImportError: - logger.debug("No adapter-specific migrations module found for %s, using default tracker", adapter_name) - except Exception as e: - logger.warning("Error loading adapter-specific migrations for %s: %s", adapter_name, e) - - # Fall back to default tracker - return SyncMigrationTracker if sync else AsyncMigrationTracker diff --git a/sqlspec/migrations/commands.py b/sqlspec/migrations/commands.py index 02b294924..7263971d6 100644 --- a/sqlspec/migrations/commands.py +++ b/sqlspec/migrations/commands.py @@ -9,7 +9,6 @@ from rich.table import Table from sqlspec._sql import sql -from sqlspec.migrations.adapter_discovery import discover_migration_tracker from sqlspec.migrations.base import BaseMigrationCommands from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner from sqlspec.migrations.utils import create_migration_file @@ -35,8 +34,7 @@ def __init__(self, config: "SyncConfigT") -> None: config: The SQLSpec configuration. """ super().__init__(config) - tracker_class = discover_migration_tracker(config, sync=True) - self.tracker = tracker_class(self.version_table) + self.tracker = config.migration_tracker_type(self.version_table) self.runner = SyncMigrationRunner(self.migrations_path) def init(self, directory: str, package: bool = True) -> None: @@ -144,7 +142,6 @@ def downgrade(self, revision: str = "-1") -> None: if revision == "-1": to_revert = [applied[-1]] elif revision == "base": - # Revert all migrations to get back to base state to_revert = list(reversed(applied)) else: for migration in reversed(applied): @@ -213,8 +210,7 @@ def __init__(self, sqlspec_config: "AsyncConfigT") -> None: sqlspec_config: The SQLSpec configuration. """ super().__init__(sqlspec_config) - tracker_class = discover_migration_tracker(sqlspec_config, sync=False) - self.tracker = tracker_class(self.version_table) + self.tracker = sqlspec_config.migration_tracker_type(self.version_table) self.runner = AsyncMigrationRunner(self.migrations_path) async def init(self, directory: str, package: bool = True) -> None: @@ -313,7 +309,6 @@ async def downgrade(self, revision: str = "-1") -> None: if revision == "-1": to_revert = [applied[-1]] elif revision == "base": - # Revert all migrations to get back to base state to_revert = list(reversed(applied)) else: for migration in reversed(applied): diff --git a/sqlspec/migrations/runner.py b/sqlspec/migrations/runner.py index 05c3594f7..c7e050782 100644 --- a/sqlspec/migrations/runner.py +++ b/sqlspec/migrations/runner.py @@ -60,7 +60,7 @@ def execute_upgrade( return None, 0 start_time = time.time() - # Execute each SQL statement separately + for sql_statement in upgrade_sql_list: if sql_statement.strip(): driver.execute_script(sql_statement) @@ -84,7 +84,7 @@ def execute_downgrade( return None, 0 start_time = time.time() - # Execute each SQL statement separately + for sql_statement in downgrade_sql_list: if sql_statement.strip(): driver.execute_script(sql_statement) @@ -235,7 +235,7 @@ async def execute_upgrade( return None, 0 start_time = time.time() - # Execute each SQL statement separately + for sql_statement in upgrade_sql_list: if sql_statement.strip(): await driver.execute_script(sql_statement) @@ -259,7 +259,7 @@ async def execute_downgrade( return None, 0 start_time = time.time() - # Execute each SQL statement separately + for sql_statement in downgrade_sql_list: if sql_statement.strip(): await driver.execute_script(sql_statement) diff --git a/sqlspec/migrations/tracker.py b/sqlspec/migrations/tracker.py index 68b9fe101..00b587156 100644 --- a/sqlspec/migrations/tracker.py +++ b/sqlspec/migrations/tracker.py @@ -89,21 +89,16 @@ def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None: driver: The database driver to use. """ try: - # Check if the connection has autocommit enabled connection = getattr(driver, "connection", None) if connection and hasattr(connection, "autocommit") and getattr(connection, "autocommit", False): return - # For ADBC and other drivers, check the driver_features driver_features = getattr(driver, "driver_features", {}) if driver_features and driver_features.get("autocommit", False): return - # Safe to commit manually driver.commit() except Exception: - # If commit fails due to no active transaction, that's acceptable - # Some drivers with autocommit will fail when trying to commit logger.debug("Failed to commit transaction, likely due to autocommit being enabled") @@ -179,19 +174,14 @@ async def _safe_commit_async(self, driver: "AsyncDriverAdapterBase") -> None: driver: The database driver to use. """ try: - # Check if the connection has autocommit enabled connection = getattr(driver, "connection", None) if connection and hasattr(connection, "autocommit") and getattr(connection, "autocommit", False): return - # For ADBC and other drivers, check the driver_features driver_features = getattr(driver, "driver_features", {}) if driver_features and driver_features.get("autocommit", False): return - # Safe to commit manually await driver.commit() except Exception: - # If commit fails due to no active transaction, that's acceptable - # Some drivers with autocommit will fail when trying to commit logger.debug("Failed to commit transaction, likely due to autocommit being enabled")