Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sqlspec/adapters/oracledb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
OracleSyncDriver,
oracledb_statement_config,
)
from sqlspec.adapters.oracledb.migrations import OracleAsyncMigrationTracker, OracleSyncMigrationTracker
from sqlspec.config import AsyncDatabaseConfig, SyncDatabaseConfig

if TYPE_CHECKING:
Expand Down Expand Up @@ -77,6 +78,7 @@ class OracleSyncConfig(SyncDatabaseConfig[OracleSyncConnection, "ConnectionPool"

driver_type: ClassVar[type[OracleSyncDriver]] = OracleSyncDriver
connection_type: "ClassVar[type[OracleSyncConnection]]" = OracleSyncConnection
migration_tracker_type: "ClassVar[type[OracleSyncMigrationTracker]]" = OracleSyncMigrationTracker

def __init__(
self,
Expand Down Expand Up @@ -199,6 +201,7 @@ class OracleAsyncConfig(AsyncDatabaseConfig[OracleAsyncConnection, "AsyncConnect

connection_type: "ClassVar[type[OracleAsyncConnection]]" = OracleAsyncConnection
driver_type: ClassVar[type[OracleAsyncDriver]] = OracleAsyncDriver
migration_tracker_type: "ClassVar[type[OracleAsyncMigrationTracker]]" = OracleAsyncMigrationTracker

def __init__(
self,
Expand Down
8 changes: 5 additions & 3 deletions sqlspec/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sqlspec.core.parameters import ParameterStyle, ParameterStyleConfig
from sqlspec.core.statement import StatementConfig
from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker
from sqlspec.utils.logging import get_logger

if TYPE_CHECKING:
Expand Down Expand Up @@ -160,6 +161,7 @@ class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
__slots__ = ("connection_config",)
is_async: "ClassVar[bool]" = False
supports_connection_pooling: "ClassVar[bool]" = False
migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker

def __init__(
self,
Expand Down Expand Up @@ -210,9 +212,9 @@ class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]):
"""Base class for an async database configurations that do not implement a pool."""

__slots__ = ("connection_config",)

is_async: "ClassVar[bool]" = True
supports_connection_pooling: "ClassVar[bool]" = False
migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker

def __init__(
self,
Expand Down Expand Up @@ -263,9 +265,9 @@ class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
"""Generic Sync Database Configuration."""

__slots__ = ("pool_config",)

is_async: "ClassVar[bool]" = False
supports_connection_pooling: "ClassVar[bool]" = True
migration_tracker_type: "ClassVar[type[Any]]" = SyncMigrationTracker

def __init__(
self,
Expand Down Expand Up @@ -339,9 +341,9 @@ class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]):
"""Generic Async Database Configuration."""

__slots__ = ("pool_config",)

is_async: "ClassVar[bool]" = True
supports_connection_pooling: "ClassVar[bool]" = True
migration_tracker_type: "ClassVar[type[Any]]" = AsyncMigrationTracker

def __init__(
self,
Expand Down
93 changes: 0 additions & 93 deletions sqlspec/migrations/adapter_discovery.py

This file was deleted.

9 changes: 2 additions & 7 deletions sqlspec/migrations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rich.table import Table

from sqlspec._sql import sql
from sqlspec.migrations.adapter_discovery import discover_migration_tracker
from sqlspec.migrations.base import BaseMigrationCommands
from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner
from sqlspec.migrations.utils import create_migration_file
Expand All @@ -35,8 +34,7 @@ def __init__(self, config: "SyncConfigT") -> None:
config: The SQLSpec configuration.
"""
super().__init__(config)
tracker_class = discover_migration_tracker(config, sync=True)
self.tracker = tracker_class(self.version_table)
self.tracker = config.migration_tracker_type(self.version_table)
self.runner = SyncMigrationRunner(self.migrations_path)

def init(self, directory: str, package: bool = True) -> None:
Expand Down Expand Up @@ -144,7 +142,6 @@ def downgrade(self, revision: str = "-1") -> None:
if revision == "-1":
to_revert = [applied[-1]]
elif revision == "base":
# Revert all migrations to get back to base state
to_revert = list(reversed(applied))
else:
for migration in reversed(applied):
Expand Down Expand Up @@ -213,8 +210,7 @@ def __init__(self, sqlspec_config: "AsyncConfigT") -> None:
sqlspec_config: The SQLSpec configuration.
"""
super().__init__(sqlspec_config)
tracker_class = discover_migration_tracker(sqlspec_config, sync=False)
self.tracker = tracker_class(self.version_table)
self.tracker = sqlspec_config.migration_tracker_type(self.version_table)
self.runner = AsyncMigrationRunner(self.migrations_path)

async def init(self, directory: str, package: bool = True) -> None:
Expand Down Expand Up @@ -313,7 +309,6 @@ async def downgrade(self, revision: str = "-1") -> None:
if revision == "-1":
to_revert = [applied[-1]]
elif revision == "base":
# Revert all migrations to get back to base state
to_revert = list(reversed(applied))
else:
for migration in reversed(applied):
Expand Down
8 changes: 4 additions & 4 deletions sqlspec/migrations/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def execute_upgrade(
return None, 0

start_time = time.time()
# Execute each SQL statement separately

for sql_statement in upgrade_sql_list:
if sql_statement.strip():
driver.execute_script(sql_statement)
Expand All @@ -84,7 +84,7 @@ def execute_downgrade(
return None, 0

start_time = time.time()
# Execute each SQL statement separately

for sql_statement in downgrade_sql_list:
if sql_statement.strip():
driver.execute_script(sql_statement)
Expand Down Expand Up @@ -235,7 +235,7 @@ async def execute_upgrade(
return None, 0

start_time = time.time()
# Execute each SQL statement separately

for sql_statement in upgrade_sql_list:
if sql_statement.strip():
await driver.execute_script(sql_statement)
Expand All @@ -259,7 +259,7 @@ async def execute_downgrade(
return None, 0

start_time = time.time()
# Execute each SQL statement separately

for sql_statement in downgrade_sql_list:
if sql_statement.strip():
await driver.execute_script(sql_statement)
Expand Down
10 changes: 0 additions & 10 deletions sqlspec/migrations/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,16 @@ def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None:
driver: The database driver to use.
"""
try:
# Check if the connection has autocommit enabled
connection = getattr(driver, "connection", None)
if connection and hasattr(connection, "autocommit") and getattr(connection, "autocommit", False):
return

# For ADBC and other drivers, check the driver_features
driver_features = getattr(driver, "driver_features", {})
if driver_features and driver_features.get("autocommit", False):
return

# Safe to commit manually
driver.commit()
except Exception:
# If commit fails due to no active transaction, that's acceptable
# Some drivers with autocommit will fail when trying to commit
logger.debug("Failed to commit transaction, likely due to autocommit being enabled")


Expand Down Expand Up @@ -179,19 +174,14 @@ async def _safe_commit_async(self, driver: "AsyncDriverAdapterBase") -> None:
driver: The database driver to use.
"""
try:
# Check if the connection has autocommit enabled
connection = getattr(driver, "connection", None)
if connection and hasattr(connection, "autocommit") and getattr(connection, "autocommit", False):
return

# For ADBC and other drivers, check the driver_features
driver_features = getattr(driver, "driver_features", {})
if driver_features and driver_features.get("autocommit", False):
return

# Safe to commit manually
await driver.commit()
except Exception:
# If commit fails due to no active transaction, that's acceptable
# Some drivers with autocommit will fail when trying to commit
logger.debug("Failed to commit transaction, likely due to autocommit being enabled")