From 9e726aab135a137f08d63a5243a2a0d1adbe3a60 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Wed, 7 Aug 2024 10:36:52 +0500 Subject: [PATCH] Implement API endpoint for listing volumes across projects --- src/dstack/_internal/core/models/volumes.py | 2 +- src/dstack/_internal/server/app.py | 3 +- .../_internal/server/routers/volumes.py | 32 ++++- .../_internal/server/schemas/volumes.py | 15 +- .../_internal/server/services/volumes.py | 82 ++++++++++- .../_internal/server/routers/test_volumes.py | 135 +++++++++++++++++- 6 files changed, 253 insertions(+), 16 deletions(-) diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py index e0f0c4aed..0215d5ca8 100644 --- a/src/dstack/_internal/core/models/volumes.py +++ b/src/dstack/_internal/core/models/volumes.py @@ -56,6 +56,7 @@ class VolumeAttachmentData(CoreModel): class Volume(CoreModel): + id: uuid.UUID name: str project_name: str configuration: VolumeConfiguration @@ -66,7 +67,6 @@ class Volume(CoreModel): volume_id: Optional[str] = None # id of the volume in the cloud provisioning_data: Optional[VolumeProvisioningData] = None attachment_data: Optional[VolumeAttachmentData] = None - volume_model_id: uuid.UUID # uuid of VolumeModel class VolumeMountPoint(CoreModel): diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index e507bf379..60e3af07a 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -162,7 +162,8 @@ def register_routes(app: FastAPI): app.include_router(logs.router) app.include_router(secrets.router) app.include_router(gateways.router) - app.include_router(volumes.router) + app.include_router(volumes.root_router) + app.include_router(volumes.project_router) @app.exception_handler(ForbiddenError) async def forbidden_error_handler(request: Request, exc: ForbiddenError): diff --git a/src/dstack/_internal/server/routers/volumes.py b/src/dstack/_internal/server/routers/volumes.py index 28aea74fc..9004f7c25 100644 --- a/src/dstack/_internal/server/routers/volumes.py +++ b/src/dstack/_internal/server/routers/volumes.py @@ -12,14 +12,34 @@ CreateVolumeRequest, DeleteVolumesRequest, GetVolumeRequest, + ListVolumesRequest, ) -from dstack._internal.server.security.permissions import ProjectMember +from dstack._internal.server.security.permissions import Authenticated, ProjectMember -router = APIRouter(prefix="/api/project/{project_name}/volumes", tags=["volumes"]) +root_router = APIRouter(prefix="/api/volumes", tags=["volumes"]) +project_router = APIRouter(prefix="/api/project/{project_name}/volumes", tags=["volumes"]) -@router.post("/list") +@root_router.post("/list") async def list_volumes( + body: ListVolumesRequest, + session: AsyncSession = Depends(get_session), + user: UserModel = Depends(Authenticated()), +) -> List[Volume]: + return await volumes_services.list_volumes( + session=session, + user=user, + project_name=body.project_name, + only_active=body.only_active, + prev_created_at=body.prev_created_at, + prev_id=body.prev_id, + limit=body.limit, + ascending=body.ascending, + ) + + +@project_router.post("/list") +async def list_project_volumes( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMember()), ) -> List[Volume]: @@ -27,7 +47,7 @@ async def list_volumes( return await volumes_services.list_project_volumes(session=session, project=project) -@router.post("/get") +@project_router.post("/get") async def get_volume( body: GetVolumeRequest, session: AsyncSession = Depends(get_session), @@ -42,7 +62,7 @@ async def get_volume( return volume -@router.post("/create") +@project_router.post("/create") async def create_volume( body: CreateVolumeRequest, session: AsyncSession = Depends(get_session), @@ -56,7 +76,7 @@ async def create_volume( ) -@router.post("/delete") +@project_router.post("/delete") async def delete_volumes( body: DeleteVolumesRequest, session: AsyncSession = Depends(get_session), diff --git a/src/dstack/_internal/server/schemas/volumes.py b/src/dstack/_internal/server/schemas/volumes.py index 1ca82467c..1a63c49b9 100644 --- a/src/dstack/_internal/server/schemas/volumes.py +++ b/src/dstack/_internal/server/schemas/volumes.py @@ -1,9 +1,22 @@ -from typing import List +from datetime import datetime +from typing import List, Optional +from uuid import UUID + +from pydantic import Field from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.volumes import VolumeConfiguration +class ListVolumesRequest(CoreModel): + project_name: Optional[str] + only_active: bool = False + prev_created_at: Optional[datetime] + prev_id: Optional[UUID] + limit: int = Field(100, ge=0, le=100) + ascending: bool = False + + class GetVolumeRequest(CoreModel): name: str diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 537340655..f906ace05 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -1,9 +1,9 @@ import asyncio import uuid -from datetime import timezone +from datetime import datetime, timezone from typing import List, Optional -from sqlalchemy import select, update +from sqlalchemy import and_, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -13,6 +13,7 @@ ResourceExistsError, ServerClientError, ) +from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.models.volumes import ( Volume, VolumeAttachmentData, @@ -21,8 +22,9 @@ VolumeStatus, ) from dstack._internal.core.services import validate_dstack_resource_name -from dstack._internal.server.models import ProjectModel, VolumeModel +from dstack._internal.server.models import ProjectModel, UserModel, VolumeModel from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services.projects import list_project_models, list_user_project_models from dstack._internal.server.utils.common import run_async, wait_to_lock_many from dstack._internal.utils import common, random_names from dstack._internal.utils.logging import get_logger @@ -34,6 +36,78 @@ PROCESSING_VOLUMES_IDS = set() +async def list_volumes( + session: AsyncSession, + user: UserModel, + project_name: Optional[str], + only_active: bool, + prev_created_at: Optional[datetime], + prev_id: Optional[uuid.UUID], + limit: int, + ascending: bool, +) -> List[Volume]: + if user.global_role == GlobalRole.ADMIN: + projects = await list_project_models(session=session) + else: + projects = await list_user_project_models(session=session, user=user) + if project_name is not None: + projects = [p for p in projects if p.name == project_name] + volume_models = await list_projects_volume_models( + session=session, + projects=projects, + only_active=only_active, + prev_created_at=prev_created_at, + prev_id=prev_id, + limit=limit, + ascending=ascending, + ) + return [volume_model_to_volume(v) for v in volume_models] + + +async def list_projects_volume_models( + session: AsyncSession, + projects: List[ProjectModel], + only_active: bool, + prev_created_at: Optional[datetime], + prev_id: Optional[uuid.UUID], + limit: int, + ascending: bool, +) -> List[VolumeModel]: + filters = [] + filters.append(VolumeModel.project_id.in_(p.id for p in projects)) + if only_active: + filters.append(VolumeModel.deleted == False) + if prev_created_at is not None: + if ascending: + if prev_id is None: + filters.append(VolumeModel.created_at > prev_created_at) + else: + filters.append( + or_( + VolumeModel.created_at > prev_created_at, + and_(VolumeModel.created_at == prev_created_at, VolumeModel.id < prev_id), + ) + ) + else: + if prev_id is None: + filters.append(VolumeModel.created_at < prev_created_at) + else: + filters.append( + or_( + VolumeModel.created_at < prev_created_at, + and_(VolumeModel.created_at == prev_created_at, VolumeModel.id > prev_id), + ) + ) + order_by = (VolumeModel.created_at.desc(), VolumeModel.id) + if ascending: + order_by = (VolumeModel.created_at.asc(), VolumeModel.id.desc()) + res = await session.execute( + select(VolumeModel).where(*filters).order_by(*order_by).limit(limit) + ) + volume_models = list(res.scalars().all()) + return volume_models + + async def list_project_volumes( session: AsyncSession, project: ProjectModel, @@ -187,7 +261,7 @@ def volume_model_to_volume(volume_model: VolumeModel) -> Volume: volume_id=vpd.volume_id if vpd is not None else None, provisioning_data=vpd, attachment_data=vad, - volume_model_id=volume_model.id, + id=volume_model.id, ) diff --git a/src/tests/_internal/server/routers/test_volumes.py b/src/tests/_internal/server/routers/test_volumes.py index 69e4c1b80..a16ad32c3 100644 --- a/src/tests/_internal/server/routers/test_volumes.py +++ b/src/tests/_internal/server/routers/test_volumes.py @@ -29,6 +29,135 @@ class TestListVolumes: + @pytest.mark.asyncio + async def test_returns_40x_if_not_authenticated(self, test_db, session: AsyncSession): + response = client.post("/api/volumes/list") + assert response.status_code == 403 + + @pytest.mark.asyncio + async def test_lists_volumes_across_projects(self, test_db, session: AsyncSession): + user = await create_user(session, global_role=GlobalRole.ADMIN) + project1 = await create_project(session, name="project1", owner=user) + volume1 = await create_volume( + session=session, + project=project1, + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + configuration=get_volume_configuration(name="volume1"), + ) + project2 = await create_project(session, name="project2", owner=user) + volume2 = await create_volume( + session=session, + project=project2, + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + configuration=get_volume_configuration(name="volume2"), + ) + response = client.post( + "/api/volumes/list", + headers=get_auth_headers(user.token), + json={}, + ) + assert response.status_code == 200, response.json() + assert response.json() == [ + { + "id": str(volume2.id), + "name": volume2.name, + "project_name": project2.name, + "configuration": json.loads(volume2.configuration), + "external": False, + "created_at": "2023-01-02T03:05:00+00:00", + "status": "submitted", + "status_message": None, + "volume_id": None, + "provisioning_data": None, + "attachment_data": None, + }, + { + "id": str(volume1.id), + "name": volume1.name, + "project_name": project1.name, + "configuration": json.loads(volume1.configuration), + "external": False, + "created_at": "2023-01-02T03:04:00+00:00", + "status": "submitted", + "status_message": None, + "volume_id": None, + "provisioning_data": None, + "attachment_data": None, + }, + ] + response = client.post( + "/api/volumes/list", + headers=get_auth_headers(user.token), + json={ + "prev_created_at": "2023-01-02T03:05:00+00:00", + "prev_id": str(volume2.id), + }, + ) + assert response.status_code == 200 + assert response.json() == [ + { + "id": str(volume1.id), + "name": volume1.name, + "project_name": project1.name, + "configuration": json.loads(volume1.configuration), + "external": False, + "created_at": "2023-01-02T03:04:00+00:00", + "status": "submitted", + "status_message": None, + "volume_id": None, + "provisioning_data": None, + "attachment_data": None, + }, + ] + + @pytest.mark.asyncio + async def test_non_admin_cannot_see_others_projects(self, test_db, session: AsyncSession): + user1 = await create_user(session, name="user1", global_role=GlobalRole.USER) + user2 = await create_user(session, name="user2", global_role=GlobalRole.USER) + project1 = await create_project(session, name="project1", owner=user1) + project2 = await create_project(session, name="project2", owner=user2) + await add_project_member( + session=session, project=project1, user=user1, project_role=ProjectRole.USER + ) + await add_project_member( + session=session, project=project2, user=user2, project_role=ProjectRole.USER + ) + volume1 = await create_volume( + session=session, + project=project1, + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + configuration=get_volume_configuration(name="volume1"), + ) + await create_volume( + session=session, + project=project2, + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + configuration=get_volume_configuration(name="volume2"), + ) + response = client.post( + "/api/volumes/list", + headers=get_auth_headers(user1.token), + json={}, + ) + assert response.status_code == 200, response.json() + assert response.json() == [ + { + "id": str(volume1.id), + "name": volume1.name, + "project_name": project1.name, + "configuration": json.loads(volume1.configuration), + "external": False, + "created_at": "2023-01-02T03:04:00+00:00", + "status": "submitted", + "status_message": None, + "volume_id": None, + "provisioning_data": None, + "attachment_data": None, + }, + ] + + +class TestListProjectVolumes: @pytest.mark.asyncio async def test_returns_40x_if_not_authenticated(self, test_db, session: AsyncSession): response = client.post("/api/project/main/volumes/list") @@ -53,6 +182,7 @@ async def test_lists_volumes(self, test_db, session: AsyncSession): assert response.status_code == 200 assert response.json() == [ { + "id": str(volume.id), "name": volume.name, "project_name": project.name, "configuration": json.loads(volume.configuration), @@ -63,7 +193,6 @@ async def test_lists_volumes(self, test_db, session: AsyncSession): "volume_id": None, "provisioning_data": None, "attachment_data": None, - "volume_model_id": str(volume.id), } ] @@ -93,6 +222,7 @@ async def test_returns_volume(self, test_db, session: AsyncSession): ) assert response.status_code == 200 assert response.json() == { + "id": str(volume.id), "name": volume.name, "project_name": project.name, "configuration": json.loads(volume.configuration), @@ -103,7 +233,6 @@ async def test_returns_volume(self, test_db, session: AsyncSession): "volume_id": None, "provisioning_data": None, "attachment_data": None, - "volume_model_id": str(volume.id), } @pytest.mark.asyncio @@ -145,6 +274,7 @@ async def test_creates_volume(self, test_db, session: AsyncSession): ) assert response.status_code == 200 assert response.json() == { + "id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", "name": configuration.name, "project_name": project.name, "configuration": configuration, @@ -155,7 +285,6 @@ async def test_creates_volume(self, test_db, session: AsyncSession): "volume_id": None, "provisioning_data": None, "attachment_data": None, - "volume_model_id": "1b0e1b45-2f8c-4ab6-8010-a0d1a3e44e0e", } res = await session.execute(select(VolumeModel)) assert res.scalar_one()