Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Add UserModel.deleted and original_name

Revision ID: 06e977bc61c7
Revises: 7d1ec2b920ac
Create Date: 2025-11-26 11:43:34.825686

"""

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "06e977bc61c7"
down_revision = "7d1ec2b920ac"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.add_column(
sa.Column("deleted", sa.Boolean(), server_default=sa.false(), nullable=False)
)
batch_op.add_column(sa.Column("original_name", sa.String(length=50), nullable=True))

with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.add_column(sa.Column("original_name", sa.String(length=50), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("users", schema=None) as batch_op:
batch_op.drop_column("original_name")
batch_op.drop_column("deleted")

with op.batch_alter_table("projects", schema=None) as batch_op:
batch_op.drop_column("original_name")

# ### end Alembic commands ###
7 changes: 6 additions & 1 deletion src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ class UserModel(BaseModel):
global_role: Mapped[GlobalRole] = mapped_column(EnumAsString(GlobalRole, 100))
# deactivated users cannot access API
active: Mapped[bool] = mapped_column(Boolean, default=True)
deleted: Mapped[bool] = mapped_column(Boolean, server_default=false())
# `original_name` stores the name of a deleted user, while `name` is changed to a unique generated value.
original_name: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)

# SSH keys can be null for users created before 0.19.33.
# Keys for those users are being gradually generated on /get_my_user calls.
Expand All @@ -212,8 +215,10 @@ class ProjectModel(BaseModel):
)
name: Mapped[str] = mapped_column(String(50), unique=True)
created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# `original_name` stores the name of a deleted project, while `name` is changed to a unique generated value.
original_name: Mapped[Optional[str]] = mapped_column(String(50), nullable=True)

owner_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
owner: Mapped[UserModel] = relationship(lazy="joined")
Expand Down
51 changes: 35 additions & 16 deletions src/dstack/_internal/server/services/projects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import secrets
import uuid
from typing import Awaitable, Callable, List, Optional, Tuple

Expand Down Expand Up @@ -194,28 +195,36 @@ async def delete_projects(
raise ServerClientError("Cannot delete the only project")

res = await session.execute(
select(ProjectModel.id).where(ProjectModel.name.in_(projects_names))
select(ProjectModel)
.where(
ProjectModel.name.in_(projects_names),
ProjectModel.deleted == False,
)
.options(load_only(ProjectModel.id, ProjectModel.name))
)
project_ids = res.scalars().all()
if len(project_ids) != len(projects_names):
projects = res.scalars().all()
if len(projects) != len(projects_names):
raise ServerClientError("Failed to delete non-existent projects")

for project_id in project_ids:
for p in projects:
# FIXME: The checks are not under lock,
# so there can be dangling active resources due to race conditions.
await _check_project_has_active_resources(session=session, project_id=project_id)
await _check_project_has_active_resources(session=session, project_id=p.id)

timestamp = str(int(get_current_datetime().timestamp()))
new_project_name = "_deleted_" + timestamp + ProjectModel.name
await session.execute(
update(ProjectModel)
.where(ProjectModel.name.in_(projects_names))
.values(
deleted=True,
name=new_project_name,
updates = []
for p in projects:
updates.append(
{
"id": p.id,
"name": f"_deleted_{timestamp}_{secrets.token_hex(8)}",
"original_name": p.name,
"deleted": True,
}
)
)
await session.execute(update(ProjectModel), updates)
await session.commit()
logger.info("Deleted projects %s by user %s", projects_names, user.name)


async def set_project_members(
Expand Down Expand Up @@ -244,12 +253,16 @@ async def set_project_members(
}
if new_admins_members != current_admins_members:
raise ForbiddenError("Access denied: changing project admins")

# FIXME: potentially long write transaction
# clear_project_members() issues DELETE without commit
await clear_project_members(session=session, project=project)
names = [m.username for m in members]
res = await session.execute(
select(UserModel).where((UserModel.name.in_(names)) | (UserModel.email.in_(names)))
select(UserModel).where(
(UserModel.name.in_(names)) | (UserModel.email.in_(names)),
UserModel.deleted == False,
)
)
users = res.scalars().all()
username_to_user = {user.name: user for user in users}
Expand Down Expand Up @@ -311,7 +324,10 @@ async def add_project_members(
raise ForbiddenError("Access denied: can only join public projects as user role")

res = await session.execute(
select(UserModel).where((UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames)))
select(UserModel).where(
(UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames)),
UserModel.deleted == False,
)
)
users_found = res.scalars().all()

Expand Down Expand Up @@ -700,7 +716,10 @@ async def remove_project_members(
raise ForbiddenError("Access denied: insufficient permissions to remove members")

res = await session.execute(
select(UserModel).where((UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames)))
select(UserModel).where(
(UserModel.name.in_(usernames)) | (UserModel.email.in_(usernames)),
UserModel.deleted == False,
)
)
users_found = res.scalars().all()

Expand Down
73 changes: 62 additions & 11 deletions src/dstack/_internal/server/services/users.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import hashlib
import os
import re
import secrets
import uuid
from typing import Awaitable, Callable, List, Optional, Tuple

from sqlalchemy import delete, select, update
from sqlalchemy import func as safunc
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import load_only

from dstack._internal.core.errors import ResourceExistsError, ServerClientError
from dstack._internal.core.models.users import (
Expand All @@ -17,11 +19,11 @@
UserTokenCreds,
UserWithCreds,
)
from dstack._internal.server.models import DecryptedString, UserModel
from dstack._internal.server.models import DecryptedString, MemberModel, UserModel
from dstack._internal.server.services.permissions import get_default_permissions
from dstack._internal.server.utils.routers import error_forbidden
from dstack._internal.utils import crypto
from dstack._internal.utils.common import run_async
from dstack._internal.utils.common import get_current_datetime, get_or_error, run_async
from dstack._internal.utils.logging import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -53,8 +55,12 @@ async def list_users_for_user(

async def list_all_users(
session: AsyncSession,
include_deleted: bool = False,
) -> List[User]:
res = await session.execute(select(UserModel))
filters = []
if not include_deleted:
filters.append(UserModel.deleted == False)
res = await session.execute(select(UserModel).where(*filters))
user_models = res.scalars().all()
user_models = sorted(user_models, key=lambda u: u.created_at)
return [user_model_to_user(u) for u in user_models]
Expand Down Expand Up @@ -116,7 +122,10 @@ async def update_user(
) -> UserModel:
await session.execute(
update(UserModel)
.where(UserModel.name == username)
.where(
UserModel.name == username,
UserModel.deleted == False,
)
.values(
global_role=global_role,
email=email,
Expand All @@ -138,7 +147,10 @@ async def refresh_ssh_key(
private_bytes, public_bytes = await run_async(crypto.generate_rsa_key_pair_bytes, username)
await session.execute(
update(UserModel)
.where(UserModel.name == username)
.where(
UserModel.name == username,
UserModel.deleted == False,
)
.values(
ssh_private_key=private_bytes.decode(),
ssh_public_key=public_bytes.decode(),
Expand All @@ -158,7 +170,10 @@ async def refresh_user_token(
new_token = str(uuid.uuid4())
await session.execute(
update(UserModel)
.where(UserModel.name == username)
.where(
UserModel.name == username,
UserModel.deleted == False,
)
.values(
token=DecryptedString(plaintext=new_token),
token_hash=get_token_hash(new_token),
Expand All @@ -173,7 +188,37 @@ async def delete_users(
user: UserModel,
usernames: List[str],
):
await session.execute(delete(UserModel).where(UserModel.name.in_(usernames)))
if _ADMIN_USERNAME in usernames:
raise ServerClientError("User 'admin' cannot be deleted")

res = await session.execute(
select(UserModel)
.where(
UserModel.name.in_(usernames),
UserModel.deleted == False,
)
.options(load_only(UserModel.id, UserModel.name))
)
users = res.scalars().all()
if len(users) != len(usernames):
raise ServerClientError("Failed to delete non-existent users")

user_ids = [u.id for u in users]
timestamp = str(int(get_current_datetime().timestamp()))
updates = []
for u in users:
updates.append(
{
"id": u.id,
"name": f"_deleted_{timestamp}_{secrets.token_hex(8)}",
"original_name": u.name,
"deleted": True,
"active": False,
}
)
await session.execute(update(UserModel), updates)
await session.execute(delete(MemberModel).where(MemberModel.user_id.in_(user_ids)))
# Projects are not deleted automatically if owners are deleted.
await session.commit()
logger.info("Deleted users %s by user %s", usernames, user.name)

Expand All @@ -183,7 +228,7 @@ async def get_user_model_by_name(
username: str,
ignore_case: bool = False,
) -> Optional[UserModel]:
filters = []
filters = [UserModel.deleted == False]
if ignore_case:
filters.append(safunc.lower(UserModel.name) == safunc.lower(username))
else:
Expand All @@ -192,9 +237,14 @@ async def get_user_model_by_name(
return res.scalar()


async def get_user_model_by_name_or_error(session: AsyncSession, username: str) -> UserModel:
res = await session.execute(select(UserModel).where(UserModel.name == username))
return res.scalar_one()
async def get_user_model_by_name_or_error(
session: AsyncSession,
username: str,
ignore_case: bool = False,
) -> UserModel:
return get_or_error(
await get_user_model_by_name(session=session, username=username, ignore_case=ignore_case)
)


async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserModel]:
Expand All @@ -203,6 +253,7 @@ async def log_in_with_token(session: AsyncSession, token: str) -> Optional[UserM
select(UserModel).where(
UserModel.token_hash == token_hash,
UserModel.active == True,
UserModel.deleted == False,
)
)
user = res.scalar()
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ async def create_user(
ssh_public_key: Optional[str] = None,
ssh_private_key: Optional[str] = None,
active: bool = True,
deleted: bool = False,
) -> UserModel:
if token is None:
token = str(uuid.uuid4())
Expand All @@ -148,6 +149,7 @@ async def create_user(
ssh_public_key=ssh_public_key,
ssh_private_key=ssh_private_key,
active=active,
deleted=deleted,
)
session.add(user)
await session.commit()
Expand Down
7 changes: 5 additions & 2 deletions src/tests/_internal/server/routers/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,12 @@ async def test_cannot_delete_the_only_project(

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_deletes_projects(self, test_db, session: AsyncSession, client: AsyncClient):
@pytest.mark.parametrize("project_name", ["project1", "a" * 50])
async def test_deletes_projects(
self, test_db, session: AsyncSession, client: AsyncClient, project_name: str
):
user = await create_user(session=session, global_role=GlobalRole.USER)
project1 = await create_project(session=session, owner=user, name="project1")
project1 = await create_project(session=session, owner=user, name=project_name)
await add_project_member(
session=session, project=project1, user=user, project_role=ProjectRole.ADMIN
)
Expand Down
Loading