Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,7 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
tags = {
"Name": volume.configuration.name,
"owner": "dstack",
"dstack_user": volume.user,
"dstack_project": volume.project_name,
}
tags = merge_tags(tags=tags, backend_tags=self.config.tags)
Expand Down
1 change: 1 addition & 0 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
labels = {
"owner": "dstack",
"dstack_project": volume.project_name.lower(),
"dstack_user": volume.user,
}
labels = {k: v for k, v in labels.items() if gcp_resources.is_valid_label_value(v)}
labels = merge_tags(tags=labels, backend_tags=self.config.tags)
Expand Down
3 changes: 3 additions & 0 deletions src/dstack/_internal/core/models/volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class VolumeAttachmentData(CoreModel):
class Volume(CoreModel):
id: uuid.UUID
name: str
# Default user to "" for client backward compatibility (old 0.18 servers).
# TODO: Remove in 0.19
user: str = ""
project_name: str
configuration: VolumeConfiguration
external: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel):
# Take lock to prevent attaching volumes that are to be deleted.
# If the volume was deleted before the lock, the volume will fail to attach and the job will fail.
await session.execute(
select(VolumeModel).where(VolumeModel.id.in_(volumes_ids)).with_for_update()
select(VolumeModel)
.where(VolumeModel.id.in_(volumes_ids))
.options(selectinload(VolumeModel.user))
.with_for_update()
)
async with get_locker().lock_ctx(VolumeModel.__tablename__, volumes_ids):
if len(volume_models) > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ async def _process_submitted_volume(session: AsyncSession, volume_model: VolumeM
select(VolumeModel)
.where(VolumeModel.id == volume_model.id)
.options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends))
.options(joinedload(VolumeModel.user))
.execution_options(populate_existing=True)
)
volume_model = res.unique().scalar_one()
Expand Down
52 changes: 52 additions & 0 deletions src/dstack/_internal/server/migrations/versions/82b32a135ea2_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""empty message

Revision ID: 82b32a135ea2
Revises: afbc600ff2b2
Create Date: 2024-11-04 15:46:37.719531

"""

import sqlalchemy as sa
import sqlalchemy_utils
from alembic import op

# revision identifiers, used by Alembic.
revision = "82b32a135ea2"
down_revision = "afbc600ff2b2"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("volumes", schema=None) as batch_op:
batch_op.add_column(
sa.Column("user_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True)
)
batch_op.create_foreign_key(
batch_op.f("fk_volumes_user_id_users"),
"users",
["user_id"],
["id"],
ondelete="CASCADE",
)

# ### end Alembic commands ###

# update any existing volumes and set the user_id equal to the project_owner.id which created the volume
op.execute(
"UPDATE volumes SET user_id = (SELECT owner_id FROM projects JOIN volumes ON projects.id = volumes.project_id) WHERE user_id IS NULL"
)

# set volumes.user_id to non-nullable
with op.batch_alter_table("volumes", schema=None) as batch_op:
batch_op.alter_column("user_id", nullable=False)


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("volumes", schema=None) as batch_op:
batch_op.drop_constraint(batch_op.f("fk_volumes_user_id_users"), type_="foreignkey")
batch_op.drop_column("user_id")

# ### end Alembic commands ###
3 changes: 3 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,9 @@ class VolumeModel(BaseModel):
)
name: Mapped[str] = mapped_column(String(100))

user_id: Mapped["UserModel"] = mapped_column(ForeignKey("users.id", ondelete="CASCADE"))
user: Mapped["UserModel"] = relationship()

project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE"))
project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id])

Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/server/routers/volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ async def create_volume(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()),
) -> Volume:
_, project = user_project
user, project = user_project
return await volumes_services.create_volume(
session=session,
project=project,
user=user,
configuration=body.configuration,
)

Expand Down
20 changes: 16 additions & 4 deletions src/dstack/_internal/server/services/volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from sqlalchemy import and_, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import joinedload, selectinload

from dstack._internal.core.backends import BACKENDS_WITH_VOLUMES_SUPPORT
from dstack._internal.core.errors import (
Expand Down Expand Up @@ -102,7 +102,11 @@ async def list_projects_volume_models(
if ascending:
order_by = (VolumeModel.created_at.asc(), VolumeModel.id.desc())
res = await session.execute(
select(VolumeModel).where(*filters).order_by(*order_by).limit(limit)
select(VolumeModel)
.where(*filters)
.order_by(*order_by)
.limit(limit)
.options(joinedload(VolumeModel.user))
)
volume_models = list(res.scalars().all())
return volume_models
Expand Down Expand Up @@ -130,7 +134,9 @@ async def list_project_volume_models(
filters.append(VolumeModel.name.in_(names))
if not include_deleted:
filters.append(VolumeModel.deleted == False)
res = await session.execute(select(VolumeModel).where(*filters))
res = await session.execute(
select(VolumeModel).where(*filters).options(joinedload(VolumeModel.user))
)
return list(res.scalars().all())


Expand All @@ -157,13 +163,16 @@ async def get_project_volume_model_by_name(
]
if not include_deleted:
filters.append(VolumeModel.deleted == False)
res = await session.execute(select(VolumeModel).where(*filters))
res = await session.execute(
select(VolumeModel).where(*filters).options(joinedload(VolumeModel.user))
)
return res.scalar_one_or_none()


async def create_volume(
session: AsyncSession,
project: ProjectModel,
user: UserModel,
configuration: VolumeConfiguration,
) -> Volume:
_validate_volume_configuration(configuration)
Expand Down Expand Up @@ -193,6 +202,7 @@ async def create_volume(
volume_model = VolumeModel(
id=uuid.uuid4(),
name=configuration.name,
user_id=user.id,
project=project,
status=VolumeStatus.SUBMITTED,
configuration=configuration.json(),
Expand Down Expand Up @@ -224,6 +234,7 @@ async def delete_volumes(session: AsyncSession, project: ProjectModel, names: Li
VolumeModel.name.in_(names),
VolumeModel.deleted == False,
)
.options(selectinload(VolumeModel.user))
.options(selectinload(VolumeModel.instances))
.execution_options(populate_existing=True)
.with_for_update()
Expand Down Expand Up @@ -263,6 +274,7 @@ def volume_model_to_volume(volume_model: VolumeModel) -> Volume:
return Volume(
name=volume_model.name,
project_name=volume_model.project.name,
user=volume_model.user.name,
configuration=configuration,
external=configuration.volume_id is not None,
created_at=volume_model.created_at.replace(tzinfo=timezone.utc),
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ async def create_instance(
async def create_volume(
session: AsyncSession,
project: ProjectModel,
user: UserModel,
status: VolumeStatus = VolumeStatus.SUBMITTED,
created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
configuration: Optional[VolumeConfiguration] = None,
Expand All @@ -532,6 +533,7 @@ async def create_volume(
configuration = get_volume_configuration(backend=backend, region=region)
vm = VolumeModel(
project=project,
user_id=user.id,
name=configuration.name,
status=status,
created_at=created_at,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ async def test_provisioning_shim_with_volumes(
volume = await create_volume(
session=session,
project=project,
user=user,
status=VolumeStatus.ACTIVE,
configuration=get_volume_configuration(
name="my-vol", backend=BackendType.AWS, region="us-east-1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ async def test_assigns_job_to_instance_with_volumes(self, test_db, session: Asyn
volume = await create_volume(
session=session,
project=project,
user=user,
status=VolumeStatus.ACTIVE,
volume_provisioning_data=get_volume_provisioning_data(),
backend=BackendType.AWS,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes
from dstack._internal.server.testing.common import (
create_project,
create_user,
create_volume,
)

Expand All @@ -17,8 +18,9 @@ class TestProcessSubmittedVolumes:
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_fails_job_when_no_backends(self, test_db, session: AsyncSession):
project = await create_project(session=session)
user = await create_user(session=session)
volume = await create_volume(
session=session, project=project, status=VolumeStatus.SUBMITTED
session=session, project=project, user=user, status=VolumeStatus.SUBMITTED
)
await process_submitted_volumes()
await session.refresh(volume)
Expand All @@ -29,8 +31,9 @@ async def test_fails_job_when_no_backends(self, test_db, session: AsyncSession):
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_provisiones_volumes(self, test_db, session: AsyncSession):
project = await create_project(session=session)
user = await create_user(session=session)
volume = await create_volume(
session=session, project=project, status=VolumeStatus.SUBMITTED
session=session, project=project, user=user, status=VolumeStatus.SUBMITTED
)
with patch(
"dstack._internal.server.services.backends.get_project_backend_by_type_or_error"
Expand Down
15 changes: 15 additions & 0 deletions src/tests/_internal/server/routers/test_volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ async def test_lists_volumes_across_projects(
volume1 = await create_volume(
session=session,
project=project1,
user=user,
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
configuration=get_volume_configuration(name="volume1"),
)
project2 = await create_project(session, name="project2", owner=user)
volume2 = await create_volume(
session=session,
project=project2,
user=user,
created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc),
configuration=get_volume_configuration(name="volume2"),
)
Expand All @@ -65,6 +67,7 @@ async def test_lists_volumes_across_projects(
"id": str(volume2.id),
"name": volume2.name,
"project_name": project2.name,
"user": user.name,
"configuration": json.loads(volume2.configuration),
"external": False,
"created_at": "2023-01-02T03:05:00+00:00",
Expand All @@ -79,6 +82,7 @@ async def test_lists_volumes_across_projects(
"id": str(volume1.id),
"name": volume1.name,
"project_name": project1.name,
"user": user.name,
"configuration": json.loads(volume1.configuration),
"external": False,
"created_at": "2023-01-02T03:04:00+00:00",
Expand All @@ -104,6 +108,7 @@ async def test_lists_volumes_across_projects(
"id": str(volume1.id),
"name": volume1.name,
"project_name": project1.name,
"user": user.name,
"configuration": json.loads(volume1.configuration),
"external": False,
"created_at": "2023-01-02T03:04:00+00:00",
Expand Down Expand Up @@ -134,12 +139,14 @@ async def test_non_admin_cannot_see_others_projects(
volume1 = await create_volume(
session=session,
project=project1,
user=user1,
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
configuration=get_volume_configuration(name="volume1"),
)
await create_volume(
session=session,
project=project2,
user=user2,
created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc),
configuration=get_volume_configuration(name="volume2"),
)
Expand All @@ -154,6 +161,7 @@ async def test_non_admin_cannot_see_others_projects(
"id": str(volume1.id),
"name": volume1.name,
"project_name": project1.name,
"user": user1.name,
"configuration": json.loads(volume1.configuration),
"external": False,
"created_at": "2023-01-02T03:04:00+00:00",
Expand Down Expand Up @@ -187,6 +195,7 @@ async def test_lists_volumes(self, test_db, session: AsyncSession, client: Async
volume = await create_volume(
session=session,
project=project,
user=user,
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
)
response = await client.post(
Expand All @@ -199,6 +208,7 @@ async def test_lists_volumes(self, test_db, session: AsyncSession, client: Async
"id": str(volume.id),
"name": volume.name,
"project_name": project.name,
"user": user.name,
"configuration": json.loads(volume.configuration),
"external": False,
"created_at": "2023-01-02T03:04:00+00:00",
Expand Down Expand Up @@ -232,6 +242,7 @@ async def test_returns_volume(self, test_db, session: AsyncSession, client: Asyn
volume = await create_volume(
session=session,
project=project,
user=user,
created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
)
response = await client.post(
Expand All @@ -244,6 +255,7 @@ async def test_returns_volume(self, test_db, session: AsyncSession, client: Asyn
"id": str(volume.id),
"name": volume.name,
"project_name": project.name,
"user": user.name,
"configuration": json.loads(volume.configuration),
"external": False,
"created_at": "2023-01-02T03:04:00+00:00",
Expand Down Expand Up @@ -305,6 +317,7 @@ async def test_creates_volume(self, test_db, session: AsyncSession, client: Asyn
"name": configuration.name,
"project_name": project.name,
"configuration": configuration,
"user": user.name,
"external": False,
"created_at": "2023-01-02T03:04:00+00:00",
"status": "submitted",
Expand Down Expand Up @@ -338,6 +351,7 @@ async def test_deletes_volumes(self, test_db, session: AsyncSession, client: Asy
volume = await create_volume(
session=session,
project=project,
user=user,
volume_provisioning_data=get_volume_provisioning_data(),
)
with patch(
Expand Down Expand Up @@ -369,6 +383,7 @@ async def test_returns_400_when_volumes_in_use(
volume = await create_volume(
session=session,
project=project,
user=user,
volume_provisioning_data=get_volume_provisioning_data(),
)
instance = await create_instance(
Expand Down