diff --git a/management_api_app/db/repositories/workspace_templates.py b/management_api_app/db/repositories/workspace_templates.py index 68d2e23891..4906a476a8 100644 --- a/management_api_app/db/repositories/workspace_templates.py +++ b/management_api_app/db/repositories/workspace_templates.py @@ -2,10 +2,12 @@ from typing import List from azure.cosmos import CosmosClient +from pydantic import parse_obj_as from core import config from db.errors import EntityDoesNotExist from db.repositories.base import BaseRepository +from models.domain.resource import ResourceType from models.domain.resource_template import ResourceTemplate from models.schemas.workspace_template import WorkspaceTemplateInCreate @@ -20,26 +22,28 @@ def _workspace_template_by_name_query(name: str) -> str: def get_workspace_templates_by_name(self, name: str) -> List[ResourceTemplate]: query = self._workspace_template_by_name_query(name) - return self.query(query=query) + resource_templates = self.query(query=query) + print(resource_templates) + return parse_obj_as(List[ResourceTemplate], resource_templates) def get_current_workspace_template_by_name(self, name: str) -> ResourceTemplate: query = self._workspace_template_by_name_query(name) + ' AND c.current = true' workspace_templates = self.query(query=query) + print(workspace_templates) if len(workspace_templates) != 1: raise EntityDoesNotExist - return workspace_templates[0] + return parse_obj_as(ResourceTemplate, workspace_templates[0]) def get_workspace_template_by_name_and_version(self, name: str, version: str) -> ResourceTemplate: query = self._workspace_template_by_name_query(name) + f' AND c.version = "{version}"' workspace_templates = self.query(query=query) if len(workspace_templates) != 1: raise EntityDoesNotExist - return workspace_templates[0] + return parse_obj_as(ResourceTemplate, workspace_templates[0]) def get_workspace_template_names(self) -> List[str]: query = 'SELECT c.name FROM c' workspace_templates = self.query(query=query) - print(workspace_templates) workspace_template_names = [template["name"] for template in workspace_templates] return list(set(workspace_template_names)) @@ -51,7 +55,7 @@ def create_workspace_template_item(self, workspace_template_create: WorkspaceTem description=workspace_template_create.description, version=workspace_template_create.version, parameters=workspace_template_create.parameters, - resourceType=workspace_template_create.resourceType, + resourceType=ResourceType.Workspace, current=workspace_template_create.current ) self.create_item(resource_template) diff --git a/management_api_app/db/repositories/workspaces.py b/management_api_app/db/repositories/workspaces.py index f093101059..c2cc1cd99b 100644 --- a/management_api_app/db/repositories/workspaces.py +++ b/management_api_app/db/repositories/workspaces.py @@ -2,7 +2,7 @@ from typing import List from azure.cosmos import CosmosClient -from pydantic import UUID4 +from pydantic import parse_obj_as, UUID4 from core import config from resources import strings @@ -25,18 +25,19 @@ def _active_workspaces_query(): def _get_template_version(self, template_name): workspace_template_repo = WorkspaceTemplateRepository(self._client) template = workspace_template_repo.get_current_workspace_template_by_name(template_name) - return template["version"] + return template.version def get_all_active_workspaces(self) -> List[Workspace]: query = self._active_workspaces_query() - return self.query(query=query) + workspaces = self.query(query=query) + return parse_obj_as(List[Workspace], workspaces) def get_workspace_by_workspace_id(self, workspace_id: UUID4) -> Workspace: query = self._active_workspaces_query() + f' AND c.id="{workspace_id}"' workspaces = self.query(query=query) if not workspaces: raise EntityDoesNotExist - return workspaces[0] + return parse_obj_as(Workspace, workspaces[0]) def create_workspace_item(self, workspace_create: WorkspaceInCreate) -> Workspace: full_workspace_id = str(uuid.uuid4()) diff --git a/management_api_app/models/schemas/workspace_template.py b/management_api_app/models/schemas/workspace_template.py index 016473c648..5b0e51600b 100644 --- a/management_api_app/models/schemas/workspace_template.py +++ b/management_api_app/models/schemas/workspace_template.py @@ -38,12 +38,10 @@ class Config: class WorkspaceTemplateInCreate(BaseModel): - name: str = Field(title="Name of workspace template") version: str = Field(title="Version of workspace template") description: str = Field(title=" Description of workspace template") parameters: List[Parameter] = Field([], title="Workspace template parameters", description="Values for the parameters required by the workspace template") - resourceType: str = Field(title="Type of workspace template") current: bool = Field(title="Mark this version as current") class Config: @@ -56,7 +54,6 @@ class Config: "name": "azure_location", "type": "string" }], - "resourceType": "workspace", "current": "true" } } diff --git a/management_api_app/tests/test_db/test_repositories/test_workpaces_repository.py b/management_api_app/tests/test_db/test_repositories/test_workpaces_repository.py index b3a948b908..af71de3a73 100644 --- a/management_api_app/tests/test_db/test_repositories/test_workpaces_repository.py +++ b/management_api_app/tests/test_db/test_repositories/test_workpaces_repository.py @@ -25,7 +25,13 @@ def test_get_all_active_workspaces_calls_db_with_correct_query(cosmos_client_moc def test_get_workspace_by_id_calls_db_with_correct_query(cosmos_client_mock): workspace_repo = db.repositories.workspaces.WorkspaceRepository(cosmos_client_mock) workspace_id = uuid.uuid4() - workspace_repo.container.query_items = MagicMock(return_value=[{"id": str(workspace_id)}]) + return_workspace = { + "id": str(workspace_id), + "resourceTemplateName": "some-template-name", + "resourceTemplateVersion": "1.0", + "deployment": {"status": Status.NotDeployed, "message": ""} + } + workspace_repo.container.query_items = MagicMock(return_value=[return_workspace]) expected_query = f'SELECT * FROM c WHERE c.resourceType = "workspace" AND c.isDeleted = false AND c.id="{str(workspace_id)}"' workspace_repo.get_workspace_by_workspace_id(workspace_id) diff --git a/management_api_app/tests/test_db/test_repositories/test_workspace_templates_repository.py b/management_api_app/tests/test_db/test_repositories/test_workspace_templates_repository.py index f7c2fad3e2..ff500f8b03 100644 --- a/management_api_app/tests/test_db/test_repositories/test_workspace_templates_repository.py +++ b/management_api_app/tests/test_db/test_repositories/test_workspace_templates_repository.py @@ -8,7 +8,7 @@ from models.schemas.workspace_template import WorkspaceTemplateInCreate -def get_sample_workspace_template(name: str, version: str = "1.0") -> ResourceTemplate: +def get_sample_workspace_template(name: str, version: str = "1.0") -> dict: return ResourceTemplate( id="a7a7a7bd-7f4e-4a4e-b970-dc86a6b31dfb", name=name, @@ -17,7 +17,7 @@ def get_sample_workspace_template(name: str, version: str = "1.0") -> ResourceTe resourceType=ResourceType.Workspace, parameters=[], current=False - ) + ).dict() @patch('db.repositories.workspace_templates.WorkspaceTemplateRepository.query') @@ -25,6 +25,7 @@ def get_sample_workspace_template(name: str, version: str = "1.0") -> ResourceTe def test_get_by_name_queries_db(cosmos_client_mock, wt_query_mock): template_repo = WorkspaceTemplateRepository(cosmos_client_mock) expected_query = 'SELECT * FROM c WHERE c.resourceType = "workspace" AND c.name = "test"' + wt_query_mock.return_value = [get_sample_workspace_template(name="test")] template_repo.get_workspace_templates_by_name(name="test") @@ -90,7 +91,7 @@ def test_get_by_name_and_version_raises_entity_does_not_exist_if_no_template_fou def test_get_current_by_name_queries_db(cosmos_client_mock, wt_query_mock): template_repo = WorkspaceTemplateRepository(cosmos_client_mock) expected_query = 'SELECT * FROM c WHERE c.resourceType = "workspace" AND c.name = "test" AND c.current = true' - wt_query_mock.return_value = [get_sample_workspace_template(name="test", version="1.0")] + wt_query_mock.return_value = [get_sample_workspace_template(name="test")] template_repo.get_current_workspace_template_by_name(name="test") @@ -102,8 +103,7 @@ def test_get_current_by_name_queries_db(cosmos_client_mock, wt_query_mock): def test_get_current_by_name_returns_matching_template(cosmos_client_mock, wt_query_mock): template_repo = WorkspaceTemplateRepository(cosmos_client_mock) template_name = "test" - workspace_templates_in_db = [get_sample_workspace_template(name=template_name)] - wt_query_mock.return_value = workspace_templates_in_db + wt_query_mock.return_value = [get_sample_workspace_template(name=template_name)] template = template_repo.get_current_workspace_template_by_name(name=template_name)