Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 177 additions & 56 deletions sqlspec/migrations/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
of concerns and proper type safety.
"""

import hashlib
import inspect
import re
import time
from abc import ABC, abstractmethod
from pathlib import Path
Expand All @@ -15,6 +17,7 @@
from sqlspec.migrations.loaders import get_migration_loader
from sqlspec.utils.logging import get_logger
from sqlspec.utils.sync_tools import async_, await_
from sqlspec.utils.version import parse_version

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Coroutine
Expand All @@ -26,6 +29,30 @@
logger = get_logger("migrations.runner")


class _CachedMigrationMetadata:
"""Cached migration metadata keyed by file path."""

__slots__ = ("metadata", "mtime_ns", "size")

def __init__(self, metadata: "dict[str, Any]", mtime_ns: int, size: int) -> None:
self.metadata = metadata
self.mtime_ns = mtime_ns
self.size = size

def clone(self) -> "dict[str, Any]":
return dict(self.metadata)


class _MigrationFileEntry:
"""Represents a migration file discovered during directory scanning."""

__slots__ = ("extension_name", "path")

def __init__(self, path: Path, extension_name: "str | None") -> None:
self.path = path
self.extension_name = extension_name


class BaseMigrationRunner(ABC):
"""Base migration runner with common functionality shared between sync and async implementations."""

Expand All @@ -52,6 +79,100 @@ def __init__(
self.project_root: Path | None = None
self.context = context
self.extension_configs = extension_configs or {}
self._listing_digest: str | None = None
self._listing_cache: list[tuple[str, Path]] | None = None
self._listing_signatures: dict[str, tuple[int, int]] = {}
self._metadata_cache: dict[str, _CachedMigrationMetadata] = {}

def _iter_directory_entries(self, base_path: Path, extension_name: "str | None") -> "list[_MigrationFileEntry]":
"""Collect migration files discovered under a base path."""

if not base_path.exists():
return []

entries: list[_MigrationFileEntry] = []
for pattern in ("*.sql", "*.py"):
for file_path in sorted(base_path.glob(pattern)):
if file_path.name.startswith("."):
continue
entries.append(_MigrationFileEntry(path=file_path, extension_name=extension_name))
return entries

def _collect_listing_entries(self) -> "tuple[list[_MigrationFileEntry], dict[str, tuple[int, int]], str]":
"""Gather migration files, stat signatures, and digest for cache validation."""

entries: list[_MigrationFileEntry] = []
signatures: dict[str, tuple[int, int]] = {}
digest_source = hashlib.md5(usedforsecurity=False)

for entry in self._iter_directory_entries(self.migrations_path, None):
self._record_entry(entry, entries, signatures, digest_source)

for ext_name, ext_path in self.extension_migrations.items():
for entry in self._iter_directory_entries(ext_path, ext_name):
self._record_entry(entry, entries, signatures, digest_source)

return entries, signatures, digest_source.hexdigest()

def _record_entry(
self,
entry: _MigrationFileEntry,
entries: "list[_MigrationFileEntry]",
signatures: "dict[str, tuple[int, int]]",
digest_source: Any,
) -> None:
"""Record entry metadata for cache decisions."""

try:
stat_result = entry.path.stat()
except FileNotFoundError:
return

path_str = str(entry.path)
token = (stat_result.st_mtime_ns, stat_result.st_size)
signatures[path_str] = token
digest_source.update(path_str.encode("utf-8"))
digest_source.update(f"{token[0]}:{token[1]}".encode())
entries.append(entry)

def _build_sorted_listing(self, entries: "list[_MigrationFileEntry]") -> "list[tuple[str, Path]]":
"""Construct sorted migration listing from directory entries."""

migrations: list[tuple[str, Path]] = []

for entry in entries:
version = self._extract_version(entry.path.name)
if not version:
continue
if entry.extension_name:
version = f"ext_{entry.extension_name}_{version}"
migrations.append((version, entry.path))

def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any":
version_str = migration_tuple[0]
try:
return parse_version(version_str)
except ValueError:
return version_str

return sorted(migrations, key=version_sort_key)

def _log_listing_invalidation(
self, previous: "dict[str, tuple[int, int]]", current: "dict[str, tuple[int, int]]"
) -> None:
"""Log cache invalidation details at INFO level."""

prev_keys = set(previous)
curr_keys = set(current)
added = curr_keys - prev_keys
removed = prev_keys - curr_keys
modified = {key for key in prev_keys & curr_keys if previous[key] != current[key]}
logger.info(
"Migration listing cache invalidated (added=%d, removed=%d, modified=%d)",
len(added),
len(removed),
len(modified),
)

def _extract_version(self, filename: str) -> "str | None":
"""Extract version from filename.
Expand Down Expand Up @@ -95,9 +216,6 @@ def _calculate_checksum(self, content: str) -> str:
Returns:
MD5 checksum hex string.
"""
import hashlib
import re

canonical_content = re.sub(r"^--\s*name:\s*migrate-[^-]+-(?:up|down)\s*$", "", content, flags=re.MULTILINE)

return hashlib.md5(canonical_content.encode()).hexdigest() # noqa: S324
Expand All @@ -114,57 +232,33 @@ def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[
For async implementations, returns a coroutine.
"""

def _get_migration_files_sync(self) -> "list[tuple[str, Path]]":
"""Get all migration files sorted by version.
def _load_migration_listing(self) -> "list[tuple[str, Path]]":
"""Build the cached migration listing shared by sync/async runners."""
entries, signatures, digest = self._collect_listing_entries()
cached_listing = self._listing_cache

Returns:
List of tuples containing (version, file_path).
"""

migrations = []
if cached_listing is not None and self._listing_digest == digest:
logger.debug("Migration listing cache hit (%d files)", len(cached_listing))
return cached_listing

# Scan primary migration path
if self.migrations_path.exists():
for pattern in ("*.sql", "*.py"):
for file_path in self.migrations_path.glob(pattern):
if file_path.name.startswith("."):
continue
version = self._extract_version(file_path.name)
if version:
migrations.append((version, file_path))
files = self._build_sorted_listing(entries)
previous_digest = self._listing_digest
previous_signatures = self._listing_signatures

# Scan extension migration paths
for ext_name, ext_path in self.extension_migrations.items():
if ext_path.exists():
for pattern in ("*.sql", "*.py"):
for file_path in ext_path.glob(pattern):
if file_path.name.startswith("."):
continue
# Prefix extension migrations to avoid version conflicts
version = self._extract_version(file_path.name)
if version:
# Use ext_ prefix to distinguish extension migrations
prefixed_version = f"ext_{ext_name}_{version}"
migrations.append((prefixed_version, file_path))

from sqlspec.utils.version import parse_version
self._listing_cache = files
self._listing_signatures = signatures
self._listing_digest = digest

def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any":
version_str = migration_tuple[0]
try:
return parse_version(version_str)
except ValueError:
return version_str
if previous_digest is None:
logger.debug("Primed migration listing cache with %d files", len(files))
else:
self._log_listing_invalidation(previous_signatures, signatures)

return sorted(migrations, key=version_sort_key)
return files

def get_migration_files(self) -> "list[tuple[str, Path]]":
"""Get all migration files sorted by version.

Returns:
List of (version, path) tuples sorted by version.
"""
return self._get_migration_files_sync()
@abstractmethod
def get_migration_files(self) -> "list[tuple[str, Path]] | Awaitable[list[tuple[str, Path]]]":
"""Get all migration files sorted by version."""

def _load_migration_metadata_common(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
"""Load common migration metadata that doesn't require async operations.
Expand All @@ -176,7 +270,18 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None"
Returns:
Partial migration metadata dictionary.
"""
import re
cache_key = str(file_path)
stat_result = file_path.stat()
cached_metadata = self._metadata_cache.get(cache_key)
if (
cached_metadata
and cached_metadata.mtime_ns == stat_result.st_mtime_ns
and cached_metadata.size == stat_result.st_size
):
logger.debug("Migration metadata cache hit: %s", cache_key)
metadata = cached_metadata.clone()
metadata["file_path"] = file_path
return metadata

content = file_path.read_text(encoding="utf-8")
checksum = self._calculate_checksum(content)
Expand All @@ -191,14 +296,22 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None"
if transactional_match:
transactional = transactional_match.group(1).lower() == "true"

return {
metadata = {
"version": version,
"description": description,
"file_path": file_path,
"checksum": checksum,
"content": content,
"transactional": transactional,
}
self._metadata_cache[cache_key] = _CachedMigrationMetadata(
metadata=dict(metadata), mtime_ns=stat_result.st_mtime_ns, size=stat_result.st_size
)
if cached_metadata:
logger.info("Migration metadata cache invalidated: %s", cache_key)
else:
logger.debug("Cached migration metadata: %s", cache_key)
return metadata

def _get_context_for_migration(self, file_path: Path) -> "MigrationContext | None":
"""Get the appropriate context for a migration file.
Expand Down Expand Up @@ -263,6 +376,14 @@ def should_use_transaction(self, migration: "dict[str, Any]", config: Any) -> bo
class SyncMigrationRunner(BaseMigrationRunner):
"""Synchronous migration runner with pure sync methods."""

def get_migration_files(self) -> "list[tuple[str, Path]]":
"""Get all migration files sorted by version.

Returns:
List of (version, path) tuples sorted by version.
"""
return self._load_migration_listing()

def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
"""Load a migration file and extract its components.

Expand All @@ -287,7 +408,7 @@ def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict
has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query)
else:
try:
has_downgrade = bool(self._get_migration_sql_sync({"loader": loader, "file_path": file_path}, "down"))
has_downgrade = bool(self._get_migration_sql({"loader": loader, "file_path": file_path}, "down"))
except Exception:
has_downgrade = False

Expand All @@ -313,7 +434,7 @@ def execute_upgrade(
Returns:
Tuple of (sql_content, execution_time_ms).
"""
upgrade_sql_list = self._get_migration_sql_sync(migration, "up")
upgrade_sql_list = self._get_migration_sql(migration, "up")
if upgrade_sql_list is None:
return None, 0

Expand Down Expand Up @@ -365,7 +486,7 @@ def execute_downgrade(
Returns:
Tuple of (sql_content, execution_time_ms).
"""
downgrade_sql_list = self._get_migration_sql_sync(migration, "down")
downgrade_sql_list = self._get_migration_sql(migration, "down")
if downgrade_sql_list is None:
return None, 0

Expand Down Expand Up @@ -398,7 +519,7 @@ def execute_downgrade(

return None, execution_time

def _get_migration_sql_sync(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None":
"""Get migration SQL for given direction (sync version).

Args:
Expand Down Expand Up @@ -475,13 +596,13 @@ def load_all_migrations(self) -> "dict[str, SQL]":
class AsyncMigrationRunner(BaseMigrationRunner):
"""Asynchronous migration runner with pure async methods."""

async def get_migration_files(self) -> "list[tuple[str, Path]]": # type: ignore[override]
async def get_migration_files(self) -> "list[tuple[str, Path]]":
"""Get all migration files sorted by version.

Returns:
List of (version, path) tuples sorted by version.
"""
return self._get_migration_files_sync()
return await async_(self._load_migration_listing)()

async def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]":
"""Load a migration file and extract its components.
Expand Down
Loading