From 0643002f99e98cfc27cc3e741f18255b1884d084 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 3 Jan 2024 23:44:09 +0100 Subject: [PATCH] Implement grid-search for distance based segmentation (#193) Implement grid-search for instance segmentation --- torch_em/util/grid_search.py | 140 ++++++++++++++++++++++++++++++++++ torch_em/util/segmentation.py | 5 +- 2 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 torch_em/util/grid_search.py diff --git a/torch_em/util/grid_search.py b/torch_em/util/grid_search.py new file mode 100644 index 00000000..18646902 --- /dev/null +++ b/torch_em/util/grid_search.py @@ -0,0 +1,140 @@ +import numpy as np + +from micro_sam.instance_segmentation import InstanceSegmentationWithDecoder +from micro_sam.evaluation.instance_segmentation import ( + default_grid_search_values_instance_segmentation_with_decoder, + evaluate_instance_segmentation_grid_search, + run_instance_segmentation_grid_search, + _get_range_of_search_values, +) + +from ..transform.raw import standardize +from .prediction import predict_with_padding, predict_with_halo +from .segmentation import watershed_from_components + + +def default_grid_search_values_boundary_based_instance_segmentation( + threshold1_values=None, + threshold2_values=None, + min_size_values=None, +): + if threshold1_values is None: + threshold1_values = [0.5, 0.55, 0.6] + if threshold2_values is None: + threshold2_values = _get_range_of_search_values( + [0.5, 0.9], step=0.1 + ) + if min_size_values is None: + min_size_values = [25, 50, 75, 100, 200] + + return { + "min_size": min_size_values, + "threshold1": threshold1_values, + "threshold2": threshold2_values, + } + + +class BoundaryBasedInstanceSegmentation(InstanceSegmentationWithDecoder): + def __init__(self, model, preprocess=None, block_shape=None, halo=None): + self._model = model + self._preprocess = standardize if preprocess is None else preprocess + + assert (block_shape is None) == (halo is None) + self._block_shape = block_shape + self._halo = halo + + self._foreground = None + self._boundaries = None + + self._is_initialized = False + + def initialize(self, data): + device = next(iter(self._model.parameters())).device + + if self._block_shape is None: + scale_factors = self._model.init_kwargs["scale_factors"] + min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] + input_ = self._preprocess(data) + output = predict_with_padding(self._model, input_, min_divisible, device) + else: + output = predict_with_halo( + data, self._model, [device], self._block_shape, self._halo, + preprocess=self._preprocess, + ) + + self._foreground = output[0] + self._boundaries = output[1] + + self._is_initialized = True + + def generate(self, min_size=50, threshold1=0.5, threshold2=0.5, output_mode="binary_mask"): + segmentation = watershed_from_components( + self._boundaries, self._foreground, + min_size=min_size, threshold1=threshold1, threshold2=threshold2, + ) + if output_mode is not None: + segmentation = self._to_masks(segmentation, output_mode) + return segmentation + + +class DistanceBasedInstanceSegmentation(InstanceSegmentationWithDecoder): + """Over-write micro_sam functionality so that it works for distance based + segmentation with a U-net. + """ + def __init__(self, model, preprocess=None, block_shape=None, halo=None): + self._model = model + self._preprocess = standardize if preprocess is None else preprocess + + assert (block_shape is None) == (halo is None) + self._block_shape = block_shape + self._halo = halo + + self._foreground = None + self._center_distances = None + self._boundary_distances = None + + self._is_initialized = False + + def initialize(self, data): + device = next(iter(self._model.parameters())).device + + if self._block_shape is None: + scale_factors = self._model.init_kwargs["scale_factors"] + min_divisible = [int(np.prod([sf[i] for sf in scale_factors])) for i in range(3)] + input_ = self._preprocess(data) + output = predict_with_padding(self._model, input_, min_divisible, device) + else: + output = predict_with_halo( + data, self._model, [device], self._block_shape, self._halo, + preprocess=self._preprocess, + ) + + self._foreground = output[0] + self._center_distances = output[1] + self._boundary_distances = output[2] + + self._is_initialized = True + + +def instance_segmentation_grid_search( + segmenter, image_paths, gt_paths, result_dir, + grid_search_values=None, rois=None, + image_key=None, gt_key=None, +): + if grid_search_values is None: + if isinstance(segmenter, DistanceBasedInstanceSegmentation): + grid_search_values = default_grid_search_values_instance_segmentation_with_decoder() + elif isinstance(segmenter, BoundaryBasedInstanceSegmentation): + grid_search_values = default_grid_search_values_boundary_based_instance_segmentation() + else: + raise ValueError(f"Could not derive default grid search values for segmenter of type {type(segmenter)}") + + run_instance_segmentation_grid_search( + segmenter, grid_search_values, image_paths, gt_paths, result_dir, + embedding_dir=None, verbose_gs=True, + image_key=image_key, gt_key=gt_key, rois=rois, + ) + best_kwargs, best_score = evaluate_instance_segmentation_grid_search( + result_dir, list(grid_search_values.keys()) + ) + return best_kwargs, best_score diff --git a/torch_em/util/segmentation.py b/torch_em/util/segmentation.py index f4ce71a7..42949974 100644 --- a/torch_em/util/segmentation.py +++ b/torch_em/util/segmentation.py @@ -38,7 +38,7 @@ def size_filter(seg, min_size, hmap=None, with_background=False): return seg -def mutex_watershed_segmentation(foreground, affinities, offsets, min_size=50, threshold=0.5): +def mutex_watershed_segmentation(foreground, affinities, offsets, min_size=50, threshold=0.5, strides=None): """Computes the mutex watershed segmentation using the affinity maps for respective pixel offsets Arguments: @@ -49,7 +49,8 @@ def mutex_watershed_segmentation(foreground, affinities, offsets, min_size=50, t - threshold: [float] - To threshold foreground predictions """ mask = (foreground >= threshold) - strides = [2] * foreground.ndim + if strides is None: + strides = [2] * foreground.ndim seg = mutex_watershed(affinities, offsets=offsets, mask=mask, strides=strides, randomize_strides=True) seg = size_filter(seg.astype("uint32"), min_size=min_size, hmap=affinities, with_background=True) return seg