diff --git a/prediction/src/algorithms/segment/trained_model.py b/prediction/src/algorithms/segment/trained_model.py index d67452ec..8d61ac51 100644 --- a/prediction/src/algorithms/segment/trained_model.py +++ b/prediction/src/algorithms/segment/trained_model.py @@ -27,7 +27,32 @@ def predict(dicom_path, centroids): 'z': int} Returns: - str: a path to the serialized binary mask that can be used for - segmentation + dict: Dictionary containing path to serialized binary masks and + volumes per centroid with form:: + {'binary_mask_path': str, + 'volumes': list[float]} """ - return 'path/to/segmentation' + segment_path = 'path/to/segmentation' + volumes = calculate_volume(segment_path, centroids) + return_value = { + 'binary_mask_path': segment_path, + 'volumes': volumes + } + return return_value + + +def calculate_volume(segment_path, centroids): + """ Calculates tumor volume from pixel masks + + Args: + segment_path (str): A path to the serialized binary mask for + each centroid + centroids (list[dict]): A list of centroids of the form:: + {'x': int, + 'y': int, + 'z': int} + + Returns: + list[float]: List of volumes per centroid + """ + return [0.5 for centroid in centroids] diff --git a/prediction/src/tests/test_endpoints.py b/prediction/src/tests/test_endpoints.py index 3384a79a..089ee947 100644 --- a/prediction/src/tests/test_endpoints.py +++ b/prediction/src/tests/test_endpoints.py @@ -96,7 +96,8 @@ def test_segment(client): data = get_data(r) - assert isinstance(data['prediction'], str) + assert isinstance(data['prediction']['binary_mask_path'], str) + assert isinstance(data['prediction']['volumes'], list) def test_error(client): diff --git a/prediction/src/views.py b/prediction/src/views.py index 37eab05e..89920cec 100644 --- a/prediction/src/views.py +++ b/prediction/src/views.py @@ -81,11 +81,7 @@ def predict(algorithm): try: predict_method = PREDICTORS[algorithm] - if 'centroids' in payload: - prediction = predict_method(payload['dicom_path'], - payload['centroids']) - else: - prediction = predict_method(payload['dicom_path']) + prediction = predict_method(**payload) response.update({ 'prediction': prediction