From 88a4d158b339b296d84879b53d5c4aa2d51a2ed8 Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Tue, 4 Nov 2025 01:35:40 +0000 Subject: [PATCH] feat: enhance caching mechanism for migration metadata and add tests --- sqlspec/migrations/runner.py | 233 +++++++++++++----- .../test_migrations/test_migration_runner.py | 63 +++++ 2 files changed, 240 insertions(+), 56 deletions(-) diff --git a/sqlspec/migrations/runner.py b/sqlspec/migrations/runner.py index 8ee04b0d..0d5ef2d7 100644 --- a/sqlspec/migrations/runner.py +++ b/sqlspec/migrations/runner.py @@ -4,7 +4,9 @@ of concerns and proper type safety. """ +import hashlib import inspect +import re import time from abc import ABC, abstractmethod from pathlib import Path @@ -15,6 +17,7 @@ from sqlspec.migrations.loaders import get_migration_loader from sqlspec.utils.logging import get_logger from sqlspec.utils.sync_tools import async_, await_ +from sqlspec.utils.version import parse_version if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Coroutine @@ -26,6 +29,30 @@ logger = get_logger("migrations.runner") +class _CachedMigrationMetadata: + """Cached migration metadata keyed by file path.""" + + __slots__ = ("metadata", "mtime_ns", "size") + + def __init__(self, metadata: "dict[str, Any]", mtime_ns: int, size: int) -> None: + self.metadata = metadata + self.mtime_ns = mtime_ns + self.size = size + + def clone(self) -> "dict[str, Any]": + return dict(self.metadata) + + +class _MigrationFileEntry: + """Represents a migration file discovered during directory scanning.""" + + __slots__ = ("extension_name", "path") + + def __init__(self, path: Path, extension_name: "str | None") -> None: + self.path = path + self.extension_name = extension_name + + class BaseMigrationRunner(ABC): """Base migration runner with common functionality shared between sync and async implementations.""" @@ -52,6 +79,100 @@ def __init__( self.project_root: Path | None = None self.context = context self.extension_configs = extension_configs or {} + self._listing_digest: str | None = None + self._listing_cache: list[tuple[str, Path]] | None = None + self._listing_signatures: dict[str, tuple[int, int]] = {} + self._metadata_cache: dict[str, _CachedMigrationMetadata] = {} + + def _iter_directory_entries(self, base_path: Path, extension_name: "str | None") -> "list[_MigrationFileEntry]": + """Collect migration files discovered under a base path.""" + + if not base_path.exists(): + return [] + + entries: list[_MigrationFileEntry] = [] + for pattern in ("*.sql", "*.py"): + for file_path in sorted(base_path.glob(pattern)): + if file_path.name.startswith("."): + continue + entries.append(_MigrationFileEntry(path=file_path, extension_name=extension_name)) + return entries + + def _collect_listing_entries(self) -> "tuple[list[_MigrationFileEntry], dict[str, tuple[int, int]], str]": + """Gather migration files, stat signatures, and digest for cache validation.""" + + entries: list[_MigrationFileEntry] = [] + signatures: dict[str, tuple[int, int]] = {} + digest_source = hashlib.md5(usedforsecurity=False) + + for entry in self._iter_directory_entries(self.migrations_path, None): + self._record_entry(entry, entries, signatures, digest_source) + + for ext_name, ext_path in self.extension_migrations.items(): + for entry in self._iter_directory_entries(ext_path, ext_name): + self._record_entry(entry, entries, signatures, digest_source) + + return entries, signatures, digest_source.hexdigest() + + def _record_entry( + self, + entry: _MigrationFileEntry, + entries: "list[_MigrationFileEntry]", + signatures: "dict[str, tuple[int, int]]", + digest_source: Any, + ) -> None: + """Record entry metadata for cache decisions.""" + + try: + stat_result = entry.path.stat() + except FileNotFoundError: + return + + path_str = str(entry.path) + token = (stat_result.st_mtime_ns, stat_result.st_size) + signatures[path_str] = token + digest_source.update(path_str.encode("utf-8")) + digest_source.update(f"{token[0]}:{token[1]}".encode()) + entries.append(entry) + + def _build_sorted_listing(self, entries: "list[_MigrationFileEntry]") -> "list[tuple[str, Path]]": + """Construct sorted migration listing from directory entries.""" + + migrations: list[tuple[str, Path]] = [] + + for entry in entries: + version = self._extract_version(entry.path.name) + if not version: + continue + if entry.extension_name: + version = f"ext_{entry.extension_name}_{version}" + migrations.append((version, entry.path)) + + def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any": + version_str = migration_tuple[0] + try: + return parse_version(version_str) + except ValueError: + return version_str + + return sorted(migrations, key=version_sort_key) + + def _log_listing_invalidation( + self, previous: "dict[str, tuple[int, int]]", current: "dict[str, tuple[int, int]]" + ) -> None: + """Log cache invalidation details at INFO level.""" + + prev_keys = set(previous) + curr_keys = set(current) + added = curr_keys - prev_keys + removed = prev_keys - curr_keys + modified = {key for key in prev_keys & curr_keys if previous[key] != current[key]} + logger.info( + "Migration listing cache invalidated (added=%d, removed=%d, modified=%d)", + len(added), + len(removed), + len(modified), + ) def _extract_version(self, filename: str) -> "str | None": """Extract version from filename. @@ -95,9 +216,6 @@ def _calculate_checksum(self, content: str) -> str: Returns: MD5 checksum hex string. """ - import hashlib - import re - canonical_content = re.sub(r"^--\s*name:\s*migrate-[^-]+-(?:up|down)\s*$", "", content, flags=re.MULTILINE) return hashlib.md5(canonical_content.encode()).hexdigest() # noqa: S324 @@ -114,57 +232,33 @@ def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[ For async implementations, returns a coroutine. """ - def _get_migration_files_sync(self) -> "list[tuple[str, Path]]": - """Get all migration files sorted by version. + def _load_migration_listing(self) -> "list[tuple[str, Path]]": + """Build the cached migration listing shared by sync/async runners.""" + entries, signatures, digest = self._collect_listing_entries() + cached_listing = self._listing_cache - Returns: - List of tuples containing (version, file_path). - """ - - migrations = [] + if cached_listing is not None and self._listing_digest == digest: + logger.debug("Migration listing cache hit (%d files)", len(cached_listing)) + return cached_listing - # Scan primary migration path - if self.migrations_path.exists(): - for pattern in ("*.sql", "*.py"): - for file_path in self.migrations_path.glob(pattern): - if file_path.name.startswith("."): - continue - version = self._extract_version(file_path.name) - if version: - migrations.append((version, file_path)) + files = self._build_sorted_listing(entries) + previous_digest = self._listing_digest + previous_signatures = self._listing_signatures - # Scan extension migration paths - for ext_name, ext_path in self.extension_migrations.items(): - if ext_path.exists(): - for pattern in ("*.sql", "*.py"): - for file_path in ext_path.glob(pattern): - if file_path.name.startswith("."): - continue - # Prefix extension migrations to avoid version conflicts - version = self._extract_version(file_path.name) - if version: - # Use ext_ prefix to distinguish extension migrations - prefixed_version = f"ext_{ext_name}_{version}" - migrations.append((prefixed_version, file_path)) - - from sqlspec.utils.version import parse_version + self._listing_cache = files + self._listing_signatures = signatures + self._listing_digest = digest - def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any": - version_str = migration_tuple[0] - try: - return parse_version(version_str) - except ValueError: - return version_str + if previous_digest is None: + logger.debug("Primed migration listing cache with %d files", len(files)) + else: + self._log_listing_invalidation(previous_signatures, signatures) - return sorted(migrations, key=version_sort_key) + return files - def get_migration_files(self) -> "list[tuple[str, Path]]": - """Get all migration files sorted by version. - - Returns: - List of (version, path) tuples sorted by version. - """ - return self._get_migration_files_sync() + @abstractmethod + def get_migration_files(self) -> "list[tuple[str, Path]] | Awaitable[list[tuple[str, Path]]]": + """Get all migration files sorted by version.""" def _load_migration_metadata_common(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]": """Load common migration metadata that doesn't require async operations. @@ -176,7 +270,18 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None" Returns: Partial migration metadata dictionary. """ - import re + cache_key = str(file_path) + stat_result = file_path.stat() + cached_metadata = self._metadata_cache.get(cache_key) + if ( + cached_metadata + and cached_metadata.mtime_ns == stat_result.st_mtime_ns + and cached_metadata.size == stat_result.st_size + ): + logger.debug("Migration metadata cache hit: %s", cache_key) + metadata = cached_metadata.clone() + metadata["file_path"] = file_path + return metadata content = file_path.read_text(encoding="utf-8") checksum = self._calculate_checksum(content) @@ -191,7 +296,7 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None" if transactional_match: transactional = transactional_match.group(1).lower() == "true" - return { + metadata = { "version": version, "description": description, "file_path": file_path, @@ -199,6 +304,14 @@ def _load_migration_metadata_common(self, file_path: Path, version: "str | None" "content": content, "transactional": transactional, } + self._metadata_cache[cache_key] = _CachedMigrationMetadata( + metadata=dict(metadata), mtime_ns=stat_result.st_mtime_ns, size=stat_result.st_size + ) + if cached_metadata: + logger.info("Migration metadata cache invalidated: %s", cache_key) + else: + logger.debug("Cached migration metadata: %s", cache_key) + return metadata def _get_context_for_migration(self, file_path: Path) -> "MigrationContext | None": """Get the appropriate context for a migration file. @@ -263,6 +376,14 @@ def should_use_transaction(self, migration: "dict[str, Any]", config: Any) -> bo class SyncMigrationRunner(BaseMigrationRunner): """Synchronous migration runner with pure sync methods.""" + def get_migration_files(self) -> "list[tuple[str, Path]]": + """Get all migration files sorted by version. + + Returns: + List of (version, path) tuples sorted by version. + """ + return self._load_migration_listing() + def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]": """Load a migration file and extract its components. @@ -287,7 +408,7 @@ def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query) else: try: - has_downgrade = bool(self._get_migration_sql_sync({"loader": loader, "file_path": file_path}, "down")) + has_downgrade = bool(self._get_migration_sql({"loader": loader, "file_path": file_path}, "down")) except Exception: has_downgrade = False @@ -313,7 +434,7 @@ def execute_upgrade( Returns: Tuple of (sql_content, execution_time_ms). """ - upgrade_sql_list = self._get_migration_sql_sync(migration, "up") + upgrade_sql_list = self._get_migration_sql(migration, "up") if upgrade_sql_list is None: return None, 0 @@ -365,7 +486,7 @@ def execute_downgrade( Returns: Tuple of (sql_content, execution_time_ms). """ - downgrade_sql_list = self._get_migration_sql_sync(migration, "down") + downgrade_sql_list = self._get_migration_sql(migration, "down") if downgrade_sql_list is None: return None, 0 @@ -398,7 +519,7 @@ def execute_downgrade( return None, execution_time - def _get_migration_sql_sync(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None": + def _get_migration_sql(self, migration: "dict[str, Any]", direction: str) -> "list[str] | None": """Get migration SQL for given direction (sync version). Args: @@ -475,13 +596,13 @@ def load_all_migrations(self) -> "dict[str, SQL]": class AsyncMigrationRunner(BaseMigrationRunner): """Asynchronous migration runner with pure async methods.""" - async def get_migration_files(self) -> "list[tuple[str, Path]]": # type: ignore[override] + async def get_migration_files(self) -> "list[tuple[str, Path]]": """Get all migration files sorted by version. Returns: List of (version, path) tuples sorted by version. """ - return self._get_migration_files_sync() + return await async_(self._load_migration_listing)() async def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]": """Load a migration file and extract its components. diff --git a/tests/unit/test_migrations/test_migration_runner.py b/tests/unit/test_migrations/test_migration_runner.py index 0394ee0d..a919977a 100644 --- a/tests/unit/test_migrations/test_migration_runner.py +++ b/tests/unit/test_migrations/test_migration_runner.py @@ -10,13 +10,16 @@ """ import tempfile +import time from pathlib import Path from typing import Any from unittest.mock import Mock, patch import pytest +from sqlspec.migrations import runner as runner_module from sqlspec.migrations.base import BaseMigrationRunner +from sqlspec.migrations.runner import SyncMigrationRunner pytestmark = pytest.mark.xdist_group("migrations") @@ -101,6 +104,66 @@ def load_all_migrations(self) -> Any: return TestMigrationRunner(migrations_path) +def _write_basic_sql(path: Path, version: str, body: str = "SELECT 1;") -> None: + path.write_text( + f""" +-- name: migrate-{version}-up +{body} + +-- name: migrate-{version}-down +{body} +""".strip() + ) + + +def test_load_migration_metadata_uses_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Ensure metadata caching prevents redundant checksum calculations.""" + + file_path = tmp_path / "0001_cached.sql" + _write_basic_sql(file_path, "0001") + runner = SyncMigrationRunner(tmp_path, {}, None, {}) + + checksum_calls = 0 + original_checksum = runner_module.BaseMigrationRunner._calculate_checksum + + def _tracked_checksum(self: Any, content: str) -> str: + nonlocal checksum_calls + checksum_calls += 1 + return original_checksum(self, content) + + monkeypatch.setattr(runner_module.BaseMigrationRunner, "_calculate_checksum", _tracked_checksum) + + runner.load_migration(file_path) + runner.load_migration(file_path) + + assert checksum_calls == 1 + + +def test_load_migration_metadata_invalidates_on_change(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Metadata cache invalidates when file content changes.""" + + file_path = tmp_path / "0001_mutated.sql" + _write_basic_sql(file_path, "0001") + runner = SyncMigrationRunner(tmp_path, {}, None, {}) + + checksum_calls = 0 + original_checksum = runner_module.BaseMigrationRunner._calculate_checksum + + def _tracked_checksum(self: Any, content: str) -> str: + nonlocal checksum_calls + checksum_calls += 1 + return original_checksum(self, content) + + monkeypatch.setattr(runner_module.BaseMigrationRunner, "_calculate_checksum", _tracked_checksum) + + runner.load_migration(file_path) + time.sleep(0.01) + _write_basic_sql(file_path, "0001", body="SELECT 2;") + runner.load_migration(file_path) + + assert checksum_calls == 2 + + def test_migration_runner_initialization() -> None: """Test basic MigrationRunner initialization.""" migrations_path = Path("/test/migrations")