diff --git a/hi-ml/src/health_ml/utils/box_utils.py b/hi-ml/src/health_ml/utils/box_utils.py index 5beb39fb1..1e398a55e 100644 --- a/hi-ml/src/health_ml/utils/box_utils.py +++ b/hi-ml/src/health_ml/utils/box_utils.py @@ -1,7 +1,13 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + from dataclasses import dataclass from typing import Optional, Sequence, Tuple import numpy as np +from scipy import ndimage @dataclass(frozen=True) @@ -123,14 +129,15 @@ def get_bounding_box(mask: np.ndarray) -> Box: :param mask: A 2D array with 0 (or `False`) as background and >0 (or `True`) as foreground. :return: The smallest box covering all non-zero elements of `mask`. + :raises TypeError: When the input mask has more than two dimensions. + :raises RuntimeError: When all elements in the mask are zero. """ if mask.ndim != 2: - raise TypeError(f"Expected a 2D array but got {mask.ndim} dimensions") - - xs = np.sum(mask, 1).nonzero()[0] - ys = np.sum(mask, 0).nonzero()[0] - x_min, x_max = xs.min(), xs.max() - y_min, y_max = ys.min(), ys.max() - width = x_max - x_min + 1 - height = y_max - y_min + 1 - return Box(x_min, y_min, width, height) + raise TypeError(f"Expected a 2D array but got an array with shape {mask.shape}") + + slices = ndimage.find_objects(mask > 0) + if not slices: + raise RuntimeError("The input mask is empty") + assert len(slices) == 1 + + return Box.from_slices(slices[0]) diff --git a/hi-ml/testhiml/testhiml/test_box_utils.py b/hi-ml/testhiml/testhiml/test_box_utils.py index da27474bf..5974e7338 100644 --- a/hi-ml/testhiml/testhiml/test_box_utils.py +++ b/hi-ml/testhiml/testhiml/test_box_utils.py @@ -8,27 +8,78 @@ from health_ml.utils.box_utils import Box, get_bounding_box -def test_get_bounding_box() -> None: +def test_no_zeros() -> None: length_x = 3 length_y = 4 # If no elements are zero, the bounding box will have the same shape as the original - mask = np.random.randint(1, 10, size=(length_x, length_y)) + mask = np.random.randint(1, 10, size=(length_y, length_x)) bbox = get_bounding_box(mask) assert isinstance(bbox, Box) - assert bbox.w == length_x - assert bbox.h == length_y + expected = Box(x=0, y=0, w=length_x, h=length_y) + assert bbox == expected + +def test_bounding_box_3d() -> None: # passing a 3D array should cause an error to be raised - length_z = 5 - mask_3d = np.random.randint(0, 10, size=(length_x, length_y, length_z)) + mask_3d = np.random.randint(0, 10, size=(1, 2, 3)) with pytest.raises(TypeError): get_bounding_box(mask_3d) + +def test_identity_matrix() -> None: # passing an identity matrix will return a bounding box with the same shape as the original, # and xmin and ymin will both be zero - mask_eye = np.eye(length_x) - bbox_eye = get_bounding_box(mask_eye) - assert isinstance(bbox_eye, Box) - assert bbox_eye.w == length_x - assert bbox_eye.h == length_x - assert bbox_eye.x == bbox_eye.y == 0 + length = 5 + mask_eye = np.eye(length) + bbox = get_bounding_box(mask_eye) + expected = Box(x=0, y=0, w=length, h=length) + assert bbox == expected + + +def test_all_zeros() -> None: + mask = np.zeros((2, 3)) + with pytest.raises(RuntimeError): + get_bounding_box(mask) + + +def test_small_rectangle() -> None: + mask = np.zeros((5, 5), int) + row = 0 + height = 1 + width = 2 + col = 3 + mask[row:row + height, col:col + width] = 1 + # array([[0, 0, 0, 1, 1], + # [0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0], + # [0, 0, 0, 0, 0]]) + bbox = get_bounding_box(mask) + expected = Box(x=col, y=row, w=width, h=height) + assert bbox == expected + + +def test_tiny_mask() -> None: + mask = np.array(1).reshape(1, 1) + bbox = get_bounding_box(mask) + assert bbox.x == bbox.y == 0 + assert bbox.w == bbox.h == 1 + + +def test_tiny_box() -> None: + mask = np.array(( + (0, 0), + (0, 1), + )) + bbox = get_bounding_box(mask) + assert bbox.x == bbox.y == bbox.w == bbox.h == 1 + + +def test_multiple_components() -> None: + length = 3 + mask = np.zeros((length, length), int) + mask[0, 0] = 1 + mask[length - 1, length - 1] = 1 + bbox = get_bounding_box(mask) + expected = Box(x=0, y=0, w=length, h=length) + assert bbox == expected