Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions flamingo_tools/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,51 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
return model


def get_default_segmentation_settings(model_type: str) -> Dict[str, Union[str, float]]:
"""Get the default settings for instance segmentation post-processing for a given model.

Args:
model_type: The model. One of 'SGN', 'SGN-lowres', 'IHC', 'IHC-lowres'.

Returns:
Dictionary with the default segmentation settings.
"""
all_default_kwargs = {
"SGN": {
"center_distance_threshold": 0.4,
"boundary_distance_threshold": 0.5,
"fg_threshold": 0.5,
"distance_smoothing": 0.0,
"seg_class": "sgn",
},
"SGN-lowres": {
"center_distance_threshold": None,
"boundary_distance_threshold": 0.5,
"fg_threshold": 0.5,
"distance_smoothing": 0.0,
"seg_class": "sgn_low",
},
"IHC": {
"center_distance_threshold": 0.5,
"boundary_distance_threshold": 0.6,
"fg_threshold": 0.5,
"distance_smoothing": 0.6,
"seg_class": "ihc",
},
"IHC-lowres": {
"center_distance_threshold": 0.5,
"boundary_distance_threshold": 0.6,
"fg_threshold": 0.5,
"distance_smoothing": 0.6,
"seg_class": "ihc",
},
}
if model_type not in all_default_kwargs:
raise ValueError(f"Invalid model: {model_type}. Choose one of {list(all_default_kwargs.keys())}.")
default_kwargs = all_default_kwargs[model_type]
return default_kwargs


def get_default_tiling() -> Dict[str, Dict[str, int]]:
"""Determine the tile shape and halo depending on the available VRAM.

Expand Down
27 changes: 18 additions & 9 deletions flamingo_tools/plugin/segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,27 @@

from .base_widget import BaseWidget
from .util import _load_custom_model, _available_devices, _get_current_tiling
from ..model_utils import get_model, get_model_registry, get_device, get_default_tiling
from ..model_utils import (
get_model, get_model_registry, get_device, get_default_tiling, get_default_segmentation_settings
)


# TODO Expose segmentation kwargs.
def _run_segmentation(image, model, model_type, tiling, device):
def _run_segmentation(image, model, model_type, tiling, device, min_size):
block_shape = [tiling["tile"][ax] for ax in "zyx"]
halo = [tiling["halo"][ax] for ax in "zyx"]
prediction = predict_with_halo(
image, model, gpu_ids=[device], block_shape=block_shape, halo=halo,
tqdm_desc="Run prediction"
)
settings = get_default_segmentation_settings(model_type)
foreground_threshold = settings.pop("fg_threshold", 0.5)
settings.pop("seg_class", None)
settings = {name: 1.0 if val is None else val for name, val in settings.items()}
foreground_map, center_distances, boundary_distances = prediction
segmentation = watershed_from_center_and_boundary_distances(
center_distances, boundary_distances, foreground_map,
center_distance_threshold=0.5,
boundary_distance_threshold=0.5,
foreground_threshold=0.5,
distance_smoothing=1.6,
min_size=100,
min_size=min_size, foreground_threshold=foreground_threshold,
**settings,
)
return segmentation

Expand Down Expand Up @@ -110,7 +112,9 @@ def on_predict(self):

# Get the current tiling.
self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
segmentation = _run_segmentation(image, model=model, model_type=model_type, tiling=self.tiling, device=device)
segmentation = _run_segmentation(
image, model=model, model_type=model_type, tiling=self.tiling, device=device, min_size=self.min_size
)

self.viewer.add_labels(segmentation, name=model_type)
show_info(f"INFO: Segmentation of {model_type} added to layers.")
Expand All @@ -120,6 +124,11 @@ def _create_settings_widget(self):
# setting_values.setToolTip(get_tooltip("embedding", "settings"))
setting_values.setLayout(QVBoxLayout())

# Create UI for the min-size parameter.
self.min_size = 100
self.min_size_menu, layout = self._add_int_param("min_size", self.min_size, 0, 10000)
setting_values.layout().addLayout(layout)

# Create UI for the device.
device = "auto"
device_options = ["auto"] + _available_devices()
Expand Down
29 changes: 2 additions & 27 deletions flamingo_tools/segmentation/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .unet_prediction import run_unet_prediction
from .synapse_detection import marker_detection
from ..model_utils import get_model_path
from ..model_utils import get_model_path, get_default_segmentation_settings


def _get_model_path(model_type, checkpoint_path=None):
Expand Down Expand Up @@ -49,32 +49,7 @@ def _convert_argval(value):


def _parse_segmentation_kwargs(extra_kwargs, model_type):
if model_type == "SGN":
default_kwargs = {
"center_distance_threshold": 0.4,
"boundary_distance_threshold": 0.5,
"fg_threshold": 0.5,
"distance_smoothing": 0.0,
"seg_class": "sgn",
}
elif model_type == "SGN-lowres":
default_kwargs = {
"center_distance_threshold": None,
"boundary_distance_threshold": 0.5,
"fg_threshold": 0.5,
"distance_smoothing": 0.0,
"seg_class": "sgn_low",
}
else:
assert model_type.startswith("IHC")
default_kwargs = {
"center_distance_threshold": 0.5,
"boundary_distance_threshold": 0.6,
"fg_threshold": 0.5,
"distance_smoothing": 0.6,
"seg_class": "ihc",
}

default_kwargs = get_default_segmentation_settings(model_type)
kwargs = _parse_kwargs(extra_kwargs, **default_kwargs)
return kwargs

Expand Down