diff --git a/AGENTS.md b/AGENTS.md index e7739ab1..fb46b839 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,4 +1,4 @@ -# AGENS.md +# AGENTS.md This file provides guidance to Gemini, Claude Code, Codex, and other agents when working with code in this repository. @@ -11,6 +11,96 @@ This file provides guidance to Gemini, Claude Code, Codex, and other agents when - **Propose improvements**: Suggest better patterns, more robust solutions, or cleaner implementations when appropriate - **Be a thoughtful collaborator**: Act as a good teammate who helps improve the overall quality and direction of the project +## Pull Request Guidelines + +### PR Description Standards (MANDATORY) + +Pull request descriptions MUST be concise, factual, and human-readable. Avoid excessive detail that should live in documentation or commit messages. + +**Maximum length**: ~30-40 lines for typical features +**Tone**: Direct, clear, professional - no marketing language or excessive enthusiasm + +**Required sections**: + +1. **Summary** (2-3 sentences): What does this do and why? +2. **The Problem** (2-4 lines): What issue does this solve? +3. **The Solution** (2-4 lines): How does it solve it? +4. **Key Features** (3-5 bullet points): Most important capabilities +5. **Example** (optional): Brief code example if it clarifies usage +6. **Link to docs** (if comprehensive guide exists) + +**PROHIBITED content**: + +- Extensive test coverage tables (this belongs in CI reports) +- Detailed file change lists (GitHub shows this automatically) +- Quality metrics and linting results (CI handles this) +- Commit-by-commit breakdown (git history shows this) +- Implementation details (belongs in code comments/docs) +- Excessive formatting (tables, sections, subsections) +- Marketing language or hype + +**Example of GOOD PR description**: + +```markdown +## Summary + +Adds hybrid versioning for migrations: timestamps in development (no conflicts), +sequential in production (deterministic ordering). Includes an automated +`sqlspec fix` command to convert between formats. + +Closes #116 + +## The Problem + +- Sequential migrations (0001, 0002): merge conflicts when multiple devs create migrations +- Timestamp migrations (20251011120000): no conflicts, but ordering depends on creation time + +## The Solution + +Use timestamps during development, convert to sequential before merging: + + $ sqlspec create-migration -m "add users" + Created: 20251011120000_add_users.sql + + $ sqlspec fix --yes + ✓ Converted to 0003_add_users.sql + +## Key Features + +- Automated conversion via `sqlspec fix` command +- Updates database tracking to prevent errors +- Idempotent - safe to re-run after pulling changes +- Stable checksums through conversions + +See [docs/guides/migrations/hybrid-versioning.md](docs/guides/migrations/hybrid-versioning.md) +for full documentation. +``` + +**Example of BAD PR description**: + +```markdown +## Summary +[800+ lines of excessive detail including test counts, file changes, +quality metrics, implementation details, commit lists, etc.] +``` + +**CI Integration examples** - Keep to 5-10 lines maximum: + +```yaml +# GitHub Actions example +- run: sqlspec fix --yes +- run: git add migrations/ && git commit && git push +``` + +**When to include more detail**: + +- Breaking changes warrant a "Breaking Changes" section +- Complex architectural changes may need a "Design Decisions" section +- Security fixes may need a "Security Impact" section + +Keep it focused: the PR description should help reviewers understand WHAT and WHY quickly. +Implementation details belong in code, commits, and documentation. + ## Common Development Commands ### Building and Installation @@ -53,7 +143,7 @@ SQLSpec is a type-safe SQL query mapper designed for minimal abstraction between 2. **Adapters (`sqlspec/adapters/`)**: Database-specific implementations. Each adapter consists of: - `config.py`: Configuration classes specific to the database - `driver.py`: Driver implementation (sync/async) that executes queries - - `_types.py`: Type definitions specific to the adapter or other uncompilable mypyc onbjects + - `_types.py`: Type definitions specific to the adapter or other uncompilable mypyc objects - Supported adapters: `adbc`, `aiosqlite`, `asyncmy`, `asyncpg`, `bigquery`, `duckdb`, `oracledb`, `psqlpy`, `psycopg`, `sqlite` 3. **Driver System (`sqlspec/driver/`)**: Base classes and mixins for all database drivers: diff --git a/docs/changelog.rst b/docs/changelog.rst index b28c4872..4025f5c5 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -10,6 +10,58 @@ SQLSpec Changelog Recent Updates ============== +Hybrid Versioning with Fix Command +----------------------------------- + +Added comprehensive hybrid versioning support for database migrations: + +- **Fix Command** - Convert timestamp migrations to sequential format +- **Hybrid Workflow** - Use timestamps in development, sequential in production +- **Automatic Conversion** - CI integration for seamless workflow +- **Safety Features** - Automatic backup, rollback on errors, dry-run preview + +Key Features: + +- **Zero merge conflicts**: Developers use timestamps (``20251011120000``) during development +- **Deterministic ordering**: Production uses sequential format (``0001``, ``0002``, etc.) +- **Database synchronization**: Automatically updates version tracking table +- **File operations**: Renames files and updates SQL query names +- **CI-ready**: ``--yes`` flag for automated workflows + +.. code-block:: bash + + # Preview changes + sqlspec --config myapp.config fix --dry-run + + # Apply conversion + sqlspec --config myapp.config fix + + # CI/CD mode + sqlspec --config myapp.config fix --yes --no-database + +Example conversion: + +.. code-block:: text + + Before: After: + migrations/ migrations/ + ├── 0001_initial.sql ├── 0001_initial.sql + ├── 0002_add_users.sql ├── 0002_add_users.sql + ├── 20251011120000_products.sql → ├── 0003_add_products.sql + └── 20251012130000_orders.sql → └── 0004_add_orders.sql + +**Documentation:** + +- Complete CLI reference: :doc:`usage/cli` +- Workflow guide: :ref:`hybrid-versioning-guide` +- CI integration examples for GitHub Actions and GitLab CI + +**Use Cases:** + +- Teams with parallel development avoiding migration number conflicts +- Projects requiring deterministic migration ordering in production +- CI/CD pipelines that standardize migrations before deployment + Shell Completion Support ------------------------- diff --git a/docs/conf.py b/docs/conf.py index dccf04cc..a0330276 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -62,7 +62,6 @@ "sphinx_paramlinks", "sphinxcontrib.mermaid", ] -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "guides/*"] intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "msgspec": ("https://jcristharif.com/msgspec/", None), @@ -156,7 +155,15 @@ templates_path = ["_templates"] html_js_files = ["versioning.js"] html_css_files = ["custom.css"] -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "PYPI_README.md", "STYLE_GUIDE.md", "VOICE_AUDIT_REPORT.md"] +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", + "PYPI_README.md", + "STYLE_GUIDE.md", + "VOICE_AUDIT_REPORT.md", + "guides/**", +] html_show_sourcelink = True html_copy_source = True diff --git a/docs/extensions/adk/migrations.rst b/docs/extensions/adk/migrations.rst index 584e69db..f78ba522 100644 --- a/docs/extensions/adk/migrations.rst +++ b/docs/extensions/adk/migrations.rst @@ -148,7 +148,7 @@ SQLSpec includes a built-in migration for ADK tables: from sqlspec.extensions.adk.migrations import create_adk_tables_migration -Location: ``sqlspec/extensions/adk/migrations/0001_create_adk_tables.py`` +Location: ``sqlspec/extensions/adk/migrations/`` You can copy this template for custom migrations: diff --git a/docs/guides/README.md b/docs/guides/README.md index 1747e9d4..854f377e 100644 --- a/docs/guides/README.md +++ b/docs/guides/README.md @@ -28,6 +28,12 @@ Optimization guides for SQLSpec: - [**SQLglot Guide**](performance/sqlglot.md) - SQL parsing, transformation, and optimization with SQLglot - [**MyPyC Guide**](performance/mypyc.md) - Compilation strategies for high-performance Python code +## Migrations + +Database migration strategies and workflows: + +- [**Hybrid Versioning**](migrations/hybrid-versioning.md) - Combine timestamp and sequential versioning for optimal workflows + ## Testing Testing strategies and patterns: @@ -91,6 +97,8 @@ docs/guides/ │ ├── oracle.md │ ├── postgres.md │ └── ... +├── migrations/ # Migration workflows +│ └── hybrid-versioning.md ├── performance/ # Performance optimization │ ├── sqlglot.md │ └── mypyc.md diff --git a/docs/guides/migrations/hybrid-versioning.md b/docs/guides/migrations/hybrid-versioning.md new file mode 100644 index 00000000..ea677060 --- /dev/null +++ b/docs/guides/migrations/hybrid-versioning.md @@ -0,0 +1,989 @@ +(hybrid-versioning-guide)= + +# Hybrid Versioning Guide + +**Combine timestamp and sequential migration numbering for optimal development and production workflows.** + +## Overview + +Hybrid versioning is a migration strategy that uses different version formats for different stages of development: + +- **Development**: Timestamp-based versions (e.g., `20251011120000`) avoid merge conflicts +- **Production**: Sequential versions (e.g., `0001`, `0002`) provide deterministic ordering + +SQLSpec's `fix` command automates the conversion between these formats, enabling teams to work independently without version collisions while maintaining strict ordering in production. + +## The Problem + +Traditional migration versioning strategies have trade-offs: + +### Sequential-Only Approach + +``` +migrations/ +├── 0001_initial.sql +├── 0002_add_users.sql +├── 0003_add_products.sql ← Alice creates this +└── 0003_add_orders.sql ← Bob creates this (CONFLICT!) +``` + +**Problem**: When multiple developers create migrations simultaneously, they pick the same next number, causing merge conflicts. + +### Timestamp-Only Approach + +``` +migrations/ +├── 20251010120000_initial.sql +├── 20251011090000_add_users.sql +├── 20251011120000_add_products.sql ← Alice (created at 12:00) +└── 20251011100000_add_orders.sql ← Bob (created at 10:00, but merged later) +``` + +**Problem**: Migration order depends on timestamp, not merge order. Bob's migration runs first even though Alice's PR merged first. + +## The Solution: Hybrid Versioning + +Hybrid versioning combines the best of both approaches: + +1. **Developers create migrations with timestamps** (no conflicts) +2. **CI converts timestamps to sequential before merge** (deterministic order) +3. **Production sees only sequential migrations** (clean, predictable) + +``` +Development (feature branch): +├── 0001_initial.sql +├── 0002_add_users.sql +└── 20251011120000_add_products.sql ← Alice's new migration + + ↓ PR merged, CI runs `sqlspec fix` + +Main branch: +├── 0001_initial.sql +├── 0002_add_users.sql +└── 0003_add_products.sql ← Converted to sequential +``` + +## Workflow + +### 1. Development Phase + +Developers create migrations normally: + +```bash +# Alice on feature/products +sqlspec --config myapp.config create-migration -m "add products table" +# Creates: 20251011120000_add_products.sql + +# Bob on feature/orders (same time) +sqlspec --config myapp.config create-migration -m "add orders table" +# Creates: 20251011120500_add_orders.sql +``` + +No conflicts! Timestamps are unique. + +### 2. Pre-Merge CI Check + +Before merging to main, CI converts timestamps to sequential: + +```yaml +# .github/workflows/migrations.yml +name: Fix Migrations +on: + pull_request: + branches: [main] + paths: ['migrations/**'] + +jobs: + fix-migrations: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.head_ref }} + + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install SQLSpec + run: pip install sqlspec[cli] + + - name: Convert migrations to sequential + run: | + sqlspec --config myapp.config fix --yes --no-database + + - name: Commit changes + run: | + git config user.name "GitHub Actions" + git config user.email "actions@github.com" + git add migrations/ + if ! git diff --quiet && ! git diff --staged --quiet; then + git commit -m "fix: convert migrations to sequential format" + git push + fi +``` + +### 3. Production Deployment + +Production only sees sequential migrations with deterministic ordering: + +``` +migrations/ +├── 0001_initial.sql +├── 0002_add_users.sql +├── 0003_add_products.sql ← Alice's migration (merged first) +└── 0004_add_orders.sql ← Bob's migration (merged second) +``` + +Order is determined by merge order, not timestamp. + +## Command Reference + +### Preview Changes + +See what would be converted without applying: + +```bash +sqlspec --config myapp.config fix --dry-run +``` + +Output: + +``` +╭─────────────────────────────────────────────────────────╮ +│ Migration Conversions │ +├───────────────┬───────────────┬─────────────────────────┤ +│ Current Ver │ New Version │ File │ +├───────────────┼───────────────┼─────────────────────────┤ +│ 20251011120000│ 0003 │ 20251011120000_prod.sql │ +│ 20251012130000│ 0004 │ 20251012130000_ord.sql │ +╰───────────────┴───────────────┴─────────────────────────╯ + +2 migrations will be converted +[Preview Mode - No changes made] +``` + +### Apply Conversion + +Convert with confirmation: + +```bash +sqlspec --config myapp.config fix +``` + +You'll be prompted: + +``` +Proceed with conversion? [y/N]: y + +✓ Created backup in .backup_20251012_143022 +✓ Renamed 20251011120000_add_products.sql → 0003_add_products.sql +✓ Renamed 20251012130000_add_orders.sql → 0004_add_orders.sql +✓ Updated 2 database records +✓ Conversion complete! +``` + +### CI/CD Mode + +Auto-approve for automation: + +```bash +sqlspec --config myapp.config fix --yes +``` + +### Files Only (Skip Database) + +Useful when database is not accessible: + +```bash +sqlspec --config myapp.config fix --no-database +``` + +## What Gets Updated + +The `fix` command updates three things: + +### 1. File Names + +``` +Before: 20251011120000_add_products.sql +After: 0003_add_products.sql +``` + +### 2. SQL Query Names (inside .sql files) + +```sql +-- Before +-- name: migrate-20251011120000-up +CREATE TABLE products (id INT); + +-- name: migrate-20251011120000-down +DROP TABLE products; + +-- After +-- name: migrate-0003-up +CREATE TABLE products (id INT); + +-- name: migrate-0003-down +DROP TABLE products; +``` + +### 3. Database Records + +Migration tracking table is updated: + +```sql +-- Before +INSERT INTO sqlspec_versions (version_num, ...) VALUES ('20251011120000', ...); + +-- After +UPDATE sqlspec_versions SET version_num = '0003' WHERE version_num = '20251011120000'; +``` + +## Safety Features + +### Automatic Backups + +Before any changes, a timestamped backup is created: + +``` +migrations/ +├── .backup_20251012_143022/ ← Automatic backup +│ ├── 20251011120000_add_products.sql +│ └── 20251012130000_add_orders.sql +├── 0003_add_products.sql +└── 0004_add_orders.sql +``` + +### Automatic Rollback + +If conversion fails, files are automatically restored: + +``` +Error: Target file already exists: 0003_add_products.sql +Restored files from backup +``` + +### Dry Run Mode + +Always preview before applying: + +```bash +sqlspec --config myapp.config fix --dry-run +``` + +## Best Practices + +### 1. Always Use Version Control + +Commit migration files before running `fix`: + +```bash +git add migrations/ +git commit -m "feat: add products migration" + +# Then run fix +sqlspec --config myapp.config fix +``` + +### 2. Run Fix in CI, Not Locally + +Let CI handle conversion to avoid inconsistencies: + +```yaml +# Good: CI converts before merge +on: + pull_request: + branches: [main] + +# Bad: Manual conversion on developer machines +``` + +### 3. Test Migrations Before Fix + +Ensure migrations work before converting: + +```bash +# Test on development database +sqlspec --config test.config upgrade +sqlspec --config test.config downgrade + +# Then convert +sqlspec --config myapp.config fix +``` + +### 4. Keep Backup Until Verified + +Don't delete backup immediately: + +```bash +# Convert +sqlspec --config myapp.config fix + +# Test deployment +sqlspec --config prod.config upgrade + +# Only then remove backup +rm -rf migrations/.backup_* +``` + +### 5. Document Your Workflow + +Add to your project's CONTRIBUTING.md: + +```markdown +## Migrations + +- Create migrations with: `sqlspec --config myapp.config create-migration -m "description"` +- Migrations use timestamp format during development +- CI automatically converts to sequential before merge +- Never manually rename migration files +``` + +## Example Workflows + +### GitHub Actions (Recommended) + +```yaml +# .github/workflows/fix-migrations.yml +name: Fix Migrations + +on: + pull_request: + branches: [main] + paths: ['migrations/**'] + +jobs: + fix: + runs-on: ubuntu-latest + + permissions: + contents: write + + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.head_ref }} + token: ${{ secrets.GITHUB_TOKEN }} + + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Install dependencies + run: | + pip install sqlspec[cli] + pip install -e . + + - name: Fix migrations + run: | + sqlspec --config myapp.config fix --yes --no-database + + - name: Commit and push + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add migrations/ + if ! git diff --quiet HEAD migrations/; then + git commit -m "fix: convert migrations to sequential format" + git push + else + echo "No changes to commit" + fi +``` + +### GitLab CI + +```yaml +# .gitlab-ci.yml +fix-migrations: + stage: migrate + image: python:3.12 + before_script: + - pip install sqlspec[cli] + script: + - sqlspec --config myapp.config fix --yes --no-database + - | + if ! git diff --quiet migrations/; then + git config user.name "GitLab CI" + git config user.email "ci@gitlab.com" + git add migrations/ + git commit -m "fix: convert migrations to sequential" + git push origin HEAD:$CI_COMMIT_REF_NAME + fi + only: + refs: + - merge_requests + changes: + - migrations/** +``` + +### Manual Workflow + +If you can't use CI: + +```bash +# 1. Create feature branch +git checkout -b feature/new-stuff + +# 2. Create migration +sqlspec --config myapp.config create-migration -m "add new stuff" +# Creates: 20251012120000_add_new_stuff.sql + +# 3. Commit +git add migrations/ +git commit -m "feat: add new stuff migration" + +# 4. Before merging, convert to sequential +sqlspec --config myapp.config fix --dry-run # Preview +sqlspec --config myapp.config fix # Apply + +# 5. Commit converted files +git add migrations/ +git commit -m "fix: convert to sequential format" + +# 6. Merge to main +git checkout main +git merge feature/new-stuff +``` + +## Troubleshooting + +### Version Collision + +**Problem**: `Target file already exists: 0003_add_products.sql` + +**Solution**: Someone already has a migration with that number. Pull latest from main: + +```bash +git pull origin main +# Then fix will assign next available number +sqlspec --config myapp.config fix +``` + +### Database Out of Sync + +**Problem**: Database has old timestamp versions after fix + +**Solution**: Run fix with database updates: + +```bash +sqlspec --config myapp.config fix --yes +``` + +Or manually update tracking table: + +```sql +UPDATE sqlspec_versions +SET version_num = '0003' +WHERE version_num = '20251011120000'; +``` + +### After Pulling Fixed Migrations + +**Problem**: Teammate ran `fix` and merged to main. You pull changes and your local database still has timestamp version. + +**Example**: + +- Your database: `version_num = '20251011120000'` +- Migration file (after pull): `0003_add_users.sql` + +**Solution (Automatic)**: Just run `upgrade` - auto-sync handles it: + +```bash +git pull origin main # Get renamed migration files +sqlspec --config myapp.config upgrade # Auto-sync updates: 20251011120000 → 0003 +``` + +**Solution (Manual)**: If you disabled auto-sync, run `fix`: + +```bash +git pull origin main # Get renamed migration files +sqlspec --config myapp.config fix # Updates your database: 20251011120000 → 0003 +sqlspec --config myapp.config upgrade # Now sees 0003 already applied +``` + +**Why this happens**: Migration files were renamed but your local database still references the old timestamp version. + +**Best Practice**: Enable auto-sync (default) for automatic reconciliation. See [Auto-Sync section](#auto-sync-the-fix-command-on-autopilot) for details. + +### CI Fails to Push + +**Problem**: CI can't push converted migrations + +**Solution**: Check repository permissions: + +- GitHub: Enable "Allow GitHub Actions to create and approve pull requests" +- GitLab: Use access token with `write_repository` scope + +### Mixed Formats After Merge + +**Problem**: Some migrations are timestamp, some sequential + +**Solution**: This is normal during transition. Run fix to convert remaining: + +```bash +sqlspec --config myapp.config fix +``` + +## Migration from Sequential-Only + +If you're currently using sequential-only migrations: + +1. Continue using sequential for existing migrations +2. New migrations can use timestamps +3. Run `fix` before each merge to convert + +No migration history is lost - the command only converts timestamps that exist. + +## Migration from Timestamp-Only + +If you're currently using timestamp-only migrations: + +1. Run `fix` once to convert all existing timestamps +2. Continue using timestamps for new migrations +3. Run `fix` in CI for future conversions + +```bash +# One-time conversion +sqlspec --config myapp.config fix --dry-run # Preview +sqlspec --config myapp.config fix # Convert all + +git add migrations/ +git commit -m "chore: convert all migrations to sequential" +git push +``` + +## Auto-Sync: The Fix Command on Autopilot + +SQLSpec now automatically reconciles renamed migrations when you run `upgrade`. No more manual `fix` commands after pulling changes. + +### How It Works + +When you run `upgrade`, SQLSpec: + +1. Checks if migration files have been renamed (timestamp → sequential) +2. Validates checksums match between old and new versions +3. Auto-updates your database tracking to match the renamed files +4. Proceeds with normal migration workflow + +This happens transparently - you just run `upgrade` and it works. + +### Usage Scenarios + +#### Scenario 1: Pull and Go (The Happy Path) + +Your teammate merged a PR that converted migrations to sequential format. + +```bash +# Your database before pull +SELECT version_num FROM ddl_migrations; +# 20251011120000 ← timestamp format + +git pull origin main + +# Migration files after pull +ls migrations/ +# 0001_initial.sql +# 0002_add_users.sql +# 0003_add_products.sql ← was 20251011120000_add_products.sql + +# Just run upgrade - auto-sync handles everything +sqlspec --config myapp.config upgrade + +# Output: +# Reconciled 1 version record(s) +# Already at latest version + +# Your database after upgrade +SELECT version_num FROM ddl_migrations; +# 0003 ← automatically updated! +``` + +**Before auto-sync**: You'd need to manually run `fix` or update the database yourself. + +**With auto-sync**: Just `upgrade` and continue working. + +#### Scenario 2: Team Workflow (Multiple PRs) + +Three developers working on different features. + +```bash +# Alice (feature/products branch) +sqlspec --config myapp.config create-migration -m "add products table" +# Creates: 20251011120000_add_products.sql + +# Bob (feature/orders branch) +sqlspec --config myapp.config create-migration -m "add orders table" +# Creates: 20251011121500_add_orders.sql + +# Carol (feature/invoices branch) +sqlspec --config myapp.config create-migration -m "add invoices table" +# Creates: 20251011123000_add_invoices.sql +``` + +**Alice's PR merges first:** + +```bash +# CI runs: sqlspec --config myapp.config fix --yes --no-database +# Renames: 20251011120000_add_products.sql → 0003_add_products.sql +# Merged to main +``` + +**Bob pulls and continues:** + +```bash +git pull origin main + +# Bob's local database still has: 20251011120000 (Alice's old timestamp) +# Bob's migration files now have: 0003_add_products.sql (Alice's renamed) + +# Bob just runs upgrade to apply his changes +sqlspec --config myapp.config upgrade + +# Output: +# Reconciled 1 version record(s) ← Alice's migration auto-synced +# Found 1 pending migrations +# Applying 20251011121500: add orders table +# ✓ Applied in 15ms +``` + +**Bob's PR merges second:** + +```bash +# CI converts Bob's timestamp → 0004 +# Merged to main +``` + +**Carol pulls and continues:** + +```bash +git pull origin main + +# Carol's local database has: +# - 20251011120000 (Alice's old timestamp) +# - 20251011121500 (Bob's old timestamp) + +# Carol's migration files now have: +# - 0003_add_products.sql (Alice's renamed) +# - 0004_add_orders.sql (Bob's renamed) + +sqlspec --config myapp.config upgrade + +# Output: +# Reconciled 2 version record(s) ← Both auto-synced! +# Found 1 pending migrations +# Applying 20251011123000: add invoices table +# ✓ Applied in 12ms +``` + +**Key takeaway**: No manual intervention needed. Each developer just pulls and runs `upgrade`. + +#### Scenario 3: Production Deployment + +Your production database has never seen timestamp versions. + +```bash +# Production database +SELECT version_num FROM ddl_migrations; +# 0001 +# 0002 +# No timestamps - only sequential + +# Deploy new version with migrations 0003, 0004, 0005 +sqlspec --config prod.config upgrade + +# Output: +# Found 3 pending migrations +# Applying 0003: add products table +# ✓ Applied in 45ms +# Applying 0004: add orders table +# ✓ Applied in 32ms +# Applying 0005: add invoices table +# ✓ Applied in 28ms +``` + +**Key takeaway**: Production never sees timestamps. Auto-sync is a no-op when all versions are already sequential. + +#### Scenario 4: Staging Environment Sync + +Staging database has old timestamp versions from before you adopted hybrid versioning. + +```bash +# Staging database (mixed state) +SELECT version_num FROM ddl_migrations; +# 0001 +# 0002 +# 20251008100000 ← old timestamp from before hybrid versioning +# 20251009150000 ← old timestamp +# 20251010180000 ← old timestamp + +# Migration files (after fix command ran in CI) +ls migrations/ +# 0001_initial.sql +# 0002_add_users.sql +# 0003_add_feature_x.sql ← was 20251008100000 +# 0004_add_feature_y.sql ← was 20251009150000 +# 0005_add_feature_z.sql ← was 20251010180000 +# 0006_new_feature.sql ← new migration + +sqlspec --config staging.config upgrade + +# Output: +# Reconciled 3 version record(s) +# Found 1 pending migrations +# Applying 0006: new feature +# ✓ Applied in 38ms + +# Staging database (cleaned up) +SELECT version_num FROM ddl_migrations; +# 0001 +# 0002 +# 0003 ← auto-synced from 20251008100000 +# 0004 ← auto-synced from 20251009150000 +# 0005 ← auto-synced from 20251010180000 +# 0006 ← newly applied +``` + +**Key takeaway**: Auto-sync gradually cleans up old timestamp versions as you deploy. No manual database updates needed. + +#### Scenario 5: Checksum Validation (Safety Check) + +Someone manually edited a migration file after it was applied. + +```bash +# Database has timestamp version +SELECT version_num, checksum FROM ddl_migrations WHERE version_num = '20251011120000'; +# 20251011120000 | a1b2c3d4e5f6... + +# Migration file renamed but content changed +cat migrations/0003_add_products.sql +# Different SQL than what was originally applied + +sqlspec --config myapp.config upgrade + +# Output: +# Checksum mismatch for 20251011120000 → 0003, skipping auto-sync +# Found 0 pending migrations + +# Database unchanged - safely prevented incorrect sync +SELECT version_num FROM ddl_migrations; +# 20251011120000 ← still has old version (not auto-synced) +``` + +**Key takeaway**: Auto-sync validates checksums before updating. Protects against corruption or incorrect renames. + +### Configuration Options + +#### Enable/Disable Auto-Sync + +Auto-sync is enabled by default. Disable via config: + +```python +from sqlspec.adapters.asyncpg import AsyncpgConfig + +config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/mydb"}, + migration_config={ + "script_location": "migrations", + "enabled": True, + "auto_sync": False # Disable auto-sync + } +) +``` + +#### Disable Per-Command + +Disable for a single migration run: + +```bash +sqlspec --config myapp.config upgrade --no-auto-sync +``` + +Useful when you want explicit control over version reconciliation. + +### When to Disable Auto-Sync + +Auto-sync is safe for most workflows, but disable if: + +1. **You want explicit control**: Run `fix` manually to see exactly what's being updated +2. **Custom migration workflows**: You're using non-standard migration file organization +3. **Debugging**: Isolate whether auto-sync is causing unexpected behavior + +### Troubleshooting Auto-Sync + +#### Auto-Sync Not Reconciling + +**Problem**: Auto-sync reports 0 reconciled records but you expected some. + +**Possible causes:** + +1. **Already synced**: Database already has sequential versions +2. **No conversion map**: No timestamp migrations found in files +3. **Version already exists**: New version already applied (edge case from parallel execution) + +**Debug steps:** + +```bash +# Check what's in your database +sqlspec --config myapp.config current --verbose + +# Check what's in your migration files +ls -la migrations/ + +# Try explicit fix to see what would convert +sqlspec --config myapp.config fix --dry-run +``` + +#### Checksum Mismatch Warnings + +**Problem**: `Checksum mismatch for X → Y, skipping auto-sync` + +**Cause**: Migration content changed between when it was applied and when it was renamed. + +**Solution:** + +```bash +# Option 1: Manual fix (if change was intentional) +sqlspec --config myapp.config fix --yes + +# Option 2: Revert file changes (if change was accidental) +git checkout migrations/0003_add_products.sql +``` + +### Migration from Manual Fix Workflow + +If you're currently using the manual `fix` workflow, auto-sync is backward-compatible: + +```bash +# Old workflow (still works) +git pull origin main +sqlspec --config myapp.config fix # Manual sync +sqlspec --config myapp.config upgrade + +# New workflow (auto-sync handles it) +git pull origin main +sqlspec --config myapp.config upgrade # Auto-sync runs automatically +``` + +Both workflows produce identical results. Auto-sync just eliminates the manual step. + +### Best Practices with Auto-Sync + +1. **Trust auto-sync in dev/staging**: Let it handle reconciliation automatically +2. **Monitor in production**: Check reconciliation output in deployment logs +3. **Use --no-auto-sync for debugging**: Disable temporarily to isolate issues +4. **Keep checksums intact**: Don't edit migration files after they're applied + +## Advanced Topics + +### Extension Migrations + +**Important**: The `fix` command only affects **user-created** migrations, not packaged extension migrations that ship with SQLSpec. + +#### Packaged Extension Migrations (NOT affected by `fix`) + +Migrations included with SQLSpec extensions are **always sequential**: + +``` +sqlspec/extensions/ +├── adk/migrations/ +│ └── 0001_create_adk_tables.py ← Always sequential +└── litestar/migrations/ + └── 0001_create_session_table.py ← Always sequential + +Database tracking: +- ext_adk_0001 +- ext_litestar_0001 +``` + +These are pre-built migrations that ship with the library and are never converted. + +#### User-Created Extension Migrations (Affected by `fix`) + +If you create custom migrations for extension functionality, they follow the standard hybrid workflow: + +``` +Before fix (your development branch): +├── 0001_initial.sql +├── ext_adk_0001_create_adk_tables.sql ← Packaged (sequential) +├── 20251011120000_custom_adk_columns.sql ← Your custom migration (timestamp) + +After fix (merged to main): +├── 0001_initial.sql +├── ext_adk_0001_create_adk_tables.sql ← Unchanged (packaged) +├── 0002_custom_adk_columns.sql ← Converted to sequential +``` + +Each extension has its own sequence counter for user-created migrations. + +### Multiple Databases + +When using multiple database configurations: + +```bash +# Fix migrations for specific database +sqlspec --config myapp.config fix --bind-key postgres + +# Or fix all +sqlspec --config myapp.config fix +``` + +### Custom Migration Paths + +Works with custom migration directories: + +```python +# config.py +AsyncpgConfig( + pool_config={"dsn": "..."}, + migration_config={ + "script_location": "db/migrations", # Custom path + "enabled": True + } +) +``` + +```bash +sqlspec --config myapp.config fix +# Converts migrations in db/migrations/ +``` + +## Performance + +The `fix` command is designed for fast execution: + +- File operations are atomic (rename only) +- Database updates use single transaction +- Backup is file-system copy (instant) +- No migration re-execution + +Typical conversion time: < 1 second for 100 migrations. + +## See Also + +- [CLI Reference](../../usage/cli.rst) - Complete `fix` command documentation +- [Configuration Guide](../../usage/configuration.rst) - Migration configuration options +- [Best Practices](../best-practices.md) - General migration best practices + +## Summary + +Hybrid versioning with the `fix` command provides: + +- **Zero merge conflicts** - Timestamps during development +- **Deterministic ordering** - Sequential in production +- **Automatic conversion** - CI handles the switch +- **Safe operations** - Automatic backup and rollback +- **Database sync** - Version tracking stays current + +Start using hybrid versioning today: + +```bash +# Preview conversion +sqlspec --config myapp.config fix --dry-run + +# Apply conversion +sqlspec --config myapp.config fix + +# Set up CI workflow (see examples above) +``` diff --git a/docs/usage/cli.rst b/docs/usage/cli.rst index 3c86f520..6c8e1056 100644 --- a/docs/usage/cli.rst +++ b/docs/usage/cli.rst @@ -137,7 +137,7 @@ You should see available commands: .. code-block:: text - create-migration downgrade init show-config show-current-revision stamp upgrade + create-migration downgrade fix init show-config show-current-revision stamp upgrade Try option completion: @@ -427,11 +427,16 @@ Apply pending migrations up to a specific revision. ``--no-prompt`` Skip confirmation prompt. +``--no-auto-sync`` + Disable automatic version reconciliation. When enabled (default), SQLSpec automatically + updates database tracking when migrations are renamed from timestamp to sequential format. + Use this flag when you want explicit control over version reconciliation. + **Examples:** .. code-block:: bash - # Upgrade to latest + # Upgrade to latest (with auto-sync enabled by default) sqlspec --config myapp.config upgrade # Upgrade to specific revision @@ -452,6 +457,9 @@ Apply pending migrations up to a specific revision. # No confirmation sqlspec --config myapp.config upgrade --no-prompt + # Disable auto-sync for manual control + sqlspec --config myapp.config upgrade --no-auto-sync + downgrade ^^^^^^^^^ @@ -512,6 +520,128 @@ Rollback migrations to a specific revision. Downgrade operations can result in data loss. Always backup your database before running downgrade commands in production. +fix +^^^ + +Convert timestamp migrations to sequential format for hybrid versioning workflow. + +.. code-block:: bash + + sqlspec --config myapp.config fix [OPTIONS] + +**Purpose:** + +The ``fix`` command implements a hybrid versioning workflow that combines the benefits +of both timestamp and sequential migration numbering: + +- **Development**: Use timestamps to avoid merge conflicts +- **Production**: Use sequential numbers for deterministic ordering + +This command converts timestamp-format migrations (YYYYMMDDHHmmss) to sequential +format (0001, 0002, etc.) while preserving migration history in the database. + +**Options:** + +``--bind-key KEY`` + Target specific config. + +``--dry-run`` + Preview changes without applying them. + +``--yes`` + Skip confirmation prompt (useful for CI/CD). + +``--no-database`` + Only rename files, skip database record updates. + +**Examples:** + +.. code-block:: bash + + # Preview what would change + sqlspec --config myapp.config fix --dry-run + + # Apply changes with confirmation + sqlspec --config myapp.config fix + + # CI/CD mode (auto-approve) + sqlspec --config myapp.config fix --yes + + # Only fix files, don't update database + sqlspec --config myapp.config fix --no-database + +**Before Fix:** + +.. code-block:: text + + migrations/ + ├── 0001_initial.sql + ├── 0002_add_users.sql + ├── 20251011120000_add_products.sql # Timestamp format + ├── 20251012130000_add_orders.sql # Timestamp format + +**After Fix:** + +.. code-block:: text + + migrations/ + ├── 0001_initial.sql + ├── 0002_add_users.sql + ├── 0003_add_products.sql # Converted to sequential + ├── 0004_add_orders.sql # Converted to sequential + +**What Gets Updated:** + +1. **File Names**: ``20251011120000_add_products.sql`` → ``0003_add_products.sql`` +2. **SQL Query Names**: ``-- name: migrate-20251011120000-up`` → ``-- name: migrate-0003-up`` +3. **Database Records**: Version tracking table updated to reflect new version numbers + +**Backup & Safety:** + +The command automatically creates a timestamped backup before making changes: + +.. code-block:: text + + migrations/ + ├── .backup_20251012_143022/ # Automatic backup + │ ├── 20251011120000_add_products.sql + │ └── 20251012130000_add_orders.sql + ├── 0003_add_products.sql + └── 0004_add_orders.sql + +If conversion fails, files are automatically restored from backup. +Remove backup with ``rm -rf migrations/.backup_*`` after verifying success. + +**Auto-Sync Integration:** + +As of SQLSpec 0.18+, the ``upgrade`` command automatically reconciles renamed migrations +when you pull changes from teammates. This means developers typically don't need to run +``fix`` manually after pulling - just run ``upgrade`` and it handles reconciliation +automatically. + +The ``fix`` command is still useful for: + +- **Pre-merge CI**: Convert timestamps before merging to main branch +- **Initial conversion**: One-time conversion of existing timestamp migrations +- **Manual control**: When you've disabled auto-sync and want explicit control + +See the :ref:`hybrid-versioning-guide` for complete workflows and examples. + +**Use Cases:** + +- **Pre-merge CI check**: Convert timestamps before merging to main branch +- **Production deployment**: Ensure deterministic migration ordering +- **Repository cleanup**: Standardize on sequential format after development + +.. seealso:: + + :ref:`hybrid-versioning-guide` for complete workflow documentation and best practices. + +.. warning:: + + Always commit migration files before running ``fix`` command. While automatic + backups are created, version control provides the safest recovery option. + stamp ^^^^^ diff --git a/docs/usage/index.rst b/docs/usage/index.rst index d382c90d..3961c70d 100644 --- a/docs/usage/index.rst +++ b/docs/usage/index.rst @@ -13,6 +13,7 @@ This section provides comprehensive guides on using SQLSpec for database operati query_builder sql_files cli + migrations framework_integrations Overview @@ -38,6 +39,10 @@ SQLSpec provides a unified interface for database operations across multiple bac **Command Line Interface** Use the SQLSpec CLI for migrations, with shell completion support for bash, zsh, and fish. +**Database Migrations** + Manage database schema changes with support for hybrid versioning, automatic schema migration, + and extension migrations. + **Framework Integrations** Integrate SQLSpec with Litestar, FastAPI, and other Python web frameworks. diff --git a/docs/usage/migrations.rst b/docs/usage/migrations.rst new file mode 100644 index 00000000..7e549877 --- /dev/null +++ b/docs/usage/migrations.rst @@ -0,0 +1,562 @@ +.. _migrations-guide: + +=================== +Database Migrations +=================== + +SQLSpec provides a comprehensive migration system for managing database schema changes +over time. The migration system supports both SQL and Python migrations with automatic +tracking, version reconciliation, and hybrid versioning workflows. + +.. contents:: Table of Contents + :local: + :depth: 2 + +Quick Start +=========== + +Initialize Migrations +--------------------- + +.. code-block:: bash + + # Initialize migration directory + sqlspec --config myapp.config init + + # Create your first migration + sqlspec --config myapp.config create-migration -m "Initial schema" + + # Apply migrations + sqlspec --config myapp.config upgrade + +Configuration +============= + +Enable migrations in your SQLSpec configuration: + +.. code-block:: python + + from sqlspec.adapters.asyncpg import AsyncpgConfig + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://user:pass@localhost/mydb"}, + migration_config={ + "enabled": True, + "script_location": "migrations", + "version_table_name": "ddl_migrations", + "auto_sync": True, # Enable automatic version reconciliation + } + ) + +Configuration Options +--------------------- + +``enabled`` + **Type:** ``bool`` + **Default:** ``False`` + + Enable or disable migrations for this configuration. + +``script_location`` + **Type:** ``str`` + **Default:** ``"migrations"`` + + Path to migration files directory (relative to project root). + +``version_table_name`` + **Type:** ``str`` + **Default:** ``"ddl_migrations"`` + + Name of the table used to track applied migrations. + +``auto_sync`` + **Type:** ``bool`` + **Default:** ``True`` + + Enable automatic version reconciliation when migrations are renamed. + When ``True``, the ``upgrade`` command automatically updates database + tracking when migrations have been converted from timestamp to sequential + format using the ``fix`` command. + +``project_root`` + **Type:** ``Path | str | None`` + **Default:** ``None`` + + Root directory for Python migration imports. If not specified, uses + the parent directory of ``script_location``. + +Migration Files +=============== + +SQL Migrations +-------------- + +SQL migrations use the aiosql-style named query format: + +.. code-block:: sql + + -- migrations/0001_initial.sql + + -- name: migrate-0001-up + CREATE TABLE users ( + id SERIAL PRIMARY KEY, + email TEXT NOT NULL UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX idx_users_email ON users(email); + + -- name: migrate-0001-down + DROP TABLE users; + +**Naming Convention:** + +- File: ``{version}_{description}.sql`` +- Upgrade query: ``migrate-{version}-up`` +- Downgrade query: ``migrate-{version}-down`` (optional) + +Python Migrations +----------------- + +Python migrations provide more flexibility for complex operations: + +.. code-block:: python + + # migrations/0002_add_user_roles.py + """Add user roles table + + Revision ID: 0002_add_user_roles + Created at: 2025-10-18 12:00:00 + """ + + def upgrade(): + """Apply migration.""" + return """ + CREATE TABLE user_roles ( + id SERIAL PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + role VARCHAR(50) NOT NULL + ); + """ + + def downgrade(): + """Revert migration.""" + return """ + DROP TABLE user_roles; + """ + +**Advanced Usage:** + +Python migrations can also return a list of SQL statements: + +.. code-block:: python + + def upgrade(): + """Apply migration in multiple steps.""" + return [ + "CREATE TABLE products (id SERIAL PRIMARY KEY);", + "CREATE TABLE orders (id SERIAL PRIMARY KEY, product_id INTEGER);", + "CREATE INDEX idx_orders_product ON orders(product_id);", + ] + +.. _hybrid-versioning-guide: + +Hybrid Versioning +================= + +SQLSpec supports a hybrid versioning workflow that combines timestamp-based versions +during development with sequential versions in production. + +Overview +-------- + +**Problem:** Timestamp versions (``20251018120000``) prevent merge conflicts when multiple +developers create migrations simultaneously, but sequential versions (``0001``) provide +more predictable ordering in production. + +**Solution:** Use timestamps during development, then convert to sequential numbers before +deploying to production using the ``fix`` command. + +Workflow +-------- + +**1. Development - Use Timestamps** + +.. code-block:: bash + + # Developer A creates migration + sqlspec --config myapp.config create-migration -m "Add users table" + # Creates: 20251018120000_add_users_table.sql + + # Developer B creates migration (same day) + sqlspec --config myapp.config create-migration -m "Add products table" + # Creates: 20251018123000_add_products_table.sql + +**2. Pre-Merge - Convert to Sequential** + +Before merging to main branch (typically in CI): + +.. code-block:: bash + + # Preview changes + sqlspec --config myapp.config fix --dry-run + + # Apply conversion + sqlspec --config myapp.config fix --yes + + # Results: + # 20251018120000_add_users_table.sql → 0001_add_users_table.sql + # 20251018123000_add_products_table.sql → 0002_add_products_table.sql + +**3. After Pull - Auto-Sync** + +When teammates pull your converted migrations, they don't need to do anything special: + +.. code-block:: bash + + git pull origin main + + # Just run upgrade - auto-sync handles reconciliation + sqlspec --config myapp.config upgrade + +Auto-sync automatically detects renamed migrations using checksums and updates +the database tracking table to reflect the new version numbers. + +Version Formats +--------------- + +**Sequential Format** + Pattern: ``^(\d+)$`` + + Examples: ``0001``, ``0042``, ``9999``, ``10000`` + + - Used in production + - Deterministic ordering + - Human-readable sequence + - No upper limit (4-digit cap removed) + +**Timestamp Format** + Pattern: ``^(\d{14})$`` + + Example: ``20251018120000`` (2025-10-18 12:00:00 UTC) + + - Used during development + - Prevents merge conflicts + - Chronologically ordered + - UTC timezone + +Version Comparison +------------------ + +SQLSpec uses type-aware version comparison: + +.. code-block:: python + + from sqlspec.utils.version import parse_version + + v1 = parse_version("0001") + v2 = parse_version("20251018120000") + + # Sequential < Timestamp (by design) + assert v1 < v2 + + # Same type comparisons work naturally + assert parse_version("0001") < parse_version("0002") + assert parse_version("20251018120000") < parse_version("20251019120000") + +Migration Tracking +================== + +Schema +------ + +SQLSpec uses a tracking table to record applied migrations: + +.. code-block:: sql + + CREATE TABLE ddl_migrations ( + version_num VARCHAR(32) PRIMARY KEY, + version_type VARCHAR(16), -- 'sequential' or 'timestamp' + execution_sequence INTEGER, -- Order of execution + description TEXT, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + execution_time_ms INTEGER, + checksum VARCHAR(64), -- MD5 hash for auto-sync + applied_by VARCHAR(255) + ); + +**Columns:** + +``version_num`` + The migration version (e.g., ``"0001"`` or ``"20251018120000"``). + +``version_type`` + Format indicator: ``"sequential"`` or ``"timestamp"``. + +``execution_sequence`` + Auto-incrementing counter showing actual application order. + Preserves history when out-of-order migrations are applied. + +``checksum`` + MD5 hash of migration content. Used by auto-sync to match + renamed migrations (e.g., timestamp → sequential conversion). + +``applied_by`` + Unix username of user who applied the migration. + +Schema Migration +---------------- + +When upgrading from older SQLSpec versions, the tracking table schema is automatically +migrated to add the new columns (``execution_sequence``, ``version_type``, ``checksum``). + +This happens transparently when you run any migration command: + +.. code-block:: bash + + # First upgrade after updating SQLSpec + sqlspec --config myapp.config upgrade + + # Output: + # Migrating tracking table schema, adding columns: checksum, execution_sequence, version_type + # Migration tracking table schema updated successfully + +The schema migration: + +1. Detects missing columns using database metadata queries +2. Adds columns one by one using ``ALTER TABLE`` +3. Populates ``execution_sequence`` based on ``applied_at`` timestamps +4. Preserves all existing migration history + +Extension Migrations +==================== + +SQLSpec supports independent migration versioning for extensions and plugins. + +Configuration +------------- + +.. code-block:: python + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://..."}, + migration_config={ + "enabled": True, + "script_location": "migrations", + "include_extensions": ["litestar"], # Enable litestar extension migrations + }, + extension_config={ + "litestar": { + "enable_repository_pattern": True, + "enable_dto_generation": False, + } + } + ) + +Directory Structure +------------------- + +Extension migrations are stored separately: + +.. code-block:: text + + migrations/ + ├── 0001_initial.sql # Main migrations + ├── 0002_add_users.sql + └── (extension migrations stored in package) + + # Extension migrations location (in package): + sqlspec/extensions/litestar/migrations/ + ├── 0001_create_litestar_metadata.sql + └── 0002_add_request_logging.sql + +Version Prefixes +---------------- + +Extension migrations are prefixed to avoid conflicts: + +.. code-block:: text + + Main migrations: 0001, 0002, 0003 + Litestar migrations: ext_litestar_0001, ext_litestar_0002 + Custom extension: ext_myext_0001, ext_myext_0002 + +This allows each extension to maintain its own sequential numbering while +preventing version conflicts. + +Commands +-------- + +Extension migrations are managed alongside main migrations: + +.. code-block:: bash + + # Upgrade includes extension migrations + sqlspec --config myapp.config upgrade + + # Show all migrations (including extensions) + sqlspec --config myapp.config show-current-revision --verbose + +Advanced Topics +=============== + +Out-of-Order Migrations +----------------------- + +When migrations are created out of chronological order (e.g., from late-merging branches), +SQLSpec detects this and logs a warning: + +.. code-block:: text + + WARNING: Out-of-order migration detected + Migration 20251017100000_feature_a was created before + already-applied migration 20251018090000_main_branch + + This can happen when: + - A feature branch was created before a migration on main + - Migrations from different branches are merged + +The migration is still applied, and ``execution_sequence`` preserves the actual +application order for auditing. + +Manual Version Reconciliation +------------------------------ + +If auto-sync is disabled, manually reconcile renamed migrations: + +.. code-block:: python + + from sqlspec.migrations.tracker import AsyncMigrationTracker + + tracker = AsyncMigrationTracker() + + async with config.provide_session() as session: + driver = session._driver + + # Update version record + await tracker.update_version_record( + driver, + old_version="20251018120000", + new_version="0003" + ) + +Troubleshooting +=============== + +Migration Not Applied +--------------------- + +**Symptom:** Migration exists but isn't being applied. + +**Checks:** + +1. Verify migration file naming: ``{version}_{description}.sql`` +2. Check query names: ``migrate-{version}-up`` and ``migrate-{version}-down`` +3. Ensure version isn't already in tracking table: + + .. code-block:: bash + + sqlspec --config myapp.config show-current-revision --verbose + +Version Mismatch After Fix +--------------------------- + +**Symptom:** After running ``fix``, database still shows old timestamp versions. + +**Solution:** Ensure auto-sync is enabled (default): + +.. code-block:: bash + + # Should auto-reconcile + sqlspec --config myapp.config upgrade + + # Or manually run fix with database update + sqlspec --config myapp.config fix # (database update is default) + +Schema Migration Fails +----------------------- + +**Symptom:** Error adding columns to tracking table. + +**Cause:** Usually insufficient permissions or incompatible database version. + +**Solution:** + +1. Ensure database user has ``ALTER TABLE`` permissions +2. Check database version compatibility +3. Manually add missing columns if needed: + + .. code-block:: sql + + ALTER TABLE ddl_migrations ADD COLUMN execution_sequence INTEGER; + ALTER TABLE ddl_migrations ADD COLUMN version_type VARCHAR(16); + ALTER TABLE ddl_migrations ADD COLUMN checksum VARCHAR(64); + +Best Practices +============== + +1. **Always Use Version Control** + + Commit migration files immediately after creation: + + .. code-block:: bash + + git add migrations/ + git commit -m "Add user authentication migration" + +2. **Test Migrations Both Ways** + + Always test both upgrade and downgrade: + + .. code-block:: bash + + sqlspec --config myapp.config upgrade + sqlspec --config myapp.config downgrade + +3. **Use Dry Run in Production** + + Preview changes before applying: + + .. code-block:: bash + + sqlspec --config myapp.config upgrade --dry-run + +4. **Backup Before Downgrade** + + Downgrades can cause data loss: + + .. code-block:: bash + + pg_dump mydb > backup_$(date +%Y%m%d_%H%M%S).sql + sqlspec --config myapp.config downgrade + +5. **Run Fix in CI** + + Automate timestamp → sequential conversion: + + .. code-block:: yaml + + # .github/workflows/migrations.yml + - name: Convert timestamp migrations + run: | + sqlspec --config myapp.config fix --dry-run + sqlspec --config myapp.config fix --yes + +6. **Descriptive Migration Names** + + Use clear, action-oriented descriptions: + + .. code-block:: bash + + # Good + sqlspec --config myapp.config create-migration -m "Add email index to users" + + # Bad + sqlspec --config myapp.config create-migration -m "update users" + +See Also +======== + +- :doc:`../usage/cli` - Complete CLI command reference +- :doc:`../usage/configuration` - Migration configuration options +- :doc:`../reference/migrations_api` - Migration API reference diff --git a/sqlspec/adapters/adbc/data_dictionary.py b/sqlspec/adapters/adbc/data_dictionary.py index 999e6c5c..4701f912 100644 --- a/sqlspec/adapters/adbc/data_dictionary.py +++ b/sqlspec/adapters/adbc/data_dictionary.py @@ -1,7 +1,7 @@ """ADBC multi-dialect data dictionary for metadata queries.""" import re -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger @@ -268,6 +268,56 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> return type_map.get(type_category, "TEXT") + def get_columns( + self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table based on detected dialect. + + Args: + driver: ADBC driver instance + table: Table name to query columns for + schema: Schema name (None for default) + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: Database data type + - is_nullable or nullable: Whether column allows NULL + - column_default or default_value: Default value if any + """ + dialect = self._get_dialect(driver) + adbc_driver = cast("AdbcDriver", driver) + + if dialect == "sqlite": + result = adbc_driver.execute(f"PRAGMA table_info({table})") + return [ + { + "column_name": row["name"] if isinstance(row, dict) else row[1], + "data_type": row["type"] if isinstance(row, dict) else row[2], + "nullable": not (row["notnull"] if isinstance(row, dict) else row[3]), + "default_value": row["dflt_value"] if isinstance(row, dict) else row[4], + } + for row in result.data or [] + ] + + if schema: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = '{schema}' + ORDER BY ordinal_position + """ + else: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' + ORDER BY ordinal_position + """ + + result = adbc_driver.execute(sql) + return result.data or [] + def list_available_features(self) -> "list[str]": """List available feature flags across all supported dialects. diff --git a/sqlspec/adapters/aiosqlite/data_dictionary.py b/sqlspec/adapters/aiosqlite/data_dictionary.py index 8bdbf080..d841cce2 100644 --- a/sqlspec/adapters/aiosqlite/data_dictionary.py +++ b/sqlspec/adapters/aiosqlite/data_dictionary.py @@ -1,7 +1,7 @@ """SQLite-specific data dictionary for metadata queries via aiosqlite.""" import re -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger @@ -99,6 +99,36 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: type_map = {"uuid": "TEXT", "boolean": "INTEGER", "timestamp": "TIMESTAMP", "text": "TEXT", "blob": "BLOB"} return type_map.get(type_category, "TEXT") + async def get_columns( + self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using SQLite PRAGMA. + + Args: + driver: AioSQLite driver instance + table: Table name to query columns for + schema: Schema name (unused in SQLite) + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: SQLite data type + - nullable: Whether column allows NULL + - default_value: Default value if any + """ + aiosqlite_driver = cast("AiosqliteDriver", driver) + result = await aiosqlite_driver.execute(f"PRAGMA table_info({table})") + + return [ + { + "column_name": row["name"] if isinstance(row, dict) else row[1], + "data_type": row["type"] if isinstance(row, dict) else row[2], + "nullable": not (row["notnull"] if isinstance(row, dict) else row[3]), + "default_value": row["dflt_value"] if isinstance(row, dict) else row[4], + } + for row in result.data or [] + ] + def list_available_features(self) -> "list[str]": """List available SQLite feature flags. diff --git a/sqlspec/adapters/asyncmy/data_dictionary.py b/sqlspec/adapters/asyncmy/data_dictionary.py index 73af49b4..c8bd8142 100644 --- a/sqlspec/adapters/asyncmy/data_dictionary.py +++ b/sqlspec/adapters/asyncmy/data_dictionary.py @@ -1,7 +1,7 @@ """MySQL-specific data dictionary for metadata queries via asyncmy.""" import re -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger @@ -104,6 +104,43 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "VARCHAR(255)") + async def get_columns( + self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using information_schema. + + Args: + driver: AsyncMy driver instance + table: Table name to query columns for + schema: Schema name (database name in MySQL) + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: MySQL data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + """ + asyncmy_driver = cast("AsyncmyDriver", driver) + + if schema: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = '{schema}' + ORDER BY ordinal_position + """ + else: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' + ORDER BY ordinal_position + """ + + result = await asyncmy_driver.execute(sql) + return result.data or [] + def list_available_features(self) -> "list[str]": """List available MySQL feature flags. diff --git a/sqlspec/adapters/asyncpg/data_dictionary.py b/sqlspec/adapters/asyncpg/data_dictionary.py index a6b3ddb2..89f3aa28 100644 --- a/sqlspec/adapters/asyncpg/data_dictionary.py +++ b/sqlspec/adapters/asyncpg/data_dictionary.py @@ -1,7 +1,7 @@ """PostgreSQL-specific data dictionary for metadata queries via asyncpg.""" import re -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger @@ -114,6 +114,43 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "TEXT") + async def get_columns( + self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using information_schema. + + Args: + driver: AsyncPG driver instance + table: Table name to query columns for + schema: Schema name (None for default 'public') + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: PostgreSQL data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + """ + asyncpg_driver = cast("AsyncpgDriver", driver) + + if schema: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = '{schema}' + ORDER BY ordinal_position + """ + else: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = 'public' + ORDER BY ordinal_position + """ + + result = await asyncpg_driver.execute(sql) + return result.data or [] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. diff --git a/sqlspec/adapters/bigquery/data_dictionary.py b/sqlspec/adapters/bigquery/data_dictionary.py index 80f55809..dc4eb5d3 100644 --- a/sqlspec/adapters/bigquery/data_dictionary.py +++ b/sqlspec/adapters/bigquery/data_dictionary.py @@ -1,8 +1,13 @@ """BigQuery-specific data dictionary for metadata queries.""" +from typing import TYPE_CHECKING, Any, cast + from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger +if TYPE_CHECKING: + from sqlspec.adapters.bigquery.driver import BigQueryDriver + logger = get_logger("adapters.bigquery.data_dictionary") __all__ = ("BigQuerySyncDataDictionary",) @@ -83,6 +88,43 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> } return type_map.get(type_category, "STRING") + def get_columns( + self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using INFORMATION_SCHEMA. + + Args: + driver: BigQuery driver instance + table: Table name to query columns for + schema: Schema name (dataset name in BigQuery) + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: BigQuery data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + """ + bigquery_driver = cast("BigQueryDriver", driver) + + if schema: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM `{schema}.INFORMATION_SCHEMA.COLUMNS` + WHERE table_name = '{table}' + ORDER BY ordinal_position + """ + else: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM INFORMATION_SCHEMA.COLUMNS + WHERE table_name = '{table}' + ORDER BY ordinal_position + """ + + result = bigquery_driver.execute(sql) + return result.data or [] + def list_available_features(self) -> "list[str]": """List available BigQuery feature flags. diff --git a/sqlspec/adapters/duckdb/data_dictionary.py b/sqlspec/adapters/duckdb/data_dictionary.py index 3b9fdc8c..c4c43957 100644 --- a/sqlspec/adapters/duckdb/data_dictionary.py +++ b/sqlspec/adapters/duckdb/data_dictionary.py @@ -1,7 +1,7 @@ """DuckDB-specific data dictionary for metadata queries.""" import re -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger @@ -104,6 +104,43 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> } return type_map.get(type_category, "VARCHAR") + def get_columns( + self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using information_schema. + + Args: + driver: DuckDB driver instance + table: Table name to query columns for + schema: Schema name (None for default) + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: DuckDB data type + - nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + """ + duckdb_driver = cast("DuckDBDriver", driver) + + if schema: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = '{schema}' + ORDER BY ordinal_position + """ + else: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' + ORDER BY ordinal_position + """ + + result = duckdb_driver.execute(sql) + return result.data or [] + def list_available_features(self) -> "list[str]": """List available DuckDB feature flags. diff --git a/sqlspec/adapters/oracledb/data_dictionary.py b/sqlspec/adapters/oracledb/data_dictionary.py index 30ca9e38..9c708c73 100644 --- a/sqlspec/adapters/oracledb/data_dictionary.py +++ b/sqlspec/adapters/oracledb/data_dictionary.py @@ -3,7 +3,7 @@ import re from contextlib import suppress -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import ( AsyncDataDictionaryBase, @@ -106,6 +106,26 @@ class OracleDataDictionaryMixin: __slots__ = () + def _get_columns_sql(self, table: str, schema: "str | None" = None) -> str: + """Get SQL to query column metadata from Oracle data dictionary. + + Uses USER_TAB_COLUMNS which returns column names in UPPERCASE. + + Args: + table: Table name to query columns for + schema: Schema name (unused for USER_TAB_COLUMNS) + + Returns: + SQL string for Oracle's USER_TAB_COLUMNS query + """ + _ = schema + return f""" + SELECT column_name, data_type, data_length, nullable + FROM user_tab_columns + WHERE table_name = '{table.upper()}' + ORDER BY column_id + """ + def _get_oracle_version(self, driver: "OracleAsyncDriver | OracleSyncDriver") -> "OracleVersionInfo | None": """Get Oracle database version information. @@ -272,6 +292,28 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> } return type_map.get(type_category, "VARCHAR2(255)") + def get_columns( + self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table from Oracle data dictionary. + + Args: + driver: Database driver instance + table: Table name to query columns for + schema: Schema name (ignored for USER_TAB_COLUMNS) + + Returns: + List of column metadata dictionaries with keys: + - COLUMN_NAME: Name of the column (UPPERCASE in Oracle) + - DATA_TYPE: Oracle data type + - DATA_LENGTH: Maximum length (for character types) + - NULLABLE: 'Y' or 'N' + """ + + oracle_driver = cast("OracleSyncDriver", driver) + result = oracle_driver.execute(self._get_columns_sql(table, schema)) + return result.get_data() + def list_available_features(self) -> "list[str]": """List available Oracle feature flags. @@ -421,6 +463,28 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: type_map = {"uuid": "RAW(16)", "boolean": "NUMBER(1)", "timestamp": "TIMESTAMP", "text": "CLOB", "blob": "BLOB"} return type_map.get(type_category, "VARCHAR2(255)") + async def get_columns( + self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table from Oracle data dictionary. + + Args: + driver: Async database driver instance + table: Table name to query columns for + schema: Schema name (ignored for USER_TAB_COLUMNS) + + Returns: + List of column metadata dictionaries with keys: + - COLUMN_NAME: Name of the column (UPPERCASE in Oracle) + - DATA_TYPE: Oracle data type + - DATA_LENGTH: Maximum length (for character types) + - NULLABLE: 'Y' or 'N' + """ + + oracle_driver = cast("OracleAsyncDriver", driver) + result = await oracle_driver.execute(self._get_columns_sql(table, schema)) + return result.get_data() + def list_available_features(self) -> "list[str]": """List available Oracle feature flags. diff --git a/sqlspec/adapters/oracledb/migrations.py b/sqlspec/adapters/oracledb/migrations.py index c15d6fab..2119c15e 100644 --- a/sqlspec/adapters/oracledb/migrations.py +++ b/sqlspec/adapters/oracledb/migrations.py @@ -5,11 +5,14 @@ """ import getpass -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any -from sqlspec.builder import CreateTable, sql +from rich.console import Console + +from sqlspec.builder import CreateTable, Select, sql from sqlspec.migrations.base import BaseMigrationTracker from sqlspec.utils.logging import get_logger +from sqlspec.utils.version import parse_version if TYPE_CHECKING: from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase @@ -17,10 +20,20 @@ __all__ = ("OracleAsyncMigrationTracker", "OracleSyncMigrationTracker") logger = get_logger("migrations.oracle") +console = Console() class OracleMigrationTrackerMixin: - """Mixin providing Oracle-specific migration table creation.""" + """Mixin providing Oracle-specific migration table creation and querying. + + Oracle has unique identifier handling rules: + - Unquoted identifiers are case-insensitive and stored as UPPERCASE + - Quoted identifiers are case-sensitive and stored exactly as written + + This mixin overrides SQL builder methods to add quoted identifiers for + all column references, ensuring they match the lowercase column names + created by the migration table. + """ __slots__ = () @@ -40,6 +53,8 @@ def _get_create_table_sql(self) -> CreateTable: return ( sql.create_table(self.version_table) .column("version_num", "VARCHAR2(32)", primary_key=True) + .column("version_type", "VARCHAR2(16)") + .column("execution_sequence", "INTEGER") .column("description", "VARCHAR2(2000)") .column("applied_at", "TIMESTAMP", default="CURRENT_TIMESTAMP") .column("execution_time_ms", "INTEGER") @@ -47,16 +62,152 @@ def _get_create_table_sql(self) -> CreateTable: .column("applied_by", "VARCHAR2(255)") ) + def _get_current_version_sql(self) -> Select: + """Get Oracle-specific SQL for retrieving current version. + + Uses uppercase column names with lowercase aliases to match Python expectations. + Oracle stores unquoted identifiers as UPPERCASE, so we query UPPERCASE columns + and alias them as quoted "lowercase" for result consistency. + + Returns: + SQL builder object for version query. + """ + return ( + sql.select('VERSION_NUM AS "version_num"') + .from_(self.version_table) + .order_by("EXECUTION_SEQUENCE DESC") + .limit(1) + ) + + def _get_applied_migrations_sql(self) -> Select: + """Get Oracle-specific SQL for retrieving all applied migrations. + + Uses uppercase column names with lowercase aliases to match Python expectations. + Oracle stores unquoted identifiers as UPPERCASE, so we query UPPERCASE columns + and alias them as quoted "lowercase" for result consistency. + + Returns: + SQL builder object for migrations query. + """ + return ( + sql.select( + 'VERSION_NUM AS "version_num"', + 'VERSION_TYPE AS "version_type"', + 'EXECUTION_SEQUENCE AS "execution_sequence"', + 'DESCRIPTION AS "description"', + 'APPLIED_AT AS "applied_at"', + 'EXECUTION_TIME_MS AS "execution_time_ms"', + 'CHECKSUM AS "checksum"', + 'APPLIED_BY AS "applied_by"', + ) + .from_(self.version_table) + .order_by("EXECUTION_SEQUENCE") + ) + + def _get_next_execution_sequence_sql(self) -> Select: + """Get Oracle-specific SQL for retrieving next execution sequence. + + Uses uppercase column names with lowercase alias to match Python expectations. + Oracle stores unquoted identifiers as UPPERCASE, so we query UPPERCASE columns + and alias them as quoted "lowercase" for result consistency. + + Returns: + SQL builder object for sequence query. + """ + return sql.select('COALESCE(MAX(EXECUTION_SEQUENCE), 0) + 1 AS "next_seq"').from_(self.version_table) + + def _get_existing_columns_sql(self) -> str: + """Get SQL to query existing columns in the tracking table. + + Returns: + Raw SQL string for Oracle's USER_TAB_COLUMNS query. + """ + return f""" + SELECT column_name + FROM user_tab_columns + WHERE table_name = '{self.version_table.upper()}' + """ + + def _detect_missing_columns(self, existing_columns: "set[str]") -> "set[str]": + """Detect which columns are missing from the current schema. + + Args: + existing_columns: Set of existing column names (uppercase). + + Returns: + Set of missing column names (lowercase). + """ + target_create = self._get_create_table_sql() + target_columns = {col.name.lower() for col in target_create.columns} + existing_lower = {col.lower() for col in existing_columns} + return target_columns - existing_lower + class OracleSyncMigrationTracker(OracleMigrationTrackerMixin, BaseMigrationTracker["SyncDriverAdapterBase"]): """Oracle-specific sync migration tracker.""" __slots__ = () + def _migrate_schema_if_needed(self, driver: "SyncDriverAdapterBase") -> None: + """Check for and add any missing columns to the tracking table. + + Uses the driver's data dictionary to query existing columns from Oracle's + USER_TAB_COLUMNS metadata table. + + Args: + driver: The database driver to use. + """ + try: + columns_data = driver.data_dictionary.get_columns(driver, self.version_table) + existing_columns = {row["COLUMN_NAME"] for row in columns_data} + missing_columns = self._detect_missing_columns(existing_columns) + + if not missing_columns: + logger.debug("Migration tracking table schema is up-to-date") + return + + console.print( + f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]" + ) + + for col_name in sorted(missing_columns): + self._add_column(driver, col_name) + + driver.commit() + console.print("[green]Migration tracking table schema updated successfully[/]") + + except Exception as e: + logger.warning("Could not check or migrate tracking table schema: %s", e) + + def _add_column(self, driver: "SyncDriverAdapterBase", column_name: str) -> None: + """Add a single column to the tracking table. + + Args: + driver: The database driver to use. + column_name: Name of the column to add (lowercase). + """ + target_create = self._get_create_table_sql() + column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None) + + if not column_def: + return + + default_clause = f" DEFAULT {column_def.default}" if column_def.default else "" + not_null_clause = " NOT NULL" if column_def.not_null else "" + + alter_sql = f""" + ALTER TABLE {self.version_table} + ADD {column_def.name} {column_def.dtype}{default_clause}{not_null_clause} + """ + + driver.execute(alter_sql) + logger.debug("Added column %s to tracking table", column_name) + def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None: """Create the migration tracking table if it doesn't exist. Uses a PL/SQL block to make the operation atomic and prevent race conditions. + Also checks for and adds missing columns to support schema migrations. Args: driver: The database driver to use. @@ -66,6 +217,8 @@ def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None: EXECUTE IMMEDIATE ' CREATE TABLE {self.version_table} ( version_num VARCHAR2(32) PRIMARY KEY, + version_type VARCHAR2(16), + execution_sequence INTEGER, description VARCHAR2(2000), applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, execution_time_ms INTEGER, @@ -84,6 +237,8 @@ def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None: driver.execute_script(create_script) driver.commit() + self._migrate_schema_if_needed(driver) + def get_current_version(self, driver: "SyncDriverAdapterBase") -> "str | None": """Get the latest applied migration version. @@ -94,7 +249,8 @@ def get_current_version(self, driver: "SyncDriverAdapterBase") -> "str | None": The current migration version or None if no migrations applied. """ result = driver.execute(self._get_current_version_sql()) - return result.data[0]["VERSION_NUM"] if result.data else None + data = result.get_data() + return data[0]["version_num"] if data else None def get_applied_migrations(self, driver: "SyncDriverAdapterBase") -> "list[dict[str, Any]]": """Get all applied migrations in order. @@ -103,15 +259,10 @@ def get_applied_migrations(self, driver: "SyncDriverAdapterBase") -> "list[dict[ driver: The database driver to use. Returns: - List of migration records as dictionaries. + List of migration records as dictionaries with lowercase keys. """ result = driver.execute(self._get_applied_migrations_sql()) - if not result.data: - return [] - - normalized_data = [{key.lower(): value for key, value in row.items()} for row in result.data] - - return cast("list[dict[str, Any]]", normalized_data) + return result.get_data() def record_migration( self, driver: "SyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str @@ -125,10 +276,17 @@ def record_migration( execution_time_ms: Execution time in milliseconds. checksum: MD5 checksum of the migration content. """ - applied_by = getpass.getuser() + parsed_version = parse_version(version) + version_type = parsed_version.type.value + + next_seq_result = driver.execute(self._get_next_execution_sequence_sql()) + seq_data = next_seq_result.get_data() + execution_sequence = seq_data[0]["next_seq"] if seq_data else 1 - record_sql = self._get_record_migration_sql(version, description, execution_time_ms, checksum, applied_by) + record_sql = self._get_record_migration_sql( + version, version_type, execution_sequence, description, execution_time_ms, checksum, applied_by + ) driver.execute(record_sql) driver.commit() @@ -143,16 +301,107 @@ def remove_migration(self, driver: "SyncDriverAdapterBase", version: str) -> Non driver.execute(remove_sql) driver.commit() + def update_version_record(self, driver: "SyncDriverAdapterBase", old_version: str, new_version: str) -> None: + """Update migration version record from timestamp to sequential. + + Updates version_num and version_type while preserving execution_sequence, + applied_at, and other tracking metadata. Used during fix command. + + Idempotent: If the version is already updated, logs and continues without error. + This allows fix command to be safely re-run after pulling changes. + + Args: + driver: The database driver to use. + old_version: Current timestamp version string. + new_version: New sequential version string. + + Raises: + ValueError: If neither old_version nor new_version found in database. + """ + parsed_new_version = parse_version(new_version) + new_version_type = parsed_new_version.type.value + + result = driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type)) + + if result.rows_affected == 0: + check_result = driver.execute(self._get_applied_migrations_sql()) + applied_versions = {row["version_num"] for row in check_result.data} if check_result.data else set() + + if new_version in applied_versions: + logger.debug("Version already updated: %s -> %s", old_version, new_version) + return + + msg = f"Migration {old_version} not found in database for update to {new_version}" + raise ValueError(msg) + + driver.commit() + class OracleAsyncMigrationTracker(OracleMigrationTrackerMixin, BaseMigrationTracker["AsyncDriverAdapterBase"]): """Oracle-specific async migration tracker.""" __slots__ = () + async def _migrate_schema_if_needed(self, driver: "AsyncDriverAdapterBase") -> None: + """Check for and add any missing columns to the tracking table. + + Uses the driver's data dictionary to query existing columns from Oracle's + USER_TAB_COLUMNS metadata table. + + Args: + driver: The database driver to use. + """ + try: + columns_data = await driver.data_dictionary.get_columns(driver, self.version_table) + existing_columns = {row["COLUMN_NAME"] for row in columns_data} + missing_columns = self._detect_missing_columns(existing_columns) + + if not missing_columns: + logger.debug("Migration tracking table schema is up-to-date") + return + + console.print( + f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]" + ) + + for col_name in sorted(missing_columns): + await self._add_column(driver, col_name) + + await driver.commit() + console.print("[green]Migration tracking table schema updated successfully[/]") + + except Exception as e: + logger.warning("Could not check or migrate tracking table schema: %s", e) + + async def _add_column(self, driver: "AsyncDriverAdapterBase", column_name: str) -> None: + """Add a single column to the tracking table. + + Args: + driver: The database driver to use. + column_name: Name of the column to add (lowercase). + """ + target_create = self._get_create_table_sql() + column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None) + + if not column_def: + return + + default_clause = f" DEFAULT {column_def.default}" if column_def.default else "" + not_null_clause = " NOT NULL" if column_def.not_null else "" + + alter_sql = f""" + ALTER TABLE {self.version_table} + ADD {column_def.name} {column_def.dtype}{default_clause}{not_null_clause} + """ + + await driver.execute(alter_sql) + logger.debug("Added column %s to tracking table", column_name) + async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None: """Create the migration tracking table if it doesn't exist. Uses a PL/SQL block to make the operation atomic and prevent race conditions. + Also checks for and adds missing columns to support schema migrations. Args: driver: The database driver to use. @@ -162,6 +411,8 @@ async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None: EXECUTE IMMEDIATE ' CREATE TABLE {self.version_table} ( version_num VARCHAR2(32) PRIMARY KEY, + version_type VARCHAR2(16), + execution_sequence INTEGER, description VARCHAR2(2000), applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, execution_time_ms INTEGER, @@ -180,6 +431,8 @@ async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None: await driver.execute_script(create_script) await driver.commit() + await self._migrate_schema_if_needed(driver) + async def get_current_version(self, driver: "AsyncDriverAdapterBase") -> "str | None": """Get the latest applied migration version. @@ -190,7 +443,8 @@ async def get_current_version(self, driver: "AsyncDriverAdapterBase") -> "str | The current migration version or None if no migrations applied. """ result = await driver.execute(self._get_current_version_sql()) - return result.data[0]["VERSION_NUM"] if result.data else None + data = result.get_data() + return data[0]["version_num"] if data else None async def get_applied_migrations(self, driver: "AsyncDriverAdapterBase") -> "list[dict[str, Any]]": """Get all applied migrations in order. @@ -199,15 +453,10 @@ async def get_applied_migrations(self, driver: "AsyncDriverAdapterBase") -> "lis driver: The database driver to use. Returns: - List of migration records as dictionaries. + List of migration records as dictionaries with lowercase keys. """ result = await driver.execute(self._get_applied_migrations_sql()) - if not result.data: - return [] - - normalized_data = [{key.lower(): value for key, value in row.items()} for row in result.data] - - return cast("list[dict[str, Any]]", normalized_data) + return result.get_data() async def record_migration( self, driver: "AsyncDriverAdapterBase", version: str, description: str, execution_time_ms: int, checksum: str @@ -223,8 +472,16 @@ async def record_migration( """ applied_by = getpass.getuser() + parsed_version = parse_version(version) + version_type = parsed_version.type.value + + next_seq_result = await driver.execute(self._get_next_execution_sequence_sql()) + seq_data = next_seq_result.get_data() + execution_sequence = seq_data[0]["next_seq"] if seq_data else 1 - record_sql = self._get_record_migration_sql(version, description, execution_time_ms, checksum, applied_by) + record_sql = self._get_record_migration_sql( + version, version_type, execution_sequence, description, execution_time_ms, checksum, applied_by + ) await driver.execute(record_sql) await driver.commit() @@ -238,3 +495,38 @@ async def remove_migration(self, driver: "AsyncDriverAdapterBase", version: str) remove_sql = self._get_remove_migration_sql(version) await driver.execute(remove_sql) await driver.commit() + + async def update_version_record(self, driver: "AsyncDriverAdapterBase", old_version: str, new_version: str) -> None: + """Update migration version record from timestamp to sequential. + + Updates version_num and version_type while preserving execution_sequence, + applied_at, and other tracking metadata. Used during fix command. + + Idempotent: If the version is already updated, logs and continues without error. + This allows fix command to be safely re-run after pulling changes. + + Args: + driver: The database driver to use. + old_version: Current timestamp version string. + new_version: New sequential version string. + + Raises: + ValueError: If neither old_version nor new_version found in database. + """ + parsed_new_version = parse_version(new_version) + new_version_type = parsed_new_version.type.value + + result = await driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type)) + + if result.rows_affected == 0: + check_result = await driver.execute(self._get_applied_migrations_sql()) + applied_versions = {row["version_num"] for row in check_result.data} if check_result.data else set() + + if new_version in applied_versions: + logger.debug("Version already updated: %s -> %s", old_version, new_version) + return + + msg = f"Migration {old_version} not found in database for update to {new_version}" + raise ValueError(msg) + + await driver.commit() diff --git a/sqlspec/adapters/psqlpy/data_dictionary.py b/sqlspec/adapters/psqlpy/data_dictionary.py index daf9ab3d..f49bd408 100644 --- a/sqlspec/adapters/psqlpy/data_dictionary.py +++ b/sqlspec/adapters/psqlpy/data_dictionary.py @@ -1,7 +1,7 @@ """PostgreSQL-specific data dictionary for metadata queries via psqlpy.""" import re -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import AsyncDataDictionaryBase, AsyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger @@ -113,6 +113,43 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "TEXT") + async def get_columns( + self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using information_schema. + + Args: + driver: Psqlpy async driver instance + table: Table name to query columns for + schema: Schema name (None for default 'public') + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: PostgreSQL data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + """ + psqlpy_driver = cast("PsqlpyDriver", driver) + + if schema: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = '{schema}' + ORDER BY ordinal_position + """ + else: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = 'public' + ORDER BY ordinal_position + """ + + result = await psqlpy_driver.execute(sql) + return result.data or [] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. diff --git a/sqlspec/adapters/psycopg/data_dictionary.py b/sqlspec/adapters/psycopg/data_dictionary.py index 94171a59..5a78dc40 100644 --- a/sqlspec/adapters/psycopg/data_dictionary.py +++ b/sqlspec/adapters/psycopg/data_dictionary.py @@ -1,7 +1,7 @@ """PostgreSQL-specific data dictionary for metadata queries via psycopg.""" import re -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import ( AsyncDataDictionaryBase, @@ -119,6 +119,43 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> } return type_map.get(type_category, "TEXT") + def get_columns( + self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using information_schema. + + Args: + driver: Psycopg sync driver instance + table: Table name to query columns for + schema: Schema name (None for default 'public') + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: PostgreSQL data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + """ + psycopg_driver = cast("PsycopgSyncDriver", driver) + + if schema: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = '{schema}' + ORDER BY ordinal_position + """ + else: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = 'public' + ORDER BY ordinal_position + """ + + result = psycopg_driver.execute(sql) + return result.data or [] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. @@ -235,6 +272,43 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category: } return type_map.get(type_category, "TEXT") + async def get_columns( + self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using information_schema. + + Args: + driver: Psycopg async driver instance + table: Table name to query columns for + schema: Schema name (None for default 'public') + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: PostgreSQL data type + - is_nullable: Whether column allows NULL (YES/NO) + - column_default: Default value if any + """ + psycopg_driver = cast("PsycopgAsyncDriver", driver) + + if schema: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = '{schema}' + ORDER BY ordinal_position + """ + else: + sql = f""" + SELECT column_name, data_type, is_nullable, column_default + FROM information_schema.columns + WHERE table_name = '{table}' AND table_schema = 'public' + ORDER BY ordinal_position + """ + + result = await psycopg_driver.execute(sql) + return result.data or [] + def list_available_features(self) -> "list[str]": """List available PostgreSQL feature flags. diff --git a/sqlspec/adapters/sqlite/data_dictionary.py b/sqlspec/adapters/sqlite/data_dictionary.py index 3e5e7066..4377a880 100644 --- a/sqlspec/adapters/sqlite/data_dictionary.py +++ b/sqlspec/adapters/sqlite/data_dictionary.py @@ -1,7 +1,7 @@ """SQLite-specific data dictionary for metadata queries.""" import re -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Any, cast from sqlspec.driver import SyncDataDictionaryBase, SyncDriverAdapterBase, VersionInfo from sqlspec.utils.logging import get_logger @@ -99,6 +99,36 @@ def get_optimal_type(self, driver: SyncDriverAdapterBase, type_category: str) -> type_map = {"uuid": "TEXT", "boolean": "INTEGER", "timestamp": "TIMESTAMP", "text": "TEXT", "blob": "BLOB"} return type_map.get(type_category, "TEXT") + def get_columns( + self, driver: SyncDriverAdapterBase, table: str, schema: "str | None" = None + ) -> "list[dict[str, Any]]": + """Get column information for a table using SQLite PRAGMA. + + Args: + driver: SQLite driver instance + table: Table name to query columns for + schema: Schema name (unused in SQLite) + + Returns: + List of column metadata dictionaries with keys: + - column_name: Name of the column + - data_type: SQLite data type + - nullable: Whether column allows NULL + - default_value: Default value if any + """ + sqlite_driver = cast("SqliteDriver", driver) + result = sqlite_driver.execute(f"PRAGMA table_info({table})") + + return [ + { + "column_name": row["name"] if isinstance(row, dict) else row[1], + "data_type": row["type"] if isinstance(row, dict) else row[2], + "nullable": not (row["notnull"] if isinstance(row, dict) else row[3]), + "default_value": row["dflt_value"] if isinstance(row, dict) else row[4], + } + for row in result.data or [] + ] + def list_available_features(self) -> "list[str]": """List available SQLite feature flags. diff --git a/sqlspec/builder/_ddl.py b/sqlspec/builder/_ddl.py index 66074035..47fd77d6 100644 --- a/sqlspec/builder/_ddl.py +++ b/sqlspec/builder/_ddl.py @@ -349,6 +349,15 @@ def partition_by(self, partition_spec: str) -> "Self": self._partition_by = partition_spec return self + @property + def columns(self) -> "list[ColumnDefinition]": + """Get the list of column definitions for this table. + + Returns: + List of ColumnDefinition objects. + """ + return self._columns + def column( self, name: str, diff --git a/sqlspec/cli.py b/sqlspec/cli.py index 9316b78b..2346d82d 100644 --- a/sqlspec/cli.py +++ b/sqlspec/cli.py @@ -124,6 +124,12 @@ def add_migration_commands(database_group: "Group | None" = None) -> "Group": default="auto", help="Force execution mode (auto-detects by default)", ) + no_auto_sync_option = click.option( + "--no-auto-sync", + is_flag=True, + default=False, + help="Disable automatic version reconciliation when migrations have been renamed", + ) def get_config_by_bind_key( ctx: "click.Context", bind_key: str | None @@ -378,6 +384,7 @@ async def _downgrade_database() -> None: @exclude_option @dry_run_option @execution_mode_option + @no_auto_sync_option @click.argument("revision", type=str, default="head") def upgrade_database( # pyright: ignore[reportUnusedFunction] bind_key: str | None, @@ -387,6 +394,7 @@ def upgrade_database( # pyright: ignore[reportUnusedFunction] exclude: "tuple[str, ...]", dry_run: bool, execution_mode: str, + no_auto_sync: bool, ) -> None: """Upgrade the database to the latest revision.""" from rich.prompt import Confirm @@ -424,7 +432,9 @@ async def _upgrade_database() -> None: migration_commands: SyncMigrationCommands[Any] | AsyncMigrationCommands[Any] = ( create_migration_commands(config=config) ) - await maybe_await(migration_commands.upgrade(revision=revision, dry_run=dry_run)) + await maybe_await( + migration_commands.upgrade(revision=revision, auto_sync=not no_auto_sync, dry_run=dry_run) + ) console.print(f"[green]✓ Successfully upgraded: {config_name}[/]") except Exception as e: console.print(f"[red]✗ Failed to upgrade {config_name}: {e}[/]") @@ -441,7 +451,9 @@ async def _upgrade_database() -> None: if input_confirmed: sqlspec_config = get_config_by_bind_key(cast("click.Context", ctx), bind_key) migration_commands = create_migration_commands(config=sqlspec_config) - await maybe_await(migration_commands.upgrade(revision=revision, dry_run=dry_run)) + await maybe_await( + migration_commands.upgrade(revision=revision, auto_sync=not no_auto_sync, dry_run=dry_run) + ) run_(_upgrade_database)() @@ -533,6 +545,28 @@ async def _create_revision() -> None: run_(_create_revision)() + @database_group.command(name="fix", help="Convert timestamp migrations to sequential format.") + @bind_key_option + @dry_run_option + @click.option("--yes", is_flag=True, help="Skip confirmation prompt") + @click.option("--no-database", is_flag=True, help="Skip database record updates") + def fix_migrations( # pyright: ignore[reportUnusedFunction] + bind_key: str | None, dry_run: bool, yes: bool, no_database: bool + ) -> None: + """Convert timestamp migrations to sequential format.""" + from sqlspec.migrations.commands import create_migration_commands + from sqlspec.utils.sync_tools import run_ + + ctx = click.get_current_context() + + async def _fix_migrations() -> None: + console.rule("[yellow]Migration Fix Command[/]", align="left") + sqlspec_config = get_config_by_bind_key(cast("click.Context", ctx), bind_key) + migration_commands = create_migration_commands(config=sqlspec_config) + await maybe_await(migration_commands.fix(dry_run=dry_run, update_database=not no_database, yes=yes)) + + run_(_fix_migrations)() + @database_group.command(name="show-config", help="Show all configurations with migrations enabled.") @bind_key_option def show_config(bind_key: str | None = None) -> None: # pyright: ignore[reportUnusedFunction] diff --git a/sqlspec/config.py b/sqlspec/config.py index 3148d61a..87fb2430 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -20,6 +20,7 @@ __all__ = ( + "ADKConfig", "AsyncConfigT", "AsyncDatabaseConfig", "ConfigT", @@ -84,6 +85,15 @@ class MigrationConfig(TypedDict): enabled: NotRequired[bool] """Whether this configuration should be included in CLI operations. Defaults to True.""" + auto_sync: NotRequired[bool] + """Enable automatic version reconciliation during upgrade. When enabled (default), SQLSpec automatically updates database tracking when migrations are renamed from timestamp to sequential format. Defaults to True.""" + + strict_ordering: NotRequired[bool] + """Enforce strict migration ordering. When enabled, prevents out-of-order migrations from being applied. Defaults to False.""" + + include_extensions: NotRequired["list[str]"] + """List of extension names whose migrations should be included. Extension migrations maintain separate versioning and are prefixed with 'ext_{name}_'.""" + class LitestarConfig(TypedDict): """Configuration options for Litestar SQLSpec plugin. @@ -113,6 +123,124 @@ class LitestarConfig(TypedDict): """Additional HTTP status codes that trigger rollback. Default: set()""" +class ADKConfig(TypedDict): + """Configuration options for ADK session store extension. + + All fields are optional with sensible defaults. Use in extension_config["adk"]: + + Example: + from sqlspec.adapters.asyncpg import AsyncpgConfig + + config = AsyncpgConfig( + pool_config={"dsn": "postgresql://localhost/mydb"}, + extension_config={ + "adk": { + "session_table": "my_sessions", + "events_table": "my_events", + "owner_id_column": "tenant_id INTEGER REFERENCES tenants(id)" + } + } + ) + + Notes: + This TypedDict provides type safety for extension config but is not required. + You can use plain dicts as well. + """ + + session_table: NotRequired[str] + """Name of the sessions table. Default: 'adk_sessions' + + Examples: + "agent_sessions" + "my_app_sessions" + "tenant_acme_sessions" + """ + + events_table: NotRequired[str] + """Name of the events table. Default: 'adk_events' + + Examples: + "agent_events" + "my_app_events" + "tenant_acme_events" + """ + + owner_id_column: NotRequired[str] + """Optional owner ID column definition to link sessions to a user, tenant, team, or other entity. + + Format: "column_name TYPE [NOT NULL] REFERENCES table(column) [options...]" + + The entire definition is passed through to DDL verbatim. We only parse + the column name (first word) for use in INSERT/SELECT statements. + + Supports: + - Foreign key constraints: REFERENCES table(column) + - Nullable or NOT NULL + - CASCADE options: ON DELETE CASCADE, ON UPDATE CASCADE + - Dialect-specific options (DEFERRABLE, ENABLE VALIDATE, etc.) + - Plain columns without FK (just extra column storage) + + Examples: + PostgreSQL with UUID FK: + "account_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE" + + MySQL with BIGINT FK: + "user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE RESTRICT" + + Oracle with NUMBER FK: + "user_id NUMBER(10) REFERENCES users(id) ENABLE VALIDATE" + + SQLite with INTEGER FK: + "tenant_id INTEGER NOT NULL REFERENCES tenants(id)" + + Nullable FK (optional relationship): + "workspace_id UUID REFERENCES workspaces(id) ON DELETE SET NULL" + + No FK (just extra column): + "organization_name VARCHAR(128) NOT NULL" + + Deferred constraint (PostgreSQL): + "user_id UUID REFERENCES users(id) DEFERRABLE INITIALLY DEFERRED" + + Notes: + - Column name (first word) is extracted for INSERT/SELECT queries + - Rest of definition is passed through to CREATE TABLE DDL + - Database validates the DDL syntax (fail-fast on errors) + - Works with all database dialects (PostgreSQL, MySQL, SQLite, Oracle, etc.) + """ + + in_memory: NotRequired[bool] + """Enable in-memory table storage (Oracle-specific). Default: False. + + When enabled, tables are created with the INMEMORY clause for Oracle Database, + which stores table data in columnar format in memory for faster query performance. + + This is an Oracle-specific feature that requires: + - Oracle Database 12.1.0.2 or higher + - Database In-Memory option license (Enterprise Edition) + - Sufficient INMEMORY_SIZE configured in the database instance + + Other database adapters ignore this setting. + + Examples: + Oracle with in-memory enabled: + config = OracleAsyncConfig( + pool_config={"dsn": "oracle://..."}, + extension_config={ + "adk": { + "in_memory": True + } + } + ) + + Notes: + - Improves query performance for analytics (10-100x faster) + - Tables created with INMEMORY clause + - Requires Oracle Database In-Memory option license + - Ignored by non-Oracle adapters + """ + + class DatabaseConfigProtocol(ABC, Generic[ConnectionT, PoolT, DriverT]): """Protocol defining the interface for database configurations.""" diff --git a/sqlspec/exceptions.py b/sqlspec/exceptions.py index 9736bbf4..3e907c49 100644 --- a/sqlspec/exceptions.py +++ b/sqlspec/exceptions.py @@ -11,11 +11,14 @@ "ForeignKeyViolationError", "ImproperConfigurationError", "IntegrityError", + "InvalidVersionFormatError", + "MigrationError", "MissingDependencyError", "MultipleResultsFoundError", "NotFoundError", "NotNullViolationError", "OperationalError", + "OutOfOrderMigrationError", "RepositoryError", "SQLBuilderError", "SQLConversionError", @@ -207,6 +210,26 @@ def __init__(self, name: str, path: str, original_error: "Exception") -> None: self.original_error = original_error +class MigrationError(SQLSpecError): + """Base exception for migration-related errors.""" + + +class InvalidVersionFormatError(MigrationError): + """Raised when a migration version format is invalid. + + Invalid formats include versions that don't match sequential (0001) + or timestamp (YYYYMMDDHHmmss) patterns, or timestamps with invalid dates. + """ + + +class OutOfOrderMigrationError(MigrationError): + """Raised when an out-of-order migration is detected in strict mode. + + Out-of-order migrations occur when a pending migration has a timestamp + earlier than already-applied migrations, typically from late-merging branches. + """ + + @contextmanager def wrap_exceptions( wrap_exceptions: bool = True, suppress: "type[Exception] | tuple[type[Exception], ...] | None" = None diff --git a/sqlspec/extensions/adk/__init__.py b/sqlspec/extensions/adk/__init__.py index f9988a53..f0482726 100644 --- a/sqlspec/extensions/adk/__init__.py +++ b/sqlspec/extensions/adk/__init__.py @@ -38,8 +38,8 @@ ) """ +from sqlspec.config import ADKConfig from sqlspec.extensions.adk._types import EventRecord, SessionRecord -from sqlspec.extensions.adk.config import ADKConfig from sqlspec.extensions.adk.service import SQLSpecSessionService from sqlspec.extensions.adk.store import BaseAsyncADKStore, BaseSyncADKStore diff --git a/sqlspec/extensions/adk/config.py b/sqlspec/extensions/adk/config.py deleted file mode 100644 index e961c598..00000000 --- a/sqlspec/extensions/adk/config.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Configuration types for ADK session store extension.""" - -from typing_extensions import NotRequired, TypedDict - -__all__ = ("ADKConfig",) - - -class ADKConfig(TypedDict): - """Configuration options for ADK session store extension. - - All fields are optional with sensible defaults. Use in extension_config["adk"]: - - Example: - from sqlspec.adapters.asyncpg import AsyncpgConfig - - config = AsyncpgConfig( - pool_config={"dsn": "postgresql://localhost/mydb"}, - extension_config={ - "adk": { - "session_table": "my_sessions", - "events_table": "my_events", - "owner_id_column": "tenant_id INTEGER REFERENCES tenants(id)" - } - } - ) - - Notes: - This TypedDict provides type safety for extension config but is not required. - You can use plain dicts as well. - """ - - session_table: NotRequired[str] - """Name of the sessions table. Default: 'adk_sessions' - - Examples: - "agent_sessions" - "my_app_sessions" - "tenant_acme_sessions" - """ - - events_table: NotRequired[str] - """Name of the events table. Default: 'adk_events' - - Examples: - "agent_events" - "my_app_events" - "tenant_acme_events" - """ - - owner_id_column: NotRequired[str] - """Optional owner ID column definition to link sessions to a user, tenant, team, or other entity. - - Format: "column_name TYPE [NOT NULL] REFERENCES table(column) [options...]" - - The entire definition is passed through to DDL verbatim. We only parse - the column name (first word) for use in INSERT/SELECT statements. - - Supports: - - Foreign key constraints: REFERENCES table(column) - - Nullable or NOT NULL - - CASCADE options: ON DELETE CASCADE, ON UPDATE CASCADE - - Dialect-specific options (DEFERRABLE, ENABLE VALIDATE, etc.) - - Plain columns without FK (just extra column storage) - - Examples: - PostgreSQL with UUID FK: - "account_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE" - - MySQL with BIGINT FK: - "user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE RESTRICT" - - Oracle with NUMBER FK: - "user_id NUMBER(10) REFERENCES users(id) ENABLE VALIDATE" - - SQLite with INTEGER FK: - "tenant_id INTEGER NOT NULL REFERENCES tenants(id)" - - Nullable FK (optional relationship): - "workspace_id UUID REFERENCES workspaces(id) ON DELETE SET NULL" - - No FK (just extra column): - "organization_name VARCHAR(128) NOT NULL" - - Deferred constraint (PostgreSQL): - "user_id UUID REFERENCES users(id) DEFERRABLE INITIALLY DEFERRED" - - Notes: - - Column name (first word) is extracted for INSERT/SELECT queries - - Rest of definition is passed through to CREATE TABLE DDL - - Database validates the DDL syntax (fail-fast on errors) - - Works with all database dialects (PostgreSQL, MySQL, SQLite, Oracle, etc.) - """ - - in_memory: NotRequired[bool] - """Enable in-memory table storage (Oracle-specific). Default: False. - - When enabled, tables are created with the INMEMORY clause for Oracle Database, - which stores table data in columnar format in memory for faster query performance. - - This is an Oracle-specific feature that requires: - - Oracle Database 12.1.0.2 or higher - - Database In-Memory option license (Enterprise Edition) - - Sufficient INMEMORY_SIZE configured in the database instance - - Other database adapters ignore this setting. - - Examples: - Oracle with in-memory enabled: - config = OracleAsyncConfig( - pool_config={"dsn": "oracle://..."}, - extension_config={ - "adk": { - "in_memory": True - } - } - ) - - Notes: - - Improves query performance for analytics (10-100x faster) - - Tables created with INMEMORY clause - - Requires Oracle Database In-Memory option license - - Ignored by non-Oracle adapters - """ diff --git a/sqlspec/loader.py b/sqlspec/loader.py index 52996c77..af0dd207 100644 --- a/sqlspec/loader.py +++ b/sqlspec/loader.py @@ -354,43 +354,20 @@ def load_sql(self, *paths: str | Path) -> None: correlation_id = CorrelationContext.get() start_time = time.perf_counter() - logger.info("Loading SQL files", extra={"file_count": len(paths), "correlation_id": correlation_id}) - - loaded_count = 0 - query_count_before = len(self._queries) - try: for path in paths: path_str = str(path) if "://" in path_str: self._load_single_file(path, None) - loaded_count += 1 else: path_obj = Path(path) if path_obj.is_dir(): - loaded_count += self._load_directory(path_obj) + self._load_directory(path_obj) elif path_obj.exists(): self._load_single_file(path_obj, None) - loaded_count += 1 elif path_obj.suffix: self._raise_file_not_found(str(path)) - duration = time.perf_counter() - start_time - new_queries = len(self._queries) - query_count_before - - logger.info( - "Loaded %d SQL files with %d new queries in %.3fms", - loaded_count, - new_queries, - duration * 1000, - extra={ - "files_loaded": loaded_count, - "new_queries": new_queries, - "duration_ms": duration * 1000, - "correlation_id": correlation_id, - }, - ) - except Exception as e: duration = time.perf_counter() - start_time logger.exception( @@ -404,34 +381,40 @@ def load_sql(self, *paths: str | Path) -> None: ) raise - def _load_directory(self, dir_path: Path) -> int: - """Load all SQL files from a directory.""" + def _load_directory(self, dir_path: Path) -> None: + """Load all SQL files from a directory. + + Args: + dir_path: Directory path to load SQL files from. + """ sql_files = list(dir_path.rglob("*.sql")) if not sql_files: - return 0 + return for file_path in sql_files: relative_path = file_path.relative_to(dir_path) namespace_parts = relative_path.parent.parts self._load_single_file(file_path, ".".join(namespace_parts) if namespace_parts else None) - return len(sql_files) - def _load_single_file(self, file_path: str | Path, namespace: str | None) -> None: + def _load_single_file(self, file_path: str | Path, namespace: str | None) -> bool: """Load a single SQL file with optional namespace. Args: file_path: Path to the SQL file. namespace: Optional namespace prefix for queries. + + Returns: + True if file was newly loaded, False if already cached. """ path_str = str(file_path) if path_str in self._files: - return + return False cache_config = get_cache_config() if not cache_config.compiled_cache_enabled: self._load_file_without_cache(file_path, namespace) - return + return True cache_key_str = self._generate_file_cache_key(file_path) cache = get_cache() @@ -455,7 +438,7 @@ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> Non ) self._queries[namespaced_name] = statement self._query_to_file[namespaced_name] = path_str - return + return True self._load_file_without_cache(file_path, namespace) @@ -472,6 +455,8 @@ def _load_single_file(self, file_path: str | Path, namespace: str | None) -> Non cached_file_data = CachedSQLFile(sql_file=sql_file, parsed_statements=file_statements) cache.put("file", cache_key_str, cached_file_data) + return True + def _load_file_without_cache(self, file_path: str | Path, namespace: str | None) -> None: """Load a single SQL file without using cache. diff --git a/sqlspec/migrations/base.py b/sqlspec/migrations/base.py index 04286400..f06d15aa 100644 --- a/sqlspec/migrations/base.py +++ b/sqlspec/migrations/base.py @@ -4,18 +4,18 @@ """ import hashlib -import operator from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Generic, TypeVar, cast -from sqlspec.builder import Delete, Insert, Select, sql +from sqlspec.builder import Delete, Insert, Select, Update, sql from sqlspec.builder._ddl import CreateTable from sqlspec.loader import SQLFileLoader from sqlspec.migrations.loaders import get_migration_loader from sqlspec.utils.logging import get_logger from sqlspec.utils.module_loader import module_to_os_path from sqlspec.utils.sync_tools import await_ +from sqlspec.utils.version import parse_version __all__ = ("BaseMigrationCommands", "BaseMigrationRunner", "BaseMigrationTracker") @@ -42,6 +42,16 @@ def __init__(self, version_table_name: str = "ddl_migrations") -> None: def _get_create_table_sql(self) -> CreateTable: """Get SQL builder for creating the tracking table. + Schema includes both legacy and new versioning columns: + - version_num: Migration version (sequential or timestamp format) + - version_type: Format indicator ('sequential' or 'timestamp') + - execution_sequence: Auto-incrementing application order + - description: Human-readable migration description + - applied_at: Timestamp when migration was applied + - execution_time_ms: Migration execution duration + - checksum: MD5 hash for content verification + - applied_by: User who applied the migration + Returns: SQL builder object for table creation. """ @@ -49,6 +59,8 @@ def _get_create_table_sql(self) -> CreateTable: sql.create_table(self.version_table) .if_not_exists() .column("version_num", "VARCHAR(32)", primary_key=True) + .column("version_type", "VARCHAR(16)") + .column("execution_sequence", "INTEGER") .column("description", "TEXT") .column("applied_at", "TIMESTAMP", default="CURRENT_TIMESTAMP", not_null=True) .column("execution_time_ms", "INTEGER") @@ -59,26 +71,49 @@ def _get_create_table_sql(self) -> CreateTable: def _get_current_version_sql(self) -> Select: """Get SQL builder for retrieving current version. + Uses execution_sequence to get the last applied migration, + which may differ from version_num order due to out-of-order migrations. + Returns: SQL builder object for version query. """ - return sql.select("version_num").from_(self.version_table).order_by("version_num DESC").limit(1) + return sql.select("version_num").from_(self.version_table).order_by("execution_sequence DESC").limit(1) def _get_applied_migrations_sql(self) -> Select: """Get SQL builder for retrieving all applied migrations. + Orders by execution_sequence to show migrations in application order, + which preserves the actual execution history for out-of-order migrations. + Returns: SQL builder object for migrations query. """ - return sql.select("*").from_(self.version_table).order_by("version_num") + return sql.select("*").from_(self.version_table).order_by("execution_sequence") + + def _get_next_execution_sequence_sql(self) -> Select: + """Get SQL builder for retrieving next execution sequence. + + Returns: + SQL builder object for sequence query. + """ + return sql.select("COALESCE(MAX(execution_sequence), 0) + 1 AS next_seq").from_(self.version_table) def _get_record_migration_sql( - self, version: str, description: str, execution_time_ms: int, checksum: str, applied_by: str + self, + version: str, + version_type: str, + execution_sequence: int, + description: str, + execution_time_ms: int, + checksum: str, + applied_by: str, ) -> Insert: """Get SQL builder for recording a migration. Args: version: Version number of the migration. + version_type: Version format type ('sequential' or 'timestamp'). + execution_sequence: Auto-incrementing application order. description: Description of the migration. execution_time_ms: Execution time in milliseconds. checksum: MD5 checksum of the migration content. @@ -89,8 +124,16 @@ def _get_record_migration_sql( """ return ( sql.insert(self.version_table) - .columns("version_num", "description", "execution_time_ms", "checksum", "applied_by") - .values(version, description, execution_time_ms, checksum, applied_by) + .columns( + "version_num", + "version_type", + "execution_sequence", + "description", + "execution_time_ms", + "checksum", + "applied_by", + ) + .values(version, version_type, execution_sequence, description, execution_time_ms, checksum, applied_by) ) def _get_remove_migration_sql(self, version: str) -> Delete: @@ -104,9 +147,90 @@ def _get_remove_migration_sql(self, version: str) -> Delete: """ return sql.delete().from_(self.version_table).where(sql.version_num == version) + def _get_update_version_sql(self, old_version: str, new_version: str, new_version_type: str) -> Update: + """Get SQL builder for updating version record. + + Updates version_num and version_type while preserving execution_sequence, + applied_at, and other metadata. Used during fix command to convert + timestamp versions to sequential format. + + Args: + old_version: Current version string. + new_version: New version string. + new_version_type: New version type ('sequential' or 'timestamp'). + + Returns: + SQL builder object for update. + """ + return ( + sql.update(self.version_table) + .set("version_num", new_version) + .set("version_type", new_version_type) + .where(sql.version_num == old_version) + ) + + def _get_check_column_exists_sql(self) -> Select: + """Get SQL to check what columns exist in the tracking table. + + Returns a query that will fail gracefully if the table doesn't exist, + and returns column names if it does. + + Returns: + SQL builder object for column check query. + """ + return sql.select("*").from_(self.version_table).limit(0) + + def _get_add_missing_columns_sql(self, missing_columns: "set[str]") -> "list[str]": + """Generate ALTER TABLE statements to add missing columns. + + Args: + missing_columns: Set of column names that need to be added. + + Returns: + List of SQL statements to execute. + """ + + statements = [] + target_create = self._get_create_table_sql() + + column_definitions = {col.name.lower(): col for col in target_create.columns} + + for col_name in sorted(missing_columns): + if col_name in column_definitions: + col_def = column_definitions[col_name] + alter = sql.alter_table(self.version_table).add_column( + name=col_def.name, + dtype=col_def.dtype, + default=col_def.default, + not_null=col_def.not_null, + unique=col_def.unique, + comment=col_def.comment, + ) + statements.append(str(alter)) + + return statements + + def _detect_missing_columns(self, existing_columns: "set[str]") -> "set[str]": + """Detect which columns are missing from the current schema. + + Args: + existing_columns: Set of existing column names (may be uppercase/lowercase). + + Returns: + Set of missing column names (lowercase). + """ + target_create = self._get_create_table_sql() + target_columns = {col.name.lower() for col in target_create.columns} + existing_lower = {col.lower() for col in existing_columns} + return target_columns - existing_lower + @abstractmethod def ensure_tracking_table(self, driver: DriverT) -> Any: - """Create the migration tracking table if it doesn't exist.""" + """Create the migration tracking table if it doesn't exist. + + Implementations should also check for and add any missing columns + to support schema migrations from older versions. + """ ... @abstractmethod @@ -168,13 +292,14 @@ def _extract_version(self, filename: str) -> str | None: Returns: The extracted version string or None. """ - # Handle extension-prefixed versions (e.g., "ext_litestar_0001") - if filename.startswith("ext_"): - # This is already a prefixed version, return as-is - return filename + from pathlib import Path + + stem = Path(filename).stem - # Regular version extraction - parts = filename.split("_", 1) + if stem.startswith("ext_"): + return stem + + parts = stem.split("_", 1) return parts[0].zfill(4) if parts and parts[0].isdigit() else None def _calculate_checksum(self, content: str) -> str: @@ -192,6 +317,9 @@ def _calculate_checksum(self, content: str) -> str: def _get_migration_files_sync(self) -> "list[tuple[str, Path]]": """Get all migration files sorted by version. + Uses version-aware sorting that handles both sequential and timestamp + formats correctly, with extension migrations sorted by extension name. + Returns: List of tuples containing (version, file_path). """ @@ -221,7 +349,7 @@ def _get_migration_files_sync(self) -> "list[tuple[str, Path]]": prefixed_version = f"ext_{ext_name}_{version}" migrations.append((prefixed_version, file_path)) - return sorted(migrations, key=operator.itemgetter(0)) + return sorted(migrations, key=lambda m: parse_version(m[0])) def _load_migration_metadata(self, file_path: Path, version: "str | None" = None) -> "dict[str, Any]": """Load migration metadata from file. @@ -424,13 +552,13 @@ def _get_init_readme_content(self) -> str: Migration files use SQLFileLoader's named query syntax with versioned names: ```sql --- name: migrate-0001-up +-- name: migrate-20251011120000-up CREATE TABLE example ( id INTEGER PRIMARY KEY, name TEXT NOT NULL ); --- name: migrate-0001-down +-- name: migrate-20251011120000-down DROP TABLE example; ``` @@ -440,16 +568,39 @@ def _get_init_readme_content(self) -> str: Format: `{version}_{description}.sql` -- Version: Zero-padded 4-digit number (0001, 0002, etc.) +- Version: Timestamp in YYYYMMDDHHmmss format (UTC) - Description: Brief description using underscores -- Example: `0001_create_users_table.sql` +- Example: `20251011120000_create_users_table.sql` ### Query Names - Upgrade: `migrate-{version}-up` - Downgrade: `migrate-{version}-down` -This naming ensures proper sorting and avoids conflicts when loading multiple files. +## Version Format + +Migrations use **timestamp-based versioning** (YYYYMMDDHHmmss): + +- **Format**: 14-digit UTC timestamp +- **Example**: `20251011120000` (October 11, 2025 at 12:00:00 UTC) +- **Benefits**: Eliminates merge conflicts when multiple developers create migrations concurrently + +### Creating Migrations + +Use the CLI to generate timestamped migrations: + +```bash +sqlspec create-migration "add user table" +# Creates: 20251011120000_add_user_table.sql +``` + +The timestamp is automatically generated in UTC timezone. + +## Migration Execution + +Migrations are applied in chronological order based on their timestamps. +The database tracks both version and execution order separately to handle +out-of-order migrations gracefully (e.g., from late-merging branches). """ def init_directory(self, directory: str, package: bool = True) -> None: diff --git a/sqlspec/migrations/commands.py b/sqlspec/migrations/commands.py index bb5b8228..d8507d3d 100644 --- a/sqlspec/migrations/commands.py +++ b/sqlspec/migrations/commands.py @@ -11,11 +11,16 @@ from sqlspec.builder import sql from sqlspec.migrations.base import BaseMigrationCommands from sqlspec.migrations.context import MigrationContext +from sqlspec.migrations.fix import MigrationFixer from sqlspec.migrations.runner import AsyncMigrationRunner, SyncMigrationRunner from sqlspec.migrations.utils import create_migration_file +from sqlspec.migrations.validation import validate_migration_order from sqlspec.utils.logging import get_logger +from sqlspec.utils.version import generate_conversion_map, generate_timestamp_version if TYPE_CHECKING: + from pathlib import Path + from sqlspec.config import AsyncConfigT, SyncConfigT __all__ = ("AsyncMigrationCommands", "SyncMigrationCommands", "create_migration_commands") @@ -95,11 +100,114 @@ def current(self, verbose: bool = False) -> "str | None": return cast("str | None", current) - def upgrade(self, revision: str = "head", *, dry_run: bool = False) -> None: + def _load_single_migration_checksum(self, version: str, file_path: "Path") -> "tuple[str, tuple[str, Path]] | None": + """Load checksum for a single migration. + + Args: + version: Migration version. + file_path: Path to migration file. + + Returns: + Tuple of (version, (checksum, file_path)) or None if load fails. + """ + try: + migration = self.runner.load_migration(file_path, version) + return (version, (migration["checksum"], file_path)) + except Exception as e: + logger.debug("Could not load migration %s for auto-sync: %s", version, e) + return None + + def _load_migration_checksums(self, all_migrations: "list[tuple[str, Path]]") -> "dict[str, tuple[str, Path]]": + """Load checksums for all migrations. + + Args: + all_migrations: List of (version, file_path) tuples. + + Returns: + Dictionary mapping version to (checksum, file_path) tuples. + """ + file_checksums = {} + for version, file_path in all_migrations: + result = self._load_single_migration_checksum(version, file_path) + if result: + file_checksums[result[0]] = result[1] + return file_checksums + + def _synchronize_version_records(self, driver: Any) -> int: + """Synchronize database version records with migration files. + + Auto-updates DB tracking when migrations have been renamed by fix command. + This allows developers to just run upgrade after pulling changes without + manually running fix. + + Validates checksums match before updating to prevent incorrect matches. + + Args: + driver: Database driver instance. + + Returns: + Number of version records updated. + """ + all_migrations = self.runner.get_migration_files() + + try: + applied_migrations = self.tracker.get_applied_migrations(driver) + except Exception: + logger.debug("Could not fetch applied migrations for synchronization (table schema may be migrating)") + return 0 + + applied_map = {m["version_num"]: m for m in applied_migrations} + + conversion_map = generate_conversion_map(all_migrations) + + updated_count = 0 + if conversion_map: + for old_version, new_version in conversion_map.items(): + if old_version in applied_map and new_version not in applied_map: + applied_checksum = applied_map[old_version]["checksum"] + + file_path = next((path for v, path in all_migrations if v == new_version), None) + if file_path: + migration = self.runner.load_migration(file_path, new_version) + if migration["checksum"] == applied_checksum: + self.tracker.update_version_record(driver, old_version, new_version) + console.print(f" [dim]Reconciled version:[/] {old_version} → {new_version}") + updated_count += 1 + else: + console.print( + f" [yellow]Warning: Checksum mismatch for {old_version} → {new_version}, skipping auto-sync[/]" + ) + else: + file_checksums = self._load_migration_checksums(all_migrations) + + for applied_version, applied_record in applied_map.items(): + for file_version, (file_checksum, _) in file_checksums.items(): + if file_version not in applied_map and applied_record["checksum"] == file_checksum: + self.tracker.update_version_record(driver, applied_version, file_version) + console.print(f" [dim]Reconciled version:[/] {applied_version} → {file_version}") + updated_count += 1 + break + + if updated_count > 0: + console.print(f"[cyan]Reconciled {updated_count} version record(s)[/]") + + return updated_count + + def upgrade( + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False + ) -> None: """Upgrade to a target revision. + Validates migration order and warns if out-of-order migrations are detected. + Out-of-order migrations can occur when branches merge in different orders + across environments. + Args: revision: Target revision or "head" for latest. + allow_missing: If True, allow out-of-order migrations even in strict mode. + Defaults to False. + auto_sync: If True, automatically reconcile renamed migrations in database. + Defaults to True. Can be disabled via --no-auto-sync flag. dry_run: If True, show what would be done without making changes. """ if dry_run: @@ -108,12 +216,29 @@ def upgrade(self, revision: str = "head", *, dry_run: bool = False) -> None: with self.config.provide_session() as driver: self.tracker.ensure_tracking_table(driver) - current = self.tracker.get_current_version(driver) + if auto_sync: + migration_config = getattr(self.config, "migration_config", {}) or {} + config_auto_sync = migration_config.get("auto_sync", True) + if config_auto_sync: + self._synchronize_version_records(driver) + + applied_migrations = self.tracker.get_applied_migrations(driver) + applied_versions = [m["version_num"] for m in applied_migrations] + applied_set = set(applied_versions) + all_migrations = self.runner.get_migration_files() pending = [] for version, file_path in all_migrations: - if (current is None or version > current) and (revision == "head" or version <= revision): - pending.append((version, file_path)) + if version not in applied_set: + if revision == "head": + pending.append((version, file_path)) + else: + from sqlspec.utils.version import parse_version + + parsed_version = parse_version(version) + parsed_revision = parse_version(revision) + if parsed_version <= parsed_revision: + pending.append((version, file_path)) if not pending: if not all_migrations: @@ -123,6 +248,12 @@ def upgrade(self, revision: str = "head", *, dry_run: bool = False) -> None: else: console.print("[green]Already at latest version[/]") return + pending_versions = [v for v, _ in pending] + + migration_config = getattr(self.config, "migration_config", {}) or {} + strict_ordering = migration_config.get("strict_ordering", False) and not allow_missing + + validate_migration_order(pending_versions, applied_versions, strict_ordering) console.print(f"[yellow]Found {len(pending)} pending migrations[/]") @@ -172,8 +303,12 @@ def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> None: elif revision == "base": to_revert = list(reversed(applied)) else: + from sqlspec.utils.version import parse_version + + parsed_revision = parse_version(revision) for migration in reversed(applied): - if migration["version_num"] > revision: + parsed_migration_version = parse_version(migration["version_num"]) + if parsed_migration_version > parsed_revision: to_revert.append(migration) if not to_revert: @@ -225,18 +360,105 @@ def stamp(self, revision: str) -> None: console.print(f"[green]Database stamped at revision {revision}[/]") def revision(self, message: str, file_type: str = "sql") -> None: - """Create a new migration file. + """Create a new migration file with timestamp-based versioning. + + Generates a unique timestamp version (YYYYMMDDHHmmss format) to avoid + conflicts when multiple developers create migrations concurrently. Args: message: Description for the migration. file_type: Type of migration file to create ('sql' or 'py'). """ - existing = self.runner.get_migration_files() - next_num = int(existing[-1][0]) + 1 if existing else 1 - next_version = str(next_num).zfill(4) - file_path = create_migration_file(self.migrations_path, next_version, message, file_type) + version = generate_timestamp_version() + file_path = create_migration_file(self.migrations_path, version, message, file_type) console.print(f"[green]Created migration:[/] {file_path}") + def fix(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: + """Convert timestamp migrations to sequential format. + + Implements hybrid versioning workflow where development uses timestamps + and production uses sequential numbers. Creates backup before changes + and provides rollback on errors. + + Args: + dry_run: Preview changes without applying. + update_database: Update migration records in database. + yes: Skip confirmation prompt. + + Examples: + >>> commands.fix(dry_run=True) # Preview only + >>> commands.fix(yes=True) # Auto-approve + >>> commands.fix(update_database=False) # Files only + """ + all_migrations = self.runner.get_migration_files() + + conversion_map = generate_conversion_map(all_migrations) + + if not conversion_map: + console.print("[yellow]No timestamp migrations found - nothing to convert[/]") + return + + fixer = MigrationFixer(self.migrations_path) + renames = fixer.plan_renames(conversion_map) + + table = Table(title="Migration Conversions") + table.add_column("Current Version", style="cyan") + table.add_column("New Version", style="green") + table.add_column("File") + + for rename in renames: + table.add_row(rename.old_version, rename.new_version, rename.old_path.name) + + console.print(table) + console.print(f"\n[yellow]{len(renames)} migrations will be converted[/]") + + if dry_run: + console.print("[yellow][Preview Mode - No changes made][/]") + return + + if not yes: + response = input("\nProceed with conversion? [y/N]: ") + if response.lower() != "y": + console.print("[yellow]Conversion cancelled[/]") + return + + try: + backup_path = fixer.create_backup() + console.print(f"[green]✓ Created backup in {backup_path.name}[/]") + + fixer.apply_renames(renames) + for rename in renames: + console.print(f"[green]✓ Renamed {rename.old_path.name} → {rename.new_path.name}[/]") + + if update_database: + with self.config.provide_session() as driver: + self.tracker.ensure_tracking_table(driver) + applied_migrations = self.tracker.get_applied_migrations(driver) + applied_versions = {m["version_num"] for m in applied_migrations} + + updated_count = 0 + for old_version, new_version in conversion_map.items(): + if old_version in applied_versions: + self.tracker.update_version_record(driver, old_version, new_version) + updated_count += 1 + + if updated_count > 0: + console.print( + f"[green]✓ Updated {updated_count} version records in migration tracking table[/]" + ) + else: + console.print("[green]✓ No applied migrations to update in tracking table[/]") + + fixer.cleanup() + console.print("[green]✓ Conversion complete![/]") + + except Exception as e: + logger.exception("Fix command failed") + console.print(f"[red]✗ Error: {e}[/]") + fixer.rollback() + console.print("[yellow]Restored files from backup[/]") + raise + class AsyncMigrationCommands(BaseMigrationCommands["AsyncConfigT", Any]): """Asynchronous migration commands.""" @@ -305,11 +527,118 @@ async def current(self, verbose: bool = False) -> "str | None": return cast("str | None", current) - async def upgrade(self, revision: str = "head", *, dry_run: bool = False) -> None: + async def _load_single_migration_checksum( + self, version: str, file_path: "Path" + ) -> "tuple[str, tuple[str, Path]] | None": + """Load checksum for a single migration. + + Args: + version: Migration version. + file_path: Path to migration file. + + Returns: + Tuple of (version, (checksum, file_path)) or None if load fails. + """ + try: + migration = await self.runner.load_migration(file_path, version) + return (version, (migration["checksum"], file_path)) + except Exception as e: + logger.debug("Could not load migration %s for auto-sync: %s", version, e) + return None + + async def _load_migration_checksums( + self, all_migrations: "list[tuple[str, Path]]" + ) -> "dict[str, tuple[str, Path]]": + """Load checksums for all migrations. + + Args: + all_migrations: List of (version, file_path) tuples. + + Returns: + Dictionary mapping version to (checksum, file_path) tuples. + """ + file_checksums = {} + for version, file_path in all_migrations: + result = await self._load_single_migration_checksum(version, file_path) + if result: + file_checksums[result[0]] = result[1] + return file_checksums + + async def _synchronize_version_records(self, driver: Any) -> int: + """Synchronize database version records with migration files. + + Auto-updates DB tracking when migrations have been renamed by fix command. + This allows developers to just run upgrade after pulling changes without + manually running fix. + + Validates checksums match before updating to prevent incorrect matches. + + Args: + driver: Database driver instance. + + Returns: + Number of version records updated. + """ + all_migrations = await self.runner.get_migration_files() + + try: + applied_migrations = await self.tracker.get_applied_migrations(driver) + except Exception: + logger.debug("Could not fetch applied migrations for synchronization (table schema may be migrating)") + return 0 + + applied_map = {m["version_num"]: m for m in applied_migrations} + + conversion_map = generate_conversion_map(all_migrations) + + updated_count = 0 + if conversion_map: + for old_version, new_version in conversion_map.items(): + if old_version in applied_map and new_version not in applied_map: + applied_checksum = applied_map[old_version]["checksum"] + + file_path = next((path for v, path in all_migrations if v == new_version), None) + if file_path: + migration = await self.runner.load_migration(file_path, new_version) + if migration["checksum"] == applied_checksum: + await self.tracker.update_version_record(driver, old_version, new_version) + console.print(f" [dim]Reconciled version:[/] {old_version} → {new_version}") + updated_count += 1 + else: + console.print( + f" [yellow]Warning: Checksum mismatch for {old_version} → {new_version}, skipping auto-sync[/]" + ) + else: + file_checksums = await self._load_migration_checksums(all_migrations) + + for applied_version, applied_record in applied_map.items(): + for file_version, (file_checksum, _) in file_checksums.items(): + if file_version not in applied_map and applied_record["checksum"] == file_checksum: + await self.tracker.update_version_record(driver, applied_version, file_version) + console.print(f" [dim]Reconciled version:[/] {applied_version} → {file_version}") + updated_count += 1 + break + + if updated_count > 0: + console.print(f"[cyan]Reconciled {updated_count} version record(s)[/]") + + return updated_count + + async def upgrade( + self, revision: str = "head", allow_missing: bool = False, auto_sync: bool = True, dry_run: bool = False + ) -> None: """Upgrade to a target revision. + Validates migration order and warns if out-of-order migrations are detected. + Out-of-order migrations can occur when branches merge in different orders + across environments. + Args: revision: Target revision or "head" for latest. + allow_missing: If True, allow out-of-order migrations even in strict mode. + Defaults to False. + auto_sync: If True, automatically reconcile renamed migrations in database. + Defaults to True. Can be disabled via --no-auto-sync flag. dry_run: If True, show what would be done without making changes. """ if dry_run: @@ -318,12 +647,29 @@ async def upgrade(self, revision: str = "head", *, dry_run: bool = False) -> Non async with self.config.provide_session() as driver: await self.tracker.ensure_tracking_table(driver) - current = await self.tracker.get_current_version(driver) + if auto_sync: + migration_config = getattr(self.config, "migration_config", {}) or {} + config_auto_sync = migration_config.get("auto_sync", True) + if config_auto_sync: + await self._synchronize_version_records(driver) + + applied_migrations = await self.tracker.get_applied_migrations(driver) + applied_versions = [m["version_num"] for m in applied_migrations] + applied_set = set(applied_versions) + all_migrations = await self.runner.get_migration_files() pending = [] for version, file_path in all_migrations: - if (current is None or version > current) and (revision == "head" or version <= revision): - pending.append((version, file_path)) + if version not in applied_set: + if revision == "head": + pending.append((version, file_path)) + else: + from sqlspec.utils.version import parse_version + + parsed_version = parse_version(version) + parsed_revision = parse_version(revision) + if parsed_version <= parsed_revision: + pending.append((version, file_path)) if not pending: if not all_migrations: console.print( @@ -332,6 +678,13 @@ async def upgrade(self, revision: str = "head", *, dry_run: bool = False) -> Non else: console.print("[green]Already at latest version[/]") return + pending_versions = [v for v, _ in pending] + + migration_config = getattr(self.config, "migration_config", {}) or {} + strict_ordering = migration_config.get("strict_ordering", False) and not allow_missing + + validate_migration_order(pending_versions, applied_versions, strict_ordering) + console.print(f"[yellow]Found {len(pending)} pending migrations[/]") for version, file_path in pending: migration = await self.runner.load_migration(file_path, version) @@ -379,8 +732,12 @@ async def downgrade(self, revision: str = "-1", *, dry_run: bool = False) -> Non elif revision == "base": to_revert = list(reversed(applied)) else: + from sqlspec.utils.version import parse_version + + parsed_revision = parse_version(revision) for migration in reversed(applied): - if migration["version_num"] > revision: + parsed_migration_version = parse_version(migration["version_num"]) + if parsed_migration_version > parsed_revision: to_revert.append(migration) if not to_revert: console.print("[yellow]Nothing to downgrade[/]") @@ -434,18 +791,105 @@ async def stamp(self, revision: str) -> None: console.print(f"[green]Database stamped at revision {revision}[/]") async def revision(self, message: str, file_type: str = "sql") -> None: - """Create a new migration file. + """Create a new migration file with timestamp-based versioning. + + Generates a unique timestamp version (YYYYMMDDHHmmss format) to avoid + conflicts when multiple developers create migrations concurrently. Args: message: Description for the migration. file_type: Type of migration file to create ('sql' or 'py'). """ - existing = await self.runner.get_migration_files() - next_num = int(existing[-1][0]) + 1 if existing else 1 - next_version = str(next_num).zfill(4) - file_path = create_migration_file(self.migrations_path, next_version, message, file_type) + version = generate_timestamp_version() + file_path = create_migration_file(self.migrations_path, version, message, file_type) console.print(f"[green]Created migration:[/] {file_path}") + async def fix(self, dry_run: bool = False, update_database: bool = True, yes: bool = False) -> None: + """Convert timestamp migrations to sequential format. + + Implements hybrid versioning workflow where development uses timestamps + and production uses sequential numbers. Creates backup before changes + and provides rollback on errors. + + Args: + dry_run: Preview changes without applying. + update_database: Update migration records in database. + yes: Skip confirmation prompt. + + Examples: + >>> await commands.fix(dry_run=True) # Preview only + >>> await commands.fix(yes=True) # Auto-approve + >>> await commands.fix(update_database=False) # Files only + """ + all_migrations = await self.runner.get_migration_files() + + conversion_map = generate_conversion_map(all_migrations) + + if not conversion_map: + console.print("[yellow]No timestamp migrations found - nothing to convert[/]") + return + + fixer = MigrationFixer(self.migrations_path) + renames = fixer.plan_renames(conversion_map) + + table = Table(title="Migration Conversions") + table.add_column("Current Version", style="cyan") + table.add_column("New Version", style="green") + table.add_column("File") + + for rename in renames: + table.add_row(rename.old_version, rename.new_version, rename.old_path.name) + + console.print(table) + console.print(f"\n[yellow]{len(renames)} migrations will be converted[/]") + + if dry_run: + console.print("[yellow][Preview Mode - No changes made][/]") + return + + if not yes: + response = input("\nProceed with conversion? [y/N]: ") + if response.lower() != "y": + console.print("[yellow]Conversion cancelled[/]") + return + + try: + backup_path = fixer.create_backup() + console.print(f"[green]✓ Created backup in {backup_path.name}[/]") + + fixer.apply_renames(renames) + for rename in renames: + console.print(f"[green]✓ Renamed {rename.old_path.name} → {rename.new_path.name}[/]") + + if update_database: + async with self.config.provide_session() as driver: + await self.tracker.ensure_tracking_table(driver) + applied_migrations = await self.tracker.get_applied_migrations(driver) + applied_versions = {m["version_num"] for m in applied_migrations} + + updated_count = 0 + for old_version, new_version in conversion_map.items(): + if old_version in applied_versions: + await self.tracker.update_version_record(driver, old_version, new_version) + updated_count += 1 + + if updated_count > 0: + console.print( + f"[green]✓ Updated {updated_count} version records in migration tracking table[/]" + ) + else: + console.print("[green]✓ No applied migrations to update in tracking table[/]") + + fixer.cleanup() + console.print("[green]✓ Conversion complete![/]") + + except Exception as e: + logger.exception("Fix command failed") + console.print(f"[red]✗ Error: {e}[/]") + fixer.rollback() + console.print("[yellow]Restored files from backup[/]") + raise + def create_migration_commands( config: "SyncConfigT | AsyncConfigT", diff --git a/sqlspec/migrations/fix.py b/sqlspec/migrations/fix.py new file mode 100644 index 00000000..0c598260 --- /dev/null +++ b/sqlspec/migrations/fix.py @@ -0,0 +1,199 @@ +"""Migration file fix operations for converting timestamp to sequential versions. + +This module provides utilities to convert timestamp-format migration files to +sequential format, supporting the hybrid versioning workflow where development +uses timestamps and production uses sequential numbers. +""" + +import logging +import re +import shutil +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path + +__all__ = ("MigrationFixer", "MigrationRename") + +logger = logging.getLogger(__name__) + + +@dataclass +class MigrationRename: + """Represents a planned migration file rename operation. + + Attributes: + old_path: Current file path. + new_path: Target file path after rename. + old_version: Current version string. + new_version: Target version string. + needs_content_update: Whether file content needs updating. + True for SQL files that contain query names. + """ + + old_path: Path + new_path: Path + old_version: str + new_version: str + needs_content_update: bool + + +class MigrationFixer: + """Handles atomic migration file conversion operations. + + Provides backup/rollback functionality and manages conversion from + timestamp-based migration files to sequential format. + """ + + def __init__(self, migrations_path: Path) -> None: + """Initialize migration fixer. + + Args: + migrations_path: Path to migrations directory. + """ + self.migrations_path = migrations_path + self.backup_path: Path | None = None + + def plan_renames(self, conversion_map: dict[str, str]) -> list[MigrationRename]: + """Plan all file rename operations from conversion map. + + Scans migration directory and builds list of MigrationRename objects + for all files that need conversion. Validates no target collisions. + + Args: + conversion_map: Dictionary mapping old versions to new versions. + + Returns: + List of planned rename operations. + + Raises: + ValueError: If target file already exists or collision detected. + """ + if not conversion_map: + return [] + + renames: list[MigrationRename] = [] + + for old_version, new_version in conversion_map.items(): + matching_files = list(self.migrations_path.glob(f"{old_version}_*")) + + for old_path in matching_files: + suffix = old_path.suffix + description = old_path.stem.replace(f"{old_version}_", "") + + new_filename = f"{new_version}_{description}{suffix}" + new_path = self.migrations_path / new_filename + + if new_path.exists() and new_path != old_path: + msg = f"Target file already exists: {new_path}" + raise ValueError(msg) + + needs_content_update = suffix == ".sql" + + renames.append( + MigrationRename( + old_path=old_path, + new_path=new_path, + old_version=old_version, + new_version=new_version, + needs_content_update=needs_content_update, + ) + ) + + return renames + + def create_backup(self) -> Path: + """Create timestamped backup directory with all migration files. + + Returns: + Path to created backup directory. + + """ + timestamp = datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + backup_dir = self.migrations_path / f".backup_{timestamp}" + + backup_dir.mkdir(parents=True, exist_ok=False) + + for file_path in self.migrations_path.iterdir(): + if file_path.is_file() and not file_path.name.startswith("."): + shutil.copy2(file_path, backup_dir / file_path.name) + + self.backup_path = backup_dir + return backup_dir + + def apply_renames(self, renames: "list[MigrationRename]", dry_run: bool = False) -> None: + """Execute planned rename operations. + + Args: + renames: List of planned rename operations. + dry_run: If True, log operations without executing. + + """ + if not renames: + return + + for rename in renames: + if dry_run: + continue + + if rename.needs_content_update: + self.update_file_content(rename.old_path, rename.old_version, rename.new_version) + + rename.old_path.rename(rename.new_path) + + def update_file_content(self, file_path: Path, old_version: str, new_version: str) -> None: + """Update SQL query names and version comments in file content. + + Transforms query names and version metadata from old version to new version: + -- name: migrate-{old_version}-up → -- name: migrate-{new_version}-up + -- name: migrate-{old_version}-down → -- name: migrate-{new_version}-down + -- Version: {old_version} → -- Version: {new_version} + + Creates version-specific regex patterns to avoid unintended replacements + of other migrate-* patterns in the file. + + Args: + file_path: Path to file to update. + old_version: Old version string. + new_version: New version string. + + """ + content = file_path.read_text(encoding="utf-8") + + up_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-up)") + down_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-down)") + version_pattern = re.compile(rf"(-- Version:\s+){re.escape(old_version)}") + + content = up_pattern.sub(rf"\g<1>{new_version}\g<2>", content) + content = down_pattern.sub(rf"\g<1>{new_version}\g<2>", content) + content = version_pattern.sub(rf"\g<1>{new_version}", content) + + file_path.write_text(content, encoding="utf-8") + logger.debug("Updated content in %s", file_path.name) + + def rollback(self) -> None: + """Restore migration files from backup. + + Deletes current migration files and restores from backup directory. + Only restores if backup exists. + """ + if not self.backup_path or not self.backup_path.exists(): + return + + for file_path in self.migrations_path.iterdir(): + if file_path.is_file() and not file_path.name.startswith("."): + file_path.unlink() + + for backup_file in self.backup_path.iterdir(): + if backup_file.is_file(): + shutil.copy2(backup_file, self.migrations_path / backup_file.name) + + def cleanup(self) -> None: + """Remove backup directory after successful conversion. + + Only removes backup if it exists. Logs warning if no backup found. + """ + if not self.backup_path or not self.backup_path.exists(): + return + + shutil.rmtree(self.backup_path) + self.backup_path = None diff --git a/sqlspec/migrations/loaders.py b/sqlspec/migrations/loaders.py index 1114038d..48f0533d 100644 --- a/sqlspec/migrations/loaders.py +++ b/sqlspec/migrations/loaders.py @@ -77,13 +77,22 @@ class SQLFileLoader(BaseMigrationLoader): __slots__ = ("sql_loader",) - def __init__(self) -> None: - """Initialize SQL file loader.""" - self.sql_loader: CoreSQLFileLoader = CoreSQLFileLoader() + def __init__(self, sql_loader: "CoreSQLFileLoader | None" = None) -> None: + """Initialize SQL file loader. + + Args: + sql_loader: Optional shared SQLFileLoader instance to reuse. + If not provided, creates a new instance. + """ + self.sql_loader: CoreSQLFileLoader = sql_loader if sql_loader is not None else CoreSQLFileLoader() async def get_up_sql(self, path: Path) -> list[str]: """Extract the 'up' SQL from a SQL migration file. + The SQL file must already be loaded via validate_migration_file() + before calling this method. This design ensures the file is loaded + exactly once during the migration process. + Args: path: Path to SQL migration file. @@ -93,8 +102,6 @@ async def get_up_sql(self, path: Path) -> list[str]: Raises: MigrationLoadError: If migration file is invalid or missing up query. """ - self.sql_loader.load_sql(path) - version = self._extract_version(path.name) up_query = f"migrate-{version}-up" @@ -108,14 +115,16 @@ async def get_up_sql(self, path: Path) -> list[str]: async def get_down_sql(self, path: Path) -> list[str]: """Extract the 'down' SQL from a SQL migration file. + The SQL file must already be loaded via validate_migration_file() + before calling this method. This design ensures the file is loaded + exactly once during the migration process. + Args: path: Path to SQL migration file. Returns: List containing single SQL statement for downgrade, or empty list. """ - self.sql_loader.load_sql(path) - version = self._extract_version(path.name) down_query = f"migrate-{version}-down" @@ -148,14 +157,31 @@ def validate_migration_file(self, path: Path) -> None: def _extract_version(self, filename: str) -> str: """Extract version from filename. + Supports sequential (0001), timestamp (20251011120000), and extension-prefixed + (ext_litestar_0001) version formats. + Args: filename: Migration filename to parse. Returns: - Zero-padded version string or empty string if invalid. + Version string or empty string if invalid. """ - parts = filename.split("_", 1) - return parts[0].zfill(4) if parts and parts[0].isdigit() else "" + extension_version_parts = 3 + timestamp_min_length = 4 + + name_without_ext = filename.rsplit(".", 1)[0] + + if name_without_ext.startswith("ext_"): + parts = name_without_ext.split("_", 3) + if len(parts) >= extension_version_parts: + return f"{parts[0]}_{parts[1]}_{parts[2]}" + return "" + + parts = name_without_ext.split("_", 1) + if parts and parts[0].isdigit(): + return parts[0] if len(parts[0]) > timestamp_min_length else parts[0].zfill(4) + + return "" class PythonFileLoader(BaseMigrationLoader): @@ -391,7 +417,11 @@ def _normalize_and_validate_sql(self, sql: Any, migration_path: Path) -> list[st def get_migration_loader( - file_path: Path, migrations_dir: Path, project_root: "Path | None" = None, context: "Any | None" = None + file_path: Path, + migrations_dir: Path, + project_root: "Path | None" = None, + context: "Any | None" = None, + sql_loader: "CoreSQLFileLoader | None" = None, ) -> BaseMigrationLoader: """Factory function to get appropriate loader for migration file. @@ -400,6 +430,9 @@ def get_migration_loader( migrations_dir: Directory containing migration files. project_root: Optional project root directory for Python imports. context: Optional migration context to pass to Python migrations. + sql_loader: Optional shared SQLFileLoader instance for SQL migrations. + When provided, SQL files are loaded using this shared instance, + avoiding redundant file parsing. Returns: Appropriate loader instance for the file type. @@ -412,6 +445,6 @@ def get_migration_loader( if suffix == ".py": return PythonFileLoader(migrations_dir, project_root, context) if suffix == ".sql": - return SQLFileLoader() + return SQLFileLoader(sql_loader) msg = f"Unsupported migration file type: {suffix}" raise MigrationLoadError(msg) diff --git a/sqlspec/migrations/runner.py b/sqlspec/migrations/runner.py index 5e075f2e..6d3988c2 100644 --- a/sqlspec/migrations/runner.py +++ b/sqlspec/migrations/runner.py @@ -4,7 +4,6 @@ of concerns and proper type safety. """ -import operator import time from abc import ABC, abstractmethod from pathlib import Path @@ -56,24 +55,39 @@ def __init__( def _extract_version(self, filename: str) -> "str | None": """Extract version from filename. + Supports sequential (0001), timestamp (20251011120000), and extension-prefixed + (ext_litestar_0001) version formats. + Args: filename: The migration filename. Returns: The extracted version string or None. """ - # Handle extension-prefixed versions (e.g., "ext_litestar_0001") - if filename.startswith("ext_"): - # This is already a prefixed version, return as-is - return filename + extension_version_parts = 3 + timestamp_min_length = 4 + + name_without_ext = filename.rsplit(".", 1)[0] + + if name_without_ext.startswith("ext_"): + parts = name_without_ext.split("_", 3) + if len(parts) >= extension_version_parts: + return f"{parts[0]}_{parts[1]}_{parts[2]}" + return None + + parts = name_without_ext.split("_", 1) + if parts and parts[0].isdigit(): + return parts[0] if len(parts[0]) > timestamp_min_length else parts[0].zfill(4) - # Regular version extraction - parts = filename.split("_", 1) - return parts[0].zfill(4) if parts and parts[0].isdigit() else None + return None def _calculate_checksum(self, content: str) -> str: """Calculate MD5 checksum of migration content. + Canonicalizes content by excluding query name headers that change during + fix command (migrate-{version}-up/down). This ensures checksums remain + stable when converting timestamp versions to sequential format. + Args: content: The migration file content. @@ -81,8 +95,11 @@ def _calculate_checksum(self, content: str) -> str: MD5 checksum hex string. """ import hashlib + import re - return hashlib.md5(content.encode()).hexdigest() # noqa: S324 + canonical_content = re.sub(r"^--\s*name:\s*migrate-[^-]+-(?:up|down)\s*$", "", content, flags=re.MULTILINE) + + return hashlib.md5(canonical_content.encode()).hexdigest() # noqa: S324 @abstractmethod def load_migration(self, file_path: Path) -> Union["dict[str, Any]", "Coroutine[Any, Any, dict[str, Any]]"]: @@ -129,7 +146,16 @@ def _get_migration_files_sync(self) -> "list[tuple[str, Path]]": prefixed_version = f"ext_{ext_name}_{version}" migrations.append((prefixed_version, file_path)) - return sorted(migrations, key=operator.itemgetter(0)) + from sqlspec.utils.version import parse_version + + def version_sort_key(migration_tuple: "tuple[str, Path]") -> "Any": + version_str = migration_tuple[0] + try: + return parse_version(version_str) + except ValueError: + return version_str + + return sorted(migrations, key=version_sort_key) def get_migration_files(self) -> "list[tuple[str, Path]]": """Get all migration files sorted by version. @@ -220,7 +246,7 @@ def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict metadata = self._load_migration_metadata_common(file_path, version) context_to_use = self._get_context_for_migration(file_path) - loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use) + loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use, self.loader) loader.validate_migration_file(file_path) has_upgrade, has_downgrade = True, False @@ -228,7 +254,6 @@ def load_migration(self, file_path: Path, version: "str | None" = None) -> "dict if file_path.suffix == ".sql": version = metadata["version"] up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down" - self.loader.load_sql(file_path) has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query) else: try: @@ -344,7 +369,9 @@ def load_all_migrations(self) -> "dict[str, SQL]": for query_name in self.loader.list_queries(): all_queries[query_name] = self.loader.get_sql(query_name) else: - loader = get_migration_loader(file_path, self.migrations_path, self.project_root, self.context) + loader = get_migration_loader( + file_path, self.migrations_path, self.project_root, self.context, self.loader + ) try: up_sql = await_(loader.get_up_sql)(file_path) @@ -385,7 +412,7 @@ async def load_migration(self, file_path: Path, version: "str | None" = None) -> metadata = self._load_migration_metadata_common(file_path, version) context_to_use = self._get_context_for_migration(file_path) - loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use) + loader = get_migration_loader(file_path, self.migrations_path, self.project_root, context_to_use, self.loader) loader.validate_migration_file(file_path) has_upgrade, has_downgrade = True, False @@ -393,7 +420,6 @@ async def load_migration(self, file_path: Path, version: "str | None" = None) -> if file_path.suffix == ".sql": version = metadata["version"] up_query, down_query = f"migrate-{version}-up", f"migrate-{version}-down" - await async_(self.loader.load_sql)(file_path) has_upgrade, has_downgrade = self.loader.has_query(up_query), self.loader.has_query(down_query) else: try: @@ -505,7 +531,9 @@ async def load_all_migrations(self) -> "dict[str, SQL]": for query_name in self.loader.list_queries(): all_queries[query_name] = self.loader.get_sql(query_name) else: - loader = get_migration_loader(file_path, self.migrations_path, self.project_root, self.context) + loader = get_migration_loader( + file_path, self.migrations_path, self.project_root, self.context, self.loader + ) try: up_sql = await loader.get_up_sql(file_path) diff --git a/sqlspec/migrations/tracker.py b/sqlspec/migrations/tracker.py index 8ce9b850..f6d19349 100644 --- a/sqlspec/migrations/tracker.py +++ b/sqlspec/migrations/tracker.py @@ -6,8 +6,12 @@ import os from typing import TYPE_CHECKING, Any +from rich.console import Console + +from sqlspec.builder import sql from sqlspec.migrations.base import BaseMigrationTracker from sqlspec.utils.logging import get_logger +from sqlspec.utils.version import parse_version if TYPE_CHECKING: from sqlspec.driver import AsyncDriverAdapterBase, SyncDriverAdapterBase @@ -20,15 +24,74 @@ class SyncMigrationTracker(BaseMigrationTracker["SyncDriverAdapterBase"]): """Synchronous migration version tracker.""" + def _migrate_schema_if_needed(self, driver: "SyncDriverAdapterBase") -> None: + """Check for and add any missing columns to the tracking table. + + Uses the adapter's data_dictionary to query existing columns, + then compares to the target schema and adds missing columns one by one. + + Args: + driver: The database driver to use. + """ + try: + columns_data = driver.data_dictionary.get_columns(driver, self.version_table) + if not columns_data: + logger.debug("Migration tracking table does not exist yet") + return + + existing_columns = {col["column_name"] for col in columns_data} + missing_columns = self._detect_missing_columns(existing_columns) + + if not missing_columns: + logger.debug("Migration tracking table schema is up-to-date") + return + + console = Console() + console.print( + f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]" + ) + + for col_name in sorted(missing_columns): + self._add_column(driver, col_name) + + driver.commit() + console.print("[green]Migration tracking table schema updated successfully[/]") + + except Exception as e: + logger.warning("Could not check or migrate tracking table schema: %s", e) + + def _add_column(self, driver: "SyncDriverAdapterBase", column_name: str) -> None: + """Add a single column to the tracking table. + + Args: + driver: The database driver to use. + column_name: Name of the column to add (lowercase). + """ + target_create = self._get_create_table_sql() + column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None) + + if not column_def: + return + + alter_sql = sql.alter_table(self.version_table).add_column( + name=column_def.name, dtype=column_def.dtype, default=column_def.default, not_null=column_def.not_null + ) + driver.execute(alter_sql) + logger.debug("Added column %s to tracking table", column_name) + def ensure_tracking_table(self, driver: "SyncDriverAdapterBase") -> None: """Create the migration tracking table if it doesn't exist. + Also checks for and adds any missing columns to support schema migrations. + Args: driver: The database driver to use. """ driver.execute(self._get_create_table_sql()) self._safe_commit(driver) + self._migrate_schema_if_needed(driver) + def get_current_version(self, driver: "SyncDriverAdapterBase") -> str | None: """Get the latest applied migration version. @@ -58,6 +121,9 @@ def record_migration( ) -> None: """Record a successfully applied migration. + Parses version to determine type (sequential or timestamp) and + auto-increments execution_sequence for application order tracking. + Args: driver: The database driver to use. version: Version number of the migration. @@ -65,9 +131,21 @@ def record_migration( execution_time_ms: Execution time in milliseconds. checksum: MD5 checksum of the migration content. """ + parsed_version = parse_version(version) + version_type = parsed_version.type.value + + result = driver.execute(self._get_next_execution_sequence_sql()) + next_sequence = result.data[0]["next_seq"] if result.data else 1 + driver.execute( self._get_record_migration_sql( - version, description, execution_time_ms, checksum, os.environ.get("USER", "unknown") + version, + version_type, + next_sequence, + description, + execution_time_ms, + checksum, + os.environ.get("USER", "unknown"), ) ) self._safe_commit(driver) @@ -82,21 +160,52 @@ def remove_migration(self, driver: "SyncDriverAdapterBase", version: str) -> Non driver.execute(self._get_remove_migration_sql(version)) self._safe_commit(driver) - def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None: - """Safely commit a transaction only if autocommit is disabled. + def update_version_record(self, driver: "SyncDriverAdapterBase", old_version: str, new_version: str) -> None: + """Update migration version record from timestamp to sequential. + + Updates version_num and version_type while preserving execution_sequence, + applied_at, and other tracking metadata. Used during fix command. + + Idempotent: If the version is already updated, logs and continues without error. + This allows fix command to be safely re-run after pulling changes. Args: driver: The database driver to use. + old_version: Current timestamp version string. + new_version: New sequential version string. + + Raises: + ValueError: If neither old_version nor new_version found in database. """ - try: - connection = getattr(driver, "connection", None) - if connection and hasattr(connection, "autocommit") and getattr(connection, "autocommit", False): - return + parsed_new_version = parse_version(new_version) + new_version_type = parsed_new_version.type.value + + result = driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type)) - driver_features = getattr(driver, "driver_features", {}) - if driver_features and driver_features.get("autocommit", False): + if result.rows_affected == 0: + check_result = driver.execute(self._get_applied_migrations_sql()) + applied_versions = {row["version_num"] for row in check_result.data} if check_result.data else set() + + if new_version in applied_versions: + logger.debug("Version already updated: %s -> %s", old_version, new_version) return + msg = f"Migration version {old_version} not found in database" + raise ValueError(msg) + + self._safe_commit(driver) + logger.debug("Updated version record: %s -> %s", old_version, new_version) + + def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None: + """Safely commit a transaction only if autocommit is disabled. + + Args: + driver: The database driver to use. + """ + if driver.driver_features.get("autocommit", False): + return + + try: driver.commit() except Exception: logger.debug("Failed to commit transaction, likely due to autocommit being enabled") @@ -105,15 +214,76 @@ def _safe_commit(self, driver: "SyncDriverAdapterBase") -> None: class AsyncMigrationTracker(BaseMigrationTracker["AsyncDriverAdapterBase"]): """Asynchronous migration version tracker.""" + async def _migrate_schema_if_needed(self, driver: "AsyncDriverAdapterBase") -> None: + """Check for and add any missing columns to the tracking table. + + Uses the driver's data_dictionary to query existing columns, + then compares to the target schema and adds missing columns one by one. + + Args: + driver: The database driver to use. + """ + try: + columns_data = await driver.data_dictionary.get_columns(driver, self.version_table) + if not columns_data: + logger.debug("Migration tracking table does not exist yet") + return + + existing_columns = {col["column_name"] for col in columns_data} + missing_columns = self._detect_missing_columns(existing_columns) + + if not missing_columns: + logger.debug("Migration tracking table schema is up-to-date") + return + + from rich.console import Console + + console = Console() + console.print( + f"[cyan]Migrating tracking table schema, adding columns: {', '.join(sorted(missing_columns))}[/]" + ) + + for col_name in sorted(missing_columns): + await self._add_column(driver, col_name) + + await driver.commit() + console.print("[green]Migration tracking table schema updated successfully[/]") + + except Exception as e: + logger.warning("Could not check or migrate tracking table schema: %s", e) + + async def _add_column(self, driver: "AsyncDriverAdapterBase", column_name: str) -> None: + """Add a single column to the tracking table. + + Args: + driver: The database driver to use. + column_name: Name of the column to add (lowercase). + """ + target_create = self._get_create_table_sql() + column_def = next((col for col in target_create.columns if col.name.lower() == column_name), None) + + if not column_def: + return + + alter_sql = sql.alter_table(self.version_table).add_column( + name=column_def.name, dtype=column_def.dtype, default=column_def.default, not_null=column_def.not_null + ) + await driver.execute(alter_sql) + logger.debug("Added column %s to tracking table", column_name) + async def ensure_tracking_table(self, driver: "AsyncDriverAdapterBase") -> None: """Create the migration tracking table if it doesn't exist. + Also checks for and adds any missing columns to support schema migrations. + Args: driver: The database driver to use. """ await driver.execute(self._get_create_table_sql()) await self._safe_commit_async(driver) + await self._migrate_schema_if_needed(driver) + async def get_current_version(self, driver: "AsyncDriverAdapterBase") -> str | None: """Get the latest applied migration version. @@ -143,6 +313,9 @@ async def record_migration( ) -> None: """Record a successfully applied migration. + Parses version to determine type (sequential or timestamp) and + auto-increments execution_sequence for application order tracking. + Args: driver: The database driver to use. version: Version number of the migration. @@ -150,9 +323,21 @@ async def record_migration( execution_time_ms: Execution time in milliseconds. checksum: MD5 checksum of the migration content. """ + parsed_version = parse_version(version) + version_type = parsed_version.type.value + + result = await driver.execute(self._get_next_execution_sequence_sql()) + next_sequence = result.data[0]["next_seq"] if result.data else 1 + await driver.execute( self._get_record_migration_sql( - version, description, execution_time_ms, checksum, os.environ.get("USER", "unknown") + version, + version_type, + next_sequence, + description, + execution_time_ms, + checksum, + os.environ.get("USER", "unknown"), ) ) await self._safe_commit_async(driver) @@ -167,21 +352,52 @@ async def remove_migration(self, driver: "AsyncDriverAdapterBase", version: str) await driver.execute(self._get_remove_migration_sql(version)) await self._safe_commit_async(driver) - async def _safe_commit_async(self, driver: "AsyncDriverAdapterBase") -> None: - """Safely commit a transaction only if autocommit is disabled. + async def update_version_record(self, driver: "AsyncDriverAdapterBase", old_version: str, new_version: str) -> None: + """Update migration version record from timestamp to sequential. + + Updates version_num and version_type while preserving execution_sequence, + applied_at, and other tracking metadata. Used during fix command. + + Idempotent: If the version is already updated, logs and continues without error. + This allows fix command to be safely re-run after pulling changes. Args: driver: The database driver to use. + old_version: Current timestamp version string. + new_version: New sequential version string. + + Raises: + ValueError: If neither old_version nor new_version found in database. """ - try: - connection = getattr(driver, "connection", None) - if connection and hasattr(connection, "autocommit") and getattr(connection, "autocommit", False): - return + parsed_new_version = parse_version(new_version) + new_version_type = parsed_new_version.type.value + + result = await driver.execute(self._get_update_version_sql(old_version, new_version, new_version_type)) - driver_features = getattr(driver, "driver_features", {}) - if driver_features and driver_features.get("autocommit", False): + if result.rows_affected == 0: + check_result = await driver.execute(self._get_applied_migrations_sql()) + applied_versions = {row["version_num"] for row in check_result.data} if check_result.data else set() + + if new_version in applied_versions: + logger.debug("Version already updated: %s -> %s", old_version, new_version) return + msg = f"Migration version {old_version} not found in database" + raise ValueError(msg) + + await self._safe_commit_async(driver) + logger.debug("Updated version record: %s -> %s", old_version, new_version) + + async def _safe_commit_async(self, driver: "AsyncDriverAdapterBase") -> None: + """Safely commit a transaction only if autocommit is disabled. + + Args: + driver: The database driver to use. + """ + if driver.driver_features.get("autocommit", False): + return + + try: await driver.commit() except Exception: logger.debug("Failed to commit transaction, likely due to autocommit being enabled") diff --git a/sqlspec/migrations/validation.py b/sqlspec/migrations/validation.py new file mode 100644 index 00000000..2a34b6c9 --- /dev/null +++ b/sqlspec/migrations/validation.py @@ -0,0 +1,177 @@ +"""Migration validation and out-of-order detection for SQLSpec. + +This module provides functionality to detect and handle out-of-order migrations, +which can occur when branches with migrations merge in different orders across +staging and production environments. +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from rich.console import Console + +from sqlspec.exceptions import OutOfOrderMigrationError +from sqlspec.utils.version import parse_version + +if TYPE_CHECKING: + from sqlspec.utils.version import MigrationVersion + +__all__ = ("MigrationGap", "detect_out_of_order_migrations", "format_out_of_order_warning") + +console = Console() + + +@dataclass(frozen=True) +class MigrationGap: + """Represents a migration that is out of order. + + An out-of-order migration occurs when a pending migration has a timestamp + earlier than already-applied migrations, indicating it was created in a branch + that merged after other migrations were already applied. + + Attributes: + missing_version: The out-of-order migration version. + applied_after: List of already-applied migrations with later timestamps. + """ + + missing_version: "MigrationVersion" + applied_after: "list[MigrationVersion]" + + +def detect_out_of_order_migrations( + pending_versions: "list[str]", applied_versions: "list[str]" +) -> "list[MigrationGap]": + """Detect migrations created before already-applied migrations. + + Identifies pending migrations with timestamps earlier than the latest applied + migration, which indicates they were created in branches that merged late or + were cherry-picked across environments. + + Extension migrations are excluded from out-of-order detection as they maintain + independent sequences within their own namespaces. + + Args: + pending_versions: List of migration versions not yet applied. + applied_versions: List of migration versions already applied. + + Returns: + List of migration gaps representing out-of-order migrations. + Empty list if no out-of-order migrations detected. + + Example: + Applied: [20251011120000, 20251012140000] + Pending: [20251011130000, 20251013090000] + Result: Gap for 20251011130000 (created between applied migrations) + + Applied: [ext_litestar_0001, 0001, 0002] + Pending: [ext_adk_0001] + Result: [] (extensions excluded from out-of-order detection) + """ + if not applied_versions or not pending_versions: + return [] + + gaps: list[MigrationGap] = [] + + parsed_applied = [parse_version(v) for v in applied_versions] + parsed_pending = [parse_version(v) for v in pending_versions] + + core_applied = [v for v in parsed_applied if v.extension is None] + core_pending = [v for v in parsed_pending if v.extension is None] + + if not core_applied or not core_pending: + return [] + + latest_applied = max(core_applied) + + for pending in core_pending: + if pending < latest_applied: + applied_after = [a for a in core_applied if a > pending] + if applied_after: + gaps.append(MigrationGap(missing_version=pending, applied_after=applied_after)) + + return gaps + + +def format_out_of_order_warning(gaps: "list[MigrationGap]") -> str: + """Create user-friendly warning message for out-of-order migrations. + + Formats migration gaps into a clear warning message explaining which migrations + are out of order and what migrations were already applied after them. + + Args: + gaps: List of migration gaps to format. + + Returns: + Formatted warning message string. + + Example: + >>> gaps = [MigrationGap(version1, [version2, version3])] + >>> print(format_out_of_order_warning(gaps)) + Out-of-order migrations detected: + + - 20251011130000 created before: + - 20251012140000 + - 20251013090000 + """ + if not gaps: + return "" + + lines = ["Out-of-order migrations detected:", ""] + + for gap in gaps: + lines.append(f"- {gap.missing_version.raw} created before:") + lines.extend(f" - {applied.raw}" for applied in gap.applied_after) + lines.append("") + + lines.extend( + ( + "These migrations will be applied but may cause issues if they", + "depend on schema changes from later migrations.", + "", + "To prevent this in the future, ensure migrations are merged in", + "chronological order or use strict_ordering mode in migration_config.", + ) + ) + + return "\n".join(lines) + + +def validate_migration_order( + pending_versions: "list[str]", applied_versions: "list[str]", strict_ordering: bool = False +) -> None: + """Validate migration order and raise error if out-of-order in strict mode. + + Checks for out-of-order migrations and either warns or raises an error + depending on the strict_ordering configuration. + + Args: + pending_versions: List of migration versions not yet applied. + applied_versions: List of migration versions already applied. + strict_ordering: If True, raise error for out-of-order migrations. + If False (default), log warning but allow. + + Raises: + OutOfOrderMigrationError: If out-of-order migrations detected and + strict_ordering is True. + + Example: + >>> validate_migration_order( + ... ["20251011130000"], + ... ["20251012140000"], + ... strict_ordering=True, + ... ) + OutOfOrderMigrationError: Out-of-order migrations detected... + """ + gaps = detect_out_of_order_migrations(pending_versions, applied_versions) + + if not gaps: + return + + warning_message = format_out_of_order_warning(gaps) + + if strict_ordering: + msg = f"{warning_message}\n\nStrict ordering is enabled. Use --allow-missing to override." + raise OutOfOrderMigrationError(msg) + + console.print("[yellow]Out-of-order migrations detected[/]") + console.print(f"[yellow]{warning_message}[/]") diff --git a/sqlspec/utils/version.py b/sqlspec/utils/version.py new file mode 100644 index 00000000..6de021f5 --- /dev/null +++ b/sqlspec/utils/version.py @@ -0,0 +1,433 @@ +"""Migration version parsing and comparison utilities. + +Provides structured parsing of migration versions supporting both legacy sequential +(0001) and timestamp-based (20251011120000) formats with type-safe comparison. +""" + +import logging +import re +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from typing import Any + +__all__ = ( + "MigrationVersion", + "VersionType", + "convert_to_sequential_version", + "generate_conversion_map", + "generate_timestamp_version", + "get_next_sequential_number", + "is_sequential_version", + "is_timestamp_version", + "parse_version", +) + +logger = logging.getLogger(__name__) + +# Regex patterns for version detection +SEQUENTIAL_PATTERN = re.compile(r"^(?!\d{14}$)(\d+)$") +TIMESTAMP_PATTERN = re.compile(r"^(\d{14})$") +EXTENSION_PATTERN = re.compile(r"^ext_(\w+)_(.+)$") + + +class VersionType(Enum): + """Migration version format type.""" + + SEQUENTIAL = "sequential" + TIMESTAMP = "timestamp" + + +@dataclass(frozen=True) +class MigrationVersion: + """Parsed migration version with structured comparison support. + + Attributes: + raw: Original version string (e.g., "0001", "20251011120000", "ext_litestar_0001"). + type: Version format type (sequential or timestamp). + sequence: Numeric value for sequential versions (e.g., 1, 2, 42). + timestamp: Parsed datetime for timestamp versions (UTC). + extension: Extension name for extension-prefixed versions (e.g., "litestar"). + """ + + raw: str + type: VersionType + sequence: "int | None" + timestamp: "datetime | None" + extension: "str | None" + + def __lt__(self, other: "MigrationVersion") -> bool: + """Compare versions supporting mixed formats. + + Comparison Rules: + 1. Extension migrations sort by extension name first, then version + 2. Sequential < Timestamp (legacy migrations first) + 3. Sequential vs Sequential: numeric comparison + 4. Timestamp vs Timestamp: chronological comparison + + Args: + other: Version to compare against. + + Returns: + True if this version sorts before other. + + Raises: + TypeError: If comparing against non-MigrationVersion. + """ + if not isinstance(other, MigrationVersion): + return NotImplemented + + if self.extension != other.extension: + if self.extension is None: + return True + if other.extension is None: + return False + return self.extension < other.extension + + if self.type == other.type: + if self.type == VersionType.SEQUENTIAL: + return (self.sequence or 0) < (other.sequence or 0) + return (self.timestamp or datetime.min.replace(tzinfo=timezone.utc)) < ( + other.timestamp or datetime.min.replace(tzinfo=timezone.utc) + ) + + return self.type == VersionType.SEQUENTIAL + + def __le__(self, other: "MigrationVersion") -> bool: + """Check if version is less than or equal to another. + + Args: + other: Version to compare against. + + Returns: + True if this version is less than or equal to other. + """ + return self == other or self < other + + def __eq__(self, other: object) -> bool: + """Check version equality. + + Args: + other: Version to compare against. + + Returns: + True if versions are equal. + """ + if not isinstance(other, MigrationVersion): + return NotImplemented + return self.raw == other.raw + + def __hash__(self) -> int: + """Hash version for use in sets and dicts. + + Returns: + Hash value based on raw version string. + """ + return hash(self.raw) + + def __repr__(self) -> str: + """Get string representation for debugging. + + Returns: + String representation with type and value. + """ + if self.extension: + return f"MigrationVersion(ext={self.extension}, {self.type.value}={self.raw})" + return f"MigrationVersion({self.type.value}={self.raw})" + + +def is_sequential_version(version_str: str) -> bool: + """Check if version string is sequential format. + + Sequential format: Any sequence of digits (0001, 42, 9999, 10000+). + + Args: + version_str: Version string to check. + + Returns: + True if sequential format. + + Examples: + >>> is_sequential_version("0001") + True + >>> is_sequential_version("42") + True + >>> is_sequential_version("10000") + True + >>> is_sequential_version("20251011120000") + False + """ + return bool(SEQUENTIAL_PATTERN.match(version_str)) + + +def is_timestamp_version(version_str: str) -> bool: + """Check if version string is timestamp format. + + Timestamp format: 14-digit YYYYMMDDHHmmss (20251011120000). + + Args: + version_str: Version string to check. + + Returns: + True if timestamp format. + + Examples: + >>> is_timestamp_version("20251011120000") + True + >>> is_timestamp_version("0001") + False + """ + if not TIMESTAMP_PATTERN.match(version_str): + return False + + try: + datetime.strptime(version_str, "%Y%m%d%H%M%S").replace(tzinfo=timezone.utc) + except ValueError: + return False + else: + return True + + +def parse_version(version_str: str) -> MigrationVersion: + """Parse version string into structured format. + + Supports: + - Sequential: "0001", "42", "9999" + - Timestamp: "20251011120000" + - Extension: "ext_litestar_0001", "ext_litestar_20251011120000" + + Args: + version_str: Version string to parse. + + Returns: + Parsed migration version. + + Raises: + ValueError: If version format is invalid. + + Examples: + >>> v = parse_version("0001") + >>> v.type == VersionType.SEQUENTIAL + True + >>> v.sequence + 1 + + >>> v = parse_version("20251011120000") + >>> v.type == VersionType.TIMESTAMP + True + + >>> v = parse_version("ext_litestar_0001") + >>> v.extension + 'litestar' + """ + extension_match = EXTENSION_PATTERN.match(version_str) + if extension_match: + extension_name = extension_match.group(1) + base_version = extension_match.group(2) + parsed = parse_version(base_version) + + return MigrationVersion( + raw=version_str, + type=parsed.type, + sequence=parsed.sequence, + timestamp=parsed.timestamp, + extension=extension_name, + ) + + if is_timestamp_version(version_str): + dt = datetime.strptime(version_str, "%Y%m%d%H%M%S").replace(tzinfo=timezone.utc) + return MigrationVersion( + raw=version_str, type=VersionType.TIMESTAMP, sequence=None, timestamp=dt, extension=None + ) + + if is_sequential_version(version_str): + return MigrationVersion( + raw=version_str, type=VersionType.SEQUENTIAL, sequence=int(version_str), timestamp=None, extension=None + ) + + msg = f"Invalid migration version format: {version_str}. Expected sequential (0001) or timestamp (YYYYMMDDHHmmss)." + raise ValueError(msg) + + +def generate_timestamp_version() -> str: + """Generate new timestamp version in UTC. + + Format: YYYYMMDDHHmmss (14 digits). + + Returns: + Timestamp version string. + + Examples: + >>> version = generate_timestamp_version() + >>> len(version) + 14 + >>> is_timestamp_version(version) + True + """ + return datetime.now(tz=timezone.utc).strftime("%Y%m%d%H%M%S") + + +def get_next_sequential_number(migrations: "list[MigrationVersion]", extension: "str | None" = None) -> int: + """Find highest sequential number and return next available. + + Scans migrations for sequential versions and returns the next number in sequence. + When extension is specified, only that extension's migrations are considered. + When extension is None, only core (non-extension) migrations are considered. + + Args: + migrations: List of parsed migration versions. + extension: Optional extension name to filter by (e.g., "litestar", "adk"). + None means core migrations only. + + Returns: + Next available sequential number (1 if no sequential migrations exist). + + Examples: + >>> v1 = parse_version("0001") + >>> v2 = parse_version("0002") + >>> get_next_sequential_number([v1, v2]) + 3 + + >>> get_next_sequential_number([]) + 1 + + >>> ext = parse_version("ext_litestar_0001") + >>> core = parse_version("0001") + >>> get_next_sequential_number([ext, core]) + 2 + + >>> ext1 = parse_version("ext_litestar_0001") + >>> get_next_sequential_number([ext1], extension="litestar") + 2 + """ + sequential = [ + m.sequence for m in migrations if m.type == VersionType.SEQUENTIAL and m.extension == extension and m.sequence + ] + + if not sequential: + return 1 + + return max(sequential) + 1 + + +def convert_to_sequential_version(timestamp_version: MigrationVersion, sequence_number: int) -> str: + """Convert timestamp MigrationVersion to sequential string format. + + Preserves extension prefixes during conversion. Format uses zero-padded + 4-digit numbers (0001, 0002, etc.). + + Args: + timestamp_version: Parsed timestamp version to convert. + sequence_number: Sequential number to assign. + + Returns: + Sequential version string with extension prefix if applicable. + + Raises: + ValueError: If input is not a timestamp version. + + Examples: + >>> v = parse_version("20251011120000") + >>> convert_to_sequential_version(v, 3) + '0003' + + >>> v = parse_version("ext_litestar_20251011120000") + >>> convert_to_sequential_version(v, 1) + 'ext_litestar_0001' + + >>> v = parse_version("0001") + >>> convert_to_sequential_version(v, 2) + Traceback (most recent call last): + ... + ValueError: Can only convert timestamp versions to sequential + """ + if timestamp_version.type != VersionType.TIMESTAMP: + msg = "Can only convert timestamp versions to sequential" + raise ValueError(msg) + + seq_str = str(sequence_number).zfill(4) + + if timestamp_version.extension: + return f"ext_{timestamp_version.extension}_{seq_str}" + + return seq_str + + +def generate_conversion_map(migrations: "list[tuple[str, Any]]") -> "dict[str, str]": + """Generate mapping from timestamp versions to sequential versions. + + Separates timestamp migrations from sequential, sorts timestamps chronologically, + and assigns sequential numbers starting after the highest existing sequential + number. Extension migrations maintain separate numbering within their namespace. + + Args: + migrations: List of tuples (version_string, migration_path). + + Returns: + Dictionary mapping old timestamp versions to new sequential versions. + + Examples: + >>> migrations = [ + ... ("0001", Path("0001_init.sql")), + ... ("0002", Path("0002_users.sql")), + ... ("20251011120000", Path("20251011120000_products.sql")), + ... ("20251012130000", Path("20251012130000_orders.sql")), + ... ] + >>> result = generate_conversion_map(migrations) + >>> result + {'20251011120000': '0003', '20251012130000': '0004'} + + >>> migrations = [ + ... ("20251011120000", Path("20251011120000_first.sql")), + ... ("20251010090000", Path("20251010090000_earlier.sql")), + ... ] + >>> result = generate_conversion_map(migrations) + >>> result + {'20251010090000': '0001', '20251011120000': '0002'} + + >>> migrations = [] + >>> generate_conversion_map(migrations) + {} + """ + if not migrations: + return {} + + def _try_parse_version(version_str: str) -> "MigrationVersion | None": + """Parse version string, returning None for invalid versions.""" + try: + return parse_version(version_str) + except ValueError: + logger.warning("Skipping invalid migration version: %s", version_str) + return None + + parsed_versions = [v for version_str, _path in migrations if (v := _try_parse_version(version_str)) is not None] + + timestamp_migrations = sorted([v for v in parsed_versions if v.type == VersionType.TIMESTAMP]) + + if not timestamp_migrations: + return {} + + core_timestamps = [m for m in timestamp_migrations if m.extension is None] + ext_timestamps_by_name: dict[str, list[MigrationVersion]] = {} + for m in timestamp_migrations: + if m.extension: + ext_timestamps_by_name.setdefault(m.extension, []).append(m) + + conversion_map: dict[str, str] = {} + + if core_timestamps: + next_seq = get_next_sequential_number(parsed_versions) + for timestamp_version in core_timestamps: + sequential_version = convert_to_sequential_version(timestamp_version, next_seq) + conversion_map[timestamp_version.raw] = sequential_version + next_seq += 1 + + for ext_name, ext_migrations in ext_timestamps_by_name.items(): + ext_parsed = [v for v in parsed_versions if v.extension == ext_name] + next_seq = get_next_sequential_number(ext_parsed, extension=ext_name) + for timestamp_version in ext_migrations: + sequential_version = convert_to_sequential_version(timestamp_version, next_seq) + conversion_map[timestamp_version.raw] = sequential_version + next_seq += 1 + + return conversion_map diff --git a/tests/integration/test_adapters/test_asyncpg/test_schema_migration.py b/tests/integration/test_adapters/test_asyncpg/test_schema_migration.py new file mode 100644 index 00000000..933221ea --- /dev/null +++ b/tests/integration/test_adapters/test_asyncpg/test_schema_migration.py @@ -0,0 +1,321 @@ +"""Integration tests for migration tracking table schema migration with PostgreSQL.""" + +import pytest +from pytest_databases.docker.postgres import PostgresService + +from sqlspec.adapters.asyncpg import AsyncpgConfig +from sqlspec.migrations.tracker import AsyncMigrationTracker + + +def _create_config(postgres_service: PostgresService) -> AsyncpgConfig: + """Create AsyncpgConfig from PostgresService fixture.""" + return AsyncpgConfig( + pool_config={ + "host": postgres_service.host, + "port": postgres_service.port, + "user": postgres_service.user, + "password": postgres_service.password, + "database": postgres_service.database, + } + ) + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_tracker_creates_full_schema(postgres_service: PostgresService) -> None: + """Test AsyncPG tracker creates complete schema with all columns.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await tracker.ensure_tracking_table(driver) + + result = await driver.execute(f""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = '{tracker.version_table}' + """) + + columns = {row["column_name"] for row in result.data or []} + + expected_columns = { + "version_num", + "version_type", + "execution_sequence", + "description", + "applied_at", + "execution_time_ms", + "checksum", + "applied_by", + } + + assert columns == expected_columns + finally: + await config.close_pool() + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_tracker_migrates_legacy_schema(postgres_service: PostgresService) -> None: + """Test AsyncPG tracker adds missing columns to legacy schema.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await driver.execute(f"DROP TABLE IF EXISTS {tracker.version_table}") + await driver.execute(f""" + CREATE TABLE {tracker.version_table} ( + version_num VARCHAR(32) PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """) + await driver.commit() + + await tracker.ensure_tracking_table(driver) + + result = await driver.execute(f""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = '{tracker.version_table}' + """) + + columns = {row["column_name"] for row in result.data or []} + + assert "version_type" in columns + assert "execution_sequence" in columns + assert "checksum" in columns + assert "execution_time_ms" in columns + assert "applied_by" in columns + finally: + await config.close_pool() + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_tracker_migration_preserves_data(postgres_service: PostgresService) -> None: + """Test AsyncPG schema migration preserves existing records.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await driver.execute(f"DROP TABLE IF EXISTS {tracker.version_table}") + await driver.execute(f""" + CREATE TABLE {tracker.version_table} ( + version_num VARCHAR(32) PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """) + + await driver.execute(f""" + INSERT INTO {tracker.version_table} (version_num, description) + VALUES ('0001', 'Initial migration') + """) + await driver.commit() + + await tracker.ensure_tracking_table(driver) + + result = await driver.execute(f"SELECT * FROM {tracker.version_table}") + records = result.data or [] + + assert len(records) == 1 + assert records[0]["version_num"] == "0001" + assert records[0]["description"] == "Initial migration" + finally: + await config.close_pool() + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_tracker_version_type_recording(postgres_service: PostgresService) -> None: + """Test AsyncPG tracker correctly records version_type.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await driver.execute(f"DROP TABLE IF EXISTS {tracker.version_table}") + await tracker.ensure_tracking_table(driver) + + await tracker.record_migration(driver, "0001", "Sequential", 100, "checksum1") + await tracker.record_migration(driver, "20251011120000", "Timestamp", 150, "checksum2") + + result = await driver.execute(f""" + SELECT version_num, version_type + FROM {tracker.version_table} + ORDER BY execution_sequence + """) + records = result.data or [] + + assert len(records) == 2 + assert records[0]["version_num"] == "0001" + assert records[0]["version_type"] == "sequential" + assert records[1]["version_num"] == "20251011120000" + assert records[1]["version_type"] == "timestamp" + finally: + await config.close_pool() + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_tracker_execution_sequence(postgres_service: PostgresService) -> None: + """Test AsyncPG tracker execution_sequence auto-increments.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await driver.execute(f"DROP TABLE IF EXISTS {tracker.version_table}") + await tracker.ensure_tracking_table(driver) + + await tracker.record_migration(driver, "0001", "First", 100, "checksum1") + await tracker.record_migration(driver, "0003", "Out of order", 100, "checksum3") + await tracker.record_migration(driver, "0002", "Late merge", 100, "checksum2") + + result = await driver.execute(f""" + SELECT version_num, execution_sequence + FROM {tracker.version_table} + ORDER BY execution_sequence + """) + records = result.data or [] + + assert len(records) == 3 + assert records[0]["execution_sequence"] == 1 + assert records[1]["execution_sequence"] == 2 + assert records[2]["execution_sequence"] == 3 + + assert records[0]["version_num"] == "0001" + assert records[1]["version_num"] == "0003" + assert records[2]["version_num"] == "0002" + finally: + await config.close_pool() + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_get_current_version_uses_execution_sequence(postgres_service: PostgresService) -> None: + """Test AsyncPG get_current_version uses execution order.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await driver.execute(f"DROP TABLE IF EXISTS {tracker.version_table}") + await tracker.ensure_tracking_table(driver) + + await tracker.record_migration(driver, "0001", "First", 100, "checksum1") + await tracker.record_migration(driver, "0003", "Out of order", 100, "checksum3") + await tracker.record_migration(driver, "0002", "Late merge", 100, "checksum2") + + current = await tracker.get_current_version(driver) + + assert current == "0002" + finally: + await config.close_pool() + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_update_version_record_preserves_metadata(postgres_service: PostgresService) -> None: + """Test AsyncPG update preserves execution_sequence and applied_at.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await driver.execute(f"DROP TABLE IF EXISTS {tracker.version_table}") + await tracker.ensure_tracking_table(driver) + + await tracker.record_migration(driver, "20251011120000", "Migration", 100, "checksum1") + + result_before = await driver.execute(f""" + SELECT execution_sequence, applied_at + FROM {tracker.version_table} + WHERE version_num = '20251011120000' + """) + record_before = (result_before.data or [])[0] + + await tracker.update_version_record(driver, "20251011120000", "0001") + + result_after = await driver.execute(f""" + SELECT version_num, version_type, execution_sequence, applied_at + FROM {tracker.version_table} + WHERE version_num = '0001' + """) + record_after = (result_after.data or [])[0] + + assert record_after["version_num"] == "0001" + assert record_after["version_type"] == "sequential" + assert record_after["execution_sequence"] == record_before["execution_sequence"] + finally: + await config.close_pool() + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_update_version_record_idempotent(postgres_service: PostgresService) -> None: + """Test AsyncPG update_version_record is idempotent.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await driver.execute(f"DROP TABLE IF EXISTS {tracker.version_table}") + await tracker.ensure_tracking_table(driver) + + await tracker.record_migration(driver, "20251011120000", "Migration", 100, "checksum1") + + await tracker.update_version_record(driver, "20251011120000", "0001") + await tracker.update_version_record(driver, "20251011120000", "0001") + + result = await driver.execute(f"SELECT COUNT(*) as count FROM {tracker.version_table}") + count = (result.data or [])[0]["count"] + + assert count == 1 + finally: + await config.close_pool() + + +@pytest.mark.asyncio +@pytest.mark.postgres +async def test_asyncpg_migration_schema_is_idempotent(postgres_service: PostgresService) -> None: + """Test AsyncPG schema migration can be run multiple times.""" + config = _create_config(postgres_service) + tracker = AsyncMigrationTracker() + + try: + async with config.provide_session() as driver: + await driver.execute(f"DROP TABLE IF EXISTS {tracker.version_table}") + await driver.execute(f""" + CREATE TABLE {tracker.version_table} ( + version_num VARCHAR(32) PRIMARY KEY, + description TEXT + ) + """) + await driver.commit() + + await tracker.ensure_tracking_table(driver) + + result1 = await driver.execute(f""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = '{tracker.version_table}' + """) + columns1 = {row["column_name"] for row in result1.data or []} + + await tracker.ensure_tracking_table(driver) + + result2 = await driver.execute(f""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = '{tracker.version_table}' + """) + columns2 = {row["column_name"] for row in result2.data or []} + + assert columns1 == columns2 + finally: + await config.close_pool() diff --git a/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_async.py b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_async.py index 4a7a3f6e..6282464e 100644 --- a/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_async.py +++ b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_async.py @@ -85,11 +85,11 @@ async def test_store_delete_nonexistent(oracle_store: OracleAsyncStore) -> None: async def test_store_expiration_with_int(oracle_store: OracleAsyncStore) -> None: """Test session expiration with integer seconds.""" - await oracle_store.set("expiring_session", b"data", expires_in=1) + await oracle_store.set("expiring_session", b"data", expires_in=2) assert await oracle_store.exists("expiring_session") - await asyncio.sleep(1.1) + await asyncio.sleep(2.1) result = await oracle_store.get("expiring_session") assert result is None @@ -98,11 +98,11 @@ async def test_store_expiration_with_int(oracle_store: OracleAsyncStore) -> None async def test_store_expiration_with_timedelta(oracle_store: OracleAsyncStore) -> None: """Test session expiration with timedelta.""" - await oracle_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + await oracle_store.set("expiring_session", b"data", expires_in=timedelta(seconds=2)) assert await oracle_store.exists("expiring_session") - await asyncio.sleep(1.1) + await asyncio.sleep(2.1) result = await oracle_store.get("expiring_session") assert result is None diff --git a/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_sync.py b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_sync.py index ee45a0c2..45a3cb7e 100644 --- a/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_sync.py +++ b/tests/integration/test_adapters/test_oracledb/test_extensions/test_litestar/test_store_sync.py @@ -85,11 +85,11 @@ async def test_store_delete_nonexistent(oracle_sync_store: OracleSyncStore) -> N async def test_store_expiration_with_int(oracle_sync_store: OracleSyncStore) -> None: """Test session expiration with integer seconds.""" - await oracle_sync_store.set("expiring_session", b"data", expires_in=1) + await oracle_sync_store.set("expiring_session", b"data", expires_in=2) assert await oracle_sync_store.exists("expiring_session") - await asyncio.sleep(1.1) + await asyncio.sleep(2.1) result = await oracle_sync_store.get("expiring_session") assert result is None @@ -98,11 +98,11 @@ async def test_store_expiration_with_int(oracle_sync_store: OracleSyncStore) -> async def test_store_expiration_with_timedelta(oracle_sync_store: OracleSyncStore) -> None: """Test session expiration with timedelta.""" - await oracle_sync_store.set("expiring_session", b"data", expires_in=timedelta(seconds=1)) + await oracle_sync_store.set("expiring_session", b"data", expires_in=timedelta(seconds=2)) assert await oracle_sync_store.exists("expiring_session") - await asyncio.sleep(1.1) + await asyncio.sleep(2.1) result = await oracle_sync_store.get("expiring_session") assert result is None diff --git a/tests/integration/test_adapters/test_oracledb/test_migrations.py b/tests/integration/test_adapters/test_oracledb/test_migrations.py index 23f077f5..1a17b0a7 100644 --- a/tests/integration/test_adapters/test_oracledb/test_migrations.py +++ b/tests/integration/test_adapters/test_oracledb/test_migrations.py @@ -795,3 +795,150 @@ def down(): finally: if config.pool_instance: await config.close_pool() + + +async def test_oracledb_async_schema_migration_from_old_format(oracle_23ai_service: OracleService) -> None: + """Test automatic schema migration from old format (without execution_sequence) to new format. + + This simulates the scenario where a user has an existing database with the old schema + (missing version_type and execution_sequence columns) and runs `db upgrade`. + """ + test_id = "oracledb_async_schema_migration" + migration_table = f"sqlspec_migrations_{test_id}" + + config = OracleAsyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + "min": 1, + "max": 5, + }, + migration_config={"version_table_name": migration_table}, + ) + + try: + async with config.provide_session() as driver: + old_schema_sql = f""" + CREATE TABLE {migration_table} ( + version_num VARCHAR2(32) PRIMARY KEY, + description VARCHAR2(2000), + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + execution_time_ms INTEGER, + checksum VARCHAR2(64), + applied_by VARCHAR2(255) + ) + """ + await driver.execute(old_schema_sql) + await driver.commit() + + insert_sql = f""" + INSERT INTO {migration_table} + (version_num, description, execution_time_ms, checksum, applied_by) + VALUES (:1, :2, :3, :4, :5) + """ + await driver.execute(insert_sql, ("0001", "test_migration", 100, "abc123", "testuser")) + await driver.commit() + + AsyncMigrationCommands(config) + from sqlspec.adapters.oracledb.migrations import OracleAsyncMigrationTracker + + tracker = OracleAsyncMigrationTracker(migration_table) + + async with config.provide_session() as driver: + await tracker.ensure_tracking_table(driver) + + column_check_sql = f""" + SELECT column_name + FROM user_tab_columns + WHERE table_name = '{migration_table.upper()}' + ORDER BY column_name + """ + result = await driver.execute(column_check_sql) + column_names = {row["COLUMN_NAME"] for row in result.data} + + assert "VERSION_TYPE" in column_names, "VERSION_TYPE column should be added" + assert "EXECUTION_SEQUENCE" in column_names, "EXECUTION_SEQUENCE column should be added" + assert "VERSION_NUM" in column_names + assert "DESCRIPTION" in column_names + + migration_data = await driver.execute(f"SELECT * FROM {migration_table}") + assert len(migration_data.data) == 1 + assert migration_data.data[0]["VERSION_NUM"] == "0001" + finally: + if config.pool_instance: + await config.close_pool() + + +def test_oracledb_sync_schema_migration_from_old_format(oracle_23ai_service: OracleService) -> None: + """Test automatic schema migration from old format (without execution_sequence) to new format (sync version). + + This simulates the scenario where a user has an existing database with the old schema + (missing version_type and execution_sequence columns) and runs `db upgrade`. + """ + test_id = "oracledb_sync_schema_migration" + migration_table = f"sqlspec_migrations_{test_id}" + + config = OracleSyncConfig( + pool_config={ + "host": oracle_23ai_service.host, + "port": oracle_23ai_service.port, + "service_name": oracle_23ai_service.service_name, + "user": oracle_23ai_service.user, + "password": oracle_23ai_service.password, + }, + migration_config={"version_table_name": migration_table}, + ) + + try: + with config.provide_session() as driver: + old_schema_sql = f""" + CREATE TABLE {migration_table} ( + version_num VARCHAR2(32) PRIMARY KEY, + description VARCHAR2(2000), + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + execution_time_ms INTEGER, + checksum VARCHAR2(64), + applied_by VARCHAR2(255) + ) + """ + driver.execute(old_schema_sql) + driver.commit() + + insert_sql = f""" + INSERT INTO {migration_table} + (version_num, description, execution_time_ms, checksum, applied_by) + VALUES (:1, :2, :3, :4, :5) + """ + driver.execute(insert_sql, ("0001", "test_migration", 100, "abc123", "testuser")) + driver.commit() + + from sqlspec.adapters.oracledb.migrations import OracleSyncMigrationTracker + + tracker = OracleSyncMigrationTracker(migration_table) + + with config.provide_session() as driver: + tracker.ensure_tracking_table(driver) + + column_check_sql = f""" + SELECT column_name + FROM user_tab_columns + WHERE table_name = '{migration_table.upper()}' + ORDER BY column_name + """ + result = driver.execute(column_check_sql) + column_names = {row["COLUMN_NAME"] for row in result.data} + + assert "VERSION_TYPE" in column_names, "VERSION_TYPE column should be added" + assert "EXECUTION_SEQUENCE" in column_names, "EXECUTION_SEQUENCE column should be added" + assert "VERSION_NUM" in column_names + assert "DESCRIPTION" in column_names + + migration_data = driver.execute(f"SELECT * FROM {migration_table}") + assert len(migration_data.data) == 1 + assert migration_data.data[0]["VERSION_NUM"] == "0001" + finally: + if config.pool_instance: + config.close_pool() diff --git a/tests/integration/test_migrations/test_auto_sync.py b/tests/integration/test_migrations/test_auto_sync.py new file mode 100644 index 00000000..e2b1f033 --- /dev/null +++ b/tests/integration/test_migrations/test_auto_sync.py @@ -0,0 +1,333 @@ +"""Integration tests for auto-sync functionality in upgrade command.""" + +from collections.abc import Generator +from pathlib import Path + +import pytest + +from sqlspec.adapters.sqlite import SqliteConfig +from sqlspec.migrations.commands import SyncMigrationCommands +from sqlspec.migrations.fix import MigrationFixer +from sqlspec.utils.version import generate_conversion_map + + +@pytest.fixture +def sqlite_config(tmp_path: Path) -> Generator[SqliteConfig, None, None]: + """Create SQLite config with migrations directory.""" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + + config = SqliteConfig( + pool_config={"database": ":memory:"}, + migration_config={ + "script_location": str(migrations_dir), + "version_table_name": "ddl_migrations", + "auto_sync": True, + }, + ) + yield config + config.close_pool() + + +@pytest.fixture +def migrations_dir(tmp_path: Path) -> Path: + """Get migrations directory.""" + return tmp_path / "migrations" + + +def test_auto_sync_reconciles_renamed_migrations(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test auto-sync automatically reconciles renamed migrations during upgrade.""" + migrations = [ + ("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("20251012130000_create_products.sql", "20251012130000", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[1]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + + with sqlite_config.provide_session() as session: + applied_before = commands.tracker.get_applied_migrations(session) + + assert len(applied_before) == 2 + assert applied_before[0]["version_num"] == "20251011120000" + assert applied_before[1]["version_num"] == "20251012130000" + + fixer = MigrationFixer(migrations_dir) + all_files = [(v, p) for v, p in commands.runner.get_migration_files()] + conversion_map = generate_conversion_map(all_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + commands_after_rename = SyncMigrationCommands(sqlite_config) + + commands_after_rename.upgrade() + + with sqlite_config.provide_session() as session: + applied_after = commands_after_rename.tracker.get_applied_migrations(session) + + assert len(applied_after) == 2 + assert applied_after[0]["version_num"] == "0001" + assert applied_after[1]["version_num"] == "0002" + + +def test_auto_sync_validates_checksums(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test auto-sync validates checksums before reconciling.""" + content = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER PRIMARY KEY); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + (migrations_dir / "20251011120000_create_users.sql").write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + commands.upgrade() + + with sqlite_config.provide_session() as session: + applied = commands.tracker.get_applied_migrations(session) + original_checksum = applied[0]["checksum"] + + (migrations_dir / "20251011120000_create_users.sql").unlink() + + modified_content = """-- name: migrate-0001-up +CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT); + +-- name: migrate-0001-down +DROP TABLE users; +""" + (migrations_dir / "0001_create_users.sql").write_text(modified_content) + + commands_after = SyncMigrationCommands(sqlite_config) + + commands_after.upgrade() + + with sqlite_config.provide_session() as session: + applied_after = commands_after.tracker.get_applied_migrations(session) + + assert applied_after[0]["version_num"] == "20251011120000" + assert applied_after[0]["checksum"] == original_checksum + + +def test_auto_sync_disabled_via_config(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test auto-sync can be disabled via migration config.""" + sqlite_config.migration_config["auto_sync"] = False + + migrations = [("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);")] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE users; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + commands.upgrade() + + with sqlite_config.provide_session() as session: + applied_before = commands.tracker.get_applied_migrations(session) + + assert applied_before[0]["version_num"] == "20251011120000" + + fixer = MigrationFixer(migrations_dir) + all_files = [(v, p) for v, p in commands.runner.get_migration_files()] + conversion_map = generate_conversion_map(all_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + commands_after = SyncMigrationCommands(sqlite_config) + + commands_after.upgrade() + + with sqlite_config.provide_session() as session: + applied_after = commands_after.tracker.get_applied_migrations(session) + + assert applied_after[0]["version_num"] == "20251011120000" + + +def test_auto_sync_disabled_via_flag(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test auto-sync can be disabled via upgrade flag.""" + migrations = [("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);")] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE users; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + commands.upgrade() + + with sqlite_config.provide_session() as session: + applied_before = commands.tracker.get_applied_migrations(session) + + assert applied_before[0]["version_num"] == "20251011120000" + + fixer = MigrationFixer(migrations_dir) + all_files = [(v, p) for v, p in commands.runner.get_migration_files()] + conversion_map = generate_conversion_map(all_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + commands_after = SyncMigrationCommands(sqlite_config) + + commands_after.upgrade(auto_sync=False) + + with sqlite_config.provide_session() as session: + applied_after = commands_after.tracker.get_applied_migrations(session) + + assert applied_after[0]["version_num"] == "20251011120000" + + +def test_auto_sync_handles_multiple_migrations(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test auto-sync handles multiple migrations being renamed.""" + migrations = [ + ("0001_init.sql", "0001", "CREATE TABLE init (id INTEGER PRIMARY KEY);"), + ("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("20251012130000_create_products.sql", "20251012130000", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ("20251013140000_create_orders.sql", "20251013140000", "CREATE TABLE orders (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + name_without_ext = filename.rsplit(".", 1)[0] + parts = name_without_ext.split("_", 1) + table_name = parts[1].replace("create_", "") if len(parts) > 1 else name_without_ext + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {table_name}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + commands.upgrade() + + with sqlite_config.provide_session() as session: + applied_before = commands.tracker.get_applied_migrations(session) + + assert len(applied_before) == 4 + timestamp_versions = [m["version_num"] for m in applied_before if m["version_type"] == "timestamp"] + assert len(timestamp_versions) == 3 + + fixer = MigrationFixer(migrations_dir) + all_files = [(v, p) for v, p in commands.runner.get_migration_files()] + conversion_map = generate_conversion_map(all_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + commands_after = SyncMigrationCommands(sqlite_config) + + commands_after.upgrade() + + with sqlite_config.provide_session() as session: + applied_after = commands_after.tracker.get_applied_migrations(session) + + assert len(applied_after) == 4 + + expected_versions = {"0001", "0002", "0003", "0004"} + actual_versions = {m["version_num"] for m in applied_after} + assert actual_versions == expected_versions + + +def test_auto_sync_preserves_execution_sequence(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test auto-sync preserves original execution sequence.""" + migrations = [ + ("0001_init.sql", "0001", "CREATE TABLE init (id INTEGER PRIMARY KEY);"), + ("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("20251012130000_create_products.sql", "20251012130000", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + name_without_ext = filename.rsplit(".", 1)[0] + parts = name_without_ext.split("_", 1) + table_name = parts[1].replace("create_", "") if len(parts) > 1 else name_without_ext + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {table_name}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + commands.upgrade() + + with sqlite_config.provide_session() as session: + applied_before = commands.tracker.get_applied_migrations(session) + + original_sequences = {m["version_num"]: m["execution_sequence"] for m in applied_before} + + fixer = MigrationFixer(migrations_dir) + all_files = [(v, p) for v, p in commands.runner.get_migration_files()] + conversion_map = generate_conversion_map(all_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + commands_after = SyncMigrationCommands(sqlite_config) + + commands_after.upgrade() + + with sqlite_config.provide_session() as session: + applied_after = commands_after.tracker.get_applied_migrations(session) + + assert applied_after[0]["execution_sequence"] == original_sequences["0001"] + assert applied_after[1]["execution_sequence"] == original_sequences["20251011120000"] + assert applied_after[2]["execution_sequence"] == original_sequences["20251012130000"] + + +def test_auto_sync_with_new_migrations_after_rename(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test auto-sync works when adding new migrations after rename.""" + migrations = [("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);")] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE users; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + commands.upgrade() + + fixer = MigrationFixer(migrations_dir) + all_files = [(v, p) for v, p in commands.runner.get_migration_files()] + conversion_map = generate_conversion_map(all_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + new_migration = """-- name: migrate-0002-up +CREATE TABLE products (id INTEGER PRIMARY KEY); + +-- name: migrate-0002-down +DROP TABLE products; +""" + (migrations_dir / "0002_create_products.sql").write_text(new_migration) + + commands_after = SyncMigrationCommands(sqlite_config) + + commands_after.upgrade() + + with sqlite_config.provide_session() as session: + applied = commands_after.tracker.get_applied_migrations(session) + + assert len(applied) == 2 + assert applied[0]["version_num"] == "0001" + assert applied[1]["version_num"] == "0002" diff --git a/tests/integration/test_migrations/test_fix_checksum_stability.py b/tests/integration/test_migrations/test_fix_checksum_stability.py new file mode 100644 index 00000000..00ce4a81 --- /dev/null +++ b/tests/integration/test_migrations/test_fix_checksum_stability.py @@ -0,0 +1,203 @@ +"""Integration tests for checksum stability during fix operations.""" + +from collections.abc import Generator +from pathlib import Path + +import pytest + +from sqlspec.adapters.sqlite import SqliteConfig, SqliteDriver +from sqlspec.migrations.fix import MigrationFixer +from sqlspec.migrations.runner import SyncMigrationRunner +from sqlspec.migrations.tracker import SyncMigrationTracker +from sqlspec.utils.version import generate_conversion_map + + +@pytest.fixture +def sqlite_config() -> Generator[SqliteConfig, None, None]: + """Create SQLite config for migration testing.""" + config = SqliteConfig(pool_config={"database": ":memory:"}) + yield config + config.close_pool() + + +@pytest.fixture +def sqlite_session(sqlite_config: SqliteConfig) -> Generator[SqliteDriver, None, None]: + """Create SQLite session for migration testing.""" + with sqlite_config.provide_session() as session: + yield session + + +@pytest.fixture +def migrations_dir(tmp_path: Path) -> Path: + """Create temporary migrations directory.""" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + return migrations_dir + + +def test_checksum_stable_after_fix_sql_migration( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test checksum remains stable when converting SQL migration from timestamp to sequential.""" + sql_content = """-- name: migrate-20251011120000-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE +); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + sql_file = migrations_dir / "20251011120000_create_users.sql" + sql_file.write_text(sql_content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + migration = runner.load_migration(migration_files[0][1], version=migration_files[0][0]) + + runner.execute_upgrade(sqlite_session, migration) + tracker.record_migration(sqlite_session, migration["version"], migration["description"], 100, migration["checksum"]) + + applied = tracker.get_applied_migrations(sqlite_session) + original_checksum = applied[0]["checksum"] + + fixer = MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + runner_after = SyncMigrationRunner(migrations_dir) + migration_files_after = runner_after.get_migration_files() + migration_after = runner_after.load_migration(migration_files_after[0][1], version=migration_files_after[0][0]) + + new_checksum = migration_after["checksum"] + + assert new_checksum == original_checksum + + +def test_multiple_migrations_checksums_stable( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test all migration checksums remain stable during batch conversion.""" + migrations = [ + ( + "20251011120000_create_users.sql", + """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER PRIMARY KEY); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""", + ), + ( + "20251012130000_create_products.sql", + """-- name: migrate-20251012130000-up +CREATE TABLE products (id INTEGER PRIMARY KEY); + +-- name: migrate-20251012130000-down +DROP TABLE products; +""", + ), + ( + "20251013140000_create_orders.sql", + """-- name: migrate-20251013140000-up +CREATE TABLE orders (id INTEGER PRIMARY KEY); + +-- name: migrate-20251013140000-down +DROP TABLE orders; +""", + ), + ] + + for filename, content in migrations: + (migrations_dir / filename).write_text(content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + + original_checksums = {} + for version, file_path in migration_files: + migration = runner.load_migration(file_path, version=version) + runner.execute_upgrade(sqlite_session, migration) + tracker.record_migration( + sqlite_session, migration["version"], migration["description"], 100, migration["checksum"] + ) + original_checksums[version] = migration["checksum"] + + fixer = MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + runner_after = SyncMigrationRunner(migrations_dir) + migration_files_after = runner_after.get_migration_files() + + for version, file_path in migration_files_after: + migration = runner_after.load_migration(file_path, version=version) + new_checksum = migration["checksum"] + + old_version = next(k for k, v in conversion_map.items() if v == version) + original_checksum = original_checksums[old_version] + + assert new_checksum == original_checksum + + +def test_checksum_stability_with_complex_sql( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test checksum stability with complex SQL containing version references.""" + sql_content = """-- name: migrate-20251011120000-up +-- This migration creates users table +-- Previous migration: migrate-20251010110000-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL CHECK (name != 'migrate-20251011120000-up'), + metadata TEXT DEFAULT '-- name: some-pattern-up' +); + +-- Comment about migrate-20251011120000-up +INSERT INTO users (name) VALUES ('test migrate-20251011120000-up reference'); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + sql_file = migrations_dir / "20251011120000_create_users.sql" + sql_file.write_text(sql_content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + migration = runner.load_migration(migration_files[0][1], version=migration_files[0][0]) + + original_checksum = migration["checksum"] + + fixer = MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + runner_after = SyncMigrationRunner(migrations_dir) + migration_files_after = runner_after.get_migration_files() + migration_after = runner_after.load_migration(migration_files_after[0][1], version=migration_files_after[0][0]) + + new_checksum = migration_after["checksum"] + + assert new_checksum == original_checksum + + converted_content = (migrations_dir / "0001_create_users.sql").read_text() + assert "-- name: migrate-0001-up" in converted_content + assert "-- name: migrate-0001-down" in converted_content + assert "migrate-20251010110000-up" in converted_content + assert "CHECK (name != 'migrate-20251011120000-up')" in converted_content + assert "metadata TEXT DEFAULT '-- name: some-pattern-up'" in converted_content diff --git a/tests/integration/test_migrations/test_fix_file_operations.py b/tests/integration/test_migrations/test_fix_file_operations.py new file mode 100644 index 00000000..275069b6 --- /dev/null +++ b/tests/integration/test_migrations/test_fix_file_operations.py @@ -0,0 +1,376 @@ +"""Integration tests for migration fix file operations.""" + +from pathlib import Path + +import pytest + +from sqlspec.migrations.fix import MigrationFixer, MigrationRename + + +@pytest.fixture +def temp_migrations_dir(tmp_path: Path) -> Path: + """Create temporary migrations directory.""" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + return migrations_dir + + +@pytest.fixture +def sample_sql_migration(temp_migrations_dir: Path) -> Path: + """Create sample SQL migration file.""" + content = """-- name: migrate-20251011120000-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + file_path = temp_migrations_dir / "20251011120000_create_users.sql" + file_path.write_text(content) + return file_path + + +@pytest.fixture +def sample_py_migration(temp_migrations_dir: Path) -> Path: + """Create sample Python migration file.""" + content = '''"""Create products table migration.""" + +async def up(driver): + """Apply migration.""" + await driver.execute("CREATE TABLE products (id INTEGER PRIMARY KEY)") + +async def down(driver): + """Revert migration.""" + await driver.execute("DROP TABLE products") +''' + file_path = temp_migrations_dir / "20251012130000_create_products.py" + file_path.write_text(content) + return file_path + + +def test_plan_renames_empty_map(temp_migrations_dir: Path) -> None: + """Test planning renames with empty conversion map.""" + fixer = MigrationFixer(temp_migrations_dir) + renames = fixer.plan_renames({}) + assert renames == [] + + +def test_plan_renames_no_matching_files(temp_migrations_dir: Path) -> None: + """Test planning renames when no files match.""" + fixer = MigrationFixer(temp_migrations_dir) + conversion_map = {"20251011120000": "0001"} + renames = fixer.plan_renames(conversion_map) + assert renames == [] + + +def test_plan_renames_single_sql_file(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test planning rename for single SQL migration.""" + fixer = MigrationFixer(temp_migrations_dir) + conversion_map = {"20251011120000": "0001"} + + renames = fixer.plan_renames(conversion_map) + + assert len(renames) == 1 + rename = renames[0] + assert rename.old_version == "20251011120000" + assert rename.new_version == "0001" + assert rename.old_path == sample_sql_migration + assert rename.new_path == temp_migrations_dir / "0001_create_users.sql" + assert rename.needs_content_update is True + + +def test_plan_renames_single_py_file(temp_migrations_dir: Path, sample_py_migration: Path) -> None: + """Test planning rename for single Python migration.""" + fixer = MigrationFixer(temp_migrations_dir) + conversion_map = {"20251012130000": "0001"} + + renames = fixer.plan_renames(conversion_map) + + assert len(renames) == 1 + rename = renames[0] + assert rename.old_version == "20251012130000" + assert rename.new_version == "0001" + assert rename.old_path == sample_py_migration + assert rename.new_path == temp_migrations_dir / "0001_create_products.py" + assert rename.needs_content_update is False + + +def test_plan_renames_multiple_files( + temp_migrations_dir: Path, sample_sql_migration: Path, sample_py_migration: Path +) -> None: + """Test planning renames for multiple migrations.""" + fixer = MigrationFixer(temp_migrations_dir) + conversion_map = {"20251011120000": "0001", "20251012130000": "0002"} + + renames = fixer.plan_renames(conversion_map) + + assert len(renames) == 2 + + sql_rename = next(r for r in renames if r.old_path == sample_sql_migration) + assert sql_rename.new_version == "0001" + assert sql_rename.needs_content_update is True + + py_rename = next(r for r in renames if r.old_path == sample_py_migration) + assert py_rename.new_version == "0002" + assert py_rename.needs_content_update is False + + +def test_plan_renames_detects_collision(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test planning renames detects target file collision.""" + existing_target = temp_migrations_dir / "0001_create_users.sql" + existing_target.write_text("EXISTING FILE") + + fixer = MigrationFixer(temp_migrations_dir) + conversion_map = {"20251011120000": "0001"} + + with pytest.raises(ValueError, match="Target file already exists"): + fixer.plan_renames(conversion_map) + + +def test_create_backup(temp_migrations_dir: Path, sample_sql_migration: Path, sample_py_migration: Path) -> None: + """Test backup creation.""" + fixer = MigrationFixer(temp_migrations_dir) + + backup_path = fixer.create_backup() + + assert backup_path.exists() + assert backup_path.is_dir() + assert backup_path.name.startswith(".backup_") + + backed_up_files = list(backup_path.iterdir()) + assert len(backed_up_files) == 2 + + backup_names = {f.name for f in backed_up_files} + assert "20251011120000_create_users.sql" in backup_names + assert "20251012130000_create_products.py" in backup_names + + +def test_create_backup_only_copies_files(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test backup only copies files, not subdirectories.""" + subdir = temp_migrations_dir / "subdir" + subdir.mkdir() + (subdir / "file.txt").write_text("test") + + fixer = MigrationFixer(temp_migrations_dir) + backup_path = fixer.create_backup() + + backed_up_files = list(backup_path.iterdir()) + assert len(backed_up_files) == 1 + assert backed_up_files[0].name == "20251011120000_create_users.sql" + + +def test_create_backup_ignores_hidden_files(temp_migrations_dir: Path) -> None: + """Test backup ignores hidden files.""" + hidden_file = temp_migrations_dir / ".hidden" + hidden_file.write_text("hidden content") + + visible_file = temp_migrations_dir / "20251011120000_create_users.sql" + visible_file.write_text("visible content") + + fixer = MigrationFixer(temp_migrations_dir) + backup_path = fixer.create_backup() + + backed_up_files = list(backup_path.iterdir()) + assert len(backed_up_files) == 1 + assert backed_up_files[0].name == "20251011120000_create_users.sql" + + +def test_apply_renames_dry_run(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test dry-run mode doesn't modify files.""" + original_content = sample_sql_migration.read_text() + + fixer = MigrationFixer(temp_migrations_dir) + renames = [ + MigrationRename( + old_path=sample_sql_migration, + new_path=temp_migrations_dir / "0001_create_users.sql", + old_version="20251011120000", + new_version="0001", + needs_content_update=True, + ) + ] + + fixer.apply_renames(renames, dry_run=True) + + assert sample_sql_migration.exists() + assert not (temp_migrations_dir / "0001_create_users.sql").exists() + assert sample_sql_migration.read_text() == original_content + + +def test_apply_renames_actual(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test actual rename execution.""" + fixer = MigrationFixer(temp_migrations_dir) + new_path = temp_migrations_dir / "0001_create_users.sql" + + renames = [ + MigrationRename( + old_path=sample_sql_migration, + new_path=new_path, + old_version="20251011120000", + new_version="0001", + needs_content_update=True, + ) + ] + + fixer.apply_renames(renames) + + assert not sample_sql_migration.exists() + assert new_path.exists() + + +def test_apply_renames_empty_list(temp_migrations_dir: Path) -> None: + """Test applying empty renames list.""" + fixer = MigrationFixer(temp_migrations_dir) + fixer.apply_renames([]) + + +def test_update_file_content_sql_up_and_down(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test updating SQL file content updates query names.""" + fixer = MigrationFixer(temp_migrations_dir) + + fixer.update_file_content(sample_sql_migration, "20251011120000", "0001") + + updated_content = sample_sql_migration.read_text() + assert "-- name: migrate-0001-up" in updated_content + assert "-- name: migrate-0001-down" in updated_content + assert "migrate-20251011120000-up" not in updated_content + assert "migrate-20251011120000-down" not in updated_content + + +def test_update_file_content_preserves_sql_statements(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test updating content preserves actual SQL statements.""" + fixer = MigrationFixer(temp_migrations_dir) + + fixer.update_file_content(sample_sql_migration, "20251011120000", "0001") + + updated_content = sample_sql_migration.read_text() + assert "CREATE TABLE users" in updated_content + assert "DROP TABLE users" in updated_content + + +def test_update_file_content_no_query_names(temp_migrations_dir: Path) -> None: + """Test updating file without query names is a no-op.""" + file_path = temp_migrations_dir / "20251011120000_simple.sql" + original_content = "CREATE TABLE test (id INTEGER);" + file_path.write_text(original_content) + + fixer = MigrationFixer(temp_migrations_dir) + fixer.update_file_content(file_path, "20251011120000", "0001") + + assert file_path.read_text() == original_content + + +def test_rollback_restores_files(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test rollback restores files from backup.""" + original_content = sample_sql_migration.read_text() + + fixer = MigrationFixer(temp_migrations_dir) + fixer.create_backup() + + sample_sql_migration.unlink() + modified_file = temp_migrations_dir / "0001_create_users.sql" + modified_file.write_text("MODIFIED CONTENT") + + fixer.rollback() + + assert sample_sql_migration.exists() + assert sample_sql_migration.read_text() == original_content + assert not modified_file.exists() + + +def test_rollback_without_backup(temp_migrations_dir: Path) -> None: + """Test rollback without backup is a no-op.""" + fixer = MigrationFixer(temp_migrations_dir) + fixer.rollback() + + +def test_cleanup_removes_backup(temp_migrations_dir: Path, sample_sql_migration: Path) -> None: + """Test cleanup removes backup directory.""" + fixer = MigrationFixer(temp_migrations_dir) + backup_path = fixer.create_backup() + + assert backup_path.exists() + + fixer.cleanup() + + assert not backup_path.exists() + assert fixer.backup_path is None + + +def test_cleanup_without_backup(temp_migrations_dir: Path) -> None: + """Test cleanup without backup is a no-op.""" + fixer = MigrationFixer(temp_migrations_dir) + fixer.cleanup() + + +def test_full_conversion_workflow(temp_migrations_dir: Path) -> None: + """Test complete conversion workflow with rollback on error.""" + sql_file = temp_migrations_dir / "20251011120000_create_users.sql" + sql_file.write_text("""-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""") + + py_file = temp_migrations_dir / "20251012130000_create_products.py" + py_file.write_text('"""Migration."""\n\nasync def up(driver):\n pass\n') + + fixer = MigrationFixer(temp_migrations_dir) + conversion_map = {"20251011120000": "0001", "20251012130000": "0002"} + + backup_path = fixer.create_backup() + renames = fixer.plan_renames(conversion_map) + + assert len(renames) == 2 + + try: + fixer.apply_renames(renames) + fixer.cleanup() + + assert (temp_migrations_dir / "0001_create_users.sql").exists() + assert (temp_migrations_dir / "0002_create_products.py").exists() + assert not sql_file.exists() + assert not py_file.exists() + assert not backup_path.exists() + + converted_content = (temp_migrations_dir / "0001_create_users.sql").read_text() + assert "migrate-0001-up" in converted_content + assert "migrate-0001-down" in converted_content + + except Exception: + fixer.rollback() + raise + + +def test_extension_migration_rename(temp_migrations_dir: Path) -> None: + """Test renaming extension migrations preserves prefix.""" + ext_file = temp_migrations_dir / "ext_litestar_20251011215440_create_sessions.py" + ext_file.write_text('"""Extension migration."""\n') + + fixer = MigrationFixer(temp_migrations_dir) + conversion_map = {"ext_litestar_20251011215440": "ext_litestar_0001"} + + renames = fixer.plan_renames(conversion_map) + + assert len(renames) == 1 + assert renames[0].new_path.name == "ext_litestar_0001_create_sessions.py" + + +def test_multiple_sql_files_same_version(temp_migrations_dir: Path) -> None: + """Test handling multiple files with same version prefix.""" + file1 = temp_migrations_dir / "20251011120000_users.sql" + file2 = temp_migrations_dir / "20251011120000_products.sql" + file1.write_text("CREATE TABLE users (id INTEGER);") + file2.write_text("CREATE TABLE products (id INTEGER);") + + fixer = MigrationFixer(temp_migrations_dir) + conversion_map = {"20251011120000": "0001"} + + renames = fixer.plan_renames(conversion_map) + + assert len(renames) == 2 + new_names = {r.new_path.name for r in renames} + assert new_names == {"0001_users.sql", "0001_products.sql"} diff --git a/tests/integration/test_migrations/test_fix_idempotency_workflow.py b/tests/integration/test_migrations/test_fix_idempotency_workflow.py new file mode 100644 index 00000000..2fa5a85e --- /dev/null +++ b/tests/integration/test_migrations/test_fix_idempotency_workflow.py @@ -0,0 +1,353 @@ +"""Integration tests for idempotent fix command workflow.""" + +from collections.abc import Generator +from pathlib import Path + +import pytest + +from sqlspec.adapters.sqlite import SqliteConfig, SqliteDriver +from sqlspec.migrations.fix import MigrationFixer +from sqlspec.migrations.runner import SyncMigrationRunner +from sqlspec.migrations.tracker import SyncMigrationTracker +from sqlspec.utils.version import generate_conversion_map + + +@pytest.fixture +def sqlite_config() -> Generator[SqliteConfig, None, None]: + """Create SQLite config for migration testing.""" + config = SqliteConfig(pool_config={"database": ":memory:"}) + yield config + config.close_pool() + + +@pytest.fixture +def sqlite_session(sqlite_config: SqliteConfig) -> Generator[SqliteDriver, None, None]: + """Create SQLite session for migration testing.""" + with sqlite_config.provide_session() as session: + yield session + + +@pytest.fixture +def migrations_dir(tmp_path: Path) -> Path: + """Create temporary migrations directory.""" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + return migrations_dir + + +def test_fix_command_idempotent_single_migration( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test fix command can be run multiple times without error on single migration.""" + sql_content = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER PRIMARY KEY); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + sql_file = migrations_dir / "20251011120000_create_users.sql" + sql_file.write_text(sql_content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + migration = runner.load_migration(migration_files[0][1], version=migration_files[0][0]) + + runner.execute_upgrade(sqlite_session, migration) + tracker.record_migration(sqlite_session, migration["version"], migration["description"], 100, migration["checksum"]) + + fixer = MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + tracker.update_version_record(sqlite_session, "20251011120000", "0001") + + applied = tracker.get_applied_migrations(sqlite_session) + assert len(applied) == 1 + assert applied[0]["version_num"] == "0001" + + new_migration_files = runner.get_migration_files() + conversion_map_second = generate_conversion_map(new_migration_files) + + assert conversion_map_second == {} + + tracker.update_version_record(sqlite_session, "20251011120000", "0001") + + applied_after = tracker.get_applied_migrations(sqlite_session) + assert len(applied_after) == 1 + assert applied_after[0]["version_num"] == "0001" + + +def test_fix_command_idempotent_multiple_migrations( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test fix command is idempotent with multiple migrations.""" + migrations = [ + ( + "20251011120000_create_users.sql", + """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER PRIMARY KEY); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""", + ), + ( + "20251012130000_create_products.sql", + """-- name: migrate-20251012130000-up +CREATE TABLE products (id INTEGER PRIMARY KEY); + +-- name: migrate-20251012130000-down +DROP TABLE products; +""", + ), + ] + + for filename, content in migrations: + (migrations_dir / filename).write_text(content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + + for version, file_path in migration_files: + migration = runner.load_migration(file_path, version=version) + runner.execute_upgrade(sqlite_session, migration) + tracker.record_migration( + sqlite_session, migration["version"], migration["description"], 100, migration["checksum"] + ) + + fixer = MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + for old_version, new_version in conversion_map.items(): + tracker.update_version_record(sqlite_session, old_version, new_version) + + applied_first = tracker.get_applied_migrations(sqlite_session) + assert len(applied_first) == 2 + assert applied_first[0]["version_num"] == "0001" + assert applied_first[1]["version_num"] == "0002" + + for old_version, new_version in conversion_map.items(): + tracker.update_version_record(sqlite_session, old_version, new_version) + + applied_second = tracker.get_applied_migrations(sqlite_session) + assert len(applied_second) == 2 + assert applied_second[0]["version_num"] == "0001" + assert applied_second[1]["version_num"] == "0002" + + +def test_ci_workflow_simulation( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test simulated CI workflow where fix runs on every commit.""" + sql_content = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER PRIMARY KEY); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + sql_file = migrations_dir / "20251011120000_create_users.sql" + sql_file.write_text(sql_content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + migration = runner.load_migration(migration_files[0][1], version=migration_files[0][0]) + + runner.execute_upgrade(sqlite_session, migration) + tracker.record_migration(sqlite_session, migration["version"], migration["description"], 100, migration["checksum"]) + + fixer = MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + for old_version, new_version in conversion_map.items(): + tracker.update_version_record(sqlite_session, old_version, new_version) + + runner_after_first_fix = SyncMigrationRunner(migrations_dir) + files_after_first = runner_after_first_fix.get_migration_files() + conversion_map_second = generate_conversion_map(files_after_first) + + assert conversion_map_second == {} + + for old_version, new_version in conversion_map.items(): + tracker.update_version_record(sqlite_session, old_version, new_version) + + applied = tracker.get_applied_migrations(sqlite_session) + assert len(applied) == 1 + assert applied[0]["version_num"] == "0001" + + +def test_developer_pull_workflow( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test developer pulls changes and runs fix on already-converted files.""" + sql_content = """-- name: migrate-0001-up +CREATE TABLE users (id INTEGER PRIMARY KEY); + +-- name: migrate-0001-down +DROP TABLE users; +""" + + sql_file = migrations_dir / "0001_create_users.sql" + sql_file.write_text(sql_content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + migration = runner.load_migration(migration_files[0][1], version=migration_files[0][0]) + + runner.execute_upgrade(sqlite_session, migration) + tracker.record_migration(sqlite_session, migration["version"], migration["description"], 100, migration["checksum"]) + + MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + + assert conversion_map == {} + + applied = tracker.get_applied_migrations(sqlite_session) + assert len(applied) == 1 + assert applied[0]["version_num"] == "0001" + assert applied[0]["version_type"] == "sequential" + + +def test_partial_conversion_recovery( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test recovery when fix partially completes.""" + migrations = [ + ( + "20251011120000_create_users.sql", + """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER PRIMARY KEY); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""", + ), + ( + "20251012130000_create_products.sql", + """-- name: migrate-20251012130000-up +CREATE TABLE products (id INTEGER PRIMARY KEY); + +-- name: migrate-20251012130000-down +DROP TABLE products; +""", + ), + ] + + for filename, content in migrations: + (migrations_dir / filename).write_text(content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + + for version, file_path in migration_files: + migration = runner.load_migration(file_path, version=version) + runner.execute_upgrade(sqlite_session, migration) + tracker.record_migration( + sqlite_session, migration["version"], migration["description"], 100, migration["checksum"] + ) + + fixer = MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + tracker.update_version_record(sqlite_session, "20251011120000", "0001") + + applied_partial = tracker.get_applied_migrations(sqlite_session) + versions_in_db = {row["version_num"] for row in applied_partial} + assert "0001" in versions_in_db + assert "20251012130000" in versions_in_db + + runner_partial = SyncMigrationRunner(migrations_dir) + files_partial = runner_partial.get_migration_files() + generate_conversion_map(files_partial) + + tracker.update_version_record(sqlite_session, "20251012130000", "0002") + + applied_complete = tracker.get_applied_migrations(sqlite_session) + assert len(applied_complete) == 2 + assert all(row["version_num"] in ["0001", "0002"] for row in applied_complete) + + +def test_mixed_sequential_and_timestamp_idempotent( + migrations_dir: Path, sqlite_session: SqliteDriver, sqlite_config: SqliteConfig +) -> None: + """Test fix is idempotent with mixed sequential and timestamp migrations.""" + migrations = [ + ( + "0001_init.sql", + """-- name: migrate-0001-up +CREATE TABLE init (id INTEGER PRIMARY KEY); + +-- name: migrate-0001-down +DROP TABLE init; +""", + ), + ( + "20251011120000_create_users.sql", + """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER PRIMARY KEY); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""", + ), + ] + + for filename, content in migrations: + (migrations_dir / filename).write_text(content) + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + runner = SyncMigrationRunner(migrations_dir) + migration_files = runner.get_migration_files() + + for version, file_path in migration_files: + migration = runner.load_migration(file_path, version=version) + runner.execute_upgrade(sqlite_session, migration) + tracker.record_migration( + sqlite_session, migration["version"], migration["description"], 100, migration["checksum"] + ) + + fixer = MigrationFixer(migrations_dir) + conversion_map = generate_conversion_map(migration_files) + renames = fixer.plan_renames(conversion_map) + fixer.apply_renames(renames) + + for old_version, new_version in conversion_map.items(): + tracker.update_version_record(sqlite_session, old_version, new_version) + + applied_first = tracker.get_applied_migrations(sqlite_session) + assert len(applied_first) == 2 + versions_first = {row["version_num"] for row in applied_first} + assert versions_first == {"0001", "0002"} + + for old_version, new_version in conversion_map.items(): + tracker.update_version_record(sqlite_session, old_version, new_version) + + applied_second = tracker.get_applied_migrations(sqlite_session) + assert len(applied_second) == 2 + versions_second = {row["version_num"] for row in applied_second} + assert versions_second == {"0001", "0002"} diff --git a/tests/integration/test_migrations/test_schema_migration.py b/tests/integration/test_migrations/test_schema_migration.py new file mode 100644 index 00000000..be499948 --- /dev/null +++ b/tests/integration/test_migrations/test_schema_migration.py @@ -0,0 +1,311 @@ +"""Integration tests for migration tracking table schema migration.""" + +from collections.abc import Generator + +import pytest + +from sqlspec.adapters.sqlite import SqliteConfig, SqliteDriver +from sqlspec.migrations.tracker import SyncMigrationTracker + + +@pytest.fixture +def sqlite_config() -> Generator[SqliteConfig, None, None]: + """Create SQLite config for testing.""" + config = SqliteConfig(pool_config={"database": ":memory:"}) + yield config + config.close_pool() + + +@pytest.fixture +def sqlite_session(sqlite_config: SqliteConfig) -> Generator[SqliteDriver, None, None]: + """Create SQLite session for testing.""" + with sqlite_config.provide_session() as session: + yield session + + +def test_tracker_creates_full_schema_on_fresh_install(sqlite_session: SqliteDriver) -> None: + """Test tracker creates complete schema with all columns on new database.""" + tracker = SyncMigrationTracker() + + tracker.ensure_tracking_table(sqlite_session) + + result = sqlite_session.execute(f"PRAGMA table_info({tracker.version_table})") + columns = {row["name"] if isinstance(row, dict) else row[1] for row in result.data or []} + + expected_columns = { + "version_num", + "version_type", + "execution_sequence", + "description", + "applied_at", + "execution_time_ms", + "checksum", + "applied_by", + } + + assert columns == expected_columns + + +def test_tracker_migrates_legacy_schema(sqlite_session: SqliteDriver) -> None: + """Test tracker adds missing columns to legacy schema.""" + tracker = SyncMigrationTracker() + + sqlite_session.execute(f""" + CREATE TABLE {tracker.version_table} ( + version_num VARCHAR(32) PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """) + sqlite_session.commit() + + tracker.ensure_tracking_table(sqlite_session) + + result = sqlite_session.execute(f"PRAGMA table_info({tracker.version_table})") + columns = {row["name"] if isinstance(row, dict) else row[1] for row in result.data or []} + + assert "version_type" in columns + assert "execution_sequence" in columns + assert "checksum" in columns + assert "execution_time_ms" in columns + assert "applied_by" in columns + + +def test_tracker_migration_preserves_existing_data(sqlite_session: SqliteDriver) -> None: + """Test schema migration preserves existing migration records.""" + tracker = SyncMigrationTracker() + + sqlite_session.execute(f""" + CREATE TABLE {tracker.version_table} ( + version_num VARCHAR(32) PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """) + + sqlite_session.execute( + f""" + INSERT INTO {tracker.version_table} (version_num, description) + VALUES ('0001', 'Initial migration') + """ + ) + sqlite_session.commit() + + tracker.ensure_tracking_table(sqlite_session) + + result = sqlite_session.execute(f"SELECT * FROM {tracker.version_table}") + records = result.data or [] + + assert len(records) == 1 + record = records[0] + assert record["version_num"] == "0001" + assert record["description"] == "Initial migration" + assert "version_type" in record + assert "execution_sequence" in record + + +def test_tracker_migration_is_idempotent(sqlite_session: SqliteDriver) -> None: + """Test schema migration can be run multiple times safely.""" + tracker = SyncMigrationTracker() + + sqlite_session.execute(f""" + CREATE TABLE {tracker.version_table} ( + version_num VARCHAR(32) PRIMARY KEY, + description TEXT, + applied_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL + ) + """) + sqlite_session.commit() + + tracker.ensure_tracking_table(sqlite_session) + + result1 = sqlite_session.execute(f"PRAGMA table_info({tracker.version_table})") + columns1 = {row["name"] if isinstance(row, dict) else row[1] for row in result1.data or []} + + tracker.ensure_tracking_table(sqlite_session) + + result2 = sqlite_session.execute(f"PRAGMA table_info({tracker.version_table})") + columns2 = {row["name"] if isinstance(row, dict) else row[1] for row in result2.data or []} + + assert columns1 == columns2 + + +def test_tracker_uses_version_type_for_recording(sqlite_session: SqliteDriver) -> None: + """Test tracker correctly records version_type when recording migrations.""" + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + tracker.record_migration(sqlite_session, "0001", "Sequential migration", 100, "checksum1") + tracker.record_migration(sqlite_session, "20251011120000", "Timestamp migration", 150, "checksum2") + + result = sqlite_session.execute( + f"SELECT version_num, version_type FROM {tracker.version_table} ORDER BY execution_sequence" + ) + records = result.data or [] + + assert len(records) == 2 + assert records[0]["version_num"] == "0001" + assert records[0]["version_type"] == "sequential" + assert records[1]["version_num"] == "20251011120000" + assert records[1]["version_type"] == "timestamp" + + +def test_tracker_execution_sequence_auto_increments(sqlite_session: SqliteDriver) -> None: + """Test execution_sequence auto-increments for tracking application order.""" + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + tracker.record_migration(sqlite_session, "0001", "First", 100, "checksum1") + tracker.record_migration(sqlite_session, "0002", "Second", 100, "checksum2") + tracker.record_migration(sqlite_session, "0003", "Third", 100, "checksum3") + + result = sqlite_session.execute( + f"SELECT version_num, execution_sequence FROM {tracker.version_table} ORDER BY execution_sequence" + ) + records = result.data or [] + + assert len(records) == 3 + assert records[0]["execution_sequence"] == 1 + assert records[1]["execution_sequence"] == 2 + assert records[2]["execution_sequence"] == 3 + + +def test_tracker_get_current_version_uses_execution_sequence(sqlite_session: SqliteDriver) -> None: + """Test get_current_version returns last applied migration by execution order.""" + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + tracker.record_migration(sqlite_session, "0001", "First", 100, "checksum1") + tracker.record_migration(sqlite_session, "0003", "Out of order", 100, "checksum3") + tracker.record_migration(sqlite_session, "0002", "Late merge", 100, "checksum2") + + current = tracker.get_current_version(sqlite_session) + + assert current == "0002" + + +def test_tracker_update_version_record_preserves_execution_sequence(sqlite_session: SqliteDriver) -> None: + """Test updating version preserves execution_sequence and applied_at.""" + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + tracker.record_migration(sqlite_session, "20251011120000", "Timestamp migration", 100, "checksum1") + + result_before = sqlite_session.execute( + f"SELECT execution_sequence, applied_at FROM {tracker.version_table} WHERE version_num = '20251011120000'" + ) + record_before = (result_before.data or [])[0] + + tracker.update_version_record(sqlite_session, "20251011120000", "0001") + + result_after = sqlite_session.execute( + f"SELECT version_num, version_type, execution_sequence, applied_at FROM {tracker.version_table} WHERE version_num = '0001'" + ) + record_after = (result_after.data or [])[0] + + assert record_after["version_num"] == "0001" + assert record_after["version_type"] == "sequential" + assert record_after["execution_sequence"] == record_before["execution_sequence"] + assert record_after["applied_at"] == record_before["applied_at"] + + +def test_tracker_update_version_record_idempotent(sqlite_session: SqliteDriver) -> None: + """Test update_version_record is idempotent when version already updated.""" + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + tracker.record_migration(sqlite_session, "20251011120000", "Migration", 100, "checksum1") + + tracker.update_version_record(sqlite_session, "20251011120000", "0001") + + tracker.update_version_record(sqlite_session, "20251011120000", "0001") + + result = sqlite_session.execute(f"SELECT COUNT(*) as count FROM {tracker.version_table}") + count = (result.data or [])[0]["count"] + + assert count == 1 + + +def test_tracker_update_version_record_raises_on_missing(sqlite_session: SqliteDriver) -> None: + """Test update_version_record raises error when version not found.""" + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + with pytest.raises(ValueError, match="Migration version .* not found"): # noqa: RUF043 + tracker.update_version_record(sqlite_session, "nonexistent", "0001") + + +def test_tracker_migration_adds_columns_in_sorted_order(sqlite_session: SqliteDriver) -> None: + """Test schema migration adds multiple missing columns consistently.""" + tracker = SyncMigrationTracker() + + sqlite_session.execute(f""" + CREATE TABLE {tracker.version_table} ( + version_num VARCHAR(32) PRIMARY KEY, + description TEXT + ) + """) + sqlite_session.commit() + + tracker.ensure_tracking_table(sqlite_session) + + result = sqlite_session.execute(f"PRAGMA table_info({tracker.version_table})") + columns = [row["name"] if isinstance(row, dict) else row[1] for row in result.data or []] + + version_num_idx = columns.index("version_num") + description_idx = columns.index("description") + version_type_idx = columns.index("version_type") + + assert version_num_idx < description_idx < version_type_idx + + +def test_tracker_checksum_column_stores_md5_hashes(sqlite_session: SqliteDriver) -> None: + """Test checksum column can store migration content checksums.""" + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + import hashlib + + content = "CREATE TABLE users (id INTEGER PRIMARY KEY);" + checksum = hashlib.md5(content.encode()).hexdigest() + + tracker.record_migration(sqlite_session, "0001", "Create users", 100, checksum) + + result = sqlite_session.execute(f"SELECT checksum FROM {tracker.version_table} WHERE version_num = '0001'") + stored_checksum = (result.data or [])[0]["checksum"] + + assert stored_checksum == checksum + assert len(stored_checksum) == 32 + + +def test_tracker_applied_by_column_stores_user(sqlite_session: SqliteDriver) -> None: + """Test applied_by column records who applied the migration.""" + import os + + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + tracker.record_migration(sqlite_session, "0001", "Migration", 100, "checksum") + + result = sqlite_session.execute(f"SELECT applied_by FROM {tracker.version_table} WHERE version_num = '0001'") + applied_by = (result.data or [])[0]["applied_by"] + + expected_user = os.environ.get("USER", "unknown") + assert applied_by == expected_user + + +def test_tracker_get_applied_migrations_orders_by_execution_sequence(sqlite_session: SqliteDriver) -> None: + """Test get_applied_migrations returns migrations in execution order.""" + tracker = SyncMigrationTracker() + tracker.ensure_tracking_table(sqlite_session) + + tracker.record_migration(sqlite_session, "0001", "First", 100, "checksum1") + tracker.record_migration(sqlite_session, "0003", "Out of order", 100, "checksum3") + tracker.record_migration(sqlite_session, "0002", "Late merge", 100, "checksum2") + + applied = tracker.get_applied_migrations(sqlite_session) + + assert len(applied) == 3 + assert applied[0]["version_num"] == "0001" + assert applied[1]["version_num"] == "0003" + assert applied[2]["version_num"] == "0002" diff --git a/tests/integration/test_migrations/test_upgrade_downgrade_versions.py b/tests/integration/test_migrations/test_upgrade_downgrade_versions.py new file mode 100644 index 00000000..809bbbdf --- /dev/null +++ b/tests/integration/test_migrations/test_upgrade_downgrade_versions.py @@ -0,0 +1,409 @@ +"""Integration tests for upgrade/downgrade commands with hybrid versioning.""" + +from collections.abc import Generator +from pathlib import Path + +import pytest + +from sqlspec.adapters.sqlite import SqliteConfig +from sqlspec.migrations.commands import SyncMigrationCommands + + +@pytest.fixture +def sqlite_config(tmp_path: Path) -> Generator[SqliteConfig, None, None]: + """Create SQLite config with migrations directory.""" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + + config = SqliteConfig( + pool_config={"database": ":memory:"}, + migration_config={"script_location": str(migrations_dir), "version_table_name": "ddl_migrations"}, + ) + yield config + config.close_pool() + + +@pytest.fixture +def migrations_dir(tmp_path: Path) -> Path: + """Get migrations directory.""" + return tmp_path / "migrations" + + +def test_upgrade_with_sequential_versions(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test upgrade works with sequential version numbers.""" + migrations = [ + ("0001_create_users.sql", "0001", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("0002_create_products.sql", "0002", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ("0003_create_orders.sql", "0003", "CREATE TABLE orders (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[1]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + + current = commands.current() + assert current == "0003" + + +def test_upgrade_with_timestamp_versions(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test upgrade works with timestamp version numbers.""" + migrations = [ + ("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("20251012130000_create_products.sql", "20251012130000", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[1]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + + current = commands.current() + assert current == "20251012130000" + + +def test_upgrade_with_mixed_versions(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test upgrade works with mixed sequential and timestamp versions.""" + migrations = [ + ("0001_create_users.sql", "0001", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("0002_create_products.sql", "0002", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ("20251011120000_create_orders.sql", "20251011120000", "CREATE TABLE orders (id INTEGER PRIMARY KEY);"), + ("20251012130000_create_payments.sql", "20251012130000", "CREATE TABLE payments (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[1]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + + current = commands.current() + assert current == "20251012130000" + + with sqlite_config.provide_session() as session: + applied = commands.tracker.get_applied_migrations(session) + + assert len(applied) == 4 + assert applied[0]["version_num"] == "0001" + assert applied[1]["version_num"] == "0002" + assert applied[2]["version_num"] == "20251011120000" + assert applied[3]["version_num"] == "20251012130000" + + +def test_upgrade_to_specific_sequential_version(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test upgrade to specific sequential version.""" + migrations = [ + ("0001_create_users.sql", "0001", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("0002_create_products.sql", "0002", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ("0003_create_orders.sql", "0003", "CREATE TABLE orders (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[1]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade(revision="0002") + + current = commands.current() + assert current == "0002" + + +def test_upgrade_to_specific_timestamp_version(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test upgrade to specific timestamp version.""" + migrations = [ + ("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("20251012130000_create_products.sql", "20251012130000", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ("20251013140000_create_orders.sql", "20251013140000", "CREATE TABLE orders (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[1]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade(revision="20251012130000") + + current = commands.current() + assert current == "20251012130000" + + +def test_downgrade_with_sequential_versions(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test downgrade works with sequential versions.""" + migrations = [ + ("0001_create_users.sql", "0001", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("0002_create_products.sql", "0002", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ("0003_create_orders.sql", "0003", "CREATE TABLE orders (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + table_name = filename.split("_", 2)[2].rsplit(".", 1)[0] + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {table_name}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + assert commands.current() == "0003" + + commands.downgrade() + assert commands.current() == "0002" + + commands.downgrade() + assert commands.current() == "0001" + + +def test_downgrade_with_timestamp_versions(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test downgrade works with timestamp versions.""" + migrations = [ + ("20251011120000_create_users.sql", "20251011120000", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("20251012130000_create_products.sql", "20251012130000", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + table_name = filename.split("_", 2)[2].rsplit(".", 1)[0] + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {table_name}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + assert commands.current() == "20251012130000" + + commands.downgrade() + assert commands.current() == "20251011120000" + + +def test_downgrade_to_specific_version(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test downgrade to specific version.""" + migrations = [ + ("0001_create_users.sql", "0001", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("0002_create_products.sql", "0002", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ("0003_create_orders.sql", "0003", "CREATE TABLE orders (id INTEGER PRIMARY KEY);"), + ("0004_create_payments.sql", "0004", "CREATE TABLE payments (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + table_name = filename.split("_", 2)[2].rsplit(".", 1)[0] + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {table_name}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + assert commands.current() == "0004" + + commands.downgrade(revision="0002") + assert commands.current() == "0002" + + with sqlite_config.provide_session() as session: + applied = commands.tracker.get_applied_migrations(session) + + assert len(applied) == 2 + assert applied[0]["version_num"] == "0001" + assert applied[1]["version_num"] == "0002" + + +def test_downgrade_to_base(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test downgrade to base (removes all migrations).""" + migrations = [ + ("0001_create_users.sql", "0001", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("0002_create_products.sql", "0002", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + table_name = filename.split("_", 2)[2].rsplit(".", 1)[0] + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {table_name}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + assert commands.current() == "0002" + + commands.downgrade(revision="base") + assert commands.current() is None + + +def test_upgrade_with_extension_migrations(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test upgrade works with extension-prefixed versions.""" + migrations = [ + ("0001_core_init.sql", "0001", "CREATE TABLE core (id INTEGER PRIMARY KEY);"), + ("ext_litestar_0001_init.sql", "ext_litestar_0001", "CREATE TABLE litestar_ext (id INTEGER PRIMARY KEY);"), + ("0002_core_users.sql", "0002", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[1]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + + with sqlite_config.provide_session() as session: + applied = commands.tracker.get_applied_migrations(session) + + assert len(applied) == 3 + assert applied[0]["version_num"] == "0001" + assert applied[1]["version_num"] == "0002" + assert applied[2]["version_num"] == "ext_litestar_0001" + + +def test_upgrade_respects_version_comparison_order(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test upgrade applies migrations in correct version comparison order.""" + migrations = [ + ("0001_init.sql", "0001", "CREATE TABLE init (id INTEGER PRIMARY KEY);"), + ("9999_large_seq.sql", "9999", "CREATE TABLE large_seq (id INTEGER PRIMARY KEY);"), + ("20200101000000_early_timestamp.sql", "20200101000000", "CREATE TABLE early (id INTEGER PRIMARY KEY);"), + ("20251011120000_late_timestamp.sql", "20251011120000", "CREATE TABLE late (id INTEGER PRIMARY KEY);"), + ("ext_aaa_0001_ext_a.sql", "ext_aaa_0001", "CREATE TABLE ext_a (id INTEGER PRIMARY KEY);"), + ("ext_zzz_0001_ext_z.sql", "ext_zzz_0001", "CREATE TABLE ext_z (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[0]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + + with sqlite_config.provide_session() as session: + applied = commands.tracker.get_applied_migrations(session) + + applied_versions = [m["version_num"] for m in applied] + + expected_order = ["0001", "9999", "20200101000000", "20251011120000", "ext_aaa_0001", "ext_zzz_0001"] + + assert applied_versions == expected_order + + +def test_upgrade_dry_run_shows_pending_migrations(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test dry run mode shows what would be applied without making changes.""" + migrations = [ + ("0001_create_users.sql", "0001", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("0002_create_products.sql", "0002", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {filename.split("_")[1]}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade(dry_run=True) + + current = commands.current() + assert current is None + + with sqlite_config.provide_session() as session: + result = session.execute("SELECT name FROM sqlite_master WHERE type='table' AND name IN ('users', 'products')") + tables = [row["name"] for row in result.data or []] + + assert len(tables) == 0 + + +def test_downgrade_dry_run_shows_pending_downgrades(sqlite_config: SqliteConfig, migrations_dir: Path) -> None: + """Test downgrade dry run shows what would be reverted without making changes.""" + migrations = [ + ("0001_create_users.sql", "0001", "CREATE TABLE users (id INTEGER PRIMARY KEY);"), + ("0002_create_products.sql", "0002", "CREATE TABLE products (id INTEGER PRIMARY KEY);"), + ] + + for filename, version, sql in migrations: + table_name = filename.split("_", 2)[2].rsplit(".", 1)[0] + content = f"""-- name: migrate-{version}-up +{sql} + +-- name: migrate-{version}-down +DROP TABLE {table_name}; +""" + (migrations_dir / filename).write_text(content) + + commands = SyncMigrationCommands(sqlite_config) + + commands.upgrade() + assert commands.current() == "0002" + + commands.downgrade(dry_run=True) + + current = commands.current() + assert current == "0002" + + with sqlite_config.provide_session() as session: + result = session.execute("SELECT name FROM sqlite_master WHERE type='table' AND name IN ('users', 'products')") + tables = [row["name"] for row in result.data or []] + + assert "users" in tables + assert "products" in tables diff --git a/tests/unit/test_cli/test_migration_commands.py b/tests/unit/test_cli/test_migration_commands.py index 211c77f6..36bfd646 100644 --- a/tests/unit/test_cli/test_migration_commands.py +++ b/tests/unit/test_cli/test_migration_commands.py @@ -386,7 +386,7 @@ def get_config(): os.chdir(original_dir) assert result.exit_code == 0 - mock_commands.upgrade.assert_called_once_with(revision="head", dry_run=False) + mock_commands.upgrade.assert_called_once_with(revision="head", auto_sync=True, dry_run=False) @patch("sqlspec.migrations.commands.create_migration_commands") @@ -421,7 +421,7 @@ def get_config(): os.chdir(original_dir) assert result.exit_code == 0 - mock_commands.upgrade.assert_called_once_with(revision="abc123", dry_run=False) + mock_commands.upgrade.assert_called_once_with(revision="abc123", auto_sync=True, dry_run=False) @patch("sqlspec.migrations.commands.create_migration_commands") @@ -775,4 +775,4 @@ def get_multi_configs(): assert result.exit_code == 0 # Should only process the analytics_db config - mock_commands.upgrade.assert_called_once_with(revision="head", dry_run=False) + mock_commands.upgrade.assert_called_once_with(revision="head", auto_sync=True, dry_run=False) diff --git a/tests/unit/test_migrations/test_checksum_canonicalization.py b/tests/unit/test_migrations/test_checksum_canonicalization.py new file mode 100644 index 00000000..3c5ca3b4 --- /dev/null +++ b/tests/unit/test_migrations/test_checksum_canonicalization.py @@ -0,0 +1,345 @@ +"""Unit tests for canonicalized checksum computation.""" + +# pyright: reportPrivateUsage=false + +from pathlib import Path + +import pytest + +from sqlspec.migrations.runner import SyncMigrationRunner + + +@pytest.fixture +def temp_migrations_dir(tmp_path: Path) -> Path: + """Create temporary migrations directory.""" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + return migrations_dir + + +def test_checksum_excludes_timestamp_version_up_header(temp_migrations_dir: Path) -> None: + """Test checksum excludes timestamp version up query name header.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content = """-- name: migrate-20251011120000-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); +""" + + checksum = runner._calculate_checksum(content) + + content_without_header = """ +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); +""" + + checksum_without_header = runner._calculate_checksum(content_without_header) + + assert checksum == checksum_without_header + + +def test_checksum_excludes_timestamp_version_down_header(temp_migrations_dir: Path) -> None: + """Test checksum excludes timestamp version down query name header.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content = """-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + checksum = runner._calculate_checksum(content) + + content_without_header = """ +DROP TABLE users; +""" + + checksum_without_header = runner._calculate_checksum(content_without_header) + + assert checksum == checksum_without_header + + +def test_checksum_excludes_sequential_version_up_header(temp_migrations_dir: Path) -> None: + """Test checksum excludes sequential version up query name header.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content = """-- name: migrate-0001-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); +""" + + checksum = runner._calculate_checksum(content) + + content_without_header = """ +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); +""" + + checksum_without_header = runner._calculate_checksum(content_without_header) + + assert checksum == checksum_without_header + + +def test_checksum_excludes_sequential_version_down_header(temp_migrations_dir: Path) -> None: + """Test checksum excludes sequential version down query name header.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content = """-- name: migrate-0001-down +DROP TABLE users; +""" + + checksum = runner._calculate_checksum(content) + + content_without_header = """ +DROP TABLE users; +""" + + checksum_without_header = runner._calculate_checksum(content_without_header) + + assert checksum == checksum_without_header + + +def test_checksum_includes_actual_sql_content(temp_migrations_dir: Path) -> None: + """Test checksum includes actual SQL statements.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content1 = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); +""" + + content2 = """-- name: migrate-20251011120000-up +CREATE TABLE products (id INTEGER); +""" + + checksum1 = runner._calculate_checksum(content1) + checksum2 = runner._calculate_checksum(content2) + + assert checksum1 != checksum2 + + +def test_checksum_stable_after_version_conversion(temp_migrations_dir: Path) -> None: + """Test checksum remains stable when converting timestamp to sequential.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + timestamp_content = """-- name: migrate-20251011120000-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + sequential_content = """-- name: migrate-0001-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); + +-- name: migrate-0001-down +DROP TABLE users; +""" + + timestamp_checksum = runner._calculate_checksum(timestamp_content) + sequential_checksum = runner._calculate_checksum(sequential_content) + + assert timestamp_checksum == sequential_checksum + + +def test_checksum_handles_multiple_query_headers(temp_migrations_dir: Path) -> None: + """Test checksum excludes all migrate-* query headers in file.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + expected_content = """ +CREATE TABLE users (id INTEGER); + + +DROP TABLE users; +""" + + checksum = runner._calculate_checksum(content) + expected_checksum = runner._calculate_checksum(expected_content) + + assert checksum == expected_checksum + + +def test_checksum_preserves_non_migration_name_headers(temp_migrations_dir: Path) -> None: + """Test checksum preserves non-migrate query name headers.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content1 = """-- name: get-users +SELECT * FROM users; +""" + + content2 = """ +SELECT * FROM users; +""" + + checksum1 = runner._calculate_checksum(content1) + checksum2 = runner._calculate_checksum(content2) + + assert checksum1 != checksum2 + + +def test_checksum_handles_whitespace_variations(temp_migrations_dir: Path) -> None: + """Test checksum handles variations in whitespace around name header.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content1 = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); +""" + + content2 = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); +""" + + content3 = """--name:migrate-20251011120000-up +CREATE TABLE users (id INTEGER); +""" + + checksum1 = runner._calculate_checksum(content1) + checksum2 = runner._calculate_checksum(content2) + checksum3 = runner._calculate_checksum(content3) + + assert checksum1 == checksum2 == checksum3 + + +def test_checksum_handles_extension_versions(temp_migrations_dir: Path) -> None: + """Test checksum excludes extension version headers.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + timestamp_content = """-- name: migrate-ext_litestar_20251011120000-up +CREATE TABLE sessions (id INTEGER); +""" + + sequential_content = """-- name: migrate-ext_litestar_0001-up +CREATE TABLE sessions (id INTEGER); +""" + + timestamp_checksum = runner._calculate_checksum(timestamp_content) + sequential_checksum = runner._calculate_checksum(sequential_content) + + assert timestamp_checksum == sequential_checksum + + +def test_checksum_empty_file(temp_migrations_dir: Path) -> None: + """Test checksum computation for empty file.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + checksum = runner._calculate_checksum("") + + assert isinstance(checksum, str) + assert len(checksum) == 32 + + +def test_checksum_only_headers(temp_migrations_dir: Path) -> None: + """Test checksum when file contains only query name headers.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content = """-- name: migrate-20251011120000-up +-- name: migrate-20251011120000-down +""" + + checksum = runner._calculate_checksum(content) + + empty_checksum = runner._calculate_checksum("\n") + + assert checksum == empty_checksum + + +def test_checksum_preserves_sql_comments(temp_migrations_dir: Path) -> None: + """Test checksum includes regular SQL comments.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content1 = """-- name: migrate-20251011120000-up +-- This is a regular comment +CREATE TABLE users (id INTEGER); +""" + + content2 = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); +""" + + checksum1 = runner._calculate_checksum(content1) + checksum2 = runner._calculate_checksum(content2) + + assert checksum1 != checksum2 + + +def test_checksum_case_sensitive(temp_migrations_dir: Path) -> None: + """Test checksum is case-sensitive for SQL content.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content1 = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); +""" + + content2 = """-- name: migrate-20251011120000-up +create table users (id integer); +""" + + checksum1 = runner._calculate_checksum(content1) + checksum2 = runner._calculate_checksum(content2) + + assert checksum1 != checksum2 + + +def test_checksum_detects_content_changes(temp_migrations_dir: Path) -> None: + """Test checksum changes when SQL content is modified.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + original = """-- name: migrate-20251011120000-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL +); +""" + + modified = """-- name: migrate-20251011120000-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT +); +""" + + original_checksum = runner._calculate_checksum(original) + modified_checksum = runner._calculate_checksum(modified) + + assert original_checksum != modified_checksum + + +def test_checksum_regex_pattern_matches_correctly(temp_migrations_dir: Path) -> None: + """Test regex pattern only matches actual query name headers.""" + runner = SyncMigrationRunner(temp_migrations_dir) + + content_with_similar_text = """-- name: migrate-20251011120000-up +-- This comment mentions migrate-20251011120000-up but shouldn't be removed +SELECT 'migrate-20251011120000-up' as text; +CREATE TABLE users (id INTEGER); +""" + + content_header_only_removed = """ +-- This comment mentions migrate-20251011120000-up but shouldn't be removed +SELECT 'migrate-20251011120000-up' as text; +CREATE TABLE users (id INTEGER); +""" + + checksum = runner._calculate_checksum(content_with_similar_text) + expected_checksum = runner._calculate_checksum(content_header_only_removed) + + assert checksum == expected_checksum diff --git a/tests/unit/test_migrations/test_fix_regex_precision.py b/tests/unit/test_migrations/test_fix_regex_precision.py new file mode 100644 index 00000000..2bf405f4 --- /dev/null +++ b/tests/unit/test_migrations/test_fix_regex_precision.py @@ -0,0 +1,368 @@ +"""Unit tests for version-specific regex patterns in fix operations.""" + +import re +from pathlib import Path + +import pytest + +from sqlspec.migrations.fix import MigrationFixer + + +@pytest.fixture +def temp_migrations_dir(tmp_path: Path) -> Path: + """Create temporary migrations directory.""" + migrations_dir = tmp_path / "migrations" + migrations_dir.mkdir() + return migrations_dir + + +def test_update_file_content_only_replaces_specific_version(temp_migrations_dir: Path) -> None: + """Test only specific old_version is replaced, not other versions.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "20251011120000_test.sql" + content = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); + +-- name: migrate-20251011120000-down +DROP TABLE users; + +-- This comment mentions migrate-20251012130000-up +-- And also migrate-0001-up +""" + + file_path.write_text(content) + + fixer.update_file_content(file_path, "20251011120000", "0001") + + updated = file_path.read_text() + + assert "-- name: migrate-0001-up" in updated + assert "-- name: migrate-0001-down" in updated + assert "migrate-20251012130000-up" in updated + assert "And also migrate-0001-up" in updated + + +def test_update_file_content_handles_special_regex_characters(temp_migrations_dir: Path) -> None: + """Test version strings with special regex characters are escaped.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "test.sql" + + version_with_dots = "1.2.3" + content = f"""-- name: migrate-{version_with_dots}-up +CREATE TABLE test (id INTEGER); +""" + + file_path.write_text(content) + + fixer.update_file_content(file_path, version_with_dots, "0001") + + updated = file_path.read_text() + assert "-- name: migrate-0001-up" in updated + + +def test_update_file_content_does_not_replace_unrelated_migrate_patterns(temp_migrations_dir: Path) -> None: + """Test unrelated migrate-* patterns are not replaced.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "20251011120000_test.sql" + content = """-- name: migrate-20251011120000-up +CREATE TABLE users (id INTEGER); + +-- name: migrate-other-pattern-up +-- This should not be touched + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + file_path.write_text(content) + + fixer.update_file_content(file_path, "20251011120000", "0001") + + updated = file_path.read_text() + + assert "-- name: migrate-0001-up" in updated + assert "-- name: migrate-0001-down" in updated + assert "-- name: migrate-other-pattern-up" in updated + + +def test_update_file_content_extension_version_pattern(temp_migrations_dir: Path) -> None: + """Test extension version patterns are handled correctly.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "ext_litestar_20251011120000_test.py" + content = '''"""Migration file.""" +# This references migrate-ext_litestar_20251011120000-up in a comment +''' + + file_path.write_text(content) + + fixer.update_file_content(file_path, "ext_litestar_20251011120000", "ext_litestar_0001") + + updated = file_path.read_text() + + assert content == updated + + +def test_regex_pattern_matches_exact_version_only(temp_migrations_dir: Path) -> None: + """Test regex pattern construction matches exact version only.""" + MigrationFixer(temp_migrations_dir) + + old_version = "20251011120000" + up_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-up)") + + test_cases = [ + ("-- name: migrate-20251011120000-up", True), + ("-- name: migrate-20251011120000-up", True), + ("-- name:migrate-20251011120000-up", False), + ("-- name: migrate-20251011120000-down", False), + ("-- name: migrate-2025101112000-up", False), + ("-- name: migrate-202510111200000-up", False), + ("-- name: migrate-20251011120001-up", False), + ("-- name: migrate-0001-up", False), + ("migrate-20251011120000-up", False), + ] + + for text, should_match in test_cases: + match = up_pattern.search(text) + if should_match: + assert match is not None, f"Expected match for: {text}" + else: + assert match is None, f"Should not match: {text}" + + +def test_regex_pattern_handles_down_direction(temp_migrations_dir: Path) -> None: + """Test regex pattern correctly handles down direction.""" + old_version = "20251011120000" + down_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-down)") + + test_cases = [ + ("-- name: migrate-20251011120000-down", True), + ("-- name: migrate-20251011120000-down", True), + ("-- name:migrate-20251011120000-down", False), + ("-- name: migrate-20251011120000-up", False), + ("-- name: migrate-20251011120001-down", False), + ] + + for text, should_match in test_cases: + match = down_pattern.search(text) + if should_match: + assert match is not None, f"Expected match for: {text}" + else: + assert match is None, f"Should not match: {text}" + + +def test_update_file_content_multiple_versions_in_file(temp_migrations_dir: Path) -> None: + """Test file with multiple different version references only updates target.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "20251011120000_test.sql" + content = """-- name: migrate-20251011120000-up +-- Migration depends on migrate-20251010110000-up being applied first +CREATE TABLE users (id INTEGER); + +-- name: migrate-20251011120000-down +-- Reverses the changes +DROP TABLE users; +""" + + file_path.write_text(content) + + fixer.update_file_content(file_path, "20251011120000", "0002") + + updated = file_path.read_text() + + assert "-- name: migrate-0002-up" in updated + assert "-- name: migrate-0002-down" in updated + assert "migrate-20251010110000-up" in updated + assert "Reverses the changes" in updated + + +def test_regex_escape_prevents_version_injection(temp_migrations_dir: Path) -> None: + """Test re.escape prevents regex injection in version strings.""" + malicious_version = "20251011.*" + + escaped = re.escape(malicious_version) + + pattern = re.compile(rf"(-- name:\s+migrate-){escaped}(-up)") + + should_not_match = [ + "-- name: migrate-20251011120000-up", + "-- name: migrate-20251011999999-up", + "-- name: migrate-20251011-up", + ] + + for text in should_not_match: + assert pattern.search(text) is None + + should_match = "-- name: migrate-20251011.*-up" + assert pattern.search(should_match) is not None + + +def test_update_file_content_preserves_line_endings(temp_migrations_dir: Path) -> None: + """Test file content update preserves line endings.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "20251011120000_test.sql" + content = "-- name: migrate-20251011120000-up\nCREATE TABLE users (id INTEGER);\n" + + file_path.write_text(content) + + fixer.update_file_content(file_path, "20251011120000", "0001") + + updated = file_path.read_text() + + assert "\n" in updated + assert updated.endswith("\n") + + +def test_update_file_content_handles_no_matches(temp_migrations_dir: Path) -> None: + """Test update when version does not appear in file.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "20251011120000_test.sql" + original_content = """-- name: migrate-20251012130000-up +CREATE TABLE products (id INTEGER); +""" + + file_path.write_text(original_content) + + fixer.update_file_content(file_path, "20251011120000", "0001") + + updated = file_path.read_text() + + assert updated == original_content + + +def test_update_file_content_complex_sql_not_affected(temp_migrations_dir: Path) -> None: + """Test complex SQL content is not affected by version replacement.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "20251011120000_test.sql" + content = """-- name: migrate-20251011120000-up +CREATE TABLE logs ( + id INTEGER PRIMARY KEY, + message TEXT CHECK (message != 'migrate-20251011120000-up'), + pattern VARCHAR(100) DEFAULT '-- name: migrate-pattern-up' +); + +INSERT INTO logs (message) VALUES ('Testing migrate-20251011120000-up reference'); + +-- name: migrate-20251011120000-down +DROP TABLE logs; +""" + + file_path.write_text(content) + + fixer.update_file_content(file_path, "20251011120000", "0001") + + updated = file_path.read_text() + + assert "-- name: migrate-0001-up" in updated + assert "-- name: migrate-0001-down" in updated + assert "message != 'migrate-20251011120000-up'" in updated + assert "pattern VARCHAR(100) DEFAULT '-- name: migrate-pattern-up'" in updated + assert "VALUES ('Testing migrate-20251011120000-up reference')" in updated + + +def test_update_file_content_timestamp_vs_sequential_collision(temp_migrations_dir: Path) -> None: + """Test version replacement doesn't confuse timestamp and sequential formats.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "20251011120000_test.sql" + content = """-- name: migrate-20251011120000-up +-- This migration comes after migrate-0001-up +CREATE TABLE users (id INTEGER); + +-- name: migrate-20251011120000-down +DROP TABLE users; +""" + + file_path.write_text(content) + + fixer.update_file_content(file_path, "20251011120000", "0002") + + updated = file_path.read_text() + + assert "-- name: migrate-0002-up" in updated + assert "-- name: migrate-0002-down" in updated + assert "after migrate-0001-up" in updated + + +def test_regex_pattern_boundary_conditions(temp_migrations_dir: Path) -> None: + """Test regex pattern handles boundary conditions correctly.""" + test_cases = [ + ("0001", "0002"), + ("9999", "10000"), + ("20251011120000", "0001"), + ("ext_litestar_0001", "ext_litestar_0002"), + ("ext_adk_20251011120000", "ext_adk_0001"), + ] + + for old_version, new_version in test_cases: + up_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-up)") + down_pattern = re.compile(rf"(-- name:\s+migrate-){re.escape(old_version)}(-down)") + + test_line_up = f"-- name: migrate-{old_version}-up" + test_line_down = f"-- name: migrate-{old_version}-down" + + assert up_pattern.search(test_line_up) is not None + assert down_pattern.search(test_line_down) is not None + + replaced_up = up_pattern.sub(rf"\g<1>{new_version}\g<2>", test_line_up) + replaced_down = down_pattern.sub(rf"\g<1>{new_version}\g<2>", test_line_down) + + assert replaced_up == f"-- name: migrate-{new_version}-up" + assert replaced_down == f"-- name: migrate-{new_version}-down" + + +def test_update_file_content_updates_version_comment(temp_migrations_dir: Path) -> None: + """Test version comment in migration header is updated.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "20251019204218_create_products.sql" + content = """-- SQLSpec Migration +-- Version: 20251019204218 +-- Description: create products table +-- Created: 2025-10-19T20:42:18.123456 + +-- name: migrate-20251019204218-up +CREATE TABLE products (id INTEGER PRIMARY KEY); + +-- name: migrate-20251019204218-down +DROP TABLE products; +""" + + file_path.write_text(content) + + fixer.update_file_content(file_path, "20251019204218", "0001") + + updated = file_path.read_text() + + assert "-- Version: 0001" in updated + assert "-- Version: 20251019204218" not in updated + assert "-- name: migrate-0001-up" in updated + assert "-- name: migrate-0001-down" in updated + assert "-- Description: create products table" in updated + + +def test_update_file_content_preserves_version_comment_format(temp_migrations_dir: Path) -> None: + """Test version comment format is preserved during update.""" + fixer = MigrationFixer(temp_migrations_dir) + + file_path = temp_migrations_dir / "0001_initial.sql" + content = """-- SQLSpec Migration +-- Version: 0001 +-- Another comment +""" + + file_path.write_text(content) + + fixer.update_file_content(file_path, "0001", "0002") + + updated = file_path.read_text() + + assert "-- Version: 0002" in updated + assert "-- Another comment" in updated diff --git a/tests/unit/test_migrations/test_migration_commands.py b/tests/unit/test_migrations/test_migration_commands.py index 686560b3..ec0c03a7 100644 --- a/tests/unit/test_migrations/test_migration_commands.py +++ b/tests/unit/test_migrations/test_migration_commands.py @@ -322,6 +322,7 @@ async def test_async_upgrade_empty_migration_folder(async_config: AiosqliteConfi commands = AsyncMigrationCommands(async_config) mock_driver = AsyncMock() + mock_driver.driver_features = {} with ( patch.object(async_config, "provide_session") as mock_session, patch("sqlspec.migrations.commands.console") as mock_console, @@ -350,7 +351,7 @@ def test_sync_upgrade_already_at_latest_version(sync_config: SqliteConfig) -> No patch.object(sync_config, "provide_session") as mock_session, patch("sqlspec.migrations.commands.console") as mock_console, patch.object(commands.runner, "get_migration_files", return_value=[("0001", mock_migration_file)]), - patch.object(commands.tracker, "get_current_version", return_value="0001"), + patch.object(commands.tracker, "get_applied_migrations", return_value=[{"version_num": "0001"}]), ): mock_session.return_value.__enter__.return_value = mock_driver @@ -367,13 +368,14 @@ async def test_async_upgrade_already_at_latest_version(async_config: AiosqliteCo commands = AsyncMigrationCommands(async_config) mock_driver = AsyncMock() + mock_driver.driver_features = {} mock_migration_file = Path("/fake/migrations/0001_initial.sql") with ( patch.object(async_config, "provide_session") as mock_session, patch("sqlspec.migrations.commands.console") as mock_console, patch.object(commands.runner, "get_migration_files", return_value=[("0001", mock_migration_file)]), - patch.object(commands.tracker, "get_current_version", return_value="0001"), + patch.object(commands.tracker, "get_applied_migrations", return_value=[{"version_num": "0001"}]), ): mock_session.return_value.__aenter__.return_value = mock_driver diff --git a/tests/unit/test_migrations/test_migration_execution.py b/tests/unit/test_migrations/test_migration_execution.py index 61887f9f..d68f3f0f 100644 --- a/tests/unit/test_migrations/test_migration_execution.py +++ b/tests/unit/test_migrations/test_migration_execution.py @@ -167,7 +167,7 @@ def test_applied_migrations_sql_generation() -> None: assert "*" in stmt.sql assert "test_migrations" in stmt.sql.lower() assert "ORDER BY" in stmt.sql.upper() - assert "version_num" in stmt.sql.lower() + assert "execution_sequence" in stmt.sql.lower() def test_record_migration_sql_generation() -> None: @@ -175,7 +175,13 @@ def test_record_migration_sql_generation() -> None: tracker = MockMigrationTracker("test_migrations") record_sql = tracker._get_record_migration_sql( - version="0001", description="test migration", execution_time_ms=250, checksum="abc123", applied_by="test_user" + version="0001", + version_type="sequential", + execution_sequence=1, + description="test migration", + execution_time_ms=250, + checksum="abc123", + applied_by="test_user", ) assert hasattr(record_sql, "to_statement") diff --git a/tests/unit/test_migrations/test_migration_runner.py b/tests/unit/test_migrations/test_migration_runner.py index e9cd97fe..0394ee0d 100644 --- a/tests/unit/test_migrations/test_migration_runner.py +++ b/tests/unit/test_migrations/test_migration_runner.py @@ -567,12 +567,250 @@ def test_sql_loader_caches_files() -> None: sql_loader = SQLFileLoader() async def test_operations() -> None: - await sql_loader.get_up_sql(migration_file) + sql_loader.validate_migration_file(migration_file) path_str = str(migration_file) assert path_str in sql_loader.sql_loader._files assert sql_loader.sql_loader.has_query("migrate-0001-up") assert sql_loader.sql_loader.has_query("migrate-0001-down") + + await sql_loader.get_up_sql(migration_file) + assert path_str in sql_loader.sql_loader._files + await sql_loader.get_down_sql(migration_file) assert path_str in sql_loader.sql_loader._files asyncio.run(test_operations()) + + +def test_no_duplicate_loading_during_migration_execution() -> None: + """Test that SQL files are loaded exactly once during migration execution. + + Verifies fix for issue #118 - validates that running a migration loads + the SQL file only once, not multiple times. Checks that the file is in + the loader's cache after validation and remains there throughout the workflow. + """ + import asyncio + + from sqlspec.migrations.loaders import SQLFileLoader + + with tempfile.TemporaryDirectory() as temp_dir: + migrations_path = Path(temp_dir) + + migration_file = migrations_path / "0001_create_users.sql" + migration_content = """ +-- name: migrate-0001-up +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + username TEXT NOT NULL +); + +-- name: migrate-0001-down +DROP TABLE users; +""" + migration_file.write_text(migration_content) + + sql_loader = SQLFileLoader() + + async def test_migration_workflow() -> None: + sql_loader.validate_migration_file(migration_file) + + path_str = str(migration_file) + assert path_str in sql_loader.sql_loader._files, "File should be loaded after validation" + assert sql_loader.sql_loader.has_query("migrate-0001-up") + assert sql_loader.sql_loader.has_query("migrate-0001-down") + + file_count_after_validation = len(sql_loader.sql_loader._files) + + await sql_loader.get_up_sql(migration_file) + file_count_after_up = len(sql_loader.sql_loader._files) + assert file_count_after_validation == file_count_after_up, "get_up_sql should not load additional files" + + await sql_loader.get_down_sql(migration_file) + file_count_after_down = len(sql_loader.sql_loader._files) + assert file_count_after_up == file_count_after_down, "get_down_sql should not load additional files" + + asyncio.run(test_migration_workflow()) + + +def test_sql_file_loader_counter_accuracy_single_file() -> None: + """Test SQLFileLoader caching behavior for single file loading. + + Verifies fix for issue #118 (Solution 2) - ensures that load_sql() + properly caches files. First call should load and parse the file, + second call should return immediately from cache without reparsing. + """ + from sqlspec.loader import SQLFileLoader + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + test_file = temp_path / "test_queries.sql" + test_content = """ +-- name: get_user +SELECT * FROM users WHERE id = :id; + +-- name: list_users +SELECT * FROM users; + +-- name: delete_user +DELETE FROM users WHERE id = :id; +""" + test_file.write_text(test_content) + + loader = SQLFileLoader() + + loader.load_sql(test_file) + path_str = str(test_file) + assert path_str in loader._files, "First load should add file to cache" + assert len(loader._queries) == 3, "First load should parse 3 queries" + + query_count_before_reload = len(loader._queries) + file_count_before_reload = len(loader._files) + + loader.load_sql(test_file) + + assert len(loader._queries) == query_count_before_reload, "Second load should not add new queries (cached)" + assert len(loader._files) == file_count_before_reload, "Second load should not add new files (cached)" + + +def test_sql_file_loader_counter_accuracy_directory() -> None: + """Test SQLFileLoader caching behavior for directory loading. + + Verifies that _load_directory() properly caches files and doesn't + reload them on subsequent calls. + """ + from sqlspec.loader import SQLFileLoader + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + file1 = temp_path / "queries1.sql" + file1.write_text(""" +-- name: query1 +SELECT 1; +""") + + file2 = temp_path / "queries2.sql" + file2.write_text(""" +-- name: query2 +SELECT 2; +""") + + loader = SQLFileLoader() + + loader.load_sql(temp_path) + assert len(loader._files) == 2, "First load should add 2 files to cache" + assert len(loader._queries) == 2, "First load should parse 2 queries" + + query_count_before_reload = len(loader._queries) + file_count_before_reload = len(loader._files) + + loader.load_sql(temp_path) + + assert len(loader._queries) == query_count_before_reload, "Second load should not add new queries (all cached)" + assert len(loader._files) == file_count_before_reload, "Second load should not add new files (all cached)" + + +def test_migration_workflow_single_load_design() -> None: + """Test that migration workflow respects single-load design. + + Verifies fix for issue #118 (Solution 1) - confirms that: + 1. validate_migration_file() loads the file and parses queries + 2. get_up_sql() retrieves queries WITHOUT reloading the file + 3. get_down_sql() retrieves queries WITHOUT reloading the file + + All three operations should use the same cached file. + """ + import asyncio + + from sqlspec.migrations.loaders import SQLFileLoader + + with tempfile.TemporaryDirectory() as temp_dir: + migrations_path = Path(temp_dir) + + migration_file = migrations_path / "0001_test.sql" + migration_content = """ +-- name: migrate-0001-up +CREATE TABLE test_table (id INTEGER); + +-- name: migrate-0001-down +DROP TABLE test_table; +""" + migration_file.write_text(migration_content) + + sql_loader = SQLFileLoader() + + async def test_workflow() -> None: + sql_loader.validate_migration_file(migration_file) + + path_str = str(migration_file) + assert path_str in sql_loader.sql_loader._files, "File should be loaded after validation" + assert sql_loader.sql_loader.has_query("migrate-0001-up") + assert sql_loader.sql_loader.has_query("migrate-0001-down") + + file_count_before_up = len(sql_loader.sql_loader._files) + up_sql = await sql_loader.get_up_sql(migration_file) + file_count_after_up = len(sql_loader.sql_loader._files) + + assert file_count_before_up == file_count_after_up, "get_up_sql() should not load additional files" + assert len(up_sql) == 1 + assert "CREATE TABLE test_table" in up_sql[0] + + file_count_before_down = len(sql_loader.sql_loader._files) + down_sql = await sql_loader.get_down_sql(migration_file) + file_count_after_down = len(sql_loader.sql_loader._files) + + assert file_count_before_down == file_count_after_down, "get_down_sql() should not load additional files" + assert len(down_sql) == 1 + assert "DROP TABLE test_table" in down_sql[0] + + asyncio.run(test_workflow()) + + +def test_migration_loader_does_not_reload_on_get_sql_calls() -> None: + """Test that get_up_sql and get_down_sql do not trigger file reloads. + + Verifies that after validate_migration_file() loads the file, + subsequent calls to get_up_sql() and get_down_sql() retrieve + the cached queries without calling load_sql() again. + """ + import asyncio + + from sqlspec.loader import SQLFileLoader as CoreSQLFileLoader + from sqlspec.migrations.loaders import SQLFileLoader + + with tempfile.TemporaryDirectory() as temp_dir: + migrations_path = Path(temp_dir) + + migration_file = migrations_path / "0001_schema.sql" + migration_content = """ +-- name: migrate-0001-up +CREATE TABLE products (id INTEGER, name TEXT); + +-- name: migrate-0001-down +DROP TABLE products; +""" + migration_file.write_text(migration_content) + + sql_loader = SQLFileLoader() + + call_counts = {"load_sql": 0} + original_load_sql = CoreSQLFileLoader.load_sql + + def counting_load_sql(self: CoreSQLFileLoader, *args: Any, **kwargs: Any) -> None: + call_counts["load_sql"] += 1 + return original_load_sql(self, *args, **kwargs) + + with patch.object(CoreSQLFileLoader, "load_sql", counting_load_sql): + + async def test_no_reload() -> None: + sql_loader.validate_migration_file(migration_file) + assert call_counts["load_sql"] == 1, "validate_migration_file should call load_sql exactly once" + + await sql_loader.get_up_sql(migration_file) + assert call_counts["load_sql"] == 1, "get_up_sql should NOT call load_sql (should use cache)" + + await sql_loader.get_down_sql(migration_file) + assert call_counts["load_sql"] == 1, "get_down_sql should NOT call load_sql (should use cache)" + + asyncio.run(test_no_reload()) diff --git a/tests/unit/test_migrations/test_tracker_idempotency.py b/tests/unit/test_migrations/test_tracker_idempotency.py new file mode 100644 index 00000000..d39e7681 --- /dev/null +++ b/tests/unit/test_migrations/test_tracker_idempotency.py @@ -0,0 +1,317 @@ +"""Unit tests for idempotent update_version_record behavior.""" + +from typing import Any +from unittest.mock import MagicMock, Mock + +import pytest + +from sqlspec.migrations.tracker import AsyncMigrationTracker, SyncMigrationTracker + + +def test_sync_update_version_record_success() -> None: + """Test sync update succeeds when old version exists.""" + tracker = SyncMigrationTracker() + driver = Mock() + + mock_result = Mock() + mock_result.rows_affected = 1 + driver.execute.return_value = mock_result + + tracker.update_version_record(driver, "20251011120000", "0001") + + update_call = driver.execute.call_args_list[0] + update_sql = str(update_call[0][0]) + assert "UPDATE" in update_sql + assert "ddl_migrations" in update_sql + + +def test_sync_update_version_record_idempotent_when_already_updated() -> None: + """Test sync update is idempotent when version already exists.""" + tracker = SyncMigrationTracker() + driver = Mock() + + update_result = Mock() + update_result.rows_affected = 0 + + check_result = Mock() + check_result.data = [ + {"version_num": "0001", "version_type": "sequential"}, + {"version_num": "0002", "version_type": "sequential"}, + ] + + driver.execute.side_effect = [update_result, check_result] + + tracker.update_version_record(driver, "20251011120000", "0001") + + assert driver.execute.call_count == 2 + + +def test_sync_update_version_record_raises_when_neither_version_exists() -> None: + """Test sync update raises ValueError when neither old nor new version exists.""" + tracker = SyncMigrationTracker() + driver = Mock() + + update_result = Mock() + update_result.rows_affected = 0 + + check_result = Mock() + check_result.data = [{"version_num": "0002", "version_type": "sequential"}] + + driver.execute.side_effect = [update_result, check_result] + + with pytest.raises(ValueError, match="Migration version 20251011120000 not found in database"): + tracker.update_version_record(driver, "20251011120000", "0001") + + +def test_sync_update_version_record_empty_database() -> None: + """Test sync update raises when database is empty.""" + tracker = SyncMigrationTracker() + driver = Mock() + + update_result = Mock() + update_result.rows_affected = 0 + + check_result = Mock() + check_result.data = [] + + driver.execute.side_effect = [update_result, check_result] + + with pytest.raises(ValueError, match="Migration version 20251011120000 not found in database"): + tracker.update_version_record(driver, "20251011120000", "0001") + + +def test_sync_update_version_record_commits_after_success() -> None: + """Test sync update commits transaction after successful update.""" + tracker = SyncMigrationTracker() + driver = Mock() + driver.connection = None + driver.driver_features = {} + + mock_result = Mock() + mock_result.rows_affected = 1 + driver.execute.return_value = mock_result + + tracker.update_version_record(driver, "20251011120000", "0001") + + driver.commit.assert_called_once() + + +def test_sync_update_version_record_no_commit_on_idempotent_path() -> None: + """Test sync update does not commit when taking idempotent path.""" + tracker = SyncMigrationTracker() + driver = Mock() + driver.connection = Mock() + driver.connection.autocommit = False + + update_result = Mock() + update_result.rows_affected = 0 + + check_result = Mock() + check_result.data = [{"version_num": "0001", "version_type": "sequential"}] + + driver.execute.side_effect = [update_result, check_result] + + tracker.update_version_record(driver, "20251011120000", "0001") + + driver.commit.assert_not_called() + + +@pytest.mark.asyncio +async def test_async_update_version_record_success() -> None: + """Test async update succeeds when old version exists.""" + from unittest.mock import AsyncMock + + tracker = AsyncMigrationTracker() + driver = MagicMock() + + mock_result = Mock() + mock_result.rows_affected = 1 + + async def mock_execute(sql: Any) -> Mock: + return mock_result + + driver.execute = AsyncMock(side_effect=mock_execute) + + await tracker.update_version_record(driver, "20251011120000", "0001") + + update_call = driver.execute.call_args_list[0] + update_sql = str(update_call[0][0]) + assert "UPDATE" in update_sql + assert "ddl_migrations" in update_sql + + +@pytest.mark.asyncio +async def test_async_update_version_record_idempotent_when_already_updated() -> None: + """Test async update is idempotent when version already exists.""" + from unittest.mock import AsyncMock + + tracker = AsyncMigrationTracker() + driver = MagicMock() + + update_result = Mock() + update_result.rows_affected = 0 + + check_result = Mock() + check_result.data = [ + {"version_num": "0001", "version_type": "sequential"}, + {"version_num": "0002", "version_type": "sequential"}, + ] + + call_count = [0] + + async def mock_execute(sql: Any) -> Mock: + call_count[0] += 1 + if call_count[0] == 1: + return update_result + return check_result + + driver.execute = AsyncMock(side_effect=mock_execute) + + await tracker.update_version_record(driver, "20251011120000", "0001") + + assert driver.execute.call_count == 2 + + +@pytest.mark.asyncio +async def test_async_update_version_record_raises_when_neither_version_exists() -> None: + """Test async update raises ValueError when neither old nor new version exists.""" + from unittest.mock import AsyncMock + + tracker = AsyncMigrationTracker() + driver = MagicMock() + + update_result = Mock() + update_result.rows_affected = 0 + + check_result = Mock() + check_result.data = [{"version_num": "0002", "version_type": "sequential"}] + + call_count = [0] + + async def mock_execute(sql: Any) -> Mock: + call_count[0] += 1 + if call_count[0] == 1: + return update_result + return check_result + + driver.execute = AsyncMock(side_effect=mock_execute) + + with pytest.raises(ValueError, match="Migration version 20251011120000 not found in database"): + await tracker.update_version_record(driver, "20251011120000", "0001") + + +@pytest.mark.asyncio +async def test_async_update_version_record_empty_database() -> None: + """Test async update raises when database is empty.""" + from unittest.mock import AsyncMock + + tracker = AsyncMigrationTracker() + driver = MagicMock() + + update_result = Mock() + update_result.rows_affected = 0 + + check_result = Mock() + check_result.data = [] + + call_count = [0] + + async def mock_execute(sql: Any) -> Mock: + call_count[0] += 1 + if call_count[0] == 1: + return update_result + return check_result + + driver.execute = AsyncMock(side_effect=mock_execute) + + with pytest.raises(ValueError, match="Migration version 20251011120000 not found in database"): + await tracker.update_version_record(driver, "20251011120000", "0001") + + +@pytest.mark.asyncio +async def test_async_update_version_record_commits_after_success() -> None: + """Test async update commits transaction after successful update.""" + from unittest.mock import AsyncMock + + tracker = AsyncMigrationTracker() + driver = MagicMock() + driver.connection = None + driver.driver_features = {} + + mock_result = Mock() + mock_result.rows_affected = 1 + + async def mock_execute(sql: Any) -> Mock: + return mock_result + + driver.execute = AsyncMock(side_effect=mock_execute) + driver.commit = AsyncMock() + + await tracker.update_version_record(driver, "20251011120000", "0001") + + driver.commit.assert_called_once() + + +@pytest.mark.asyncio +async def test_async_update_version_record_no_commit_on_idempotent_path() -> None: + """Test async update does not commit when taking idempotent path.""" + from unittest.mock import AsyncMock + + tracker = AsyncMigrationTracker() + driver = MagicMock() + driver.connection = None + driver.driver_features = {} + + update_result = Mock() + update_result.rows_affected = 0 + + check_result = Mock() + check_result.data = [{"version_num": "0001", "version_type": "sequential"}] + + call_count = [0] + + async def mock_execute(sql: Any) -> Mock: + call_count[0] += 1 + if call_count[0] == 1: + return update_result + return check_result + + driver.execute = AsyncMock(side_effect=mock_execute) + driver.commit = AsyncMock() + + await tracker.update_version_record(driver, "20251011120000", "0001") + + driver.commit.assert_not_called() + + +def test_sync_update_version_preserves_sequential_type() -> None: + """Test sync update correctly sets version_type to sequential.""" + tracker = SyncMigrationTracker() + driver = Mock() + + mock_result = Mock() + mock_result.rows_affected = 1 + driver.execute.return_value = mock_result + + tracker.update_version_record(driver, "20251011120000", "0001") + + update_call = driver.execute.call_args_list[0] + update_sql = str(update_call[0][0]) + assert "version_type" in update_sql + assert "SET" in update_sql + + +def test_sync_update_version_handles_extension_versions() -> None: + """Test sync update handles extension version format.""" + tracker = SyncMigrationTracker() + driver = Mock() + + mock_result = Mock() + mock_result.rows_affected = 1 + driver.execute.return_value = mock_result + + tracker.update_version_record(driver, "ext_litestar_20251011120000", "ext_litestar_0001") + + update_call = driver.execute.call_args_list[0] + update_sql = str(update_call[0][0]) + assert "UPDATE" in update_sql diff --git a/tests/unit/test_migrations/test_validation.py b/tests/unit/test_migrations/test_validation.py new file mode 100644 index 00000000..57b69c87 --- /dev/null +++ b/tests/unit/test_migrations/test_validation.py @@ -0,0 +1,206 @@ +"""Tests for migration validation and out-of-order detection.""" + +from typing import Any + +import pytest + +from sqlspec.exceptions import OutOfOrderMigrationError +from sqlspec.migrations.validation import ( + MigrationGap, + detect_out_of_order_migrations, + format_out_of_order_warning, + validate_migration_order, +) +from sqlspec.utils.version import parse_version + + +def test_detect_out_of_order_no_applied() -> None: + """Test detection with no applied migrations.""" + pending: list[str] = ["20251011120000", "20251012120000"] + applied: list[str] = [] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert gaps == [] + + +def test_detect_out_of_order_no_pending() -> None: + """Test detection with no pending migrations.""" + pending: list[str] = [] + applied: list[str] = ["20251011120000", "20251012120000"] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert gaps == [] + + +def test_detect_out_of_order_no_gaps() -> None: + """Test detection with no out-of-order migrations.""" + pending: list[str] = ["20251013120000", "20251014120000"] + applied: list[str] = ["20251011120000", "20251012120000"] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert gaps == [] + + +def test_detect_out_of_order_single_gap() -> None: + """Test detection with single out-of-order migration.""" + pending: list[str] = ["20251011130000", "20251013120000"] + applied: list[str] = ["20251011120000", "20251012120000"] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert len(gaps) == 1 + assert gaps[0].missing_version == parse_version("20251011130000") + assert gaps[0].applied_after == [parse_version("20251012120000")] + + +def test_detect_out_of_order_multiple_gaps() -> None: + """Test detection with multiple out-of-order migrations.""" + pending: list[str] = ["20251011130000", "20251011140000", "20251013120000"] + applied: list[str] = ["20251011120000", "20251012120000"] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert len(gaps) == 2 + assert gaps[0].missing_version == parse_version("20251011130000") + assert gaps[1].missing_version == parse_version("20251011140000") + + +def test_detect_out_of_order_with_sequential() -> None: + """Test detection works with mixed sequential and timestamp versions.""" + pending: list[str] = ["20251011120000"] + applied: list[str] = ["0001", "0002", "20251012120000"] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert len(gaps) == 1 + assert gaps[0].missing_version == parse_version("20251011120000") + + +def test_detect_out_of_order_extension_versions_excluded() -> None: + """Test that extension migrations are excluded from out-of-order detection. + + Extension migrations maintain independent sequences within their namespaces + and should not be flagged as out-of-order relative to other migrations. + """ + pending: list[str] = ["ext_adk_0001", "ext_litestar_0001"] + applied: list[str] = ["0001", "0002", "ext_litestar_0002"] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert gaps == [] + + +def test_detect_out_of_order_mixed_core_and_extension() -> None: + """Test detection with mixed core and extension migrations. + + Only core migrations should be checked for out-of-order status. + """ + pending: list[str] = ["20251011130000", "ext_litestar_0001"] + applied: list[str] = ["20251012120000", "ext_adk_0001"] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert len(gaps) == 1 + assert gaps[0].missing_version == parse_version("20251011130000") + assert gaps[0].missing_version.extension is None + + +def test_format_out_of_order_warning_empty() -> None: + """Test formatting with no gaps.""" + warning = format_out_of_order_warning([]) + + assert warning == "" + + +def test_format_out_of_order_warning_single() -> None: + """Test formatting with single gap.""" + gap = MigrationGap(missing_version=parse_version("20251011130000"), applied_after=[parse_version("20251012120000")]) + + warning = format_out_of_order_warning([gap]) + + assert "Out-of-order migrations detected" in warning + assert "20251011130000" in warning + assert "20251012120000" in warning + assert "created before" in warning + + +def test_format_out_of_order_warning_multiple() -> None: + """Test formatting with multiple gaps.""" + gaps = [ + MigrationGap( + missing_version=parse_version("20251011130000"), + applied_after=[parse_version("20251012120000"), parse_version("20251013120000")], + ), + MigrationGap(missing_version=parse_version("20251011140000"), applied_after=[parse_version("20251012120000")]), + ] + + warning = format_out_of_order_warning(gaps) + + assert "20251011130000" in warning + assert "20251011140000" in warning + assert "20251012120000" in warning + assert "20251013120000" in warning + + +def test_validate_migration_order_no_gaps() -> None: + """Test validation with no out-of-order migrations.""" + pending: list[str] = ["20251013120000"] + applied: list[str] = ["20251011120000", "20251012120000"] + + validate_migration_order(pending, applied, strict_ordering=False) + validate_migration_order(pending, applied, strict_ordering=True) + + +def test_validate_migration_order_warns_by_default(capsys: "Any") -> None: + """Test validation warns but allows out-of-order migrations by default.""" + pending: list[str] = ["20251011130000"] + applied: list[str] = ["20251012120000"] + + validate_migration_order(pending, applied, strict_ordering=False) + + captured = capsys.readouterr() + assert "Out-of-order migrations detected" in captured.out + + +def test_validate_migration_order_strict_raises() -> None: + """Test validation raises error in strict mode.""" + pending: list[str] = ["20251011130000"] + applied: list[str] = ["20251012120000"] + + with pytest.raises(OutOfOrderMigrationError) as exc_info: + validate_migration_order(pending, applied, strict_ordering=True) + + assert "Out-of-order migrations detected" in str(exc_info.value) + assert "20251011130000" in str(exc_info.value) + assert "Strict ordering is enabled" in str(exc_info.value) + + +def test_migration_gap_frozen() -> None: + """Test MigrationGap is frozen (immutable).""" + gap = MigrationGap(missing_version=parse_version("20251011130000"), applied_after=[parse_version("20251012120000")]) + + with pytest.raises(Exception): + gap.missing_version = parse_version("20251011140000") # type: ignore[misc] + + +def test_detect_out_of_order_complex_scenario() -> None: + """Test detection with complex real-world scenario.""" + pending: list[str] = ["20251011100000", "20251011150000", "20251012100000", "20251015120000"] + applied: list[str] = ["20251011120000", "20251011140000", "20251013120000"] + + gaps = detect_out_of_order_migrations(pending, applied) + + assert len(gaps) == 3 + assert gaps[0].missing_version == parse_version("20251011100000") + assert gaps[1].missing_version == parse_version("20251011150000") + assert gaps[2].missing_version == parse_version("20251012100000") + + pending_versions = {g.missing_version.raw for g in gaps} + assert "20251011100000" in pending_versions + assert "20251011150000" in pending_versions + assert "20251012100000" in pending_versions + + assert parse_version("20251015120000") not in [g.missing_version for g in gaps] diff --git a/tests/unit/test_migrations/test_version.py b/tests/unit/test_migrations/test_version.py new file mode 100644 index 00000000..01e7309f --- /dev/null +++ b/tests/unit/test_migrations/test_version.py @@ -0,0 +1,225 @@ +"""Unit tests for migration version parsing and comparison.""" + +from datetime import datetime, timezone + +import pytest + +from sqlspec.utils.version import ( + VersionType, + generate_timestamp_version, + is_sequential_version, + is_timestamp_version, + parse_version, +) + + +def test_is_sequential_version() -> None: + """Test sequential version detection.""" + assert is_sequential_version("0001") + assert is_sequential_version("42") + assert is_sequential_version("9999") + assert is_sequential_version("1") + assert is_sequential_version("00001") + assert is_sequential_version("10000") + + assert not is_sequential_version("20251011120000") + assert not is_sequential_version("abc") + assert not is_sequential_version("") + + +def test_is_timestamp_version() -> None: + """Test timestamp version detection.""" + assert is_timestamp_version("20251011120000") + assert is_timestamp_version("20200101000000") + assert is_timestamp_version("20991231235959") + + assert not is_timestamp_version("0001") + assert not is_timestamp_version("2025101112") + assert not is_timestamp_version("20259999999999") + assert not is_timestamp_version("") + + +def test_parse_sequential_version() -> None: + """Test parsing sequential versions.""" + v = parse_version("0001") + assert v.raw == "0001" + assert v.type == VersionType.SEQUENTIAL + assert v.sequence == 1 + assert v.timestamp is None + assert v.extension is None + + v = parse_version("42") + assert v.sequence == 42 + + v = parse_version("9999") + assert v.sequence == 9999 + + +def test_parse_timestamp_version() -> None: + """Test parsing timestamp versions.""" + v = parse_version("20251011120000") + assert v.raw == "20251011120000" + assert v.type == VersionType.TIMESTAMP + assert v.sequence is None + assert v.timestamp == datetime(2025, 10, 11, 12, 0, 0, tzinfo=timezone.utc) + assert v.extension is None + + v = parse_version("20200101000000") + assert v.timestamp == datetime(2020, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + + +def test_parse_extension_version_sequential() -> None: + """Test parsing extension versions with sequential format.""" + v = parse_version("ext_litestar_0001") + assert v.raw == "ext_litestar_0001" + assert v.type == VersionType.SEQUENTIAL + assert v.sequence == 1 + assert v.extension == "litestar" + + v = parse_version("ext_myext_42") + assert v.sequence == 42 + assert v.extension == "myext" + + +def test_parse_extension_version_timestamp() -> None: + """Test parsing extension versions with timestamp format.""" + v = parse_version("ext_litestar_20251011120000") + assert v.raw == "ext_litestar_20251011120000" + assert v.type == VersionType.TIMESTAMP + assert v.timestamp == datetime(2025, 10, 11, 12, 0, 0, tzinfo=timezone.utc) + assert v.extension == "litestar" + + +def test_parse_invalid_version() -> None: + """Test parsing invalid version formats.""" + with pytest.raises(ValueError, match="Invalid migration version format"): + parse_version("abc") + + with pytest.raises(ValueError, match="Invalid migration version format"): + parse_version("") + + with pytest.raises(ValueError, match="Invalid migration version format"): + parse_version("20259999999999") + + +def test_version_comparison_sequential() -> None: + """Test comparing sequential versions.""" + v1 = parse_version("0001") + v2 = parse_version("0002") + v42 = parse_version("42") + + assert v1 < v2 + assert v2 < v42 + assert not v2 < v1 + assert not v42 < v2 + + +def test_version_comparison_timestamp() -> None: + """Test comparing timestamp versions.""" + v1 = parse_version("20200101000000") + v2 = parse_version("20251011120000") + v3 = parse_version("20251011130000") + + assert v1 < v2 + assert v2 < v3 + assert not v2 < v1 + assert not v3 < v2 + + +def test_version_comparison_mixed() -> None: + """Test comparing mixed sequential and timestamp versions. + + Sequential versions should sort before timestamp versions (legacy priority). + """ + sequential = parse_version("9999") + timestamp = parse_version("20200101000000") + + assert sequential < timestamp + assert not timestamp < sequential + + +def test_version_comparison_extension() -> None: + """Test comparing extension versions.""" + main = parse_version("0001") + ext1 = parse_version("ext_aaa_0001") + ext2 = parse_version("ext_bbb_0001") + + assert main < ext1 + assert main < ext2 + assert ext1 < ext2 + + +def test_version_equality() -> None: + """Test version equality.""" + v1 = parse_version("0001") + v2 = parse_version("0001") + v3 = parse_version("0002") + + assert v1 == v2 + assert not v1 == v3 + assert v1 != v3 + + +def test_version_hash() -> None: + """Test version hashing for use in sets/dicts.""" + v1 = parse_version("0001") + v2 = parse_version("0001") + v3 = parse_version("0002") + + assert hash(v1) == hash(v2) + assert hash(v1) != hash(v3) + + version_set = {v1, v2, v3} + assert len(version_set) == 2 + + +def test_version_sorting() -> None: + """Test sorting versions.""" + versions = [ + parse_version("ext_bbb_0002"), + parse_version("20251011120000"), + parse_version("0002"), + parse_version("ext_aaa_0001"), + parse_version("0001"), + parse_version("20200101000000"), + ] + + sorted_versions = sorted(versions) + + expected_order = ["0001", "0002", "20200101000000", "20251011120000", "ext_aaa_0001", "ext_bbb_0002"] + + assert [v.raw for v in sorted_versions] == expected_order + + +def test_generate_timestamp_version() -> None: + """Test timestamp version generation.""" + version = generate_timestamp_version() + + assert len(version) == 14 + assert version.isdigit() + assert is_timestamp_version(version) + + parsed = parse_version(version) + assert parsed.type == VersionType.TIMESTAMP + assert parsed.timestamp is not None + + +def test_generate_timestamp_version_uniqueness() -> None: + """Test that generated timestamps are unique (within reasonable time).""" + v1 = generate_timestamp_version() + v2 = generate_timestamp_version() + + assert v1 <= v2 + + +def test_version_repr() -> None: + """Test version string representation.""" + v = parse_version("0001") + assert "sequential" in repr(v) + assert "0001" in repr(v) + + v = parse_version("20251011120000") + assert "timestamp" in repr(v) + + v = parse_version("ext_litestar_0001") + assert "litestar" in repr(v) diff --git a/tests/unit/test_migrations/test_version_conversion.py b/tests/unit/test_migrations/test_version_conversion.py new file mode 100644 index 00000000..587983c1 --- /dev/null +++ b/tests/unit/test_migrations/test_version_conversion.py @@ -0,0 +1,303 @@ +"""Unit tests for migration version conversion utilities.""" + +from pathlib import Path + +import pytest + +from sqlspec.utils.version import ( + convert_to_sequential_version, + generate_conversion_map, + get_next_sequential_number, + parse_version, +) + + +def test_get_next_sequential_number_empty() -> None: + """Test next sequential number with empty list.""" + assert get_next_sequential_number([]) == 1 + + +def test_get_next_sequential_number_single() -> None: + """Test next sequential number with single migration.""" + v1 = parse_version("0001") + assert get_next_sequential_number([v1]) == 2 + + +def test_get_next_sequential_number_multiple() -> None: + """Test next sequential number with multiple sequential migrations.""" + v1 = parse_version("0001") + v2 = parse_version("0002") + v3 = parse_version("0003") + assert get_next_sequential_number([v1, v2, v3]) == 4 + + +def test_get_next_sequential_number_with_timestamps() -> None: + """Test next sequential number ignores timestamp migrations.""" + v1 = parse_version("0001") + v2 = parse_version("0002") + t1 = parse_version("20251011120000") + t2 = parse_version("20251012130000") + + result = get_next_sequential_number([v1, t1, v2, t2]) + assert result == 3 + + +def test_get_next_sequential_number_ignores_extensions() -> None: + """Test next sequential number ignores extension migrations.""" + core = parse_version("0001") + ext1 = parse_version("ext_litestar_0001") + ext2 = parse_version("ext_litestar_0002") + + result = get_next_sequential_number([core, ext1, ext2]) + assert result == 2 + + +def test_get_next_sequential_number_only_timestamps() -> None: + """Test next sequential number with only timestamp migrations.""" + t1 = parse_version("20251011120000") + t2 = parse_version("20251012130000") + + result = get_next_sequential_number([t1, t2]) + assert result == 1 + + +def test_get_next_sequential_number_only_extensions() -> None: + """Test next sequential number with only extension migrations.""" + ext1 = parse_version("ext_litestar_0001") + ext2 = parse_version("ext_adk_0001") + + result = get_next_sequential_number([ext1, ext2]) + assert result == 1 + + +def test_get_next_sequential_number_high_numbers() -> None: + """Test next sequential number with high sequence numbers.""" + v1 = parse_version("9998") + v2 = parse_version("9999") + + result = get_next_sequential_number([v1, v2]) + assert result == 10000 + + +def test_convert_to_sequential_version_basic() -> None: + """Test basic timestamp to sequential conversion.""" + v = parse_version("20251011120000") + result = convert_to_sequential_version(v, 3) + assert result == "0003" + + +def test_convert_to_sequential_version_zero_padding() -> None: + """Test zero padding in sequential version.""" + v = parse_version("20251011120000") + + assert convert_to_sequential_version(v, 1) == "0001" + assert convert_to_sequential_version(v, 10) == "0010" + assert convert_to_sequential_version(v, 100) == "0100" + assert convert_to_sequential_version(v, 1000) == "1000" + + +def test_convert_to_sequential_version_with_extension() -> None: + """Test conversion preserves extension prefix.""" + v = parse_version("ext_litestar_20251011120000") + result = convert_to_sequential_version(v, 1) + assert result == "ext_litestar_0001" + + +def test_convert_to_sequential_version_various_extensions() -> None: + """Test conversion with various extension names.""" + v1 = parse_version("ext_adk_20251011120000") + assert convert_to_sequential_version(v1, 2) == "ext_adk_0002" + + v2 = parse_version("ext_myext_20251011120000") + assert convert_to_sequential_version(v2, 42) == "ext_myext_0042" + + +def test_convert_to_sequential_version_rejects_sequential() -> None: + """Test conversion rejects sequential input.""" + v = parse_version("0001") + + with pytest.raises(ValueError, match="Can only convert timestamp versions"): + convert_to_sequential_version(v, 2) + + +def test_generate_conversion_map_empty() -> None: + """Test conversion map generation with empty list.""" + result = generate_conversion_map([]) + assert result == {} + + +def test_generate_conversion_map_no_timestamps() -> None: + """Test conversion map with only sequential migrations.""" + migrations = [("0001", Path("0001_init.sql")), ("0002", Path("0002_users.sql"))] + result = generate_conversion_map(migrations) + assert result == {} + + +def test_generate_conversion_map_basic() -> None: + """Test basic conversion map generation.""" + migrations = [ + ("0001", Path("0001_init.sql")), + ("0002", Path("0002_users.sql")), + ("20251011120000", Path("20251011120000_products.sql")), + ("20251012130000", Path("20251012130000_orders.sql")), + ] + + result = generate_conversion_map(migrations) + + assert result == {"20251011120000": "0003", "20251012130000": "0004"} + + +def test_generate_conversion_map_chronological_order() -> None: + """Test conversion map respects chronological order.""" + migrations = [ + ("20251012130000", Path("20251012130000_later.sql")), + ("20251011120000", Path("20251011120000_earlier.sql")), + ("20251010090000", Path("20251010090000_earliest.sql")), + ] + + result = generate_conversion_map(migrations) + + assert result == {"20251010090000": "0001", "20251011120000": "0002", "20251012130000": "0003"} + + +def test_generate_conversion_map_only_timestamps() -> None: + """Test conversion map with only timestamp migrations.""" + migrations = [ + ("20251011120000", Path("20251011120000_first.sql")), + ("20251012130000", Path("20251012130000_second.sql")), + ] + + result = generate_conversion_map(migrations) + + assert result == {"20251011120000": "0001", "20251012130000": "0002"} + + +def test_generate_conversion_map_mixed_formats() -> None: + """Test conversion map with mixed sequential and timestamp.""" + migrations = [ + ("0001", Path("0001_init.sql")), + ("20251011120000", Path("20251011120000_products.sql")), + ("0002", Path("0002_users.sql")), + ("20251012130000", Path("20251012130000_orders.sql")), + ("0003", Path("0003_settings.sql")), + ] + + result = generate_conversion_map(migrations) + + assert result == {"20251011120000": "0004", "20251012130000": "0005"} + + +def test_generate_conversion_map_with_extensions() -> None: + """Test conversion map handles extension migrations correctly.""" + migrations = [ + ("0001", Path("0001_init.sql")), + ("ext_litestar_20251011120000", Path("ext_litestar_20251011120000_sessions.py")), + ("20251012130000", Path("20251012130000_products.sql")), + ("ext_adk_20251011120000", Path("ext_adk_20251011120000_tables.py")), + ] + + result = generate_conversion_map(migrations) + + assert result == { + "20251012130000": "0002", + "ext_litestar_20251011120000": "ext_litestar_0001", + "ext_adk_20251011120000": "ext_adk_0001", + } + + +def test_generate_conversion_map_extension_namespaces() -> None: + """Test extension migrations maintain separate numbering.""" + migrations = [ + ("ext_litestar_0001", Path("ext_litestar_0001_existing.py")), + ("ext_litestar_20251011120000", Path("ext_litestar_20251011120000_new.py")), + ("ext_adk_20251011120000", Path("ext_adk_20251011120000_new.py")), + ] + + result = generate_conversion_map(migrations) + + assert result == {"ext_litestar_20251011120000": "ext_litestar_0002", "ext_adk_20251011120000": "ext_adk_0001"} + + +def test_generate_conversion_map_multiple_extension_timestamps() -> None: + """Test multiple timestamp migrations for same extension.""" + migrations = [ + ("ext_litestar_20251011120000", Path("ext_litestar_20251011120000_first.py")), + ("ext_litestar_20251012130000", Path("ext_litestar_20251012130000_second.py")), + ("ext_litestar_20251013140000", Path("ext_litestar_20251013140000_third.py")), + ] + + result = generate_conversion_map(migrations) + + assert result == { + "ext_litestar_20251011120000": "ext_litestar_0001", + "ext_litestar_20251012130000": "ext_litestar_0002", + "ext_litestar_20251013140000": "ext_litestar_0003", + } + + +def test_generate_conversion_map_ignores_invalid_versions() -> None: + """Test conversion map skips invalid migration versions.""" + migrations = [ + ("0001", Path("0001_init.sql")), + ("invalid", Path("invalid_migration.sql")), + ("20251011120000", Path("20251011120000_products.sql")), + ("", Path("empty.sql")), + ] + + result = generate_conversion_map(migrations) + + assert result == {"20251011120000": "0002"} + + +def test_generate_conversion_map_complex_scenario() -> None: + """Test conversion map with complex real-world scenario.""" + migrations = [ + ("0001", Path("0001_init.sql")), + ("0002", Path("0002_users.sql")), + ("20251011120000", Path("20251011120000_products.sql")), + ("ext_litestar_0001", Path("ext_litestar_0001_sessions.py")), + ("20251012130000", Path("20251012130000_orders.sql")), + ("ext_litestar_20251011215440", Path("ext_litestar_20251011215440_new_session.py")), + ("ext_adk_20251011215914", Path("ext_adk_20251011215914_tables.py")), + ("0003", Path("0003_categories.sql")), + ] + + result = generate_conversion_map(migrations) + + assert result == { + "20251011120000": "0004", + "20251012130000": "0005", + "ext_litestar_20251011215440": "ext_litestar_0002", + "ext_adk_20251011215914": "ext_adk_0001", + } + + +def test_generate_conversion_map_preserves_path_info() -> None: + """Test that conversion map generation doesn't fail with Path objects.""" + migrations = [ + ("20251011120000", Path("/migrations/20251011120000_products.sql")), + ("20251012130000", Path("/migrations/20251012130000_orders.sql")), + ] + + result = generate_conversion_map(migrations) + + assert "20251011120000" in result + assert "20251012130000" in result + + +def test_get_next_sequential_number_unordered() -> None: + """Test next sequential number with unordered input.""" + v3 = parse_version("0003") + v1 = parse_version("0001") + v2 = parse_version("0002") + + result = get_next_sequential_number([v3, v1, v2]) + assert result == 4 + + +def test_convert_to_sequential_version_large_numbers() -> None: + """Test conversion with large sequence numbers.""" + v = parse_version("20251011120000") + result = convert_to_sequential_version(v, 99999) + assert result == "99999" + assert len(result) == 5 diff --git a/tests/unit/test_migrations/test_version_parsing_edge_cases.py b/tests/unit/test_migrations/test_version_parsing_edge_cases.py new file mode 100644 index 00000000..933faf5f --- /dev/null +++ b/tests/unit/test_migrations/test_version_parsing_edge_cases.py @@ -0,0 +1,236 @@ +"""Unit tests for edge cases in migration version parsing.""" + +import pytest + +from sqlspec.utils.version import ( + VersionType, + convert_to_sequential_version, + generate_conversion_map, + get_next_sequential_number, + is_sequential_version, + parse_version, +) + + +def test_is_sequential_version_no_digit_cap() -> None: + """Test sequential version detection works without 4-digit limitation.""" + assert is_sequential_version("10000") + assert is_sequential_version("99999") + assert is_sequential_version("12345") + assert is_sequential_version("00001") + + +def test_parse_sequential_version_large_numbers() -> None: + """Test parsing sequential versions beyond 9999.""" + v = parse_version("10000") + assert v.raw == "10000" + assert v.type == VersionType.SEQUENTIAL + assert v.sequence == 10000 + assert v.timestamp is None + + v = parse_version("99999") + assert v.sequence == 99999 + + +def test_version_comparison_large_sequential() -> None: + """Test comparing large sequential versions.""" + v1 = parse_version("9999") + v2 = parse_version("10000") + v3 = parse_version("10001") + + assert v1 < v2 + assert v2 < v3 + assert not v2 < v1 + + +def test_version_comparison_extension_with_large_numbers() -> None: + """Test extension versions with large sequential numbers.""" + ext1 = parse_version("ext_litestar_9999") + ext2 = parse_version("ext_litestar_10000") + + assert ext1 < ext2 + + +def test_get_next_sequential_number_after_9999() -> None: + """Test getting next sequential number after exceeding 9999.""" + v1 = parse_version("9999") + v2 = parse_version("10000") + + next_num = get_next_sequential_number([v1, v2]) + assert next_num == 10001 + + +def test_get_next_sequential_number_with_extension() -> None: + """Test getting next sequential number for extension migrations.""" + core1 = parse_version("0001") + core2 = parse_version("0002") + ext1 = parse_version("ext_litestar_0001") + ext2 = parse_version("ext_litestar_0002") + + next_core = get_next_sequential_number([core1, core2, ext1, ext2], extension=None) + assert next_core == 3 + + next_ext = get_next_sequential_number([core1, core2, ext1, ext2], extension="litestar") + assert next_ext == 3 + + +def test_get_next_sequential_number_only_timestamps() -> None: + """Test getting next sequential number when only timestamp versions exist.""" + v1 = parse_version("20251011120000") + v2 = parse_version("20251012130000") + + next_num = get_next_sequential_number([v1, v2]) + assert next_num == 1 + + +def test_convert_to_sequential_version_preserves_extension() -> None: + """Test converting timestamp to sequential preserves extension prefix.""" + timestamp_version = parse_version("ext_litestar_20251011120000") + sequential = convert_to_sequential_version(timestamp_version, 5) + + assert sequential == "ext_litestar_0005" + + +def test_convert_to_sequential_version_large_sequence() -> None: + """Test converting timestamp to large sequential number.""" + timestamp_version = parse_version("20251011120000") + sequential = convert_to_sequential_version(timestamp_version, 10000) + + assert sequential == "10000" + + +def test_convert_to_sequential_version_rejects_sequential() -> None: + """Test converting sequential version raises error.""" + sequential_version = parse_version("0001") + + with pytest.raises(ValueError, match="Can only convert timestamp versions"): + convert_to_sequential_version(sequential_version, 2) + + +def test_generate_conversion_map_with_extensions() -> None: + """Test conversion map generation with extension migrations.""" + from pathlib import Path + + migrations = [ + ("0001", Path("0001_init.sql")), + ("ext_litestar_0001", Path("ext_litestar_0001_init.sql")), + ("20251011120000", Path("20251011120000_users.sql")), + ("ext_litestar_20251012130000", Path("ext_litestar_20251012130000_users.sql")), + ] + + conversion_map = generate_conversion_map(migrations) + + assert conversion_map["20251011120000"] == "0002" + assert conversion_map["ext_litestar_20251012130000"] == "ext_litestar_0002" + assert "0001" not in conversion_map + assert "ext_litestar_0001" not in conversion_map + + +def test_generate_conversion_map_maintains_chronological_order() -> None: + """Test conversion map assigns sequential numbers in chronological order.""" + from pathlib import Path + + migrations = [ + ("20251011120000", Path("20251011120000_third.sql")), + ("20251010100000", Path("20251010100000_first.sql")), + ("20251010120000", Path("20251010120000_second.sql")), + ] + + conversion_map = generate_conversion_map(migrations) + + assert conversion_map["20251010100000"] == "0001" + assert conversion_map["20251010120000"] == "0002" + assert conversion_map["20251011120000"] == "0003" + + +def test_generate_conversion_map_separate_extension_namespaces() -> None: + """Test extension migrations maintain separate sequential namespaces.""" + from pathlib import Path + + migrations = [ + ("0001", Path("0001_init.sql")), + ("ext_aaa_0001", Path("ext_aaa_0001_init.sql")), + ("ext_bbb_0001", Path("ext_bbb_0001_init.sql")), + ("ext_aaa_20251011120000", Path("ext_aaa_20251011120000_users.sql")), + ("ext_bbb_20251012130000", Path("ext_bbb_20251012130000_products.sql")), + ] + + conversion_map = generate_conversion_map(migrations) + + assert conversion_map["ext_aaa_20251011120000"] == "ext_aaa_0002" + assert conversion_map["ext_bbb_20251012130000"] == "ext_bbb_0002" + + +def test_version_sorting_with_large_numbers() -> None: + """Test version sorting works correctly with large sequential numbers.""" + versions = [ + parse_version("10001"), + parse_version("0001"), + parse_version("9999"), + parse_version("10000"), + parse_version("20251011120000"), + ] + + sorted_versions = sorted(versions) + + expected_order = ["0001", "9999", "10000", "10001", "20251011120000"] + assert [v.raw for v in sorted_versions] == expected_order + + +def test_version_comparison_sequential_vs_timestamp_edge_case() -> None: + """Test that even very large sequential numbers sort before timestamps.""" + large_sequential = parse_version("99999") + early_timestamp = parse_version("20000101000000") + + assert large_sequential < early_timestamp + assert not early_timestamp < large_sequential + + +def test_get_next_sequential_number_mixed_extensions() -> None: + """Test getting next sequential with mixed core and extension migrations.""" + core1 = parse_version("0001") + ext_litestar = parse_version("ext_litestar_0001") + ext_adk = parse_version("ext_adk_0001") + timestamp = parse_version("20251011120000") + + next_core = get_next_sequential_number([core1, ext_litestar, ext_adk, timestamp], extension=None) + assert next_core == 2 + + next_litestar = get_next_sequential_number([core1, ext_litestar, ext_adk, timestamp], extension="litestar") + assert next_litestar == 2 + + next_adk = get_next_sequential_number([core1, ext_litestar, ext_adk, timestamp], extension="adk") + assert next_adk == 2 + + +def test_generate_conversion_map_empty_list() -> None: + """Test conversion map generation with empty migration list.""" + conversion_map = generate_conversion_map([]) + assert conversion_map == {} + + +def test_generate_conversion_map_only_sequential() -> None: + """Test conversion map generation when only sequential migrations exist.""" + from pathlib import Path + + migrations = [("0001", Path("0001_init.sql")), ("0002", Path("0002_users.sql"))] + + conversion_map = generate_conversion_map(migrations) + assert conversion_map == {} + + +def test_generate_conversion_map_invalid_versions_skipped() -> None: + """Test conversion map skips invalid version strings.""" + from pathlib import Path + + migrations = [ + ("0001", Path("0001_init.sql")), + ("invalid", Path("invalid_migration.sql")), + ("20251011120000", Path("20251011120000_users.sql")), + ] + + conversion_map = generate_conversion_map(migrations) + + assert "20251011120000" in conversion_map + assert conversion_map["20251011120000"] == "0002" + assert "invalid" not in conversion_map