Skip to content

Commit

Permalink
feat(cli): Adds drop-all option to the litestar cli (#226)
Browse files Browse the repository at this point in the history
* feat(cli): Adds drop-all option to the litestar cli

* fix(mypy): CI failure

* tests(cli): Add tests for drop-all

* chore: remove unnecessary import

* feat: updated tests

* fix: 3.8 and 3.9 support

* chore: revert later

* revert: "chore: revert later"

This reverts commit 4091abd.

* chore: 3.8/3.9 support

* chore: add eval type check

* chore: fix lint

* feat: updated linting

---------

Co-authored-by: Alc-Alc <alc@localhost>
Co-authored-by: Cody Fincher <cody.fincher@gmail.com>
  • Loading branch information
3 people committed Jun 30, 2024
1 parent da87543 commit d3f8cfd
Show file tree
Hide file tree
Showing 11 changed files with 326 additions and 231 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/provinzkraut/unasyncd
rev: "v0.7.1"
rev: "v0.7.2"
hooks:
- id: unasyncd
additional_dependencies: ["ruff"]
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.4.10"
rev: "v0.5.0"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
4 changes: 2 additions & 2 deletions advanced_alchemy/alembic/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def ensure_version(self, sql: bool = False) -> None:
def heads(self, verbose: bool = False, resolve_dependencies: bool = False) -> None:
"""Show current available heads in the script directory."""

return migration_command.heads(config=self.config, verbose=verbose, resolve_dependencies=resolve_dependencies) # type: ignore # noqa: PGH003
return migration_command.heads(config=self.config, verbose=verbose, resolve_dependencies=resolve_dependencies)

def history(
self,
Expand Down Expand Up @@ -193,7 +193,7 @@ def show(
) -> None:
"""Show the revision(s) denoted by the given symbol."""

return migration_command.show(config=self.config, rev=rev) # type: ignore # noqa: PGH003
return migration_command.show(config=self.config, rev=rev)

def init(
self,
Expand Down
37 changes: 37 additions & 0 deletions advanced_alchemy/alembic/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from litestar.cli._utils import console
from sqlalchemy import Engine, MetaData, Table
from typing_extensions import TypeIs

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine


async def drop_all(engine: AsyncEngine | Engine, version_table_name: str, metadata: MetaData) -> None:
def _is_sync(engine: Engine | AsyncEngine) -> TypeIs[Engine]:
return isinstance(engine, Engine)

def _drop_tables_sync(engine: Engine) -> None:
console.rule("[bold red]Connecting to database backend.")
with engine.begin() as db:
console.rule("[bold red]Dropping the db", align="left")
metadata.drop_all(db)
console.rule("[bold red]Dropping the version table", align="left")
Table(version_table_name, metadata).drop(db, checkfirst=True)
console.rule("[bold yellow]Successfully dropped all objects", align="left")

async def _drop_tables_async(engine: AsyncEngine) -> None:
console.rule("[bold red]Connecting to database backend.", align="left")
async with engine.begin() as db:
console.rule("[bold red]Dropping the db", align="left")
await db.run_sync(metadata.drop_all)
console.rule("[bold red]Dropping the version table", align="left")
await db.run_sync(Table(version_table_name, metadata).drop, checkfirst=True)
console.rule("[bold yellow]Successfully dropped all objects", align="left")

if _is_sync(engine):
return _drop_tables_sync(engine)
return await _drop_tables_async(engine)
43 changes: 36 additions & 7 deletions advanced_alchemy/extensions/litestar/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING, cast

from anyio import run
from click import argument, group, option
from litestar.cli._utils import LitestarGroup, console

Expand Down Expand Up @@ -67,7 +68,7 @@ def downgrade_database(app: Litestar, revision: str, sql: bool, tag: str | None,
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"Are you sure you you want to downgrade the database to the `{revision}` revision?")
else Confirm.ask(f"Are you sure you want to downgrade the database to the `{revision}` revision?")
)
if input_confirmed:
alembic_commands = AlembicCommands(app=app)
Expand Down Expand Up @@ -109,7 +110,7 @@ def upgrade_database(app: Litestar, revision: str, sql: bool, tag: str | None, n
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"[bold]Are you sure you you want migrate the database to the `{revision}` revision?[/]")
else Confirm.ask(f"[bold]Are you sure you want migrate the database to the `{revision}` revision?[/]")
)
if input_confirmed:
alembic_commands = AlembicCommands(app=app)
Expand Down Expand Up @@ -141,11 +142,9 @@ def init_alembic(app: Litestar, directory: str | None, multidb: bool, package: b
console.rule("[yellow]Initializing database migrations.", align="left")
plugin = get_database_migration_plugin(app)
if directory is None:
directory = plugin._alembic_config.script_location # noqa: SLF001
directory = plugin._alembic_config.script_location # pyright: ignore[reportPrivateUsage] # noqa: SLF001
input_confirmed = (
True
if no_prompt
else Confirm.ask(f"[bold]Are you sure you you want initialize the project in `{directory}`?[/]")
True if no_prompt else Confirm.ask(f"[bold]Are you sure you want initialize the project in `{directory}`?[/]")
)
if input_confirmed:
alembic_commands = AlembicCommands(app)
Expand Down Expand Up @@ -312,7 +311,37 @@ def stamp_revision(app: Litestar, revision: str, sql: bool, tag: str | None, pur
from advanced_alchemy.extensions.litestar.alembic import AlembicCommands

console.rule("[yellow]Stamping database revision as current[/]", align="left")
input_confirmed = True if no_prompt else Confirm.ask("Are you sure you you want to stamp revision as current?")
input_confirmed = True if no_prompt else Confirm.ask("Are you sure you want to stamp revision as current?")
if input_confirmed:
alembic_commands = AlembicCommands(app=app)
alembic_commands.stamp(sql=sql, revision=revision, tag=tag, purge=purge)


@database_group.command(name="drop-all", help="Drop all tables from the database.")
@option(
"--no-prompt",
help="Do not prompt for confirmation before upgrading.",
type=bool,
default=False,
required=False,
show_default=True,
is_flag=True,
)
def drop_all(app: Litestar, no_prompt: bool) -> None:
from rich.prompt import Confirm

from advanced_alchemy.alembic.utils import drop_all
from advanced_alchemy.extensions.litestar.alembic import get_database_migration_plugin

console.rule("[yellow]Dropping all tables from the database[/]", align="left")
input_confirmed = no_prompt or Confirm.ask("[bold red]Are you sure you want to drop all tables from the database?")

sqlalchemy_config = get_database_migration_plugin(app)._config # pyright: ignore[reportPrivateUsage] # noqa: SLF001
engine = sqlalchemy_config.get_engine()
if input_confirmed:
run(
drop_all,
engine,
sqlalchemy_config.alembic_config.version_table_name,
sqlalchemy_config.alembic_config.target_metadata,
)
4 changes: 2 additions & 2 deletions advanced_alchemy/repository/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,12 +1649,12 @@ def _merge_on_match_fields(
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for row_id, datum in enumerate(data):
for _row_id, datum in enumerate(data):
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(data[row_id], self.id_attribute, getattr(existing_datum, self.id_attribute))
setattr(datum, self.id_attribute, getattr(existing_datum, self.id_attribute))
return data

async def list(
Expand Down
4 changes: 2 additions & 2 deletions advanced_alchemy/repository/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,12 +1650,12 @@ def _merge_on_match_fields(
if match_fields is None:
match_fields = [self.id_attribute]
for existing_datum in existing_data:
for row_id, datum in enumerate(data):
for _row_id, datum in enumerate(data):
match = all(
getattr(datum, field_name) == getattr(existing_datum, field_name) for field_name in match_fields
)
if match and getattr(existing_datum, self.id_attribute) is not None:
setattr(data[row_id], self.id_attribute, getattr(existing_datum, self.id_attribute))
setattr(datum, self.id_attribute, getattr(existing_datum, self.id_attribute))
return data

def list(
Expand Down
21 changes: 7 additions & 14 deletions examples/litestar/litestar_repo_only.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List
from datetime import date # noqa: TCH003
from typing import TYPE_CHECKING, List, Optional
from uuid import UUID # noqa: TCH003

from litestar import Litestar
from litestar.controller import Controller
Expand All @@ -20,9 +22,6 @@
from advanced_alchemy.repository import SQLAlchemyAsyncRepository

if TYPE_CHECKING:
from datetime import date
from uuid import UUID

from sqlalchemy.ext.asyncio import AsyncSession


Expand All @@ -38,8 +37,8 @@ class AuthorModel(UUIDBase):
# we can optionally provide the table name instead of auto-generating it
__tablename__ = "author"
name: Mapped[str]
dob: Mapped[date | None]
books: Mapped[list[BookModel]] = relationship(back_populates="author", lazy="noload")
dob: Mapped[Optional[date]] # noqa: UP007
books: Mapped[List[BookModel]] = relationship(back_populates="author", lazy="noload") # noqa: UP006


# The `AuditBase` class includes the same UUID` based primary key (`id`) and 2
Expand All @@ -56,7 +55,7 @@ class BookModel(UUIDAuditBase):


class Author(BaseModel):
id: UUID | None
id: Optional[UUID] # noqa: UP007
name: str
dob: date | None = None

Expand Down Expand Up @@ -201,19 +200,13 @@ async def delete_author(
sqlalchemy_config = SQLAlchemyAsyncConfig(
connection_string="sqlite+aiosqlite:///test.sqlite",
session_config=session_config,
create_all=True,
) # Create 'db_session' dependency.
sqlalchemy_plugin = SQLAlchemyPlugin(config=sqlalchemy_config)


async def on_startup() -> None:
"""Initializes the database."""
async with sqlalchemy_config.get_engine().begin() as conn:
await conn.run_sync(UUIDBase.metadata.create_all)


app = Litestar(
route_handlers=[AuthorController],
on_startup=[on_startup],
plugins=[sqlalchemy_plugin],
dependencies={"limit_offset": Provide(provide_limit_offset_pagination, sync_to_thread=False)},
)
2 changes: 1 addition & 1 deletion examples/sanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# The `Base` class includes a `UUID` based primary key (`id`)
class AuthorModel(UUIDBase):
# we can optionally provide the table name instead of auto-generating it
__tablename__ = "author" #
__tablename__ = "author"
name: Mapped[str]
dob: Mapped[date | None]
books: Mapped[list[BookModel]] = relationship(back_populates="author", lazy="noload")
Expand Down
Loading

0 comments on commit d3f8cfd

Please sign in to comment.