diff --git a/flamingo_tools/model_utils.py b/flamingo_tools/model_utils.py index 64b0f67..714f6ff 100644 --- a/flamingo_tools/model_utils.py +++ b/flamingo_tools/model_utils.py @@ -117,6 +117,51 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None return model +def get_default_segmentation_settings(model_type: str) -> Dict[str, Union[str, float]]: + """Get the default settings for instance segmentation post-processing for a given model. + + Args: + model_type: The model. One of 'SGN', 'SGN-lowres', 'IHC', 'IHC-lowres'. + + Returns: + Dictionary with the default segmentation settings. + """ + all_default_kwargs = { + "SGN": { + "center_distance_threshold": 0.4, + "boundary_distance_threshold": 0.5, + "fg_threshold": 0.5, + "distance_smoothing": 0.0, + "seg_class": "sgn", + }, + "SGN-lowres": { + "center_distance_threshold": None, + "boundary_distance_threshold": 0.5, + "fg_threshold": 0.5, + "distance_smoothing": 0.0, + "seg_class": "sgn_low", + }, + "IHC": { + "center_distance_threshold": 0.5, + "boundary_distance_threshold": 0.6, + "fg_threshold": 0.5, + "distance_smoothing": 0.6, + "seg_class": "ihc", + }, + "IHC-lowres": { + "center_distance_threshold": 0.5, + "boundary_distance_threshold": 0.6, + "fg_threshold": 0.5, + "distance_smoothing": 0.6, + "seg_class": "ihc", + }, + } + if model_type not in all_default_kwargs: + raise ValueError(f"Invalid model: {model_type}. Choose one of {list(all_default_kwargs.keys())}.") + default_kwargs = all_default_kwargs[model_type] + return default_kwargs + + def get_default_tiling() -> Dict[str, Dict[str, int]]: """Determine the tile shape and halo depending on the available VRAM. diff --git a/flamingo_tools/plugin/segmentation_widget.py b/flamingo_tools/plugin/segmentation_widget.py index 5d22fed..d629bd9 100644 --- a/flamingo_tools/plugin/segmentation_widget.py +++ b/flamingo_tools/plugin/segmentation_widget.py @@ -9,25 +9,27 @@ from .base_widget import BaseWidget from .util import _load_custom_model, _available_devices, _get_current_tiling -from ..model_utils import get_model, get_model_registry, get_device, get_default_tiling +from ..model_utils import ( + get_model, get_model_registry, get_device, get_default_tiling, get_default_segmentation_settings +) -# TODO Expose segmentation kwargs. -def _run_segmentation(image, model, model_type, tiling, device): +def _run_segmentation(image, model, model_type, tiling, device, min_size): block_shape = [tiling["tile"][ax] for ax in "zyx"] halo = [tiling["halo"][ax] for ax in "zyx"] prediction = predict_with_halo( image, model, gpu_ids=[device], block_shape=block_shape, halo=halo, tqdm_desc="Run prediction" ) + settings = get_default_segmentation_settings(model_type) + foreground_threshold = settings.pop("fg_threshold", 0.5) + settings.pop("seg_class", None) + settings = {name: 1.0 if val is None else val for name, val in settings.items()} foreground_map, center_distances, boundary_distances = prediction segmentation = watershed_from_center_and_boundary_distances( center_distances, boundary_distances, foreground_map, - center_distance_threshold=0.5, - boundary_distance_threshold=0.5, - foreground_threshold=0.5, - distance_smoothing=1.6, - min_size=100, + min_size=min_size, foreground_threshold=foreground_threshold, + **settings, ) return segmentation @@ -110,7 +112,9 @@ def on_predict(self): # Get the current tiling. self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type) - segmentation = _run_segmentation(image, model=model, model_type=model_type, tiling=self.tiling, device=device) + segmentation = _run_segmentation( + image, model=model, model_type=model_type, tiling=self.tiling, device=device, min_size=self.min_size + ) self.viewer.add_labels(segmentation, name=model_type) show_info(f"INFO: Segmentation of {model_type} added to layers.") @@ -120,6 +124,11 @@ def _create_settings_widget(self): # setting_values.setToolTip(get_tooltip("embedding", "settings")) setting_values.setLayout(QVBoxLayout()) + # Create UI for the min-size parameter. + self.min_size = 100 + self.min_size_menu, layout = self._add_int_param("min_size", self.min_size, 0, 10000) + setting_values.layout().addLayout(layout) + # Create UI for the device. device = "auto" device_options = ["auto"] + _available_devices() diff --git a/flamingo_tools/segmentation/cli.py b/flamingo_tools/segmentation/cli.py index b971ebc..953af2e 100644 --- a/flamingo_tools/segmentation/cli.py +++ b/flamingo_tools/segmentation/cli.py @@ -4,7 +4,7 @@ from .unet_prediction import run_unet_prediction from .synapse_detection import marker_detection -from ..model_utils import get_model_path +from ..model_utils import get_model_path, get_default_segmentation_settings def _get_model_path(model_type, checkpoint_path=None): @@ -49,32 +49,7 @@ def _convert_argval(value): def _parse_segmentation_kwargs(extra_kwargs, model_type): - if model_type == "SGN": - default_kwargs = { - "center_distance_threshold": 0.4, - "boundary_distance_threshold": 0.5, - "fg_threshold": 0.5, - "distance_smoothing": 0.0, - "seg_class": "sgn", - } - elif model_type == "SGN-lowres": - default_kwargs = { - "center_distance_threshold": None, - "boundary_distance_threshold": 0.5, - "fg_threshold": 0.5, - "distance_smoothing": 0.0, - "seg_class": "sgn_low", - } - else: - assert model_type.startswith("IHC") - default_kwargs = { - "center_distance_threshold": 0.5, - "boundary_distance_threshold": 0.6, - "fg_threshold": 0.5, - "distance_smoothing": 0.6, - "seg_class": "ihc", - } - + default_kwargs = get_default_segmentation_settings(model_type) kwargs = _parse_kwargs(extra_kwargs, **default_kwargs) return kwargs