diff --git a/management_api_app/api/dependencies/authentication.py b/management_api_app/api/dependencies/authentication.py index b6637b6103..99dc898ace 100644 --- a/management_api_app/api/dependencies/authentication.py +++ b/management_api_app/api/dependencies/authentication.py @@ -1,8 +1,25 @@ from fastapi import Depends, HTTPException, status +from models.schemas.workspace import AuthenticationConfiguration, AuthProvider from resources import strings from services.aad_authentication import authorize +from services.aad_access_service import AADAccessService from services.authentication import User +from services.access_service import AccessService, AuthConfigValidationError + + +def extract_auth_information(auth_config: AuthenticationConfiguration) -> dict: + access_service = get_access_service(auth_config.provider) + try: + return access_service.extract_workspace_auth_information(auth_config.data) + except AuthConfigValidationError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + + +def get_access_service(provider: str) -> AccessService: + if provider == AuthProvider.AAD: + return AADAccessService() + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=strings.INVALID_AUTH_PROVIDER) async def get_current_user(user: User = Depends(authorize)) -> User: diff --git a/management_api_app/api/errors/generic_error.py b/management_api_app/api/errors/generic_error.py index 42fb0e070a..1b98f8fbfa 100644 --- a/management_api_app/api/errors/generic_error.py +++ b/management_api_app/api/errors/generic_error.py @@ -1,4 +1,5 @@ import logging +import traceback from starlette.requests import Request from starlette.responses import PlainTextResponse @@ -11,5 +12,5 @@ async def generic_error_handler(_: Request, exception: Exception) -> PlainTextRe logging.debug("=====================================") logging.exception(exception) logging.debug("=====================================") - error_string = exception if config.DEBUG else strings.UNABLE_TO_PROCESS_REQUEST + error_string = traceback.format_exc() if config.DEBUG else strings.UNABLE_TO_PROCESS_REQUEST return PlainTextResponse(error_string, status_code=500) diff --git a/management_api_app/api/routes/workspaces.py b/management_api_app/api/routes/workspaces.py index 35e42d3892..501e07c9fd 100644 --- a/management_api_app/api/routes/workspaces.py +++ b/management_api_app/api/routes/workspaces.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, Depends, HTTPException from starlette import status +from api.dependencies.authentication import extract_auth_information from api.dependencies.database import get_repository from api.dependencies.workspaces import get_workspace_by_workspace_id_from_path from api.dependencies.authentication import get_current_user, get_current_admin_user @@ -25,14 +26,12 @@ async def retrieve_active_workspaces( return WorkspacesInList(workspaces=workspaces) -@router.post("/workspaces", status_code=status.HTTP_202_ACCEPTED, response_model=WorkspaceIdInResponse, name=strings.API_CREATE_WORKSPACE, - dependencies=[Depends(get_current_admin_user)]) -async def create_workspace( - workspace_create: WorkspaceInCreate, - workspace_repo: WorkspaceRepository = Depends(get_repository(WorkspaceRepository)), -) -> WorkspaceIdInResponse: +@router.post("/workspaces", status_code=status.HTTP_202_ACCEPTED, response_model=WorkspaceIdInResponse, name=strings.API_CREATE_WORKSPACE, dependencies=[Depends(get_current_admin_user)]) +async def create_workspace(workspace_create: WorkspaceInCreate, workspace_repo: WorkspaceRepository = Depends(get_repository(WorkspaceRepository))) -> WorkspaceIdInResponse: + auth_information = extract_auth_information(workspace_create.authConfig) + try: - workspace = workspace_repo.create_workspace_item(workspace_create) + workspace = workspace_repo.create_workspace_item(workspace_create, auth_information) except WorkspaceValidationError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.errors) except ValueError as e: diff --git a/management_api_app/core/config.py b/management_api_app/core/config.py index f7e4e63f65..054f256c25 100644 --- a/management_api_app/core/config.py +++ b/management_api_app/core/config.py @@ -33,6 +33,7 @@ # Authentication API_CLIENT_ID: str = config("API_CLIENT_ID", default="") +API_CLIENT_SECRET: str = config("API_CLIENT_SECRET", default="") SWAGGER_UI_CLIENT_ID: str = config("SWAGGER_UI_CLIENT_ID", default="") AAD_TENANT_ID: str = config("AAD_TENANT_ID", default="") diff --git a/management_api_app/db/repositories/workspace_templates.py b/management_api_app/db/repositories/workspace_templates.py index eb24e9002f..cdb69ea03a 100644 --- a/management_api_app/db/repositories/workspace_templates.py +++ b/management_api_app/db/repositories/workspace_templates.py @@ -47,7 +47,7 @@ def get_workspace_template_names(self) -> List[str]: workspace_template_names = [template["name"] for template in workspace_templates] return list(set(workspace_template_names)) - def create_workspace_template_item(self, workspace_template_create: WorkspaceTemplateInCreate): + def create_workspace_template_item(self, workspace_template_create: WorkspaceTemplateInCreate) -> ResourceTemplate: item_id = str(uuid.uuid4()) resource_template = ResourceTemplate( id=item_id, @@ -56,7 +56,7 @@ def create_workspace_template_item(self, workspace_template_create: WorkspaceTem version=workspace_template_create.version, parameters=workspace_template_create.parameters, resourceType=ResourceType.Workspace, - current=workspace_template_create.current + current=workspace_template_create.current, ) self.create_item(resource_template) return resource_template diff --git a/management_api_app/db/repositories/workspaces.py b/management_api_app/db/repositories/workspaces.py index df1125ba15..2e5de1a11d 100644 --- a/management_api_app/db/repositories/workspaces.py +++ b/management_api_app/db/repositories/workspaces.py @@ -103,7 +103,7 @@ def get_workspace_by_workspace_id(self, workspace_id: UUID4) -> Workspace: raise EntityDoesNotExist return parse_obj_as(Workspace, workspaces[0]) - def create_workspace_item(self, workspace_create: WorkspaceInCreate) -> Workspace: + def create_workspace_item(self, workspace_create: WorkspaceInCreate, auth_info: dict) -> Workspace: full_workspace_id = str(uuid.uuid4()) try: @@ -131,7 +131,8 @@ def create_workspace_item(self, workspace_create: WorkspaceInCreate) -> Workspac resourceTemplateName=workspace_create.workspaceType, resourceTemplateVersion=template_version, resourceTemplateParameters=resource_spec_parameters, - deployment=Deployment(status=Status.NotDeployed, message=strings.RESOURCE_STATUS_NOT_DEPLOYED_MESSAGE) + deployment=Deployment(status=Status.NotDeployed, message=strings.RESOURCE_STATUS_NOT_DEPLOYED_MESSAGE), + authInformation=auth_info ) self._validate_workspace_parameters(current_template.parameters, workspace.resourceTemplateParameters) diff --git a/management_api_app/models/domain/workspace.py b/management_api_app/models/domain/workspace.py index 7b3592ee8c..7fd58a7bde 100644 --- a/management_api_app/models/domain/workspace.py +++ b/management_api_app/models/domain/workspace.py @@ -17,3 +17,4 @@ class Workspace(Resource): """ workspaceURL: str = Field("", title="Workspace URL", description="Main endpoint for workspace users") resourceType = ResourceType.Workspace + authInformation: dict = Field({}) diff --git a/management_api_app/models/schemas/workspace.py b/management_api_app/models/schemas/workspace.py index 9509e854d7..e4fac1f86b 100644 --- a/management_api_app/models/schemas/workspace.py +++ b/management_api_app/models/schemas/workspace.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import List from pydantic import BaseModel, Field @@ -23,10 +24,23 @@ def get_sample_workspace(workspace_id: str, spec_workspace_id: str = "0001") -> }, "isDeleted": False, "resourceType": "workspace", - "workspaceURL": "" + "workspaceURL": "", + "authInformation": {} } +class AuthProvider(str, Enum): + """ + Auth Provider + """ + AAD = "AAD" + + +class AuthenticationConfiguration(BaseModel): + provider: AuthProvider = Field(AuthProvider.AAD, title="Authentication Provider") + data: dict = Field({}, title="Authentication information") + + class WorkspaceInResponse(BaseModel): workspace: Workspace @@ -57,6 +71,7 @@ class WorkspaceInCreate(BaseModel): workspaceType: str = Field(title="Workspace type", description="Bundle name") description: str = Field(title="Workspace description") parameters: dict = Field({}, title="Workspace parameters", description="Values for the parameters required by the workspace resource specification") + authConfig: AuthenticationConfiguration = Field(title="Authentication configuration", description="Authentication configuration for the workspace") class Config: schema_extra = { @@ -64,7 +79,11 @@ class Config: "displayName": "My workspace", "description": "workspace for team X", "workspaceType": "tre-workspace-vanilla", - "parameters": {} + "parameters": {}, + "authConfig": { + "provider": "AAD", + "data": {"app_id": "1212445c-aae6-41ec-a539-30dfa90ab1ae"} + } } } diff --git a/management_api_app/requirements.txt b/management_api_app/requirements.txt index ca385e5402..ab61e166b9 100644 --- a/management_api_app/requirements.txt +++ b/management_api_app/requirements.txt @@ -10,3 +10,4 @@ azure-identity==1.6.0 azure-servicebus==7.2.0 opencensus-ext-azure==1.0.8 opencensus-ext-logging==0.1.0 +msal==1.12.0 diff --git a/management_api_app/resources/strings.py b/management_api_app/resources/strings.py index 99df73d4a7..a54d42f4fe 100644 --- a/management_api_app/resources/strings.py +++ b/management_api_app/resources/strings.py @@ -21,14 +21,19 @@ UNSPECIFIED_ERROR = "Unspecified error" # Error strings +ACCESS_APP_IS_MISSING_ROLE = "The App is missing role" +ACCESS_PLEASE_SUPPLY_APP_ID = "Please supply the app_id for the AAD application" +ACCESS_UNABLE_TO_GET_INFO_FOR_APP = "Unable to get app info for app:" AUTH_NOT_ASSIGNED_TO_ADMIN_ROLE = "Not assigned to admin role" AUTH_COULD_NOT_VALIDATE_CREDENTIALS = "Could not validate credentials" +INVALID_AUTH_PROVIDER = "Invalid authentication provider" UNABLE_TO_REPLACE_CURRENT_TEMPLATE = "Unable to replace the existing 'current' template with this name" UNABLE_TO_PROCESS_REQUEST = "Unable to process request" WORKSPACE_DOES_NOT_EXIST = "Workspace does not exist" WORKSPACE_TEMPLATE_DOES_NOT_EXIST = "Could not retrieve the 'current' template with this name" WORKSPACE_TEMPLATE_VERSION_EXISTS = "A template with this version already exists" + # Resource Status RESOURCE_STATUS_NOT_DEPLOYED = "not_deployed" RESOURCE_STATUS_DEPLOYING = "deploying" diff --git a/management_api_app/services/aad_access_service.py b/management_api_app/services/aad_access_service.py new file mode 100644 index 0000000000..0e68787326 --- /dev/null +++ b/management_api_app/services/aad_access_service.py @@ -0,0 +1,67 @@ +import logging + +import requests +from msal import ConfidentialClientApplication + +from core import config +from resources import strings +from services.access_service import AccessService, AuthConfigValidationError + + +class AADAccessService(AccessService): + @staticmethod + def _get_msgraph_token() -> str: + scopes = ["https://graph.microsoft.com/.default"] + app = ConfidentialClientApplication(client_id=config.API_CLIENT_ID, client_credential=config.API_CLIENT_SECRET, authority=f"{config.AAD_INSTANCE}/{config.AAD_TENANT_ID}") + result = app.acquire_token_silent(scopes=scopes, account=None) + if not result: + logging.info('No suitable token exists in cache, getting a new one from AAD') + result = app.acquire_token_for_client(scopes=scopes) + if "access_token" not in result: + logging.debug(result.get('error')) + logging.debug(result.get('error_description')) + logging.debug(result.get('correlation_id')) + raise Exception(result.get('error')) + return result["access_token"] + + @staticmethod + def _get_auth_header(msgraph_token: str) -> dict: + return {'Authorization': 'Bearer ' + msgraph_token} + + @staticmethod + def _get_service_principal_endpoint(app_id) -> str: + return f"https://graph.microsoft.com/v1.0/serviceprincipals?$filter=appid eq '{app_id}'" + + def _get_app_sp_graph_data(self, app_id: str) -> dict: + msgraph_token = self._get_msgraph_token() + sp_endpoint = self._get_service_principal_endpoint(app_id) + graph_data = requests.get(sp_endpoint, headers=self._get_auth_header(msgraph_token)).json() + return graph_data + + def _get_app_auth_info(self, app_id: str) -> dict: + graph_data = self._get_app_sp_graph_data(app_id) + if 'value' not in graph_data or len(graph_data['value']) == 0: + logging.debug(graph_data) + raise AuthConfigValidationError(f"{strings.ACCESS_UNABLE_TO_GET_INFO_FOR_APP} {app_id}") + + app_info = graph_data['value'][0] + sp_id = app_info['id'] + roles = app_info['appRoles'] + + return { + 'sp_id': sp_id, + 'roles': {role['value']: role['id'] for role in roles} + } + + def extract_workspace_auth_information(self, data: dict) -> dict: + if "app_id" not in data: + raise AuthConfigValidationError(strings.ACCESS_PLEASE_SUPPLY_APP_ID) + + auth_info = self._get_app_auth_info(data["app_id"]) + print(auth_info) + + for role in ['WorkspaceOwner', 'WorkspaceResearcher']: + if role not in auth_info['roles']: + raise AuthConfigValidationError(f"{strings.ACCESS_APP_IS_MISSING_ROLE} {role}") + + return auth_info diff --git a/management_api_app/services/access_service.py b/management_api_app/services/access_service.py new file mode 100644 index 0000000000..bc6f7465b5 --- /dev/null +++ b/management_api_app/services/access_service.py @@ -0,0 +1,11 @@ +from abc import ABC, abstractmethod + + +class AuthConfigValidationError(Exception): + """Raised when the input auth information is invalid""" + + +class AccessService(ABC): + @abstractmethod + def extract_workspace_auth_information(self, data: dict) -> dict: + pass diff --git a/management_api_app/tests/test_api/test_routes/test_workspaces.py b/management_api_app/tests/test_api/test_routes/test_workspaces.py index 84c78b0971..ae78853bd1 100644 --- a/management_api_app/tests/test_api/test_routes/test_workspaces.py +++ b/management_api_app/tests/test_api/test_routes/test_workspaces.py @@ -1,5 +1,5 @@ import pytest -from mock import AsyncMock, patch +from mock import patch from fastapi import FastAPI from httpx import AsyncClient @@ -22,7 +22,8 @@ def create_sample_workspace_object(workspace_id): resourceTemplateName="tre-workspace-vanilla", resourceTemplateVersion="0.1.0", resourceTemplateParameters={}, - deployment=Deployment(status=Status.NotDeployed, message="") + deployment=Deployment(status=Status.NotDeployed, message=""), + authInformation={} ) @@ -31,7 +32,13 @@ def create_sample_workspace_input_data(): "displayName": "My workspace", "description": "workspace for team X", "workspaceType": "tre-workspace-vanilla", - "parameters": {} + "parameters": {}, + "authConfig": { + "provider": "AAD", + "data": { + "app_id": "1212445c-aae6-41ec-a539-30dfa90ab1ae" + } + } } @@ -84,12 +91,11 @@ async def test_workspaces_id_get_returns_workspace_if_found(get_workspace_mock, assert actual_resource["id"] == sample_workspace["id"] -@patch('service_bus.resource_request_sender.ServiceBusClient') @patch("api.routes.workspaces.send_resource_request_message") @patch("api.routes.workspaces.WorkspaceRepository.save_workspace") @patch("api.routes.workspaces.WorkspaceRepository.create_workspace_item") -async def test_workspaces_post_creates_workspace(create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, service_bus_client_mock, app: FastAPI, client: AsyncClient): - service_bus_client_mock().get_queue_sender().send_messages = AsyncMock() +@patch("api.routes.workspaces.extract_auth_information", return_value={}) +async def test_workspaces_post_creates_workspace(extract_auth_info_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient): workspace_id = "000000d3-82da-4bfc-b6e9-9a7853ef753e" create_workspace_item_mock.return_value = create_sample_workspace_object(workspace_id) input_data = create_sample_workspace_input_data() @@ -100,15 +106,14 @@ async def test_workspaces_post_creates_workspace(create_workspace_item_mock, sav assert response.json()["workspaceId"] == workspace_id -@patch('service_bus.resource_request_sender.ServiceBusClient') @patch("api.routes.workspaces.send_resource_request_message") @patch("api.routes.workspaces.WorkspaceRepository.save_workspace") @patch("api.routes.workspaces.WorkspaceRepository.create_workspace_item") @patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters") -async def test_workspaces_post_calls_db_and_service_bus(validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, service_bus_client_mock, app: FastAPI, client: AsyncClient): +@patch("api.routes.workspaces.extract_auth_information", return_value={}) +async def test_workspaces_post_calls_db_and_service_bus(extract_auth_info_mock, validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient): workspace_id = "000000d3-82da-4bfc-b6e9-9a7853ef753e" validate_workspace_parameters_mock.return_value = None - service_bus_client_mock().get_queue_sender().send_messages = AsyncMock() create_workspace_item_mock.return_value = create_sample_workspace_object(workspace_id) input_data = create_sample_workspace_input_data() @@ -118,15 +123,14 @@ async def test_workspaces_post_calls_db_and_service_bus(validate_workspace_param send_resource_request_message_mock.assert_called_once() -@patch('service_bus.resource_request_sender.ServiceBusClient') @patch("api.routes.workspaces.send_resource_request_message") @patch("api.routes.workspaces.WorkspaceRepository.save_workspace") @patch("api.routes.workspaces.WorkspaceRepository.create_workspace_item") @patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters") -async def test_workspaces_post_returns_202_on_successful_create(validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, service_bus_client_mock, app: FastAPI, client: AsyncClient): +@patch("api.routes.workspaces.extract_auth_information", return_value={}) +async def test_workspaces_post_returns_202_on_successful_create(extract_auth_info_mock, validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient): workspace_id = "000000d3-82da-4bfc-b6e9-9a7853ef753e" validate_workspace_parameters_mock.return_value = None - service_bus_client_mock().get_queue_sender().send_messages = AsyncMock() create_workspace_item_mock.return_value = create_sample_workspace_object(workspace_id) input_data = create_sample_workspace_input_data() @@ -140,7 +144,8 @@ async def test_workspaces_post_returns_202_on_successful_create(validate_workspa @patch("api.routes.workspaces.WorkspaceRepository.save_workspace") @patch("api.routes.workspaces.WorkspaceRepository.create_workspace_item") @patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters") -async def test_workspaces_post_returns_503_if_service_bus_call_fails(validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient): +@patch("api.routes.workspaces.extract_auth_information", return_value={}) +async def test_workspaces_post_returns_503_if_service_bus_call_fails(extract_auth_info_mock, validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient): workspace_id = "000000d3-82da-4bfc-b6e9-9a7853ef753e" validate_workspace_parameters_mock.return_value = None create_workspace_item_mock.return_value = create_sample_workspace_object(workspace_id) @@ -155,7 +160,8 @@ async def test_workspaces_post_returns_503_if_service_bus_call_fails(validate_wo @patch("api.routes.workspaces.WorkspaceRepository._get_current_workspace_template") @patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters") -async def test_workspaces_post_returns_400_if_template_does_not_exist(validate_workspace_parameters_mock, get_current_workspace_template_mock, app: FastAPI, client: AsyncClient): +@patch("api.routes.workspaces.extract_auth_information", return_value={}) +async def test_workspaces_post_returns_400_if_template_does_not_exist(extract_auth_info_mock, validate_workspace_parameters_mock, get_current_workspace_template_mock, app: FastAPI, client: AsyncClient): validate_workspace_parameters_mock.return_value = None get_current_workspace_template_mock.side_effect = EntityDoesNotExist input_data = create_sample_workspace_input_data() diff --git a/management_api_app/tests/test_db/test_repositories/test_workpaces_repository.py b/management_api_app/tests/test_db/test_repositories/test_workpaces_repository.py index 29868e7d9d..fe543f4305 100644 --- a/management_api_app/tests/test_db/test_repositories/test_workpaces_repository.py +++ b/management_api_app/tests/test_db/test_repositories/test_workpaces_repository.py @@ -8,7 +8,7 @@ from models.domain.resource import Deployment, Status, ResourceType from models.domain.resource_template import ResourceTemplate, Parameter from models.domain.workspace import Workspace -from models.schemas.workspace import WorkspaceInCreate +from models.schemas.workspace import WorkspaceInCreate, AuthenticationConfiguration, AuthProvider @patch('azure.cosmos.CosmosClient') @@ -52,14 +52,19 @@ def test_get_workspace_by_id_throws_entity_does_not_exist_if_item_does_not_exist @patch('db.repositories.workspaces.WorkspaceRepository._get_current_workspace_template') @patch('azure.cosmos.CosmosClient') -@patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters") +@patch("db.repositories.workspaces.WorkspaceRepository._validate_workspace_parameters") def test_create_workspace_item_creates_a_workspace_with_the_right_values(validate_workspace_parameters_mock, cosmos_client_mock, _get_current_workspace_template_mock): workspace_repo = db.repositories.workspaces.WorkspaceRepository(cosmos_client_mock) workspace_type = "vanilla-tre" display_name = "my workspace" description = "some description" - workspace_to_create = WorkspaceInCreate(workspaceType=workspace_type, displayName=display_name, description=description) + workspace_to_create = WorkspaceInCreate( + workspaceType=workspace_type, + displayName=display_name, + description=description, + authConfig=AuthenticationConfiguration(provider=AuthProvider.AAD, data={}) + ) validate_workspace_parameters_mock.return_value = None _get_current_workspace_template_mock.return_value = ResourceTemplate( id="a7a7a7bd-7f4e-4a4e-b970-dc86a6b31dfb", @@ -68,9 +73,10 @@ def test_create_workspace_item_creates_a_workspace_with_the_right_values(validat version="0.1.0", resourceType=ResourceType.Workspace, parameters=[], - current=False) + current=False + ) - workspace = workspace_repo.create_workspace_item(workspace_to_create) + workspace = workspace_repo.create_workspace_item(workspace_to_create, {}) assert workspace.displayName == display_name assert workspace.description == description @@ -88,11 +94,16 @@ def test_create_workspace_item_creates_a_workspace_with_the_right_values(validat def test_create_workspace_item_raises_value_error_if_template_is_invalid(cosmos_client_mock, _get_current_workspace_template_mock): workspace_repo = db.repositories.workspaces.WorkspaceRepository(cosmos_client_mock) - workspace_to_create = WorkspaceInCreate(workspaceType="vanilla-tre", displayName="my workspace", description="some description") + workspace_to_create = WorkspaceInCreate( + workspaceType="vanilla-tre", + displayName="my workspace", + description="some description", + authConfig=AuthenticationConfiguration(provider=AuthProvider.AAD, data={}) + ) _get_current_workspace_template_mock.side_effect = EntityDoesNotExist with pytest.raises(ValueError): - workspace_repo.create_workspace_item(workspace_to_create) + workspace_repo.create_workspace_item(workspace_to_create, {}) @patch('azure.cosmos.CosmosClient') diff --git a/management_api_app/tests/test_services/test_aad_access_service.py b/management_api_app/tests/test_services/test_aad_access_service.py new file mode 100644 index 0000000000..430da23166 --- /dev/null +++ b/management_api_app/tests/test_services/test_aad_access_service.py @@ -0,0 +1,56 @@ +import pytest +from mock import patch + +from services.aad_access_service import AADAccessService +from services.access_service import AuthConfigValidationError + + +def test_extract_workspace__raises_error_if_app_id_not_available(): + access_service = AADAccessService() + with pytest.raises(AuthConfigValidationError): + access_service.extract_workspace_auth_information(data={}) + + +@patch("services.aad_access_service.AADAccessService._get_app_auth_info", return_value={"roles": {"WorkspaceResearcher": "1234"}}) +def test_extract_workspace__raises_error_if_owner_not_in_roles(get_app_auth_info_mock): + access_service = AADAccessService() + with pytest.raises(AuthConfigValidationError): + access_service.extract_workspace_auth_information(data={"app_id": "1234"}) + + +@patch("services.aad_access_service.AADAccessService._get_app_auth_info", return_value={"roles": {"WorkspaceOwner": "1234"}}) +def test_extract_workspace__raises_error_if_researcher_not_in_roles(get_app_auth_info_mock): + access_service = AADAccessService() + with pytest.raises(AuthConfigValidationError): + access_service.extract_workspace_auth_information(data={"app_id": "1234"}) + + +@patch("services.aad_access_service.AADAccessService._get_app_sp_graph_data", return_value={}) +def test_extract_workspace__raises_error_if_graph_data_is_invalid(get_app_sp_graph_data_mock): + access_service = AADAccessService() + with pytest.raises(AuthConfigValidationError): + access_service.extract_workspace_auth_information(data={"app_id": "1234"}) + + +@patch("services.aad_access_service.AADAccessService._get_app_sp_graph_data") +def test_extract_workspace__returns_sp_id_and_roles(get_app_sp_graph_data_mock): + get_app_sp_graph_data_mock.return_value = { + 'value': [ + { + 'id': '12345', + 'appRoles': [ + {'id': '1abc3', 'value': 'WorkspaceResearcher'}, + {'id': '1abc4', 'value': 'WorkspaceOwner'}, + ] + } + ] + } + expected_auth_info = { + 'sp_id': '12345', + 'roles': {'WorkspaceResearcher': '1abc3', 'WorkspaceOwner': '1abc4'} + } + + access_service = AADAccessService() + actual_auth_info = access_service.extract_workspace_auth_information(data={"app_id": "1234"}) + + assert actual_auth_info == expected_auth_info