diff --git a/src/dstack/_internal/server/migrations/versions/06e977bc61c7_add_usermodel_deleted_and_original_name.py b/src/dstack/_internal/server/migrations/versions/06e977bc61c7_add_usermodel_deleted_and_original_name.py new file mode 100644 index 000000000..434c0e286 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/06e977bc61c7_add_usermodel_deleted_and_original_name.py @@ -0,0 +1,41 @@ +"""Add UserModel.deleted and original_name + +Revision ID: 06e977bc61c7 +Revises: 7d1ec2b920ac +Create Date: 2025-11-26 11:43:34.825686 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "06e977bc61c7" +down_revision = "7d1ec2b920ac" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.add_column( + sa.Column("deleted", sa.Boolean(), server_default=sa.false(), nullable=False) + ) + batch_op.add_column(sa.Column("original_name", sa.String(length=50), nullable=True)) + + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.add_column(sa.Column("original_name", sa.String(length=50), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("users", schema=None) as batch_op: + batch_op.drop_column("original_name") + batch_op.drop_column("deleted") + + with op.batch_alter_table("projects", schema=None) as batch_op: + batch_op.drop_column("original_name") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index e88f83d59..04185762a 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -190,6 +190,9 @@ class UserModel(BaseModel): global_role: Mapped[GlobalRole] = mapped_column(EnumAsString(GlobalRole, 100)) # deactivated users cannot access API active: Mapped[bool] = mapped_column(Boolean, default=True) + deleted: Mapped[bool] = mapped_column(Boolean, server_default=false()) + # `original_name` stores the name of a deleted user, while `name` is changed to a unique generated value. + original_name: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) # SSH keys can be null for users created before 0.19.33. # Keys for those users are being gradually generated on /get_my_user calls. @@ -212,8 +215,10 @@ class ProjectModel(BaseModel): ) name: Mapped[str] = mapped_column(String(50), unique=True) created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) - deleted: Mapped[bool] = mapped_column(Boolean, default=False) is_public: Mapped[bool] = mapped_column(Boolean, default=False) + deleted: Mapped[bool] = mapped_column(Boolean, default=False) + # `original_name` stores the name of a deleted project, while `name` is changed to a unique generated value. + original_name: Mapped[Optional[str]] = mapped_column(String(50), nullable=True) owner_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) owner: Mapped[UserModel] = relationship(lazy="joined") diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 330fcceb4..cc6c37401 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -1,3 +1,4 @@ +import secrets import uuid from typing import Awaitable, Callable, List, Optional, Tuple @@ -194,28 +195,36 @@ async def delete_projects( raise ServerClientError("Cannot delete the only project") res = await session.execute( - select(ProjectModel.id).where(ProjectModel.name.in_(projects_names)) + select(ProjectModel) + .where( + ProjectModel.name.in_(projects_names), + ProjectModel.deleted == False, + ) + .options(load_only(ProjectModel.id, ProjectModel.name)) ) - project_ids = res.scalars().all() - if len(project_ids) != len(projects_names): + projects = res.scalars().all() + if len(projects) != len(projects_names): raise ServerClientError("Failed to delete non-existent projects") - for project_id in project_ids: + for p in projects: # FIXME: The checks are not under lock, # so there can be dangling active resources due to race conditions. - await _check_project_has_active_resources(session=session, project_id=project_id) + await _check_project_has_active_resources(session=session, project_id=p.id) timestamp = str(int(get_current_datetime().timestamp())) - new_project_name = "_deleted_" + timestamp + ProjectModel.name - await session.execute( - update(ProjectModel) - .where(ProjectModel.name.in_(projects_names)) - .values( - deleted=True, - name=new_project_name, + updates = [] + for p in projects: + updates.append( + { + "id": p.id, + "name": f"_deleted_{timestamp}_{secrets.token_hex(8)}", + "original_name": p.name, + "deleted": True, + } ) - ) + await session.execute(update(ProjectModel), updates) await session.commit() + logger.info("Deleted projects %s by user %s", projects_names, user.name) async def set_project_members( @@ -244,12 +253,16 @@ async def set_project_members( } if new_admins_members != current_admins_members: raise ForbiddenError("Access denied: changing project admins") + # FIXME: potentially long write transaction # clear_project_members() issues DELETE without commit await clear_project_members(session=session, project=project) names = [m.username for m in members] res = await session.execute( - select(UserModel).where((UserModel.name.in_(names)) | (UserModel.email.in_(names))) + select(UserModel).where( + (UserModel.name.in_(names)) | (UserModel.email.in_(names)), + UserModel.deleted == False, + ) ) users = res.scalars().all() username_to_user = {user.name: user for user in users} @@ -311,7 +324,10 @@ async def add_project_members( raise ForbiddenError("Access denied: can only join public projects as user role") res = await session.execute( - select(UserModel).where((UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames))) + select(UserModel).where( + (UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames)), + UserModel.deleted == False, + ) ) users_found = res.scalars().all() @@ -700,7 +716,10 @@ async def remove_project_members( raise ForbiddenError("Access denied: insufficient permissions to remove members") res = await session.execute( - select(UserModel).where((UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames))) + select(UserModel).where( + (UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames)), + UserModel.deleted == False, + ) ) users_found = res.scalars().all() diff --git a/src/dstack/_internal/server/services/users.py b/src/dstack/_internal/server/services/users.py index 9fdbe3b4e..aed0a23de 100644 --- a/src/dstack/_internal/server/services/users.py +++ b/src/dstack/_internal/server/services/users.py @@ -1,12 +1,14 @@ import hashlib import os import re +import secrets import uuid from typing import Awaitable, Callable, List, Optional, Tuple from sqlalchemy import delete, select, update from sqlalchemy import func as safunc from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import load_only from dstack._internal.core.errors import ResourceExistsError, ServerClientError from dstack._internal.core.models.users import ( @@ -17,11 +19,11 @@ UserTokenCreds, UserWithCreds, ) -from dstack._internal.server.models import DecryptedString, UserModel +from dstack._internal.server.models import DecryptedString, MemberModel, UserModel from dstack._internal.server.services.permissions import get_default_permissions from dstack._internal.server.utils.routers import error_forbidden from dstack._internal.utils import crypto -from dstack._internal.utils.common import run_async +from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -53,8 +55,12 @@ async def list_users_for_user( async def list_all_users( session: AsyncSession, + include_deleted: bool = False, ) -> List[User]: - res = await session.execute(select(UserModel)) + filters = [] + if not include_deleted: + filters.append(UserModel.deleted == False) + res = await session.execute(select(UserModel).where(*filters)) user_models = res.scalars().all() user_models = sorted(user_models, key=lambda u: u.created_at) return [user_model_to_user(u) for u in user_models] @@ -116,7 +122,10 @@ async def update_user( ) -> UserModel: await session.execute( update(UserModel) - .where(UserModel.name == username) + .where( + UserModel.name == username, + UserModel.deleted == False, + ) .values( global_role=global_role, email=email, @@ -138,7 +147,10 @@ async def refresh_ssh_key( private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username) await session.execute( update(UserModel) - .where(UserModel.name == username) + .where( + UserModel.name == username, + UserModel.deleted == False, + ) .values( ssh_private_key=private_bytes.decode(), ssh_public_key=public_bytes.decode(), @@ -158,7 +170,10 @@ async def refresh_user_token( new_token = str(uuid.uuid4()) await session.execute( update(UserModel) - .where(UserModel.name == username) + .where( + UserModel.name == username, + UserModel.deleted == False, + ) .values( token=DecryptedString(plaintext=new_token), token_hash=get_token_hash(new_token), @@ -173,7 +188,37 @@ async def delete_users( user: UserModel, usernames: List[str], ): - await session.execute(delete(UserModel).where(UserModel.name.in_(usernames))) + if _ADMIN_USERNAME in usernames: + raise ServerClientError("User 'admin' cannot be deleted") + + res = await session.execute( + select(UserModel) + .where( + UserModel.name.in_(usernames), + UserModel.deleted == False, + ) + .options(load_only(UserModel.id, UserModel.name)) + ) + users = res.scalars().all() + if len(users) != len(usernames): + raise ServerClientError("Failed to delete non-existent users") + + user_ids = [u.id for u in users] + timestamp = str(int(get_current_datetime().timestamp())) + updates = [] + for u in users: + updates.append( + { + "id": u.id, + "name": f"_deleted_{timestamp}_{secrets.token_hex(8)}", + "original_name": u.name, + "deleted": True, + "active": False, + } + ) + await session.execute(update(UserModel), updates) + await session.execute(delete(MemberModel).where(MemberModel.user_id.in_(user_ids))) + # Projects are not deleted automatically if owners are deleted. await session.commit() logger.info("Deleted users %s by user %s", usernames, user.name) @@ -183,7 +228,7 @@ async def get_user_model_by_name( username: str, ignore_case: bool = False, ) -> Optional[UserModel]: - filters = [] + filters = [UserModel.deleted == False] if ignore_case: filters.append(safunc.lower(UserModel.name) == safunc.lower(username)) else: @@ -192,9 +237,14 @@ async def get_user_model_by_name( return res.scalar() -async def get_user_model_by_name_or_error(session: AsyncSession, username: str) -> UserModel: - res = await session.execute(select(UserModel).where(UserModel.name == username)) - return res.scalar_one() +async def get_user_model_by_name_or_error( + session: AsyncSession, + username: str, + ignore_case: bool = False, +) -> UserModel: + return get_or_error( + await get_user_model_by_name(session=session, username=username, ignore_case=ignore_case) + ) async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserModel]: @@ -203,6 +253,7 @@ async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserM select(UserModel).where( UserModel.token_hash == token_hash, UserModel.active == True, + UserModel.deleted == False, ) ) user = res.scalar() diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 883ce1453..9c9585982 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -135,6 +135,7 @@ async def create_user( ssh_public_key: Optional[str] = None, ssh_private_key: Optional[str] = None, active: bool = True, + deleted: bool = False, ) -> UserModel: if token is None: token = str(uuid.uuid4()) @@ -148,6 +149,7 @@ async def create_user( ssh_public_key=ssh_public_key, ssh_private_key=ssh_private_key, active=active, + deleted=deleted, ) session.add(user) await session.commit() diff --git a/src/tests/_internal/server/routers/test_projects.py b/src/tests/_internal/server/routers/test_projects.py index d3b042696..8e21957f5 100644 --- a/src/tests/_internal/server/routers/test_projects.py +++ b/src/tests/_internal/server/routers/test_projects.py @@ -472,9 +472,12 @@ async def test_cannot_delete_the_only_project( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_deletes_projects(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("project_name", ["project1", "a" * 50]) + async def test_deletes_projects( + self, test_db, session: AsyncSession, client: AsyncClient, project_name: str + ): user = await create_user(session=session, global_role=GlobalRole.USER) - project1 = await create_project(session=session, owner=user, name="project1") + project1 = await create_project(session=session, owner=user, name=project_name) await add_project_member( session=session, project=project1, user=user, project_role=ProjectRole.ADMIN ) diff --git a/src/tests/_internal/server/routers/test_users.py b/src/tests/_internal/server/routers/test_users.py index 54da63803..8b8c7ca2a 100644 --- a/src/tests/_internal/server/routers/test_users.py +++ b/src/tests/_internal/server/routers/test_users.py @@ -8,9 +8,20 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from dstack._internal.core.models.projects import ProjectRole +from dstack._internal.core.models.runs import JobStatus, RunStatus from dstack._internal.core.models.users import GlobalRole -from dstack._internal.server.models import UserModel -from dstack._internal.server.testing.common import create_user, get_auth_headers +from dstack._internal.server.models import MemberModel, UserModel +from dstack._internal.server.services.projects import add_project_member +from dstack._internal.server.testing.common import ( + create_job, + create_probe, + create_project, + create_repo, + create_run, + create_user, + get_auth_headers, +) class TestListUsers: @@ -22,7 +33,9 @@ async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClie @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_admins_see_all_users(self, test_db, session: AsyncSession, client: AsyncClient): + async def test_admins_see_all_non_deleted_users( + self, test_db, session: AsyncSession, client: AsyncClient + ): admin = await create_user( session=session, name="admin", @@ -35,6 +48,13 @@ async def test_admins_see_all_users(self, test_db, session: AsyncSession, client created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), global_role=GlobalRole.USER, ) + await create_user( + session=session, + name="deleted_user", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + global_role=GlobalRole.USER, + deleted=True, + ) response = await client.post("/api/users/list", headers=get_auth_headers(admin.token)) assert response.status_code in [200] assert response.json() == [ @@ -360,9 +380,12 @@ async def test_returns_40x_if_not_authenticated(self, test_db, client: AsyncClie @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_deletes_users(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("username", ["test", "a" * 50]) + async def test_deletes_users( + self, test_db, session: AsyncSession, client: AsyncClient, username: str + ): admin = await create_user(name="admin", session=session) - user = await create_user(name="test", session=session) + user = await create_user(name=username, session=session) response = await client.post( "/api/users/delete", headers=get_auth_headers(admin.token), @@ -372,6 +395,78 @@ async def test_deletes_users(self, test_db, session: AsyncSession, client: Async res = await session.execute(select(UserModel).where(UserModel.name == user.name)) assert len(res.scalars().all()) == 0 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_400_if_users_not_exist( + self, test_db, session: AsyncSession, client: AsyncClient + ): + admin = await create_user(name="admin", session=session) + user1 = await create_user(name="test1", session=session) + user2 = await create_user(name="test2", session=session) + response = await client.post( + "/api/users/delete", + headers=get_auth_headers(admin.token), + json={"users": [user1.name, "non_existing_user"]}, + ) + assert response.status_code == 400 + response = await client.post( + "/api/users/delete", + headers=get_auth_headers(admin.token), + json={"users": [user1.name, user2.name]}, + ) + assert response.status_code == 200 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.usefixtures("image_config_mock") + async def test_deletes_user_with_resources( + self, test_db, session: AsyncSession, client: AsyncClient + ): + admin = await create_user(name="admin", session=session) + user = await create_user(name="temp", session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + status=RunStatus.RUNNING, + ) + job = await create_job(session=session, run=run, status=JobStatus.RUNNING) + await create_probe(session=session, job=job) + response = await client.post( + "/api/users/delete", + headers=get_auth_headers(admin.token), + json={"users": [user.name]}, + ) + assert response.status_code == 200 + res = await session.execute(select(UserModel).where(UserModel.name == user.name)) + assert res.scalar() is None + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.usefixtures("image_config_mock") + async def test_deleting_users_deletes_members( + self, test_db, session: AsyncSession, client: AsyncClient + ): + admin = await create_user(name="admin", session=session) + user = await create_user(name="temp", session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + response = await client.post( + "/api/users/delete", + headers=get_auth_headers(admin.token), + json={"users": [user.name]}, + ) + assert response.status_code == 200 + res = await session.execute(select(UserModel).where(UserModel.name == user.name)) + assert res.scalar() is None + res = await session.execute(select(MemberModel).where(MemberModel.user_id == user.id)) + assert res.scalar() is None + class TestRefreshToken: @pytest.mark.asyncio