Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions management_api_app/api/dependencies/authentication.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,25 @@
from fastapi import Depends, HTTPException, status

from models.schemas.workspace import AuthenticationConfiguration, AuthProvider
from resources import strings
from services.aad_authentication import authorize
from services.aad_access_service import AADAccessService
from services.authentication import User
from services.access_service import AccessService, AuthConfigValidationError


def extract_auth_information(auth_config: AuthenticationConfiguration) -> dict:
access_service = get_access_service(auth_config.provider)
try:
return access_service.extract_workspace_auth_information(auth_config.data)
except AuthConfigValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))


def get_access_service(provider: str) -> AccessService:
if provider == AuthProvider.AAD:
return AADAccessService()
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=strings.INVALID_AUTH_PROVIDER)


async def get_current_user(user: User = Depends(authorize)) -> User:
Expand Down
3 changes: 2 additions & 1 deletion management_api_app/api/errors/generic_error.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import traceback

from starlette.requests import Request
from starlette.responses import PlainTextResponse
Expand All @@ -11,5 +12,5 @@ async def generic_error_handler(_: Request, exception: Exception) -> PlainTextRe
logging.debug("=====================================")
logging.exception(exception)
logging.debug("=====================================")
error_string = exception if config.DEBUG else strings.UNABLE_TO_PROCESS_REQUEST
error_string = traceback.format_exc() if config.DEBUG else strings.UNABLE_TO_PROCESS_REQUEST
return PlainTextResponse(error_string, status_code=500)
13 changes: 6 additions & 7 deletions management_api_app/api/routes/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import APIRouter, Depends, HTTPException
from starlette import status

from api.dependencies.authentication import extract_auth_information
from api.dependencies.database import get_repository
from api.dependencies.workspaces import get_workspace_by_workspace_id_from_path
from api.dependencies.authentication import get_current_user, get_current_admin_user
Expand All @@ -25,14 +26,12 @@ async def retrieve_active_workspaces(
return WorkspacesInList(workspaces=workspaces)


@router.post("/workspaces", status_code=status.HTTP_202_ACCEPTED, response_model=WorkspaceIdInResponse, name=strings.API_CREATE_WORKSPACE,
dependencies=[Depends(get_current_admin_user)])
async def create_workspace(
workspace_create: WorkspaceInCreate,
workspace_repo: WorkspaceRepository = Depends(get_repository(WorkspaceRepository)),
) -> WorkspaceIdInResponse:
@router.post("/workspaces", status_code=status.HTTP_202_ACCEPTED, response_model=WorkspaceIdInResponse, name=strings.API_CREATE_WORKSPACE, dependencies=[Depends(get_current_admin_user)])
async def create_workspace(workspace_create: WorkspaceInCreate, workspace_repo: WorkspaceRepository = Depends(get_repository(WorkspaceRepository))) -> WorkspaceIdInResponse:
auth_information = extract_auth_information(workspace_create.authConfig)

try:
workspace = workspace_repo.create_workspace_item(workspace_create)
workspace = workspace_repo.create_workspace_item(workspace_create, auth_information)
except WorkspaceValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.errors)
except ValueError as e:
Expand Down
1 change: 1 addition & 0 deletions management_api_app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

# Authentication
API_CLIENT_ID: str = config("API_CLIENT_ID", default="")
API_CLIENT_SECRET: str = config("API_CLIENT_SECRET", default="")
SWAGGER_UI_CLIENT_ID: str = config("SWAGGER_UI_CLIENT_ID", default="")
AAD_TENANT_ID: str = config("AAD_TENANT_ID", default="")

Expand Down
4 changes: 2 additions & 2 deletions management_api_app/db/repositories/workspace_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def get_workspace_template_names(self) -> List[str]:
workspace_template_names = [template["name"] for template in workspace_templates]
return list(set(workspace_template_names))

def create_workspace_template_item(self, workspace_template_create: WorkspaceTemplateInCreate):
def create_workspace_template_item(self, workspace_template_create: WorkspaceTemplateInCreate) -> ResourceTemplate:
item_id = str(uuid.uuid4())
resource_template = ResourceTemplate(
id=item_id,
Expand All @@ -56,7 +56,7 @@ def create_workspace_template_item(self, workspace_template_create: WorkspaceTem
version=workspace_template_create.version,
parameters=workspace_template_create.parameters,
resourceType=ResourceType.Workspace,
current=workspace_template_create.current
current=workspace_template_create.current,
)
self.create_item(resource_template)
return resource_template
Expand Down
5 changes: 3 additions & 2 deletions management_api_app/db/repositories/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_workspace_by_workspace_id(self, workspace_id: UUID4) -> Workspace:
raise EntityDoesNotExist
return parse_obj_as(Workspace, workspaces[0])

def create_workspace_item(self, workspace_create: WorkspaceInCreate) -> Workspace:
def create_workspace_item(self, workspace_create: WorkspaceInCreate, auth_info: dict) -> Workspace:
full_workspace_id = str(uuid.uuid4())

try:
Expand Down Expand Up @@ -131,7 +131,8 @@ def create_workspace_item(self, workspace_create: WorkspaceInCreate) -> Workspac
resourceTemplateName=workspace_create.workspaceType,
resourceTemplateVersion=template_version,
resourceTemplateParameters=resource_spec_parameters,
deployment=Deployment(status=Status.NotDeployed, message=strings.RESOURCE_STATUS_NOT_DEPLOYED_MESSAGE)
deployment=Deployment(status=Status.NotDeployed, message=strings.RESOURCE_STATUS_NOT_DEPLOYED_MESSAGE),
authInformation=auth_info
)

self._validate_workspace_parameters(current_template.parameters, workspace.resourceTemplateParameters)
Expand Down
1 change: 1 addition & 0 deletions management_api_app/models/domain/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ class Workspace(Resource):
"""
workspaceURL: str = Field("", title="Workspace URL", description="Main endpoint for workspace users")
resourceType = ResourceType.Workspace
authInformation: dict = Field({})
23 changes: 21 additions & 2 deletions management_api_app/models/schemas/workspace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import List
from pydantic import BaseModel, Field

Expand All @@ -23,10 +24,23 @@ def get_sample_workspace(workspace_id: str, spec_workspace_id: str = "0001") ->
},
"isDeleted": False,
"resourceType": "workspace",
"workspaceURL": ""
"workspaceURL": "",
"authInformation": {}
}


class AuthProvider(str, Enum):
"""
Auth Provider
"""
AAD = "AAD"


class AuthenticationConfiguration(BaseModel):
provider: AuthProvider = Field(AuthProvider.AAD, title="Authentication Provider")
data: dict = Field({}, title="Authentication information")


class WorkspaceInResponse(BaseModel):
workspace: Workspace

Expand Down Expand Up @@ -57,14 +71,19 @@ class WorkspaceInCreate(BaseModel):
workspaceType: str = Field(title="Workspace type", description="Bundle name")
description: str = Field(title="Workspace description")
parameters: dict = Field({}, title="Workspace parameters", description="Values for the parameters required by the workspace resource specification")
authConfig: AuthenticationConfiguration = Field(title="Authentication configuration", description="Authentication configuration for the workspace")

class Config:
schema_extra = {
"example": {
"displayName": "My workspace",
"description": "workspace for team X",
"workspaceType": "tre-workspace-vanilla",
"parameters": {}
"parameters": {},
"authConfig": {
"provider": "AAD",
"data": {"app_id": "1212445c-aae6-41ec-a539-30dfa90ab1ae"}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions management_api_app/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ azure-identity==1.6.0
azure-servicebus==7.2.0
opencensus-ext-azure==1.0.8
opencensus-ext-logging==0.1.0
msal==1.12.0
5 changes: 5 additions & 0 deletions management_api_app/resources/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
UNSPECIFIED_ERROR = "Unspecified error"

# Error strings
ACCESS_APP_IS_MISSING_ROLE = "The App is missing role"
ACCESS_PLEASE_SUPPLY_APP_ID = "Please supply the app_id for the AAD application"
ACCESS_UNABLE_TO_GET_INFO_FOR_APP = "Unable to get app info for app:"
AUTH_NOT_ASSIGNED_TO_ADMIN_ROLE = "Not assigned to admin role"
AUTH_COULD_NOT_VALIDATE_CREDENTIALS = "Could not validate credentials"
INVALID_AUTH_PROVIDER = "Invalid authentication provider"
UNABLE_TO_REPLACE_CURRENT_TEMPLATE = "Unable to replace the existing 'current' template with this name"
UNABLE_TO_PROCESS_REQUEST = "Unable to process request"
WORKSPACE_DOES_NOT_EXIST = "Workspace does not exist"
WORKSPACE_TEMPLATE_DOES_NOT_EXIST = "Could not retrieve the 'current' template with this name"
WORKSPACE_TEMPLATE_VERSION_EXISTS = "A template with this version already exists"


# Resource Status
RESOURCE_STATUS_NOT_DEPLOYED = "not_deployed"
RESOURCE_STATUS_DEPLOYING = "deploying"
Expand Down
67 changes: 67 additions & 0 deletions management_api_app/services/aad_access_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging

import requests
from msal import ConfidentialClientApplication

from core import config
from resources import strings
from services.access_service import AccessService, AuthConfigValidationError


class AADAccessService(AccessService):
@staticmethod
def _get_msgraph_token() -> str:
scopes = ["https://graph.microsoft.com/.default"]
app = ConfidentialClientApplication(client_id=config.API_CLIENT_ID, client_credential=config.API_CLIENT_SECRET, authority=f"{config.AAD_INSTANCE}/{config.AAD_TENANT_ID}")
result = app.acquire_token_silent(scopes=scopes, account=None)
if not result:
logging.info('No suitable token exists in cache, getting a new one from AAD')
result = app.acquire_token_for_client(scopes=scopes)
if "access_token" not in result:
logging.debug(result.get('error'))
logging.debug(result.get('error_description'))
logging.debug(result.get('correlation_id'))
raise Exception(result.get('error'))
return result["access_token"]

@staticmethod
def _get_auth_header(msgraph_token: str) -> dict:
return {'Authorization': 'Bearer ' + msgraph_token}

@staticmethod
def _get_service_principal_endpoint(app_id) -> str:
return f"https://graph.microsoft.com/v1.0/serviceprincipals?$filter=appid eq '{app_id}'"

def _get_app_sp_graph_data(self, app_id: str) -> dict:
msgraph_token = self._get_msgraph_token()
sp_endpoint = self._get_service_principal_endpoint(app_id)
graph_data = requests.get(sp_endpoint, headers=self._get_auth_header(msgraph_token)).json()
return graph_data

def _get_app_auth_info(self, app_id: str) -> dict:
graph_data = self._get_app_sp_graph_data(app_id)
if 'value' not in graph_data or len(graph_data['value']) == 0:
logging.debug(graph_data)
raise AuthConfigValidationError(f"{strings.ACCESS_UNABLE_TO_GET_INFO_FOR_APP} {app_id}")

app_info = graph_data['value'][0]
sp_id = app_info['id']
roles = app_info['appRoles']

return {
'sp_id': sp_id,
'roles': {role['value']: role['id'] for role in roles}
}

def extract_workspace_auth_information(self, data: dict) -> dict:
if "app_id" not in data:
raise AuthConfigValidationError(strings.ACCESS_PLEASE_SUPPLY_APP_ID)

auth_info = self._get_app_auth_info(data["app_id"])
print(auth_info)

for role in ['WorkspaceOwner', 'WorkspaceResearcher']:
if role not in auth_info['roles']:
raise AuthConfigValidationError(f"{strings.ACCESS_APP_IS_MISSING_ROLE} {role}")

return auth_info
11 changes: 11 additions & 0 deletions management_api_app/services/access_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC, abstractmethod


class AuthConfigValidationError(Exception):
"""Raised when the input auth information is invalid"""


class AccessService(ABC):
@abstractmethod
def extract_workspace_auth_information(self, data: dict) -> dict:
pass
34 changes: 20 additions & 14 deletions management_api_app/tests/test_api/test_routes/test_workspaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from mock import AsyncMock, patch
from mock import patch

from fastapi import FastAPI
from httpx import AsyncClient
Expand All @@ -22,7 +22,8 @@ def create_sample_workspace_object(workspace_id):
resourceTemplateName="tre-workspace-vanilla",
resourceTemplateVersion="0.1.0",
resourceTemplateParameters={},
deployment=Deployment(status=Status.NotDeployed, message="")
deployment=Deployment(status=Status.NotDeployed, message=""),
authInformation={}
)


Expand All @@ -31,7 +32,13 @@ def create_sample_workspace_input_data():
"displayName": "My workspace",
"description": "workspace for team X",
"workspaceType": "tre-workspace-vanilla",
"parameters": {}
"parameters": {},
"authConfig": {
"provider": "AAD",
"data": {
"app_id": "1212445c-aae6-41ec-a539-30dfa90ab1ae"
}
}
}


Expand Down Expand Up @@ -84,12 +91,11 @@ async def test_workspaces_id_get_returns_workspace_if_found(get_workspace_mock,
assert actual_resource["id"] == sample_workspace["id"]


@patch('service_bus.resource_request_sender.ServiceBusClient')
@patch("api.routes.workspaces.send_resource_request_message")
@patch("api.routes.workspaces.WorkspaceRepository.save_workspace")
@patch("api.routes.workspaces.WorkspaceRepository.create_workspace_item")
async def test_workspaces_post_creates_workspace(create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, service_bus_client_mock, app: FastAPI, client: AsyncClient):
service_bus_client_mock().get_queue_sender().send_messages = AsyncMock()
@patch("api.routes.workspaces.extract_auth_information", return_value={})
async def test_workspaces_post_creates_workspace(extract_auth_info_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient):
workspace_id = "000000d3-82da-4bfc-b6e9-9a7853ef753e"
create_workspace_item_mock.return_value = create_sample_workspace_object(workspace_id)
input_data = create_sample_workspace_input_data()
Expand All @@ -100,15 +106,14 @@ async def test_workspaces_post_creates_workspace(create_workspace_item_mock, sav
assert response.json()["workspaceId"] == workspace_id


@patch('service_bus.resource_request_sender.ServiceBusClient')
@patch("api.routes.workspaces.send_resource_request_message")
@patch("api.routes.workspaces.WorkspaceRepository.save_workspace")
@patch("api.routes.workspaces.WorkspaceRepository.create_workspace_item")
@patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters")
async def test_workspaces_post_calls_db_and_service_bus(validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, service_bus_client_mock, app: FastAPI, client: AsyncClient):
@patch("api.routes.workspaces.extract_auth_information", return_value={})
async def test_workspaces_post_calls_db_and_service_bus(extract_auth_info_mock, validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient):
workspace_id = "000000d3-82da-4bfc-b6e9-9a7853ef753e"
validate_workspace_parameters_mock.return_value = None
service_bus_client_mock().get_queue_sender().send_messages = AsyncMock()
create_workspace_item_mock.return_value = create_sample_workspace_object(workspace_id)
input_data = create_sample_workspace_input_data()

Expand All @@ -118,15 +123,14 @@ async def test_workspaces_post_calls_db_and_service_bus(validate_workspace_param
send_resource_request_message_mock.assert_called_once()


@patch('service_bus.resource_request_sender.ServiceBusClient')
@patch("api.routes.workspaces.send_resource_request_message")
@patch("api.routes.workspaces.WorkspaceRepository.save_workspace")
@patch("api.routes.workspaces.WorkspaceRepository.create_workspace_item")
@patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters")
async def test_workspaces_post_returns_202_on_successful_create(validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, service_bus_client_mock, app: FastAPI, client: AsyncClient):
@patch("api.routes.workspaces.extract_auth_information", return_value={})
async def test_workspaces_post_returns_202_on_successful_create(extract_auth_info_mock, validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient):
workspace_id = "000000d3-82da-4bfc-b6e9-9a7853ef753e"
validate_workspace_parameters_mock.return_value = None
service_bus_client_mock().get_queue_sender().send_messages = AsyncMock()
create_workspace_item_mock.return_value = create_sample_workspace_object(workspace_id)
input_data = create_sample_workspace_input_data()

Expand All @@ -140,7 +144,8 @@ async def test_workspaces_post_returns_202_on_successful_create(validate_workspa
@patch("api.routes.workspaces.WorkspaceRepository.save_workspace")
@patch("api.routes.workspaces.WorkspaceRepository.create_workspace_item")
@patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters")
async def test_workspaces_post_returns_503_if_service_bus_call_fails(validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient):
@patch("api.routes.workspaces.extract_auth_information", return_value={})
async def test_workspaces_post_returns_503_if_service_bus_call_fails(extract_auth_info_mock, validate_workspace_parameters_mock, create_workspace_item_mock, save_workspace_mock, send_resource_request_message_mock, app: FastAPI, client: AsyncClient):
workspace_id = "000000d3-82da-4bfc-b6e9-9a7853ef753e"
validate_workspace_parameters_mock.return_value = None
create_workspace_item_mock.return_value = create_sample_workspace_object(workspace_id)
Expand All @@ -155,7 +160,8 @@ async def test_workspaces_post_returns_503_if_service_bus_call_fails(validate_wo

@patch("api.routes.workspaces.WorkspaceRepository._get_current_workspace_template")
@patch("api.routes.workspaces.WorkspaceRepository._validate_workspace_parameters")
async def test_workspaces_post_returns_400_if_template_does_not_exist(validate_workspace_parameters_mock, get_current_workspace_template_mock, app: FastAPI, client: AsyncClient):
@patch("api.routes.workspaces.extract_auth_information", return_value={})
async def test_workspaces_post_returns_400_if_template_does_not_exist(extract_auth_info_mock, validate_workspace_parameters_mock, get_current_workspace_template_mock, app: FastAPI, client: AsyncClient):
validate_workspace_parameters_mock.return_value = None
get_current_workspace_template_mock.side_effect = EntityDoesNotExist
input_data = create_sample_workspace_input_data()
Expand Down
Loading