diff --git a/sqlspec/migrations/fix.py b/sqlspec/migrations/fix.py index 0c598260..2dc72f85 100644 --- a/sqlspec/migrations/fix.py +++ b/sqlspec/migrations/fix.py @@ -140,7 +140,7 @@ def apply_renames(self, renames: "list[MigrationRename]", dry_run: bool = False) rename.old_path.rename(rename.new_path) - def update_file_content(self, file_path: Path, old_version: str, new_version: str) -> None: + def update_file_content(self, file_path: Path, old_version: "str | None", new_version: "str | None") -> None: """Update SQL query names and version comments in file content. Transforms query names and version metadata from old version to new version: @@ -153,10 +153,14 @@ def update_file_content(self, file_path: Path, old_version: str, new_version: st Args: file_path: Path to file to update. - old_version: Old version string. - new_version: New version string. + old_version: Old version string (None values skipped gracefully). + new_version: New version string (None values skipped gracefully). """ + if not old_version or not new_version: + logger.warning("Skipping content update - missing version information") + return + content = file_path.read_text(encoding="utf-8") up_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-up)") diff --git a/sqlspec/migrations/validation.py b/sqlspec/migrations/validation.py index 85af34ed..0f332c5c 100644 --- a/sqlspec/migrations/validation.py +++ b/sqlspec/migrations/validation.py @@ -14,6 +14,8 @@ from sqlspec.utils.version import parse_version if TYPE_CHECKING: + from collections.abc import Sequence + from sqlspec.utils.version import MigrationVersion __all__ = ("MigrationGap", "detect_out_of_order_migrations", "format_out_of_order_warning") @@ -39,7 +41,7 @@ class MigrationGap: def detect_out_of_order_migrations( - pending_versions: "list[str]", applied_versions: "list[str]" + pending_versions: "Sequence[str | None]", applied_versions: "Sequence[str | None]" ) -> "list[MigrationGap]": """Detect migrations created before already-applied migrations. @@ -51,29 +53,26 @@ def detect_out_of_order_migrations( independent sequences within their own namespaces. Args: - pending_versions: List of migration versions not yet applied. - applied_versions: List of migration versions already applied. + pending_versions: List of migration versions not yet applied (may contain None). + applied_versions: List of migration versions already applied (may contain None). Returns: - List of migration gaps representing out-of-order migrations. - Empty list if no out-of-order migrations detected. - - Example: - Applied: [20251011120000, 20251012140000] - Pending: [20251011130000, 20251013090000] - Result: Gap for 20251011130000 (created between applied migrations) - - Applied: [ext_litestar_0001, 0001, 0002] - Pending: [ext_adk_0001] - Result: [] (extensions excluded from out-of-order detection) + List of migration gaps where pending versions are older than applied. """ if not applied_versions or not pending_versions: return [] gaps: list[MigrationGap] = [] - parsed_applied = [parse_version(v) for v in applied_versions] - parsed_pending = [parse_version(v) for v in pending_versions] + # Filter out None values, empty strings, and whitespace-only strings + valid_applied = [v for v in applied_versions if v is not None and v.strip()] + valid_pending = [v for v in pending_versions if v is not None and v.strip()] + + if not valid_applied or not valid_pending: + return [] + + parsed_applied = [parse_version(v) for v in valid_applied] + parsed_pending = [parse_version(v) for v in valid_pending] core_applied = [v for v in parsed_applied if v.extension is None] core_pending = [v for v in parsed_pending if v.extension is None] diff --git a/sqlspec/utils/version.py b/sqlspec/utils/version.py index b08bb818..8b065e4f 100644 --- a/sqlspec/utils/version.py +++ b/sqlspec/utils/version.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -SEQUENTIAL_PATTERN = re.compile(r"^(?!\d{14}$)(\d+)$") +SEQUENTIAL_PATTERN = re.compile(r"^(?!\d{14}$)\d+$") TIMESTAMP_PATTERN = re.compile(r"^(\d{14})$") EXTENSION_PATTERN = re.compile(r"^ext_(\w+)_(.+)$") @@ -135,7 +135,7 @@ def __repr__(self) -> str: return f"MigrationVersion({self.type.value}={self.raw})" -def is_sequential_version(version_str: str) -> bool: +def is_sequential_version(version_str: "str | None") -> bool: """Check if version string is sequential format. Sequential format: Any sequence of digits (0001, 42, 9999, 10000+). @@ -144,7 +144,7 @@ def is_sequential_version(version_str: str) -> bool: version_str: Version string to check. Returns: - True if sequential format. + True if sequential format, False if None or whitespace. Examples: >>> is_sequential_version("0001") @@ -155,11 +155,15 @@ def is_sequential_version(version_str: str) -> bool: True >>> is_sequential_version("20251011120000") False + >>> is_sequential_version(None) + False """ + if version_str is None or not version_str.strip(): + return False return bool(SEQUENTIAL_PATTERN.match(version_str)) -def is_timestamp_version(version_str: str) -> bool: +def is_timestamp_version(version_str: "str | None") -> bool: """Check if version string is timestamp format. Timestamp format: 14-digit YYYYMMDDHHmmss (20251011120000). @@ -168,14 +172,18 @@ def is_timestamp_version(version_str: str) -> bool: version_str: Version string to check. Returns: - True if timestamp format. + True if timestamp format, False if None or whitespace. Examples: >>> is_timestamp_version("20251011120000") True >>> is_timestamp_version("0001") False + >>> is_timestamp_version(None) + False """ + if version_str is None or not version_str.strip(): + return False if not TIMESTAMP_PATTERN.match(version_str): return False @@ -187,7 +195,7 @@ def is_timestamp_version(version_str: str) -> bool: return True -def parse_version(version_str: str) -> MigrationVersion: +def parse_version(version_str: "str | None") -> MigrationVersion: """Parse version string into structured format. Supports: @@ -202,7 +210,7 @@ def parse_version(version_str: str) -> MigrationVersion: Parsed migration version. Raises: - ValueError: If version format is invalid. + ValueError: If version format is invalid, None, or whitespace-only. Examples: >>> v = parse_version("0001") @@ -219,6 +227,10 @@ def parse_version(version_str: str) -> MigrationVersion: >>> v.extension 'litestar' """ + if version_str is None or not version_str.strip(): + msg = "Invalid migration version: version string is None or empty" + raise ValueError(msg) + extension_match = EXTENSION_PATTERN.match(version_str) if extension_match: extension_name = extension_match.group(1) diff --git a/tests/unit/test_migrations/test_null_handling_fixes.py b/tests/unit/test_migrations/test_null_handling_fixes.py new file mode 100644 index 00000000..23a3baed --- /dev/null +++ b/tests/unit/test_migrations/test_null_handling_fixes.py @@ -0,0 +1,111 @@ +"""Test cases for null handling fixes in migration system.""" + +import tempfile +from pathlib import Path + +import pytest + +from sqlspec.migrations.fix import MigrationFixer +from sqlspec.migrations.validation import detect_out_of_order_migrations +from sqlspec.utils.version import is_sequential_version, is_timestamp_version, parse_version + + +class TestNullHandlingFixes: + """Test fixes for None value handling in migrations.""" + + def test_parse_version_with_none(self): + """Test parse_version handles None gracefully.""" + with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): + parse_version(None) + + def test_parse_version_with_empty_string(self): + """Test parse_version handles empty string gracefully.""" + with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): + parse_version("") + + def test_parse_version_with_whitespace_only(self): + """Test parse_version handles whitespace-only strings.""" + with pytest.raises(ValueError, match="Invalid migration version: version string is None or empty"): + parse_version(" ") + + def test_parse_version_valid_formats_still_work(self): + """Test that valid version formats still work after fixes.""" + # Sequential versions + result = parse_version("0001") + assert result.type.value == "sequential" + assert result.sequence == 1 + + result = parse_version("9999") + assert result.type.value == "sequential" + assert result.sequence == 9999 + + # Timestamp versions + result = parse_version("20251011120000") + assert result.type.value == "timestamp" + assert result.timestamp is not None + + # Extension versions + result = parse_version("ext_litestar_0001") + assert result.type.value == "sequential" # Base is sequential + assert result.extension == "litestar" + + def test_migration_fixer_handles_none_gracefully(self): + """Test MigrationFixer.update_file_content handles None values.""" + with tempfile.TemporaryDirectory() as temp_dir: + migrations_path = Path(temp_dir) + fixer = MigrationFixer(migrations_path) + + test_file = migrations_path / "test.sql" + test_file.write_text("-- Test content") + + # Should not crash with None values + fixer.update_file_content(test_file, None, "0001") + fixer.update_file_content(test_file, "0001", None) + fixer.update_file_content(test_file, None, None) + + # File should remain unchanged + content = test_file.read_text() + assert content == "-- Test content" + + def test_validation_filters_none_values(self): + """Test migration validation filters None values properly.""" + # Should not crash with None values in lists + gaps = detect_out_of_order_migrations( + pending_versions=["0001", None, "0003", ""], applied_versions=[None, "0002", " ", "0004"] + ) + + # Should only process valid versions + assert len(gaps) >= 0 # Should not crash + + def test_sequential_pattern_edge_cases(self): + """Test sequential pattern handles edge cases.""" + assert is_sequential_version("0001") + assert is_sequential_version("9999") + assert is_sequential_version("10000") + assert not is_sequential_version("20251011120000") # Timestamp + assert not is_sequential_version("abc") + assert not is_sequential_version("") + assert not is_sequential_version(None) + + def test_timestamp_pattern_edge_cases(self): + """Test timestamp pattern handles edge cases.""" + assert is_timestamp_version("20251011120000") + assert is_timestamp_version("20250101000000") + assert is_timestamp_version("20251231235959") + assert not is_timestamp_version("0001") # Sequential + assert not is_timestamp_version("2025101112000") # Too short + assert not is_timestamp_version("202510111200000") # Too long + assert not is_timestamp_version("") + assert not is_timestamp_version(None) + + def test_error_messages_are_descriptive(self): + """Test that error messages are helpful for debugging.""" + try: + parse_version(None) + except ValueError as e: + assert "version string is None or empty" in str(e) + + try: + parse_version("") + except ValueError as e: + assert "version string is None or empty" in str(e) diff --git a/tests/unit/test_migrations/test_version.py b/tests/unit/test_migrations/test_version.py index 01e7309f..d9c257fe 100644 --- a/tests/unit/test_migrations/test_version.py +++ b/tests/unit/test_migrations/test_version.py @@ -95,7 +95,7 @@ def test_parse_invalid_version() -> None: with pytest.raises(ValueError, match="Invalid migration version format"): parse_version("abc") - with pytest.raises(ValueError, match="Invalid migration version format"): + with pytest.raises(ValueError, match="Invalid migration version"): parse_version("") with pytest.raises(ValueError, match="Invalid migration version format"):