diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index b52f25f0a6..61544e8bfb 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -489,6 +489,7 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData: tags = { "Name": volume.configuration.name, "owner": "dstack", + "dstack_user": volume.user, "dstack_project": volume.project_name, } tags = merge_tags(tags=tags, backend_tags=self.config.tags) diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index 1f263d14a4..2133ccc833 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -500,6 +500,7 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData: labels = { "owner": "dstack", "dstack_project": volume.project_name.lower(), + "dstack_user": volume.user, } labels = {k: v for k, v in labels.items() if gcp_resources.is_valid_label_value(v)} labels = merge_tags(tags=labels, backend_tags=self.config.tags) diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py index b25aff5484..54a63edb9f 100644 --- a/src/dstack/_internal/core/models/volumes.py +++ b/src/dstack/_internal/core/models/volumes.py @@ -64,6 +64,9 @@ class VolumeAttachmentData(CoreModel): class Volume(CoreModel): id: uuid.UUID name: str + # Default user to "" for client backward compatibility (old 0.18 servers). + # TODO: Remove in 0.19 + user: str = "" project_name: str configuration: VolumeConfiguration external: bool diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 5c9585489a..1cfb33a63b 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -290,7 +290,10 @@ async def _process_submitted_job(session: AsyncSession, job_model: JobModel): # Take lock to prevent attaching volumes that are to be deleted. # If the volume was deleted before the lock, the volume will fail to attach and the job will fail. await session.execute( - select(VolumeModel).where(VolumeModel.id.in_(volumes_ids)).with_for_update() + select(VolumeModel) + .where(VolumeModel.id.in_(volumes_ids)) + .options(selectinload(VolumeModel.user)) + .with_for_update() ) async with get_locker().lock_ctx(VolumeModel.__tablename__, volumes_ids): if len(volume_models) > 0: diff --git a/src/dstack/_internal/server/background/tasks/process_volumes.py b/src/dstack/_internal/server/background/tasks/process_volumes.py index 9089de5f3b..da6d145bce 100644 --- a/src/dstack/_internal/server/background/tasks/process_volumes.py +++ b/src/dstack/_internal/server/background/tasks/process_volumes.py @@ -49,6 +49,7 @@ async def _process_submitted_volume(session: AsyncSession, volume_model: VolumeM select(VolumeModel) .where(VolumeModel.id == volume_model.id) .options(joinedload(VolumeModel.project).joinedload(ProjectModel.backends)) + .options(joinedload(VolumeModel.user)) .execution_options(populate_existing=True) ) volume_model = res.unique().scalar_one() diff --git a/src/dstack/_internal/server/migrations/versions/82b32a135ea2_.py b/src/dstack/_internal/server/migrations/versions/82b32a135ea2_.py new file mode 100644 index 0000000000..29cd3e8e66 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/82b32a135ea2_.py @@ -0,0 +1,52 @@ +"""empty message + +Revision ID: 82b32a135ea2 +Revises: afbc600ff2b2 +Create Date: 2024-11-04 15:46:37.719531 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "82b32a135ea2" +down_revision = "afbc600ff2b2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("volumes", schema=None) as batch_op: + batch_op.add_column( + sa.Column("user_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True) + ) + batch_op.create_foreign_key( + batch_op.f("fk_volumes_user_id_users"), + "users", + ["user_id"], + ["id"], + ondelete="CASCADE", + ) + + # ### end Alembic commands ### + + # update any existing volumes and set the user_id equal to the project_owner.id which created the volume + op.execute( + "UPDATE volumes SET user_id = (SELECT owner_id FROM projects JOIN volumes ON projects.id = volumes.project_id) WHERE user_id IS NULL" + ) + + # set volumes.user_id to non-nullable + with op.batch_alter_table("volumes", schema=None) as batch_op: + batch_op.alter_column("user_id", nullable=False) + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("volumes", schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f("fk_volumes_user_id_users"), type_="foreignkey") + batch_op.drop_column("user_id") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 705d1e5b97..0a83dfc100 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -535,6 +535,9 @@ class VolumeModel(BaseModel): ) name: Mapped[str] = mapped_column(String(100)) + user_id: Mapped["UserModel"] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) + user: Mapped["UserModel"] = relationship() + project_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("projects.id", ondelete="CASCADE")) project: Mapped["ProjectModel"] = relationship(foreign_keys=[project_id]) diff --git a/src/dstack/_internal/server/routers/volumes.py b/src/dstack/_internal/server/routers/volumes.py index 9004f7c25b..a63fa8ba24 100644 --- a/src/dstack/_internal/server/routers/volumes.py +++ b/src/dstack/_internal/server/routers/volumes.py @@ -68,10 +68,11 @@ async def create_volume( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> Volume: - _, project = user_project + user, project = user_project return await volumes_services.create_volume( session=session, project=project, + user=user, configuration=body.configuration, ) diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 9a0359d233..0063b3419f 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -4,7 +4,7 @@ from sqlalchemy import and_, func, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import joinedload, selectinload from dstack._internal.core.backends import BACKENDS_WITH_VOLUMES_SUPPORT from dstack._internal.core.errors import ( @@ -102,7 +102,11 @@ async def list_projects_volume_models( if ascending: order_by = (VolumeModel.created_at.asc(), VolumeModel.id.desc()) res = await session.execute( - select(VolumeModel).where(*filters).order_by(*order_by).limit(limit) + select(VolumeModel) + .where(*filters) + .order_by(*order_by) + .limit(limit) + .options(joinedload(VolumeModel.user)) ) volume_models = list(res.scalars().all()) return volume_models @@ -130,7 +134,9 @@ async def list_project_volume_models( filters.append(VolumeModel.name.in_(names)) if not include_deleted: filters.append(VolumeModel.deleted == False) - res = await session.execute(select(VolumeModel).where(*filters)) + res = await session.execute( + select(VolumeModel).where(*filters).options(joinedload(VolumeModel.user)) + ) return list(res.scalars().all()) @@ -157,13 +163,16 @@ async def get_project_volume_model_by_name( ] if not include_deleted: filters.append(VolumeModel.deleted == False) - res = await session.execute(select(VolumeModel).where(*filters)) + res = await session.execute( + select(VolumeModel).where(*filters).options(joinedload(VolumeModel.user)) + ) return res.scalar_one_or_none() async def create_volume( session: AsyncSession, project: ProjectModel, + user: UserModel, configuration: VolumeConfiguration, ) -> Volume: _validate_volume_configuration(configuration) @@ -193,6 +202,7 @@ async def create_volume( volume_model = VolumeModel( id=uuid.uuid4(), name=configuration.name, + user_id=user.id, project=project, status=VolumeStatus.SUBMITTED, configuration=configuration.json(), @@ -224,6 +234,7 @@ async def delete_volumes(session: AsyncSession, project: ProjectModel, names: Li VolumeModel.name.in_(names), VolumeModel.deleted == False, ) + .options(selectinload(VolumeModel.user)) .options(selectinload(VolumeModel.instances)) .execution_options(populate_existing=True) .with_for_update() @@ -263,6 +274,7 @@ def volume_model_to_volume(volume_model: VolumeModel) -> Volume: return Volume( name=volume_model.name, project_name=volume_model.project.name, + user=volume_model.user.name, configuration=configuration, external=configuration.volume_id is not None, created_at=volume_model.created_at.replace(tzinfo=timezone.utc), diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 1a42df0949..6c6d9deb20 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -520,6 +520,7 @@ async def create_instance( async def create_volume( session: AsyncSession, project: ProjectModel, + user: UserModel, status: VolumeStatus = VolumeStatus.SUBMITTED, created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), configuration: Optional[VolumeConfiguration] = None, @@ -532,6 +533,7 @@ async def create_volume( configuration = get_volume_configuration(backend=backend, region=region) vm = VolumeModel( project=project, + user_id=user.id, name=configuration.name, status=status, created_at=created_at, diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index c8c02e0c84..149ff6f98d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -222,6 +222,7 @@ async def test_provisioning_shim_with_volumes( volume = await create_volume( session=session, project=project, + user=user, status=VolumeStatus.ACTIVE, configuration=get_volume_configuration( name="my-vol", backend=BackendType.AWS, region="us-east-1" diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py index 8ffb44f9f5..995c6dd37d 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_jobs.py @@ -392,6 +392,7 @@ async def test_assigns_job_to_instance_with_volumes(self, test_db, session: Asyn volume = await create_volume( session=session, project=project, + user=user, status=VolumeStatus.ACTIVE, volume_provisioning_data=get_volume_provisioning_data(), backend=BackendType.AWS, diff --git a/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py b/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py index 970c674b4a..86fbed603c 100644 --- a/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py +++ b/src/tests/_internal/server/background/tasks/test_process_submitted_volumes.py @@ -8,6 +8,7 @@ from dstack._internal.server.background.tasks.process_volumes import process_submitted_volumes from dstack._internal.server.testing.common import ( create_project, + create_user, create_volume, ) @@ -17,8 +18,9 @@ class TestProcessSubmittedVolumes: @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_fails_job_when_no_backends(self, test_db, session: AsyncSession): project = await create_project(session=session) + user = await create_user(session=session) volume = await create_volume( - session=session, project=project, status=VolumeStatus.SUBMITTED + session=session, project=project, user=user, status=VolumeStatus.SUBMITTED ) await process_submitted_volumes() await session.refresh(volume) @@ -29,8 +31,9 @@ async def test_fails_job_when_no_backends(self, test_db, session: AsyncSession): @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_provisiones_volumes(self, test_db, session: AsyncSession): project = await create_project(session=session) + user = await create_user(session=session) volume = await create_volume( - session=session, project=project, status=VolumeStatus.SUBMITTED + session=session, project=project, user=user, status=VolumeStatus.SUBMITTED ) with patch( "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" diff --git a/src/tests/_internal/server/routers/test_volumes.py b/src/tests/_internal/server/routers/test_volumes.py index a39c6e16b6..7b08fe2e6d 100644 --- a/src/tests/_internal/server/routers/test_volumes.py +++ b/src/tests/_internal/server/routers/test_volumes.py @@ -44,6 +44,7 @@ async def test_lists_volumes_across_projects( volume1 = await create_volume( session=session, project=project1, + user=user, created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), configuration=get_volume_configuration(name="volume1"), ) @@ -51,6 +52,7 @@ async def test_lists_volumes_across_projects( volume2 = await create_volume( session=session, project=project2, + user=user, created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), configuration=get_volume_configuration(name="volume2"), ) @@ -65,6 +67,7 @@ async def test_lists_volumes_across_projects( "id": str(volume2.id), "name": volume2.name, "project_name": project2.name, + "user": user.name, "configuration": json.loads(volume2.configuration), "external": False, "created_at": "2023-01-02T03:05:00+00:00", @@ -79,6 +82,7 @@ async def test_lists_volumes_across_projects( "id": str(volume1.id), "name": volume1.name, "project_name": project1.name, + "user": user.name, "configuration": json.loads(volume1.configuration), "external": False, "created_at": "2023-01-02T03:04:00+00:00", @@ -104,6 +108,7 @@ async def test_lists_volumes_across_projects( "id": str(volume1.id), "name": volume1.name, "project_name": project1.name, + "user": user.name, "configuration": json.loads(volume1.configuration), "external": False, "created_at": "2023-01-02T03:04:00+00:00", @@ -134,12 +139,14 @@ async def test_non_admin_cannot_see_others_projects( volume1 = await create_volume( session=session, project=project1, + user=user1, created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), configuration=get_volume_configuration(name="volume1"), ) await create_volume( session=session, project=project2, + user=user2, created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), configuration=get_volume_configuration(name="volume2"), ) @@ -154,6 +161,7 @@ async def test_non_admin_cannot_see_others_projects( "id": str(volume1.id), "name": volume1.name, "project_name": project1.name, + "user": user1.name, "configuration": json.loads(volume1.configuration), "external": False, "created_at": "2023-01-02T03:04:00+00:00", @@ -187,6 +195,7 @@ async def test_lists_volumes(self, test_db, session: AsyncSession, client: Async volume = await create_volume( session=session, project=project, + user=user, created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), ) response = await client.post( @@ -199,6 +208,7 @@ async def test_lists_volumes(self, test_db, session: AsyncSession, client: Async "id": str(volume.id), "name": volume.name, "project_name": project.name, + "user": user.name, "configuration": json.loads(volume.configuration), "external": False, "created_at": "2023-01-02T03:04:00+00:00", @@ -232,6 +242,7 @@ async def test_returns_volume(self, test_db, session: AsyncSession, client: Asyn volume = await create_volume( session=session, project=project, + user=user, created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), ) response = await client.post( @@ -244,6 +255,7 @@ async def test_returns_volume(self, test_db, session: AsyncSession, client: Asyn "id": str(volume.id), "name": volume.name, "project_name": project.name, + "user": user.name, "configuration": json.loads(volume.configuration), "external": False, "created_at": "2023-01-02T03:04:00+00:00", @@ -305,6 +317,7 @@ async def test_creates_volume(self, test_db, session: AsyncSession, client: Asyn "name": configuration.name, "project_name": project.name, "configuration": configuration, + "user": user.name, "external": False, "created_at": "2023-01-02T03:04:00+00:00", "status": "submitted", @@ -338,6 +351,7 @@ async def test_deletes_volumes(self, test_db, session: AsyncSession, client: Asy volume = await create_volume( session=session, project=project, + user=user, volume_provisioning_data=get_volume_provisioning_data(), ) with patch( @@ -369,6 +383,7 @@ async def test_returns_400_when_volumes_in_use( volume = await create_volume( session=session, project=project, + user=user, volume_provisioning_data=get_volume_provisioning_data(), ) instance = await create_instance(