Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make grid search functionality more flexible #300

Merged
merged 3 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion development/seg_with_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
print("Start segmentation ...")
segmenter.initialize(image, image_embeddings)
masks = segmenter.generate(output_mode="binary_mask")
segmentation = mask_data_to_segmentation(masks, image.shape, with_background=True)
segmentation = mask_data_to_segmentation(masks, with_background=True)
print("Segmentation done")

v = napari.Viewer()
Expand Down
3 changes: 1 addition & 2 deletions micro_sam/evaluation/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,8 @@ def run_inference_with_prompts(
def _save_segmentation(masks, prediction_path):
# masks to segmentation
masks = masks.cpu().numpy().squeeze().astype("bool")
shape = masks.shape[-2:]
masks = [{"segmentation": mask, "area": mask.sum()} for mask in masks]
segmentation = mask_data_to_segmentation(masks, shape, with_background=True)
segmentation = mask_data_to_segmentation(masks, with_background=True)
imageio.imwrite(prediction_path, segmentation, compression=5)


Expand Down
87 changes: 63 additions & 24 deletions micro_sam/evaluation/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd

from elf.evaluation import mean_segmentation_accuracy
from elf.io import open_file
from tqdm import tqdm

from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation
Expand Down Expand Up @@ -56,16 +57,28 @@ def default_grid_search_values_amg(
}


# TODO document the function
# TODO smaller default search range
def default_grid_search_values_instance_segmentation_with_decoder(
center_distance_threshold_values: Optional[List[float]] = None,
boundary_distance_threshold_values: Optional[List[float]] = None,
distance_smoothing_values: Optional[List[float]] = None,
min_size_values: Optional[List[float]] = None,

) -> Dict[str, List[float]]:
"""Default grid-search parameter for decoder-based instance segmentation.

Args:
center_distance_threshold_values: The values for `center_distance_threshold` used in the gridsearch.
By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
boundary_distance_threshold_values: The values for `boundary_distance_threshold` used in the gridsearch.
By default values in the range from 0.5 to 0.9 with a stepsize of 0.1 will be used.
distance_smoothing_values: The values for `distance_smoothing` used in the gridsearch.
By default values in the range from 1.0 to 2.0 with a stepsize of 0.1 will be used.
min_size_values: The values for `min_size` used in the gridsearch.
By default the values 25, 50, 75, 100 and 200 are used.

Returns:
The values for grid search.
"""
if center_distance_threshold_values is None:
center_distance_threshold_values = _get_range_of_search_values(
[0.5, 0.9], step=0.1
Expand All @@ -80,7 +93,6 @@ def default_grid_search_values_instance_segmentation_with_decoder(
)
if min_size_values is None:
min_size_values = [25, 50, 75, 100, 200]

return {
"center_distance_threshold": center_distance_threshold_values,
"boundary_distance_threshold": boundary_distance_threshold_values,
Expand All @@ -89,18 +101,22 @@ def default_grid_search_values_instance_segmentation_with_decoder(
}


def _grid_search(
segmenter, gs_combinations, gt, image_name, result_path, fixed_generate_kwargs, verbose,
):
def _grid_search_iteration(
segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
gs_combinations: List[Dict],
gt: np.ndarray,
image_name: str,
fixed_generate_kwargs: Dict[str, Any],
result_path: Optional[Union[str, os.PathLike]],
verbose: bool = False,
) -> pd.DataFrame:
net_list = []
for gs_kwargs in tqdm(gs_combinations, disable=not verbose):
generate_kwargs = gs_kwargs | fixed_generate_kwargs
masks = segmenter.generate(**generate_kwargs)

min_object_size = generate_kwargs.get("min_mask_region_area", 0)
instance_labels = mask_data_to_segmentation(
masks, gt.shape, with_background=True, min_object_size=min_object_size,
)
instance_labels = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)
m_sas, sas = mean_segmentation_accuracy(instance_labels, gt, return_accuracies=True) # type: ignore

result_dict = {"image_name": image_name, "mSA": m_sas, "SA50": sas[0], "SA75": sas[5]}
Expand All @@ -111,16 +127,32 @@ def _grid_search(
img_gs_df = pd.concat(net_list)
img_gs_df.to_csv(result_path, index=False)

return img_gs_df


def _load_image(path, key, roi):
if key is None:
im = imageio.imread(path)
if roi is not None:
im = im[roi]
return im
with open_file(path, "r") as f:
im = f[key][:] if roi is None else f[key][roi]
return im


def run_instance_segmentation_grid_search(
segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
grid_search_values: Dict[str, List],
image_paths: List[Union[str, os.PathLike]],
gt_paths: List[Union[str, os.PathLike]],
embedding_dir: Union[str, os.PathLike],
result_dir: Union[str, os.PathLike],
embedding_dir: Optional[Union[str, os.PathLike]],
fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
verbose_gs: bool = False,
image_key: Optional[str] = None,
gt_key: Optional[str] = None,
rois: Optional[Tuple[slice, ...]] = None,
) -> None:
"""Run grid search for automatic mask generation.

Expand All @@ -144,10 +176,15 @@ def run_instance_segmentation_grid_search(
grid_search_values: The grid search values for parameters of the `generate` function.
image_paths: The input images for the grid search.
gt_paths: The ground-truth segmentation for the grid search.
embedding_dir: Folder to cache the image embeddings.
result_dir: Folder to cache the evaluation results per image.
embedding_dir: Folder to cache the image embeddings.
fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
image_key: Key for loading the image data from a more complex file format like HDF5.
If not given a simple image format like tif is assumed.
gt_key: Key for loading the ground-truth data from a more complex file format like HDF5.
If not given a simple image format like tif is assumed.
rois: Region of interests to resetrict the evaluation to.
"""
assert len(image_paths) == len(gt_paths)
fixed_generate_kwargs = {} if fixed_generate_kwargs is None else fixed_generate_kwargs
Expand All @@ -167,10 +204,10 @@ def run_instance_segmentation_grid_search(
]

os.makedirs(result_dir, exist_ok=True)
predictor = segmenter._predictor
predictor = getattr(segmenter, "_predictor", None)

for image_path, gt_path in tqdm(
zip(image_paths, gt_paths), desc="Run instance segmentation grid-search", total=len(image_paths)
for i, (image_path, gt_path) in tqdm(
enumerate(zip(image_paths, gt_paths)), desc="Run instance segmentation grid-search", total=len(image_paths)
):
image_name = Path(image_path).stem
result_path = os.path.join(result_dir, f"{image_name}.csv")
Expand All @@ -182,16 +219,20 @@ def run_instance_segmentation_grid_search(
assert os.path.exists(image_path), image_path
assert os.path.exists(gt_path), gt_path

image = imageio.imread(image_path)
gt = imageio.imread(gt_path)
image = _load_image(image_path, image_key, roi=None if rois is None else rois[i])
gt = _load_image(gt_path, gt_key, roi=None if rois is None else rois[i])

embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
segmenter.initialize(image, image_embeddings)
if embedding_dir is None:
segmenter.initialize(image)
else:
assert predictor is not None
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
segmenter.initialize(image, image_embeddings)

_grid_search(
_grid_search_iteration(
segmenter, gs_combinations, gt, image_name,
result_path=result_path, fixed_generate_kwargs=fixed_generate_kwargs, verbose=verbose_gs,
fixed_generate_kwargs=fixed_generate_kwargs, result_path=result_path, verbose=verbose_gs,
)


Expand Down Expand Up @@ -232,9 +273,7 @@ def run_instance_segmentation_inference(

segmenter.initialize(image, image_embeddings)
masks = segmenter.generate(**generate_kwargs)
instances = mask_data_to_segmentation(
masks, image.shape, with_background=True, min_object_size=min_object_size,
)
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=min_object_size)

# It's important to compress here, otherwise the predictions would take up a lot of space.
imageio.imwrite(prediction_path, instances, compression=5)
Expand Down
3 changes: 1 addition & 2 deletions micro_sam/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ def batched_inference(
# then we need to select the most likely mask (according to the predicted IOU) here.
if reduce_multimasking and multimasking:
_, max_index = batch_ious.max(axis=1)
# How can this be vectorized???
batch_masks = torch.cat([batch_masks[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)
batch_ious = torch.cat([batch_ious[i, max_id][None] for i, max_id in enumerate(max_index)]).unsqueeze(1)

Expand All @@ -144,6 +143,6 @@ def batched_inference(
]

if return_instance_segmentation:
masks = mask_data_to_segmentation(masks, image_shape, with_background=False, min_object_size=0)
masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0)

return masks
26 changes: 21 additions & 5 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def __getitem__(self, index):

def mask_data_to_segmentation(
masks: List[Dict[str, Any]],
shape: Tuple[int, ...],
with_background: bool,
min_object_size: int = 0,
max_object_size: Optional[int] = None,
Expand All @@ -59,7 +58,6 @@ def mask_data_to_segmentation(
Args:
masks: The outputs generated by AutomaticMaskGenerator or EmbeddingMaskGenerator.
Only supports output_mode=binary_mask.
shape: The image shape.
with_background: Whether the segmentation has background. If yes this function assures that the largest
object in the output will be mapped to zero (the background value).
min_object_size: The minimal size of an object in pixels.
Expand All @@ -69,7 +67,9 @@ def mask_data_to_segmentation(
"""

masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
segmentation = np.zeros(shape[:2], dtype="uint32")
# we could also get the shape from the crop box
shape = next(iter(masks))["segmentation"].shape
segmentation = np.zeros(shape, dtype="uint32")

def require_numpy(mask):
return mask.cpu().numpy() if torch.is_tensor(mask) else mask
Expand Down Expand Up @@ -872,16 +872,32 @@ def initialize(
def _to_masks(self, segmentation, output_mode):
if output_mode != "binary_mask":
raise NotImplementedError

props = regionprops(segmentation)
crop_box = [0, segmentation.shape[1], 0, segmentation.shape[0]]
ndim = segmentation.ndim
assert ndim in (2, 3)

shape = segmentation.shape
if ndim == 2:
crop_box = [0, shape[1], 0, shape[0]]
else:
crop_box = [0, shape[2], 0, shape[1], 0, shape[0]]

# go from skimage bbox in format [y0, x0, y1, x1] to SAM format [x0, w, y0, h]
def to_bbox(bbox):
def to_bbox_2d(bbox):
y0, x0 = bbox[0], bbox[1]
w = bbox[3] - x0
h = bbox[2] - y0
return [x0, w, y0, h]

def to_bbox_3d(bbox):
z0, y0, x0 = bbox[0], bbox[1], bbox[2]
w = bbox[5] - x0
h = bbox[4] - y0
d = bbox[3] - y0
return [x0, w, y0, h, z0, d]

to_bbox = to_bbox_2d if ndim == 2 else to_bbox_3d
masks = [
{
"segmentation": segmentation == prop.label,
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def segment_3d_from_slice(

seg_z = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
seg_z = mask_data_to_segmentation(
seg_z, shape=raw.shape[1:], with_background=True,
seg_z, with_background=True,
min_object_size=min_object_size_z,
max_object_size=max_object_size_z,
)
Expand Down
2 changes: 1 addition & 1 deletion micro_sam/sam_annotator/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _autosegment_widget(
seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)

seg = instance_segmentation.mask_data_to_segmentation(
seg, shape, with_background=with_background, min_object_size=min_object_size
seg, with_background=with_background, min_object_size=min_object_size
)
assert isinstance(seg, np.ndarray)

Expand Down
2 changes: 1 addition & 1 deletion micro_sam/sam_annotator/annotator_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _autosegment_widget(
seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)

seg = instance_segmentation.mask_data_to_segmentation(
seg, shape, with_background=with_background, min_object_size=min_object_size
seg, with_background=with_background, min_object_size=min_object_size
)
assert isinstance(seg, np.ndarray)

Expand Down
12 changes: 6 additions & 6 deletions test/test_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,20 @@ def test_automatic_mask_generator(self):
amg.initialize(image, image_embeddings=image_embeddings, verbose=False)

predicted = amg.generate()
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
predicted = mask_data_to_segmentation(predicted, with_background=True)
self.assertGreater(matching(predicted, mask, threshold=0.75)["segmentation_accuracy"], 0.99)

# check that regenerating the segmentation works
predicted2 = amg.generate()
predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True)
predicted2 = mask_data_to_segmentation(predicted2, with_background=True)
self.assertTrue(np.array_equal(predicted, predicted2))

# check that serializing and reserializing the state works
state = amg.get_state()
amg = AutomaticMaskGenerator(predictor, points_per_side=10, points_per_batch=16)
amg.set_state(state)
predicted3 = amg.generate()
predicted3 = mask_data_to_segmentation(predicted3, image.shape, with_background=True)
predicted3 = mask_data_to_segmentation(predicted3, with_background=True)
self.assertTrue(np.array_equal(predicted, predicted3))

def test_tiled_automatic_mask_generator(self):
Expand All @@ -107,19 +107,19 @@ def test_tiled_automatic_mask_generator(self):
amg = TiledAutomaticMaskGenerator(predictor, points_per_side=8)
amg.initialize(image, image_embeddings=image_embeddings, verbose=False)
predicted = amg.generate(pred_iou_thresh=pred_iou_thresh)
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
predicted = mask_data_to_segmentation(predicted, with_background=True)
self.assertGreater(matching(predicted, mask, threshold=0.75)["segmentation_accuracy"], 0.99)

predicted2 = amg.generate(pred_iou_thresh=pred_iou_thresh)
predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True)
predicted2 = mask_data_to_segmentation(predicted2, with_background=True)
self.assertTrue(np.array_equal(predicted, predicted2))

# check that serializing and reserializing the state works
state = amg.get_state()
amg = TiledAutomaticMaskGenerator(predictor)
amg.set_state(state)
predicted3 = amg.generate(pred_iou_thresh=pred_iou_thresh)
predicted3 = mask_data_to_segmentation(predicted3, image.shape, with_background=True)
predicted3 = mask_data_to_segmentation(predicted3, with_background=True)
self.assertTrue(np.array_equal(predicted, predicted3))


Expand Down