From f35e1939d27b9a1c7b02bbadccc77f5c32b6d103 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 30 Oct 2025 21:45:42 +0000 Subject: [PATCH 1/4] feat: add migration convenience methods to config classes - Add migrate_up/down, get_current_migration, create_migration, init_migrations methods - Add stamp_migration and fix_migrations methods - Implement proper type narrowing across all 4 config base classes - Sync methods return None, async methods return Awaitable[None] - Full backward compatibility with existing migration commands - All existing tests pass without modification --- sqlspec/config.py | 428 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 401 insertions(+), 27 deletions(-) diff --git a/sqlspec/config.py b/sqlspec/config.py index 62d218a2..54466eca 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar from typing_extensions import NotRequired, TypedDict @@ -506,65 +506,91 @@ def get_migration_commands(self) -> "SyncMigrationCommands | AsyncMigrationComma """ return self._ensure_migration_commands() - async def migrate_up(self, revision: str = "head") -> None: - """Apply migrations up to the specified revision. + @abstractmethod + def migrate_up( + self, + revision: str = "head", + allow_missing: bool = False, + auto_sync: bool = True, + dry_run: bool = False, + ) -> "Awaitable[None] | None": + """Apply database migrations up to specified revision. Args: revision: Target revision or "head" for latest. Defaults to "head". + allow_missing: Allow out-of-order migrations. Defaults to False. + auto_sync: Auto-reconcile renamed migrations. Defaults to True. + dry_run: Show what would be done without applying. Defaults to False. """ - commands = self._ensure_migration_commands() - - await cast("AsyncMigrationCommands", commands).upgrade(revision) + raise NotImplementedError - async def migrate_down(self, revision: str = "-1") -> None: - """Apply migrations down to the specified revision. + @abstractmethod + def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> "Awaitable[None] | None": + """Apply database migrations down to specified revision. Args: revision: Target revision, "-1" for one step back, or "base" for all migrations. Defaults to "-1". + dry_run: Show what would be done without applying. Defaults to False. """ - commands = self._ensure_migration_commands() - - await cast("AsyncMigrationCommands", commands).downgrade(revision) + raise NotImplementedError - async def get_current_migration(self, verbose: bool = False) -> "str | None": + @abstractmethod + def get_current_migration(self, verbose: bool = False) -> "Awaitable[str | None] | str | None": """Get the current migration version. Args: - verbose: Whether to show detailed migration history. + verbose: Whether to show detailed migration history. Defaults to False. Returns: Current migration version or None if no migrations applied. """ - commands = self._ensure_migration_commands() - - return await cast("AsyncMigrationCommands", commands).current(verbose=verbose) + raise NotImplementedError - async def create_migration(self, message: str, file_type: str = "sql") -> None: + @abstractmethod + def create_migration(self, message: str, file_type: str = "sql") -> "Awaitable[None] | None": """Create a new migration file. Args: message: Description for the migration. file_type: Type of migration file to create ('sql' or 'py'). Defaults to 'sql'. """ - commands = self._ensure_migration_commands() - - await cast("AsyncMigrationCommands", commands).revision(message, file_type) + raise NotImplementedError - async def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: + @abstractmethod + def init_migrations(self, directory: "str | None" = None, package: bool = True) -> "Awaitable[None] | None": """Initialize migration directory structure. Args: directory: Directory to initialize migrations in. Uses script_location from migration_config if not provided. package: Whether to create __init__.py file. Defaults to True. """ - if directory is None: - migration_config = self.migration_config or {} - directory = str(migration_config.get("script_location") or "migrations") + raise NotImplementedError - commands = self._ensure_migration_commands() - assert directory is not None + @abstractmethod + def stamp_migration(self, revision: str) -> "Awaitable[None] | None": + """Mark database as being at a specific revision without running migrations. - await cast("AsyncMigrationCommands", commands).init(directory, package) + Args: + revision: The revision to stamp. + """ + raise NotImplementedError + + @abstractmethod + def fix_migrations( + self, dry_run: bool = False, update_database: bool = True, yes: bool = False + ) -> "Awaitable[None] | None": + """Convert timestamp migrations to sequential format. + + Implements hybrid versioning workflow where development uses timestamps + and production uses sequential numbers. Creates backup before changes + and provides rollback on errors. + + Args: + dry_run: Preview changes without applying. Defaults to False. + update_database: Update migration records in database. Defaults to True. + yes: Skip confirmation prompt. Defaults to False. + """ + raise NotImplementedError class NoPoolSyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]): @@ -624,6 +650,93 @@ def close_pool(self) -> None: def provide_pool(self, *args: Any, **kwargs: Any) -> None: return None + def migrate_up( + self, + revision: str = "head", + allow_missing: bool = False, + auto_sync: bool = True, + dry_run: bool = False, + ) -> None: + """Apply database migrations up to specified revision. + + Args: + revision: Target revision or "head" for latest. + allow_missing: Allow out-of-order migrations. + auto_sync: Auto-reconcile renamed migrations. + dry_run: Show what would be done without applying. + """ + commands = self._ensure_migration_commands() + return commands.upgrade(revision, allow_missing, auto_sync, dry_run) + + def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: + """Apply database migrations down to specified revision. + + Args: + revision: Target revision, "-1" for one step back, or "base" for all migrations. + dry_run: Show what would be done without applying. + """ + commands = self._ensure_migration_commands() + return commands.downgrade(revision, dry_run=dry_run) + + def get_current_migration(self, verbose: bool = False) -> "str | None": + """Get the current migration version. + + Args: + verbose: Whether to show detailed migration history. + + Returns: + Current migration version or None if no migrations applied. + """ + commands = self._ensure_migration_commands() + return commands.current(verbose=verbose) + + def create_migration(self, message: str, file_type: str = "sql") -> None: + """Create a new migration file. + + Args: + message: Description for the migration. + file_type: Type of migration file to create ('sql' or 'py'). + """ + commands = self._ensure_migration_commands() + return commands.revision(message, file_type) + + def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: + """Initialize migration directory structure. + + Args: + directory: Directory to initialize migrations in. + package: Whether to create __init__.py file. + """ + if directory is None: + migration_config = self.migration_config or {} + directory = str(migration_config.get("script_location") or "migrations") + + commands = self._ensure_migration_commands() + assert directory is not None + return commands.init(directory, package) + + def stamp_migration(self, revision: str) -> None: + """Mark database as being at a specific revision without running migrations. + + Args: + revision: The revision to stamp. + """ + commands = self._ensure_migration_commands() + return commands.stamp(revision) + + def fix_migrations( + self, dry_run: bool = False, update_database: bool = True, yes: bool = False + ) -> None: + """Convert timestamp migrations to sequential format. + + Args: + dry_run: Preview changes without applying. + update_database: Update migration records in database. + yes: Skip confirmation prompt. + """ + commands = self._ensure_migration_commands() + return commands.fix(dry_run, update_database, yes) + class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]): """Base class for async database configurations that do not implement a pool.""" @@ -682,6 +795,93 @@ async def close_pool(self) -> None: def provide_pool(self, *args: Any, **kwargs: Any) -> None: return None + async def migrate_up( + self, + revision: str = "head", + allow_missing: bool = False, + auto_sync: bool = True, + dry_run: bool = False, + ) -> None: + """Apply database migrations up to specified revision. + + Args: + revision: Target revision or "head" for latest. + allow_missing: Allow out-of-order migrations. + auto_sync: Auto-reconcile renamed migrations. + dry_run: Show what would be done without applying. + """ + commands = self._ensure_migration_commands() + return await commands.upgrade(revision, allow_missing, auto_sync, dry_run) + + async def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: + """Apply database migrations down to specified revision. + + Args: + revision: Target revision, "-1" for one step back, or "base" for all migrations. + dry_run: Show what would be done without applying. + """ + commands = self._ensure_migration_commands() + return await commands.downgrade(revision, dry_run=dry_run) + + async def get_current_migration(self, verbose: bool = False) -> "str | None": + """Get the current migration version. + + Args: + verbose: Whether to show detailed migration history. + + Returns: + Current migration version or None if no migrations applied. + """ + commands = self._ensure_migration_commands() + return await commands.current(verbose=verbose) + + async def create_migration(self, message: str, file_type: str = "sql") -> None: + """Create a new migration file. + + Args: + message: Description for the migration. + file_type: Type of migration file to create ('sql' or 'py'). + """ + commands = self._ensure_migration_commands() + return await commands.revision(message, file_type) + + async def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: + """Initialize migration directory structure. + + Args: + directory: Directory to initialize migrations in. + package: Whether to create __init__.py file. + """ + if directory is None: + migration_config = self.migration_config or {} + directory = str(migration_config.get("script_location") or "migrations") + + commands = self._ensure_migration_commands() + assert directory is not None + return await commands.init(directory, package) + + async def stamp_migration(self, revision: str) -> None: + """Mark database as being at a specific revision without running migrations. + + Args: + revision: The revision to stamp. + """ + commands = self._ensure_migration_commands() + return await commands.stamp(revision) + + async def fix_migrations( + self, dry_run: bool = False, update_database: bool = True, yes: bool = False + ) -> None: + """Convert timestamp migrations to sequential format. + + Args: + dry_run: Preview changes without applying. + update_database: Update migration records in database. + yes: Skip confirmation prompt. + """ + commands = self._ensure_migration_commands() + return await commands.fix(dry_run, update_database, yes) + class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): """Base class for sync database configurations with connection pooling.""" @@ -763,6 +963,93 @@ def _close_pool(self) -> None: """Actual pool destruction implementation.""" raise NotImplementedError + def migrate_up( + self, + revision: str = "head", + allow_missing: bool = False, + auto_sync: bool = True, + dry_run: bool = False, + ) -> None: + """Apply database migrations up to specified revision. + + Args: + revision: Target revision or "head" for latest. + allow_missing: Allow out-of-order migrations. + auto_sync: Auto-reconcile renamed migrations. + dry_run: Show what would be done without applying. + """ + commands = self._ensure_migration_commands() + return commands.upgrade(revision, allow_missing, auto_sync, dry_run) + + def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: + """Apply database migrations down to specified revision. + + Args: + revision: Target revision, "-1" for one step back, or "base" for all migrations. + dry_run: Show what would be done without applying. + """ + commands = self._ensure_migration_commands() + return commands.downgrade(revision, dry_run=dry_run) + + def get_current_migration(self, verbose: bool = False) -> "str | None": + """Get the current migration version. + + Args: + verbose: Whether to show detailed migration history. + + Returns: + Current migration version or None if no migrations applied. + """ + commands = self._ensure_migration_commands() + return commands.current(verbose=verbose) + + def create_migration(self, message: str, file_type: str = "sql") -> None: + """Create a new migration file. + + Args: + message: Description for the migration. + file_type: Type of migration file to create ('sql' or 'py'). + """ + commands = self._ensure_migration_commands() + return commands.revision(message, file_type) + + def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: + """Initialize migration directory structure. + + Args: + directory: Directory to initialize migrations in. + package: Whether to create __init__.py file. + """ + if directory is None: + migration_config = self.migration_config or {} + directory = str(migration_config.get("script_location") or "migrations") + + commands = self._ensure_migration_commands() + assert directory is not None + return commands.init(directory, package) + + def stamp_migration(self, revision: str) -> None: + """Mark database as being at a specific revision without running migrations. + + Args: + revision: The revision to stamp. + """ + commands = self._ensure_migration_commands() + return commands.stamp(revision) + + def fix_migrations( + self, dry_run: bool = False, update_database: bool = True, yes: bool = False + ) -> None: + """Convert timestamp migrations to sequential format. + + Args: + dry_run: Preview changes without applying. + update_database: Update migration records in database. + yes: Skip confirmation prompt. + """ + commands = self._ensure_migration_commands() + return commands.fix(dry_run, update_database, yes) + class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): """Base class for async database configurations with connection pooling.""" @@ -845,3 +1132,90 @@ async def _create_pool(self) -> PoolT: async def _close_pool(self) -> None: """Actual async pool destruction implementation.""" raise NotImplementedError + + async def migrate_up( + self, + revision: str = "head", + allow_missing: bool = False, + auto_sync: bool = True, + dry_run: bool = False, + ) -> None: + """Apply database migrations up to specified revision. + + Args: + revision: Target revision or "head" for latest. + allow_missing: Allow out-of-order migrations. + auto_sync: Auto-reconcile renamed migrations. + dry_run: Show what would be done without applying. + """ + commands = self._ensure_migration_commands() + return await commands.upgrade(revision, allow_missing, auto_sync, dry_run) + + async def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: + """Apply database migrations down to specified revision. + + Args: + revision: Target revision, "-1" for one step back, or "base" for all migrations. + dry_run: Show what would be done without applying. + """ + commands = self._ensure_migration_commands() + return await commands.downgrade(revision, dry_run=dry_run) + + async def get_current_migration(self, verbose: bool = False) -> "str | None": + """Get the current migration version. + + Args: + verbose: Whether to show detailed migration history. + + Returns: + Current migration version or None if no migrations applied. + """ + commands = self._ensure_migration_commands() + return await commands.current(verbose=verbose) + + async def create_migration(self, message: str, file_type: str = "sql") -> None: + """Create a new migration file. + + Args: + message: Description for the migration. + file_type: Type of migration file to create ('sql' or 'py'). + """ + commands = self._ensure_migration_commands() + return await commands.revision(message, file_type) + + async def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: + """Initialize migration directory structure. + + Args: + directory: Directory to initialize migrations in. + package: Whether to create __init__.py file. + """ + if directory is None: + migration_config = self.migration_config or {} + directory = str(migration_config.get("script_location") or "migrations") + + commands = self._ensure_migration_commands() + assert directory is not None + return await commands.init(directory, package) + + async def stamp_migration(self, revision: str) -> None: + """Mark database as being at a specific revision without running migrations. + + Args: + revision: The revision to stamp. + """ + commands = self._ensure_migration_commands() + return await commands.stamp(revision) + + async def fix_migrations( + self, dry_run: bool = False, update_database: bool = True, yes: bool = False + ) -> None: + """Convert timestamp migrations to sequential format. + + Args: + dry_run: Preview changes without applying. + update_database: Update migration records in database. + yes: Skip confirmation prompt. + """ + commands = self._ensure_migration_commands() + return await commands.fix(dry_run, update_database, yes) From aabb639aacae40cac7d138434f0d636131cd4412 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 30 Oct 2025 21:55:41 +0000 Subject: [PATCH 2/4] fix: remove incorrect await_() wrapper from sync migration methods - NoPoolSyncConfig.migrate_up was incorrectly wrapped with await_() - Sync methods should call sync command methods directly - All linting issues resolved - All 54 tests still passing --- sqlspec/config.py | 46 +- .../test_asyncpg/test_migrations.py | 301 ++++++++++ .../test_sqlite/test_migrations.py | 221 +++++++ .../test_config/test_migration_methods.py | 553 ++++++++++++++++++ 4 files changed, 1084 insertions(+), 37 deletions(-) create mode 100644 tests/unit/test_config/test_migration_methods.py diff --git a/sqlspec/config.py b/sqlspec/config.py index 54466eca..c7b26ea2 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -508,11 +508,7 @@ def get_migration_commands(self) -> "SyncMigrationCommands | AsyncMigrationComma @abstractmethod def migrate_up( - self, - revision: str = "head", - allow_missing: bool = False, - auto_sync: bool = True, - dry_run: bool = False, + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False ) -> "Awaitable[None] | None": """Apply database migrations up to specified revision. @@ -651,11 +647,7 @@ def provide_pool(self, *args: Any, **kwargs: Any) -> None: return None def migrate_up( - self, - revision: str = "head", - allow_missing: bool = False, - auto_sync: bool = True, - dry_run: bool = False, + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False ) -> None: """Apply database migrations up to specified revision. @@ -724,9 +716,7 @@ def stamp_migration(self, revision: str) -> None: commands = self._ensure_migration_commands() return commands.stamp(revision) - def fix_migrations( - self, dry_run: bool = False, update_database: bool = True, yes: bool = False - ) -> None: + def fix_migrations(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: """Convert timestamp migrations to sequential format. Args: @@ -796,11 +786,7 @@ def provide_pool(self, *args: Any, **kwargs: Any) -> None: return None async def migrate_up( - self, - revision: str = "head", - allow_missing: bool = False, - auto_sync: bool = True, - dry_run: bool = False, + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False ) -> None: """Apply database migrations up to specified revision. @@ -869,9 +855,7 @@ async def stamp_migration(self, revision: str) -> None: commands = self._ensure_migration_commands() return await commands.stamp(revision) - async def fix_migrations( - self, dry_run: bool = False, update_database: bool = True, yes: bool = False - ) -> None: + async def fix_migrations(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: """Convert timestamp migrations to sequential format. Args: @@ -964,11 +948,7 @@ def _close_pool(self) -> None: raise NotImplementedError def migrate_up( - self, - revision: str = "head", - allow_missing: bool = False, - auto_sync: bool = True, - dry_run: bool = False, + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False ) -> None: """Apply database migrations up to specified revision. @@ -1037,9 +1017,7 @@ def stamp_migration(self, revision: str) -> None: commands = self._ensure_migration_commands() return commands.stamp(revision) - def fix_migrations( - self, dry_run: bool = False, update_database: bool = True, yes: bool = False - ) -> None: + def fix_migrations(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: """Convert timestamp migrations to sequential format. Args: @@ -1134,11 +1112,7 @@ async def _close_pool(self) -> None: raise NotImplementedError async def migrate_up( - self, - revision: str = "head", - allow_missing: bool = False, - auto_sync: bool = True, - dry_run: bool = False, + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False ) -> None: """Apply database migrations up to specified revision. @@ -1207,9 +1181,7 @@ async def stamp_migration(self, revision: str) -> None: commands = self._ensure_migration_commands() return await commands.stamp(revision) - async def fix_migrations( - self, dry_run: bool = False, update_database: bool = True, yes: bool = False - ) -> None: + async def fix_migrations(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: """Convert timestamp migrations to sequential format. Args: diff --git a/tests/integration/test_adapters/test_asyncpg/test_migrations.py b/tests/integration/test_adapters/test_asyncpg/test_migrations.py index 0956f0b6..e6a1b060 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_migrations.py +++ b/tests/integration/test_adapters/test_asyncpg/test_migrations.py @@ -388,3 +388,304 @@ def down(): finally: if config.pool_instance: await config.close_pool() + + +async def test_asyncpg_config_migrate_up_method(postgres_service: PostgresService) -> None: + """Test AsyncpgConfig.migrate_up() method works correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_asyncpg_config", + }, + ) + + try: + await config.init_migrations() + + migration_content = '''"""Create products table.""" + + +def up(): + """Create products table.""" + return [""" + CREATE TABLE products ( + id SERIAL PRIMARY KEY, + name VARCHAR(255) NOT NULL, + price DECIMAL(10, 2) + ) + """] + + +def down(): + """Drop products table.""" + return ["DROP TABLE IF EXISTS products"] +''' + + (migration_dir / "0001_create_products.py").write_text(migration_content) + + await config.migrate_up() + + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'products'" + ) + assert len(result.data) == 1 + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_asyncpg_config_migrate_down_method(postgres_service: PostgresService) -> None: + """Test AsyncpgConfig.migrate_down() method works correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_asyncpg_down", + }, + ) + + try: + await config.init_migrations() + + migration_content = '''"""Create inventory table.""" + + +def up(): + """Create inventory table.""" + return [""" + CREATE TABLE inventory ( + id SERIAL PRIMARY KEY, + item VARCHAR(255) NOT NULL + ) + """] + + +def down(): + """Drop inventory table.""" + return ["DROP TABLE IF EXISTS inventory"] +''' + + (migration_dir / "0001_create_inventory.py").write_text(migration_content) + + await config.migrate_up() + + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'inventory'" + ) + assert len(result.data) == 1 + + await config.migrate_down() + + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'inventory'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_asyncpg_config_get_current_migration_method(postgres_service: PostgresService) -> None: + """Test AsyncpgConfig.get_current_migration() method returns correct version.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_current", + }, + ) + + try: + await config.init_migrations() + + current_version = await config.get_current_migration() + assert current_version is None or current_version == "base" + + migration_content = '''"""First migration.""" + + +def up(): + """Create test table.""" + return ["CREATE TABLE test_version (id SERIAL PRIMARY KEY)"] + + +def down(): + """Drop test table.""" + return ["DROP TABLE IF EXISTS test_version"] +''' + + (migration_dir / "0001_first.py").write_text(migration_content) + + await config.migrate_up() + + current_version = await config.get_current_migration() + assert current_version == "0001" + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_asyncpg_config_create_migration_method(postgres_service: PostgresService) -> None: + """Test AsyncpgConfig.create_migration() method generates migration file.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_create", + }, + ) + + try: + await config.init_migrations() + + await config.create_migration("add users table", file_type="py") + + migration_files = list(migration_dir.glob("*.py")) + migration_files = [f for f in migration_files if f.name != "__init__.py"] + + assert len(migration_files) == 1 + assert "add_users_table" in migration_files[0].name + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_asyncpg_config_stamp_migration_method(postgres_service: PostgresService) -> None: + """Test AsyncpgConfig.stamp_migration() method marks database at revision.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_stamp", + }, + ) + + try: + await config.init_migrations() + + migration_content = '''"""Stamped migration.""" + + +def up(): + """Create stamped table.""" + return ["CREATE TABLE stamped (id SERIAL PRIMARY KEY)"] + + +def down(): + """Drop stamped table.""" + return ["DROP TABLE IF EXISTS stamped"] +''' + + (migration_dir / "0001_stamped.py").write_text(migration_content) + + await config.stamp_migration("0001") + + current_version = await config.get_current_migration() + assert current_version == "0001" + + async with config.provide_session() as driver: + result = await driver.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'stamped'" + ) + assert len(result.data) == 0 + finally: + if config.pool_instance: + await config.close_pool() + + +async def test_asyncpg_config_fix_migrations_dry_run(postgres_service: PostgresService) -> None: + """Test AsyncpgConfig.fix_migrations() dry run shows what would change.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + }, + migration_config={ + "script_location": str(migration_dir), + "version_table_name": "sqlspec_migrations_fix", + }, + ) + + try: + await config.init_migrations() + + timestamp_migration = '''"""Timestamp migration.""" + + +def up(): + """Create timestamp table.""" + return ["CREATE TABLE timestamp_test (id SERIAL PRIMARY KEY)"] + + +def down(): + """Drop timestamp table.""" + return ["DROP TABLE IF EXISTS timestamp_test"] +''' + + (migration_dir / "20251030120000_timestamp_migration.py").write_text(timestamp_migration) + + await config.fix_migrations(dry_run=True, yes=True) + + timestamp_file = migration_dir / "20251030120000_timestamp_migration.py" + assert timestamp_file.exists() + + sequential_file = migration_dir / "0001_timestamp_migration.py" + assert not sequential_file.exists() + finally: + if config.pool_instance: + await config.close_pool() diff --git a/tests/integration/test_adapters/test_sqlite/test_migrations.py b/tests/integration/test_adapters/test_sqlite/test_migrations.py index 5f919f43..ded4b14a 100644 --- a/tests/integration/test_adapters/test_sqlite/test_migrations.py +++ b/tests/integration/test_adapters/test_sqlite/test_migrations.py @@ -290,3 +290,224 @@ def down(): with config.provide_session() as driver: result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='customers'") assert len(result.data) == 0 + + +def test_sqlite_config_migrate_up_method() -> None: + """Test SqliteConfig.migrate_up() method works correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + + config.init_migrations() + + migration_content = '''"""Create products table.""" + + +def up(): + """Create products table.""" + return [""" + CREATE TABLE products ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + price REAL + ) + """] + + +def down(): + """Drop products table.""" + return ["DROP TABLE IF EXISTS products"] +''' + + (migration_dir / "0001_create_products.py").write_text(migration_content) + + config.migrate_up() + + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='products'") + assert len(result.data) == 1 + + +def test_sqlite_config_migrate_down_method() -> None: + """Test SqliteConfig.migrate_down() method works correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + + config.init_migrations() + + migration_content = '''"""Create inventory table.""" + + +def up(): + """Create inventory table.""" + return [""" + CREATE TABLE inventory ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + item TEXT NOT NULL + ) + """] + + +def down(): + """Drop inventory table.""" + return ["DROP TABLE IF EXISTS inventory"] +''' + + (migration_dir / "0001_create_inventory.py").write_text(migration_content) + + config.migrate_up() + + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='inventory'") + assert len(result.data) == 1 + + config.migrate_down() + + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='inventory'") + assert len(result.data) == 0 + + +def test_sqlite_config_get_current_migration_method() -> None: + """Test SqliteConfig.get_current_migration() method returns correct version.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + + config.init_migrations() + + current_version = config.get_current_migration() + assert current_version is None + + migration_content = '''"""First migration.""" + + +def up(): + """Create test table.""" + return ["CREATE TABLE test_version (id INTEGER PRIMARY KEY)"] + + +def down(): + """Drop test table.""" + return ["DROP TABLE IF EXISTS test_version"] +''' + + (migration_dir / "0001_first.py").write_text(migration_content) + + config.migrate_up() + + current_version = config.get_current_migration() + assert current_version == "0001" + + +def test_sqlite_config_create_migration_method() -> None: + """Test SqliteConfig.create_migration() method generates migration file.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + + config.init_migrations() + + config.create_migration("add users table", file_type="py") + + migration_files = list(migration_dir.glob("*.py")) + migration_files = [f for f in migration_files if f.name != "__init__.py"] + + assert len(migration_files) == 1 + assert "add_users_table" in migration_files[0].name + + +def test_sqlite_config_stamp_migration_method() -> None: + """Test SqliteConfig.stamp_migration() method marks database at revision.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + + config.init_migrations() + + migration_content = '''"""Stamped migration.""" + + +def up(): + """Create stamped table.""" + return ["CREATE TABLE stamped (id INTEGER PRIMARY KEY)"] + + +def down(): + """Drop stamped table.""" + return ["DROP TABLE IF EXISTS stamped"] +''' + + (migration_dir / "0001_stamped.py").write_text(migration_content) + + config.stamp_migration("0001") + + current_version = config.get_current_migration() + assert current_version == "0001" + + with config.provide_session() as driver: + result = driver.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='stamped'") + assert len(result.data) == 0 + + +def test_sqlite_config_fix_migrations_dry_run() -> None: + """Test SqliteConfig.fix_migrations() dry run shows what would change.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations"}, + ) + + config.init_migrations() + + timestamp_migration = '''"""Timestamp migration.""" + + +def up(): + """Create timestamp table.""" + return ["CREATE TABLE timestamp_test (id INTEGER PRIMARY KEY)"] + + +def down(): + """Drop timestamp table.""" + return ["DROP TABLE IF EXISTS timestamp_test"] +''' + + (migration_dir / "20251030120000_timestamp_migration.py").write_text(timestamp_migration) + + config.fix_migrations(dry_run=True, yes=True) + + timestamp_file = migration_dir / "20251030120000_timestamp_migration.py" + assert timestamp_file.exists() + + sequential_file = migration_dir / "0001_timestamp_migration.py" + assert not sequential_file.exists() diff --git a/tests/unit/test_config/test_migration_methods.py b/tests/unit/test_config/test_migration_methods.py new file mode 100644 index 00000000..5f3eea35 --- /dev/null +++ b/tests/unit/test_config/test_migration_methods.py @@ -0,0 +1,553 @@ +"""Unit tests for config migration convenience methods. + +Tests the 7 migration methods added to DatabaseConfigProtocol: +- migrate_up() +- migrate_down() +- get_current_migration() +- create_migration() +- init_migrations() +- stamp_migration() +- fix_migrations() + +Tests cover all 4 base config classes: +- NoPoolSyncConfig (sync, no pool) +- NoPoolAsyncConfig (async, no pool) +- SyncDatabaseConfig (sync, pooled) +- AsyncDatabaseConfig (async, pooled) +""" + +import tempfile +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from sqlspec.adapters.aiosqlite.config import AiosqliteConfig +from sqlspec.adapters.asyncpg.config import AsyncpgConfig +from sqlspec.adapters.duckdb.config import DuckDBConfig +from sqlspec.adapters.sqlite.config import SqliteConfig +from sqlspec.config import AsyncDatabaseConfig, NoPoolAsyncConfig, NoPoolSyncConfig, SyncDatabaseConfig +from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands + + +def test_sync_config_has_migration_methods() -> None: + """Test that SyncDatabaseConfig has all migration methods.""" + assert hasattr(SyncDatabaseConfig, "migrate_up") + assert hasattr(SyncDatabaseConfig, "migrate_down") + assert hasattr(SyncDatabaseConfig, "get_current_migration") + assert hasattr(SyncDatabaseConfig, "create_migration") + assert hasattr(SyncDatabaseConfig, "init_migrations") + assert hasattr(SyncDatabaseConfig, "stamp_migration") + assert hasattr(SyncDatabaseConfig, "fix_migrations") + + +def test_async_config_has_migration_methods() -> None: + """Test that AsyncDatabaseConfig has all migration methods.""" + assert hasattr(AsyncDatabaseConfig, "migrate_up") + assert hasattr(AsyncDatabaseConfig, "migrate_down") + assert hasattr(AsyncDatabaseConfig, "get_current_migration") + assert hasattr(AsyncDatabaseConfig, "create_migration") + assert hasattr(AsyncDatabaseConfig, "init_migrations") + assert hasattr(AsyncDatabaseConfig, "stamp_migration") + assert hasattr(AsyncDatabaseConfig, "fix_migrations") + + +def test_no_pool_sync_config_has_migration_methods() -> None: + """Test that NoPoolSyncConfig has all migration methods.""" + assert hasattr(NoPoolSyncConfig, "migrate_up") + assert hasattr(NoPoolSyncConfig, "migrate_down") + assert hasattr(NoPoolSyncConfig, "get_current_migration") + assert hasattr(NoPoolSyncConfig, "create_migration") + assert hasattr(NoPoolSyncConfig, "init_migrations") + assert hasattr(NoPoolSyncConfig, "stamp_migration") + assert hasattr(NoPoolSyncConfig, "fix_migrations") + + +def test_no_pool_async_config_has_migration_methods() -> None: + """Test that NoPoolAsyncConfig has all migration methods.""" + assert hasattr(NoPoolAsyncConfig, "migrate_up") + assert hasattr(NoPoolAsyncConfig, "migrate_down") + assert hasattr(NoPoolAsyncConfig, "get_current_migration") + assert hasattr(NoPoolAsyncConfig, "create_migration") + assert hasattr(NoPoolAsyncConfig, "init_migrations") + assert hasattr(NoPoolAsyncConfig, "stamp_migration") + assert hasattr(NoPoolAsyncConfig, "fix_migrations") + + +def test_sqlite_config_migrate_up_calls_commands() -> None: + """Test that SqliteConfig.migrate_up() delegates to SyncMigrationCommands.upgrade().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + config.migrate_up(revision="head", allow_missing=True, auto_sync=False, dry_run=True) + + mock_upgrade.assert_called_once_with("head", True, False, True) + + +def test_sqlite_config_migrate_down_calls_commands() -> None: + """Test that SqliteConfig.migrate_down() delegates to SyncMigrationCommands.downgrade().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: + config.migrate_down(revision="-2", dry_run=True) + + mock_downgrade.assert_called_once_with("-2", dry_run=True) + + +def test_sqlite_config_get_current_migration_calls_commands() -> None: + """Test that SqliteConfig.get_current_migration() delegates to SyncMigrationCommands.current().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "current", return_value="0001") as mock_current: + result = config.get_current_migration(verbose=True) + + mock_current.assert_called_once_with(verbose=True) + assert result == "0001" + + +def test_sqlite_config_create_migration_calls_commands() -> None: + """Test that SqliteConfig.create_migration() delegates to SyncMigrationCommands.revision().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "revision", return_value=None) as mock_revision: + config.create_migration(message="test migration", file_type="py") + + mock_revision.assert_called_once_with("test migration", "py") + + +def test_sqlite_config_init_migrations_calls_commands() -> None: + """Test that SqliteConfig.init_migrations() delegates to SyncMigrationCommands.init().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: + config.init_migrations(directory=str(migration_dir), package=False) + + mock_init.assert_called_once_with(str(migration_dir), False) + + +def test_sqlite_config_init_migrations_uses_default_directory() -> None: + """Test that SqliteConfig.init_migrations() uses script_location when directory not provided.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: + config.init_migrations(package=True) + + mock_init.assert_called_once_with(str(migration_dir), True) + + +def test_sqlite_config_stamp_migration_calls_commands() -> None: + """Test that SqliteConfig.stamp_migration() delegates to SyncMigrationCommands.stamp().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "stamp", return_value=None) as mock_stamp: + config.stamp_migration(revision="0001") + + mock_stamp.assert_called_once_with("0001") + + +def test_sqlite_config_fix_migrations_calls_commands() -> None: + """Test that SqliteConfig.fix_migrations() delegates to SyncMigrationCommands.fix().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "fix", return_value=None) as mock_fix: + config.fix_migrations(dry_run=True, update_database=False, yes=True) + + mock_fix.assert_called_once_with(True, False, True) + + +@pytest.mark.asyncio +async def test_asyncpg_config_migrate_up_calls_commands() -> None: + """Test that AsyncpgConfig.migrate_up() delegates to AsyncMigrationCommands.upgrade().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + await config.migrate_up(revision="0002", allow_missing=False, auto_sync=True, dry_run=False) + + mock_upgrade.assert_called_once_with("0002", False, True, False) + + +@pytest.mark.asyncio +async def test_asyncpg_config_migrate_down_calls_commands() -> None: + """Test that AsyncpgConfig.migrate_down() delegates to AsyncMigrationCommands.downgrade().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: + await config.migrate_down(revision="base", dry_run=False) + + mock_downgrade.assert_called_once_with("base", dry_run=False) + + +@pytest.mark.asyncio +async def test_asyncpg_config_get_current_migration_calls_commands() -> None: + """Test that AsyncpgConfig.get_current_migration() delegates to AsyncMigrationCommands.current().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "current", return_value="0002") as mock_current: + result = await config.get_current_migration(verbose=False) + + mock_current.assert_called_once_with(verbose=False) + assert result == "0002" + + +@pytest.mark.asyncio +async def test_asyncpg_config_create_migration_calls_commands() -> None: + """Test that AsyncpgConfig.create_migration() delegates to AsyncMigrationCommands.revision().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "revision", return_value=None) as mock_revision: + await config.create_migration(message="add users table", file_type="sql") + + mock_revision.assert_called_once_with("add users table", "sql") + + +@pytest.mark.asyncio +async def test_asyncpg_config_init_migrations_calls_commands() -> None: + """Test that AsyncpgConfig.init_migrations() delegates to AsyncMigrationCommands.init().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "init", return_value=None) as mock_init: + await config.init_migrations(directory=str(migration_dir), package=True) + + mock_init.assert_called_once_with(str(migration_dir), True) + + +@pytest.mark.asyncio +async def test_asyncpg_config_stamp_migration_calls_commands() -> None: + """Test that AsyncpgConfig.stamp_migration() delegates to AsyncMigrationCommands.stamp().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "stamp", return_value=None) as mock_stamp: + await config.stamp_migration(revision="0003") + + mock_stamp.assert_called_once_with("0003") + + +@pytest.mark.asyncio +async def test_asyncpg_config_fix_migrations_calls_commands() -> None: + """Test that AsyncpgConfig.fix_migrations() delegates to AsyncMigrationCommands.fix().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "fix", return_value=None) as mock_fix: + await config.fix_migrations(dry_run=False, update_database=True, yes=False) + + mock_fix.assert_called_once_with(False, True, False) + + +def test_duckdb_pooled_config_migrate_up_calls_commands() -> None: + """Test that DuckDBConfig.migrate_up() delegates to SyncMigrationCommands.upgrade().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = DuckDBConfig( + pool_config={"database": ":memory:"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + config.migrate_up(revision="head", allow_missing=False, auto_sync=True, dry_run=False) + + mock_upgrade.assert_called_once_with("head", False, True, False) + + +def test_duckdb_pooled_config_get_current_migration_calls_commands() -> None: + """Test that DuckDBConfig.get_current_migration() delegates to SyncMigrationCommands.current().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = DuckDBConfig( + pool_config={"database": ":memory:"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "current", return_value=None) as mock_current: + result = config.get_current_migration(verbose=False) + + mock_current.assert_called_once_with(verbose=False) + assert result is None + + +@pytest.mark.asyncio +async def test_aiosqlite_async_config_migrate_up_calls_commands() -> None: + """Test that AiosqliteConfig.migrate_up() delegates to AsyncMigrationCommands.upgrade().""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = AiosqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + await config.migrate_up(revision="head", allow_missing=True, auto_sync=True, dry_run=True) + + mock_upgrade.assert_called_once_with("head", True, True, True) + + +def test_migrate_up_default_parameters_sync() -> None: + """Test that migrate_up() uses correct default parameter values for sync configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + config.migrate_up() + + mock_upgrade.assert_called_once_with("head", False, True, False) + + +@pytest.mark.asyncio +async def test_migrate_up_default_parameters_async() -> None: + """Test that migrate_up() uses correct default parameter values for async configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: + await config.migrate_up() + + mock_upgrade.assert_called_once_with("head", False, True, False) + + +def test_migrate_down_default_parameters_sync() -> None: + """Test that migrate_down() uses correct default parameter values for sync configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: + config.migrate_down() + + mock_downgrade.assert_called_once_with("-1", dry_run=False) + + +@pytest.mark.asyncio +async def test_migrate_down_default_parameters_async() -> None: + """Test that migrate_down() uses correct default parameter values for async configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: + await config.migrate_down() + + mock_downgrade.assert_called_once_with("-1", dry_run=False) + + +def test_create_migration_default_file_type_sync() -> None: + """Test that create_migration() defaults to 'sql' file type for sync configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "revision", return_value=None) as mock_revision: + config.create_migration(message="test migration") + + mock_revision.assert_called_once_with("test migration", "sql") + + +@pytest.mark.asyncio +async def test_create_migration_default_file_type_async() -> None: + """Test that create_migration() defaults to 'sql' file type for async configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "revision", return_value=None) as mock_revision: + await config.create_migration(message="test migration") + + mock_revision.assert_called_once_with("test migration", "sql") + + +def test_init_migrations_default_package_sync() -> None: + """Test that init_migrations() defaults to package=True for sync configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: + config.init_migrations(directory=str(migration_dir)) + + mock_init.assert_called_once_with(str(migration_dir), True) + + +@pytest.mark.asyncio +async def test_init_migrations_default_package_async() -> None: + """Test that init_migrations() defaults to package=True for async configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "init", return_value=None) as mock_init: + await config.init_migrations(directory=str(migration_dir)) + + mock_init.assert_called_once_with(str(migration_dir), True) + + +def test_fix_migrations_default_parameters_sync() -> None: + """Test that fix_migrations() uses correct default parameter values for sync configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + temp_db = str(Path(temp_dir) / "test.db") + + config = SqliteConfig( + pool_config={"database": temp_db}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(SyncMigrationCommands, "fix", return_value=None) as mock_fix: + config.fix_migrations() + + mock_fix.assert_called_once_with(False, True, False) + + +@pytest.mark.asyncio +async def test_fix_migrations_default_parameters_async() -> None: + """Test that fix_migrations() uses correct default parameter values for async configs.""" + with tempfile.TemporaryDirectory() as temp_dir: + migration_dir = Path(temp_dir) / "migrations" + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/test"}, + migration_config={"script_location": str(migration_dir)}, + ) + + with patch.object(AsyncMigrationCommands, "fix", return_value=None) as mock_fix: + await config.fix_migrations() + + mock_fix.assert_called_once_with(False, True, False) From 5c9111871b7565ce4e5267f68f61913c0ee1d2b4 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Thu, 30 Oct 2025 23:55:52 +0000 Subject: [PATCH 3/4] feat: add migration convenience methods to database config classes --- AGENTS.md | 88 ++++++++++ docs/changelog.rst | 66 ++++++++ docs/guides/migrations/hybrid-versioning.md | 86 ++++++++++ docs/usage/migrations.rst | 150 ++++++++++++++++++ sqlspec/config.py | 82 +++++----- .../test_asyncpg/test_migrations.py | 15 +- .../test_config/test_migration_methods.py | 87 ++++------ 7 files changed, 463 insertions(+), 111 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index c64e8f46..a1888235 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -183,6 +183,94 @@ SQLSpec is a type-safe SQL query mapper designed for minimal abstraction between - **Parameter Style Abstraction**: Automatically converts between different parameter styles (?, :name, $1, %s) - **Type Safety**: Supports mapping results to Pydantic, msgspec, attrs, and other typed models - **Single-Pass Processing**: Parse once → transform once → validate once - SQL object is single source of truth +- **Abstract Methods with Concrete Implementations**: Protocol defines abstract methods, base classes provide concrete sync/async implementations + +### Protocol Abstract Methods Pattern + +When adding methods that need to support both sync and async configurations, use this pattern: + +**Step 1: Define abstract method in protocol** + +```python +from abc import abstractmethod +from typing import Awaitable + +class DatabaseConfigProtocol(Protocol): + is_async: ClassVar[bool] # Set by base classes + + @abstractmethod + def migrate_up( + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False + ) -> "Awaitable[None] | None": + """Apply database migrations up to specified revision. + + Args: + revision: Target revision or "head" for latest. + allow_missing: Allow out-of-order migrations. + auto_sync: Auto-reconcile renamed migrations. + dry_run: Show what would be done without applying. + """ + raise NotImplementedError +``` + +**Step 2: Implement in sync base class (no async/await)** + +```python +class NoPoolSyncConfig(DatabaseConfigProtocol): + is_async: ClassVar[bool] = False + + def migrate_up( + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False + ) -> None: + """Apply database migrations up to specified revision.""" + commands = self._ensure_migration_commands() + commands.upgrade(revision, allow_missing, auto_sync, dry_run) +``` + +**Step 3: Implement in async base class (with async/await)** + +```python +class NoPoolAsyncConfig(DatabaseConfigProtocol): + is_async: ClassVar[bool] = True + + async def migrate_up( + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False + ) -> None: + """Apply database migrations up to specified revision.""" + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.upgrade(revision, allow_missing, auto_sync, dry_run) +``` + +**Key principles:** + +- Protocol defines the interface with union return type (`Awaitable[T] | T`) +- Sync base classes implement without `async def` or `await` +- Async base classes implement with `async def` and `await` +- Each base class has concrete implementation - no need for child classes to override +- Use `cast()` to narrow types when delegating to command objects +- All 4 base classes (NoPoolSyncConfig, NoPoolAsyncConfig, SyncDatabaseConfig, AsyncDatabaseConfig) implement the same way + +**Benefits:** + +- Single source of truth (protocol) for API contract +- Each base class provides complete implementation +- Child adapter classes (AsyncpgConfig, SqliteConfig, etc.) inherit working methods automatically +- Type checkers understand sync vs async based on `is_async` class variable +- No code duplication across adapters + +**When to use:** + +- Adding convenience methods that delegate to external command objects +- Methods that need identical behavior across all adapters +- Operations that differ only in sync vs async execution +- Any protocol method where behavior is determined by sync/async mode + +**Anti-patterns to avoid:** + +- Don't use runtime `if self.is_async:` checks in a single implementation +- Don't make protocol methods concrete (always use `@abstractmethod`) +- Don't duplicate logic across the 4 base classes +- Don't forget to update all 4 base classes when adding new methods ### Database Connection Flow diff --git a/docs/changelog.rst b/docs/changelog.rst index d7431477..9a9a9cfb 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,72 @@ SQLSpec Changelog Recent Updates ============== +Migration Convenience Methods on Config Classes +------------------------------------------------ + +Added migration methods directly to database configuration classes, eliminating the need to instantiate separate command objects. + +**What's New:** + +All database configs (both sync and async) now provide migration methods: + +- ``migrate_up()`` / ``upgrade()`` - Apply migrations up to a revision +- ``migrate_down()`` / ``downgrade()`` - Rollback migrations +- ``get_current_migration()`` - Check current version +- ``create_migration()`` - Create new migration file +- ``init_migrations()`` - Initialize migrations directory +- ``stamp_migration()`` - Stamp database to specific revision +- ``fix_migrations()`` - Convert timestamp to sequential migrations + +**Before (verbose):** + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + from sqlspec.migrations.commands import AsyncMigrationCommands + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://..."}, + migration_config={"script_location": "migrations"} + ) + + commands = AsyncMigrationCommands(config) + await commands.upgrade("head") + +**After (recommended):** + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://..."}, + migration_config={"script_location": "migrations"} + ) + + await config.upgrade("head") + +**Key Benefits:** + +- Simpler API - no need to import and instantiate command classes +- Works with both sync and async adapters +- Full backward compatibility - command classes still available +- Cleaner test fixtures and deployment scripts + +**Async Adapters** (AsyncPG, Asyncmy, Aiosqlite, Psqlpy): + +.. code-block:: python + + await config.migrate_up("head") + await config.create_migration("add users") + +**Sync Adapters** (SQLite, DuckDB): + +.. code-block:: python + + config.migrate_up("head") # No await needed + config.create_migration("add users") + SQL Loader Graceful Error Handling ----------------------------------- diff --git a/docs/guides/migrations/hybrid-versioning.md b/docs/guides/migrations/hybrid-versioning.md index ea677060..7e617f55 100644 --- a/docs/guides/migrations/hybrid-versioning.md +++ b/docs/guides/migrations/hybrid-versioning.md @@ -272,6 +272,92 @@ Always preview before applying: sqlspec --config myapp.config fix --dry-run ``` +## Programmatic API + +For Python-based migration automation, use the config method directly instead of CLI commands: + +### Async Configuration + +```python +from sqlspec.adapters.asyncpg import AsyncpgConfig + +config = AsyncpgConfig( + pool_config={"dsn": "postgresql://user:pass@localhost/mydb"}, + migration_config={ + "enabled": True, + "script_location": "migrations", + } +) + +# Preview conversions +await config.fix_migrations(dry_run=True) + +# Apply conversions (auto-approve) +await config.fix_migrations(dry_run=False, update_database=True, yes=True) + +# Files only (skip database update) +await config.fix_migrations(dry_run=False, update_database=False, yes=True) +``` + +### Sync Configuration + +```python +from sqlspec.adapters.sqlite import SqliteConfig + +config = SqliteConfig( + pool_config={"database": "myapp.db"}, + migration_config={ + "enabled": True, + "script_location": "migrations", + } +) + +# Preview conversions (no await needed) +config.fix_migrations(dry_run=True) + +# Apply conversions (auto-approve) +config.fix_migrations(dry_run=False, update_database=True, yes=True) + +# Files only (skip database update) +config.fix_migrations(dry_run=False, update_database=False, yes=True) +``` + +### Use Cases + +The programmatic API is useful for: + +- **Custom deployment scripts** - Integrate migration fixing into deployment automation +- **Testing workflows** - Automate migration testing in CI/CD pipelines +- **Framework integrations** - Build migration support into web framework startup hooks +- **Monitoring tools** - Track migration conversions programmatically + +### Example: Custom Deployment Script + +```python +import asyncio +from sqlspec.adapters.asyncpg import AsyncpgConfig + +async def deploy(): + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://..."}, + migration_config={"script_location": "migrations"} + ) + + # Step 1: Convert migrations to sequential + print("Converting migrations to sequential format...") + await config.fix_migrations(dry_run=False, update_database=True, yes=True) + + # Step 2: Apply all pending migrations + print("Applying migrations...") + await config.upgrade("head") + + # Step 3: Verify current version + current = await config.get_current_migration(verbose=True) + print(f"Deployed to version: {current}") + +asyncio.run(deploy()) +``` + ## Best Practices ### 1. Always Use Version Control diff --git a/docs/usage/migrations.rst b/docs/usage/migrations.rst index 7e549877..82acfb69 100644 --- a/docs/usage/migrations.rst +++ b/docs/usage/migrations.rst @@ -29,6 +29,156 @@ Initialize Migrations # Apply migrations sqlspec --config myapp.config upgrade +Programmatic API (Recommended) +=============================== + +SQLSpec provides migration convenience methods directly on config classes, eliminating +the need to instantiate separate command objects. + +Async Adapters +-------------- + +For async adapters (AsyncPG, Asyncmy, Aiosqlite, Psqlpy), migration methods return awaitables: + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://user:pass@localhost/mydb"}, + migration_config={ + "enabled": True, + "script_location": "migrations", + } + ) + + # Apply migrations + await config.migrate_up("head") + # Or use the alias + await config.upgrade("head") + + # Rollback one revision + await config.migrate_down("-1") + # Or use the alias + await config.downgrade("-1") + + # Check current version + current = await config.get_current_migration(verbose=True) + print(current) + + # Create new migration + await config.create_migration("add users table", file_type="sql") + + # Initialize migrations directory + await config.init_migrations() + + # Stamp database to specific revision + await config.stamp_migration("0003") + + # Convert timestamp to sequential migrations + await config.fix_migrations(dry_run=False, update_database=True, yes=True) + +Sync Adapters +------------- + +For sync adapters (SQLite, DuckDB), migration methods execute immediately without await: + +.. code-block:: python + + from sqlspec.adapters.sqlite import SqliteConfig + + config = SqliteConfig( + pool_config={"database": "myapp.db"}, + migration_config={ + "enabled": True, + "script_location": "migrations", + } + ) + + # Apply migrations (no await needed) + config.migrate_up("head") + # Or use the alias + config.upgrade("head") + + # Rollback one revision + config.migrate_down("-1") + # Or use the alias + config.downgrade("-1") + + # Check current version + current = config.get_current_migration(verbose=True) + print(current) + + # Create new migration + config.create_migration("add users table", file_type="sql") + + # Initialize migrations directory + config.init_migrations() + + # Stamp database to specific revision + config.stamp_migration("0003") + + # Convert timestamp to sequential migrations + config.fix_migrations(dry_run=False, update_database=True, yes=True) + +Available Methods +----------------- + +All database configs (sync and async) provide these migration methods: + +``migrate_up(revision="head", allow_missing=False, auto_sync=True, dry_run=False)`` + Apply migrations up to the specified revision. + + Also available as ``upgrade()`` alias. + +``migrate_down(revision="-1", dry_run=False)`` + Rollback migrations down to the specified revision. + + Also available as ``downgrade()`` alias. + +``get_current_migration(verbose=False)`` + Get the current migration version. + +``create_migration(message, file_type="sql")`` + Create a new migration file. + +``init_migrations(directory=None, package=None)`` + Initialize the migrations directory structure. + +``stamp_migration(revision)`` + Stamp the database to a specific revision without running migrations. + +``fix_migrations(dry_run=False, update_database=True, yes=False)`` + Convert timestamp migrations to sequential format. + +Command Classes (Advanced) +--------------------------- + +For advanced use cases requiring custom logic, you can still use command classes directly: + +.. code-block:: python + + from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands + from sqlspec.adapters.asyncpg import AsyncpgConfig + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://..."}, + migration_config={"script_location": "migrations"} + ) + + # Create commands instance + commands = AsyncMigrationCommands(config) + + # Use commands directly + await commands.upgrade("head") + +This approach is useful when: + +- Building custom migration runners +- Implementing migration lifecycle hooks +- Integrating with third-party workflow tools +- Need fine-grained control over migration execution + Configuration ============= diff --git a/sqlspec/config.py b/sqlspec/config.py index c7b26ea2..b1ddfb8e 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, cast from typing_extensions import NotRequired, TypedDict @@ -658,7 +658,7 @@ def migrate_up( dry_run: Show what would be done without applying. """ commands = self._ensure_migration_commands() - return commands.upgrade(revision, allow_missing, auto_sync, dry_run) + commands.upgrade(revision, allow_missing, auto_sync, dry_run) def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: """Apply database migrations down to specified revision. @@ -668,7 +668,7 @@ def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: dry_run: Show what would be done without applying. """ commands = self._ensure_migration_commands() - return commands.downgrade(revision, dry_run=dry_run) + commands.downgrade(revision, dry_run=dry_run) def get_current_migration(self, verbose: bool = False) -> "str | None": """Get the current migration version. @@ -679,7 +679,7 @@ def get_current_migration(self, verbose: bool = False) -> "str | None": Returns: Current migration version or None if no migrations applied. """ - commands = self._ensure_migration_commands() + commands = cast("SyncMigrationCommands", self._ensure_migration_commands()) return commands.current(verbose=verbose) def create_migration(self, message: str, file_type: str = "sql") -> None: @@ -690,7 +690,7 @@ def create_migration(self, message: str, file_type: str = "sql") -> None: file_type: Type of migration file to create ('sql' or 'py'). """ commands = self._ensure_migration_commands() - return commands.revision(message, file_type) + commands.revision(message, file_type) def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: """Initialize migration directory structure. @@ -705,7 +705,7 @@ def init_migrations(self, directory: "str | None" = None, package: bool = True) commands = self._ensure_migration_commands() assert directory is not None - return commands.init(directory, package) + commands.init(directory, package) def stamp_migration(self, revision: str) -> None: """Mark database as being at a specific revision without running migrations. @@ -714,7 +714,7 @@ def stamp_migration(self, revision: str) -> None: revision: The revision to stamp. """ commands = self._ensure_migration_commands() - return commands.stamp(revision) + commands.stamp(revision) def fix_migrations(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: """Convert timestamp migrations to sequential format. @@ -725,7 +725,7 @@ def fix_migrations(self, dry_run: bool = False, update_database: bool = True, ye yes: Skip confirmation prompt. """ commands = self._ensure_migration_commands() - return commands.fix(dry_run, update_database, yes) + commands.fix(dry_run, update_database, yes) class NoPoolAsyncConfig(DatabaseConfigProtocol[ConnectionT, None, DriverT]): @@ -796,8 +796,8 @@ async def migrate_up( auto_sync: Auto-reconcile renamed migrations. dry_run: Show what would be done without applying. """ - commands = self._ensure_migration_commands() - return await commands.upgrade(revision, allow_missing, auto_sync, dry_run) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.upgrade(revision, allow_missing, auto_sync, dry_run) async def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: """Apply database migrations down to specified revision. @@ -806,8 +806,8 @@ async def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> revision: Target revision, "-1" for one step back, or "base" for all migrations. dry_run: Show what would be done without applying. """ - commands = self._ensure_migration_commands() - return await commands.downgrade(revision, dry_run=dry_run) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.downgrade(revision, dry_run=dry_run) async def get_current_migration(self, verbose: bool = False) -> "str | None": """Get the current migration version. @@ -818,7 +818,7 @@ async def get_current_migration(self, verbose: bool = False) -> "str | None": Returns: Current migration version or None if no migrations applied. """ - commands = self._ensure_migration_commands() + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) return await commands.current(verbose=verbose) async def create_migration(self, message: str, file_type: str = "sql") -> None: @@ -828,8 +828,8 @@ async def create_migration(self, message: str, file_type: str = "sql") -> None: message: Description for the migration. file_type: Type of migration file to create ('sql' or 'py'). """ - commands = self._ensure_migration_commands() - return await commands.revision(message, file_type) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.revision(message, file_type) async def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: """Initialize migration directory structure. @@ -842,9 +842,9 @@ async def init_migrations(self, directory: "str | None" = None, package: bool = migration_config = self.migration_config or {} directory = str(migration_config.get("script_location") or "migrations") - commands = self._ensure_migration_commands() + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) assert directory is not None - return await commands.init(directory, package) + await commands.init(directory, package) async def stamp_migration(self, revision: str) -> None: """Mark database as being at a specific revision without running migrations. @@ -852,8 +852,8 @@ async def stamp_migration(self, revision: str) -> None: Args: revision: The revision to stamp. """ - commands = self._ensure_migration_commands() - return await commands.stamp(revision) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.stamp(revision) async def fix_migrations(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: """Convert timestamp migrations to sequential format. @@ -863,8 +863,8 @@ async def fix_migrations(self, dry_run: bool = False, update_database: bool = Tr update_database: Update migration records in database. yes: Skip confirmation prompt. """ - commands = self._ensure_migration_commands() - return await commands.fix(dry_run, update_database, yes) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.fix(dry_run, update_database, yes) class SyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): @@ -959,7 +959,7 @@ def migrate_up( dry_run: Show what would be done without applying. """ commands = self._ensure_migration_commands() - return commands.upgrade(revision, allow_missing, auto_sync, dry_run) + commands.upgrade(revision, allow_missing, auto_sync, dry_run) def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: """Apply database migrations down to specified revision. @@ -969,7 +969,7 @@ def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: dry_run: Show what would be done without applying. """ commands = self._ensure_migration_commands() - return commands.downgrade(revision, dry_run=dry_run) + commands.downgrade(revision, dry_run=dry_run) def get_current_migration(self, verbose: bool = False) -> "str | None": """Get the current migration version. @@ -980,7 +980,7 @@ def get_current_migration(self, verbose: bool = False) -> "str | None": Returns: Current migration version or None if no migrations applied. """ - commands = self._ensure_migration_commands() + commands = cast("SyncMigrationCommands", self._ensure_migration_commands()) return commands.current(verbose=verbose) def create_migration(self, message: str, file_type: str = "sql") -> None: @@ -991,7 +991,7 @@ def create_migration(self, message: str, file_type: str = "sql") -> None: file_type: Type of migration file to create ('sql' or 'py'). """ commands = self._ensure_migration_commands() - return commands.revision(message, file_type) + commands.revision(message, file_type) def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: """Initialize migration directory structure. @@ -1006,7 +1006,7 @@ def init_migrations(self, directory: "str | None" = None, package: bool = True) commands = self._ensure_migration_commands() assert directory is not None - return commands.init(directory, package) + commands.init(directory, package) def stamp_migration(self, revision: str) -> None: """Mark database as being at a specific revision without running migrations. @@ -1015,7 +1015,7 @@ def stamp_migration(self, revision: str) -> None: revision: The revision to stamp. """ commands = self._ensure_migration_commands() - return commands.stamp(revision) + commands.stamp(revision) def fix_migrations(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: """Convert timestamp migrations to sequential format. @@ -1026,7 +1026,7 @@ def fix_migrations(self, dry_run: bool = False, update_database: bool = True, ye yes: Skip confirmation prompt. """ commands = self._ensure_migration_commands() - return commands.fix(dry_run, update_database, yes) + commands.fix(dry_run, update_database, yes) class AsyncDatabaseConfig(DatabaseConfigProtocol[ConnectionT, PoolT, DriverT]): @@ -1122,8 +1122,8 @@ async def migrate_up( auto_sync: Auto-reconcile renamed migrations. dry_run: Show what would be done without applying. """ - commands = self._ensure_migration_commands() - return await commands.upgrade(revision, allow_missing, auto_sync, dry_run) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.upgrade(revision, allow_missing, auto_sync, dry_run) async def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> None: """Apply database migrations down to specified revision. @@ -1132,8 +1132,8 @@ async def migrate_down(self, revision: str = "-1", *, dry_run: bool = False) -> revision: Target revision, "-1" for one step back, or "base" for all migrations. dry_run: Show what would be done without applying. """ - commands = self._ensure_migration_commands() - return await commands.downgrade(revision, dry_run=dry_run) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.downgrade(revision, dry_run=dry_run) async def get_current_migration(self, verbose: bool = False) -> "str | None": """Get the current migration version. @@ -1144,7 +1144,7 @@ async def get_current_migration(self, verbose: bool = False) -> "str | None": Returns: Current migration version or None if no migrations applied. """ - commands = self._ensure_migration_commands() + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) return await commands.current(verbose=verbose) async def create_migration(self, message: str, file_type: str = "sql") -> None: @@ -1154,8 +1154,8 @@ async def create_migration(self, message: str, file_type: str = "sql") -> None: message: Description for the migration. file_type: Type of migration file to create ('sql' or 'py'). """ - commands = self._ensure_migration_commands() - return await commands.revision(message, file_type) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.revision(message, file_type) async def init_migrations(self, directory: "str | None" = None, package: bool = True) -> None: """Initialize migration directory structure. @@ -1168,9 +1168,9 @@ async def init_migrations(self, directory: "str | None" = None, package: bool = migration_config = self.migration_config or {} directory = str(migration_config.get("script_location") or "migrations") - commands = self._ensure_migration_commands() + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) assert directory is not None - return await commands.init(directory, package) + await commands.init(directory, package) async def stamp_migration(self, revision: str) -> None: """Mark database as being at a specific revision without running migrations. @@ -1178,8 +1178,8 @@ async def stamp_migration(self, revision: str) -> None: Args: revision: The revision to stamp. """ - commands = self._ensure_migration_commands() - return await commands.stamp(revision) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.stamp(revision) async def fix_migrations(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: """Convert timestamp migrations to sequential format. @@ -1189,5 +1189,5 @@ async def fix_migrations(self, dry_run: bool = False, update_database: bool = Tr update_database: Update migration records in database. yes: Skip confirmation prompt. """ - commands = self._ensure_migration_commands() - return await commands.fix(dry_run, update_database, yes) + commands = cast("AsyncMigrationCommands", self._ensure_migration_commands()) + await commands.fix(dry_run, update_database, yes) diff --git a/tests/integration/test_adapters/test_asyncpg/test_migrations.py b/tests/integration/test_adapters/test_asyncpg/test_migrations.py index e6a1b060..d57ac45d 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_migrations.py +++ b/tests/integration/test_adapters/test_asyncpg/test_migrations.py @@ -569,10 +569,7 @@ async def test_asyncpg_config_create_migration_method(postgres_service: Postgres "password": postgres_service.password, "database": postgres_service.database, }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_create", - }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_create"}, ) try: @@ -603,10 +600,7 @@ async def test_asyncpg_config_stamp_migration_method(postgres_service: PostgresS "password": postgres_service.password, "database": postgres_service.database, }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_stamp", - }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_stamp"}, ) try: @@ -655,10 +649,7 @@ async def test_asyncpg_config_fix_migrations_dry_run(postgres_service: PostgresS "password": postgres_service.password, "database": postgres_service.database, }, - migration_config={ - "script_location": str(migration_dir), - "version_table_name": "sqlspec_migrations_fix", - }, + migration_config={"script_location": str(migration_dir), "version_table_name": "sqlspec_migrations_fix"}, ) try: diff --git a/tests/unit/test_config/test_migration_methods.py b/tests/unit/test_config/test_migration_methods.py index 5f3eea35..a29f7d19 100644 --- a/tests/unit/test_config/test_migration_methods.py +++ b/tests/unit/test_config/test_migration_methods.py @@ -18,8 +18,7 @@ import tempfile from pathlib import Path -from typing import Any -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import patch import pytest @@ -82,8 +81,7 @@ def test_sqlite_config_migrate_up_calls_commands() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: @@ -99,8 +97,7 @@ def test_sqlite_config_migrate_down_calls_commands() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: @@ -116,8 +113,7 @@ def test_sqlite_config_get_current_migration_calls_commands() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "current", return_value="0001") as mock_current: @@ -134,8 +130,7 @@ def test_sqlite_config_create_migration_calls_commands() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "revision", return_value=None) as mock_revision: @@ -151,8 +146,7 @@ def test_sqlite_config_init_migrations_calls_commands() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: @@ -168,8 +162,7 @@ def test_sqlite_config_init_migrations_uses_default_directory() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: @@ -185,8 +178,7 @@ def test_sqlite_config_stamp_migration_calls_commands() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "stamp", return_value=None) as mock_stamp: @@ -202,8 +194,7 @@ def test_sqlite_config_fix_migrations_calls_commands() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "fix", return_value=None) as mock_fix: @@ -219,8 +210,7 @@ async def test_asyncpg_config_migrate_up_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: @@ -236,8 +226,7 @@ async def test_asyncpg_config_migrate_down_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: @@ -253,8 +242,7 @@ async def test_asyncpg_config_get_current_migration_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "current", return_value="0002") as mock_current: @@ -271,8 +259,7 @@ async def test_asyncpg_config_create_migration_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "revision", return_value=None) as mock_revision: @@ -288,8 +275,7 @@ async def test_asyncpg_config_init_migrations_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "init", return_value=None) as mock_init: @@ -305,8 +291,7 @@ async def test_asyncpg_config_stamp_migration_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "stamp", return_value=None) as mock_stamp: @@ -322,8 +307,7 @@ async def test_asyncpg_config_fix_migrations_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "fix", return_value=None) as mock_fix: @@ -338,8 +322,7 @@ def test_duckdb_pooled_config_migrate_up_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = DuckDBConfig( - pool_config={"database": ":memory:"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": ":memory:"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: @@ -354,8 +337,7 @@ def test_duckdb_pooled_config_get_current_migration_calls_commands() -> None: migration_dir = Path(temp_dir) / "migrations" config = DuckDBConfig( - pool_config={"database": ":memory:"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": ":memory:"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "current", return_value=None) as mock_current: @@ -373,8 +355,7 @@ async def test_aiosqlite_async_config_migrate_up_calls_commands() -> None: temp_db = str(Path(temp_dir) / "test.db") config = AiosqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: @@ -390,8 +371,7 @@ def test_migrate_up_default_parameters_sync() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: @@ -407,8 +387,7 @@ async def test_migrate_up_default_parameters_async() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "upgrade", return_value=None) as mock_upgrade: @@ -424,8 +403,7 @@ def test_migrate_down_default_parameters_sync() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: @@ -441,8 +419,7 @@ async def test_migrate_down_default_parameters_async() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "downgrade", return_value=None) as mock_downgrade: @@ -458,8 +435,7 @@ def test_create_migration_default_file_type_sync() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "revision", return_value=None) as mock_revision: @@ -475,8 +451,7 @@ async def test_create_migration_default_file_type_async() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "revision", return_value=None) as mock_revision: @@ -492,8 +467,7 @@ def test_init_migrations_default_package_sync() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "init", return_value=None) as mock_init: @@ -509,8 +483,7 @@ async def test_init_migrations_default_package_async() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "init", return_value=None) as mock_init: @@ -526,8 +499,7 @@ def test_fix_migrations_default_parameters_sync() -> None: temp_db = str(Path(temp_dir) / "test.db") config = SqliteConfig( - pool_config={"database": temp_db}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"database": temp_db}, migration_config={"script_location": str(migration_dir)} ) with patch.object(SyncMigrationCommands, "fix", return_value=None) as mock_fix: @@ -543,8 +515,7 @@ async def test_fix_migrations_default_parameters_async() -> None: migration_dir = Path(temp_dir) / "migrations" config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/test"}, - migration_config={"script_location": str(migration_dir)}, + pool_config={"dsn": "postgresql://localhost/test"}, migration_config={"script_location": str(migration_dir)} ) with patch.object(AsyncMigrationCommands, "fix", return_value=None) as mock_fix: From d5422005126bedbb194a0faa8b2906a67c03614a Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Fri, 31 Oct 2025 01:23:58 +0000 Subject: [PATCH 4/4] feat: group tests by xdist for improved parallel execution --- .../test_adapters/test_adbc/test_data_dictionary.py | 2 ++ .../test_litestar/test_numpy_serialization.py | 2 +- .../test_extensions/test_litestar/test_store.py | 2 +- tests/integration/test_adapters/test_asyncmy/test_arrow.py | 2 ++ .../test_adapters/test_asyncpg/test_data_dictionary.py | 2 ++ .../test_adapters/test_asyncpg/test_schema_migration.py | 2 ++ .../test_adapters/test_bigquery/test_type_handler_config.py | 0 tests/integration/test_adapters/test_oracledb/test_arrow.py | 2 ++ tests/integration/test_adapters/test_sqlite/test_arrow.py | 2 ++ .../test_adapters/test_sqlite/test_data_dictionary.py | 3 +++ .../test_adapters/test_sqlite/test_driver_features.py | 2 ++ .../test_sqlite/test_extensions/test_litestar/test_store.py | 2 +- tests/integration/test_async_migrations.py | 2 ++ tests/integration/test_dishka/conftest.py | 2 ++ tests/integration/test_dishka/test_dishka_integration.py | 2 ++ .../test_extensions/test_fastapi_filters_integration.py | 2 ++ .../integration/test_extensions/test_fastapi_integration.py | 3 +++ tests/integration/test_extensions/test_flask_integration.py | 5 ++--- .../test_extensions/test_starlette_integration.py | 3 +++ .../test_migrations/test_upgrade_downgrade_versions.py | 2 ++ 20 files changed, 38 insertions(+), 6 deletions(-) delete mode 100644 tests/integration/test_adapters/test_bigquery/test_type_handler_config.py diff --git a/tests/integration/test_adapters/test_adbc/test_data_dictionary.py b/tests/integration/test_adapters/test_adbc/test_data_dictionary.py index 1ff71550..41186c1c 100644 --- a/tests/integration/test_adapters/test_adbc/test_data_dictionary.py +++ b/tests/integration/test_adapters/test_adbc/test_data_dictionary.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from sqlspec.adapters.adbc.driver import AdbcDriver +pytestmark = pytest.mark.xdist_group("adbc") + @pytest.mark.adbc def test_adbc_data_dictionary_version_detection(adbc_sync_driver: "AdbcDriver") -> None: diff --git a/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_numpy_serialization.py b/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_numpy_serialization.py index b65ea6c3..64998602 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_numpy_serialization.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_numpy_serialization.py @@ -22,7 +22,7 @@ from sqlspec.adapters.aiosqlite import AiosqliteConfig from sqlspec.extensions.litestar.plugin import SQLSpecPlugin -pytestmark = [pytest.mark.integration, pytest.mark.aiosqlite, pytest.mark.xdist_group("aiosqlite-litestar")] +pytestmark = [pytest.mark.integration, pytest.mark.aiosqlite, pytest.mark.xdist_group("sqlite")] def test_litestar_numpy_encoder_registered() -> None: diff --git a/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_store.py index c207f276..8ba7a321 100644 --- a/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_store.py +++ b/tests/integration/test_adapters/test_aiosqlite/test_extensions/test_litestar/test_store.py @@ -9,7 +9,7 @@ from sqlspec.adapters.aiosqlite.config import AiosqliteConfig from sqlspec.adapters.aiosqlite.litestar.store import AiosqliteStore -pytestmark = [pytest.mark.aiosqlite, pytest.mark.integration] +pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.aiosqlite, pytest.mark.integration] @pytest.fixture diff --git a/tests/integration/test_adapters/test_asyncmy/test_arrow.py b/tests/integration/test_adapters/test_asyncmy/test_arrow.py index 0b800f70..0d720537 100644 --- a/tests/integration/test_adapters/test_asyncmy/test_arrow.py +++ b/tests/integration/test_adapters/test_asyncmy/test_arrow.py @@ -7,6 +7,8 @@ from sqlspec.adapters.asyncmy import AsyncmyConfig +pytestmark = [pytest.mark.xdist_group("mysql")] + @pytest.fixture async def asyncmy_arrow_config(mysql_service: MySQLService) -> AsyncGenerator[AsyncmyConfig, None]: diff --git a/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py b/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py index a096e257..68126324 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py +++ b/tests/integration/test_adapters/test_asyncpg/test_data_dictionary.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from sqlspec.adapters.asyncpg.driver import AsyncpgDriver +pytestmark = pytest.mark.xdist_group("postgres") + @pytest.mark.asyncpg async def test_asyncpg_data_dictionary_version_detection(asyncpg_async_driver: "AsyncpgDriver") -> None: diff --git a/tests/integration/test_adapters/test_asyncpg/test_schema_migration.py b/tests/integration/test_adapters/test_asyncpg/test_schema_migration.py index 933221ea..b2e2eaba 100644 --- a/tests/integration/test_adapters/test_asyncpg/test_schema_migration.py +++ b/tests/integration/test_adapters/test_asyncpg/test_schema_migration.py @@ -6,6 +6,8 @@ from sqlspec.adapters.asyncpg import AsyncpgConfig from sqlspec.migrations.tracker import AsyncMigrationTracker +pytestmark = pytest.mark.xdist_group("postgres") + def _create_config(postgres_service: PostgresService) -> AsyncpgConfig: """Create AsyncpgConfig from PostgresService fixture.""" diff --git a/tests/integration/test_adapters/test_bigquery/test_type_handler_config.py b/tests/integration/test_adapters/test_bigquery/test_type_handler_config.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/integration/test_adapters/test_oracledb/test_arrow.py b/tests/integration/test_adapters/test_oracledb/test_arrow.py index 23462ed2..e0fcfab0 100644 --- a/tests/integration/test_adapters/test_oracledb/test_arrow.py +++ b/tests/integration/test_adapters/test_oracledb/test_arrow.py @@ -7,6 +7,8 @@ from sqlspec.adapters.oracledb import OracleAsyncConfig +pytestmark = pytest.mark.xdist_group("oracle") + @pytest.fixture async def oracle_arrow_config(oracle_23ai_service: OracleService) -> AsyncGenerator[OracleAsyncConfig, None]: diff --git a/tests/integration/test_adapters/test_sqlite/test_arrow.py b/tests/integration/test_adapters/test_sqlite/test_arrow.py index 9c3c131a..6bee85e3 100644 --- a/tests/integration/test_adapters/test_sqlite/test_arrow.py +++ b/tests/integration/test_adapters/test_sqlite/test_arrow.py @@ -4,6 +4,8 @@ from sqlspec.adapters.sqlite import SqliteConfig +pytestmark = pytest.mark.xdist_group("sqlite") + @pytest.fixture def sqlite_arrow_config() -> SqliteConfig: diff --git a/tests/integration/test_adapters/test_sqlite/test_data_dictionary.py b/tests/integration/test_adapters/test_sqlite/test_data_dictionary.py index 0be78354..9e3e6818 100644 --- a/tests/integration/test_adapters/test_sqlite/test_data_dictionary.py +++ b/tests/integration/test_adapters/test_sqlite/test_data_dictionary.py @@ -2,10 +2,13 @@ from typing import TYPE_CHECKING +import pytest + from sqlspec.driver import VersionInfo if TYPE_CHECKING: from sqlspec.adapters.sqlite.driver import SqliteDriver +pytestmark = pytest.mark.xdist_group("sqlite") def test_sqlite_data_dictionary_version_detection(sqlite_driver: "SqliteDriver") -> None: diff --git a/tests/integration/test_adapters/test_sqlite/test_driver_features.py b/tests/integration/test_adapters/test_sqlite/test_driver_features.py index d78393a5..3bb4abc7 100644 --- a/tests/integration/test_adapters/test_sqlite/test_driver_features.py +++ b/tests/integration/test_adapters/test_sqlite/test_driver_features.py @@ -9,6 +9,8 @@ from sqlspec import SQLSpec from sqlspec.adapters.sqlite import SqliteConfig +pytestmark = pytest.mark.xdist_group("sqlite") + @pytest.mark.sqlite def test_driver_features_enabled_by_default() -> None: diff --git a/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/test_store.py b/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/test_store.py index 3252e35c..b80f1466 100644 --- a/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/test_store.py +++ b/tests/integration/test_adapters/test_sqlite/test_extensions/test_litestar/test_store.py @@ -9,7 +9,7 @@ from sqlspec.adapters.sqlite.config import SqliteConfig from sqlspec.adapters.sqlite.litestar.store import SQLiteStore -pytestmark = [pytest.mark.sqlite, pytest.mark.integration] +pytestmark = [pytest.mark.sqlite, pytest.mark.integration, pytest.mark.xdist_group("sqlite")] @pytest.fixture diff --git a/tests/integration/test_async_migrations.py b/tests/integration/test_async_migrations.py index 22cb8220..6c09a804 100644 --- a/tests/integration/test_async_migrations.py +++ b/tests/integration/test_async_migrations.py @@ -13,6 +13,8 @@ from sqlspec.utils.config_resolver import resolve_config_async from sqlspec.utils.sync_tools import run_ +pytestmark = pytest.mark.xdist_group("migrations") + class TestAsyncMigrationsIntegration: """Integration tests for async migrations functionality.""" diff --git a/tests/integration/test_dishka/conftest.py b/tests/integration/test_dishka/conftest.py index 80ede1ed..241b9ca5 100644 --- a/tests/integration/test_dishka/conftest.py +++ b/tests/integration/test_dishka/conftest.py @@ -6,6 +6,8 @@ dishka = pytest.importorskip("dishka") +pytestmark = pytest.mark.xdist_group("dishka") + if TYPE_CHECKING: from dishka import Provider # type: ignore[import-not-found] diff --git a/tests/integration/test_dishka/test_dishka_integration.py b/tests/integration/test_dishka/test_dishka_integration.py index 1996787f..13061907 100644 --- a/tests/integration/test_dishka/test_dishka_integration.py +++ b/tests/integration/test_dishka/test_dishka_integration.py @@ -12,6 +12,8 @@ dishka = pytest.importorskip("dishka") +pytestmark = pytest.mark.xdist_group("dishka") + def test_simple_sync_dishka_provider(simple_sqlite_provider: Any) -> None: """Test CLI with a simple synchronous Dishka provider.""" diff --git a/tests/integration/test_extensions/test_fastapi_filters_integration.py b/tests/integration/test_extensions/test_fastapi_filters_integration.py index 51aa8357..99272b07 100644 --- a/tests/integration/test_extensions/test_fastapi_filters_integration.py +++ b/tests/integration/test_extensions/test_fastapi_filters_integration.py @@ -14,6 +14,8 @@ from sqlspec.extensions.fastapi import SQLSpecPlugin from sqlspec.extensions.fastapi.providers import dep_cache +pytestmark = pytest.mark.xdist_group("sqlite") + @pytest.fixture(autouse=True) def _clear_dependency_cache() -> Generator[None, None, None]: diff --git a/tests/integration/test_extensions/test_fastapi_integration.py b/tests/integration/test_extensions/test_fastapi_integration.py index 2fe5662a..5b03dc24 100644 --- a/tests/integration/test_extensions/test_fastapi_integration.py +++ b/tests/integration/test_extensions/test_fastapi_integration.py @@ -3,6 +3,7 @@ import tempfile from typing import Annotated, Any +import pytest from fastapi import Depends, FastAPI from fastapi.testclient import TestClient @@ -10,6 +11,8 @@ from sqlspec.base import SQLSpec from sqlspec.extensions.fastapi import SQLSpecPlugin +pytestmark = pytest.mark.xdist_group("sqlite") + def test_fastapi_dependency_injection() -> None: """Test FastAPI dependency injection with session_dependency.""" diff --git a/tests/integration/test_extensions/test_flask_integration.py b/tests/integration/test_extensions/test_flask_integration.py index ec47110c..d280c969 100644 --- a/tests/integration/test_extensions/test_flask_integration.py +++ b/tests/integration/test_extensions/test_flask_integration.py @@ -3,15 +3,14 @@ from typing import Any import pytest - -pytest.importorskip("flask") - from flask import Flask from sqlspec import SQLSpec from sqlspec.adapters.sqlite import SqliteConfig from sqlspec.extensions.flask import SQLSpecPlugin +pytestmark = pytest.mark.xdist_group("sqlite") + def test_flask_manual_mode_sync_sqlite() -> None: """Test Flask extension with manual commit mode and sync SQLite.""" diff --git a/tests/integration/test_extensions/test_starlette_integration.py b/tests/integration/test_extensions/test_starlette_integration.py index 023def90..e4efa4e3 100644 --- a/tests/integration/test_extensions/test_starlette_integration.py +++ b/tests/integration/test_extensions/test_starlette_integration.py @@ -2,6 +2,7 @@ import tempfile +import pytest from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -12,6 +13,8 @@ from sqlspec.base import SQLSpec from sqlspec.extensions.starlette import SQLSpecPlugin +pytestmark = pytest.mark.xdist_group("sqlite") + def test_starlette_basic_query() -> None: """Test basic query execution through Starlette extension.""" diff --git a/tests/integration/test_migrations/test_upgrade_downgrade_versions.py b/tests/integration/test_migrations/test_upgrade_downgrade_versions.py index 809bbbdf..5107b72b 100644 --- a/tests/integration/test_migrations/test_upgrade_downgrade_versions.py +++ b/tests/integration/test_migrations/test_upgrade_downgrade_versions.py @@ -8,6 +8,8 @@ from sqlspec.adapters.sqlite import SqliteConfig from sqlspec.migrations.commands import SyncMigrationCommands +pytestmark = pytest.mark.xdist_group("migrations") + @pytest.fixture def sqlite_config(tmp_path: Path) -> Generator[SqliteConfig, None, None]: